function calls

generics
cosmonaut 2021-04-22 17:19:35 -07:00
parent e4cf57ef74
commit acb6c61922
3 changed files with 141 additions and 37 deletions

View File

@ -23,7 +23,7 @@ find_package(LLVM)
include_directories(${CMAKE_SOURCE_DIR}) include_directories(${CMAKE_SOURCE_DIR})
BISON_TARGET(Parser wraith.y ${CMAKE_CURRENT_BINARY_DIR}/y.tab.c COMPILE_FLAGS -d) BISON_TARGET(Parser wraith.y ${CMAKE_CURRENT_BINARY_DIR}/y.tab.c COMPILE_FLAGS "-d -v -t")
FLEX_TARGET(Scanner wraith.lex ${CMAKE_CURRENT_BINARY_DIR}/lex.yy.c) FLEX_TARGET(Scanner wraith.lex ${CMAKE_CURRENT_BINARY_DIR}/lex.yy.c)
ADD_FLEX_BISON_DEPENDENCY(Scanner Parser) ADD_FLEX_BISON_DEPENDENCY(Scanner Parser)

View File

@ -53,12 +53,22 @@ typedef struct StructTypeField
uint32_t index; uint32_t index;
} StructTypeField; } StructTypeField;
typedef struct StructTypeFunction
{
char *name;
LLVMValueRef function;
LLVMTypeRef returnType;
uint8_t isStatic;
} StructTypeFunction;
typedef struct StructTypeFieldDeclaration typedef struct StructTypeFieldDeclaration
{ {
LLVMTypeRef structType; LLVMTypeRef structType;
StructTypeField *fields; StructTypeField *fields;
uint32_t fieldCount; uint32_t fieldCount;
StructTypeFunction *functions;
uint32_t functionCount;
} StructTypeFieldDeclaration; } StructTypeFieldDeclaration;
StructTypeFieldDeclaration *structTypeFieldDeclarations; StructTypeFieldDeclaration *structTypeFieldDeclarations;
@ -198,6 +208,8 @@ static void AddStructDeclaration(
structTypeFieldDeclarations[index].structType = wStructType; structTypeFieldDeclarations[index].structType = wStructType;
structTypeFieldDeclarations[index].fields = NULL; structTypeFieldDeclarations[index].fields = NULL;
structTypeFieldDeclarations[index].fieldCount = 0; structTypeFieldDeclarations[index].fieldCount = 0;
structTypeFieldDeclarations[index].functions = NULL;
structTypeFieldDeclarations[index].functionCount = 0;
for (i = 0; i < fieldDeclarationCount; i += 1) for (i = 0; i < fieldDeclarationCount; i += 1)
{ {
@ -210,7 +222,63 @@ static void AddStructDeclaration(
structTypeFieldDeclarationCount += 1; structTypeFieldDeclarationCount += 1;
} }
static void AddStructVariables( static void DeclareStructFunction(
LLVMTypeRef wStructType,
LLVMValueRef function,
LLVMTypeRef returnType,
uint8_t isStatic,
char *name
) {
uint32_t i, index;
for (i = 0; i < structTypeFieldDeclarationCount; i += 1)
{
if (structTypeFieldDeclarations[i].structType == wStructType)
{
index = structTypeFieldDeclarations[i].functionCount;
structTypeFieldDeclarations[i].functions = realloc(structTypeFieldDeclarations[i].functions, sizeof(StructTypeFunction) * (structTypeFieldDeclarations[i].functionCount + 1));
structTypeFieldDeclarations[i].functions[index].name = strdup(name);
structTypeFieldDeclarations[i].functions[index].function = function;
structTypeFieldDeclarations[i].functions[index].returnType = returnType;
structTypeFieldDeclarations[i].functions[index].isStatic = isStatic;
structTypeFieldDeclarations[i].functionCount += 1;
return;
}
}
fprintf(stderr, "Could not find struct type for function!");
}
static LLVMValueRef LookupFunction(
LLVMValueRef structPointer,
char *name,
LLVMTypeRef *pReturnType,
uint8_t *pStatic
) {
uint32_t i, j;
for (i = 0; i < structTypeFieldDeclarationCount; i += 1)
{
if (structTypeFieldDeclarations[i].structType == LLVMTypeOf(structPointer))
{
for (j = 0; j < structTypeFieldDeclarations[i].functionCount; j += 1)
{
if (strcmp(structTypeFieldDeclarations[i].functions[j].name, name) == 0)
{
*pReturnType = structTypeFieldDeclarations[i].functions[j].returnType;
*pStatic = structTypeFieldDeclarations[i].functions[j].isStatic;
return structTypeFieldDeclarations[i].functions[j].function;
}
}
}
}
fprintf(stderr, "Could not find struct function!");
return NULL;
}
static void AddStructVariablesToScope(
LLVMBuilderRef builder, LLVMBuilderRef builder,
LLVMValueRef structPointer LLVMValueRef structPointer
) { ) {
@ -244,7 +312,6 @@ static void AddStructVariables(
static LLVMValueRef CompileExpression( static LLVMValueRef CompileExpression(
LLVMBuilderRef builder, LLVMBuilderRef builder,
LLVMValueRef function,
Node *binaryExpression Node *binaryExpression
); );
@ -309,11 +376,10 @@ static LLVMValueRef CompileNumber(
static LLVMValueRef CompileBinaryExpression( static LLVMValueRef CompileBinaryExpression(
LLVMBuilderRef builder, LLVMBuilderRef builder,
LLVMValueRef function,
Node *binaryExpression Node *binaryExpression
) { ) {
LLVMValueRef left = CompileExpression(builder, function, binaryExpression->children[0]); LLVMValueRef left = CompileExpression(builder, binaryExpression->children[0]);
LLVMValueRef right = CompileExpression(builder, function, binaryExpression->children[1]); LLVMValueRef right = CompileExpression(builder, binaryExpression->children[1]);
switch (binaryExpression->operator.binaryOperator) switch (binaryExpression->operator.binaryOperator)
{ {
@ -334,25 +400,51 @@ static LLVMValueRef CompileBinaryExpression(
/* FIXME THIS IS ALL BROKEN */ /* FIXME THIS IS ALL BROKEN */
static LLVMValueRef CompileFunctionCallExpression( static LLVMValueRef CompileFunctionCallExpression(
LLVMBuilderRef builder, LLVMBuilderRef builder,
LLVMValueRef function,
Node *expression Node *expression
) { ) {
uint32_t i; uint32_t i;
uint32_t argumentCount = expression->children[1]->childCount; uint32_t argumentCount = 0;
LLVMValueRef args[argumentCount]; LLVMValueRef args[argumentCount];
LLVMValueRef function;
uint8_t isStatic;
LLVMValueRef structInstance;
LLVMTypeRef functionReturnType;
char *returnName = "";
for (i = 0; i < argumentCount; i += 1) /* FIXME: this needs to be recursive on access chains */
if (expression->children[0]->syntaxKind == AccessExpression)
{ {
args[i] = CompileExpression(builder, function, expression->children[1]->children[i]); structInstance = FindVariablePointer(expression->children[0]->children[0]->value.string);
function = LookupFunction(structInstance, expression->children[0]->children[1]->value.string, &functionReturnType, &isStatic);
}
else
{
fprintf(stderr, "Failed to find function!\n");
return NULL;
} }
//return LLVMBuildCall(builder, FindVariableValueByName(builder, wStructValue, expression->children[0]->value.string), args, argumentCount, "tmp"); if (!isStatic)
return NULL; {
args[argumentCount] = structInstance;
argumentCount += 1;
}
for (i = 0; i < expression->children[1]->childCount; i += 1)
{
args[argumentCount] = CompileExpression(builder, expression->children[1]->children[i]);
argumentCount += 1;
}
if (LLVMGetTypeKind(functionReturnType) != LLVMVoidTypeKind)
{
returnName = "callReturn";
}
return LLVMBuildCall(builder, function, args, argumentCount, returnName);
} }
static LLVMValueRef CompileAccessExpressionForStore( static LLVMValueRef CompileAccessExpressionForStore(
LLVMBuilderRef builder, LLVMBuilderRef builder,
LLVMValueRef function,
Node *expression Node *expression
) { ) {
Node *accessee = expression->children[0]; Node *accessee = expression->children[0];
@ -363,7 +455,6 @@ static LLVMValueRef CompileAccessExpressionForStore(
static LLVMValueRef CompileAccessExpression( static LLVMValueRef CompileAccessExpression(
LLVMBuilderRef builder, LLVMBuilderRef builder,
LLVMValueRef function,
Node *expression Node *expression
) { ) {
Node *accessee = expression->children[0]; Node *accessee = expression->children[0];
@ -375,19 +466,18 @@ static LLVMValueRef CompileAccessExpression(
static LLVMValueRef CompileExpression( static LLVMValueRef CompileExpression(
LLVMBuilderRef builder, LLVMBuilderRef builder,
LLVMValueRef function,
Node *expression Node *expression
) { ) {
switch (expression->syntaxKind) switch (expression->syntaxKind)
{ {
case AccessExpression: case AccessExpression:
return CompileAccessExpression(builder, function, expression); return CompileAccessExpression(builder, expression);
case BinaryExpression: case BinaryExpression:
return CompileBinaryExpression(builder, function, expression); return CompileBinaryExpression(builder, expression);
case FunctionCallExpression: case FunctionCallExpression:
return CompileFunctionCallExpression(builder, function, expression); return CompileFunctionCallExpression(builder, expression);
case Identifier: case Identifier:
return FindVariableValue(builder, expression->value.string); return FindVariableValue(builder, expression->value.string);
@ -400,9 +490,9 @@ static LLVMValueRef CompileExpression(
return NULL; return NULL;
} }
static void CompileReturn(LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement) static void CompileReturn(LLVMBuilderRef builder, Node *returnStatemement)
{ {
LLVMValueRef expression = CompileExpression(builder, function, returnStatemement->children[0]); LLVMValueRef expression = CompileExpression(builder, returnStatemement->children[0]);
LLVMBuildRet(builder, expression); LLVMBuildRet(builder, expression);
} }
@ -411,13 +501,13 @@ static void CompileReturnVoid(LLVMBuilderRef builder)
LLVMBuildRetVoid(builder); LLVMBuildRetVoid(builder);
} }
static void CompileAssignment(LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement) static void CompileAssignment(LLVMBuilderRef builder, Node *assignmentStatement)
{ {
LLVMValueRef result = CompileExpression(builder, function, assignmentStatement->children[1]); LLVMValueRef result = CompileExpression(builder, assignmentStatement->children[1]);
LLVMValueRef identifier; LLVMValueRef identifier;
if (assignmentStatement->children[0]->syntaxKind == AccessExpression) if (assignmentStatement->children[0]->syntaxKind == AccessExpression)
{ {
identifier = CompileAccessExpressionForStore(builder, function, assignmentStatement->children[0]); identifier = CompileAccessExpressionForStore(builder, assignmentStatement->children[0]);
} }
else if (assignmentStatement->children[0]->syntaxKind == Identifier) else if (assignmentStatement->children[0]->syntaxKind == Identifier)
{ {
@ -459,7 +549,11 @@ static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, N
switch (statement->syntaxKind) switch (statement->syntaxKind)
{ {
case Assignment: case Assignment:
CompileAssignment(builder, function, statement); CompileAssignment(builder, statement);
return 0;
case FunctionCallExpression:
CompileFunctionCallExpression(builder, statement);
return 0; return 0;
case Declaration: case Declaration:
@ -467,7 +561,7 @@ static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, N
return 0; return 0;
case Return: case Return:
CompileReturn(builder, function, statement); CompileReturn(builder, statement);
return 1; return 1;
case ReturnVoid: case ReturnVoid:
@ -531,6 +625,8 @@ static void CompileFunction(
LLVMValueRef function = LLVMAddFunction(module, functionName, functionType); LLVMValueRef function = LLVMAddFunction(module, functionName, functionType);
free(functionName); free(functionName);
DeclareStructFunction(wStructPointerType, function, returnType, isStatic, functionSignature->children[0]->value.string);
LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry"); LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry");
LLVMBuilderRef builder = LLVMCreateBuilder(); LLVMBuilderRef builder = LLVMCreateBuilder();
LLVMPositionBuilderAtEnd(builder, entry); LLVMPositionBuilderAtEnd(builder, entry);
@ -538,7 +634,7 @@ static void CompileFunction(
if (!isStatic) if (!isStatic)
{ {
LLVMValueRef wStructPointer = LLVMGetParam(function, 0); LLVMValueRef wStructPointer = LLVMGetParam(function, 0);
AddStructVariables(builder, wStructPointer); AddStructVariablesToScope(builder, wStructPointer);
} }
for (i = 0; i < functionSignature->children[2]->childCount; i += 1) for (i = 0; i < functionSignature->children[2]->childCount; i += 1)
@ -626,16 +722,16 @@ static void Compile(LLVMModuleRef module, LLVMContextRef context, Node *node)
{ {
uint32_t i; uint32_t i;
switch (node->syntaxKind)
{
case StructDeclaration:
CompileStruct(module, context, node);
break;
}
for (i = 0; i < node->childCount; i += 1) for (i = 0; i < node->childCount; i += 1)
{ {
Compile(module, context, node->children[i]); if (node->children[i]->syntaxKind == StructDeclaration)
{
CompileStruct(module, context, node->children[i]);
}
else
{
fprintf(stderr, "top level declarations that are not structs are forbidden!\n");
}
} }
} }
@ -647,6 +743,9 @@ int main(int argc, char *argv[])
return 1; return 1;
} }
extern int yydebug;
yydebug = 1;
scope = CreateScope(); scope = CreateScope();
structTypeFieldDeclarations = NULL; structTypeFieldDeclarations = NULL;

View File

@ -62,6 +62,8 @@ extern Node *rootNode;
%parse-param { FILE* fp } { Stack *stack } %parse-param { FILE* fp } { Stack *stack }
%define parse.error verbose
%left PLUS MINUS %left PLUS MINUS
%left BANG %left BANG
%left LEFT_PAREN RIGHT_PAREN %left LEFT_PAREN RIGHT_PAREN
@ -195,7 +197,7 @@ ReturnStatement : RETURN Expression
$$ = MakeReturnVoidStatementNode(); $$ = MakeReturnVoidStatementNode();
} }
FunctionCallExpression : Identifier LEFT_PAREN Arguments RIGHT_PAREN FunctionCallExpression : AccessExpression LEFT_PAREN Arguments RIGHT_PAREN
{ {
Node **arguments; Node **arguments;
uint32_t argumentCount; uint32_t argumentCount;
@ -206,9 +208,9 @@ FunctionCallExpression : Identifier LEFT_PAREN Arguments RIGHT_PAREN
PopStackFrame(stack); PopStackFrame(stack);
} }
PartialStatement : AssignmentStatement PartialStatement : FunctionCallExpression
| AssignmentStatement
| VariableDeclaration | VariableDeclaration
| FunctionCallExpression
| ReturnStatement | ReturnStatement
; ;
@ -233,6 +235,9 @@ Arguments : PrimaryExpression COMMA Arguments
AddNode(stack, $1); AddNode(stack, $1);
} }
| |
{
PushStackFrame(stack);
}
; ;
SignatureArguments : VariableDeclaration COMMA SignatureArguments SignatureArguments : VariableDeclaration COMMA SignatureArguments