add if statement

pull/1/head
cosmonaut 2021-04-28 21:25:25 -07:00
parent 9a97b73c7c
commit a320086038
6 changed files with 84 additions and 6 deletions

View File

@ -36,6 +36,11 @@ struct Program
myStruct.myInt = myInt;
myStruct.Increment();
if (myStruct.myInt < 5)
{
myStruct.Increment();
}
return myStruct.myInt;
}
}

View File

@ -17,6 +17,8 @@
"static" return STATIC;
"Reference" return REFERENCE;
"alloc" return ALLOC;
"if" return IF;
"else" return ELSE;
[0-9]+ return NUMBER;
[a-zA-Z][a-zA-Z0-9]* return ID;
\"[a-zA-Z][a-zA-Z0-9]*\" return STRING_LITERAL;

View File

@ -31,6 +31,8 @@ extern FILE *yyin;
%token STATIC
%token REFERENCE
%token ALLOC
%token IF
%token ELSE
%token NUMBER
%token ID
%token STRING_LITERAL
@ -178,6 +180,10 @@ BinaryExpression : Expression PLUS Expression
{
$$ = MakeBinaryNode(Multiply, $1, $3);
}
| Expression LESS_THAN Expression
{
$$ = MakeBinaryNode(LessThan, $1, $3);
}
Expression : PrimaryExpression
| UnaryExpression
@ -229,7 +235,23 @@ PartialStatement : FunctionCallExpression
| ReturnStatement
;
Statement : PartialStatement SEMICOLON;
IfStatement : IF LEFT_PAREN Expression RIGHT_PAREN LEFT_BRACE Statements RIGHT_BRACE
{
Node **statements;
Node *statementSequence;
uint32_t statementCount;
statements = GetNodes(stack, &statementCount);
statementSequence = MakeStatementSequenceNode(statements, statementCount);
$$ = MakeIfNode($3, statementSequence);
PopStackFrame(stack);
}
Statement : PartialStatement SEMICOLON
| IfStatement
;
Statements : Statement Statements
{

View File

@ -376,6 +376,19 @@ Node* MakeAllocNode(Node *typeNode)
return node;
}
Node* MakeIfNode(
Node *expressionNode,
Node *statementSequenceNode
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = IfStatement;
node->childCount = 2;
node->children = (Node**) malloc(sizeof(Node*));
node->children[0] = expressionNode;
node->children[1] = statementSequenceNode;
return node;
}
static const char* PrimitiveTypeToString(PrimitiveType type)
{
switch (type)

View File

@ -22,6 +22,7 @@ typedef enum
FunctionSignature,
FunctionSignatureArguments,
Identifier,
IfStatement,
Number,
PrimitiveTypeNode,
ReferenceTypeNode,
@ -44,7 +45,8 @@ typedef enum
{
Add,
Subtract,
Multiply
Multiply,
LessThan
} BinaryOperator;
typedef enum
@ -174,6 +176,10 @@ Node* MakeAccessExpressionNode(
Node* MakeAllocNode(
Node *typeNode
);
Node* MakeIfNode(
Node *expressionNode,
Node *statementSequenceNode
);
void PrintTree(Node *node, uint32_t tabCount);

View File

@ -423,7 +423,7 @@ static void AddStructVariablesToScope(
static LLVMValueRef CompileExpression(
LLVMBuilderRef builder,
Node *binaryExpression
Node *expression
);
static LLVMValueRef CompileNumber(
@ -442,14 +442,17 @@ static LLVMValueRef CompileBinaryExpression(
switch (binaryExpression->operator.binaryOperator)
{
case Add:
return LLVMBuildAdd(builder, left, right, "tmp");
return LLVMBuildAdd(builder, left, right, "addResult");
case Subtract:
return LLVMBuildSub(builder, left, right, "tmp");
return LLVMBuildSub(builder, left, right, "subtractResult");
case Multiply:
return LLVMBuildMul(builder, left, right, "tmp");
return LLVMBuildMul(builder, left, right, "multiplyResult");
/* FIXME: need type information for comparison */
case LessThan:
return LLVMBuildICmp(builder, LLVMIntSLT, left, right, "compareResult");
}
return NULL;
@ -575,6 +578,8 @@ static LLVMValueRef CompileExpression(
return NULL;
}
static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement);
static void CompileReturn(LLVMBuilderRef builder, Node *returnStatemement)
{
LLVMValueRef expression = CompileExpression(builder, returnStatemement->children[0]);
@ -622,6 +627,27 @@ static void CompileFunctionVariableDeclaration(LLVMBuilderRef builder, Node *var
AddLocalVariable(scope, variable, variableName);
}
static void CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement)
{
uint32_t i;
LLVMValueRef conditional = CompileExpression(builder, ifStatement->children[0]);
LLVMBasicBlockRef block = LLVMAppendBasicBlock(function, "ifBlock");
LLVMBasicBlockRef afterCond = LLVMAppendBasicBlock(function, "afterCond");
LLVMBuildCondBr(builder, conditional, block, afterCond);
LLVMPositionBuilderAtEnd(builder, block);
for (i = 0; i < ifStatement->children[1]->childCount; i += 1)
{
CompileStatement(builder, function, ifStatement->children[1]->children[i]);
LLVMBuildBr(builder, afterCond);
}
LLVMPositionBuilderAtEnd(builder, afterCond);
}
static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement)
{
switch (statement->syntaxKind)
@ -638,6 +664,10 @@ static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, N
CompileFunctionVariableDeclaration(builder, statement);
return 0;
case IfStatement:
CompileIfStatement(builder, function, statement);
return 0;
case Return:
CompileReturn(builder, statement);
return 1;