compile struct functions

pull/1/head
cosmonaut 2021-04-20 19:00:18 -07:00
parent b5d256251e
commit 2708dfbbed
6 changed files with 304 additions and 54 deletions

10
ast.c
View File

@ -165,6 +165,15 @@ Node* MakeReturnStatementNode(
return node; return node;
} }
Node* MakeReturnVoidStatementNode()
{
Node *node = (Node*) malloc(sizeof(Node));
node->syntaxKind = ReturnVoid;
node->childCount = 0;
node->children = NULL;
return node;
}
Node *MakeFunctionSignatureArgumentsNode( Node *MakeFunctionSignatureArgumentsNode(
Node **pArgumentNodes, Node **pArgumentNodes,
uint32_t argumentCount uint32_t argumentCount
@ -278,6 +287,7 @@ static const char* PrimitiveTypeToString(PrimitiveType type)
case Int: return "Int"; case Int: return "Int";
case UInt: return "UInt"; case UInt: return "UInt";
case Bool: return "Bool"; case Bool: return "Bool";
case Void: return "Void";
} }
return "Unknown"; return "Unknown";

3
ast.h
View File

@ -20,6 +20,7 @@ typedef enum
Identifier, Identifier,
Number, Number,
Return, Return,
ReturnVoid,
StatementSequence, StatementSequence,
StringLiteral, StringLiteral,
StructDeclaration, StructDeclaration,
@ -41,6 +42,7 @@ typedef enum
typedef enum typedef enum
{ {
Void,
Bool, Bool,
Int, Int,
UInt, UInt,
@ -112,6 +114,7 @@ Node* MakeStatementSequenceNode(
Node* MakeReturnStatementNode( Node* MakeReturnStatementNode(
Node *expressionNode Node *expressionNode
); );
Node* MakeReturnVoidStatementNode();
Node* MakeFunctionSignatureArgumentsNode( Node* MakeFunctionSignatureArgumentsNode(
Node **pArgumentNodes, Node **pArgumentNodes,
uint32_t argumentCount uint32_t argumentCount

View File

@ -14,6 +14,100 @@ extern FILE *yyin;
Stack *stack; Stack *stack;
Node *rootNode; Node *rootNode;
typedef struct StructFieldMapValue
{
char *name;
LLVMValueRef value;
LLVMValueRef valuePointer;
uint8_t needsWrite;
} StructFieldMapValue;
typedef struct StructFieldMap
{
LLVMValueRef structPointer;
StructFieldMapValue *fields;
uint32_t fieldCount;
} StructFieldMap;
StructFieldMap *structFieldMaps;
uint32_t structFieldMapCount;
static void AddStruct(LLVMValueRef wStructPointer)
{
structFieldMaps = realloc(structFieldMaps, sizeof(StructFieldMap) * (structFieldMapCount + 1));
structFieldMaps[structFieldMapCount].structPointer = wStructPointer;
structFieldMaps[structFieldMapCount].fields = NULL;
structFieldMaps[structFieldMapCount].fieldCount = 0;
structFieldMapCount += 1;
}
static void AddStructField(LLVMBuilderRef builder, LLVMValueRef wStructPointer, char *name, uint32_t index)
{
uint32_t i, fieldCount;
for (i = 0; i < structFieldMapCount; i += 1)
{
if (structFieldMaps[i].structPointer == wStructPointer)
{
fieldCount = structFieldMaps[i].fieldCount;
LLVMValueRef elementPointer = LLVMBuildStructGEP(
builder,
wStructPointer,
fieldCount,
"ptr"
);
structFieldMaps[i].fields = realloc(structFieldMaps[i].fields, sizeof(StructFieldMapValue) * (fieldCount + 1));
structFieldMaps[i].fields[fieldCount].name = strdup(name);
structFieldMaps[i].fields[fieldCount].value = LLVMBuildLoad(builder, elementPointer, name);
structFieldMaps[i].fields[fieldCount].valuePointer = elementPointer;
structFieldMaps[i].fields[fieldCount].needsWrite = 0;
structFieldMaps[i].fieldCount += 1;
break;
}
}
}
static LLVMValueRef GetStructFieldPointer(LLVMValueRef wStructPointer, LLVMValueRef value)
{
uint32_t i, j;
for (i = 0; i < structFieldMapCount; i += 1)
{
if (structFieldMaps[i].structPointer == wStructPointer)
{
for (j = 0; j < structFieldMaps[i].fieldCount; j += 1)
{
if (structFieldMaps[i].fields[j].value == value)
{
return structFieldMaps[i].fields[j].valuePointer;
}
}
}
}
return NULL;
}
static void RemoveStruct(LLVMValueRef wStructPointer)
{
uint32_t i;
for (i = 0; i < structFieldMapCount; i += 1)
{
if (structFieldMaps[i].structPointer == wStructPointer)
{
free(structFieldMaps[i].fields);
structFieldMaps[i].fields = NULL;
structFieldMaps[i].fieldCount = 0;
break;
}
}
}
typedef struct IdentifierMapValue typedef struct IdentifierMapValue
{ {
char *name; char *name;
@ -24,7 +118,7 @@ IdentifierMapValue *namedVariables;
uint32_t namedVariableCount; uint32_t namedVariableCount;
static LLVMValueRef CompileExpression( static LLVMValueRef CompileExpression(
LLVMModuleRef module, LLVMValueRef wStructValue,
LLVMBuilderRef builder, LLVMBuilderRef builder,
LLVMValueRef function, LLVMValueRef function,
Node *binaryExpression Node *binaryExpression
@ -42,10 +136,11 @@ static void AddNamedVariable(char *name, LLVMValueRef variable)
namedVariableCount += 1; namedVariableCount += 1;
} }
static LLVMValueRef FindVariableByName(char *name) static LLVMValueRef FindVariableByName(LLVMValueRef wStructValue, LLVMBuilderRef builder, char *name)
{ {
uint32_t i; uint32_t i, j;
/* first, search scoped vars */
for (i = 0; i < namedVariableCount; i += 1) for (i = 0; i < namedVariableCount; i += 1)
{ {
if (strcmp(namedVariables[i].name, name) == 0) if (strcmp(namedVariables[i].name, name) == 0)
@ -54,6 +149,22 @@ static LLVMValueRef FindVariableByName(char *name)
} }
} }
/* if none exist, search struct vars */
for (i = 0; i < structFieldMapCount; i += 1)
{
if (structFieldMaps[i].structPointer == wStructValue)
{
for (j = 0; j < structFieldMaps[i].fieldCount; j += 1)
{
if (strcmp(structFieldMaps[i].fields[j].name, name) == 0)
{
return structFieldMaps[i].fields[j].value;
}
}
}
}
fprintf(stderr, "Identifier not found!");
return NULL; return NULL;
} }
@ -66,28 +177,32 @@ static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type)
case UInt: case UInt:
return LLVMInt64Type(); return LLVMInt64Type();
case Bool:
return LLVMInt1Type();
case Void:
return LLVMVoidType();
} }
fprintf(stderr, "Unrecognized type!");
return NULL; return NULL;
} }
static LLVMValueRef CompileNumber( static LLVMValueRef CompileNumber(
LLVMModuleRef module,
LLVMBuilderRef builder,
LLVMValueRef function,
Node *numberExpression Node *numberExpression
) { ) {
return LLVMConstInt(LLVMInt64Type(), numberExpression->value.number, 0); return LLVMConstInt(LLVMInt64Type(), numberExpression->value.number, 0);
} }
static LLVMValueRef CompileBinaryExpression( static LLVMValueRef CompileBinaryExpression(
LLVMModuleRef module, LLVMValueRef wStructValue,
LLVMBuilderRef builder, LLVMBuilderRef builder,
LLVMValueRef function, LLVMValueRef function,
Node *binaryExpression Node *binaryExpression
) { ) {
LLVMValueRef left = CompileExpression(module, builder, function, binaryExpression->children[0]); LLVMValueRef left = CompileExpression(wStructValue, builder, function, binaryExpression->children[0]);
LLVMValueRef right = CompileExpression(module, builder, function, binaryExpression->children[1]); LLVMValueRef right = CompileExpression(wStructValue, builder, function, binaryExpression->children[1]);
switch (binaryExpression->operator.binaryOperator) switch (binaryExpression->operator.binaryOperator)
{ {
@ -106,7 +221,7 @@ static LLVMValueRef CompileBinaryExpression(
} }
static LLVMValueRef CompileFunctionCallExpression( static LLVMValueRef CompileFunctionCallExpression(
LLVMModuleRef module, LLVMValueRef wStructValue,
LLVMBuilderRef builder, LLVMBuilderRef builder,
LLVMValueRef function, LLVMValueRef function,
Node *expression Node *expression
@ -117,71 +232,113 @@ static LLVMValueRef CompileFunctionCallExpression(
for (i = 0; i < argumentCount; i += 1) for (i = 0; i < argumentCount; i += 1)
{ {
args[i] = CompileExpression(module, builder, function, expression->children[1]->children[i]); args[i] = CompileExpression(wStructValue, builder, function, expression->children[1]->children[i]);
} }
return LLVMBuildCall(builder, FindVariableByName(expression->children[0]->value.string), args, argumentCount, "tmp"); return LLVMBuildCall(builder, FindVariableByName(wStructValue, builder, expression->children[0]->value.string), args, argumentCount, "tmp");
} }
static LLVMValueRef CompileExpression( static LLVMValueRef CompileExpression(
LLVMModuleRef module, LLVMValueRef wStructValue,
LLVMBuilderRef builder, LLVMBuilderRef builder,
LLVMValueRef function, LLVMValueRef function,
Node *expression Node *expression
) { ) {
LLVMValueRef var;
switch (expression->syntaxKind) switch (expression->syntaxKind)
{ {
case BinaryExpression: case BinaryExpression:
return CompileBinaryExpression(module, builder, function, expression); return CompileBinaryExpression(wStructValue, builder, function, expression);
case FunctionCallExpression: case FunctionCallExpression:
return CompileFunctionCallExpression(module, builder, function, expression); return CompileFunctionCallExpression(wStructValue, builder, function, expression);
case Identifier: case Identifier:
return FindVariableByName(expression->value.string); return FindVariableByName(wStructValue, builder, expression->value.string);
case Number: case Number:
return CompileNumber(module, builder, function, expression); return CompileNumber(expression);
} }
printf("Error: expected expression\n"); fprintf(stderr, "Unknown expression kind!\n");
return NULL; return NULL;
} }
static void CompileReturn(LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement) static void CompileReturn(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement)
{ {
LLVMBuildRet(builder, CompileExpression(module, builder, function, returnStatemement->children[0])); LLVMBuildRet(builder, CompileExpression(wStructValue, builder, function, returnStatemement->children[0]));
} }
static void CompileStatement(LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *statement) static void CompileReturnVoid(LLVMBuilderRef builder)
{
LLVMBuildRetVoid(builder);
}
static void CompileAssignment(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement)
{
LLVMValueRef fieldPointer;
LLVMValueRef result = CompileExpression(wStructValue, builder, function, assignmentStatement->children[1]);
LLVMValueRef identifier = CompileExpression(wStructValue, builder, function, assignmentStatement->children[0]);
fieldPointer = GetStructFieldPointer(wStructValue, identifier);
if (fieldPointer != NULL)
{
LLVMBuildStore(builder, result, fieldPointer);
}
}
static uint8_t CompileStatement(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *statement)
{ {
switch (statement->syntaxKind) switch (statement->syntaxKind)
{ {
case Assignment:
CompileAssignment(wStructValue, builder, function, statement);
return 0;
case Return: case Return:
CompileReturn(module, builder, function, statement); CompileReturn(wStructValue, builder, function, statement);
break; return 1;
case ReturnVoid:
CompileReturnVoid(builder);
return 1;
} }
fprintf(stderr, "Unknown statement kind!\n");
return 0;
} }
static void CompileFunction(LLVMModuleRef module, Node *functionDeclaration) static void CompileFunction(
{ LLVMModuleRef module,
LLVMTypeRef wStructPointerType,
Node **fieldDeclarations,
uint32_t fieldDeclarationCount,
Node *functionDeclaration
) {
uint32_t i; uint32_t i;
uint8_t hasReturn = 0;
Node *functionSignature = functionDeclaration->children[0]; Node *functionSignature = functionDeclaration->children[0];
Node *functionBody = functionDeclaration->children[1]; Node *functionBody = functionDeclaration->children[1];
LLVMTypeRef paramTypes[functionSignature->children[2]->childCount]; uint32_t argumentCount = functionSignature->children[2]->childCount + 1; /* struct is implicit argument */
LLVMTypeRef paramTypes[argumentCount];
paramTypes[0] = wStructPointerType;
for (i = 0; i < functionSignature->children[2]->childCount; i += 1) for (i = 0; i < functionSignature->children[2]->childCount; i += 1)
{ {
paramTypes[i] = WraithTypeToLLVMType(functionSignature->children[2]->children[i]->children[0]->type); paramTypes[i + 1] = WraithTypeToLLVMType(functionSignature->children[2]->children[i]->children[0]->type);
} }
LLVMTypeRef functionType = LLVMFunctionType(WraithTypeToLLVMType(functionSignature->children[1]->type), paramTypes, functionSignature->children[2]->childCount, 0); LLVMTypeRef returnType = WraithTypeToLLVMType(functionSignature->children[1]->type);
LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, argumentCount, 0);
LLVMValueRef function = LLVMAddFunction(module, functionSignature->children[0]->value.string, functionType); LLVMValueRef function = LLVMAddFunction(module, functionSignature->children[0]->value.string, functionType);
LLVMValueRef wStructPointer = LLVMGetParam(function, 0);
for (i = 0; i < functionSignature->children[2]->childCount; i += 1) for (i = 0; i < functionSignature->children[2]->childCount; i += 1)
{ {
LLVMValueRef argument = LLVMGetParam(function, i); LLVMValueRef argument = LLVMGetParam(function, i + 1);
AddNamedVariable(functionSignature->children[2]->children[i]->children[1]->value.string, argument); AddNamedVariable(functionSignature->children[2]->children[i]->children[1]->value.string, argument);
} }
@ -190,28 +347,91 @@ static void CompileFunction(LLVMModuleRef module, Node *functionDeclaration)
LLVMBuilderRef builder = LLVMCreateBuilder(); LLVMBuilderRef builder = LLVMCreateBuilder();
LLVMPositionBuilderAtEnd(builder, entry); LLVMPositionBuilderAtEnd(builder, entry);
/* FIXME: replace this with a scope abstraction */
AddStruct(wStructPointer);
for (i = 0; i < fieldDeclarationCount; i += 1)
{
AddStructField(builder, wStructPointer, fieldDeclarations[i]->children[1]->value.string, i);
}
for (i = 0; i < functionBody->childCount; i += 1) for (i = 0; i < functionBody->childCount; i += 1)
{ {
CompileStatement(module, builder, function, functionBody->children[i]); hasReturn |= CompileStatement(wStructPointer, builder, function, functionBody->children[i]);
} }
AddNamedVariable(functionSignature->children[0]->value.string, function); AddNamedVariable(functionSignature->children[0]->value.string, function);
if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn)
{
LLVMBuildRetVoid(builder);
}
else if (LLVMGetTypeKind(returnType) != LLVMVoidTypeKind && !hasReturn)
{
fprintf(stderr, "Return statement not provided!");
}
RemoveStruct(wStructPointer);
} }
static void Compile(LLVMModuleRef module, Node *node) static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *node)
{
uint32_t i;
uint32_t fieldCount = 0;
uint32_t declarationCount = node->children[1]->childCount;
uint8_t packed = 1;
LLVMTypeRef types[declarationCount];
Node *currentDeclarationNode;
Node *fieldDeclarations[declarationCount];
LLVMTypeRef wStruct = LLVMStructCreateNamed(context, node->children[0]->value.string);
LLVMTypeRef wStructPointerType = LLVMPointerType(wStruct, 0); /* FIXME: is this address space correct? */
/* first, build the structure definition */
for (i = 0; i < declarationCount; i += 1)
{
currentDeclarationNode = node->children[1]->children[i];
switch (currentDeclarationNode->syntaxKind)
{
case Declaration: /* this is badly named */
types[fieldCount] = WraithTypeToLLVMType(currentDeclarationNode->children[0]->type);
fieldDeclarations[fieldCount] = currentDeclarationNode;
fieldCount += 1;
break;
}
}
LLVMStructSetBody(wStruct, types, fieldCount, packed);
/* now we can wire up the functions */
for (i = 0; i < declarationCount; i += 1)
{
currentDeclarationNode = node->children[1]->children[i];
switch (currentDeclarationNode->syntaxKind)
{
case FunctionDeclaration:
CompileFunction(module, wStructPointerType, fieldDeclarations, fieldCount, currentDeclarationNode);
break;
}
}
}
static void Compile(LLVMModuleRef module, LLVMContextRef context, Node *node)
{ {
uint32_t i; uint32_t i;
switch (node->syntaxKind) switch (node->syntaxKind)
{ {
case FunctionDeclaration: case StructDeclaration:
CompileFunction(module, node); CompileStruct(module, context, node);
break; break;
} }
for (i = 0; i < node->childCount; i += 1) for (i = 0; i < node->childCount; i += 1)
{ {
Compile(module, node->children[i]); Compile(module, context, node->children[i]);
} }
} }
@ -226,6 +446,9 @@ int main(int argc, char *argv[])
namedVariables = NULL; namedVariables = NULL;
namedVariableCount = 0; namedVariableCount = 0;
structFieldMaps = NULL;
structFieldMapCount = 0;
stack = CreateStack(); stack = CreateStack();
FILE *fp = fopen(argv[1], "r"); FILE *fp = fopen(argv[1], "r");
@ -236,8 +459,9 @@ int main(int argc, char *argv[])
PrintTree(rootNode, 0); PrintTree(rootNode, 0);
LLVMModuleRef module = LLVMModuleCreateWithName("my_module"); LLVMModuleRef module = LLVMModuleCreateWithName("my_module");
LLVMContextRef context = LLVMGetGlobalContext();
Compile(module, rootNode); Compile(module, context, rootNode);
char *error = NULL; char *error = NULL;
LLVMVerifyModule(module, LLVMAbortProcessAction, &error); LLVMVerifyModule(module, LLVMAbortProcessAction, &error);

View File

@ -44,7 +44,7 @@ void AddNode(Stack *stack, Node *statementNode)
if (stackFrame->nodeCount == stackFrame->nodeCapacity) if (stackFrame->nodeCount == stackFrame->nodeCapacity)
{ {
stackFrame->nodeCapacity += 1; stackFrame->nodeCapacity += 1;
stackFrame->nodes = (Node**) realloc(stackFrame->nodes, stackFrame->nodeCapacity); stackFrame->nodes = (Node**) realloc(stackFrame->nodes, sizeof(Node*) * stackFrame->nodeCapacity);
} }
stackFrame->nodes[stackFrame->nodeCount] = statementNode; stackFrame->nodes[stackFrame->nodeCount] = statementNode;
stackFrame->nodeCount += 1; stackFrame->nodeCount += 1;

View File

@ -5,7 +5,7 @@
%option noyywrap %option noyywrap
%% %%
[0-9]+ return NUMBER; "void" return VOID;
"int" return INT; "int" return INT;
"uint" return UINT; "uint" return UINT;
"float" return FLOAT; "float" return FLOAT;
@ -14,6 +14,7 @@
"bool" return BOOL; "bool" return BOOL;
"struct" return STRUCT; "struct" return STRUCT;
"return" return RETURN; "return" return RETURN;
[0-9]+ return NUMBER;
[a-zA-Z][a-zA-Z0-9]* return ID; [a-zA-Z][a-zA-Z0-9]* return ID;
\"[a-zA-Z][a-zA-Z0-9]*\" return STRING_LITERAL; \"[a-zA-Z][a-zA-Z0-9]*\" return STRING_LITERAL;
"+" return PLUS; "+" return PLUS;

View File

@ -21,7 +21,7 @@ extern Node *rootNode;
%define api.value.type {struct Node*} %define api.value.type {struct Node*}
%token NUMBER %token VOID
%token INT %token INT
%token UINT %token UINT
%token FLOAT %token FLOAT
@ -30,6 +30,7 @@ extern Node *rootNode;
%token BOOL %token BOOL
%token STRUCT %token STRUCT
%token RETURN %token RETURN
%token NUMBER
%token ID %token ID
%token STRING_LITERAL %token STRING_LITERAL
%token PLUS %token PLUS
@ -65,7 +66,7 @@ extern Node *rootNode;
%left LEFT_PAREN RIGHT_PAREN %left LEFT_PAREN RIGHT_PAREN
%% %%
Program : Declarations Program : TopLevelDeclarations
{ {
Node **declarations; Node **declarations;
Node *declarationSequence; Node *declarationSequence;
@ -79,7 +80,11 @@ Program : Declarations
rootNode = declarationSequence; rootNode = declarationSequence;
} }
Type : INT Type : VOID
{
$$ = MakeTypeNode(Void);
}
| INT
{ {
$$ = MakeTypeNode(Int); $$ = MakeTypeNode(Int);
} }
@ -167,6 +172,10 @@ ReturnStatement : RETURN Expression
{ {
$$ = MakeReturnStatementNode($2); $$ = MakeReturnStatementNode($2);
} }
| RETURN
{
$$ = MakeReturnVoidStatementNode();
}
FunctionCallExpression : Identifier LEFT_PAREN Arguments RIGHT_PAREN FunctionCallExpression : Identifier LEFT_PAREN Arguments RIGHT_PAREN
{ {
@ -208,7 +217,7 @@ Arguments : PrimaryExpression COMMA Arguments
| |
; ;
SignatureArguments : VariableDeclaration COMMA VariableDeclarations SignatureArguments : VariableDeclaration COMMA SignatureArguments
{ {
AddNode(stack, $1); AddNode(stack, $1);
} }
@ -249,16 +258,7 @@ FunctionDeclaration : FunctionSignature Body
$$ = MakeFunctionDeclarationNode($1, $2); $$ = MakeFunctionDeclarationNode($1, $2);
} }
VariableDeclarations : VariableDeclaration SEMICOLON VariableDeclarations StructDeclaration : STRUCT Identifier LEFT_BRACE Declarations RIGHT_BRACE
{
AddNode(stack, $1);
}
|
{
PushStackFrame(stack);
}
StructDeclaration : STRUCT Identifier LEFT_BRACE VariableDeclarations RIGHT_BRACE
{ {
Node **declarations; Node **declarations;
Node *declarationSequence; Node *declarationSequence;
@ -271,8 +271,9 @@ StructDeclaration : STRUCT Identifier LEFT_BRACE VariableDeclarations RIGH
PopStackFrame(stack); PopStackFrame(stack);
} }
Declaration : StructDeclaration
| FunctionDeclaration Declaration : FunctionDeclaration
| VariableDeclaration SEMICOLON
; ;
Declarations : Declaration Declarations Declarations : Declaration Declarations
@ -283,4 +284,15 @@ Declarations : Declaration Declarations
{ {
PushStackFrame(stack); PushStackFrame(stack);
} }
TopLevelDeclaration : StructDeclaration;
TopLevelDeclarations : TopLevelDeclaration TopLevelDeclarations
{
AddNode(stack, $1);
}
|
{
PushStackFrame(stack);
}
%% %%