wraith-lang/compiler.c

530 lines
15 KiB
C
Raw Normal View History

2021-04-18 22:29:54 +00:00
#include <stdio.h>
2021-04-20 01:18:45 +00:00
#include <stdlib.h>
#include <string.h>
2021-04-18 22:29:54 +00:00
#include <llvm-c/Core.h>
2021-04-20 01:18:45 +00:00
#include <llvm-c/Analysis.h>
#include <llvm-c/BitWriter.h>
2021-04-18 22:45:06 +00:00
2021-04-18 22:29:54 +00:00
#include "y.tab.h"
2021-04-18 22:45:06 +00:00
#include "ast.h"
2021-04-18 22:29:54 +00:00
#include "stack.h"
extern FILE *yyin;
Stack *stack;
2021-04-18 22:45:06 +00:00
Node *rootNode;
2021-04-21 02:00:18 +00:00
typedef struct StructFieldMapValue
{
char *name;
LLVMValueRef value;
LLVMValueRef valuePointer;
2021-04-21 02:40:39 +00:00
uint32_t index;
2021-04-21 02:00:18 +00:00
uint8_t needsWrite;
2021-04-21 02:40:39 +00:00
uint8_t needsRead;
2021-04-21 02:00:18 +00:00
} StructFieldMapValue;
typedef struct StructFieldMap
{
LLVMValueRef structPointer;
StructFieldMapValue *fields;
uint32_t fieldCount;
} StructFieldMap;
StructFieldMap *structFieldMaps;
uint32_t structFieldMapCount;
static void AddStruct(LLVMValueRef wStructPointer)
{
structFieldMaps = realloc(structFieldMaps, sizeof(StructFieldMap) * (structFieldMapCount + 1));
structFieldMaps[structFieldMapCount].structPointer = wStructPointer;
structFieldMaps[structFieldMapCount].fields = NULL;
structFieldMaps[structFieldMapCount].fieldCount = 0;
structFieldMapCount += 1;
}
2021-04-21 02:40:39 +00:00
static void AddStructFieldName(LLVMBuilderRef builder, LLVMValueRef wStructPointer, char *name, uint32_t index)
2021-04-21 02:00:18 +00:00
{
uint32_t i, fieldCount;
for (i = 0; i < structFieldMapCount; i += 1)
{
if (structFieldMaps[i].structPointer == wStructPointer)
{
fieldCount = structFieldMaps[i].fieldCount;
structFieldMaps[i].fields = realloc(structFieldMaps[i].fields, sizeof(StructFieldMapValue) * (fieldCount + 1));
structFieldMaps[i].fields[fieldCount].name = strdup(name);
2021-04-21 02:40:39 +00:00
structFieldMaps[i].fields[fieldCount].value = NULL;
structFieldMaps[i].fields[fieldCount].valuePointer = NULL;
structFieldMaps[i].fields[fieldCount].index = index;
2021-04-21 02:00:18 +00:00
structFieldMaps[i].fields[fieldCount].needsWrite = 0;
2021-04-21 02:40:39 +00:00
structFieldMaps[i].fields[fieldCount].needsRead = 1;
2021-04-21 02:00:18 +00:00
structFieldMaps[i].fieldCount += 1;
break;
}
}
}
2021-04-21 02:40:39 +00:00
static LLVMValueRef CheckStructFieldAndLoad(LLVMBuilderRef builder, LLVMValueRef wStructPointer, char *name)
{
uint32_t i, j;
for (i = 0; i < structFieldMapCount; i += 1)
{
if (structFieldMaps[i].structPointer == wStructPointer)
{
for (j = 0; j < structFieldMaps[i].fieldCount; j += 1)
{
if (strcmp(structFieldMaps[i].fields[j].name, name) == 0)
{
if (structFieldMaps[i].fields[j].needsRead)
{
char *ptrName = strdup(name);
strcat(ptrName, "_ptr");
LLVMValueRef elementPointer = LLVMBuildStructGEP(
builder,
wStructPointer,
structFieldMaps[i].fields[j].index,
ptrName
);
free(ptrName);
structFieldMaps[i].fields[j].value = LLVMBuildLoad(builder, elementPointer, name);
structFieldMaps[i].fields[j].valuePointer = elementPointer;
structFieldMaps[i].fields[j].needsRead = 0;
}
return structFieldMaps[i].fields[j].value;
}
}
}
}
return NULL;
}
2021-04-21 02:07:11 +00:00
static void MarkStructFieldForWrite(LLVMValueRef wStructPointer, LLVMValueRef value)
{
uint32_t i, j;
for (i = 0; i < structFieldMapCount; i += 1)
{
if (structFieldMaps[i].structPointer == wStructPointer)
{
for (j = 0; j < structFieldMaps[i].fieldCount; j += 1)
{
if (structFieldMaps[i].fields[j].value == value)
{
structFieldMaps[i].fields[j].needsWrite = 1;
break;
}
}
}
}
}
2021-04-21 02:00:18 +00:00
static LLVMValueRef GetStructFieldPointer(LLVMValueRef wStructPointer, LLVMValueRef value)
{
uint32_t i, j;
for (i = 0; i < structFieldMapCount; i += 1)
{
if (structFieldMaps[i].structPointer == wStructPointer)
{
for (j = 0; j < structFieldMaps[i].fieldCount; j += 1)
{
if (structFieldMaps[i].fields[j].value == value)
{
return structFieldMaps[i].fields[j].valuePointer;
}
}
}
}
return NULL;
}
2021-04-21 02:07:11 +00:00
static void RemoveStruct(LLVMBuilderRef builder, LLVMValueRef wStructPointer)
2021-04-21 02:00:18 +00:00
{
2021-04-21 02:07:11 +00:00
uint32_t i, j;
2021-04-21 02:00:18 +00:00
for (i = 0; i < structFieldMapCount; i += 1)
{
if (structFieldMaps[i].structPointer == wStructPointer)
{
2021-04-21 02:07:11 +00:00
for (j = 0; j < structFieldMaps[i].fieldCount; j += 1)
{
if (structFieldMaps[i].fields[j].needsWrite)
{
LLVMBuildStore(
builder,
structFieldMaps[i].fields[j].value,
structFieldMaps[i].fields[j].valuePointer
);
}
}
2021-04-21 02:00:18 +00:00
free(structFieldMaps[i].fields);
structFieldMaps[i].fields = NULL;
structFieldMaps[i].fieldCount = 0;
break;
}
}
}
2021-04-20 17:47:40 +00:00
typedef struct IdentifierMapValue
2021-04-20 01:18:45 +00:00
{
char *name;
2021-04-20 17:47:40 +00:00
LLVMValueRef value;
} IdentifierMapValue;
2021-04-20 01:18:45 +00:00
2021-04-20 17:47:40 +00:00
IdentifierMapValue *namedVariables;
2021-04-20 01:18:45 +00:00
uint32_t namedVariableCount;
static LLVMValueRef CompileExpression(
2021-04-21 02:00:18 +00:00
LLVMValueRef wStructValue,
2021-04-20 01:18:45 +00:00
LLVMBuilderRef builder,
LLVMValueRef function,
Node *binaryExpression
);
static void AddNamedVariable(char *name, LLVMValueRef variable)
{
2021-04-20 17:47:40 +00:00
IdentifierMapValue mapValue;
2021-04-20 01:18:45 +00:00
mapValue.name = name;
2021-04-20 17:47:40 +00:00
mapValue.value = variable;
2021-04-20 01:18:45 +00:00
2021-04-20 17:47:40 +00:00
namedVariables = realloc(namedVariables, sizeof(IdentifierMapValue) * (namedVariableCount + 1));
2021-04-20 01:18:45 +00:00
namedVariables[namedVariableCount] = mapValue;
namedVariableCount += 1;
}
2021-04-21 02:40:39 +00:00
static LLVMValueRef FindVariableByName(LLVMBuilderRef builder, LLVMValueRef wStructValue, char *name)
2021-04-20 01:18:45 +00:00
{
2021-04-21 02:00:18 +00:00
uint32_t i, j;
2021-04-21 02:40:39 +00:00
LLVMValueRef searchResult;
2021-04-20 01:18:45 +00:00
2021-04-21 02:00:18 +00:00
/* first, search scoped vars */
2021-04-20 01:18:45 +00:00
for (i = 0; i < namedVariableCount; i += 1)
{
if (strcmp(namedVariables[i].name, name) == 0)
{
2021-04-20 17:47:40 +00:00
return namedVariables[i].value;
2021-04-20 01:18:45 +00:00
}
}
2021-04-21 02:00:18 +00:00
/* if none exist, search struct vars */
2021-04-21 02:40:39 +00:00
searchResult = CheckStructFieldAndLoad(builder, wStructValue, name);
if (searchResult == NULL)
2021-04-21 02:00:18 +00:00
{
2021-04-21 02:40:39 +00:00
fprintf(stderr, "Identifier not found!");
2021-04-21 02:00:18 +00:00
}
2021-04-21 02:40:39 +00:00
return searchResult;
2021-04-20 01:18:45 +00:00
}
static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type)
2021-04-18 22:45:06 +00:00
{
switch (type)
{
case Int:
return LLVMInt64Type();
case UInt:
return LLVMInt64Type();
2021-04-21 02:00:18 +00:00
case Bool:
return LLVMInt1Type();
case Void:
return LLVMVoidType();
2021-04-18 22:45:06 +00:00
}
2021-04-21 02:00:18 +00:00
fprintf(stderr, "Unrecognized type!");
2021-04-18 22:45:06 +00:00
return NULL;
}
2021-04-20 01:18:45 +00:00
static LLVMValueRef CompileNumber(
Node *numberExpression
2021-04-20 17:47:40 +00:00
) {
2021-04-20 01:18:45 +00:00
return LLVMConstInt(LLVMInt64Type(), numberExpression->value.number, 0);
}
static LLVMValueRef CompileBinaryExpression(
2021-04-21 02:00:18 +00:00
LLVMValueRef wStructValue,
2021-04-20 01:18:45 +00:00
LLVMBuilderRef builder,
LLVMValueRef function,
Node *binaryExpression
) {
2021-04-21 02:00:18 +00:00
LLVMValueRef left = CompileExpression(wStructValue, builder, function, binaryExpression->children[0]);
LLVMValueRef right = CompileExpression(wStructValue, builder, function, binaryExpression->children[1]);
2021-04-20 01:18:45 +00:00
switch (binaryExpression->operator.binaryOperator)
{
case Add:
return LLVMBuildAdd(builder, left, right, "tmp");
2021-04-20 17:47:40 +00:00
case Subtract:
return LLVMBuildSub(builder, left, right, "tmp");
case Multiply:
return LLVMBuildMul(builder, left, right, "tmp");
2021-04-20 01:18:45 +00:00
}
return NULL;
}
2021-04-20 17:47:40 +00:00
static LLVMValueRef CompileFunctionCallExpression(
2021-04-21 02:00:18 +00:00
LLVMValueRef wStructValue,
2021-04-20 17:47:40 +00:00
LLVMBuilderRef builder,
LLVMValueRef function,
Node *expression
) {
uint32_t i;
uint32_t argumentCount = expression->children[1]->childCount;
LLVMValueRef args[argumentCount];
for (i = 0; i < argumentCount; i += 1)
{
2021-04-21 02:00:18 +00:00
args[i] = CompileExpression(wStructValue, builder, function, expression->children[1]->children[i]);
2021-04-20 17:47:40 +00:00
}
2021-04-21 02:40:39 +00:00
return LLVMBuildCall(builder, FindVariableByName(builder, wStructValue, expression->children[0]->value.string), args, argumentCount, "tmp");
2021-04-20 17:47:40 +00:00
}
2021-04-20 01:18:45 +00:00
static LLVMValueRef CompileExpression(
2021-04-21 02:00:18 +00:00
LLVMValueRef wStructValue,
2021-04-20 01:18:45 +00:00
LLVMBuilderRef builder,
LLVMValueRef function,
Node *expression
) {
2021-04-21 02:00:18 +00:00
LLVMValueRef var;
2021-04-20 01:18:45 +00:00
switch (expression->syntaxKind)
{
case BinaryExpression:
2021-04-21 02:00:18 +00:00
return CompileBinaryExpression(wStructValue, builder, function, expression);
2021-04-20 01:18:45 +00:00
2021-04-20 17:47:40 +00:00
case FunctionCallExpression:
2021-04-21 02:00:18 +00:00
return CompileFunctionCallExpression(wStructValue, builder, function, expression);
2021-04-20 17:47:40 +00:00
2021-04-20 01:18:45 +00:00
case Identifier:
2021-04-21 02:40:39 +00:00
return FindVariableByName(builder, wStructValue, expression->value.string);
2021-04-20 01:18:45 +00:00
case Number:
2021-04-21 02:00:18 +00:00
return CompileNumber(expression);
2021-04-20 01:18:45 +00:00
}
2021-04-21 02:00:18 +00:00
fprintf(stderr, "Unknown expression kind!\n");
2021-04-20 01:18:45 +00:00
return NULL;
}
2021-04-21 02:00:18 +00:00
static void CompileReturn(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement)
2021-04-20 01:18:45 +00:00
{
2021-04-21 02:00:18 +00:00
LLVMBuildRet(builder, CompileExpression(wStructValue, builder, function, returnStatemement->children[0]));
}
static void CompileReturnVoid(LLVMBuilderRef builder)
{
LLVMBuildRetVoid(builder);
}
static void CompileAssignment(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement)
{
LLVMValueRef fieldPointer;
LLVMValueRef result = CompileExpression(wStructValue, builder, function, assignmentStatement->children[1]);
LLVMValueRef identifier = CompileExpression(wStructValue, builder, function, assignmentStatement->children[0]);
2021-04-21 02:07:11 +00:00
MarkStructFieldForWrite(wStructValue, identifier);
2021-04-20 01:18:45 +00:00
}
2021-04-21 02:00:18 +00:00
static uint8_t CompileStatement(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *statement)
2021-04-20 01:18:45 +00:00
{
switch (statement->syntaxKind)
{
2021-04-21 02:00:18 +00:00
case Assignment:
CompileAssignment(wStructValue, builder, function, statement);
return 0;
2021-04-20 01:18:45 +00:00
case Return:
2021-04-21 02:00:18 +00:00
CompileReturn(wStructValue, builder, function, statement);
return 1;
case ReturnVoid:
CompileReturnVoid(builder);
return 1;
2021-04-20 01:18:45 +00:00
}
2021-04-21 02:00:18 +00:00
fprintf(stderr, "Unknown statement kind!\n");
return 0;
2021-04-20 01:18:45 +00:00
}
2021-04-21 02:00:18 +00:00
static void CompileFunction(
LLVMModuleRef module,
LLVMTypeRef wStructPointerType,
Node **fieldDeclarations,
uint32_t fieldDeclarationCount,
Node *functionDeclaration
) {
2021-04-18 22:45:06 +00:00
uint32_t i;
2021-04-21 02:00:18 +00:00
uint8_t hasReturn = 0;
2021-04-18 22:45:06 +00:00
Node *functionSignature = functionDeclaration->children[0];
2021-04-20 01:18:45 +00:00
Node *functionBody = functionDeclaration->children[1];
2021-04-21 02:00:18 +00:00
uint32_t argumentCount = functionSignature->children[2]->childCount + 1; /* struct is implicit argument */
LLVMTypeRef paramTypes[argumentCount];
paramTypes[0] = wStructPointerType;
2021-04-18 22:45:06 +00:00
for (i = 0; i < functionSignature->children[2]->childCount; i += 1)
{
2021-04-21 02:00:18 +00:00
paramTypes[i + 1] = WraithTypeToLLVMType(functionSignature->children[2]->children[i]->children[0]->type);
2021-04-20 01:18:45 +00:00
}
2021-04-21 02:00:18 +00:00
LLVMTypeRef returnType = WraithTypeToLLVMType(functionSignature->children[1]->type);
LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, argumentCount, 0);
2021-04-20 01:18:45 +00:00
LLVMValueRef function = LLVMAddFunction(module, functionSignature->children[0]->value.string, functionType);
2021-04-21 02:00:18 +00:00
LLVMValueRef wStructPointer = LLVMGetParam(function, 0);
2021-04-20 01:18:45 +00:00
for (i = 0; i < functionSignature->children[2]->childCount; i += 1)
{
2021-04-21 02:00:18 +00:00
LLVMValueRef argument = LLVMGetParam(function, i + 1);
2021-04-20 01:18:45 +00:00
AddNamedVariable(functionSignature->children[2]->children[i]->children[1]->value.string, argument);
2021-04-18 22:45:06 +00:00
}
2021-04-20 01:18:45 +00:00
LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry");
2021-04-18 22:45:06 +00:00
2021-04-20 01:18:45 +00:00
LLVMBuilderRef builder = LLVMCreateBuilder();
LLVMPositionBuilderAtEnd(builder, entry);
2021-04-21 02:00:18 +00:00
/* FIXME: replace this with a scope abstraction */
AddStruct(wStructPointer);
for (i = 0; i < fieldDeclarationCount; i += 1)
{
2021-04-21 02:40:39 +00:00
AddStructFieldName(builder, wStructPointer, fieldDeclarations[i]->children[1]->value.string, i);
2021-04-21 02:00:18 +00:00
}
2021-04-20 01:18:45 +00:00
for (i = 0; i < functionBody->childCount; i += 1)
{
2021-04-21 02:00:18 +00:00
hasReturn |= CompileStatement(wStructPointer, builder, function, functionBody->children[i]);
2021-04-20 01:18:45 +00:00
}
2021-04-20 17:47:40 +00:00
2021-04-21 02:07:11 +00:00
RemoveStruct(builder, wStructPointer);
2021-04-21 02:00:18 +00:00
if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn)
{
LLVMBuildRetVoid(builder);
}
else if (LLVMGetTypeKind(returnType) != LLVMVoidTypeKind && !hasReturn)
{
fprintf(stderr, "Return statement not provided!");
}
}
static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *node)
{
uint32_t i;
uint32_t fieldCount = 0;
uint32_t declarationCount = node->children[1]->childCount;
uint8_t packed = 1;
LLVMTypeRef types[declarationCount];
Node *currentDeclarationNode;
Node *fieldDeclarations[declarationCount];
LLVMTypeRef wStruct = LLVMStructCreateNamed(context, node->children[0]->value.string);
LLVMTypeRef wStructPointerType = LLVMPointerType(wStruct, 0); /* FIXME: is this address space correct? */
/* first, build the structure definition */
for (i = 0; i < declarationCount; i += 1)
{
currentDeclarationNode = node->children[1]->children[i];
switch (currentDeclarationNode->syntaxKind)
{
case Declaration: /* this is badly named */
types[fieldCount] = WraithTypeToLLVMType(currentDeclarationNode->children[0]->type);
fieldDeclarations[fieldCount] = currentDeclarationNode;
fieldCount += 1;
break;
}
}
LLVMStructSetBody(wStruct, types, fieldCount, packed);
/* now we can wire up the functions */
for (i = 0; i < declarationCount; i += 1)
{
currentDeclarationNode = node->children[1]->children[i];
switch (currentDeclarationNode->syntaxKind)
{
case FunctionDeclaration:
CompileFunction(module, wStructPointerType, fieldDeclarations, fieldCount, currentDeclarationNode);
break;
}
}
2021-04-18 22:45:06 +00:00
}
2021-04-21 02:00:18 +00:00
static void Compile(LLVMModuleRef module, LLVMContextRef context, Node *node)
2021-04-18 22:45:06 +00:00
{
uint32_t i;
switch (node->syntaxKind)
{
2021-04-21 02:00:18 +00:00
case StructDeclaration:
CompileStruct(module, context, node);
2021-04-18 22:45:06 +00:00
break;
}
for (i = 0; i < node->childCount; i += 1)
{
2021-04-21 02:00:18 +00:00
Compile(module, context, node->children[i]);
2021-04-18 22:45:06 +00:00
}
}
2021-04-18 22:29:54 +00:00
int main(int argc, char *argv[])
{
if (argc < 2)
{
printf("Please provide a file.\n");
return 1;
}
2021-04-20 01:18:45 +00:00
namedVariables = NULL;
namedVariableCount = 0;
2021-04-21 02:00:18 +00:00
structFieldMaps = NULL;
structFieldMapCount = 0;
2021-04-18 22:29:54 +00:00
stack = CreateStack();
FILE *fp = fopen(argv[1], "r");
yyin = fp;
yyparse(fp, stack);
fclose(fp);
2021-04-18 22:45:06 +00:00
PrintTree(rootNode, 0);
LLVMModuleRef module = LLVMModuleCreateWithName("my_module");
2021-04-21 02:00:18 +00:00
LLVMContextRef context = LLVMGetGlobalContext();
2021-04-18 22:45:06 +00:00
2021-04-21 02:00:18 +00:00
Compile(module, context, rootNode);
2021-04-18 22:29:54 +00:00
2021-04-20 01:18:45 +00:00
char *error = NULL;
LLVMVerifyModule(module, LLVMAbortProcessAction, &error);
LLVMDisposeMessage(error);
if (LLVMWriteBitcodeToFile(module, "test.bc") != 0) {
fprintf(stderr, "error writing bitcode to file\n");
}
2021-04-18 22:29:54 +00:00
return 0;
}