From 2708dfbbed0e8e4aa52ff6a469059a5b67681903 Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Tue, 20 Apr 2021 19:00:18 -0700 Subject: [PATCH] compile struct functions --- ast.c | 10 ++ ast.h | 3 + compiler.c | 296 ++++++++++++++++++++++++++++++++++++++++++++++------- stack.c | 2 +- wraith.lex | 3 +- wraith.y | 44 +++++--- 6 files changed, 304 insertions(+), 54 deletions(-) diff --git a/ast.c b/ast.c index 6d3b123..daafae9 100644 --- a/ast.c +++ b/ast.c @@ -165,6 +165,15 @@ Node* MakeReturnStatementNode( return node; } +Node* MakeReturnVoidStatementNode() +{ + Node *node = (Node*) malloc(sizeof(Node)); + node->syntaxKind = ReturnVoid; + node->childCount = 0; + node->children = NULL; + return node; +} + Node *MakeFunctionSignatureArgumentsNode( Node **pArgumentNodes, uint32_t argumentCount @@ -278,6 +287,7 @@ static const char* PrimitiveTypeToString(PrimitiveType type) case Int: return "Int"; case UInt: return "UInt"; case Bool: return "Bool"; + case Void: return "Void"; } return "Unknown"; diff --git a/ast.h b/ast.h index 595f53f..a8c80bd 100644 --- a/ast.h +++ b/ast.h @@ -20,6 +20,7 @@ typedef enum Identifier, Number, Return, + ReturnVoid, StatementSequence, StringLiteral, StructDeclaration, @@ -41,6 +42,7 @@ typedef enum typedef enum { + Void, Bool, Int, UInt, @@ -112,6 +114,7 @@ Node* MakeStatementSequenceNode( Node* MakeReturnStatementNode( Node *expressionNode ); +Node* MakeReturnVoidStatementNode(); Node* MakeFunctionSignatureArgumentsNode( Node **pArgumentNodes, uint32_t argumentCount diff --git a/compiler.c b/compiler.c index b3ce40a..56618b3 100644 --- a/compiler.c +++ b/compiler.c @@ -14,6 +14,100 @@ extern FILE *yyin; Stack *stack; Node *rootNode; +typedef struct StructFieldMapValue +{ + char *name; + LLVMValueRef value; + LLVMValueRef valuePointer; + uint8_t needsWrite; +} StructFieldMapValue; + +typedef struct StructFieldMap +{ + LLVMValueRef structPointer; + StructFieldMapValue *fields; + uint32_t fieldCount; +} StructFieldMap; + +StructFieldMap *structFieldMaps; +uint32_t structFieldMapCount; + +static void AddStruct(LLVMValueRef wStructPointer) +{ + structFieldMaps = realloc(structFieldMaps, sizeof(StructFieldMap) * (structFieldMapCount + 1)); + structFieldMaps[structFieldMapCount].structPointer = wStructPointer; + structFieldMaps[structFieldMapCount].fields = NULL; + structFieldMaps[structFieldMapCount].fieldCount = 0; + structFieldMapCount += 1; +} + +static void AddStructField(LLVMBuilderRef builder, LLVMValueRef wStructPointer, char *name, uint32_t index) +{ + uint32_t i, fieldCount; + + for (i = 0; i < structFieldMapCount; i += 1) + { + if (structFieldMaps[i].structPointer == wStructPointer) + { + fieldCount = structFieldMaps[i].fieldCount; + + LLVMValueRef elementPointer = LLVMBuildStructGEP( + builder, + wStructPointer, + fieldCount, + "ptr" + ); + + structFieldMaps[i].fields = realloc(structFieldMaps[i].fields, sizeof(StructFieldMapValue) * (fieldCount + 1)); + structFieldMaps[i].fields[fieldCount].name = strdup(name); + structFieldMaps[i].fields[fieldCount].value = LLVMBuildLoad(builder, elementPointer, name); + structFieldMaps[i].fields[fieldCount].valuePointer = elementPointer; + structFieldMaps[i].fields[fieldCount].needsWrite = 0; + structFieldMaps[i].fieldCount += 1; + + break; + } + } +} + +static LLVMValueRef GetStructFieldPointer(LLVMValueRef wStructPointer, LLVMValueRef value) +{ + uint32_t i, j; + + for (i = 0; i < structFieldMapCount; i += 1) + { + if (structFieldMaps[i].structPointer == wStructPointer) + { + for (j = 0; j < structFieldMaps[i].fieldCount; j += 1) + { + if (structFieldMaps[i].fields[j].value == value) + { + return structFieldMaps[i].fields[j].valuePointer; + } + } + } + } + + return NULL; +} + +static void RemoveStruct(LLVMValueRef wStructPointer) +{ + uint32_t i; + + for (i = 0; i < structFieldMapCount; i += 1) + { + if (structFieldMaps[i].structPointer == wStructPointer) + { + free(structFieldMaps[i].fields); + structFieldMaps[i].fields = NULL; + structFieldMaps[i].fieldCount = 0; + + break; + } + } +} + typedef struct IdentifierMapValue { char *name; @@ -24,7 +118,7 @@ IdentifierMapValue *namedVariables; uint32_t namedVariableCount; static LLVMValueRef CompileExpression( - LLVMModuleRef module, + LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *binaryExpression @@ -42,10 +136,11 @@ static void AddNamedVariable(char *name, LLVMValueRef variable) namedVariableCount += 1; } -static LLVMValueRef FindVariableByName(char *name) +static LLVMValueRef FindVariableByName(LLVMValueRef wStructValue, LLVMBuilderRef builder, char *name) { - uint32_t i; + uint32_t i, j; + /* first, search scoped vars */ for (i = 0; i < namedVariableCount; i += 1) { if (strcmp(namedVariables[i].name, name) == 0) @@ -54,6 +149,22 @@ static LLVMValueRef FindVariableByName(char *name) } } + /* if none exist, search struct vars */ + for (i = 0; i < structFieldMapCount; i += 1) + { + if (structFieldMaps[i].structPointer == wStructValue) + { + for (j = 0; j < structFieldMaps[i].fieldCount; j += 1) + { + if (strcmp(structFieldMaps[i].fields[j].name, name) == 0) + { + return structFieldMaps[i].fields[j].value; + } + } + } + } + + fprintf(stderr, "Identifier not found!"); return NULL; } @@ -66,28 +177,32 @@ static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type) case UInt: return LLVMInt64Type(); + + case Bool: + return LLVMInt1Type(); + + case Void: + return LLVMVoidType(); } + fprintf(stderr, "Unrecognized type!"); return NULL; } static LLVMValueRef CompileNumber( - LLVMModuleRef module, - LLVMBuilderRef builder, - LLVMValueRef function, Node *numberExpression ) { return LLVMConstInt(LLVMInt64Type(), numberExpression->value.number, 0); } static LLVMValueRef CompileBinaryExpression( - LLVMModuleRef module, + LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *binaryExpression ) { - LLVMValueRef left = CompileExpression(module, builder, function, binaryExpression->children[0]); - LLVMValueRef right = CompileExpression(module, builder, function, binaryExpression->children[1]); + LLVMValueRef left = CompileExpression(wStructValue, builder, function, binaryExpression->children[0]); + LLVMValueRef right = CompileExpression(wStructValue, builder, function, binaryExpression->children[1]); switch (binaryExpression->operator.binaryOperator) { @@ -106,7 +221,7 @@ static LLVMValueRef CompileBinaryExpression( } static LLVMValueRef CompileFunctionCallExpression( - LLVMModuleRef module, + LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *expression @@ -117,71 +232,113 @@ static LLVMValueRef CompileFunctionCallExpression( for (i = 0; i < argumentCount; i += 1) { - args[i] = CompileExpression(module, builder, function, expression->children[1]->children[i]); + args[i] = CompileExpression(wStructValue, builder, function, expression->children[1]->children[i]); } - return LLVMBuildCall(builder, FindVariableByName(expression->children[0]->value.string), args, argumentCount, "tmp"); + return LLVMBuildCall(builder, FindVariableByName(wStructValue, builder, expression->children[0]->value.string), args, argumentCount, "tmp"); } static LLVMValueRef CompileExpression( - LLVMModuleRef module, + LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *expression ) { + LLVMValueRef var; + switch (expression->syntaxKind) { case BinaryExpression: - return CompileBinaryExpression(module, builder, function, expression); + return CompileBinaryExpression(wStructValue, builder, function, expression); case FunctionCallExpression: - return CompileFunctionCallExpression(module, builder, function, expression); + return CompileFunctionCallExpression(wStructValue, builder, function, expression); case Identifier: - return FindVariableByName(expression->value.string); + return FindVariableByName(wStructValue, builder, expression->value.string); case Number: - return CompileNumber(module, builder, function, expression); + return CompileNumber(expression); } - printf("Error: expected expression\n"); + fprintf(stderr, "Unknown expression kind!\n"); return NULL; } -static void CompileReturn(LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement) +static void CompileReturn(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement) { - LLVMBuildRet(builder, CompileExpression(module, builder, function, returnStatemement->children[0])); + LLVMBuildRet(builder, CompileExpression(wStructValue, builder, function, returnStatemement->children[0])); } -static void CompileStatement(LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *statement) +static void CompileReturnVoid(LLVMBuilderRef builder) +{ + LLVMBuildRetVoid(builder); +} + +static void CompileAssignment(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement) +{ + LLVMValueRef fieldPointer; + LLVMValueRef result = CompileExpression(wStructValue, builder, function, assignmentStatement->children[1]); + LLVMValueRef identifier = CompileExpression(wStructValue, builder, function, assignmentStatement->children[0]); + + fieldPointer = GetStructFieldPointer(wStructValue, identifier); + if (fieldPointer != NULL) + { + LLVMBuildStore(builder, result, fieldPointer); + } +} + +static uint8_t CompileStatement(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *statement) { switch (statement->syntaxKind) { + case Assignment: + CompileAssignment(wStructValue, builder, function, statement); + return 0; + case Return: - CompileReturn(module, builder, function, statement); - break; + CompileReturn(wStructValue, builder, function, statement); + return 1; + + case ReturnVoid: + CompileReturnVoid(builder); + return 1; } + + fprintf(stderr, "Unknown statement kind!\n"); + return 0; } -static void CompileFunction(LLVMModuleRef module, Node *functionDeclaration) -{ +static void CompileFunction( + LLVMModuleRef module, + LLVMTypeRef wStructPointerType, + Node **fieldDeclarations, + uint32_t fieldDeclarationCount, + Node *functionDeclaration +) { uint32_t i; + uint8_t hasReturn = 0; Node *functionSignature = functionDeclaration->children[0]; Node *functionBody = functionDeclaration->children[1]; - LLVMTypeRef paramTypes[functionSignature->children[2]->childCount]; + uint32_t argumentCount = functionSignature->children[2]->childCount + 1; /* struct is implicit argument */ + LLVMTypeRef paramTypes[argumentCount]; + + paramTypes[0] = wStructPointerType; for (i = 0; i < functionSignature->children[2]->childCount; i += 1) { - paramTypes[i] = WraithTypeToLLVMType(functionSignature->children[2]->children[i]->children[0]->type); + paramTypes[i + 1] = WraithTypeToLLVMType(functionSignature->children[2]->children[i]->children[0]->type); } - LLVMTypeRef functionType = LLVMFunctionType(WraithTypeToLLVMType(functionSignature->children[1]->type), paramTypes, functionSignature->children[2]->childCount, 0); - + LLVMTypeRef returnType = WraithTypeToLLVMType(functionSignature->children[1]->type); + LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, argumentCount, 0); LLVMValueRef function = LLVMAddFunction(module, functionSignature->children[0]->value.string, functionType); + LLVMValueRef wStructPointer = LLVMGetParam(function, 0); + for (i = 0; i < functionSignature->children[2]->childCount; i += 1) { - LLVMValueRef argument = LLVMGetParam(function, i); + LLVMValueRef argument = LLVMGetParam(function, i + 1); AddNamedVariable(functionSignature->children[2]->children[i]->children[1]->value.string, argument); } @@ -190,28 +347,91 @@ static void CompileFunction(LLVMModuleRef module, Node *functionDeclaration) LLVMBuilderRef builder = LLVMCreateBuilder(); LLVMPositionBuilderAtEnd(builder, entry); + /* FIXME: replace this with a scope abstraction */ + AddStruct(wStructPointer); + + for (i = 0; i < fieldDeclarationCount; i += 1) + { + AddStructField(builder, wStructPointer, fieldDeclarations[i]->children[1]->value.string, i); + } + for (i = 0; i < functionBody->childCount; i += 1) { - CompileStatement(module, builder, function, functionBody->children[i]); + hasReturn |= CompileStatement(wStructPointer, builder, function, functionBody->children[i]); } AddNamedVariable(functionSignature->children[0]->value.string, function); + + if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn) + { + LLVMBuildRetVoid(builder); + } + else if (LLVMGetTypeKind(returnType) != LLVMVoidTypeKind && !hasReturn) + { + fprintf(stderr, "Return statement not provided!"); + } + + RemoveStruct(wStructPointer); } -static void Compile(LLVMModuleRef module, Node *node) +static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *node) +{ + uint32_t i; + uint32_t fieldCount = 0; + uint32_t declarationCount = node->children[1]->childCount; + uint8_t packed = 1; + LLVMTypeRef types[declarationCount]; + Node *currentDeclarationNode; + Node *fieldDeclarations[declarationCount]; + + LLVMTypeRef wStruct = LLVMStructCreateNamed(context, node->children[0]->value.string); + LLVMTypeRef wStructPointerType = LLVMPointerType(wStruct, 0); /* FIXME: is this address space correct? */ + + /* first, build the structure definition */ + for (i = 0; i < declarationCount; i += 1) + { + currentDeclarationNode = node->children[1]->children[i]; + + switch (currentDeclarationNode->syntaxKind) + { + case Declaration: /* this is badly named */ + types[fieldCount] = WraithTypeToLLVMType(currentDeclarationNode->children[0]->type); + fieldDeclarations[fieldCount] = currentDeclarationNode; + fieldCount += 1; + break; + } + } + + LLVMStructSetBody(wStruct, types, fieldCount, packed); + + /* now we can wire up the functions */ + for (i = 0; i < declarationCount; i += 1) + { + currentDeclarationNode = node->children[1]->children[i]; + + switch (currentDeclarationNode->syntaxKind) + { + case FunctionDeclaration: + CompileFunction(module, wStructPointerType, fieldDeclarations, fieldCount, currentDeclarationNode); + break; + } + } +} + +static void Compile(LLVMModuleRef module, LLVMContextRef context, Node *node) { uint32_t i; switch (node->syntaxKind) { - case FunctionDeclaration: - CompileFunction(module, node); + case StructDeclaration: + CompileStruct(module, context, node); break; } for (i = 0; i < node->childCount; i += 1) { - Compile(module, node->children[i]); + Compile(module, context, node->children[i]); } } @@ -226,6 +446,9 @@ int main(int argc, char *argv[]) namedVariables = NULL; namedVariableCount = 0; + structFieldMaps = NULL; + structFieldMapCount = 0; + stack = CreateStack(); FILE *fp = fopen(argv[1], "r"); @@ -236,8 +459,9 @@ int main(int argc, char *argv[]) PrintTree(rootNode, 0); LLVMModuleRef module = LLVMModuleCreateWithName("my_module"); + LLVMContextRef context = LLVMGetGlobalContext(); - Compile(module, rootNode); + Compile(module, context, rootNode); char *error = NULL; LLVMVerifyModule(module, LLVMAbortProcessAction, &error); diff --git a/stack.c b/stack.c index 9ab52de..1fe7fe0 100644 --- a/stack.c +++ b/stack.c @@ -44,7 +44,7 @@ void AddNode(Stack *stack, Node *statementNode) if (stackFrame->nodeCount == stackFrame->nodeCapacity) { stackFrame->nodeCapacity += 1; - stackFrame->nodes = (Node**) realloc(stackFrame->nodes, stackFrame->nodeCapacity); + stackFrame->nodes = (Node**) realloc(stackFrame->nodes, sizeof(Node*) * stackFrame->nodeCapacity); } stackFrame->nodes[stackFrame->nodeCount] = statementNode; stackFrame->nodeCount += 1; diff --git a/wraith.lex b/wraith.lex index 056eefc..5bfef85 100644 --- a/wraith.lex +++ b/wraith.lex @@ -5,7 +5,7 @@ %option noyywrap %% -[0-9]+ return NUMBER; +"void" return VOID; "int" return INT; "uint" return UINT; "float" return FLOAT; @@ -14,6 +14,7 @@ "bool" return BOOL; "struct" return STRUCT; "return" return RETURN; +[0-9]+ return NUMBER; [a-zA-Z][a-zA-Z0-9]* return ID; \"[a-zA-Z][a-zA-Z0-9]*\" return STRING_LITERAL; "+" return PLUS; diff --git a/wraith.y b/wraith.y index d048896..96a53f9 100644 --- a/wraith.y +++ b/wraith.y @@ -21,7 +21,7 @@ extern Node *rootNode; %define api.value.type {struct Node*} -%token NUMBER +%token VOID %token INT %token UINT %token FLOAT @@ -30,6 +30,7 @@ extern Node *rootNode; %token BOOL %token STRUCT %token RETURN +%token NUMBER %token ID %token STRING_LITERAL %token PLUS @@ -65,7 +66,7 @@ extern Node *rootNode; %left LEFT_PAREN RIGHT_PAREN %% -Program : Declarations +Program : TopLevelDeclarations { Node **declarations; Node *declarationSequence; @@ -79,7 +80,11 @@ Program : Declarations rootNode = declarationSequence; } -Type : INT +Type : VOID + { + $$ = MakeTypeNode(Void); + } + | INT { $$ = MakeTypeNode(Int); } @@ -167,6 +172,10 @@ ReturnStatement : RETURN Expression { $$ = MakeReturnStatementNode($2); } + | RETURN + { + $$ = MakeReturnVoidStatementNode(); + } FunctionCallExpression : Identifier LEFT_PAREN Arguments RIGHT_PAREN { @@ -208,7 +217,7 @@ Arguments : PrimaryExpression COMMA Arguments | ; -SignatureArguments : VariableDeclaration COMMA VariableDeclarations +SignatureArguments : VariableDeclaration COMMA SignatureArguments { AddNode(stack, $1); } @@ -249,16 +258,7 @@ FunctionDeclaration : FunctionSignature Body $$ = MakeFunctionDeclarationNode($1, $2); } -VariableDeclarations : VariableDeclaration SEMICOLON VariableDeclarations - { - AddNode(stack, $1); - } - | - { - PushStackFrame(stack); - } - -StructDeclaration : STRUCT Identifier LEFT_BRACE VariableDeclarations RIGHT_BRACE +StructDeclaration : STRUCT Identifier LEFT_BRACE Declarations RIGHT_BRACE { Node **declarations; Node *declarationSequence; @@ -271,8 +271,9 @@ StructDeclaration : STRUCT Identifier LEFT_BRACE VariableDeclarations RIGH PopStackFrame(stack); } -Declaration : StructDeclaration - | FunctionDeclaration + +Declaration : FunctionDeclaration + | VariableDeclaration SEMICOLON ; Declarations : Declaration Declarations @@ -283,4 +284,15 @@ Declarations : Declaration Declarations { PushStackFrame(stack); } + +TopLevelDeclaration : StructDeclaration; + +TopLevelDeclarations : TopLevelDeclaration TopLevelDeclarations + { + AddNode(stack, $1); + } + | + { + PushStackFrame(stack); + } %%