#include #include #include #include #include #include #include "y.tab.h" #include "ast.h" #include "stack.h" extern FILE *yyin; Stack *stack; Node *rootNode; typedef struct StructFieldMapValue { char *name; LLVMValueRef value; LLVMValueRef valuePointer; uint32_t index; uint8_t needsWrite; uint8_t needsRead; } StructFieldMapValue; typedef struct StructFieldMap { LLVMValueRef structPointer; StructFieldMapValue *fields; uint32_t fieldCount; } StructFieldMap; StructFieldMap *structFieldMaps; uint32_t structFieldMapCount; static void AddStruct(LLVMValueRef wStructPointer) { structFieldMaps = realloc(structFieldMaps, sizeof(StructFieldMap) * (structFieldMapCount + 1)); structFieldMaps[structFieldMapCount].structPointer = wStructPointer; structFieldMaps[structFieldMapCount].fields = NULL; structFieldMaps[structFieldMapCount].fieldCount = 0; structFieldMapCount += 1; } static void AddStructFieldName(LLVMBuilderRef builder, LLVMValueRef wStructPointer, char *name, uint32_t index) { uint32_t i, fieldCount; for (i = 0; i < structFieldMapCount; i += 1) { if (structFieldMaps[i].structPointer == wStructPointer) { fieldCount = structFieldMaps[i].fieldCount; structFieldMaps[i].fields = realloc(structFieldMaps[i].fields, sizeof(StructFieldMapValue) * (fieldCount + 1)); structFieldMaps[i].fields[fieldCount].name = strdup(name); structFieldMaps[i].fields[fieldCount].value = NULL; structFieldMaps[i].fields[fieldCount].valuePointer = NULL; structFieldMaps[i].fields[fieldCount].index = index; structFieldMaps[i].fields[fieldCount].needsWrite = 0; structFieldMaps[i].fields[fieldCount].needsRead = 1; structFieldMaps[i].fieldCount += 1; break; } } } static LLVMValueRef CheckStructFieldAndLoad(LLVMBuilderRef builder, LLVMValueRef wStructPointer, char *name) { uint32_t i, j; for (i = 0; i < structFieldMapCount; i += 1) { if (structFieldMaps[i].structPointer == wStructPointer) { for (j = 0; j < structFieldMaps[i].fieldCount; j += 1) { if (strcmp(structFieldMaps[i].fields[j].name, name) == 0) { if (structFieldMaps[i].fields[j].needsRead) { char *ptrName = strdup(name); strcat(ptrName, "_ptr"); LLVMValueRef elementPointer = LLVMBuildStructGEP( builder, wStructPointer, structFieldMaps[i].fields[j].index, ptrName ); free(ptrName); structFieldMaps[i].fields[j].value = LLVMBuildLoad(builder, elementPointer, name); structFieldMaps[i].fields[j].valuePointer = elementPointer; structFieldMaps[i].fields[j].needsRead = 0; } return structFieldMaps[i].fields[j].value; } } } } return NULL; } static void MarkStructFieldForWrite(LLVMValueRef wStructPointer, LLVMValueRef value) { uint32_t i, j; for (i = 0; i < structFieldMapCount; i += 1) { if (structFieldMaps[i].structPointer == wStructPointer) { for (j = 0; j < structFieldMaps[i].fieldCount; j += 1) { if (structFieldMaps[i].fields[j].value == value) { structFieldMaps[i].fields[j].needsWrite = 1; break; } } } } } static LLVMValueRef GetStructFieldPointer(LLVMValueRef wStructPointer, LLVMValueRef value) { uint32_t i, j; for (i = 0; i < structFieldMapCount; i += 1) { if (structFieldMaps[i].structPointer == wStructPointer) { for (j = 0; j < structFieldMaps[i].fieldCount; j += 1) { if (structFieldMaps[i].fields[j].value == value) { return structFieldMaps[i].fields[j].valuePointer; } } } } return NULL; } static void RemoveStruct(LLVMBuilderRef builder, LLVMValueRef wStructPointer) { uint32_t i, j; for (i = 0; i < structFieldMapCount; i += 1) { if (structFieldMaps[i].structPointer == wStructPointer) { for (j = 0; j < structFieldMaps[i].fieldCount; j += 1) { if (structFieldMaps[i].fields[j].needsWrite) { LLVMBuildStore( builder, structFieldMaps[i].fields[j].value, structFieldMaps[i].fields[j].valuePointer ); } } free(structFieldMaps[i].fields); structFieldMaps[i].fields = NULL; structFieldMaps[i].fieldCount = 0; break; } } } typedef struct IdentifierMapValue { char *name; LLVMValueRef value; } IdentifierMapValue; IdentifierMapValue *namedVariables; uint32_t namedVariableCount; static LLVMValueRef CompileExpression( LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *binaryExpression ); static void AddNamedVariable(char *name, LLVMValueRef variable) { IdentifierMapValue mapValue; mapValue.name = name; mapValue.value = variable; namedVariables = realloc(namedVariables, sizeof(IdentifierMapValue) * (namedVariableCount + 1)); namedVariables[namedVariableCount] = mapValue; namedVariableCount += 1; } static LLVMValueRef FindVariableByName(LLVMBuilderRef builder, LLVMValueRef wStructValue, char *name) { uint32_t i, j; LLVMValueRef searchResult; /* first, search scoped vars */ for (i = 0; i < namedVariableCount; i += 1) { if (strcmp(namedVariables[i].name, name) == 0) { return namedVariables[i].value; } } /* if none exist, search struct vars */ searchResult = CheckStructFieldAndLoad(builder, wStructValue, name); if (searchResult == NULL) { fprintf(stderr, "Identifier not found!"); } return searchResult; } typedef struct CustomTypeMap { LLVMTypeRef type; char *name; } CustomTypeMap; CustomTypeMap *customTypes; uint32_t customTypeCount; static void RegisterCustomType(LLVMTypeRef type, char *name) { customTypes = realloc(customTypes, sizeof(CustomType) * (customTypeCount + 1)); customTypes[customTypeCount].type = type; customTypes[customTypeCount].name = strdup(name); customTypeCount += 1; } static LLVMTypeRef LookupCustomType(char *name) { uint32_t i; for (i = 0; i < customTypeCount; i += 1) { if (strcmp(customTypes[i].name, name) == 0) { return customTypes[i].type; } } return NULL; } static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type) { switch (type) { case Int: return LLVMInt64Type(); case UInt: return LLVMInt64Type(); case Bool: return LLVMInt1Type(); case Void: return LLVMVoidType(); } fprintf(stderr, "Unrecognized type!"); return NULL; } static LLVMValueRef CompileNumber( Node *numberExpression ) { return LLVMConstInt(LLVMInt64Type(), numberExpression->value.number, 0); } static LLVMValueRef CompileBinaryExpression( LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *binaryExpression ) { LLVMValueRef left = CompileExpression(wStructValue, builder, function, binaryExpression->children[0]); LLVMValueRef right = CompileExpression(wStructValue, builder, function, binaryExpression->children[1]); switch (binaryExpression->operator.binaryOperator) { case Add: return LLVMBuildAdd(builder, left, right, "tmp"); case Subtract: return LLVMBuildSub(builder, left, right, "tmp"); case Multiply: return LLVMBuildMul(builder, left, right, "tmp"); } return NULL; } static LLVMValueRef CompileFunctionCallExpression( LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *expression ) { uint32_t i; uint32_t argumentCount = expression->children[1]->childCount; LLVMValueRef args[argumentCount]; for (i = 0; i < argumentCount; i += 1) { args[i] = CompileExpression(wStructValue, builder, function, expression->children[1]->children[i]); } return LLVMBuildCall(builder, FindVariableByName(builder, wStructValue, expression->children[0]->value.string), args, argumentCount, "tmp"); } static LLVMValueRef CompileExpression( LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *expression ) { LLVMValueRef var; switch (expression->syntaxKind) { case BinaryExpression: return CompileBinaryExpression(wStructValue, builder, function, expression); case FunctionCallExpression: return CompileFunctionCallExpression(wStructValue, builder, function, expression); case Identifier: return FindVariableByName(builder, wStructValue, expression->value.string); case Number: return CompileNumber(expression); } fprintf(stderr, "Unknown expression kind!\n"); return NULL; } static void CompileReturn(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement) { LLVMBuildRet(builder, CompileExpression(wStructValue, builder, function, returnStatemement->children[0])); } static void CompileReturnVoid(LLVMBuilderRef builder) { LLVMBuildRetVoid(builder); } static void CompileAssignment(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement) { LLVMValueRef fieldPointer; LLVMValueRef result = CompileExpression(wStructValue, builder, function, assignmentStatement->children[1]); LLVMValueRef identifier = CompileExpression(wStructValue, builder, function, assignmentStatement->children[0]); MarkStructFieldForWrite(wStructValue, identifier); } static void CompileFunctionVariableDeclaration(LLVMBuilderRef builder, Node *variableDeclaration) { char *variableName = variableDeclaration->children[1]->value.string; LLVMValueRef variable; if (variableDeclaration->children[0]->type == CustomType) { char *customTypeName = variableDeclaration->children[0]->children[0]->value.string; variable = LLVMBuildAlloca(builder, LookupCustomType(customTypeName), variableName); } else { variable = LLVMBuildAlloca(builder, WraithTypeToLLVMType(variableDeclaration->children[0]->type), variableName); } AddNamedVariable(variableName, variable); } static uint8_t CompileStatement(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *statement) { switch (statement->syntaxKind) { case Assignment: CompileAssignment(wStructValue, builder, function, statement); return 0; case Declaration: CompileFunctionVariableDeclaration(builder, statement); return 0; case Return: CompileReturn(wStructValue, builder, function, statement); return 1; case ReturnVoid: CompileReturnVoid(builder); return 1; } fprintf(stderr, "Unknown statement kind!\n"); return 0; } static void CompileFunction( LLVMModuleRef module, LLVMTypeRef wStructPointerType, Node **fieldDeclarations, uint32_t fieldDeclarationCount, Node *functionDeclaration ) { uint32_t i; uint8_t hasReturn = 0; Node *functionSignature = functionDeclaration->children[0]; Node *functionBody = functionDeclaration->children[1]; uint32_t argumentCount = functionSignature->children[2]->childCount + 1; /* struct is implicit argument */ LLVMTypeRef paramTypes[argumentCount]; paramTypes[0] = wStructPointerType; for (i = 0; i < functionSignature->children[2]->childCount; i += 1) { paramTypes[i + 1] = WraithTypeToLLVMType(functionSignature->children[2]->children[i]->children[0]->type); } LLVMTypeRef returnType = WraithTypeToLLVMType(functionSignature->children[1]->type); LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, argumentCount, 0); LLVMValueRef function = LLVMAddFunction(module, functionSignature->children[0]->value.string, functionType); LLVMValueRef wStructPointer = LLVMGetParam(function, 0); for (i = 0; i < functionSignature->children[2]->childCount; i += 1) { LLVMValueRef argument = LLVMGetParam(function, i + 1); AddNamedVariable(functionSignature->children[2]->children[i]->children[1]->value.string, argument); } LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry"); LLVMBuilderRef builder = LLVMCreateBuilder(); LLVMPositionBuilderAtEnd(builder, entry); /* FIXME: replace this with a scope abstraction */ AddStruct(wStructPointer); for (i = 0; i < fieldDeclarationCount; i += 1) { AddStructFieldName(builder, wStructPointer, fieldDeclarations[i]->children[1]->value.string, i); } for (i = 0; i < functionBody->childCount; i += 1) { hasReturn |= CompileStatement(wStructPointer, builder, function, functionBody->children[i]); } RemoveStruct(builder, wStructPointer); if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn) { LLVMBuildRetVoid(builder); } else if (LLVMGetTypeKind(returnType) != LLVMVoidTypeKind && !hasReturn) { fprintf(stderr, "Return statement not provided!"); } } static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *node) { uint32_t i; uint32_t fieldCount = 0; uint32_t declarationCount = node->children[1]->childCount; uint8_t packed = 1; LLVMTypeRef types[declarationCount]; Node *currentDeclarationNode; Node *fieldDeclarations[declarationCount]; LLVMTypeRef wStruct = LLVMStructCreateNamed(context, node->children[0]->value.string); LLVMTypeRef wStructPointerType = LLVMPointerType(wStruct, 0); /* FIXME: is this address space correct? */ /* first, build the structure definition */ for (i = 0; i < declarationCount; i += 1) { currentDeclarationNode = node->children[1]->children[i]; switch (currentDeclarationNode->syntaxKind) { case Declaration: /* this is badly named */ types[fieldCount] = WraithTypeToLLVMType(currentDeclarationNode->children[0]->type); fieldDeclarations[fieldCount] = currentDeclarationNode; fieldCount += 1; break; } } LLVMStructSetBody(wStruct, types, fieldCount, packed); /* now we can wire up the functions */ for (i = 0; i < declarationCount; i += 1) { currentDeclarationNode = node->children[1]->children[i]; switch (currentDeclarationNode->syntaxKind) { case FunctionDeclaration: CompileFunction(module, wStructPointerType, fieldDeclarations, fieldCount, currentDeclarationNode); break; } } RegisterCustomType(wStruct, node->children[0]->value.string); } static void Compile(LLVMModuleRef module, LLVMContextRef context, Node *node) { uint32_t i; switch (node->syntaxKind) { case StructDeclaration: CompileStruct(module, context, node); break; } for (i = 0; i < node->childCount; i += 1) { Compile(module, context, node->children[i]); } } int main(int argc, char *argv[]) { if (argc < 2) { printf("Please provide a file.\n"); return 1; } namedVariables = NULL; namedVariableCount = 0; structFieldMaps = NULL; structFieldMapCount = 0; customTypes = NULL; customTypeCount = 0; stack = CreateStack(); FILE *fp = fopen(argv[1], "r"); yyin = fp; yyparse(fp, stack); fclose(fp); PrintTree(rootNode, 0); LLVMModuleRef module = LLVMModuleCreateWithName("my_module"); LLVMContextRef context = LLVMGetGlobalContext(); Compile(module, context, rootNode); char *error = NULL; LLVMVerifyModule(module, LLVMAbortProcessAction, &error); LLVMDisposeMessage(error); if (LLVMWriteBitcodeToFile(module, "test.bc") != 0) { fprintf(stderr, "error writing bitcode to file\n"); } return 0; }