From a320086038b5b0b12f0407ab92fe2ef1f637720a Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Wed, 28 Apr 2021 21:25:25 -0700 Subject: [PATCH] add if statement --- example.w | 5 +++++ generators/wraith.lex | 2 ++ generators/wraith.y | 24 +++++++++++++++++++++++- src/ast.c | 13 +++++++++++++ src/ast.h | 8 +++++++- src/codegen.c | 38 ++++++++++++++++++++++++++++++++++---- 6 files changed, 84 insertions(+), 6 deletions(-) diff --git a/example.w b/example.w index 2adb74e..21f7736 100644 --- a/example.w +++ b/example.w @@ -36,6 +36,11 @@ struct Program myStruct.myInt = myInt; myStruct.Increment(); + if (myStruct.myInt < 5) + { + myStruct.Increment(); + } + return myStruct.myInt; } } diff --git a/generators/wraith.lex b/generators/wraith.lex index e571f31..b870eb3 100644 --- a/generators/wraith.lex +++ b/generators/wraith.lex @@ -17,6 +17,8 @@ "static" return STATIC; "Reference" return REFERENCE; "alloc" return ALLOC; +"if" return IF; +"else" return ELSE; [0-9]+ return NUMBER; [a-zA-Z][a-zA-Z0-9]* return ID; \"[a-zA-Z][a-zA-Z0-9]*\" return STRING_LITERAL; diff --git a/generators/wraith.y b/generators/wraith.y index bde907e..67e6f48 100644 --- a/generators/wraith.y +++ b/generators/wraith.y @@ -31,6 +31,8 @@ extern FILE *yyin; %token STATIC %token REFERENCE %token ALLOC +%token IF +%token ELSE %token NUMBER %token ID %token STRING_LITERAL @@ -178,6 +180,10 @@ BinaryExpression : Expression PLUS Expression { $$ = MakeBinaryNode(Multiply, $1, $3); } + | Expression LESS_THAN Expression + { + $$ = MakeBinaryNode(LessThan, $1, $3); + } Expression : PrimaryExpression | UnaryExpression @@ -229,7 +235,23 @@ PartialStatement : FunctionCallExpression | ReturnStatement ; -Statement : PartialStatement SEMICOLON; +IfStatement : IF LEFT_PAREN Expression RIGHT_PAREN LEFT_BRACE Statements RIGHT_BRACE + { + Node **statements; + Node *statementSequence; + uint32_t statementCount; + + statements = GetNodes(stack, &statementCount); + statementSequence = MakeStatementSequenceNode(statements, statementCount); + + $$ = MakeIfNode($3, statementSequence); + + PopStackFrame(stack); + } + +Statement : PartialStatement SEMICOLON + | IfStatement + ; Statements : Statement Statements { diff --git a/src/ast.c b/src/ast.c index ae2d254..23f87e0 100644 --- a/src/ast.c +++ b/src/ast.c @@ -376,6 +376,19 @@ Node* MakeAllocNode(Node *typeNode) return node; } +Node* MakeIfNode( + Node *expressionNode, + Node *statementSequenceNode +) { + Node* node = (Node*) malloc(sizeof(Node)); + node->syntaxKind = IfStatement; + node->childCount = 2; + node->children = (Node**) malloc(sizeof(Node*)); + node->children[0] = expressionNode; + node->children[1] = statementSequenceNode; + return node; +} + static const char* PrimitiveTypeToString(PrimitiveType type) { switch (type) diff --git a/src/ast.h b/src/ast.h index 51fe040..6f31bde 100644 --- a/src/ast.h +++ b/src/ast.h @@ -22,6 +22,7 @@ typedef enum FunctionSignature, FunctionSignatureArguments, Identifier, + IfStatement, Number, PrimitiveTypeNode, ReferenceTypeNode, @@ -44,7 +45,8 @@ typedef enum { Add, Subtract, - Multiply + Multiply, + LessThan } BinaryOperator; typedef enum @@ -174,6 +176,10 @@ Node* MakeAccessExpressionNode( Node* MakeAllocNode( Node *typeNode ); +Node* MakeIfNode( + Node *expressionNode, + Node *statementSequenceNode +); void PrintTree(Node *node, uint32_t tabCount); diff --git a/src/codegen.c b/src/codegen.c index 563eb6c..874e24f 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -423,7 +423,7 @@ static void AddStructVariablesToScope( static LLVMValueRef CompileExpression( LLVMBuilderRef builder, - Node *binaryExpression + Node *expression ); static LLVMValueRef CompileNumber( @@ -442,14 +442,17 @@ static LLVMValueRef CompileBinaryExpression( switch (binaryExpression->operator.binaryOperator) { case Add: - return LLVMBuildAdd(builder, left, right, "tmp"); + return LLVMBuildAdd(builder, left, right, "addResult"); case Subtract: - return LLVMBuildSub(builder, left, right, "tmp"); + return LLVMBuildSub(builder, left, right, "subtractResult"); case Multiply: - return LLVMBuildMul(builder, left, right, "tmp"); + return LLVMBuildMul(builder, left, right, "multiplyResult"); + /* FIXME: need type information for comparison */ + case LessThan: + return LLVMBuildICmp(builder, LLVMIntSLT, left, right, "compareResult"); } return NULL; @@ -575,6 +578,8 @@ static LLVMValueRef CompileExpression( return NULL; } +static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement); + static void CompileReturn(LLVMBuilderRef builder, Node *returnStatemement) { LLVMValueRef expression = CompileExpression(builder, returnStatemement->children[0]); @@ -622,6 +627,27 @@ static void CompileFunctionVariableDeclaration(LLVMBuilderRef builder, Node *var AddLocalVariable(scope, variable, variableName); } +static void CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement) +{ + uint32_t i; + LLVMValueRef conditional = CompileExpression(builder, ifStatement->children[0]); + + LLVMBasicBlockRef block = LLVMAppendBasicBlock(function, "ifBlock"); + LLVMBasicBlockRef afterCond = LLVMAppendBasicBlock(function, "afterCond"); + + LLVMBuildCondBr(builder, conditional, block, afterCond); + + LLVMPositionBuilderAtEnd(builder, block); + + for (i = 0; i < ifStatement->children[1]->childCount; i += 1) + { + CompileStatement(builder, function, ifStatement->children[1]->children[i]); + LLVMBuildBr(builder, afterCond); + } + + LLVMPositionBuilderAtEnd(builder, afterCond); +} + static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement) { switch (statement->syntaxKind) @@ -638,6 +664,10 @@ static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, N CompileFunctionVariableDeclaration(builder, statement); return 0; + case IfStatement: + CompileIfStatement(builder, function, statement); + return 0; + case Return: CompileReturn(builder, statement); return 1;