diff --git a/euler001.w b/euler001.w new file mode 100644 index 0000000..2c5c49e --- /dev/null +++ b/euler001.w @@ -0,0 +1,19 @@ +struct Program +{ + static Main(): int + { + sum: int; + + sum = 0; + + for (i in [1..1000]) + { + if ((i % 3 == 0) || (i % 5 == 0)) + { + sum = sum + i; + } + } + + return sum; + } +} diff --git a/generators/wraith.lex b/generators/wraith.lex index b870eb3..ee84f7b 100644 --- a/generators/wraith.lex +++ b/generators/wraith.lex @@ -19,6 +19,8 @@ "alloc" return ALLOC; "if" return IF; "else" return ELSE; +"in" return IN; +"for" return FOR; [0-9]+ return NUMBER; [a-zA-Z][a-zA-Z0-9]* return ID; \"[a-zA-Z][a-zA-Z0-9]*\" return STRING_LITERAL; diff --git a/generators/wraith.y b/generators/wraith.y index 831bb76..fc3e0ab 100644 --- a/generators/wraith.y +++ b/generators/wraith.y @@ -32,6 +32,8 @@ extern FILE *yyin; %token ALLOC %token IF %token ELSE +%token IN +%token FOR %token NUMBER %token ID %token STRING_LITERAL @@ -65,10 +67,10 @@ extern FILE *yyin; %define parse.error verbose -%left GREATER_THAN LESS_THAN +%left GREATER_THAN LESS_THAN EQUAL %left PLUS MINUS -%left STAR -%left BANG +%left STAR PERCENT +%left BANG BAR %left LEFT_PAREN RIGHT_PAREN %% @@ -139,10 +141,12 @@ AccessExpression : Identifier POINT AccessExpression $$ = $1; } -PrimaryExpression : NUMBER +Number : NUMBER { $$ = MakeNumberNode(yytext); } + +PrimaryExpression : Number | STRING { $$ = MakeStringNode(yytext); @@ -172,6 +176,10 @@ BinaryExpression : Expression PLUS Expression { $$ = MakeBinaryNode(Multiply, $1, $3); } + | Expression PERCENT Expression + { + $$ = MakeBinaryNode(Mod, $1, $3); + } | Expression LESS_THAN Expression { $$ = MakeBinaryNode(LessThan, $1, $3); @@ -180,6 +188,14 @@ BinaryExpression : Expression PLUS Expression { $$ = MakeBinaryNode(GreaterThan, $1, $3); } + | Expression EQUAL EQUAL Expression + { + $$ = MakeBinaryNode(Equal, $1, $4); + } + | Expression BAR BAR Expression + { + $$ = MakeBinaryNode(LogicalOr, $1, $4); + } Expression : BinaryExpression | UnaryExpression @@ -240,8 +256,14 @@ Conditional : IfStatement $$ = MakeIfElseNode($1, $3); } +ForStatement : FOR LEFT_PAREN Identifier IN LEFT_BRACKET Number POINT POINT Number RIGHT_BRACKET RIGHT_PAREN LEFT_BRACE Statements RIGHT_BRACE + { + $$ = MakeForLoopNode($3, $6, $9, $13); + } + Statement : PartialStatement SEMICOLON | Conditional + | ForStatement ; Statements : Statement diff --git a/src/ast.c b/src/ast.c index b736938..24bdeb9 100644 --- a/src/ast.c +++ b/src/ast.c @@ -440,6 +440,23 @@ Node* MakeIfElseNode( return node; } +Node* MakeForLoopNode( + Node *identifierNode, + Node *startNumberNode, + Node *endNumberNode, + Node *statementSequenceNode +) { + Node* node = (Node*) malloc(sizeof(Node)); + node->syntaxKind = ForLoop; + node->childCount = 4; + node->children = (Node**) malloc(sizeof(Node*) * 4); + node->children[0] = identifierNode; + node->children[1] = startNumberNode; + node->children[2] = endNumberNode; + node->children[3] = statementSequenceNode; + return node; +} + static const char* PrimitiveTypeToString(PrimitiveType type) { switch (type) diff --git a/src/ast.h b/src/ast.h index ac65995..8e49ace 100644 --- a/src/ast.h +++ b/src/ast.h @@ -47,8 +47,11 @@ typedef enum Add, Subtract, Multiply, + Mod, + Equal, LessThan, - GreaterThan + GreaterThan, + LogicalOr } BinaryOperator; typedef enum @@ -200,6 +203,12 @@ Node* MakeIfElseNode( Node *ifNode, Node *statementSequenceNode ); +Node* MakeForLoopNode( + Node *identifierNode, + Node *startNumberNode, + Node *endNumberNode, + Node *statementSequenceNode +); void PrintTree(Node *node, uint32_t tabCount); diff --git a/src/codegen.c b/src/codegen.c index 4719317..29f48c2 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -19,6 +19,7 @@ typedef struct LocalVariable { char *name; LLVMValueRef pointer; + LLVMValueRef value; } LocalVariable; typedef struct FunctionArgument @@ -111,14 +112,19 @@ static void PopScopeFrame(Scope *scope) scope->scopeStack = realloc(scope->scopeStack, sizeof(ScopeFrame) * scope->scopeStackCount); } -static void AddLocalVariable(Scope *scope, LLVMValueRef pointer, char *name) -{ +static void AddLocalVariable( + Scope *scope, + LLVMValueRef pointer, /* can be NULL */ + LLVMValueRef value, /* can be NULL */ + char *name +) { ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1]; uint32_t index = scopeFrame->localVariableCount; scopeFrame->localVariables = realloc(scopeFrame->localVariables, sizeof(LocalVariable) * (scopeFrame->localVariableCount + 1)); scopeFrame->localVariables[index].name = strdup(name); scopeFrame->localVariables[index].pointer = pointer; + scopeFrame->localVariables[index].value = value; scopeFrame->localVariableCount += 1; } @@ -220,7 +226,14 @@ static LLVMValueRef FindVariableValue(LLVMBuilderRef builder, char *name) { if (strcmp(scope->scopeStack[i].localVariables[j].name, name) == 0) { - return LLVMBuildLoad(builder, scope->scopeStack[i].localVariables[j].pointer, name); + if (scope->scopeStack[i].localVariables[j].value != NULL) + { + return scope->scopeStack[i].localVariables[j].value; + } + else + { + return LLVMBuildLoad(builder, scope->scopeStack[i].localVariables[j].pointer, name); + } } } } @@ -414,6 +427,7 @@ static void AddStructVariablesToScope( AddLocalVariable( scope, elementPointer, + NULL, structTypeDeclarations[i].fields[j].name ); } @@ -456,6 +470,15 @@ static LLVMValueRef CompileBinaryExpression( case GreaterThan: return LLVMBuildICmp(builder, LLVMIntSGT, left, right, "greaterThanResult"); + + case Mod: + return LLVMBuildSRem(builder, left, right, "modResult"); + + case Equal: + return LLVMBuildICmp(builder, LLVMIntEQ, left, right, "equalResult"); + + case LogicalOr: + return LLVMBuildOr(builder, left, right, "orResult"); } return NULL; @@ -581,20 +604,22 @@ static LLVMValueRef CompileExpression( return NULL; } -static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement); +static LLVMBasicBlockRef CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement); -static void CompileReturn(LLVMBuilderRef builder, Node *returnStatemement) +static LLVMBasicBlockRef CompileReturn(LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement) { LLVMValueRef expression = CompileExpression(builder, returnStatemement->children[0]); LLVMBuildRet(builder, expression); + return LLVMGetLastBasicBlock(function); } -static void CompileReturnVoid(LLVMBuilderRef builder) +static LLVMBasicBlockRef CompileReturnVoid(LLVMBuilderRef builder, LLVMValueRef function) { LLVMBuildRetVoid(builder); + return LLVMGetLastBasicBlock(function); } -static void CompileAssignment(LLVMBuilderRef builder, Node *assignmentStatement) +static LLVMBasicBlockRef CompileAssignment(LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement) { LLVMValueRef result = CompileExpression(builder, assignmentStatement->children[1]); LLVMValueRef identifier; @@ -609,14 +634,16 @@ static void CompileAssignment(LLVMBuilderRef builder, Node *assignmentStatement) else { printf("Identifier not found!"); - return; + return LLVMGetLastBasicBlock(function); } LLVMBuildStore(builder, result, identifier); + + return LLVMGetLastBasicBlock(function); } /* FIXME: path for reference types */ -static void CompileFunctionVariableDeclaration(LLVMBuilderRef builder, Node *variableDeclaration) +static LLVMBasicBlockRef CompileFunctionVariableDeclaration(LLVMBuilderRef builder, LLVMValueRef function, Node *variableDeclaration) { LLVMValueRef variable; char *variableName = variableDeclaration->children[1]->value.string; @@ -627,10 +654,12 @@ static void CompileFunctionVariableDeclaration(LLVMBuilderRef builder, Node *var free(ptrName); - AddLocalVariable(scope, variable, variableName); + AddLocalVariable(scope, variable, NULL, variableName); + + return LLVMGetLastBasicBlock(function); } -static void CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement) +static LLVMBasicBlockRef CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement) { uint32_t i; LLVMValueRef conditional = CompileExpression(builder, ifStatement->children[0]); @@ -649,9 +678,11 @@ static void CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef function, No LLVMBuildBr(builder, afterCond); LLVMPositionBuilderAtEnd(builder, afterCond); + + return afterCond; } -static void CompileIfElseStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifElseStatement) +static LLVMBasicBlockRef CompileIfElseStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifElseStatement) { uint32_t i; LLVMValueRef conditional = CompileExpression(builder, ifElseStatement->children[0]->children[0]); @@ -686,45 +717,102 @@ static void CompileIfElseStatement(LLVMBuilderRef builder, LLVMValueRef function } LLVMBuildBr(builder, afterCond); - LLVMPositionBuilderAtEnd(builder, afterCond); + + return afterCond; } -static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement) +static LLVMBasicBlockRef CompileForLoopStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *forLoopStatement) +{ + uint32_t i; + LLVMBasicBlockRef entryBlock = LLVMAppendBasicBlock(function, "loopEntry"); + LLVMBasicBlockRef checkBlock = LLVMAppendBasicBlock(function, "loopCheck"); + LLVMBasicBlockRef bodyBlock = LLVMAppendBasicBlock(function, "loopBody"); + LLVMBasicBlockRef afterLoopBlock = LLVMAppendBasicBlock(function, "afterLoop"); + char *iteratorVariableName = forLoopStatement->children[0]->value.string; + + PushScopeFrame(scope); + + LLVMBuildBr(builder, entryBlock); + + LLVMPositionBuilderAtEnd(builder, entryBlock); + LLVMBuildBr(builder, checkBlock); + + LLVMPositionBuilderAtEnd(builder, checkBlock); + LLVMValueRef iteratorValue = LLVMBuildPhi(builder, LLVMInt64Type(), iteratorVariableName); + AddLocalVariable(scope, NULL, iteratorValue, iteratorVariableName); + + LLVMPositionBuilderAtEnd(builder, bodyBlock); + LLVMValueRef nextValue = LLVMBuildAdd(builder, iteratorValue, LLVMConstInt(LLVMInt64Type(), 1, 0), "next"); + + LLVMPositionBuilderAtEnd(builder, checkBlock); + + LLVMValueRef iteratorEndValue = CompileNumber(forLoopStatement->children[2]); + LLVMValueRef comparison = LLVMBuildICmp(builder, LLVMIntULE, iteratorValue, iteratorEndValue, "iteratorCompare"); + + LLVMBuildCondBr(builder, comparison, bodyBlock, afterLoopBlock); + + LLVMPositionBuilderAtEnd(builder, bodyBlock); + + LLVMBasicBlockRef lastBlock; + for (i = 0; i < forLoopStatement->children[3]->childCount; i += 1) + { + lastBlock = CompileStatement(builder, function, forLoopStatement->children[3]->children[i]); + } + + LLVMBuildBr(builder, checkBlock); + + LLVMPositionBuilderBefore(builder, LLVMGetFirstInstruction(checkBlock)); + + LLVMValueRef incomingValues[2]; + incomingValues[0] = CompileNumber(forLoopStatement->children[1]); + incomingValues[1] = nextValue; + + LLVMBasicBlockRef incomingBlocks[2]; + incomingBlocks[0] = entryBlock; + incomingBlocks[1] = lastBlock; + + LLVMAddIncoming(iteratorValue, incomingValues, incomingBlocks, 2); + + LLVMPositionBuilderAtEnd(builder, afterLoopBlock); + + PopScopeFrame(scope); + + return afterLoopBlock; +} + +static LLVMBasicBlockRef CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement) { switch (statement->syntaxKind) { case Assignment: - CompileAssignment(builder, statement); - return 0; + return CompileAssignment(builder, function, statement); + + case Declaration: + return CompileFunctionVariableDeclaration(builder, function, statement); + + case ForLoop: + return CompileForLoopStatement(builder, function, statement); case FunctionCallExpression: CompileFunctionCallExpression(builder, statement); - return 0; - - case Declaration: - CompileFunctionVariableDeclaration(builder, statement); - return 0; + return LLVMGetLastBasicBlock(function); case IfStatement: - CompileIfStatement(builder, function, statement); - return 0; + return CompileIfStatement(builder, function, statement); case IfElseStatement: - CompileIfElseStatement(builder, function, statement); - return 0; + return CompileIfElseStatement(builder, function, statement); case Return: - CompileReturn(builder, statement); - return 1; + return CompileReturn(builder, function, statement); case ReturnVoid: - CompileReturnVoid(builder); - return 1; + return CompileReturnVoid(builder, function); } fprintf(stderr, "Unknown statement kind!\n"); - return 0; + return NULL; } static void CompileFunction( @@ -800,14 +888,16 @@ static void CompileFunction( LLVMValueRef argumentCopy = LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName); LLVMBuildStore(builder, argument, argumentCopy); free(ptrName); - AddLocalVariable(scope, argumentCopy, functionSignature->children[2]->children[i]->children[1]->value.string); + AddLocalVariable(scope, argumentCopy, NULL, functionSignature->children[2]->children[i]->children[1]->value.string); } for (i = 0; i < functionBody->childCount; i += 1) { - hasReturn |= CompileStatement(builder, function, functionBody->children[i]); + CompileStatement(builder, function, functionBody->children[i]); } + hasReturn = LLVMGetBasicBlockTerminator(LLVMGetLastBasicBlock(function)) != NULL; + if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn) { LLVMBuildRetVoid(builder);