530 lines
15 KiB
C
530 lines
15 KiB
C
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
#include <string.h>
|
|
|
|
#include <llvm-c/Core.h>
|
|
#include <llvm-c/Analysis.h>
|
|
#include <llvm-c/BitWriter.h>
|
|
|
|
#include "y.tab.h"
|
|
#include "ast.h"
|
|
#include "stack.h"
|
|
|
|
extern FILE *yyin;
|
|
Stack *stack;
|
|
Node *rootNode;
|
|
|
|
typedef struct StructFieldMapValue
|
|
{
|
|
char *name;
|
|
LLVMValueRef value;
|
|
LLVMValueRef valuePointer;
|
|
uint32_t index;
|
|
uint8_t needsWrite;
|
|
uint8_t needsRead;
|
|
} StructFieldMapValue;
|
|
|
|
typedef struct StructFieldMap
|
|
{
|
|
LLVMValueRef structPointer;
|
|
StructFieldMapValue *fields;
|
|
uint32_t fieldCount;
|
|
} StructFieldMap;
|
|
|
|
StructFieldMap *structFieldMaps;
|
|
uint32_t structFieldMapCount;
|
|
|
|
static void AddStruct(LLVMValueRef wStructPointer)
|
|
{
|
|
structFieldMaps = realloc(structFieldMaps, sizeof(StructFieldMap) * (structFieldMapCount + 1));
|
|
structFieldMaps[structFieldMapCount].structPointer = wStructPointer;
|
|
structFieldMaps[structFieldMapCount].fields = NULL;
|
|
structFieldMaps[structFieldMapCount].fieldCount = 0;
|
|
structFieldMapCount += 1;
|
|
}
|
|
|
|
static void AddStructFieldName(LLVMBuilderRef builder, LLVMValueRef wStructPointer, char *name, uint32_t index)
|
|
{
|
|
uint32_t i, fieldCount;
|
|
|
|
for (i = 0; i < structFieldMapCount; i += 1)
|
|
{
|
|
if (structFieldMaps[i].structPointer == wStructPointer)
|
|
{
|
|
fieldCount = structFieldMaps[i].fieldCount;
|
|
|
|
structFieldMaps[i].fields = realloc(structFieldMaps[i].fields, sizeof(StructFieldMapValue) * (fieldCount + 1));
|
|
structFieldMaps[i].fields[fieldCount].name = strdup(name);
|
|
structFieldMaps[i].fields[fieldCount].value = NULL;
|
|
structFieldMaps[i].fields[fieldCount].valuePointer = NULL;
|
|
structFieldMaps[i].fields[fieldCount].index = index;
|
|
structFieldMaps[i].fields[fieldCount].needsWrite = 0;
|
|
structFieldMaps[i].fields[fieldCount].needsRead = 1;
|
|
structFieldMaps[i].fieldCount += 1;
|
|
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
static LLVMValueRef CheckStructFieldAndLoad(LLVMBuilderRef builder, LLVMValueRef wStructPointer, char *name)
|
|
{
|
|
uint32_t i, j;
|
|
|
|
for (i = 0; i < structFieldMapCount; i += 1)
|
|
{
|
|
if (structFieldMaps[i].structPointer == wStructPointer)
|
|
{
|
|
for (j = 0; j < structFieldMaps[i].fieldCount; j += 1)
|
|
{
|
|
if (strcmp(structFieldMaps[i].fields[j].name, name) == 0)
|
|
{
|
|
if (structFieldMaps[i].fields[j].needsRead)
|
|
{
|
|
char *ptrName = strdup(name);
|
|
strcat(ptrName, "_ptr");
|
|
LLVMValueRef elementPointer = LLVMBuildStructGEP(
|
|
builder,
|
|
wStructPointer,
|
|
structFieldMaps[i].fields[j].index,
|
|
ptrName
|
|
);
|
|
free(ptrName);
|
|
|
|
structFieldMaps[i].fields[j].value = LLVMBuildLoad(builder, elementPointer, name);
|
|
structFieldMaps[i].fields[j].valuePointer = elementPointer;
|
|
structFieldMaps[i].fields[j].needsRead = 0;
|
|
}
|
|
|
|
return structFieldMaps[i].fields[j].value;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return NULL;
|
|
}
|
|
|
|
static void MarkStructFieldForWrite(LLVMValueRef wStructPointer, LLVMValueRef value)
|
|
{
|
|
uint32_t i, j;
|
|
|
|
for (i = 0; i < structFieldMapCount; i += 1)
|
|
{
|
|
if (structFieldMaps[i].structPointer == wStructPointer)
|
|
{
|
|
for (j = 0; j < structFieldMaps[i].fieldCount; j += 1)
|
|
{
|
|
if (structFieldMaps[i].fields[j].value == value)
|
|
{
|
|
structFieldMaps[i].fields[j].needsWrite = 1;
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
static LLVMValueRef GetStructFieldPointer(LLVMValueRef wStructPointer, LLVMValueRef value)
|
|
{
|
|
uint32_t i, j;
|
|
|
|
for (i = 0; i < structFieldMapCount; i += 1)
|
|
{
|
|
if (structFieldMaps[i].structPointer == wStructPointer)
|
|
{
|
|
for (j = 0; j < structFieldMaps[i].fieldCount; j += 1)
|
|
{
|
|
if (structFieldMaps[i].fields[j].value == value)
|
|
{
|
|
return structFieldMaps[i].fields[j].valuePointer;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return NULL;
|
|
}
|
|
|
|
static void RemoveStruct(LLVMBuilderRef builder, LLVMValueRef wStructPointer)
|
|
{
|
|
uint32_t i, j;
|
|
|
|
for (i = 0; i < structFieldMapCount; i += 1)
|
|
{
|
|
if (structFieldMaps[i].structPointer == wStructPointer)
|
|
{
|
|
for (j = 0; j < structFieldMaps[i].fieldCount; j += 1)
|
|
{
|
|
if (structFieldMaps[i].fields[j].needsWrite)
|
|
{
|
|
LLVMBuildStore(
|
|
builder,
|
|
structFieldMaps[i].fields[j].value,
|
|
structFieldMaps[i].fields[j].valuePointer
|
|
);
|
|
}
|
|
}
|
|
|
|
free(structFieldMaps[i].fields);
|
|
structFieldMaps[i].fields = NULL;
|
|
structFieldMaps[i].fieldCount = 0;
|
|
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
typedef struct IdentifierMapValue
|
|
{
|
|
char *name;
|
|
LLVMValueRef value;
|
|
} IdentifierMapValue;
|
|
|
|
IdentifierMapValue *namedVariables;
|
|
uint32_t namedVariableCount;
|
|
|
|
static LLVMValueRef CompileExpression(
|
|
LLVMValueRef wStructValue,
|
|
LLVMBuilderRef builder,
|
|
LLVMValueRef function,
|
|
Node *binaryExpression
|
|
);
|
|
|
|
static void AddNamedVariable(char *name, LLVMValueRef variable)
|
|
{
|
|
IdentifierMapValue mapValue;
|
|
mapValue.name = name;
|
|
mapValue.value = variable;
|
|
|
|
namedVariables = realloc(namedVariables, sizeof(IdentifierMapValue) * (namedVariableCount + 1));
|
|
namedVariables[namedVariableCount] = mapValue;
|
|
|
|
namedVariableCount += 1;
|
|
}
|
|
|
|
static LLVMValueRef FindVariableByName(LLVMBuilderRef builder, LLVMValueRef wStructValue, char *name)
|
|
{
|
|
uint32_t i, j;
|
|
LLVMValueRef searchResult;
|
|
|
|
/* first, search scoped vars */
|
|
for (i = 0; i < namedVariableCount; i += 1)
|
|
{
|
|
if (strcmp(namedVariables[i].name, name) == 0)
|
|
{
|
|
return namedVariables[i].value;
|
|
}
|
|
}
|
|
|
|
/* if none exist, search struct vars */
|
|
searchResult = CheckStructFieldAndLoad(builder, wStructValue, name);
|
|
|
|
if (searchResult == NULL)
|
|
{
|
|
fprintf(stderr, "Identifier not found!");
|
|
}
|
|
|
|
return searchResult;
|
|
}
|
|
|
|
static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type)
|
|
{
|
|
switch (type)
|
|
{
|
|
case Int:
|
|
return LLVMInt64Type();
|
|
|
|
case UInt:
|
|
return LLVMInt64Type();
|
|
|
|
case Bool:
|
|
return LLVMInt1Type();
|
|
|
|
case Void:
|
|
return LLVMVoidType();
|
|
}
|
|
|
|
fprintf(stderr, "Unrecognized type!");
|
|
return NULL;
|
|
}
|
|
|
|
static LLVMValueRef CompileNumber(
|
|
Node *numberExpression
|
|
) {
|
|
return LLVMConstInt(LLVMInt64Type(), numberExpression->value.number, 0);
|
|
}
|
|
|
|
static LLVMValueRef CompileBinaryExpression(
|
|
LLVMValueRef wStructValue,
|
|
LLVMBuilderRef builder,
|
|
LLVMValueRef function,
|
|
Node *binaryExpression
|
|
) {
|
|
LLVMValueRef left = CompileExpression(wStructValue, builder, function, binaryExpression->children[0]);
|
|
LLVMValueRef right = CompileExpression(wStructValue, builder, function, binaryExpression->children[1]);
|
|
|
|
switch (binaryExpression->operator.binaryOperator)
|
|
{
|
|
case Add:
|
|
return LLVMBuildAdd(builder, left, right, "tmp");
|
|
|
|
case Subtract:
|
|
return LLVMBuildSub(builder, left, right, "tmp");
|
|
|
|
case Multiply:
|
|
return LLVMBuildMul(builder, left, right, "tmp");
|
|
|
|
}
|
|
|
|
return NULL;
|
|
}
|
|
|
|
static LLVMValueRef CompileFunctionCallExpression(
|
|
LLVMValueRef wStructValue,
|
|
LLVMBuilderRef builder,
|
|
LLVMValueRef function,
|
|
Node *expression
|
|
) {
|
|
uint32_t i;
|
|
uint32_t argumentCount = expression->children[1]->childCount;
|
|
LLVMValueRef args[argumentCount];
|
|
|
|
for (i = 0; i < argumentCount; i += 1)
|
|
{
|
|
args[i] = CompileExpression(wStructValue, builder, function, expression->children[1]->children[i]);
|
|
}
|
|
|
|
return LLVMBuildCall(builder, FindVariableByName(builder, wStructValue, expression->children[0]->value.string), args, argumentCount, "tmp");
|
|
}
|
|
|
|
static LLVMValueRef CompileExpression(
|
|
LLVMValueRef wStructValue,
|
|
LLVMBuilderRef builder,
|
|
LLVMValueRef function,
|
|
Node *expression
|
|
) {
|
|
LLVMValueRef var;
|
|
|
|
switch (expression->syntaxKind)
|
|
{
|
|
case BinaryExpression:
|
|
return CompileBinaryExpression(wStructValue, builder, function, expression);
|
|
|
|
case FunctionCallExpression:
|
|
return CompileFunctionCallExpression(wStructValue, builder, function, expression);
|
|
|
|
case Identifier:
|
|
return FindVariableByName(builder, wStructValue, expression->value.string);
|
|
|
|
case Number:
|
|
return CompileNumber(expression);
|
|
}
|
|
|
|
fprintf(stderr, "Unknown expression kind!\n");
|
|
return NULL;
|
|
}
|
|
|
|
static void CompileReturn(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement)
|
|
{
|
|
LLVMBuildRet(builder, CompileExpression(wStructValue, builder, function, returnStatemement->children[0]));
|
|
}
|
|
|
|
static void CompileReturnVoid(LLVMBuilderRef builder)
|
|
{
|
|
LLVMBuildRetVoid(builder);
|
|
}
|
|
|
|
static void CompileAssignment(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement)
|
|
{
|
|
LLVMValueRef fieldPointer;
|
|
LLVMValueRef result = CompileExpression(wStructValue, builder, function, assignmentStatement->children[1]);
|
|
LLVMValueRef identifier = CompileExpression(wStructValue, builder, function, assignmentStatement->children[0]);
|
|
|
|
MarkStructFieldForWrite(wStructValue, identifier);
|
|
}
|
|
|
|
static uint8_t CompileStatement(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *statement)
|
|
{
|
|
switch (statement->syntaxKind)
|
|
{
|
|
case Assignment:
|
|
CompileAssignment(wStructValue, builder, function, statement);
|
|
return 0;
|
|
|
|
case Return:
|
|
CompileReturn(wStructValue, builder, function, statement);
|
|
return 1;
|
|
|
|
case ReturnVoid:
|
|
CompileReturnVoid(builder);
|
|
return 1;
|
|
}
|
|
|
|
fprintf(stderr, "Unknown statement kind!\n");
|
|
return 0;
|
|
}
|
|
|
|
static void CompileFunction(
|
|
LLVMModuleRef module,
|
|
LLVMTypeRef wStructPointerType,
|
|
Node **fieldDeclarations,
|
|
uint32_t fieldDeclarationCount,
|
|
Node *functionDeclaration
|
|
) {
|
|
uint32_t i;
|
|
uint8_t hasReturn = 0;
|
|
Node *functionSignature = functionDeclaration->children[0];
|
|
Node *functionBody = functionDeclaration->children[1];
|
|
uint32_t argumentCount = functionSignature->children[2]->childCount + 1; /* struct is implicit argument */
|
|
LLVMTypeRef paramTypes[argumentCount];
|
|
|
|
paramTypes[0] = wStructPointerType;
|
|
|
|
for (i = 0; i < functionSignature->children[2]->childCount; i += 1)
|
|
{
|
|
paramTypes[i + 1] = WraithTypeToLLVMType(functionSignature->children[2]->children[i]->children[0]->type);
|
|
}
|
|
|
|
LLVMTypeRef returnType = WraithTypeToLLVMType(functionSignature->children[1]->type);
|
|
LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, argumentCount, 0);
|
|
LLVMValueRef function = LLVMAddFunction(module, functionSignature->children[0]->value.string, functionType);
|
|
|
|
LLVMValueRef wStructPointer = LLVMGetParam(function, 0);
|
|
|
|
for (i = 0; i < functionSignature->children[2]->childCount; i += 1)
|
|
{
|
|
LLVMValueRef argument = LLVMGetParam(function, i + 1);
|
|
AddNamedVariable(functionSignature->children[2]->children[i]->children[1]->value.string, argument);
|
|
}
|
|
|
|
LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry");
|
|
|
|
LLVMBuilderRef builder = LLVMCreateBuilder();
|
|
LLVMPositionBuilderAtEnd(builder, entry);
|
|
|
|
/* FIXME: replace this with a scope abstraction */
|
|
AddStruct(wStructPointer);
|
|
|
|
for (i = 0; i < fieldDeclarationCount; i += 1)
|
|
{
|
|
AddStructFieldName(builder, wStructPointer, fieldDeclarations[i]->children[1]->value.string, i);
|
|
}
|
|
|
|
for (i = 0; i < functionBody->childCount; i += 1)
|
|
{
|
|
hasReturn |= CompileStatement(wStructPointer, builder, function, functionBody->children[i]);
|
|
}
|
|
|
|
RemoveStruct(builder, wStructPointer);
|
|
|
|
if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn)
|
|
{
|
|
LLVMBuildRetVoid(builder);
|
|
}
|
|
else if (LLVMGetTypeKind(returnType) != LLVMVoidTypeKind && !hasReturn)
|
|
{
|
|
fprintf(stderr, "Return statement not provided!");
|
|
}
|
|
}
|
|
|
|
static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *node)
|
|
{
|
|
uint32_t i;
|
|
uint32_t fieldCount = 0;
|
|
uint32_t declarationCount = node->children[1]->childCount;
|
|
uint8_t packed = 1;
|
|
LLVMTypeRef types[declarationCount];
|
|
Node *currentDeclarationNode;
|
|
Node *fieldDeclarations[declarationCount];
|
|
|
|
LLVMTypeRef wStruct = LLVMStructCreateNamed(context, node->children[0]->value.string);
|
|
LLVMTypeRef wStructPointerType = LLVMPointerType(wStruct, 0); /* FIXME: is this address space correct? */
|
|
|
|
/* first, build the structure definition */
|
|
for (i = 0; i < declarationCount; i += 1)
|
|
{
|
|
currentDeclarationNode = node->children[1]->children[i];
|
|
|
|
switch (currentDeclarationNode->syntaxKind)
|
|
{
|
|
case Declaration: /* this is badly named */
|
|
types[fieldCount] = WraithTypeToLLVMType(currentDeclarationNode->children[0]->type);
|
|
fieldDeclarations[fieldCount] = currentDeclarationNode;
|
|
fieldCount += 1;
|
|
break;
|
|
}
|
|
}
|
|
|
|
LLVMStructSetBody(wStruct, types, fieldCount, packed);
|
|
|
|
/* now we can wire up the functions */
|
|
for (i = 0; i < declarationCount; i += 1)
|
|
{
|
|
currentDeclarationNode = node->children[1]->children[i];
|
|
|
|
switch (currentDeclarationNode->syntaxKind)
|
|
{
|
|
case FunctionDeclaration:
|
|
CompileFunction(module, wStructPointerType, fieldDeclarations, fieldCount, currentDeclarationNode);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
static void Compile(LLVMModuleRef module, LLVMContextRef context, Node *node)
|
|
{
|
|
uint32_t i;
|
|
|
|
switch (node->syntaxKind)
|
|
{
|
|
case StructDeclaration:
|
|
CompileStruct(module, context, node);
|
|
break;
|
|
}
|
|
|
|
for (i = 0; i < node->childCount; i += 1)
|
|
{
|
|
Compile(module, context, node->children[i]);
|
|
}
|
|
}
|
|
|
|
int main(int argc, char *argv[])
|
|
{
|
|
if (argc < 2)
|
|
{
|
|
printf("Please provide a file.\n");
|
|
return 1;
|
|
}
|
|
|
|
namedVariables = NULL;
|
|
namedVariableCount = 0;
|
|
|
|
structFieldMaps = NULL;
|
|
structFieldMapCount = 0;
|
|
|
|
stack = CreateStack();
|
|
|
|
FILE *fp = fopen(argv[1], "r");
|
|
yyin = fp;
|
|
yyparse(fp, stack);
|
|
fclose(fp);
|
|
|
|
PrintTree(rootNode, 0);
|
|
|
|
LLVMModuleRef module = LLVMModuleCreateWithName("my_module");
|
|
LLVMContextRef context = LLVMGetGlobalContext();
|
|
|
|
Compile(module, context, rootNode);
|
|
|
|
char *error = NULL;
|
|
LLVMVerifyModule(module, LLVMAbortProcessAction, &error);
|
|
LLVMDisposeMessage(error);
|
|
|
|
if (LLVMWriteBitcodeToFile(module, "test.bc") != 0) {
|
|
fprintf(stderr, "error writing bitcode to file\n");
|
|
}
|
|
|
|
return 0;
|
|
}
|