forked from cosmonaut/wraith-lang
252 lines
6.4 KiB
C
252 lines
6.4 KiB
C
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
#include <string.h>
|
|
|
|
#include <llvm-c/Core.h>
|
|
#include <llvm-c/Analysis.h>
|
|
#include <llvm-c/BitWriter.h>
|
|
|
|
#include "y.tab.h"
|
|
#include "ast.h"
|
|
#include "stack.h"
|
|
|
|
extern FILE *yyin;
|
|
Stack *stack;
|
|
Node *rootNode;
|
|
|
|
typedef struct IdentifierMapValue
|
|
{
|
|
char *name;
|
|
LLVMValueRef value;
|
|
} IdentifierMapValue;
|
|
|
|
IdentifierMapValue *namedVariables;
|
|
uint32_t namedVariableCount;
|
|
|
|
static LLVMValueRef CompileExpression(
|
|
LLVMModuleRef module,
|
|
LLVMBuilderRef builder,
|
|
LLVMValueRef function,
|
|
Node *binaryExpression
|
|
);
|
|
|
|
static void AddNamedVariable(char *name, LLVMValueRef variable)
|
|
{
|
|
IdentifierMapValue mapValue;
|
|
mapValue.name = name;
|
|
mapValue.value = variable;
|
|
|
|
namedVariables = realloc(namedVariables, sizeof(IdentifierMapValue) * (namedVariableCount + 1));
|
|
namedVariables[namedVariableCount] = mapValue;
|
|
|
|
namedVariableCount += 1;
|
|
}
|
|
|
|
static LLVMValueRef FindVariableByName(char *name)
|
|
{
|
|
uint32_t i;
|
|
|
|
for (i = 0; i < namedVariableCount; i += 1)
|
|
{
|
|
if (strcmp(namedVariables[i].name, name) == 0)
|
|
{
|
|
return namedVariables[i].value;
|
|
}
|
|
}
|
|
|
|
return NULL;
|
|
}
|
|
|
|
static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type)
|
|
{
|
|
switch (type)
|
|
{
|
|
case Int:
|
|
return LLVMInt64Type();
|
|
|
|
case UInt:
|
|
return LLVMInt64Type();
|
|
}
|
|
|
|
return NULL;
|
|
}
|
|
|
|
static LLVMValueRef CompileNumber(
|
|
LLVMModuleRef module,
|
|
LLVMBuilderRef builder,
|
|
LLVMValueRef function,
|
|
Node *numberExpression
|
|
) {
|
|
return LLVMConstInt(LLVMInt64Type(), numberExpression->value.number, 0);
|
|
}
|
|
|
|
static LLVMValueRef CompileBinaryExpression(
|
|
LLVMModuleRef module,
|
|
LLVMBuilderRef builder,
|
|
LLVMValueRef function,
|
|
Node *binaryExpression
|
|
) {
|
|
LLVMValueRef left = CompileExpression(module, builder, function, binaryExpression->children[0]);
|
|
LLVMValueRef right = CompileExpression(module, builder, function, binaryExpression->children[1]);
|
|
|
|
switch (binaryExpression->operator.binaryOperator)
|
|
{
|
|
case Add:
|
|
return LLVMBuildAdd(builder, left, right, "tmp");
|
|
|
|
case Subtract:
|
|
return LLVMBuildSub(builder, left, right, "tmp");
|
|
|
|
case Multiply:
|
|
return LLVMBuildMul(builder, left, right, "tmp");
|
|
|
|
}
|
|
|
|
return NULL;
|
|
}
|
|
|
|
static LLVMValueRef CompileFunctionCallExpression(
|
|
LLVMModuleRef module,
|
|
LLVMBuilderRef builder,
|
|
LLVMValueRef function,
|
|
Node *expression
|
|
) {
|
|
uint32_t i;
|
|
uint32_t argumentCount = expression->children[1]->childCount;
|
|
LLVMValueRef args[argumentCount];
|
|
|
|
for (i = 0; i < argumentCount; i += 1)
|
|
{
|
|
args[i] = CompileExpression(module, builder, function, expression->children[1]->children[i]);
|
|
}
|
|
|
|
return LLVMBuildCall(builder, FindVariableByName(expression->children[0]->value.string), args, argumentCount, "tmp");
|
|
}
|
|
|
|
static LLVMValueRef CompileExpression(
|
|
LLVMModuleRef module,
|
|
LLVMBuilderRef builder,
|
|
LLVMValueRef function,
|
|
Node *expression
|
|
) {
|
|
switch (expression->syntaxKind)
|
|
{
|
|
case BinaryExpression:
|
|
return CompileBinaryExpression(module, builder, function, expression);
|
|
|
|
case FunctionCallExpression:
|
|
return CompileFunctionCallExpression(module, builder, function, expression);
|
|
|
|
case Identifier:
|
|
return FindVariableByName(expression->value.string);
|
|
|
|
case Number:
|
|
return CompileNumber(module, builder, function, expression);
|
|
}
|
|
|
|
printf("Error: expected expression\n");
|
|
return NULL;
|
|
}
|
|
|
|
static void CompileReturn(LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement)
|
|
{
|
|
LLVMBuildRet(builder, CompileExpression(module, builder, function, returnStatemement->children[0]));
|
|
}
|
|
|
|
static void CompileStatement(LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *statement)
|
|
{
|
|
switch (statement->syntaxKind)
|
|
{
|
|
case Return:
|
|
CompileReturn(module, builder, function, statement);
|
|
break;
|
|
}
|
|
}
|
|
|
|
static void CompileFunction(LLVMModuleRef module, Node *functionDeclaration)
|
|
{
|
|
uint32_t i;
|
|
Node *functionSignature = functionDeclaration->children[0];
|
|
Node *functionBody = functionDeclaration->children[1];
|
|
LLVMTypeRef paramTypes[functionSignature->children[2]->childCount];
|
|
|
|
for (i = 0; i < functionSignature->children[2]->childCount; i += 1)
|
|
{
|
|
paramTypes[i] = WraithTypeToLLVMType(functionSignature->children[2]->children[i]->children[0]->type);
|
|
}
|
|
|
|
LLVMTypeRef functionType = LLVMFunctionType(WraithTypeToLLVMType(functionSignature->children[1]->type), paramTypes, functionSignature->children[2]->childCount, 0);
|
|
|
|
LLVMValueRef function = LLVMAddFunction(module, functionSignature->children[0]->value.string, functionType);
|
|
|
|
for (i = 0; i < functionSignature->children[2]->childCount; i += 1)
|
|
{
|
|
LLVMValueRef argument = LLVMGetParam(function, i);
|
|
AddNamedVariable(functionSignature->children[2]->children[i]->children[1]->value.string, argument);
|
|
}
|
|
|
|
LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry");
|
|
|
|
LLVMBuilderRef builder = LLVMCreateBuilder();
|
|
LLVMPositionBuilderAtEnd(builder, entry);
|
|
|
|
for (i = 0; i < functionBody->childCount; i += 1)
|
|
{
|
|
CompileStatement(module, builder, function, functionBody->children[i]);
|
|
}
|
|
|
|
AddNamedVariable(functionSignature->children[0]->value.string, function);
|
|
}
|
|
|
|
static void Compile(LLVMModuleRef module, Node *node)
|
|
{
|
|
uint32_t i;
|
|
|
|
switch (node->syntaxKind)
|
|
{
|
|
case FunctionDeclaration:
|
|
CompileFunction(module, node);
|
|
break;
|
|
}
|
|
|
|
for (i = 0; i < node->childCount; i += 1)
|
|
{
|
|
Compile(module, node->children[i]);
|
|
}
|
|
}
|
|
|
|
int main(int argc, char *argv[])
|
|
{
|
|
if (argc < 2)
|
|
{
|
|
printf("Please provide a file.\n");
|
|
return 1;
|
|
}
|
|
|
|
namedVariables = NULL;
|
|
namedVariableCount = 0;
|
|
|
|
stack = CreateStack();
|
|
|
|
FILE *fp = fopen(argv[1], "r");
|
|
yyin = fp;
|
|
yyparse(fp, stack);
|
|
fclose(fp);
|
|
|
|
PrintTree(rootNode, 0);
|
|
|
|
LLVMModuleRef module = LLVMModuleCreateWithName("my_module");
|
|
|
|
Compile(module, rootNode);
|
|
|
|
char *error = NULL;
|
|
LLVMVerifyModule(module, LLVMAbortProcessAction, &error);
|
|
LLVMDisposeMessage(error);
|
|
|
|
if (LLVMWriteBitcodeToFile(module, "test.bc") != 0) {
|
|
fprintf(stderr, "error writing bitcode to file\n");
|
|
}
|
|
|
|
return 0;
|
|
}
|