wraith-lang/compiler.c

588 lines
17 KiB
C

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <llvm-c/Core.h>
#include <llvm-c/Analysis.h>
#include <llvm-c/BitWriter.h>
#include "y.tab.h"
#include "ast.h"
#include "stack.h"
extern FILE *yyin;
Stack *stack;
Node *rootNode;
typedef struct StructFieldMapValue
{
char *name;
LLVMValueRef value;
LLVMValueRef valuePointer;
uint32_t index;
uint8_t needsWrite;
uint8_t needsRead;
} StructFieldMapValue;
typedef struct StructFieldMap
{
LLVMValueRef structPointer;
StructFieldMapValue *fields;
uint32_t fieldCount;
} StructFieldMap;
StructFieldMap *structFieldMaps;
uint32_t structFieldMapCount;
static void AddStruct(LLVMValueRef wStructPointer)
{
structFieldMaps = realloc(structFieldMaps, sizeof(StructFieldMap) * (structFieldMapCount + 1));
structFieldMaps[structFieldMapCount].structPointer = wStructPointer;
structFieldMaps[structFieldMapCount].fields = NULL;
structFieldMaps[structFieldMapCount].fieldCount = 0;
structFieldMapCount += 1;
}
static void AddStructFieldName(LLVMBuilderRef builder, LLVMValueRef wStructPointer, char *name, uint32_t index)
{
uint32_t i, fieldCount;
for (i = 0; i < structFieldMapCount; i += 1)
{
if (structFieldMaps[i].structPointer == wStructPointer)
{
fieldCount = structFieldMaps[i].fieldCount;
structFieldMaps[i].fields = realloc(structFieldMaps[i].fields, sizeof(StructFieldMapValue) * (fieldCount + 1));
structFieldMaps[i].fields[fieldCount].name = strdup(name);
structFieldMaps[i].fields[fieldCount].value = NULL;
structFieldMaps[i].fields[fieldCount].valuePointer = NULL;
structFieldMaps[i].fields[fieldCount].index = index;
structFieldMaps[i].fields[fieldCount].needsWrite = 0;
structFieldMaps[i].fields[fieldCount].needsRead = 1;
structFieldMaps[i].fieldCount += 1;
break;
}
}
}
static LLVMValueRef CheckStructFieldAndLoad(LLVMBuilderRef builder, LLVMValueRef wStructPointer, char *name)
{
uint32_t i, j;
for (i = 0; i < structFieldMapCount; i += 1)
{
if (structFieldMaps[i].structPointer == wStructPointer)
{
for (j = 0; j < structFieldMaps[i].fieldCount; j += 1)
{
if (strcmp(structFieldMaps[i].fields[j].name, name) == 0)
{
if (structFieldMaps[i].fields[j].needsRead)
{
char *ptrName = strdup(name);
strcat(ptrName, "_ptr");
LLVMValueRef elementPointer = LLVMBuildStructGEP(
builder,
wStructPointer,
structFieldMaps[i].fields[j].index,
ptrName
);
free(ptrName);
structFieldMaps[i].fields[j].value = LLVMBuildLoad(builder, elementPointer, name);
structFieldMaps[i].fields[j].valuePointer = elementPointer;
structFieldMaps[i].fields[j].needsRead = 0;
}
return structFieldMaps[i].fields[j].value;
}
}
}
}
return NULL;
}
static void MarkStructFieldForWrite(LLVMValueRef wStructPointer, LLVMValueRef value)
{
uint32_t i, j;
for (i = 0; i < structFieldMapCount; i += 1)
{
if (structFieldMaps[i].structPointer == wStructPointer)
{
for (j = 0; j < structFieldMaps[i].fieldCount; j += 1)
{
if (structFieldMaps[i].fields[j].value == value)
{
structFieldMaps[i].fields[j].needsWrite = 1;
break;
}
}
}
}
}
static LLVMValueRef GetStructFieldPointer(LLVMValueRef wStructPointer, LLVMValueRef value)
{
uint32_t i, j;
for (i = 0; i < structFieldMapCount; i += 1)
{
if (structFieldMaps[i].structPointer == wStructPointer)
{
for (j = 0; j < structFieldMaps[i].fieldCount; j += 1)
{
if (structFieldMaps[i].fields[j].value == value)
{
return structFieldMaps[i].fields[j].valuePointer;
}
}
}
}
return NULL;
}
static void RemoveStruct(LLVMBuilderRef builder, LLVMValueRef wStructPointer)
{
uint32_t i, j;
for (i = 0; i < structFieldMapCount; i += 1)
{
if (structFieldMaps[i].structPointer == wStructPointer)
{
for (j = 0; j < structFieldMaps[i].fieldCount; j += 1)
{
if (structFieldMaps[i].fields[j].needsWrite)
{
LLVMBuildStore(
builder,
structFieldMaps[i].fields[j].value,
structFieldMaps[i].fields[j].valuePointer
);
}
}
free(structFieldMaps[i].fields);
structFieldMaps[i].fields = NULL;
structFieldMaps[i].fieldCount = 0;
break;
}
}
}
typedef struct IdentifierMapValue
{
char *name;
LLVMValueRef value;
} IdentifierMapValue;
IdentifierMapValue *namedVariables;
uint32_t namedVariableCount;
static LLVMValueRef CompileExpression(
LLVMValueRef wStructValue,
LLVMBuilderRef builder,
LLVMValueRef function,
Node *binaryExpression
);
static void AddNamedVariable(char *name, LLVMValueRef variable)
{
IdentifierMapValue mapValue;
mapValue.name = name;
mapValue.value = variable;
namedVariables = realloc(namedVariables, sizeof(IdentifierMapValue) * (namedVariableCount + 1));
namedVariables[namedVariableCount] = mapValue;
namedVariableCount += 1;
}
static LLVMValueRef FindVariableByName(LLVMBuilderRef builder, LLVMValueRef wStructValue, char *name)
{
uint32_t i, j;
LLVMValueRef searchResult;
/* first, search scoped vars */
for (i = 0; i < namedVariableCount; i += 1)
{
if (strcmp(namedVariables[i].name, name) == 0)
{
return namedVariables[i].value;
}
}
/* if none exist, search struct vars */
searchResult = CheckStructFieldAndLoad(builder, wStructValue, name);
if (searchResult == NULL)
{
fprintf(stderr, "Identifier not found!");
}
return searchResult;
}
typedef struct CustomTypeMap
{
LLVMTypeRef type;
char *name;
} CustomTypeMap;
CustomTypeMap *customTypes;
uint32_t customTypeCount;
static void RegisterCustomType(LLVMTypeRef type, char *name)
{
customTypes = realloc(customTypes, sizeof(CustomType) * (customTypeCount + 1));
customTypes[customTypeCount].type = type;
customTypes[customTypeCount].name = strdup(name);
customTypeCount += 1;
}
static LLVMTypeRef LookupCustomType(char *name)
{
uint32_t i;
for (i = 0; i < customTypeCount; i += 1)
{
if (strcmp(customTypes[i].name, name) == 0)
{
return customTypes[i].type;
}
}
return NULL;
}
static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type)
{
switch (type)
{
case Int:
return LLVMInt64Type();
case UInt:
return LLVMInt64Type();
case Bool:
return LLVMInt1Type();
case Void:
return LLVMVoidType();
}
fprintf(stderr, "Unrecognized type!");
return NULL;
}
static LLVMValueRef CompileNumber(
Node *numberExpression
) {
return LLVMConstInt(LLVMInt64Type(), numberExpression->value.number, 0);
}
static LLVMValueRef CompileBinaryExpression(
LLVMValueRef wStructValue,
LLVMBuilderRef builder,
LLVMValueRef function,
Node *binaryExpression
) {
LLVMValueRef left = CompileExpression(wStructValue, builder, function, binaryExpression->children[0]);
LLVMValueRef right = CompileExpression(wStructValue, builder, function, binaryExpression->children[1]);
switch (binaryExpression->operator.binaryOperator)
{
case Add:
return LLVMBuildAdd(builder, left, right, "tmp");
case Subtract:
return LLVMBuildSub(builder, left, right, "tmp");
case Multiply:
return LLVMBuildMul(builder, left, right, "tmp");
}
return NULL;
}
static LLVMValueRef CompileFunctionCallExpression(
LLVMValueRef wStructValue,
LLVMBuilderRef builder,
LLVMValueRef function,
Node *expression
) {
uint32_t i;
uint32_t argumentCount = expression->children[1]->childCount;
LLVMValueRef args[argumentCount];
for (i = 0; i < argumentCount; i += 1)
{
args[i] = CompileExpression(wStructValue, builder, function, expression->children[1]->children[i]);
}
return LLVMBuildCall(builder, FindVariableByName(builder, wStructValue, expression->children[0]->value.string), args, argumentCount, "tmp");
}
static LLVMValueRef CompileExpression(
LLVMValueRef wStructValue,
LLVMBuilderRef builder,
LLVMValueRef function,
Node *expression
) {
LLVMValueRef var;
switch (expression->syntaxKind)
{
case BinaryExpression:
return CompileBinaryExpression(wStructValue, builder, function, expression);
case FunctionCallExpression:
return CompileFunctionCallExpression(wStructValue, builder, function, expression);
case Identifier:
return FindVariableByName(builder, wStructValue, expression->value.string);
case Number:
return CompileNumber(expression);
}
fprintf(stderr, "Unknown expression kind!\n");
return NULL;
}
static void CompileReturn(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement)
{
LLVMBuildRet(builder, CompileExpression(wStructValue, builder, function, returnStatemement->children[0]));
}
static void CompileReturnVoid(LLVMBuilderRef builder)
{
LLVMBuildRetVoid(builder);
}
static void CompileAssignment(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement)
{
LLVMValueRef fieldPointer;
LLVMValueRef result = CompileExpression(wStructValue, builder, function, assignmentStatement->children[1]);
LLVMValueRef identifier = CompileExpression(wStructValue, builder, function, assignmentStatement->children[0]);
MarkStructFieldForWrite(wStructValue, identifier);
}
static void CompileFunctionVariableDeclaration(LLVMBuilderRef builder, Node *variableDeclaration)
{
char *variableName = variableDeclaration->children[1]->value.string;
LLVMValueRef variable;
if (variableDeclaration->children[0]->type == CustomType)
{
char *customTypeName = variableDeclaration->children[0]->children[0]->value.string;
variable = LLVMBuildAlloca(builder, LookupCustomType(customTypeName), variableName);
}
else
{
variable = LLVMBuildAlloca(builder, WraithTypeToLLVMType(variableDeclaration->children[0]->type), variableName);
}
AddNamedVariable(variableName, variable);
}
static uint8_t CompileStatement(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *statement)
{
switch (statement->syntaxKind)
{
case Assignment:
CompileAssignment(wStructValue, builder, function, statement);
return 0;
case Declaration:
CompileFunctionVariableDeclaration(builder, statement);
return 0;
case Return:
CompileReturn(wStructValue, builder, function, statement);
return 1;
case ReturnVoid:
CompileReturnVoid(builder);
return 1;
}
fprintf(stderr, "Unknown statement kind!\n");
return 0;
}
static void CompileFunction(
LLVMModuleRef module,
LLVMTypeRef wStructPointerType,
Node **fieldDeclarations,
uint32_t fieldDeclarationCount,
Node *functionDeclaration
) {
uint32_t i;
uint8_t hasReturn = 0;
Node *functionSignature = functionDeclaration->children[0];
Node *functionBody = functionDeclaration->children[1];
uint32_t argumentCount = functionSignature->children[2]->childCount + 1; /* struct is implicit argument */
LLVMTypeRef paramTypes[argumentCount];
paramTypes[0] = wStructPointerType;
for (i = 0; i < functionSignature->children[2]->childCount; i += 1)
{
paramTypes[i + 1] = WraithTypeToLLVMType(functionSignature->children[2]->children[i]->children[0]->type);
}
LLVMTypeRef returnType = WraithTypeToLLVMType(functionSignature->children[1]->type);
LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, argumentCount, 0);
LLVMValueRef function = LLVMAddFunction(module, functionSignature->children[0]->value.string, functionType);
LLVMValueRef wStructPointer = LLVMGetParam(function, 0);
for (i = 0; i < functionSignature->children[2]->childCount; i += 1)
{
LLVMValueRef argument = LLVMGetParam(function, i + 1);
AddNamedVariable(functionSignature->children[2]->children[i]->children[1]->value.string, argument);
}
LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry");
LLVMBuilderRef builder = LLVMCreateBuilder();
LLVMPositionBuilderAtEnd(builder, entry);
/* FIXME: replace this with a scope abstraction */
AddStruct(wStructPointer);
for (i = 0; i < fieldDeclarationCount; i += 1)
{
AddStructFieldName(builder, wStructPointer, fieldDeclarations[i]->children[1]->value.string, i);
}
for (i = 0; i < functionBody->childCount; i += 1)
{
hasReturn |= CompileStatement(wStructPointer, builder, function, functionBody->children[i]);
}
RemoveStruct(builder, wStructPointer);
if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn)
{
LLVMBuildRetVoid(builder);
}
else if (LLVMGetTypeKind(returnType) != LLVMVoidTypeKind && !hasReturn)
{
fprintf(stderr, "Return statement not provided!");
}
}
static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *node)
{
uint32_t i;
uint32_t fieldCount = 0;
uint32_t declarationCount = node->children[1]->childCount;
uint8_t packed = 1;
LLVMTypeRef types[declarationCount];
Node *currentDeclarationNode;
Node *fieldDeclarations[declarationCount];
LLVMTypeRef wStruct = LLVMStructCreateNamed(context, node->children[0]->value.string);
LLVMTypeRef wStructPointerType = LLVMPointerType(wStruct, 0); /* FIXME: is this address space correct? */
/* first, build the structure definition */
for (i = 0; i < declarationCount; i += 1)
{
currentDeclarationNode = node->children[1]->children[i];
switch (currentDeclarationNode->syntaxKind)
{
case Declaration: /* this is badly named */
types[fieldCount] = WraithTypeToLLVMType(currentDeclarationNode->children[0]->type);
fieldDeclarations[fieldCount] = currentDeclarationNode;
fieldCount += 1;
break;
}
}
LLVMStructSetBody(wStruct, types, fieldCount, packed);
/* now we can wire up the functions */
for (i = 0; i < declarationCount; i += 1)
{
currentDeclarationNode = node->children[1]->children[i];
switch (currentDeclarationNode->syntaxKind)
{
case FunctionDeclaration:
CompileFunction(module, wStructPointerType, fieldDeclarations, fieldCount, currentDeclarationNode);
break;
}
}
RegisterCustomType(wStruct, node->children[0]->value.string);
}
static void Compile(LLVMModuleRef module, LLVMContextRef context, Node *node)
{
uint32_t i;
switch (node->syntaxKind)
{
case StructDeclaration:
CompileStruct(module, context, node);
break;
}
for (i = 0; i < node->childCount; i += 1)
{
Compile(module, context, node->children[i]);
}
}
int main(int argc, char *argv[])
{
if (argc < 2)
{
printf("Please provide a file.\n");
return 1;
}
namedVariables = NULL;
namedVariableCount = 0;
structFieldMaps = NULL;
structFieldMapCount = 0;
customTypes = NULL;
customTypeCount = 0;
stack = CreateStack();
FILE *fp = fopen(argv[1], "r");
yyin = fp;
yyparse(fp, stack);
fclose(fp);
PrintTree(rootNode, 0);
LLVMModuleRef module = LLVMModuleCreateWithName("my_module");
LLVMContextRef context = LLVMGetGlobalContext();
Compile(module, context, rootNode);
char *error = NULL;
LLVMVerifyModule(module, LLVMAbortProcessAction, &error);
LLVMDisposeMessage(error);
if (LLVMWriteBitcodeToFile(module, "test.bc") != 0) {
fprintf(stderr, "error writing bitcode to file\n");
}
return 0;
}