wraith-lang/compiler.c

223 lines
5.4 KiB
C

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <llvm-c/Core.h>
#include <llvm-c/Analysis.h>
#include <llvm-c/BitWriter.h>
#include "y.tab.h"
#include "ast.h"
#include "stack.h"
extern FILE *yyin;
Stack *stack;
Node *rootNode;
typedef struct VariableMapValue
{
char *name;
LLVMValueRef variable;
} VariableMapValue;
VariableMapValue *namedVariables;
uint32_t namedVariableCount;
static LLVMValueRef CompileExpression(
LLVMModuleRef module,
LLVMBuilderRef builder,
LLVMValueRef function,
Node *binaryExpression
);
static void AddNamedVariable(char *name, LLVMValueRef variable)
{
VariableMapValue mapValue;
mapValue.name = name;
mapValue.variable = variable;
namedVariables = realloc(namedVariables, namedVariableCount + 1);
namedVariables[namedVariableCount] = mapValue;
namedVariableCount += 1;
}
static LLVMValueRef FindVariableByName(char *name)
{
uint32_t i;
for (i = 0; i < namedVariableCount; i += 1)
{
if (strcmp(namedVariables[i].name, name) == 0)
{
return namedVariables[i].variable;
}
}
return NULL;
}
static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type)
{
switch (type)
{
case Int:
return LLVMInt64Type();
case UInt:
return LLVMInt64Type();
}
return NULL;
}
static LLVMValueRef CompileNumber(
LLVMModuleRef module,
LLVMBuilderRef builder,
LLVMValueRef function,
Node *numberExpression
)
{
return LLVMConstInt(LLVMInt64Type(), numberExpression->value.number, 0);
}
static LLVMValueRef CompileBinaryExpression(
LLVMModuleRef module,
LLVMBuilderRef builder,
LLVMValueRef function,
Node *binaryExpression
) {
LLVMValueRef left = CompileExpression(module, builder, function, binaryExpression->children[0]);
LLVMValueRef right = CompileExpression(module, builder, function, binaryExpression->children[1]);
switch (binaryExpression->operator.binaryOperator)
{
case Add:
return LLVMBuildAdd(builder, left, right, "tmp");
}
return NULL;
}
static LLVMValueRef CompileExpression(
LLVMModuleRef module,
LLVMBuilderRef builder,
LLVMValueRef function,
Node *expression
) {
switch (expression->syntaxKind)
{
case BinaryExpression:
return CompileBinaryExpression(module, builder, function, expression);
case Identifier:
return FindVariableByName(expression->value.string);
case Number:
return CompileNumber(module, builder, function, expression);
}
printf("Error: expected expression\n");
return NULL;
}
static void CompileReturn(LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement)
{
LLVMBuildRet(builder, CompileExpression(module, builder, function, returnStatemement->children[0]));
}
static void CompileStatement(LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *statement)
{
switch (statement->syntaxKind)
{
case Return:
CompileReturn(module, builder, function, statement);
break;
}
}
static void CompileFunction(LLVMModuleRef module, Node *functionDeclaration)
{
uint32_t i;
Node *functionSignature = functionDeclaration->children[0];
Node *functionBody = functionDeclaration->children[1];
LLVMTypeRef paramTypes[functionSignature->children[2]->childCount];
for (i = 0; i < functionSignature->children[2]->childCount; i += 1)
{
paramTypes[i] = WraithTypeToLLVMType(functionSignature->children[2]->children[i]->children[0]->type);
}
LLVMTypeRef functionType = LLVMFunctionType(WraithTypeToLLVMType(functionSignature->children[1]->type), paramTypes, functionSignature->children[2]->childCount, 0);
LLVMValueRef function = LLVMAddFunction(module, functionSignature->children[0]->value.string, functionType);
for (i = 0; i < functionSignature->children[2]->childCount; i += 1)
{
LLVMValueRef argument = LLVMGetParam(function, i);
AddNamedVariable(functionSignature->children[2]->children[i]->children[1]->value.string, argument);
}
LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry");
LLVMBuilderRef builder = LLVMCreateBuilder();
LLVMPositionBuilderAtEnd(builder, entry);
for (i = 0; i < functionBody->childCount; i += 1)
{
CompileStatement(module, builder, function, functionBody->children[i]);
}
}
static void Compile(LLVMModuleRef module, Node *node)
{
uint32_t i;
switch (node->syntaxKind)
{
case FunctionDeclaration:
CompileFunction(module, node);
break;
}
for (i = 0; i < node->childCount; i += 1)
{
Compile(module, node->children[i]);
}
}
int main(int argc, char *argv[])
{
if (argc < 2)
{
printf("Please provide a file.\n");
return 1;
}
namedVariables = NULL;
namedVariableCount = 0;
stack = CreateStack();
FILE *fp = fopen(argv[1], "r");
yyin = fp;
yyparse(fp, stack);
fclose(fp);
PrintTree(rootNode, 0);
LLVMModuleRef module = LLVMModuleCreateWithName("my_module");
Compile(module, rootNode);
char *error = NULL;
LLVMVerifyModule(module, LLVMAbortProcessAction, &error);
LLVMDisposeMessage(error);
if (LLVMWriteBitcodeToFile(module, "test.bc") != 0) {
fprintf(stderr, "error writing bitcode to file\n");
}
return 0;
}