compile struct functions

generics
cosmonaut 2021-04-20 19:00:18 -07:00
parent b5d256251e
commit 2708dfbbed
6 changed files with 304 additions and 54 deletions

10
ast.c
View File

@ -165,6 +165,15 @@ Node* MakeReturnStatementNode(
return node;
}
Node* MakeReturnVoidStatementNode()
{
Node *node = (Node*) malloc(sizeof(Node));
node->syntaxKind = ReturnVoid;
node->childCount = 0;
node->children = NULL;
return node;
}
Node *MakeFunctionSignatureArgumentsNode(
Node **pArgumentNodes,
uint32_t argumentCount
@ -278,6 +287,7 @@ static const char* PrimitiveTypeToString(PrimitiveType type)
case Int: return "Int";
case UInt: return "UInt";
case Bool: return "Bool";
case Void: return "Void";
}
return "Unknown";

3
ast.h
View File

@ -20,6 +20,7 @@ typedef enum
Identifier,
Number,
Return,
ReturnVoid,
StatementSequence,
StringLiteral,
StructDeclaration,
@ -41,6 +42,7 @@ typedef enum
typedef enum
{
Void,
Bool,
Int,
UInt,
@ -112,6 +114,7 @@ Node* MakeStatementSequenceNode(
Node* MakeReturnStatementNode(
Node *expressionNode
);
Node* MakeReturnVoidStatementNode();
Node* MakeFunctionSignatureArgumentsNode(
Node **pArgumentNodes,
uint32_t argumentCount

View File

@ -14,6 +14,100 @@ extern FILE *yyin;
Stack *stack;
Node *rootNode;
typedef struct StructFieldMapValue
{
char *name;
LLVMValueRef value;
LLVMValueRef valuePointer;
uint8_t needsWrite;
} StructFieldMapValue;
typedef struct StructFieldMap
{
LLVMValueRef structPointer;
StructFieldMapValue *fields;
uint32_t fieldCount;
} StructFieldMap;
StructFieldMap *structFieldMaps;
uint32_t structFieldMapCount;
static void AddStruct(LLVMValueRef wStructPointer)
{
structFieldMaps = realloc(structFieldMaps, sizeof(StructFieldMap) * (structFieldMapCount + 1));
structFieldMaps[structFieldMapCount].structPointer = wStructPointer;
structFieldMaps[structFieldMapCount].fields = NULL;
structFieldMaps[structFieldMapCount].fieldCount = 0;
structFieldMapCount += 1;
}
static void AddStructField(LLVMBuilderRef builder, LLVMValueRef wStructPointer, char *name, uint32_t index)
{
uint32_t i, fieldCount;
for (i = 0; i < structFieldMapCount; i += 1)
{
if (structFieldMaps[i].structPointer == wStructPointer)
{
fieldCount = structFieldMaps[i].fieldCount;
LLVMValueRef elementPointer = LLVMBuildStructGEP(
builder,
wStructPointer,
fieldCount,
"ptr"
);
structFieldMaps[i].fields = realloc(structFieldMaps[i].fields, sizeof(StructFieldMapValue) * (fieldCount + 1));
structFieldMaps[i].fields[fieldCount].name = strdup(name);
structFieldMaps[i].fields[fieldCount].value = LLVMBuildLoad(builder, elementPointer, name);
structFieldMaps[i].fields[fieldCount].valuePointer = elementPointer;
structFieldMaps[i].fields[fieldCount].needsWrite = 0;
structFieldMaps[i].fieldCount += 1;
break;
}
}
}
static LLVMValueRef GetStructFieldPointer(LLVMValueRef wStructPointer, LLVMValueRef value)
{
uint32_t i, j;
for (i = 0; i < structFieldMapCount; i += 1)
{
if (structFieldMaps[i].structPointer == wStructPointer)
{
for (j = 0; j < structFieldMaps[i].fieldCount; j += 1)
{
if (structFieldMaps[i].fields[j].value == value)
{
return structFieldMaps[i].fields[j].valuePointer;
}
}
}
}
return NULL;
}
static void RemoveStruct(LLVMValueRef wStructPointer)
{
uint32_t i;
for (i = 0; i < structFieldMapCount; i += 1)
{
if (structFieldMaps[i].structPointer == wStructPointer)
{
free(structFieldMaps[i].fields);
structFieldMaps[i].fields = NULL;
structFieldMaps[i].fieldCount = 0;
break;
}
}
}
typedef struct IdentifierMapValue
{
char *name;
@ -24,7 +118,7 @@ IdentifierMapValue *namedVariables;
uint32_t namedVariableCount;
static LLVMValueRef CompileExpression(
LLVMModuleRef module,
LLVMValueRef wStructValue,
LLVMBuilderRef builder,
LLVMValueRef function,
Node *binaryExpression
@ -42,10 +136,11 @@ static void AddNamedVariable(char *name, LLVMValueRef variable)
namedVariableCount += 1;
}
static LLVMValueRef FindVariableByName(char *name)
static LLVMValueRef FindVariableByName(LLVMValueRef wStructValue, LLVMBuilderRef builder, char *name)
{
uint32_t i;
uint32_t i, j;
/* first, search scoped vars */
for (i = 0; i < namedVariableCount; i += 1)
{
if (strcmp(namedVariables[i].name, name) == 0)
@ -54,6 +149,22 @@ static LLVMValueRef FindVariableByName(char *name)
}
}
/* if none exist, search struct vars */
for (i = 0; i < structFieldMapCount; i += 1)
{
if (structFieldMaps[i].structPointer == wStructValue)
{
for (j = 0; j < structFieldMaps[i].fieldCount; j += 1)
{
if (strcmp(structFieldMaps[i].fields[j].name, name) == 0)
{
return structFieldMaps[i].fields[j].value;
}
}
}
}
fprintf(stderr, "Identifier not found!");
return NULL;
}
@ -66,28 +177,32 @@ static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type)
case UInt:
return LLVMInt64Type();
case Bool:
return LLVMInt1Type();
case Void:
return LLVMVoidType();
}
fprintf(stderr, "Unrecognized type!");
return NULL;
}
static LLVMValueRef CompileNumber(
LLVMModuleRef module,
LLVMBuilderRef builder,
LLVMValueRef function,
Node *numberExpression
) {
return LLVMConstInt(LLVMInt64Type(), numberExpression->value.number, 0);
}
static LLVMValueRef CompileBinaryExpression(
LLVMModuleRef module,
LLVMValueRef wStructValue,
LLVMBuilderRef builder,
LLVMValueRef function,
Node *binaryExpression
) {
LLVMValueRef left = CompileExpression(module, builder, function, binaryExpression->children[0]);
LLVMValueRef right = CompileExpression(module, builder, function, binaryExpression->children[1]);
LLVMValueRef left = CompileExpression(wStructValue, builder, function, binaryExpression->children[0]);
LLVMValueRef right = CompileExpression(wStructValue, builder, function, binaryExpression->children[1]);
switch (binaryExpression->operator.binaryOperator)
{
@ -106,7 +221,7 @@ static LLVMValueRef CompileBinaryExpression(
}
static LLVMValueRef CompileFunctionCallExpression(
LLVMModuleRef module,
LLVMValueRef wStructValue,
LLVMBuilderRef builder,
LLVMValueRef function,
Node *expression
@ -117,71 +232,113 @@ static LLVMValueRef CompileFunctionCallExpression(
for (i = 0; i < argumentCount; i += 1)
{
args[i] = CompileExpression(module, builder, function, expression->children[1]->children[i]);
args[i] = CompileExpression(wStructValue, builder, function, expression->children[1]->children[i]);
}
return LLVMBuildCall(builder, FindVariableByName(expression->children[0]->value.string), args, argumentCount, "tmp");
return LLVMBuildCall(builder, FindVariableByName(wStructValue, builder, expression->children[0]->value.string), args, argumentCount, "tmp");
}
static LLVMValueRef CompileExpression(
LLVMModuleRef module,
LLVMValueRef wStructValue,
LLVMBuilderRef builder,
LLVMValueRef function,
Node *expression
) {
LLVMValueRef var;
switch (expression->syntaxKind)
{
case BinaryExpression:
return CompileBinaryExpression(module, builder, function, expression);
return CompileBinaryExpression(wStructValue, builder, function, expression);
case FunctionCallExpression:
return CompileFunctionCallExpression(module, builder, function, expression);
return CompileFunctionCallExpression(wStructValue, builder, function, expression);
case Identifier:
return FindVariableByName(expression->value.string);
return FindVariableByName(wStructValue, builder, expression->value.string);
case Number:
return CompileNumber(module, builder, function, expression);
return CompileNumber(expression);
}
printf("Error: expected expression\n");
fprintf(stderr, "Unknown expression kind!\n");
return NULL;
}
static void CompileReturn(LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement)
static void CompileReturn(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement)
{
LLVMBuildRet(builder, CompileExpression(module, builder, function, returnStatemement->children[0]));
LLVMBuildRet(builder, CompileExpression(wStructValue, builder, function, returnStatemement->children[0]));
}
static void CompileStatement(LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *statement)
static void CompileReturnVoid(LLVMBuilderRef builder)
{
LLVMBuildRetVoid(builder);
}
static void CompileAssignment(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement)
{
LLVMValueRef fieldPointer;
LLVMValueRef result = CompileExpression(wStructValue, builder, function, assignmentStatement->children[1]);
LLVMValueRef identifier = CompileExpression(wStructValue, builder, function, assignmentStatement->children[0]);
fieldPointer = GetStructFieldPointer(wStructValue, identifier);
if (fieldPointer != NULL)
{
LLVMBuildStore(builder, result, fieldPointer);
}
}
static uint8_t CompileStatement(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *statement)
{
switch (statement->syntaxKind)
{
case Assignment:
CompileAssignment(wStructValue, builder, function, statement);
return 0;
case Return:
CompileReturn(module, builder, function, statement);
break;
}
CompileReturn(wStructValue, builder, function, statement);
return 1;
case ReturnVoid:
CompileReturnVoid(builder);
return 1;
}
static void CompileFunction(LLVMModuleRef module, Node *functionDeclaration)
{
fprintf(stderr, "Unknown statement kind!\n");
return 0;
}
static void CompileFunction(
LLVMModuleRef module,
LLVMTypeRef wStructPointerType,
Node **fieldDeclarations,
uint32_t fieldDeclarationCount,
Node *functionDeclaration
) {
uint32_t i;
uint8_t hasReturn = 0;
Node *functionSignature = functionDeclaration->children[0];
Node *functionBody = functionDeclaration->children[1];
LLVMTypeRef paramTypes[functionSignature->children[2]->childCount];
uint32_t argumentCount = functionSignature->children[2]->childCount + 1; /* struct is implicit argument */
LLVMTypeRef paramTypes[argumentCount];
paramTypes[0] = wStructPointerType;
for (i = 0; i < functionSignature->children[2]->childCount; i += 1)
{
paramTypes[i] = WraithTypeToLLVMType(functionSignature->children[2]->children[i]->children[0]->type);
paramTypes[i + 1] = WraithTypeToLLVMType(functionSignature->children[2]->children[i]->children[0]->type);
}
LLVMTypeRef functionType = LLVMFunctionType(WraithTypeToLLVMType(functionSignature->children[1]->type), paramTypes, functionSignature->children[2]->childCount, 0);
LLVMTypeRef returnType = WraithTypeToLLVMType(functionSignature->children[1]->type);
LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, argumentCount, 0);
LLVMValueRef function = LLVMAddFunction(module, functionSignature->children[0]->value.string, functionType);
LLVMValueRef wStructPointer = LLVMGetParam(function, 0);
for (i = 0; i < functionSignature->children[2]->childCount; i += 1)
{
LLVMValueRef argument = LLVMGetParam(function, i);
LLVMValueRef argument = LLVMGetParam(function, i + 1);
AddNamedVariable(functionSignature->children[2]->children[i]->children[1]->value.string, argument);
}
@ -190,28 +347,91 @@ static void CompileFunction(LLVMModuleRef module, Node *functionDeclaration)
LLVMBuilderRef builder = LLVMCreateBuilder();
LLVMPositionBuilderAtEnd(builder, entry);
/* FIXME: replace this with a scope abstraction */
AddStruct(wStructPointer);
for (i = 0; i < fieldDeclarationCount; i += 1)
{
AddStructField(builder, wStructPointer, fieldDeclarations[i]->children[1]->value.string, i);
}
for (i = 0; i < functionBody->childCount; i += 1)
{
CompileStatement(module, builder, function, functionBody->children[i]);
hasReturn |= CompileStatement(wStructPointer, builder, function, functionBody->children[i]);
}
AddNamedVariable(functionSignature->children[0]->value.string, function);
if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn)
{
LLVMBuildRetVoid(builder);
}
else if (LLVMGetTypeKind(returnType) != LLVMVoidTypeKind && !hasReturn)
{
fprintf(stderr, "Return statement not provided!");
}
static void Compile(LLVMModuleRef module, Node *node)
RemoveStruct(wStructPointer);
}
static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *node)
{
uint32_t i;
uint32_t fieldCount = 0;
uint32_t declarationCount = node->children[1]->childCount;
uint8_t packed = 1;
LLVMTypeRef types[declarationCount];
Node *currentDeclarationNode;
Node *fieldDeclarations[declarationCount];
LLVMTypeRef wStruct = LLVMStructCreateNamed(context, node->children[0]->value.string);
LLVMTypeRef wStructPointerType = LLVMPointerType(wStruct, 0); /* FIXME: is this address space correct? */
/* first, build the structure definition */
for (i = 0; i < declarationCount; i += 1)
{
currentDeclarationNode = node->children[1]->children[i];
switch (currentDeclarationNode->syntaxKind)
{
case Declaration: /* this is badly named */
types[fieldCount] = WraithTypeToLLVMType(currentDeclarationNode->children[0]->type);
fieldDeclarations[fieldCount] = currentDeclarationNode;
fieldCount += 1;
break;
}
}
LLVMStructSetBody(wStruct, types, fieldCount, packed);
/* now we can wire up the functions */
for (i = 0; i < declarationCount; i += 1)
{
currentDeclarationNode = node->children[1]->children[i];
switch (currentDeclarationNode->syntaxKind)
{
case FunctionDeclaration:
CompileFunction(module, wStructPointerType, fieldDeclarations, fieldCount, currentDeclarationNode);
break;
}
}
}
static void Compile(LLVMModuleRef module, LLVMContextRef context, Node *node)
{
uint32_t i;
switch (node->syntaxKind)
{
case FunctionDeclaration:
CompileFunction(module, node);
case StructDeclaration:
CompileStruct(module, context, node);
break;
}
for (i = 0; i < node->childCount; i += 1)
{
Compile(module, node->children[i]);
Compile(module, context, node->children[i]);
}
}
@ -226,6 +446,9 @@ int main(int argc, char *argv[])
namedVariables = NULL;
namedVariableCount = 0;
structFieldMaps = NULL;
structFieldMapCount = 0;
stack = CreateStack();
FILE *fp = fopen(argv[1], "r");
@ -236,8 +459,9 @@ int main(int argc, char *argv[])
PrintTree(rootNode, 0);
LLVMModuleRef module = LLVMModuleCreateWithName("my_module");
LLVMContextRef context = LLVMGetGlobalContext();
Compile(module, rootNode);
Compile(module, context, rootNode);
char *error = NULL;
LLVMVerifyModule(module, LLVMAbortProcessAction, &error);

