forked from cosmonaut/wraith-lang
initial for loop range implementation
parent
c728dd6b8c
commit
62f42e47b9
|
@ -0,0 +1,19 @@
|
|||
struct Program
|
||||
{
|
||||
static Main(): int
|
||||
{
|
||||
sum: int;
|
||||
|
||||
sum = 0;
|
||||
|
||||
for (i in [1..1000])
|
||||
{
|
||||
if ((i % 3 == 0) || (i % 5 == 0))
|
||||
{
|
||||
sum = sum + i;
|
||||
}
|
||||
}
|
||||
|
||||
return sum;
|
||||
}
|
||||
}
|
|
@ -19,6 +19,8 @@
|
|||
"alloc" return ALLOC;
|
||||
"if" return IF;
|
||||
"else" return ELSE;
|
||||
"in" return IN;
|
||||
"for" return FOR;
|
||||
[0-9]+ return NUMBER;
|
||||
[a-zA-Z][a-zA-Z0-9]* return ID;
|
||||
\"[a-zA-Z][a-zA-Z0-9]*\" return STRING_LITERAL;
|
||||
|
|
|
@ -32,6 +32,8 @@ extern FILE *yyin;
|
|||
%token ALLOC
|
||||
%token IF
|
||||
%token ELSE
|
||||
%token IN
|
||||
%token FOR
|
||||
%token NUMBER
|
||||
%token ID
|
||||
%token STRING_LITERAL
|
||||
|
@ -65,10 +67,10 @@ extern FILE *yyin;
|
|||
|
||||
%define parse.error verbose
|
||||
|
||||
%left GREATER_THAN LESS_THAN
|
||||
%left GREATER_THAN LESS_THAN EQUAL
|
||||
%left PLUS MINUS
|
||||
%left STAR
|
||||
%left BANG
|
||||
%left STAR PERCENT
|
||||
%left BANG BAR
|
||||
%left LEFT_PAREN RIGHT_PAREN
|
||||
|
||||
%%
|
||||
|
@ -139,10 +141,12 @@ AccessExpression : Identifier POINT AccessExpression
|
|||
$$ = $1;
|
||||
}
|
||||
|
||||
PrimaryExpression : NUMBER
|
||||
Number : NUMBER
|
||||
{
|
||||
$$ = MakeNumberNode(yytext);
|
||||
}
|
||||
|
||||
PrimaryExpression : Number
|
||||
| STRING
|
||||
{
|
||||
$$ = MakeStringNode(yytext);
|
||||
|
@ -172,6 +176,10 @@ BinaryExpression : Expression PLUS Expression
|
|||
{
|
||||
$$ = MakeBinaryNode(Multiply, $1, $3);
|
||||
}
|
||||
| Expression PERCENT Expression
|
||||
{
|
||||
$$ = MakeBinaryNode(Mod, $1, $3);
|
||||
}
|
||||
| Expression LESS_THAN Expression
|
||||
{
|
||||
$$ = MakeBinaryNode(LessThan, $1, $3);
|
||||
|
@ -180,6 +188,14 @@ BinaryExpression : Expression PLUS Expression
|
|||
{
|
||||
$$ = MakeBinaryNode(GreaterThan, $1, $3);
|
||||
}
|
||||
| Expression EQUAL EQUAL Expression
|
||||
{
|
||||
$$ = MakeBinaryNode(Equal, $1, $4);
|
||||
}
|
||||
| Expression BAR BAR Expression
|
||||
{
|
||||
$$ = MakeBinaryNode(LogicalOr, $1, $4);
|
||||
}
|
||||
|
||||
Expression : BinaryExpression
|
||||
| UnaryExpression
|
||||
|
@ -240,8 +256,14 @@ Conditional : IfStatement
|
|||
$$ = MakeIfElseNode($1, $3);
|
||||
}
|
||||
|
||||
ForStatement : FOR LEFT_PAREN Identifier IN LEFT_BRACKET Number POINT POINT Number RIGHT_BRACKET RIGHT_PAREN LEFT_BRACE Statements RIGHT_BRACE
|
||||
{
|
||||
$$ = MakeForLoopNode($3, $6, $9, $13);
|
||||
}
|
||||
|
||||
Statement : PartialStatement SEMICOLON
|
||||
| Conditional
|
||||
| ForStatement
|
||||
;
|
||||
|
||||
Statements : Statement
|
||||
|
|
17
src/ast.c
17
src/ast.c
|
@ -440,6 +440,23 @@ Node* MakeIfElseNode(
|
|||
return node;
|
||||
}
|
||||
|
||||
Node* MakeForLoopNode(
|
||||
Node *identifierNode,
|
||||
Node *startNumberNode,
|
||||
Node *endNumberNode,
|
||||
Node *statementSequenceNode
|
||||
) {
|
||||
Node* node = (Node*) malloc(sizeof(Node));
|
||||
node->syntaxKind = ForLoop;
|
||||
node->childCount = 4;
|
||||
node->children = (Node**) malloc(sizeof(Node*) * 4);
|
||||
node->children[0] = identifierNode;
|
||||
node->children[1] = startNumberNode;
|
||||
node->children[2] = endNumberNode;
|
||||
node->children[3] = statementSequenceNode;
|
||||
return node;
|
||||
}
|
||||
|
||||
static const char* PrimitiveTypeToString(PrimitiveType type)
|
||||
{
|
||||
switch (type)
|
||||
|
|
11
src/ast.h
11
src/ast.h
|
@ -47,8 +47,11 @@ typedef enum
|
|||
Add,
|
||||
Subtract,
|
||||
Multiply,
|
||||
Mod,
|
||||
Equal,
|
||||
LessThan,
|
||||
GreaterThan
|
||||
GreaterThan,
|
||||
LogicalOr
|
||||
} BinaryOperator;
|
||||
|
||||
typedef enum
|
||||
|
@ -200,6 +203,12 @@ Node* MakeIfElseNode(
|
|||
Node *ifNode,
|
||||
Node *statementSequenceNode
|
||||
);
|
||||
Node* MakeForLoopNode(
|
||||
Node *identifierNode,
|
||||
Node *startNumberNode,
|
||||
Node *endNumberNode,
|
||||
Node *statementSequenceNode
|
||||
);
|
||||
|
||||
void PrintTree(Node *node, uint32_t tabCount);
|
||||
|
||||
|
|
154
src/codegen.c
154
src/codegen.c
|
@ -19,6 +19,7 @@ typedef struct LocalVariable
|
|||
{
|
||||
char *name;
|
||||
LLVMValueRef pointer;
|
||||
LLVMValueRef value;
|
||||
} LocalVariable;
|
||||
|
||||
typedef struct FunctionArgument
|
||||
|
@ -111,14 +112,19 @@ static void PopScopeFrame(Scope *scope)
|
|||
scope->scopeStack = realloc(scope->scopeStack, sizeof(ScopeFrame) * scope->scopeStackCount);
|
||||
}
|
||||
|
||||
static void AddLocalVariable(Scope *scope, LLVMValueRef pointer, char *name)
|
||||
{
|
||||
static void AddLocalVariable(
|
||||
Scope *scope,
|
||||
LLVMValueRef pointer, /* can be NULL */
|
||||
LLVMValueRef value, /* can be NULL */
|
||||
char *name
|
||||
) {
|
||||
ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1];
|
||||
uint32_t index = scopeFrame->localVariableCount;
|
||||
|
||||
scopeFrame->localVariables = realloc(scopeFrame->localVariables, sizeof(LocalVariable) * (scopeFrame->localVariableCount + 1));
|
||||
scopeFrame->localVariables[index].name = strdup(name);
|
||||
scopeFrame->localVariables[index].pointer = pointer;
|
||||
scopeFrame->localVariables[index].value = value;
|
||||
|
||||
scopeFrame->localVariableCount += 1;
|
||||
}
|
||||
|
@ -220,7 +226,14 @@ static LLVMValueRef FindVariableValue(LLVMBuilderRef builder, char *name)
|
|||
{
|
||||
if (strcmp(scope->scopeStack[i].localVariables[j].name, name) == 0)
|
||||
{
|
||||
return LLVMBuildLoad(builder, scope->scopeStack[i].localVariables[j].pointer, name);
|
||||
if (scope->scopeStack[i].localVariables[j].value != NULL)
|
||||
{
|
||||
return scope->scopeStack[i].localVariables[j].value;
|
||||
}
|
||||
else
|
||||
{
|
||||
return LLVMBuildLoad(builder, scope->scopeStack[i].localVariables[j].pointer, name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -414,6 +427,7 @@ static void AddStructVariablesToScope(
|
|||
AddLocalVariable(
|
||||
scope,
|
||||
elementPointer,
|
||||
NULL,
|
||||
structTypeDeclarations[i].fields[j].name
|
||||
);
|
||||
}
|
||||
|
@ -456,6 +470,15 @@ static LLVMValueRef CompileBinaryExpression(
|
|||
|
||||
case GreaterThan:
|
||||
return LLVMBuildICmp(builder, LLVMIntSGT, left, right, "greaterThanResult");
|
||||
|
||||
case Mod:
|
||||
return LLVMBuildSRem(builder, left, right, "modResult");
|
||||
|
||||
case Equal:
|
||||
return LLVMBuildICmp(builder, LLVMIntEQ, left, right, "equalResult");
|
||||
|
||||
case LogicalOr:
|
||||
return LLVMBuildOr(builder, left, right, "orResult");
|
||||
}
|
||||
|
||||
return NULL;
|
||||
|
@ -581,20 +604,22 @@ static LLVMValueRef CompileExpression(
|
|||
return NULL;
|
||||
}
|
||||
|
||||
static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement);
|
||||
static LLVMBasicBlockRef CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement);
|
||||
|
||||
static void CompileReturn(LLVMBuilderRef builder, Node *returnStatemement)
|
||||
static LLVMBasicBlockRef CompileReturn(LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement)
|
||||
{
|
||||
LLVMValueRef expression = CompileExpression(builder, returnStatemement->children[0]);
|
||||
LLVMBuildRet(builder, expression);
|
||||
return LLVMGetLastBasicBlock(function);
|
||||
}
|
||||
|
||||
static void CompileReturnVoid(LLVMBuilderRef builder)
|
||||
static LLVMBasicBlockRef CompileReturnVoid(LLVMBuilderRef builder, LLVMValueRef function)
|
||||
{
|
||||
LLVMBuildRetVoid(builder);
|
||||
return LLVMGetLastBasicBlock(function);
|
||||
}
|
||||
|
||||
static void CompileAssignment(LLVMBuilderRef builder, Node *assignmentStatement)
|
||||
static LLVMBasicBlockRef CompileAssignment(LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement)
|
||||
{
|
||||
LLVMValueRef result = CompileExpression(builder, assignmentStatement->children[1]);
|
||||
LLVMValueRef identifier;
|
||||
|
@ -609,14 +634,16 @@ static void CompileAssignment(LLVMBuilderRef builder, Node *assignmentStatement)
|
|||
else
|
||||
{
|
||||
printf("Identifier not found!");
|
||||
return;
|
||||
return LLVMGetLastBasicBlock(function);
|
||||
}
|
||||
|
||||
LLVMBuildStore(builder, result, identifier);
|
||||
|
||||
return LLVMGetLastBasicBlock(function);
|
||||
}
|
||||
|
||||
/* FIXME: path for reference types */
|
||||
static void CompileFunctionVariableDeclaration(LLVMBuilderRef builder, Node *variableDeclaration)
|
||||
static LLVMBasicBlockRef CompileFunctionVariableDeclaration(LLVMBuilderRef builder, LLVMValueRef function, Node *variableDeclaration)
|
||||
{
|
||||
LLVMValueRef variable;
|
||||
char *variableName = variableDeclaration->children[1]->value.string;
|
||||
|
@ -627,10 +654,12 @@ static void CompileFunctionVariableDeclaration(LLVMBuilderRef builder, Node *var
|
|||
|
||||
free(ptrName);
|
||||
|
||||
AddLocalVariable(scope, variable, variableName);
|
||||
AddLocalVariable(scope, variable, NULL, variableName);
|
||||
|
||||
return LLVMGetLastBasicBlock(function);
|
||||
}
|
||||
|
||||
static void CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement)
|
||||
static LLVMBasicBlockRef CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement)
|
||||
{
|
||||
uint32_t i;
|
||||
LLVMValueRef conditional = CompileExpression(builder, ifStatement->children[0]);
|
||||
|
@ -649,9 +678,11 @@ static void CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef function, No
|
|||
|
||||
LLVMBuildBr(builder, afterCond);
|
||||
LLVMPositionBuilderAtEnd(builder, afterCond);
|
||||
|
||||
return afterCond;
|
||||
}
|
||||
|
||||
static void CompileIfElseStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifElseStatement)
|
||||
static LLVMBasicBlockRef CompileIfElseStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifElseStatement)
|
||||
{
|
||||
uint32_t i;
|
||||
LLVMValueRef conditional = CompileExpression(builder, ifElseStatement->children[0]->children[0]);
|
||||
|
@ -686,45 +717,102 @@ static void CompileIfElseStatement(LLVMBuilderRef builder, LLVMValueRef function
|
|||
}
|
||||
|
||||
LLVMBuildBr(builder, afterCond);
|
||||
|
||||
LLVMPositionBuilderAtEnd(builder, afterCond);
|
||||
|
||||
return afterCond;
|
||||
}
|
||||
|
||||
static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement)
|
||||
static LLVMBasicBlockRef CompileForLoopStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *forLoopStatement)
|
||||
{
|
||||
uint32_t i;
|
||||
LLVMBasicBlockRef entryBlock = LLVMAppendBasicBlock(function, "loopEntry");
|
||||
LLVMBasicBlockRef checkBlock = LLVMAppendBasicBlock(function, "loopCheck");
|
||||
LLVMBasicBlockRef bodyBlock = LLVMAppendBasicBlock(function, "loopBody");
|
||||
LLVMBasicBlockRef afterLoopBlock = LLVMAppendBasicBlock(function, "afterLoop");
|
||||
char *iteratorVariableName = forLoopStatement->children[0]->value.string;
|
||||
|
||||
PushScopeFrame(scope);
|
||||
|
||||
LLVMBuildBr(builder, entryBlock);
|
||||
|
||||
LLVMPositionBuilderAtEnd(builder, entryBlock);
|
||||
LLVMBuildBr(builder, checkBlock);
|
||||
|
||||
LLVMPositionBuilderAtEnd(builder, checkBlock);
|
||||
LLVMValueRef iteratorValue = LLVMBuildPhi(builder, LLVMInt64Type(), iteratorVariableName);
|
||||
AddLocalVariable(scope, NULL, iteratorValue, iteratorVariableName);
|
||||
|
||||
LLVMPositionBuilderAtEnd(builder, bodyBlock);
|
||||
LLVMValueRef nextValue = LLVMBuildAdd(builder, iteratorValue, LLVMConstInt(LLVMInt64Type(), 1, 0), "next");
|
||||
|
||||
LLVMPositionBuilderAtEnd(builder, checkBlock);
|
||||
|
||||
LLVMValueRef iteratorEndValue = CompileNumber(forLoopStatement->children[2]);
|
||||
LLVMValueRef comparison = LLVMBuildICmp(builder, LLVMIntULE, iteratorValue, iteratorEndValue, "iteratorCompare");
|
||||
|
||||
LLVMBuildCondBr(builder, comparison, bodyBlock, afterLoopBlock);
|
||||
|
||||
LLVMPositionBuilderAtEnd(builder, bodyBlock);
|
||||
|
||||
LLVMBasicBlockRef lastBlock;
|
||||
for (i = 0; i < forLoopStatement->children[3]->childCount; i += 1)
|
||||
{
|
||||
lastBlock = CompileStatement(builder, function, forLoopStatement->children[3]->children[i]);
|
||||
}
|
||||
|
||||
LLVMBuildBr(builder, checkBlock);
|
||||
|
||||
LLVMPositionBuilderBefore(builder, LLVMGetFirstInstruction(checkBlock));
|
||||
|
||||
LLVMValueRef incomingValues[2];
|
||||
incomingValues[0] = CompileNumber(forLoopStatement->children[1]);
|
||||
incomingValues[1] = nextValue;
|
||||
|
||||
LLVMBasicBlockRef incomingBlocks[2];
|
||||
incomingBlocks[0] = entryBlock;
|
||||
incomingBlocks[1] = lastBlock;
|
||||
|
||||
LLVMAddIncoming(iteratorValue, incomingValues, incomingBlocks, 2);
|
||||
|
||||
LLVMPositionBuilderAtEnd(builder, afterLoopBlock);
|
||||
|
||||
PopScopeFrame(scope);
|
||||
|
||||
return afterLoopBlock;
|
||||
}
|
||||
|
||||
static LLVMBasicBlockRef CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement)
|
||||
{
|
||||
switch (statement->syntaxKind)
|
||||
{
|
||||
case Assignment:
|
||||
CompileAssignment(builder, statement);
|
||||
return 0;
|
||||
return CompileAssignment(builder, function, statement);
|
||||
|
||||
case Declaration:
|
||||
return CompileFunctionVariableDeclaration(builder, function, statement);
|
||||
|
||||
case ForLoop:
|
||||
return CompileForLoopStatement(builder, function, statement);
|
||||
|
||||
case FunctionCallExpression:
|
||||
CompileFunctionCallExpression(builder, statement);
|
||||
return 0;
|
||||
|
||||
case Declaration:
|
||||
CompileFunctionVariableDeclaration(builder, statement);
|
||||
return 0;
|
||||
return LLVMGetLastBasicBlock(function);
|
||||
|
||||
case IfStatement:
|
||||
CompileIfStatement(builder, function, statement);
|
||||
return 0;
|
||||
return CompileIfStatement(builder, function, statement);
|
||||
|
||||
case IfElseStatement:
|
||||
CompileIfElseStatement(builder, function, statement);
|
||||
return 0;
|
||||
return CompileIfElseStatement(builder, function, statement);
|
||||
|
||||
case Return:
|
||||
CompileReturn(builder, statement);
|
||||
return 1;
|
||||
return CompileReturn(builder, function, statement);
|
||||
|
||||
case ReturnVoid:
|
||||
CompileReturnVoid(builder);
|
||||
return 1;
|
||||
return CompileReturnVoid(builder, function);
|
||||
}
|
||||
|
||||
fprintf(stderr, "Unknown statement kind!\n");
|
||||
return 0;
|
||||
return NULL;
|
||||
}
|
||||
|
||||
static void CompileFunction(
|
||||
|
@ -800,14 +888,16 @@ static void CompileFunction(
|
|||
LLVMValueRef argumentCopy = LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName);
|
||||
LLVMBuildStore(builder, argument, argumentCopy);
|
||||
free(ptrName);
|
||||
AddLocalVariable(scope, argumentCopy, functionSignature->children[2]->children[i]->children[1]->value.string);
|
||||
AddLocalVariable(scope, argumentCopy, NULL, functionSignature->children[2]->children[i]->children[1]->value.string);
|
||||
}
|
||||
|
||||
for (i = 0; i < functionBody->childCount; i += 1)
|
||||
{
|
||||
hasReturn |= CompileStatement(builder, function, functionBody->children[i]);
|
||||
CompileStatement(builder, function, functionBody->children[i]);
|
||||
}
|
||||
|
||||
hasReturn = LLVMGetBasicBlockTerminator(LLVMGetLastBasicBlock(function)) != NULL;
|
||||
|
||||
if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn)
|
||||
{
|
||||
LLVMBuildRetVoid(builder);
|
||||
|
|
Loading…
Reference in New Issue