#include #include #include #include #include #include #include "y.tab.h" #include "ast.h" #include "stack.h" extern FILE *yyin; Stack *stack; Node *rootNode; typedef struct VariableMapValue { char *name; LLVMValueRef variable; } VariableMapValue; VariableMapValue *namedVariables; uint32_t namedVariableCount; static LLVMValueRef CompileExpression( LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *binaryExpression ); static void AddNamedVariable(char *name, LLVMValueRef variable) { VariableMapValue mapValue; mapValue.name = name; mapValue.variable = variable; namedVariables = realloc(namedVariables, namedVariableCount + 1); namedVariables[namedVariableCount] = mapValue; namedVariableCount += 1; } static LLVMValueRef FindVariableByName(char *name) { uint32_t i; for (i = 0; i < namedVariableCount; i += 1) { if (strcmp(namedVariables[i].name, name) == 0) { return namedVariables[i].variable; } } return NULL; } static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type) { switch (type) { case Int: return LLVMInt64Type(); case UInt: return LLVMInt64Type(); } return NULL; } static LLVMValueRef CompileNumber( LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *numberExpression ) { return LLVMConstInt(LLVMInt64Type(), numberExpression->value.number, 0); } static LLVMValueRef CompileBinaryExpression( LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *binaryExpression ) { LLVMValueRef left = CompileExpression(module, builder, function, binaryExpression->children[0]); LLVMValueRef right = CompileExpression(module, builder, function, binaryExpression->children[1]); switch (binaryExpression->operator.binaryOperator) { case Add: return LLVMBuildAdd(builder, left, right, "tmp"); } return NULL; } static LLVMValueRef CompileExpression( LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *expression ) { switch (expression->syntaxKind) { case BinaryExpression: return CompileBinaryExpression(module, builder, function, expression); case Identifier: return FindVariableByName(expression->value.string); case Number: return CompileNumber(module, builder, function, expression); } printf("Error: expected expression\n"); return NULL; } static void CompileReturn(LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement) { LLVMBuildRet(builder, CompileExpression(module, builder, function, returnStatemement->children[0])); } static void CompileStatement(LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *statement) { switch (statement->syntaxKind) { case Return: CompileReturn(module, builder, function, statement); break; } } static void CompileFunction(LLVMModuleRef module, Node *functionDeclaration) { uint32_t i; Node *functionSignature = functionDeclaration->children[0]; Node *functionBody = functionDeclaration->children[1]; LLVMTypeRef paramTypes[functionSignature->children[2]->childCount]; for (i = 0; i < functionSignature->children[2]->childCount; i += 1) { paramTypes[i] = WraithTypeToLLVMType(functionSignature->children[2]->children[i]->children[0]->type); } LLVMTypeRef functionType = LLVMFunctionType(WraithTypeToLLVMType(functionSignature->children[1]->type), paramTypes, functionSignature->children[2]->childCount, 0); LLVMValueRef function = LLVMAddFunction(module, functionSignature->children[0]->value.string, functionType); for (i = 0; i < functionSignature->children[2]->childCount; i += 1) { LLVMValueRef argument = LLVMGetParam(function, i); AddNamedVariable(functionSignature->children[2]->children[i]->children[1]->value.string, argument); } LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry"); LLVMBuilderRef builder = LLVMCreateBuilder(); LLVMPositionBuilderAtEnd(builder, entry); for (i = 0; i < functionBody->childCount; i += 1) { CompileStatement(module, builder, function, functionBody->children[i]); } } static void Compile(LLVMModuleRef module, Node *node) { uint32_t i; switch (node->syntaxKind) { case FunctionDeclaration: CompileFunction(module, node); break; } for (i = 0; i < node->childCount; i += 1) { Compile(module, node->children[i]); } } int main(int argc, char *argv[]) { if (argc < 2) { printf("Please provide a file.\n"); return 1; } namedVariables = NULL; namedVariableCount = 0; stack = CreateStack(); FILE *fp = fopen(argv[1], "r"); yyin = fp; yyparse(fp, stack); fclose(fp); PrintTree(rootNode, 0); LLVMModuleRef module = LLVMModuleCreateWithName("my_module"); Compile(module, rootNode); char *error = NULL; LLVMVerifyModule(module, LLVMAbortProcessAction, &error); LLVMDisposeMessage(error); if (LLVMWriteBitcodeToFile(module, "test.bc") != 0) { fprintf(stderr, "error writing bitcode to file\n"); } return 0; }