From 506ee9ecad0f6f2a9a12480ea6367cca591623f8 Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Wed, 2 Jun 2021 12:33:01 -0700 Subject: [PATCH] structure for system calls --- generators/wraith.lex | 1 + generators/wraith.y | 10 ++++ generic.w | 2 + src/ast.c | 13 +++++ src/ast.h | 10 ++++ src/codegen.c | 113 ++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 149 insertions(+) diff --git a/generators/wraith.lex b/generators/wraith.lex index f288c6e..44df0a7 100644 --- a/generators/wraith.lex +++ b/generators/wraith.lex @@ -40,6 +40,7 @@ ";" return SEMICOLON; ":" return COLON; "?" return QUESTION; +"@" return AT; "(" return LEFT_PAREN; ")" return RIGHT_PAREN; "[" return LEFT_BRACKET; diff --git a/generators/wraith.y b/generators/wraith.y index 3176d60..b3d4d57 100644 --- a/generators/wraith.y +++ b/generators/wraith.y @@ -54,6 +54,7 @@ extern FILE *yyin; %token SEMICOLON %token COLON %token QUESTION +%token AT %token LEFT_PAREN %token RIGHT_PAREN %token LEFT_BRACE @@ -141,6 +142,11 @@ AccessExpression : Identifier POINT AccessExpression $$ = $1; } +SystemCallExpression : AT Identifier + { + $$ = $2; + } + Number : NUMBER { $$ = MakeNumberNode(yytext); @@ -234,6 +240,10 @@ FunctionCallExpression : AccessExpression LEFT_PAREN Arguments RIGHT_PAREN { $$ = MakeFunctionCallExpressionNode($1, $3); } + | SystemCallExpression LEFT_PAREN Arguments RIGHT_PAREN + { + $$ = MakeSystemCallExpressionNode($1, $3); + } PartialStatement : FunctionCallExpression | AssignmentStatement diff --git a/generic.w b/generic.w index 73fb981..52e9d3c 100644 --- a/generic.w +++ b/generic.w @@ -13,6 +13,8 @@ struct Program { static Main(): int { x: int = 4; y: int = Foo.Func(x); + addr: uint = @malloc(y); + @free(addr); return x; } } diff --git a/src/ast.c b/src/ast.c index fddf74d..7901bbc 100644 --- a/src/ast.c +++ b/src/ast.c @@ -67,6 +67,8 @@ const char *SyntaxKindString(SyntaxKind syntaxKind) return "StringLiteral"; case StructDeclaration: return "StructDeclaration"; + case SystemCall: + return "SystemCall"; case Type: return "Type"; case UnaryExpression: @@ -426,6 +428,17 @@ Node *MakeFunctionCallExpressionNode( return node; } +Node *MakeSystemCallExpressionNode( + Node *identifierNode, + Node *argumentSequenceNode) +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = SystemCall; + node->systemCall.identifier = identifierNode; + node->systemCall.argumentSequence = argumentSequenceNode; + return node; +} + Node *MakeAccessExpressionNode(Node *accessee, Node *accessor) { Node *node = (Node *)malloc(sizeof(Node)); diff --git a/src/ast.h b/src/ast.h index 3d36cf6..977262b 100644 --- a/src/ast.h +++ b/src/ast.h @@ -44,6 +44,7 @@ typedef enum StaticModifier, StringLiteral, StructDeclaration, + SystemCall, Type, UnaryExpression } SyntaxKind; @@ -287,6 +288,12 @@ struct Node Node *declarationSequence; } structDeclaration; + struct + { + Node *identifier; + Node *argumentSequence; + } systemCall; + struct { Node *typeNode; @@ -349,6 +356,9 @@ Node *MakeEmptyFunctionArgumentSequenceNode(); Node *MakeFunctionCallExpressionNode( Node *identifierNode, Node *argumentSequenceNode); +Node *MakeSystemCallExpressionNode( + Node *identifierNode, + Node *argumentSequenceNode); Node *MakeAccessExpressionNode(Node *accessee, Node *accessor); Node *MakeAllocNode(Node *typeNode); Node *MakeIfNode(Node *expressionNode, Node *statementSequenceNode); diff --git a/src/codegen.c b/src/codegen.c index dd3bf73..dcd1d53 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -112,6 +112,16 @@ typedef struct StructTypeDeclaration StructTypeDeclaration *structTypeDeclarations; uint32_t structTypeDeclarationCount; +typedef struct SystemFunction +{ + char *name; + LLVMTypeRef functionType; + LLVMValueRef function; +} SystemFunction; + +SystemFunction *systemFunctions; +uint32_t systemFunctionCount; + /* FUNCTION FORWARD DECLARATIONS */ static LLVMBasicBlockRef CompileStatement( LLVMModuleRef module, @@ -262,6 +272,36 @@ static LLVMTypeRef ResolveType(TypeTag *typeTag) } } +static void AddSystemFunction( + char *name, + LLVMTypeRef functionType, + LLVMValueRef function) +{ + systemFunctions = realloc( + systemFunctions, + sizeof(SystemFunction) * (systemFunctionCount + 1)); + systemFunctions[systemFunctionCount].name = strdup(name); + systemFunctions[systemFunctionCount].functionType = functionType; + systemFunctions[systemFunctionCount].function = function; + systemFunctionCount += 1; +} + +static SystemFunction *LookupSystemFunction(char *name) +{ + uint32_t i; + + for (i = 0; i < systemFunctionCount; i += 1) + { + if (strcmp(name, systemFunctions[i].name) == 0) + { + return &systemFunctions[i]; + } + } + + fprintf(stderr, "%s %s %s", "System function", name, "not found!"); + return NULL; +} + static void AddLocalVariable( Scope *scope, LLVMValueRef pointer, /* can be NULL */ @@ -1133,6 +1173,51 @@ static LLVMValueRef CompileFunctionCallExpression( return LLVMBuildCall(builder, function, args, argumentCount, returnName); } +static LLVMValueRef CompileSystemCallExpression( + LLVMModuleRef module, + LLVMBuilderRef builder, + Node *systemCallExpression) +{ + uint32_t i; + uint32_t argumentCount = systemCallExpression->systemCall.argumentSequence + ->functionArgumentSequence.count; + LLVMValueRef args[argumentCount]; + char *returnName = ""; + + for (i = 0; i < systemCallExpression->systemCall.argumentSequence + ->functionArgumentSequence.count; + i += 1) + { + args[i] = CompileExpression( + module, + builder, + systemCallExpression->systemCall.argumentSequence + ->functionArgumentSequence.sequence[i]); + } + + SystemFunction *systemFunction = LookupSystemFunction( + systemCallExpression->systemCall.identifier->identifier.name); + + if (systemFunction == NULL) + { + fprintf(stderr, "System function not found!"); + return NULL; + } + + if (LLVMGetTypeKind(LLVMGetReturnType(systemFunction->functionType)) != + LLVMVoidTypeKind) + { + returnName = "callReturn"; + } + + return LLVMBuildCall( + builder, + systemFunction->function, + args, + argumentCount, + returnName); +} + static LLVMValueRef CompileAccessExpressionForStore( LLVMBuilderRef builder, Node *accessExpression) @@ -1196,6 +1281,9 @@ static LLVMValueRef CompileExpression( case StringLiteral: return CompileString(builder, expression); + + case SystemCall: + return CompileSystemCallExpression(module, builder, expression); } fprintf(stderr, "Unknown expression kind!\n"); @@ -1516,6 +1604,10 @@ static LLVMBasicBlockRef CompileStatement( case ReturnVoid: return CompileReturnVoid(builder, function); + + case SystemCall: + CompileSystemCallExpression(module, builder, statement); + return LLVMGetLastBasicBlock(function); } fprintf(stderr, "Unknown statement kind!\n"); @@ -1812,6 +1904,24 @@ static void RegisterLibraryFunctions( LLVMInt8Type(), 1, "PrintLine"); + + LLVMTypeRef mallocArg = LLVMInt64Type(); + LLVMTypeRef mallocFunctionType = + LLVMFunctionType(LLVMInt64Type(), &mallocArg, 1, 0); + LLVMValueRef mallocFunction = + LLVMAddFunction(module, "malloc", mallocFunctionType); + LLVMSetLinkage(mallocFunction, LLVMExternalLinkage); + + AddSystemFunction("malloc", mallocFunctionType, mallocFunction); + + LLVMTypeRef freeArg = LLVMInt64Type(); + LLVMTypeRef freeFunctionType = + LLVMFunctionType(LLVMVoidType(), &freeArg, 1, 0); + LLVMValueRef freeFunction = + LLVMAddFunction(module, "free", freeFunctionType); + LLVMSetLinkage(freeFunction, LLVMExternalLinkage); + + AddSystemFunction("free", freeFunctionType, freeFunction); } int Codegen(Node *node, uint32_t optimizationLevel) @@ -1821,6 +1931,9 @@ int Codegen(Node *node, uint32_t optimizationLevel) structTypeDeclarations = NULL; structTypeDeclarationCount = 0; + systemFunctions = NULL; + systemFunctionCount = 0; + LLVMModuleRef module = LLVMModuleCreateWithName("my_module"); LLVMContextRef context = LLVMGetGlobalContext();