From b5d256251eaf4673bee11620f7f1004503718a17 Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Tue, 20 Apr 2021 10:47:40 -0700 Subject: [PATCH] support function calls --- ast.c | 38 +++++++++++++++++++++++++++++++++-- ast.h | 13 +++++++++++- compiler.c | 49 ++++++++++++++++++++++++++++++++++++--------- stack.c | 51 +++++++++++++--------------------------------- stack.h | 18 ++++++----------- wraith.lex | 6 +++--- wraith.y | 59 +++++++++++++++++++++++++++++++++++++++++------------- 7 files changed, 155 insertions(+), 79 deletions(-) diff --git a/ast.c b/ast.c index a8634c8..6d3b123 100644 --- a/ast.c +++ b/ast.c @@ -26,6 +26,8 @@ const char* SyntaxKindString(SyntaxKind syntaxKind) case Comment: return "Comment"; case Declaration: return "Declaration"; case DeclarationSequence: return "DeclarationSequence"; + case FunctionArgumentSequence: return "FunctionArgumentSequence"; + case FunctionCallExpression: return "FunctionCallExpression"; case FunctionDeclaration: return "FunctionDeclaration"; case FunctionSignature: return "FunctionSignature"; case FunctionSignatureArguments: return "FunctionSignatureArguments"; @@ -239,6 +241,36 @@ Node* MakeDeclarationSequenceNode( return node; } +Node *MakeFunctionArgumentSequenceNode( + Node **pArgumentNodes, + uint32_t argumentCount +) { + int32_t i; + Node* node = (Node*) malloc(sizeof(Node)); + node->syntaxKind = FunctionArgumentSequence; + node->childCount = argumentCount; + node->children = (Node**) malloc(sizeof(Node*) * node->childCount); + for (i = argumentCount - 1; i >= 0; i -= 1) + { + node->children[argumentCount - 1 - i] = pArgumentNodes[i]; + } + return node; +} + +Node* MakeFunctionCallExpressionNode( + Node *identifierNode, + Node *argumentSequenceNode +) { + int32_t i; + Node* node = (Node*) malloc(sizeof(Node)); + node->syntaxKind = FunctionCallExpression; + node->children = (Node**) malloc(sizeof(Node*) * 2); + node->childCount = 2; + node->children[0] = identifierNode; + node->children[1] = argumentSequenceNode; + return node; +} + static const char* PrimitiveTypeToString(PrimitiveType type) { switch (type) @@ -262,6 +294,10 @@ static void PrintBinaryOperator(BinaryOperator expression) case Subtract: printf("-"); break; + + case Multiply: + printf("*"); + break; } } @@ -308,5 +344,3 @@ void PrintTree(Node *node, uint32_t tabCount) PrintTree(node->children[i], tabCount + 1); } } - - diff --git a/ast.h b/ast.h index 6dd42ad..595f53f 100644 --- a/ast.h +++ b/ast.h @@ -12,6 +12,8 @@ typedef enum DeclarationSequence, Expression, ForLoop, + FunctionArgumentSequence, + FunctionCallExpression, FunctionDeclaration, FunctionSignature, FunctionSignatureArguments, @@ -33,7 +35,8 @@ typedef enum typedef enum { Add, - Subtract + Subtract, + Multiply } BinaryOperator; typedef enum @@ -130,6 +133,14 @@ Node* MakeDeclarationSequenceNode( Node **pNodes, uint32_t nodeCount ); +Node *MakeFunctionArgumentSequenceNode( + Node **pArgumentNodes, + uint32_t argumentCount +); +Node* MakeFunctionCallExpressionNode( + Node *identifierNode, + Node *argumentSequenceNode +); void PrintTree(Node *node, uint32_t tabCount); diff --git a/compiler.c b/compiler.c index 98b4e30..b3ce40a 100644 --- a/compiler.c +++ b/compiler.c @@ -14,13 +14,13 @@ extern FILE *yyin; Stack *stack; Node *rootNode; -typedef struct VariableMapValue +typedef struct IdentifierMapValue { char *name; - LLVMValueRef variable; -} VariableMapValue; + LLVMValueRef value; +} IdentifierMapValue; -VariableMapValue *namedVariables; +IdentifierMapValue *namedVariables; uint32_t namedVariableCount; static LLVMValueRef CompileExpression( @@ -32,11 +32,11 @@ static LLVMValueRef CompileExpression( static void AddNamedVariable(char *name, LLVMValueRef variable) { - VariableMapValue mapValue; + IdentifierMapValue mapValue; mapValue.name = name; - mapValue.variable = variable; + mapValue.value = variable; - namedVariables = realloc(namedVariables, namedVariableCount + 1); + namedVariables = realloc(namedVariables, sizeof(IdentifierMapValue) * (namedVariableCount + 1)); namedVariables[namedVariableCount] = mapValue; namedVariableCount += 1; @@ -50,7 +50,7 @@ static LLVMValueRef FindVariableByName(char *name) { if (strcmp(namedVariables[i].name, name) == 0) { - return namedVariables[i].variable; + return namedVariables[i].value; } } @@ -76,8 +76,7 @@ static LLVMValueRef CompileNumber( LLVMBuilderRef builder, LLVMValueRef function, Node *numberExpression -) -{ +) { return LLVMConstInt(LLVMInt64Type(), numberExpression->value.number, 0); } @@ -94,11 +93,36 @@ static LLVMValueRef CompileBinaryExpression( { case Add: return LLVMBuildAdd(builder, left, right, "tmp"); + + case Subtract: + return LLVMBuildSub(builder, left, right, "tmp"); + + case Multiply: + return LLVMBuildMul(builder, left, right, "tmp"); + } return NULL; } +static LLVMValueRef CompileFunctionCallExpression( + LLVMModuleRef module, + LLVMBuilderRef builder, + LLVMValueRef function, + Node *expression +) { + uint32_t i; + uint32_t argumentCount = expression->children[1]->childCount; + LLVMValueRef args[argumentCount]; + + for (i = 0; i < argumentCount; i += 1) + { + args[i] = CompileExpression(module, builder, function, expression->children[1]->children[i]); + } + + return LLVMBuildCall(builder, FindVariableByName(expression->children[0]->value.string), args, argumentCount, "tmp"); +} + static LLVMValueRef CompileExpression( LLVMModuleRef module, LLVMBuilderRef builder, @@ -110,6 +134,9 @@ static LLVMValueRef CompileExpression( case BinaryExpression: return CompileBinaryExpression(module, builder, function, expression); + case FunctionCallExpression: + return CompileFunctionCallExpression(module, builder, function, expression); + case Identifier: return FindVariableByName(expression->value.string); @@ -167,6 +194,8 @@ static void CompileFunction(LLVMModuleRef module, Node *functionDeclaration) { CompileStatement(module, builder, function, functionBody->children[i]); } + + AddNamedVariable(functionSignature->children[0]->value.string, function); } static void Compile(LLVMModuleRef module, Node *node) diff --git a/stack.c b/stack.c index 570b1aa..9ab52de 100644 --- a/stack.c +++ b/stack.c @@ -10,12 +10,9 @@ Stack* CreateStack() stack->stackFrames = (StackFrame*) malloc(sizeof(StackFrame) * stack->stackCapacity); for (i = 0; i < stack->stackCapacity; i += 1) { - stack->stackFrames[i].statements = NULL; - stack->stackFrames[i].statementCapacity = 0; - stack->stackFrames[i].statementCount = 0; - stack->stackFrames[i].declarations = NULL; - stack->stackFrames[i].declarationCapacity = 0; - stack->stackFrames[i].declarationCount = 0; + stack->stackFrames[i].nodes = NULL; + stack->stackFrames[i].nodeCapacity = 0; + stack->stackFrames[i].nodeCount = 0; } stack->stackIndex = 0; return stack; @@ -30,12 +27,10 @@ void PushStackFrame(Stack *stack) stack->stackCapacity += 1; stack->stackFrames = (StackFrame*) realloc(stack->stackFrames, sizeof(StackFrame) * stack->stackCapacity); - stack->stackFrames[stack->stackIndex].statementCapacity = 0; - stack->stackFrames[stack->stackIndex].declarationCapacity = 0; + stack->stackFrames[stack->stackIndex].nodeCapacity = 0; } - stack->stackFrames[stack->stackIndex].statementCount = 0; - stack->stackFrames[stack->stackIndex].declarationCount = 0; + stack->stackFrames[stack->stackIndex].nodeCount = 0; } void PopStackFrame(Stack *stack) @@ -43,38 +38,20 @@ void PopStackFrame(Stack *stack) stack->stackIndex -= 1; } -void AddStatement(Stack *stack, Node *statementNode) +void AddNode(Stack *stack, Node *statementNode) { StackFrame *stackFrame = &stack->stackFrames[stack->stackIndex]; - if (stackFrame->statementCount == stackFrame->statementCapacity) + if (stackFrame->nodeCount == stackFrame->nodeCapacity) { - stackFrame->statementCapacity += 1; - stackFrame->statements = (Node**) realloc(stackFrame->statements, stackFrame->statementCapacity); + stackFrame->nodeCapacity += 1; + stackFrame->nodes = (Node**) realloc(stackFrame->nodes, stackFrame->nodeCapacity); } - stackFrame->statements[stackFrame->statementCount] = statementNode; - stackFrame->statementCount += 1; + stackFrame->nodes[stackFrame->nodeCount] = statementNode; + stackFrame->nodeCount += 1; } -Node** GetStatements(Stack *stack, uint32_t *pCount) +Node** GetNodes(Stack *stack, uint32_t *pCount) { - *pCount = stack->stackFrames[stack->stackIndex].statementCount; - return stack->stackFrames[stack->stackIndex].statements; -} - -void AddDeclaration(Stack *stack, Node *declarationNode) -{ - StackFrame *stackFrame = &stack->stackFrames[stack->stackIndex]; - if (stackFrame->declarationCount == stackFrame->declarationCapacity) - { - stackFrame->declarationCapacity += 1; - stackFrame->declarations = (Node**) realloc(stackFrame->declarations, stackFrame->declarationCapacity); - } - stackFrame->declarations[stackFrame->declarationCount] = declarationNode; - stackFrame->declarationCount += 1; -} - -Node** GetDeclarations(Stack *stack, uint32_t *pCount) -{ - *pCount = stack->stackFrames[stack->stackIndex].declarationCount; - return stack->stackFrames[stack->stackIndex].declarations; + *pCount = stack->stackFrames[stack->stackIndex].nodeCount; + return stack->stackFrames[stack->stackIndex].nodes; } diff --git a/stack.h b/stack.h index 3d118b9..09c8892 100644 --- a/stack.h +++ b/stack.h @@ -5,13 +5,9 @@ typedef struct StackFrame { - Node **statements; - uint32_t statementCount; - uint32_t statementCapacity; - - Node **declarations; - uint32_t declarationCount; - uint32_t declarationCapacity; + Node **nodes; + uint32_t nodeCount; + uint32_t nodeCapacity; } StackFrame; typedef struct Stack @@ -25,9 +21,7 @@ Stack* CreateStack(); void PushStackFrame(Stack *stack); void PopStackFrame(Stack *stack); -void AddStatement(Stack *stack, Node *statementNode); -Node** GetStatements(Stack *stack, uint32_t *pCount); -void AddDeclaration(Stack *stack, Node *declarationNode); -Node** GetDeclarations(Stack *stack, uint32_t *pCount); +void AddNode(Stack *stack, Node *statementNode); +Node** GetNodes(Stack *stack, uint32_t *pCount); -#endif /* WRAITH_STACK_H */ \ No newline at end of file +#endif /* WRAITH_STACK_H */ diff --git a/wraith.lex b/wraith.lex index 8137c48..056eefc 100644 --- a/wraith.lex +++ b/wraith.lex @@ -18,9 +18,9 @@ \"[a-zA-Z][a-zA-Z0-9]*\" return STRING_LITERAL; "+" return PLUS; "-" return MINUS; -"*" return MULTIPLY; -"/" return DIVIDE; -"%" return MOD; +"*" return STAR; +"/" return SLASH; +"%" return PERCENT; "<" return LESS_THAN; ">" return GREATER_THAN; "=" return EQUAL; diff --git a/wraith.y b/wraith.y index dee5563..d048896 100644 --- a/wraith.y +++ b/wraith.y @@ -34,9 +34,9 @@ extern Node *rootNode; %token STRING_LITERAL %token PLUS %token MINUS -%token MULTIPLY -%token DIVIDE -%token MOD +%token STAR +%token SLASH +%token PERCENT %token EQUAL %token LESS_THAN %token GREATER_THAN @@ -71,7 +71,7 @@ Program : Declarations Node *declarationSequence; uint32_t declarationCount; - declarations = GetDeclarations(stack, &declarationCount); + declarations = GetNodes(stack, &declarationCount); declarationSequence = MakeDeclarationSequenceNode(declarations, declarationCount); PopStackFrame(stack); @@ -123,6 +123,7 @@ PrimaryExpression : Identifier { $$ = $2; } + | FunctionCallExpression ; UnaryExpression : BANG Expression @@ -138,6 +139,10 @@ BinaryExpression : Expression PLUS Expression { $$ = MakeBinaryNode(Subtract, $1, $3); } + | Expression STAR Expression + { + $$ = MakeBinaryNode(Multiply, $1, $3); + } Expression : PrimaryExpression | UnaryExpression @@ -163,8 +168,20 @@ ReturnStatement : RETURN Expression $$ = MakeReturnStatementNode($2); } +FunctionCallExpression : Identifier LEFT_PAREN Arguments RIGHT_PAREN + { + Node **arguments; + uint32_t argumentCount; + + arguments = GetNodes(stack, &argumentCount); + $$ = MakeFunctionCallExpressionNode($1, MakeFunctionArgumentSequenceNode(arguments, argumentCount)); + + PopStackFrame(stack); + } + PartialStatement : AssignmentStatement | VariableDeclaration + | FunctionCallExpression | ReturnStatement ; @@ -172,22 +189,36 @@ Statement : PartialStatement SEMICOLON; Statements : Statement Statements { - AddStatement(stack, $1); + AddNode(stack, $1); } | { PushStackFrame(stack); } -Arguments : VariableDeclaration COMMA VariableDeclarations +Arguments : PrimaryExpression COMMA Arguments { - AddDeclaration(stack, $1); + AddNode(stack, $1); + } + | PrimaryExpression + { + PushStackFrame(stack); + AddNode(stack, $1); + } + | + ; + +SignatureArguments : VariableDeclaration COMMA VariableDeclarations + { + AddNode(stack, $1); } | VariableDeclaration { PushStackFrame(stack); - AddDeclaration(stack, $1); + AddNode(stack, $1); } + | + ; Body : LEFT_BRACE Statements RIGHT_BRACE { @@ -195,19 +226,19 @@ Body : LEFT_BRACE Statements RIGHT_BRACE Node *statementSequence; uint32_t statementCount; - statements = GetStatements(stack, &statementCount); + statements = GetNodes(stack, &statementCount); statementSequence = MakeStatementSequenceNode(statements, statementCount); $$ = MakeStatementSequenceNode(statements, statementCount); PopStackFrame(stack); } -FunctionSignature : Type Identifier LEFT_PAREN Arguments RIGHT_PAREN +FunctionSignature : Type Identifier LEFT_PAREN SignatureArguments RIGHT_PAREN { Node **declarations; uint32_t declarationCount; - declarations = GetDeclarations(stack, &declarationCount); + declarations = GetNodes(stack, &declarationCount); $$ = MakeFunctionSignatureNode($2, $1, MakeFunctionSignatureArgumentsNode(declarations, declarationCount)); PopStackFrame(stack); @@ -220,7 +251,7 @@ FunctionDeclaration : FunctionSignature Body VariableDeclarations : VariableDeclaration SEMICOLON VariableDeclarations { - AddDeclaration(stack, $1); + AddNode(stack, $1); } | { @@ -233,7 +264,7 @@ StructDeclaration : STRUCT Identifier LEFT_BRACE VariableDeclarations RIGH Node *declarationSequence; uint32_t declarationCount; - declarations = GetDeclarations(stack, &declarationCount); + declarations = GetNodes(stack, &declarationCount); declarationSequence = MakeDeclarationSequenceNode(declarations, declarationCount); $$ = MakeStructDeclarationNode($2, declarationSequence); @@ -246,7 +277,7 @@ Declaration : StructDeclaration Declarations : Declaration Declarations { - AddDeclaration(stack, $1); + AddNode(stack, $1); } | {