From acb6c6192269e28c271b5ededb07c7c000209386 Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Thu, 22 Apr 2021 17:19:35 -0700 Subject: [PATCH] function calls --- CMakeLists.txt | 2 +- compiler.c | 165 +++++++++++++++++++++++++++++++++++++++---------- wraith.y | 11 +++- 3 files changed, 141 insertions(+), 37 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a042da7..bb42a92 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,7 +23,7 @@ find_package(LLVM) include_directories(${CMAKE_SOURCE_DIR}) -BISON_TARGET(Parser wraith.y ${CMAKE_CURRENT_BINARY_DIR}/y.tab.c COMPILE_FLAGS -d) +BISON_TARGET(Parser wraith.y ${CMAKE_CURRENT_BINARY_DIR}/y.tab.c COMPILE_FLAGS "-d -v -t") FLEX_TARGET(Scanner wraith.lex ${CMAKE_CURRENT_BINARY_DIR}/lex.yy.c) ADD_FLEX_BISON_DEPENDENCY(Scanner Parser) diff --git a/compiler.c b/compiler.c index cd66052..7531728 100644 --- a/compiler.c +++ b/compiler.c @@ -53,12 +53,22 @@ typedef struct StructTypeField uint32_t index; } StructTypeField; +typedef struct StructTypeFunction +{ + char *name; + LLVMValueRef function; + LLVMTypeRef returnType; + uint8_t isStatic; +} StructTypeFunction; + typedef struct StructTypeFieldDeclaration { LLVMTypeRef structType; StructTypeField *fields; uint32_t fieldCount; + StructTypeFunction *functions; + uint32_t functionCount; } StructTypeFieldDeclaration; StructTypeFieldDeclaration *structTypeFieldDeclarations; @@ -198,6 +208,8 @@ static void AddStructDeclaration( structTypeFieldDeclarations[index].structType = wStructType; structTypeFieldDeclarations[index].fields = NULL; structTypeFieldDeclarations[index].fieldCount = 0; + structTypeFieldDeclarations[index].functions = NULL; + structTypeFieldDeclarations[index].functionCount = 0; for (i = 0; i < fieldDeclarationCount; i += 1) { @@ -210,7 +222,63 @@ static void AddStructDeclaration( structTypeFieldDeclarationCount += 1; } -static void AddStructVariables( +static void DeclareStructFunction( + LLVMTypeRef wStructType, + LLVMValueRef function, + LLVMTypeRef returnType, + uint8_t isStatic, + char *name +) { + uint32_t i, index; + + for (i = 0; i < structTypeFieldDeclarationCount; i += 1) + { + if (structTypeFieldDeclarations[i].structType == wStructType) + { + index = structTypeFieldDeclarations[i].functionCount; + structTypeFieldDeclarations[i].functions = realloc(structTypeFieldDeclarations[i].functions, sizeof(StructTypeFunction) * (structTypeFieldDeclarations[i].functionCount + 1)); + structTypeFieldDeclarations[i].functions[index].name = strdup(name); + structTypeFieldDeclarations[i].functions[index].function = function; + structTypeFieldDeclarations[i].functions[index].returnType = returnType; + structTypeFieldDeclarations[i].functions[index].isStatic = isStatic; + structTypeFieldDeclarations[i].functionCount += 1; + + return; + } + } + + fprintf(stderr, "Could not find struct type for function!"); +} + +static LLVMValueRef LookupFunction( + LLVMValueRef structPointer, + char *name, + LLVMTypeRef *pReturnType, + uint8_t *pStatic +) { + uint32_t i, j; + + for (i = 0; i < structTypeFieldDeclarationCount; i += 1) + { + if (structTypeFieldDeclarations[i].structType == LLVMTypeOf(structPointer)) + { + for (j = 0; j < structTypeFieldDeclarations[i].functionCount; j += 1) + { + if (strcmp(structTypeFieldDeclarations[i].functions[j].name, name) == 0) + { + *pReturnType = structTypeFieldDeclarations[i].functions[j].returnType; + *pStatic = structTypeFieldDeclarations[i].functions[j].isStatic; + return structTypeFieldDeclarations[i].functions[j].function; + } + } + } + } + + fprintf(stderr, "Could not find struct function!"); + return NULL; +} + +static void AddStructVariablesToScope( LLVMBuilderRef builder, LLVMValueRef structPointer ) { @@ -244,7 +312,6 @@ static void AddStructVariables( static LLVMValueRef CompileExpression( LLVMBuilderRef builder, - LLVMValueRef function, Node *binaryExpression ); @@ -309,11 +376,10 @@ static LLVMValueRef CompileNumber( static LLVMValueRef CompileBinaryExpression( LLVMBuilderRef builder, - LLVMValueRef function, Node *binaryExpression ) { - LLVMValueRef left = CompileExpression(builder, function, binaryExpression->children[0]); - LLVMValueRef right = CompileExpression(builder, function, binaryExpression->children[1]); + LLVMValueRef left = CompileExpression(builder, binaryExpression->children[0]); + LLVMValueRef right = CompileExpression(builder, binaryExpression->children[1]); switch (binaryExpression->operator.binaryOperator) { @@ -334,25 +400,51 @@ static LLVMValueRef CompileBinaryExpression( /* FIXME THIS IS ALL BROKEN */ static LLVMValueRef CompileFunctionCallExpression( LLVMBuilderRef builder, - LLVMValueRef function, Node *expression ) { uint32_t i; - uint32_t argumentCount = expression->children[1]->childCount; + uint32_t argumentCount = 0; LLVMValueRef args[argumentCount]; + LLVMValueRef function; + uint8_t isStatic; + LLVMValueRef structInstance; + LLVMTypeRef functionReturnType; + char *returnName = ""; - for (i = 0; i < argumentCount; i += 1) + /* FIXME: this needs to be recursive on access chains */ + if (expression->children[0]->syntaxKind == AccessExpression) { - args[i] = CompileExpression(builder, function, expression->children[1]->children[i]); + structInstance = FindVariablePointer(expression->children[0]->children[0]->value.string); + function = LookupFunction(structInstance, expression->children[0]->children[1]->value.string, &functionReturnType, &isStatic); + } + else + { + fprintf(stderr, "Failed to find function!\n"); + return NULL; } - //return LLVMBuildCall(builder, FindVariableValueByName(builder, wStructValue, expression->children[0]->value.string), args, argumentCount, "tmp"); - return NULL; + if (!isStatic) + { + args[argumentCount] = structInstance; + argumentCount += 1; + } + + for (i = 0; i < expression->children[1]->childCount; i += 1) + { + args[argumentCount] = CompileExpression(builder, expression->children[1]->children[i]); + argumentCount += 1; + } + + if (LLVMGetTypeKind(functionReturnType) != LLVMVoidTypeKind) + { + returnName = "callReturn"; + } + + return LLVMBuildCall(builder, function, args, argumentCount, returnName); } static LLVMValueRef CompileAccessExpressionForStore( LLVMBuilderRef builder, - LLVMValueRef function, Node *expression ) { Node *accessee = expression->children[0]; @@ -363,7 +455,6 @@ static LLVMValueRef CompileAccessExpressionForStore( static LLVMValueRef CompileAccessExpression( LLVMBuilderRef builder, - LLVMValueRef function, Node *expression ) { Node *accessee = expression->children[0]; @@ -375,19 +466,18 @@ static LLVMValueRef CompileAccessExpression( static LLVMValueRef CompileExpression( LLVMBuilderRef builder, - LLVMValueRef function, Node *expression ) { switch (expression->syntaxKind) { case AccessExpression: - return CompileAccessExpression(builder, function, expression); + return CompileAccessExpression(builder, expression); case BinaryExpression: - return CompileBinaryExpression(builder, function, expression); + return CompileBinaryExpression(builder, expression); case FunctionCallExpression: - return CompileFunctionCallExpression(builder, function, expression); + return CompileFunctionCallExpression(builder, expression); case Identifier: return FindVariableValue(builder, expression->value.string); @@ -400,9 +490,9 @@ static LLVMValueRef CompileExpression( return NULL; } -static void CompileReturn(LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement) +static void CompileReturn(LLVMBuilderRef builder, Node *returnStatemement) { - LLVMValueRef expression = CompileExpression(builder, function, returnStatemement->children[0]); + LLVMValueRef expression = CompileExpression(builder, returnStatemement->children[0]); LLVMBuildRet(builder, expression); } @@ -411,13 +501,13 @@ static void CompileReturnVoid(LLVMBuilderRef builder) LLVMBuildRetVoid(builder); } -static void CompileAssignment(LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement) +static void CompileAssignment(LLVMBuilderRef builder, Node *assignmentStatement) { - LLVMValueRef result = CompileExpression(builder, function, assignmentStatement->children[1]); + LLVMValueRef result = CompileExpression(builder, assignmentStatement->children[1]); LLVMValueRef identifier; if (assignmentStatement->children[0]->syntaxKind == AccessExpression) { - identifier = CompileAccessExpressionForStore(builder, function, assignmentStatement->children[0]); + identifier = CompileAccessExpressionForStore(builder, assignmentStatement->children[0]); } else if (assignmentStatement->children[0]->syntaxKind == Identifier) { @@ -459,7 +549,11 @@ static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, N switch (statement->syntaxKind) { case Assignment: - CompileAssignment(builder, function, statement); + CompileAssignment(builder, statement); + return 0; + + case FunctionCallExpression: + CompileFunctionCallExpression(builder, statement); return 0; case Declaration: @@ -467,7 +561,7 @@ static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, N return 0; case Return: - CompileReturn(builder, function, statement); + CompileReturn(builder, statement); return 1; case ReturnVoid: @@ -531,6 +625,8 @@ static void CompileFunction( LLVMValueRef function = LLVMAddFunction(module, functionName, functionType); free(functionName); + DeclareStructFunction(wStructPointerType, function, returnType, isStatic, functionSignature->children[0]->value.string); + LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry"); LLVMBuilderRef builder = LLVMCreateBuilder(); LLVMPositionBuilderAtEnd(builder, entry); @@ -538,7 +634,7 @@ static void CompileFunction( if (!isStatic) { LLVMValueRef wStructPointer = LLVMGetParam(function, 0); - AddStructVariables(builder, wStructPointer); + AddStructVariablesToScope(builder, wStructPointer); } for (i = 0; i < functionSignature->children[2]->childCount; i += 1) @@ -626,16 +722,16 @@ static void Compile(LLVMModuleRef module, LLVMContextRef context, Node *node) { uint32_t i; - switch (node->syntaxKind) - { - case StructDeclaration: - CompileStruct(module, context, node); - break; - } - for (i = 0; i < node->childCount; i += 1) { - Compile(module, context, node->children[i]); + if (node->children[i]->syntaxKind == StructDeclaration) + { + CompileStruct(module, context, node->children[i]); + } + else + { + fprintf(stderr, "top level declarations that are not structs are forbidden!\n"); + } } } @@ -647,6 +743,9 @@ int main(int argc, char *argv[]) return 1; } + extern int yydebug; + yydebug = 1; + scope = CreateScope(); structTypeFieldDeclarations = NULL; diff --git a/wraith.y b/wraith.y index 57415dc..ffc2fe0 100644 --- a/wraith.y +++ b/wraith.y @@ -62,6 +62,8 @@ extern Node *rootNode; %parse-param { FILE* fp } { Stack *stack } +%define parse.error verbose + %left PLUS MINUS %left BANG %left LEFT_PAREN RIGHT_PAREN @@ -195,7 +197,7 @@ ReturnStatement : RETURN Expression $$ = MakeReturnVoidStatementNode(); } -FunctionCallExpression : Identifier LEFT_PAREN Arguments RIGHT_PAREN +FunctionCallExpression : AccessExpression LEFT_PAREN Arguments RIGHT_PAREN { Node **arguments; uint32_t argumentCount; @@ -206,9 +208,9 @@ FunctionCallExpression : Identifier LEFT_PAREN Arguments RIGHT_PAREN PopStackFrame(stack); } -PartialStatement : AssignmentStatement +PartialStatement : FunctionCallExpression + | AssignmentStatement | VariableDeclaration - | FunctionCallExpression | ReturnStatement ; @@ -233,6 +235,9 @@ Arguments : PrimaryExpression COMMA Arguments AddNode(stack, $1); } | + { + PushStackFrame(stack); + } ; SignatureArguments : VariableDeclaration COMMA SignatureArguments