wraith-lang/src/codegen.c

1133 lines
35 KiB
C

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <llvm-c/Core.h>
#include <llvm-c/Analysis.h>
#include <llvm-c/BitWriter.h>
#include <llvm-c/Object.h>
#include <llvm-c/Transforms/PassManagerBuilder.h>
#include <llvm-c/Transforms/InstCombine.h>
#include <llvm-c/Transforms/Scalar.h>
#include <llvm-c/Transforms/Utils.h>
#include <llvm-c/TargetMachine.h>
#include <llvm-c/Target.h>
#include "ast.h"
typedef struct LocalVariable
{
char *name;
LLVMValueRef pointer;
LLVMValueRef value;
} LocalVariable;
typedef struct FunctionArgument
{
char *name;
LLVMValueRef value;
} FunctionArgument;
typedef struct ScopeFrame
{
LocalVariable *localVariables;
uint32_t localVariableCount;
} ScopeFrame;
typedef struct Scope
{
ScopeFrame *scopeStack;
uint32_t scopeStackCount;
} Scope;
Scope *scope;
typedef struct StructTypeField
{
char *name;
uint32_t index;
} StructTypeField;
typedef struct StructTypeFunction
{
char *name;
LLVMValueRef function;
LLVMTypeRef returnType;
uint8_t isStatic;
} StructTypeFunction;
typedef struct StructTypeDeclaration
{
char *name;
LLVMTypeRef structType;
LLVMTypeRef structPointerType;
StructTypeField *fields;
uint32_t fieldCount;
StructTypeFunction *functions;
uint32_t functionCount;
} StructTypeDeclaration;
StructTypeDeclaration *structTypeDeclarations;
uint32_t structTypeDeclarationCount;
static Scope* CreateScope()
{
Scope *scope = malloc(sizeof(Scope));
scope->scopeStack = malloc(sizeof(ScopeFrame));
scope->scopeStack[0].localVariableCount = 0;
scope->scopeStack[0].localVariables = NULL;
scope->scopeStackCount = 1;
return scope;
}
static void PushScopeFrame(Scope *scope)
{
uint32_t index = scope->scopeStackCount;
scope->scopeStack = realloc(scope->scopeStack, sizeof(ScopeFrame) * (scope->scopeStackCount + 1));
scope->scopeStack[index].localVariableCount = 0;
scope->scopeStack[index].localVariables = NULL;
scope->scopeStackCount += 1;
}
static void PopScopeFrame(Scope *scope)
{
uint32_t i;
uint32_t index = scope->scopeStackCount - 1;
if (scope->scopeStack[index].localVariables != NULL)
{
for (i = 0; i < scope->scopeStack[index].localVariableCount; i += 1)
{
free(scope->scopeStack[index].localVariables[i].name);
}
free(scope->scopeStack[index].localVariables);
}
scope->scopeStackCount -= 1;
scope->scopeStack = realloc(scope->scopeStack, sizeof(ScopeFrame) * scope->scopeStackCount);
}
static void AddLocalVariable(
Scope *scope,
LLVMValueRef pointer, /* can be NULL */
LLVMValueRef value, /* can be NULL */
char *name
) {
ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1];
uint32_t index = scopeFrame->localVariableCount;
scopeFrame->localVariables = realloc(scopeFrame->localVariables, sizeof(LocalVariable) * (scopeFrame->localVariableCount + 1));
scopeFrame->localVariables[index].name = strdup(name);
scopeFrame->localVariables[index].pointer = pointer;
scopeFrame->localVariables[index].value = value;
scopeFrame->localVariableCount += 1;
}
static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type)
{
switch (type)
{
case Int:
return LLVMInt64Type();
case UInt:
return LLVMInt64Type();
case Bool:
return LLVMInt1Type();
case Void:
return LLVMVoidType();
}
fprintf(stderr, "Unrecognized type!");
return NULL;
}
static LLVMTypeRef FindStructType(char *name)
{
uint32_t i;
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (strcmp(structTypeDeclarations[i].name, name) == 0)
{
return structTypeDeclarations[i].structType;
}
}
return NULL;
}
static LLVMValueRef FindStructFieldPointer(LLVMBuilderRef builder, LLVMValueRef structPointer, char *name)
{
int32_t i, j;
LLVMTypeRef structPointerType = LLVMTypeOf(structPointer);
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (structTypeDeclarations[i].structPointerType == structPointerType)
{
for (j = 0; j < structTypeDeclarations[i].fieldCount; j += 1)
{
if (strcmp(structTypeDeclarations[i].fields[j].name, name) == 0)
{
char *ptrName = strdup(name);
strcat(ptrName, "_ptr");
return LLVMBuildStructGEP(
builder,
structPointer,
structTypeDeclarations[i].fields[j].index,
ptrName
);
free(ptrName);
}
}
}
}
printf("Failed to find struct field pointer!");
return NULL;
}
static LLVMValueRef FindVariablePointer(char *name)
{
int32_t i, j;
for (i = scope->scopeStackCount - 1; i >= 0; i -= 1)
{
for (j = 0; j < scope->scopeStack[i].localVariableCount; j += 1)
{
if (strcmp(scope->scopeStack[i].localVariables[j].name, name) == 0)
{
return scope->scopeStack[i].localVariables[j].pointer;
}
}
}
printf("Failed to find variable pointer!");
return NULL;
}
static LLVMValueRef FindVariableValue(LLVMBuilderRef builder, char *name)
{
int32_t i, j;
for (i = scope->scopeStackCount - 1; i >= 0; i -= 1)
{
for (j = 0; j < scope->scopeStack[i].localVariableCount; j += 1)
{
if (strcmp(scope->scopeStack[i].localVariables[j].name, name) == 0)
{
if (scope->scopeStack[i].localVariables[j].value != NULL)
{
return scope->scopeStack[i].localVariables[j].value;
}
else
{
return LLVMBuildLoad(builder, scope->scopeStack[i].localVariables[j].pointer, name);
}
}
}
}
printf("Failed to find variable value!");
return NULL;
}
static void AddStructDeclaration(
LLVMTypeRef wStructType,
LLVMTypeRef wStructPointerType,
char *name,
Node **fieldDeclarations,
uint32_t fieldDeclarationCount
) {
uint32_t i;
uint32_t index = structTypeDeclarationCount;
structTypeDeclarations = realloc(structTypeDeclarations, sizeof(StructTypeDeclaration) * (structTypeDeclarationCount + 1));
structTypeDeclarations[index].structType = wStructType;
structTypeDeclarations[index].structPointerType = wStructPointerType;
structTypeDeclarations[index].name = strdup(name);
structTypeDeclarations[index].fields = NULL;
structTypeDeclarations[index].fieldCount = 0;
structTypeDeclarations[index].functions = NULL;
structTypeDeclarations[index].functionCount = 0;
for (i = 0; i < fieldDeclarationCount; i += 1)
{
structTypeDeclarations[index].fields = realloc(structTypeDeclarations[index].fields, sizeof(StructTypeField) * (structTypeDeclarations[index].fieldCount + 1));
structTypeDeclarations[index].fields[i].name = strdup(fieldDeclarations[i]->children[1]->value.string);
structTypeDeclarations[index].fields[i].index = i;
structTypeDeclarations[index].fieldCount += 1;
}
structTypeDeclarationCount += 1;
}
static void DeclareStructFunction(
LLVMTypeRef wStructPointerType,
LLVMValueRef function,
LLVMTypeRef returnType,
uint8_t isStatic,
char *name
) {
uint32_t i, index;
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (structTypeDeclarations[i].structPointerType == wStructPointerType)
{
index = structTypeDeclarations[i].functionCount;
structTypeDeclarations[i].functions = realloc(structTypeDeclarations[i].functions, sizeof(StructTypeFunction) * (structTypeDeclarations[i].functionCount + 1));
structTypeDeclarations[i].functions[index].name = strdup(name);
structTypeDeclarations[i].functions[index].function = function;
structTypeDeclarations[i].functions[index].returnType = returnType;
structTypeDeclarations[i].functions[index].isStatic = isStatic;
structTypeDeclarations[i].functionCount += 1;
return;
}
}
fprintf(stderr, "Could not find struct type for function!\n");
}
static LLVMTypeRef LookupCustomType(char *name)
{
uint32_t i;
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (strcmp(structTypeDeclarations[i].name, name) == 0)
{
return structTypeDeclarations[i].structType;
}
}
fprintf(stderr, "Could not find struct type!\n");
return NULL;
}
static LLVMTypeRef ResolveType(Node* typeNode)
{
if (IsPrimitiveType(typeNode))
{
return WraithTypeToLLVMType(typeNode->children[0]->primitiveType);
}
else if (typeNode->children[0]->syntaxKind == CustomTypeNode)
{
char *typeName = typeNode->children[0]->value.string;
return LookupCustomType(typeName);
}
else if (typeNode->children[0]->syntaxKind == ReferenceTypeNode)
{
return LLVMPointerType(ResolveType(typeNode->children[0]->children[0]), 0);
}
else
{
fprintf(stderr, "Unknown type node!\n");
return NULL;
}
}
static LLVMValueRef LookupFunctionByType(
LLVMTypeRef structType,
char *name,
LLVMTypeRef *pReturnType,
uint8_t *pStatic
) {
uint32_t i, j;
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (structTypeDeclarations[i].structType == structType)
{
for (j = 0; j < structTypeDeclarations[i].functionCount; j += 1)
{
if (strcmp(structTypeDeclarations[i].functions[j].name, name) == 0)
{
*pReturnType = structTypeDeclarations[i].functions[j].returnType;
*pStatic = structTypeDeclarations[i].functions[j].isStatic;
return structTypeDeclarations[i].functions[j].function;
}
}
}
}
fprintf(stderr, "Could not find struct function!\n");
return NULL;
}
static LLVMValueRef LookupFunctionByPointerType(
LLVMTypeRef structPointerType,
char *name,
LLVMTypeRef *pReturnType,
uint8_t *pStatic
) {
uint32_t i, j;
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (structTypeDeclarations[i].structPointerType == structPointerType)
{
for (j = 0; j < structTypeDeclarations[i].functionCount; j += 1)
{
if (strcmp(structTypeDeclarations[i].functions[j].name, name) == 0)
{
*pReturnType = structTypeDeclarations[i].functions[j].returnType;
*pStatic = structTypeDeclarations[i].functions[j].isStatic;
return structTypeDeclarations[i].functions[j].function;
}
}
}
}
fprintf(stderr, "Could not find struct function!\n");
return NULL;
}
static LLVMValueRef LookupFunctionByInstance(
LLVMValueRef structPointer,
char *name,
LLVMTypeRef *pReturnType,
uint8_t *pStatic
) {
return LookupFunctionByPointerType(LLVMTypeOf(structPointer), name, pReturnType, pStatic);
}
static void AddStructVariablesToScope(
LLVMBuilderRef builder,
LLVMValueRef structPointer
) {
uint32_t i, j;
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (structTypeDeclarations[i].structPointerType == LLVMTypeOf(structPointer))
{
for (j = 0; j < structTypeDeclarations[i].fieldCount; j += 1)
{
char *ptrName = strdup(structTypeDeclarations[i].fields[j].name);
strcat(ptrName, "_ptr");
LLVMValueRef elementPointer = LLVMBuildStructGEP(
builder,
structPointer,
structTypeDeclarations[i].fields[j].index,
ptrName
);
free(ptrName);
AddLocalVariable(
scope,
elementPointer,
NULL,
structTypeDeclarations[i].fields[j].name
);
}
}
}
}
static LLVMValueRef CompileExpression(
LLVMBuilderRef builder,
Node *expression
);
static LLVMValueRef CompileNumber(
Node *numberExpression
) {
return LLVMConstInt(LLVMInt64Type(), numberExpression->value.number, 0);
}
static LLVMValueRef CompileString(
LLVMBuilderRef builder,
Node *stringExpression
) {
return LLVMBuildGlobalStringPtr(builder, stringExpression->value.string, "stringConstant");
}
static LLVMValueRef CompileBinaryExpression(
LLVMBuilderRef builder,
Node *binaryExpression
) {
LLVMValueRef left = CompileExpression(builder, binaryExpression->children[0]);
LLVMValueRef right = CompileExpression(builder, binaryExpression->children[1]);
switch (binaryExpression->operator.binaryOperator)
{
case Add:
return LLVMBuildAdd(builder, left, right, "addResult");
case Subtract:
return LLVMBuildSub(builder, left, right, "subtractResult");
case Multiply:
return LLVMBuildMul(builder, left, right, "multiplyResult");
/* FIXME: need type information for comparison */
case LessThan:
return LLVMBuildICmp(builder, LLVMIntSLT, left, right, "lessThanResult");
case GreaterThan:
return LLVMBuildICmp(builder, LLVMIntSGT, left, right, "greaterThanResult");
case Mod:
return LLVMBuildSRem(builder, left, right, "modResult");
case Equal:
return LLVMBuildICmp(builder, LLVMIntEQ, left, right, "equalResult");
case LogicalOr:
return LLVMBuildOr(builder, left, right, "orResult");
}
return NULL;
}
/* FIXME THIS IS ALL BROKEN */
static LLVMValueRef CompileFunctionCallExpression(
LLVMBuilderRef builder,
Node *expression
) {
uint32_t i;
uint32_t argumentCount = 0;
LLVMValueRef args[expression->children[1]->childCount + 1];
LLVMValueRef function;
uint8_t isStatic;
LLVMValueRef structInstance;
LLVMTypeRef functionReturnType;
char *returnName = "";
/* FIXME: this needs to be recursive on access chains */
if (expression->children[0]->syntaxKind == AccessExpression)
{
LLVMTypeRef typeReference = FindStructType(
expression->children[0]->children[0]->value.string
);
if (typeReference != NULL)
{
function = LookupFunctionByType(
typeReference,
expression->children[0]->children[1]->value.string,
&functionReturnType,
&isStatic
);
}
else
{
structInstance = FindVariablePointer(expression->children[0]->children[0]->value.string);
function = LookupFunctionByInstance(structInstance, expression->children[0]->children[1]->value.string, &functionReturnType, &isStatic);
}
}
else
{
fprintf(stderr, "Failed to find function!\n");
return NULL;
}
if (!isStatic)
{
args[argumentCount] = structInstance;
argumentCount += 1;
}
for (i = 0; i < expression->children[1]->childCount; i += 1)
{
args[argumentCount] = CompileExpression(builder, expression->children[1]->children[i]);
argumentCount += 1;
}
if (LLVMGetTypeKind(functionReturnType) != LLVMVoidTypeKind)
{
returnName = "callReturn";
}
return LLVMBuildCall(builder, function, args, argumentCount, returnName);
}
static LLVMValueRef CompileAccessExpressionForStore(
LLVMBuilderRef builder,
Node *expression
) {
Node *accessee = expression->children[0];
Node *accessor = expression->children[1];
LLVMValueRef accesseeValue = FindVariablePointer(accessee->value.string);
return FindStructFieldPointer(builder, accesseeValue, accessor->value.string);
}
static LLVMValueRef CompileAccessExpression(
LLVMBuilderRef builder,
Node *expression
) {
Node *accessee = expression->children[0];
Node *accessor = expression->children[1];
LLVMValueRef accesseeValue = FindVariablePointer(accessee->value.string);
LLVMValueRef access = FindStructFieldPointer(builder, accesseeValue, accessor->value.string);
return LLVMBuildLoad(builder, access, accessor->value.string);
}
static LLVMValueRef CompileAllocExpression(
LLVMBuilderRef builder,
Node *expression
) {
LLVMTypeRef type = ResolveType(expression->children[0]);
return LLVMBuildMalloc(builder, type, "allocation");
}
static LLVMValueRef CompileExpression(
LLVMBuilderRef builder,
Node *expression
) {
switch (expression->syntaxKind)
{
case AccessExpression:
return CompileAccessExpression(builder, expression);
case AllocExpression:
return CompileAllocExpression(builder, expression);
case BinaryExpression:
return CompileBinaryExpression(builder, expression);
case FunctionCallExpression:
return CompileFunctionCallExpression(builder, expression);
case Identifier:
return FindVariableValue(builder, expression->value.string);
case Number:
return CompileNumber(expression);
case StringLiteral:
return CompileString(builder, expression);
}
fprintf(stderr, "Unknown expression kind!\n");
return NULL;
}
static LLVMBasicBlockRef CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement);
static LLVMBasicBlockRef CompileReturn(LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement)
{
LLVMValueRef expression = CompileExpression(builder, returnStatemement->children[0]);
LLVMBuildRet(builder, expression);
return LLVMGetLastBasicBlock(function);
}
static LLVMBasicBlockRef CompileReturnVoid(LLVMBuilderRef builder, LLVMValueRef function)
{
LLVMBuildRetVoid(builder);
return LLVMGetLastBasicBlock(function);
}
/* FIXME: path for reference types */
static LLVMValueRef CompileFunctionVariableDeclaration(LLVMBuilderRef builder, LLVMValueRef function, Node *variableDeclaration)
{
LLVMValueRef variable;
char *variableName = variableDeclaration->children[1]->value.string;
char *ptrName = strdup(variableName);
strcat(ptrName, "_ptr");
variable = LLVMBuildAlloca(builder, ResolveType(variableDeclaration->children[0]), ptrName);
free(ptrName);
AddLocalVariable(scope, variable, NULL, variableName);
return variable;
}
static LLVMBasicBlockRef CompileAssignment(LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement)
{
LLVMValueRef result = CompileExpression(builder, assignmentStatement->children[1]);
LLVMValueRef identifier;
if (assignmentStatement->children[0]->syntaxKind == AccessExpression)
{
identifier = CompileAccessExpressionForStore(builder, assignmentStatement->children[0]);
}
else if (assignmentStatement->children[0]->syntaxKind == Identifier)
{
identifier = FindVariablePointer(assignmentStatement->children[0]->value.string);
}
else if (assignmentStatement->children[0]->syntaxKind == Declaration)
{
identifier = CompileFunctionVariableDeclaration(builder, function, assignmentStatement->children[0]);
}
else
{
printf("Identifier not found!");
return LLVMGetLastBasicBlock(function);
}
LLVMBuildStore(builder, result, identifier);
return LLVMGetLastBasicBlock(function);
}
static LLVMBasicBlockRef CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement)
{
uint32_t i;
LLVMValueRef conditional = CompileExpression(builder, ifStatement->children[0]);
LLVMBasicBlockRef block = LLVMAppendBasicBlock(function, "ifBlock");
LLVMBasicBlockRef afterCond = LLVMAppendBasicBlock(function, "afterCond");
LLVMBuildCondBr(builder, conditional, block, afterCond);
LLVMPositionBuilderAtEnd(builder, block);
for (i = 0; i < ifStatement->children[1]->childCount; i += 1)
{
CompileStatement(builder, function, ifStatement->children[1]->children[i]);
}
LLVMBuildBr(builder, afterCond);
LLVMPositionBuilderAtEnd(builder, afterCond);
return afterCond;
}
static LLVMBasicBlockRef CompileIfElseStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifElseStatement)
{
uint32_t i;
LLVMValueRef conditional = CompileExpression(builder, ifElseStatement->children[0]->children[0]);
LLVMBasicBlockRef ifBlock = LLVMAppendBasicBlock(function, "ifBlock");
LLVMBasicBlockRef elseBlock = LLVMAppendBasicBlock(function, "elseBlock");
LLVMBasicBlockRef afterCond = LLVMAppendBasicBlock(function, "afterCond");
LLVMBuildCondBr(builder, conditional, ifBlock, elseBlock);
LLVMPositionBuilderAtEnd(builder, ifBlock);
for (i = 0; i < ifElseStatement->children[0]->children[1]->childCount; i += 1)
{
CompileStatement(builder, function, ifElseStatement->children[0]->children[1]->children[i]);
}
LLVMBuildBr(builder, afterCond);
LLVMPositionBuilderAtEnd(builder, elseBlock);
if (ifElseStatement->children[1]->syntaxKind == StatementSequence)
{
for (i = 0; i < ifElseStatement->children[1]->childCount; i += 1)
{
CompileStatement(builder, function, ifElseStatement->children[1]->children[i]);
}
}
else
{
CompileStatement(builder, function, ifElseStatement->children[1]);
}
LLVMBuildBr(builder, afterCond);
LLVMPositionBuilderAtEnd(builder, afterCond);
return afterCond;
}
static LLVMBasicBlockRef CompileForLoopStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *forLoopStatement)
{
uint32_t i;
LLVMBasicBlockRef entryBlock = LLVMAppendBasicBlock(function, "loopEntry");
LLVMBasicBlockRef checkBlock = LLVMAppendBasicBlock(function, "loopCheck");
LLVMBasicBlockRef bodyBlock = LLVMAppendBasicBlock(function, "loopBody");
LLVMBasicBlockRef afterLoopBlock = LLVMAppendBasicBlock(function, "afterLoop");
char *iteratorVariableName = forLoopStatement->children[0]->children[1]->value.string;
LLVMTypeRef iteratorVariableType = ResolveType(forLoopStatement->children[0]->children[0]);
PushScopeFrame(scope);
LLVMBuildBr(builder, entryBlock);
LLVMPositionBuilderAtEnd(builder, entryBlock);
LLVMBuildBr(builder, checkBlock);
LLVMPositionBuilderAtEnd(builder, checkBlock);
LLVMValueRef iteratorValue = LLVMBuildPhi(builder, iteratorVariableType, iteratorVariableName);
AddLocalVariable(scope, NULL, iteratorValue, iteratorVariableName);
LLVMPositionBuilderAtEnd(builder, bodyBlock);
LLVMValueRef nextValue = LLVMBuildAdd(
builder,
iteratorValue,
LLVMConstInt(iteratorVariableType, forLoopStatement->children[1]->value.number, 0),
"next"
);
LLVMPositionBuilderAtEnd(builder, checkBlock);
LLVMValueRef iteratorEndValue = CompileNumber(forLoopStatement->children[2]);
LLVMValueRef comparison = LLVMBuildICmp(builder, LLVMIntULE, iteratorValue, iteratorEndValue, "iteratorCompare");
LLVMBuildCondBr(builder, comparison, bodyBlock, afterLoopBlock);
LLVMPositionBuilderAtEnd(builder, bodyBlock);
LLVMBasicBlockRef lastBlock;
for (i = 0; i < forLoopStatement->children[3]->childCount; i += 1)
{
lastBlock = CompileStatement(builder, function, forLoopStatement->children[3]->children[i]);
}
LLVMBuildBr(builder, checkBlock);
LLVMPositionBuilderBefore(builder, LLVMGetFirstInstruction(checkBlock));
LLVMValueRef incomingValues[2];
incomingValues[0] = CompileNumber(forLoopStatement->children[1]);
incomingValues[1] = nextValue;
LLVMBasicBlockRef incomingBlocks[2];
incomingBlocks[0] = entryBlock;
incomingBlocks[1] = lastBlock;
LLVMAddIncoming(iteratorValue, incomingValues, incomingBlocks, 2);
LLVMPositionBuilderAtEnd(builder, afterLoopBlock);
PopScopeFrame(scope);
return afterLoopBlock;
}
static LLVMBasicBlockRef CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement)
{
switch (statement->syntaxKind)
{
case Assignment:
return CompileAssignment(builder, function, statement);
case Declaration:
CompileFunctionVariableDeclaration(builder, function, statement);
return LLVMGetLastBasicBlock(function);
case ForLoop:
return CompileForLoopStatement(builder, function, statement);
case FunctionCallExpression:
CompileFunctionCallExpression(builder, statement);
return LLVMGetLastBasicBlock(function);
case IfStatement:
return CompileIfStatement(builder, function, statement);
case IfElseStatement:
return CompileIfElseStatement(builder, function, statement);
case Return:
return CompileReturn(builder, function, statement);
case ReturnVoid:
return CompileReturnVoid(builder, function);
}
fprintf(stderr, "Unknown statement kind!\n");
return NULL;
}
static void CompileFunction(
LLVMModuleRef module,
char *parentStructName,
LLVMTypeRef wStructPointerType,
Node **fieldDeclarations,
uint32_t fieldDeclarationCount,
Node *functionDeclaration
) {
uint32_t i;
uint8_t hasReturn = 0;
uint8_t isStatic = 0;
Node *functionSignature = functionDeclaration->children[0];
Node *functionBody = functionDeclaration->children[1];
uint32_t argumentCount = functionSignature->children[2]->childCount;
LLVMTypeRef paramTypes[argumentCount + 1];
uint32_t paramIndex = 0;
if (functionSignature->children[3]->childCount > 0)
{
for (i = 0; i < functionSignature->children[3]->childCount; i += 1)
{
if (functionSignature->children[3]->children[i]->syntaxKind == StaticModifier)
{
isStatic = 1;
break;
}
}
}
if (!isStatic)
{
paramTypes[paramIndex] = wStructPointerType;
paramIndex += 1;
}
PushScopeFrame(scope);
/* FIXME: should work for non-primitive types */
for (i = 0; i < functionSignature->children[2]->childCount; i += 1)
{
paramTypes[paramIndex] = WraithTypeToLLVMType(functionSignature->children[2]->children[i]->children[0]->children[0]->primitiveType);
paramIndex += 1;
}
LLVMTypeRef returnType = WraithTypeToLLVMType(functionSignature->children[1]->children[0]->primitiveType);
LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, paramIndex, 0);
char *functionName = strdup(parentStructName);
strcat(functionName, "_");
strcat(functionName, functionSignature->children[0]->value.string);
LLVMValueRef function = LLVMAddFunction(module, functionName, functionType);
free(functionName);
DeclareStructFunction(wStructPointerType, function, returnType, isStatic, functionSignature->children[0]->value.string);
LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry");
LLVMBuilderRef builder = LLVMCreateBuilder();
LLVMPositionBuilderAtEnd(builder, entry);
if (!isStatic)
{
LLVMValueRef wStructPointer = LLVMGetParam(function, 0);
AddStructVariablesToScope(builder, wStructPointer);
}
for (i = 0; i < functionSignature->children[2]->childCount; i += 1)
{
char *ptrName = strdup(functionSignature->children[2]->children[i]->children[1]->value.string);
strcat(ptrName, "_ptr");
LLVMValueRef argument = LLVMGetParam(function, i + !isStatic);
LLVMValueRef argumentCopy = LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName);
LLVMBuildStore(builder, argument, argumentCopy);
free(ptrName);
AddLocalVariable(scope, argumentCopy, NULL, functionSignature->children[2]->children[i]->children[1]->value.string);
}
for (i = 0; i < functionBody->childCount; i += 1)
{
CompileStatement(builder, function, functionBody->children[i]);
}
hasReturn = LLVMGetBasicBlockTerminator(LLVMGetLastBasicBlock(function)) != NULL;
if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn)
{
LLVMBuildRetVoid(builder);
}
else if (LLVMGetTypeKind(returnType) != LLVMVoidTypeKind && !hasReturn)
{
fprintf(stderr, "Return statement not provided!");
}
PopScopeFrame(scope);
LLVMDisposeBuilder(builder);
}
static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *node)
{
uint32_t i;
uint32_t fieldCount = 0;
uint32_t declarationCount = node->children[1]->childCount;
uint8_t packed = 1;
LLVMTypeRef types[declarationCount];
Node *currentDeclarationNode;
Node *fieldDeclarations[declarationCount];
char *structName = node->children[0]->value.string;
PushScopeFrame(scope);
LLVMTypeRef wStructType = LLVMStructCreateNamed(context, structName);
LLVMTypeRef wStructPointerType = LLVMPointerType(wStructType, 0); /* FIXME: is this address space correct? */
/* first, build the structure definition */
for (i = 0; i < declarationCount; i += 1)
{
currentDeclarationNode = node->children[1]->children[i];
switch (currentDeclarationNode->syntaxKind)
{
case Declaration: /* this is badly named */
types[fieldCount] = ResolveType(currentDeclarationNode->children[0]);
fieldDeclarations[fieldCount] = currentDeclarationNode;
fieldCount += 1;
break;
}
}
LLVMStructSetBody(wStructType, types, fieldCount, packed);
AddStructDeclaration(wStructType, wStructPointerType, node->children[0]->value.string, fieldDeclarations, fieldCount);
/* now we can wire up the functions */
for (i = 0; i < declarationCount; i += 1)
{
currentDeclarationNode = node->children[1]->children[i];
switch (currentDeclarationNode->syntaxKind)
{
case FunctionDeclaration:
CompileFunction(module, structName, wStructPointerType, fieldDeclarations, fieldCount, currentDeclarationNode);
break;
}
}
PopScopeFrame(scope);
}
static void Compile(LLVMModuleRef module, LLVMContextRef context, Node *node)
{
uint32_t i;
for (i = 0; i < node->childCount; i += 1)
{
if (node->children[i]->syntaxKind == StructDeclaration)
{
CompileStruct(module, context, node->children[i]);
}
else
{
fprintf(stderr, "top level declarations that are not structs are forbidden!\n");
}
}
}
/* TODO: move this to some kind of standard library file? */
static void RegisterLibraryFunctions(LLVMModuleRef module, LLVMContextRef context)
{
LLVMTypeRef structType = LLVMStructCreateNamed(context, "Console");
LLVMTypeRef structPointerType = LLVMPointerType(structType, 0);
AddStructDeclaration(structType, structPointerType, "Console", NULL, 0);
LLVMTypeRef printfArg = LLVMPointerType(LLVMInt8Type(), 0);
LLVMTypeRef printfFunctionType = LLVMFunctionType(LLVMInt32Type(), &printfArg, 1, 1);
LLVMValueRef printfFunction = LLVMAddFunction(module, "printf", printfFunctionType);
LLVMSetLinkage(printfFunction, LLVMExternalLinkage);
LLVMTypeRef printLineFunctionType = LLVMFunctionType(LLVMInt32Type(), &printfArg, 1, 1);
LLVMValueRef printLineFunction = LLVMAddFunction(module, "printLine", printLineFunctionType);
LLVMBuilderRef builder = LLVMCreateBuilder();
LLVMBasicBlockRef entry = LLVMAppendBasicBlock(printLineFunction, "entry");
LLVMPositionBuilderAtEnd(builder, entry);
LLVMValueRef newLine = LLVMBuildGlobalStringPtr(builder, "\n", "newline");
LLVMValueRef printParams[LLVMCountParams(printLineFunction)];
LLVMGetParams(printLineFunction, printParams);
LLVMValueRef stringPrint = LLVMBuildCall(builder, printfFunction, printParams, LLVMCountParams(printLineFunction), "printfCall");
LLVMValueRef newlinePrint = LLVMBuildCall(builder, printfFunction, &newLine, 1, "printNewLine");
LLVMBuildRet(builder, LLVMBuildAnd(builder, stringPrint, newlinePrint, "and"));
DeclareStructFunction(structPointerType, printLineFunction, LLVMInt8Type(), 1, "PrintLine");
}
int Codegen(Node *node, uint32_t optimizationLevel)
{
scope = CreateScope();
structTypeDeclarations = NULL;
structTypeDeclarationCount = 0;
LLVMModuleRef module = LLVMModuleCreateWithName("my_module");
LLVMContextRef context = LLVMGetGlobalContext();
RegisterLibraryFunctions(module, context);
Compile(module, context, node);
/* add main call */
LLVMBuilderRef builder = LLVMCreateBuilder();
LLVMTypeRef mainFunctionType = LLVMFunctionType(LLVMInt64Type(), NULL, 0, 0);
LLVMValueRef mainFunction = LLVMAddFunction(module, "main", mainFunctionType);
LLVMBasicBlockRef entry = LLVMAppendBasicBlock(mainFunction, "entry");
LLVMPositionBuilderAtEnd(builder, entry);
LLVMValueRef wraithMainFunction = LLVMGetNamedFunction(module, "Program_Main");
LLVMValueRef mainResult = LLVMBuildCall(builder, wraithMainFunction, NULL, 0, "result");
LLVMBuildRet(builder, mainResult);
LLVMDisposeBuilder(builder);
/* verify */
char *error = NULL;
if (LLVMVerifyModule(module, LLVMAbortProcessAction, &error) != 0)
{
fprintf(stderr, "%s\n", error);
LLVMDisposeMessage(error);
return EXIT_FAILURE;
}
/* prepare to emit assembly */
LLVMInitializeNativeTarget();
LLVMInitializeAllTargetInfos();
LLVMInitializeAllTargets();
LLVMInitializeAllTargetMCs();
LLVMInitializeAllAsmParsers();
LLVMInitializeAllAsmPrinters();
LLVMSetTarget(module, LLVM_DEFAULT_TARGET_TRIPLE);
LLVMTargetRef target;
if (LLVMGetTargetFromTriple(LLVM_DEFAULT_TARGET_TRIPLE, &target, &error) != 0)
{
fprintf(stderr, "Failed to get target!\n");
fprintf(stderr, "%s\n", error);
LLVMDisposeMessage(error);
return EXIT_FAILURE;
}
LLVMPassManagerRef passManager = LLVMCreatePassManager();
LLVMPassManagerBuilderRef passManagerBuilder = LLVMPassManagerBuilderCreate();
LLVMPassManagerBuilderSetOptLevel(passManagerBuilder, optimizationLevel);
LLVMPassManagerBuilderPopulateModulePassManager(passManagerBuilder, passManager);
LLVMRunPassManager(passManager, module);
if (LLVMWriteBitcodeToFile(module, "test.bc") != 0) {
fprintf(stderr, "error writing bitcode to file\n");
return EXIT_FAILURE;
}
char *cpu = "generic";
char *features = "";
LLVMTargetMachineRef targetMachine = LLVMCreateTargetMachine(
target,
LLVM_DEFAULT_TARGET_TRIPLE,
cpu,
features,
LLVMCodeGenLevelDefault,
LLVMRelocDefault,
LLVMCodeModelDefault
);
if (LLVMTargetMachineEmitToFile(targetMachine, module, "test.o", LLVMObjectFile, &error) != 0)
{
fprintf(stderr, "Failed to emit machine code!\n");
fprintf(stderr, "%s\n", error);
LLVMDisposeMessage(error);
return EXIT_FAILURE;
}
LLVMDisposeMessage(error);
LLVMDisposeTargetMachine(targetMachine);
LLVMPassManagerBuilderDispose(passManagerBuilder);
LLVMDisposePassManager(passManager);
LLVMDisposeModule(module);
return EXIT_SUCCESS;
}