add if statement

pull/1/head
cosmonaut 2021-04-28 21:25:25 -07:00
parent 9a97b73c7c
commit a320086038
6 changed files with 84 additions and 6 deletions

View File

@ -36,6 +36,11 @@ struct Program
myStruct.myInt = myInt; myStruct.myInt = myInt;
myStruct.Increment(); myStruct.Increment();
if (myStruct.myInt < 5)
{
myStruct.Increment();
}
return myStruct.myInt; return myStruct.myInt;
} }
} }

View File

@ -17,6 +17,8 @@
"static" return STATIC; "static" return STATIC;
"Reference" return REFERENCE; "Reference" return REFERENCE;
"alloc" return ALLOC; "alloc" return ALLOC;
"if" return IF;
"else" return ELSE;
[0-9]+ return NUMBER; [0-9]+ return NUMBER;
[a-zA-Z][a-zA-Z0-9]* return ID; [a-zA-Z][a-zA-Z0-9]* return ID;
\"[a-zA-Z][a-zA-Z0-9]*\" return STRING_LITERAL; \"[a-zA-Z][a-zA-Z0-9]*\" return STRING_LITERAL;

View File

@ -31,6 +31,8 @@ extern FILE *yyin;
%token STATIC %token STATIC
%token REFERENCE %token REFERENCE
%token ALLOC %token ALLOC
%token IF
%token ELSE
%token NUMBER %token NUMBER
%token ID %token ID
%token STRING_LITERAL %token STRING_LITERAL
@ -178,6 +180,10 @@ BinaryExpression : Expression PLUS Expression
{ {
$$ = MakeBinaryNode(Multiply, $1, $3); $$ = MakeBinaryNode(Multiply, $1, $3);
} }
| Expression LESS_THAN Expression
{
$$ = MakeBinaryNode(LessThan, $1, $3);
}
Expression : PrimaryExpression Expression : PrimaryExpression
| UnaryExpression | UnaryExpression
@ -229,7 +235,23 @@ PartialStatement : FunctionCallExpression
| ReturnStatement | ReturnStatement
; ;
Statement : PartialStatement SEMICOLON; IfStatement : IF LEFT_PAREN Expression RIGHT_PAREN LEFT_BRACE Statements RIGHT_BRACE
{
Node **statements;
Node *statementSequence;
uint32_t statementCount;
statements = GetNodes(stack, &statementCount);
statementSequence = MakeStatementSequenceNode(statements, statementCount);
$$ = MakeIfNode($3, statementSequence);
PopStackFrame(stack);
}
Statement : PartialStatement SEMICOLON
| IfStatement
;
Statements : Statement Statements Statements : Statement Statements
{ {

View File

@ -376,6 +376,19 @@ Node* MakeAllocNode(Node *typeNode)
return node; return node;
} }
Node* MakeIfNode(
Node *expressionNode,
Node *statementSequenceNode
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = IfStatement;
node->childCount = 2;
node->children = (Node**) malloc(sizeof(Node*));
node->children[0] = expressionNode;
node->children[1] = statementSequenceNode;
return node;
}
static const char* PrimitiveTypeToString(PrimitiveType type) static const char* PrimitiveTypeToString(PrimitiveType type)
{ {
switch (type) switch (type)

View File

@ -22,6 +22,7 @@ typedef enum
FunctionSignature, FunctionSignature,
FunctionSignatureArguments, FunctionSignatureArguments,
Identifier, Identifier,
IfStatement,
Number, Number,
PrimitiveTypeNode, PrimitiveTypeNode,
ReferenceTypeNode, ReferenceTypeNode,
@ -44,7 +45,8 @@ typedef enum
{ {
Add, Add,
Subtract, Subtract,
Multiply Multiply,
LessThan
} BinaryOperator; } BinaryOperator;
typedef enum typedef enum
@ -174,6 +176,10 @@ Node* MakeAccessExpressionNode(
Node* MakeAllocNode( Node* MakeAllocNode(
Node *typeNode Node *typeNode
); );
Node* MakeIfNode(
Node *expressionNode,
Node *statementSequenceNode
);
void PrintTree(Node *node, uint32_t tabCount); void PrintTree(Node *node, uint32_t tabCount);

View File

@ -423,7 +423,7 @@ static void AddStructVariablesToScope(
static LLVMValueRef CompileExpression( static LLVMValueRef CompileExpression(
LLVMBuilderRef builder, LLVMBuilderRef builder,
Node *binaryExpression Node *expression
); );
static LLVMValueRef CompileNumber( static LLVMValueRef CompileNumber(
@ -442,14 +442,17 @@ static LLVMValueRef CompileBinaryExpression(
switch (binaryExpression->operator.binaryOperator) switch (binaryExpression->operator.binaryOperator)
{ {
case Add: case Add:
return LLVMBuildAdd(builder, left, right, "tmp"); return LLVMBuildAdd(builder, left, right, "addResult");
case Subtract: case Subtract:
return LLVMBuildSub(builder, left, right, "tmp"); return LLVMBuildSub(builder, left, right, "subtractResult");
case Multiply: case Multiply:
return LLVMBuildMul(builder, left, right, "tmp"); return LLVMBuildMul(builder, left, right, "multiplyResult");
/* FIXME: need type information for comparison */
case LessThan:
return LLVMBuildICmp(builder, LLVMIntSLT, left, right, "compareResult");
} }
return NULL; return NULL;
@ -575,6 +578,8 @@ static LLVMValueRef CompileExpression(
return NULL; return NULL;
} }
static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement);
static void CompileReturn(LLVMBuilderRef builder, Node *returnStatemement) static void CompileReturn(LLVMBuilderRef builder, Node *returnStatemement)
{ {
LLVMValueRef expression = CompileExpression(builder, returnStatemement->children[0]); LLVMValueRef expression = CompileExpression(builder, returnStatemement->children[0]);
@ -622,6 +627,27 @@ static void CompileFunctionVariableDeclaration(LLVMBuilderRef builder, Node *var
AddLocalVariable(scope, variable, variableName); AddLocalVariable(scope, variable, variableName);
} }
static void CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement)
{
uint32_t i;
LLVMValueRef conditional = CompileExpression(builder, ifStatement->children[0]);
LLVMBasicBlockRef block = LLVMAppendBasicBlock(function, "ifBlock");
LLVMBasicBlockRef afterCond = LLVMAppendBasicBlock(function, "afterCond");
LLVMBuildCondBr(builder, conditional, block, afterCond);
LLVMPositionBuilderAtEnd(builder, block);
for (i = 0; i < ifStatement->children[1]->childCount; i += 1)
{
CompileStatement(builder, function, ifStatement->children[1]->children[i]);
LLVMBuildBr(builder, afterCond);
}
LLVMPositionBuilderAtEnd(builder, afterCond);
}
static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement) static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement)
{ {
switch (statement->syntaxKind) switch (statement->syntaxKind)
@ -638,6 +664,10 @@ static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, N
CompileFunctionVariableDeclaration(builder, statement); CompileFunctionVariableDeclaration(builder, statement);
return 0; return 0;
case IfStatement:
CompileIfStatement(builder, function, statement);
return 0;
case Return: case Return:
CompileReturn(builder, statement); CompileReturn(builder, statement);
return 1; return 1;