From c2efcbd7d2eadf91f687def707a2fbe54a9a7078 Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Thu, 22 Apr 2021 00:35:42 -0700 Subject: [PATCH] add static modifier --- ast.c | 36 +++++++++++++++++++++-- ast.h | 10 ++++++- compiler.c | 86 ++++++++++++++++++++++++++++++++++-------------------- wraith.lex | 1 + wraith.y | 14 ++++++++- 5 files changed, 111 insertions(+), 36 deletions(-) diff --git a/ast.c b/ast.c index e41761c..03652d7 100644 --- a/ast.c +++ b/ast.c @@ -30,12 +30,14 @@ const char* SyntaxKindString(SyntaxKind syntaxKind) case FunctionArgumentSequence: return "FunctionArgumentSequence"; case FunctionCallExpression: return "FunctionCallExpression"; case FunctionDeclaration: return "FunctionDeclaration"; + case FunctionModifiers: return "FunctionModifiers"; case FunctionSignature: return "FunctionSignature"; case FunctionSignatureArguments: return "FunctionSignatureArguments"; case Identifier: return "Identifier"; case Number: return "Number"; case Return: return "Return"; case StatementSequence: return "StatementSequence"; + case StaticModifier: return "StaticModifier"; case StringLiteral: return "StringLiteral"; case StructDeclaration: return "StructDeclaration"; case Type: return "Type"; @@ -97,6 +99,34 @@ Node* MakeStringNode( return node; } +Node* MakeStaticNode() +{ + Node* node = (Node*) malloc(sizeof(Node)); + node->syntaxKind = StaticModifier; + node->childCount = 0; + return node; +} + +Node* MakeFunctionModifiersNode( + Node **pModifierNodes, + uint32_t modifierCount +) { + uint32_t i; + Node* node = (Node*) malloc(sizeof(Node)); + node->syntaxKind = FunctionModifiers; + node->childCount = modifierCount; + if (modifierCount > 0) + { + node->children = malloc(sizeof(Node*) * node->childCount); + for (i = 0; i < modifierCount; i += 1) + { + node->children[i] = pModifierNodes[i]; + } + } + + return node; +} + Node* MakeUnaryNode( UnaryOperator operator, Node *child @@ -208,15 +238,17 @@ Node *MakeFunctionSignatureArgumentsNode( Node* MakeFunctionSignatureNode( Node *identifierNode, Node* typeNode, - Node* arguments + Node* arguments, + Node* modifiersNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionSignature; - node->childCount = 3; + node->childCount = 4; node->children = (Node**) malloc(sizeof(Node*) * (node->childCount)); node->children[0] = identifierNode; node->children[1] = typeNode; node->children[2] = arguments; + node->children[3] = modifiersNode; return node; } diff --git a/ast.h b/ast.h index 96f22a8..148355d 100644 --- a/ast.h +++ b/ast.h @@ -16,6 +16,7 @@ typedef enum FunctionArgumentSequence, FunctionCallExpression, FunctionDeclaration, + FunctionModifiers, FunctionSignature, FunctionSignatureArguments, Identifier, @@ -23,6 +24,7 @@ typedef enum Return, ReturnVoid, StatementSequence, + StaticModifier, StringLiteral, StructDeclaration, Type, @@ -95,6 +97,11 @@ Node* MakeNumberNode( Node* MakeStringNode( const char *string ); +Node* MakeStaticNode(); +Node* MakeFunctionModifiersNode( + Node **pModifierNodes, + uint32_t modifierCount +); Node* MakeUnaryNode( UnaryOperator operator, Node *child @@ -127,7 +134,8 @@ Node* MakeFunctionSignatureArgumentsNode( Node* MakeFunctionSignatureNode( Node *identifierNode, Node* typeNode, - Node* arguments + Node* arguments, + Node* modifiersNode ); Node* MakeFunctionDeclarationNode( Node* functionSignatureNode, diff --git a/compiler.c b/compiler.c index b6278a7..6c482ae 100644 --- a/compiler.c +++ b/compiler.c @@ -5,10 +5,12 @@ #include #include #include +#include #include #include #include #include +#include #include "y.tab.h" #include "ast.h" @@ -240,7 +242,6 @@ static void AddStructVariables( } static LLVMValueRef CompileExpression( - LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *binaryExpression @@ -306,13 +307,12 @@ static LLVMValueRef CompileNumber( } static LLVMValueRef CompileBinaryExpression( - LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *binaryExpression ) { - LLVMValueRef left = CompileExpression(wStructValue, builder, function, binaryExpression->children[0]); - LLVMValueRef right = CompileExpression(wStructValue, builder, function, binaryExpression->children[1]); + LLVMValueRef left = CompileExpression(builder, function, binaryExpression->children[0]); + LLVMValueRef right = CompileExpression(builder, function, binaryExpression->children[1]); switch (binaryExpression->operator.binaryOperator) { @@ -332,7 +332,6 @@ static LLVMValueRef CompileBinaryExpression( /* FIXME THIS IS ALL BROKEN */ static LLVMValueRef CompileFunctionCallExpression( - LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *expression @@ -343,7 +342,7 @@ static LLVMValueRef CompileFunctionCallExpression( for (i = 0; i < argumentCount; i += 1) { - args[i] = CompileExpression(wStructValue, builder, function, expression->children[1]->children[i]); + args[i] = CompileExpression(builder, function, expression->children[1]->children[i]); } //return LLVMBuildCall(builder, FindVariableValueByName(builder, wStructValue, expression->children[0]->value.string), args, argumentCount, "tmp"); @@ -352,7 +351,6 @@ static LLVMValueRef CompileFunctionCallExpression( static LLVMValueRef CompileAccessExpressionForStore( LLVMBuilderRef builder, - LLVMValueRef wStructValue, LLVMValueRef function, Node *expression ) { @@ -364,7 +362,6 @@ static LLVMValueRef CompileAccessExpressionForStore( static LLVMValueRef CompileAccessExpression( LLVMBuilderRef builder, - LLVMValueRef wStructValue, LLVMValueRef function, Node *expression ) { @@ -376,7 +373,6 @@ static LLVMValueRef CompileAccessExpression( } static LLVMValueRef CompileExpression( - LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *expression @@ -384,13 +380,13 @@ static LLVMValueRef CompileExpression( switch (expression->syntaxKind) { case AccessExpression: - return CompileAccessExpression(builder, wStructValue, function, expression); + return CompileAccessExpression(builder, function, expression); case BinaryExpression: - return CompileBinaryExpression(wStructValue, builder, function, expression); + return CompileBinaryExpression(builder, function, expression); case FunctionCallExpression: - return CompileFunctionCallExpression(wStructValue, builder, function, expression); + return CompileFunctionCallExpression(builder, function, expression); case Identifier: return FindVariableValue(builder, expression->value.string); @@ -403,10 +399,9 @@ static LLVMValueRef CompileExpression( return NULL; } -/* FIXME: we need a scope structure */ -static void CompileReturn(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement) +static void CompileReturn(LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement) { - LLVMValueRef expression = CompileExpression(wStructValue, builder, function, returnStatemement->children[0]); + LLVMValueRef expression = CompileExpression(builder, function, returnStatemement->children[0]); LLVMBuildRet(builder, expression); } @@ -415,13 +410,13 @@ static void CompileReturnVoid(LLVMBuilderRef builder) LLVMBuildRetVoid(builder); } -static void CompileAssignment(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement) +static void CompileAssignment(LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement) { - LLVMValueRef result = CompileExpression(wStructValue, builder, function, assignmentStatement->children[1]); + LLVMValueRef result = CompileExpression(builder, function, assignmentStatement->children[1]); LLVMValueRef identifier; if (assignmentStatement->children[0]->syntaxKind == AccessExpression) { - identifier = CompileAccessExpressionForStore(builder, wStructValue, function, assignmentStatement->children[0]); + identifier = CompileAccessExpressionForStore(builder, function, assignmentStatement->children[0]); } else if (assignmentStatement->children[0]->syntaxKind == Identifier) { @@ -458,12 +453,12 @@ static void CompileFunctionVariableDeclaration(LLVMBuilderRef builder, Node *var AddLocalVariable(scope, variable, variableName); } -static uint8_t CompileStatement(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *statement) +static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement) { switch (statement->syntaxKind) { case Assignment: - CompileAssignment(wStructValue, builder, function, statement); + CompileAssignment(builder, function, statement); return 0; case Declaration: @@ -471,7 +466,7 @@ static uint8_t CompileStatement(LLVMValueRef wStructValue, LLVMBuilderRef builde return 0; case Return: - CompileReturn(wStructValue, builder, function, statement); + CompileReturn(builder, function, statement); return 1; case ReturnVoid: @@ -492,31 +487,52 @@ static void CompileFunction( ) { uint32_t i; uint8_t hasReturn = 0; + uint8_t isStatic = 0; Node *functionSignature = functionDeclaration->children[0]; Node *functionBody = functionDeclaration->children[1]; - uint32_t argumentCount = functionSignature->children[2]->childCount + 1; /* struct is implicit argument */ - LLVMTypeRef paramTypes[argumentCount]; + uint32_t argumentCount = functionSignature->children[2]->childCount; + LLVMTypeRef paramTypes[argumentCount + 1]; + uint32_t paramIndex = 0; + + if (functionSignature->children[3]->childCount > 0) + { + for (i = 0; i < functionSignature->children[3]->childCount; i += 1) + { + if (functionSignature->children[3]->children[i]->syntaxKind == StaticModifier) + { + isStatic = 1; + break; + } + } + } + + if (!isStatic) + { + paramTypes[paramIndex] = wStructPointerType; + paramIndex += 1; + } PushScopeFrame(scope); - paramTypes[0] = wStructPointerType; - for (i = 0; i < functionSignature->children[2]->childCount; i += 1) { - paramTypes[i + 1] = WraithTypeToLLVMType(functionSignature->children[2]->children[i]->children[0]->type); + paramTypes[paramIndex] = WraithTypeToLLVMType(functionSignature->children[2]->children[i]->children[0]->type); + paramIndex += 1; } LLVMTypeRef returnType = WraithTypeToLLVMType(functionSignature->children[1]->type); - LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, argumentCount, 0); + LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, paramIndex, 0); LLVMValueRef function = LLVMAddFunction(module, functionSignature->children[0]->value.string, functionType); LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry"); LLVMBuilderRef builder = LLVMCreateBuilder(); LLVMPositionBuilderAtEnd(builder, entry); - LLVMValueRef wStructPointer = LLVMGetParam(function, 0); - - AddStructVariables(builder, wStructPointer); + if (!isStatic) + { + LLVMValueRef wStructPointer = LLVMGetParam(function, 0); + AddStructVariables(builder, wStructPointer); + } for (i = 0; i < functionSignature->children[2]->childCount; i += 1) { @@ -531,7 +547,7 @@ static void CompileFunction( for (i = 0; i < functionBody->childCount; i += 1) { - hasReturn |= CompileStatement(wStructPointer, builder, function, functionBody->children[i]); + hasReturn |= CompileStatement(builder, function, functionBody->children[i]); } if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn) @@ -645,7 +661,8 @@ int main(int argc, char *argv[]) char *error = NULL; LLVMVerifyModule(module, LLVMAbortProcessAction, &error); - LLVMDisposeMessage(error); + + LLVMSetTarget(module, LLVM_DEFAULT_TARGET_TRIPLE); LLVMPassManagerRef passManager = LLVMCreatePassManager(); LLVMAddInstructionCombiningPass(passManager); @@ -663,6 +680,11 @@ int main(int argc, char *argv[]) fprintf(stderr, "error writing bitcode to file\n"); } + LLVMMemoryBufferRef memoryBuffer = LLVMWriteBitcodeToMemoryBuffer(module); + LLVMCreateBinary(memoryBuffer, context, &error); + LLVMDisposeMessage(error); + LLVMDisposeMemoryBuffer(memoryBuffer); + LLVMPassManagerBuilderDispose(passManagerBuilder); LLVMDisposePassManager(passManager); LLVMDisposeModule(module); diff --git a/wraith.lex b/wraith.lex index 5bfef85..2c7d539 100644 --- a/wraith.lex +++ b/wraith.lex @@ -14,6 +14,7 @@ "bool" return BOOL; "struct" return STRUCT; "return" return RETURN; +"static" return STATIC; [0-9]+ return NUMBER; [a-zA-Z][a-zA-Z0-9]* return ID; \"[a-zA-Z][a-zA-Z0-9]*\" return STRING_LITERAL; diff --git a/wraith.y b/wraith.y index 0cec2a5..57415dc 100644 --- a/wraith.y +++ b/wraith.y @@ -30,6 +30,7 @@ extern Node *rootNode; %token BOOL %token STRUCT %token RETURN +%token STATIC %token NUMBER %token ID %token STRING_LITERAL @@ -265,7 +266,18 @@ FunctionSignature : Type Identifier LEFT_PAREN SignatureArguments RIGHT_PA uint32_t declarationCount; declarations = GetNodes(stack, &declarationCount); - $$ = MakeFunctionSignatureNode($2, $1, MakeFunctionSignatureArgumentsNode(declarations, declarationCount)); + $$ = MakeFunctionSignatureNode($2, $1, MakeFunctionSignatureArgumentsNode(declarations, declarationCount), MakeFunctionModifiersNode(NULL, 0)); + + PopStackFrame(stack); + } + | STATIC Type Identifier LEFT_PAREN SignatureArguments RIGHT_PAREN + { + Node **declarations; + uint32_t declarationCount; + Node *modifier = MakeStaticNode(); + + declarations = GetNodes(stack, &declarationCount); + $$ = MakeFunctionSignatureNode($3, $2, MakeFunctionSignatureArgumentsNode(declarations, declarationCount), MakeFunctionModifiersNode(&modifier, 1)); PopStackFrame(stack); }