View File

@ -44,7 +44,7 @@ void AddNode(Stack *stack, Node *statementNode)
if (stackFrame->nodeCount == stackFrame->nodeCapacity)
{
stackFrame->nodeCapacity += 1;
stackFrame->nodes = (Node**) realloc(stackFrame->nodes, stackFrame->nodeCapacity);
stackFrame->nodes = (Node**) realloc(stackFrame->nodes, sizeof(Node*) * stackFrame->nodeCapacity);
}
stackFrame->nodes[stackFrame->nodeCount] = statementNode;
stackFrame->nodeCount += 1;

View File

@ -5,7 +5,7 @@
%option noyywrap
%%
[0-9]+ return NUMBER;
"void" return VOID;
"int" return INT;
"uint" return UINT;
"float" return FLOAT;
@ -14,6 +14,7 @@
"bool" return BOOL;
"struct" return STRUCT;
"return" return RETURN;
[0-9]+ return NUMBER;
[a-zA-Z][a-zA-Z0-9]* return ID;
\"[a-zA-Z][a-zA-Z0-9]*\" return STRING_LITERAL;
"+" return PLUS;

View File

@ -21,7 +21,7 @@ extern Node *rootNode;
%define api.value.type {struct Node*}
%token NUMBER
%token VOID
%token INT
%token UINT
%token FLOAT
@ -30,6 +30,7 @@ extern Node *rootNode;
%token BOOL
%token STRUCT
%token RETURN
%token NUMBER
%token ID
%token STRING_LITERAL
%token PLUS
@ -65,7 +66,7 @@ extern Node *rootNode;
%left LEFT_PAREN RIGHT_PAREN
%%
Program : Declarations
Program : TopLevelDeclarations
{
Node **declarations;
Node *declarationSequence;
@ -79,7 +80,11 @@ Program : Declarations
rootNode = declarationSequence;
}
Type : INT
Type : VOID
{
$$ = MakeTypeNode(Void);
}
| INT
{
$$ = MakeTypeNode(Int);
}
@ -167,6 +172,10 @@ ReturnStatement : RETURN Expression
{
$$ = MakeReturnStatementNode($2);
}
| RETURN
{
$$ = MakeReturnVoidStatementNode();
}
FunctionCallExpression : Identifier LEFT_PAREN Arguments RIGHT_PAREN
{
@ -208,7 +217,7 @@ Arguments : PrimaryExpression COMMA Arguments
|
;
SignatureArguments : VariableDeclaration COMMA VariableDeclarations
SignatureArguments : VariableDeclaration COMMA SignatureArguments
{
AddNode(stack, $1);
}
@ -249,16 +258,7 @@ FunctionDeclaration : FunctionSignature Body
$$ = MakeFunctionDeclarationNode($1, $2);
}
VariableDeclarations : VariableDeclaration SEMICOLON VariableDeclarations
{
AddNode(stack, $1);
}
|
{
PushStackFrame(stack);
}
StructDeclaration : STRUCT Identifier LEFT_BRACE VariableDeclarations RIGHT_BRACE
StructDeclaration : STRUCT Identifier LEFT_BRACE Declarations RIGHT_BRACE
{
Node **declarations;
Node *declarationSequence;
@ -271,8 +271,9 @@ StructDeclaration : STRUCT Identifier LEFT_BRACE VariableDeclarations RIGH
PopStackFrame(stack);
}
Declaration : StructDeclaration
| FunctionDeclaration
Declaration : FunctionDeclaration
| VariableDeclaration SEMICOLON
;
Declarations : Declaration Declarations
@ -283,4 +284,15 @@ Declarations : Declaration Declarations
{
PushStackFrame(stack);
}
TopLevelDeclaration : StructDeclaration;
TopLevelDeclarations : TopLevelDeclaration TopLevelDeclarations
{
AddNode(stack, $1);
}
|
{
PushStackFrame(stack);
}
%%