From d9b01515ebc3969fc3b1b533fe871cdeb63984bd Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Tue, 1 Jun 2021 19:58:46 +0000 Subject: [PATCH] Generics (#7) Co-authored-by: venko Reviewed-on: https://gitea.moonside.games/cosmonaut/wraith-lang/pulls/7 Co-authored-by: cosmonaut Co-committed-by: cosmonaut --- CMakeLists.txt | 6 +- access.w | 13 + generators/wraith.y | 32 +- generic.w | 18 + ordering.w | 15 + src/ast.c | 455 +++++++++++++++++++++- src/ast.h | 47 ++- src/codegen.c | 919 +++++++++++++++++++++++++++++++++++--------- src/identcheck.c | 485 ----------------------- src/identcheck.h | 49 --- src/main.c | 17 +- src/util.c | 14 +- src/util.h | 2 + src/validation.c | 452 ++++++++++++++++++++++ src/validation.h | 10 + 15 files changed, 1793 insertions(+), 741 deletions(-) create mode 100644 access.w create mode 100644 generic.w create mode 100644 ordering.w delete mode 100644 src/identcheck.c delete mode 100644 src/identcheck.h create mode 100644 src/validation.c create mode 100644 src/validation.h diff --git a/CMakeLists.txt b/CMakeLists.txt index a3cbf9e..1b7a9e5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -41,12 +41,14 @@ add_executable( # Source src/ast.h src/codegen.h - src/identcheck.h src/parser.h + src/validation.h + src/util.h src/ast.c src/codegen.c - src/identcheck.c src/parser.c + src/validation.c + src/util.c src/main.c # Generated code ${BISON_Parser_OUTPUTS} diff --git a/access.w b/access.w new file mode 100644 index 0000000..2534dd2 --- /dev/null +++ b/access.w @@ -0,0 +1,13 @@ +struct G { + Foo(t: bool): bool { + return t; + } +} + +struct Program { + static main(): int { + g: G = alloc G; + g.Foo(true); + return 0; + } +} \ No newline at end of file diff --git a/generators/wraith.y b/generators/wraith.y index 7aaf70a..3176d60 100644 --- a/generators/wraith.y +++ b/generators/wraith.y @@ -307,14 +307,38 @@ Body : LEFT_BRACE Statements RIGHT_BRACE $$ = $2; } -FunctionSignature : Identifier LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type +GenericArgument : Identifier { - $$ = MakeFunctionSignatureNode($1, $6, $3, MakeFunctionModifiersNode(NULL, 0)); + $$ = MakeGenericArgumentNode($1, NULL); } - | STATIC Identifier LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type + +GenericArguments : GenericArgument + { + $$ = StartGenericArgumentsNode($1); + } + | GenericArguments COMMA GenericArgument + { + $$ = AddGenericArgument($1, $3); + } + +GenericArgumentsClause : LESS_THAN GenericArguments GREATER_THAN + { + $$ = $2; + } + | + { + $$ = MakeEmptyGenericArgumentsNode(); + } + + +FunctionSignature : Identifier GenericArgumentsClause LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type + { + $$ = MakeFunctionSignatureNode($1, $7, $4, MakeFunctionModifiersNode(NULL, 0), $2); + } + | STATIC Identifier GenericArgumentsClause LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type { Node *modifier = MakeStaticNode(); - $$ = MakeFunctionSignatureNode($2, $7, $4, MakeFunctionModifiersNode(&modifier, 1)); + $$ = MakeFunctionSignatureNode($2, $8, $5, MakeFunctionModifiersNode(&modifier, 1), $3); } FunctionDeclaration : FunctionSignature Body diff --git a/generic.w b/generic.w new file mode 100644 index 0000000..73fb981 --- /dev/null +++ b/generic.w @@ -0,0 +1,18 @@ +struct Foo { + static Func2(u: U) : U { + return u; + } + + static Func(t: T): T { + foo: T = t; + return Foo.Func2(foo); + } +} + +struct Program { + static Main(): int { + x: int = 4; + y: int = Foo.Func(x); + return x; + } +} diff --git a/ordering.w b/ordering.w new file mode 100644 index 0000000..f446a68 --- /dev/null +++ b/ordering.w @@ -0,0 +1,15 @@ +struct Foo { + static Func(): void { + Func2(); + } + + static Func2(): void { + Func(); + } +} + +struct Program { + static main(): int { + return 0; + } +} \ No newline at end of file diff --git a/src/ast.c b/src/ast.c index 74ee4d9..fddf74d 100644 --- a/src/ast.c +++ b/src/ast.c @@ -39,6 +39,12 @@ const char *SyntaxKindString(SyntaxKind syntaxKind) return "FunctionSignature"; case FunctionSignatureArguments: return "FunctionSignatureArguments"; + case GenericArgument: + return "GenericArgument"; + case GenericArguments: + return "GenericArguments"; + case GenericTypeNode: + return "GenericTypeNode"; case Identifier: return "Identifier"; case IfStatement: @@ -271,7 +277,8 @@ Node *MakeFunctionSignatureNode( Node *identifierNode, Node *typeNode, Node *arguments, - Node *modifiersNode) + Node *modifiersNode, + Node *genericArgumentsNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = FunctionSignature; @@ -279,6 +286,7 @@ Node *MakeFunctionSignatureNode( node->functionSignature.type = typeNode; node->functionSignature.arguments = arguments; node->functionSignature.modifiers = modifiersNode; + node->functionSignature.genericArguments = genericArgumentsNode; return node; } @@ -359,6 +367,54 @@ Node *MakeEmptyFunctionArgumentSequenceNode() return node; } +Node *MakeGenericArgumentNode(Node *identifierNode, Node *constraintNode) +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = GenericArgument; + node->genericArgument.identifier = identifierNode; + node->genericArgument.constraint = constraintNode; + return node; +} + +Node *StartGenericArgumentsNode(Node *genericArgumentNode) +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = GenericArguments; + node->genericArguments.arguments = (Node **)malloc(sizeof(Node *)); + node->genericArguments.arguments[0] = genericArgumentNode; + node->genericArguments.count = 1; + return node; +} + +Node *AddGenericArgument(Node *genericArgumentsNode, Node *genericArgumentNode) +{ + genericArgumentsNode->genericArguments.arguments = (Node **)realloc( + genericArgumentsNode->genericArguments.arguments, + sizeof(Node *) * (genericArgumentsNode->genericArguments.count + 1)); + genericArgumentsNode->genericArguments + .arguments[genericArgumentsNode->genericArguments.count] = + genericArgumentNode; + genericArgumentsNode->genericArguments.count += 1; + return genericArgumentsNode; +} + +Node *MakeEmptyGenericArgumentsNode() +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = GenericArguments; + node->genericArguments.arguments = NULL; + node->genericArguments.count = 0; + return node; +} + +Node *MakeGenericTypeNode(char *name) +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = GenericTypeNode; + node->genericType.name = strdup(name); + return node; +} + Node *MakeFunctionCallExpressionNode( Node *identifierNode, Node *argumentSequenceNode) @@ -557,6 +613,7 @@ void PrintNode(Node *node, uint32_t tabCount) case FunctionSignature: printf("\n"); PrintNode(node->functionSignature.identifier, tabCount + 1); + PrintNode(node->functionSignature.genericArguments, tabCount + 1); PrintNode(node->functionSignature.arguments, tabCount + 1); PrintNode(node->functionSignature.type, tabCount + 1); PrintNode(node->functionSignature.modifiers, tabCount + 1); @@ -572,6 +629,25 @@ void PrintNode(Node *node, uint32_t tabCount) } return; + case GenericArgument: + printf("\n"); + PrintNode(node->genericArgument.identifier, tabCount + 1); + /* Constraint nodes are not implemented. */ + /* PrintNode(node->genericArgument.constraint, tabCount + 1); */ + return; + + case GenericArguments: + printf("\n"); + for (i = 0; i < node->genericArguments.count; i += 1) + { + PrintNode(node->genericArguments.arguments[i], tabCount + 1); + } + return; + + case GenericTypeNode: + printf("%s\n", node->genericType.name); + return; + case Identifier: if (node->typeTag == NULL) { @@ -630,7 +706,7 @@ void PrintNode(Node *node, uint32_t tabCount) return; case StringLiteral: - printf("%s", node->stringLiteral.string); + printf("%s\n", node->stringLiteral.string); return; case StructDeclaration: @@ -651,6 +727,173 @@ void PrintNode(Node *node, uint32_t tabCount) } } +void Recurse(Node *node, void (*func)(Node *)) +{ + uint32_t i; + switch (node->syntaxKind) + { + case AccessExpression: + func(node->accessExpression.accessee); + func(node->accessExpression.accessor); + return; + + case AllocExpression: + func(node->allocExpression.type); + return; + + case Assignment: + func(node->assignmentStatement.left); + func(node->assignmentStatement.right); + return; + + case BinaryExpression: + func(node->binaryExpression.left); + func(node->binaryExpression.right); + return; + + case Comment: + return; + + case CustomTypeNode: + return; + + case Declaration: + func(node->declaration.type); + func(node->declaration.identifier); + return; + + case DeclarationSequence: + for (i = 0; i < node->declarationSequence.count; i += 1) + { + func(node->declarationSequence.sequence[i]); + } + return; + + case ForLoop: + func(node->forLoop.declaration); + func(node->forLoop.startNumber); + func(node->forLoop.endNumber); + func(node->forLoop.statementSequence); + return; + + case FunctionArgumentSequence: + for (i = 0; i < node->functionArgumentSequence.count; i += 1) + { + func(node->functionArgumentSequence.sequence[i]); + } + return; + + case FunctionCallExpression: + func(node->functionCallExpression.identifier); + func(node->functionCallExpression.argumentSequence); + return; + + case FunctionDeclaration: + func(node->functionDeclaration.functionSignature); + func(node->functionDeclaration.functionBody); + return; + + case FunctionModifiers: + for (i = 0; i < node->functionModifiers.count; i += 1) + { + func(node->functionModifiers.sequence[i]); + } + return; + + case FunctionSignature: + func(node->functionSignature.identifier); + func(node->functionSignature.type); + func(node->functionSignature.arguments); + func(node->functionSignature.modifiers); + func(node->functionSignature.genericArguments); + return; + + case FunctionSignatureArguments: + for (i = 0; i < node->functionSignatureArguments.count; i += 1) + { + func(node->functionSignatureArguments.sequence[i]); + } + return; + + case GenericArgument: + func(node->genericArgument.identifier); + func(node->genericArgument.constraint); + return; + + case GenericArguments: + for (i = 0; i < node->genericArguments.count; i += 1) + { + func(node->genericArguments.arguments[i]); + } + return; + + case GenericTypeNode: + return; + + case Identifier: + return; + + case IfStatement: + func(node->ifStatement.expression); + func(node->ifStatement.statementSequence); + return; + + case IfElseStatement: + func(node->ifElseStatement.ifStatement); + func(node->ifElseStatement.elseStatement); + return; + + case Number: + return; + + case PrimitiveTypeNode: + return; + + case ReferenceTypeNode: + func(node->referenceType.type); + return; + + case Return: + func(node->returnStatement.expression); + return; + + case ReturnVoid: + return; + + case StatementSequence: + for (i = 0; i < node->statementSequence.count; i += 1) + { + func(node->statementSequence.sequence[i]); + } + return; + + case StaticModifier: + return; + + case StringLiteral: + return; + + case StructDeclaration: + func(node->structDeclaration.identifier); + func(node->structDeclaration.declarationSequence); + return; + + case Type: + return; + + case UnaryExpression: + func(node->unaryExpression.child); + return; + + default: + fprintf( + stderr, + "wraith: Unhandled SyntaxKind %s in recurse function.\n", + SyntaxKindString(node->syntaxKind)); + return; + } +} + TypeTag *MakeTypeTag(Node *node) { if (node == NULL) @@ -698,6 +941,21 @@ TypeTag *MakeTypeTag(Node *node) ->functionSignature.type); break; + case AllocExpression: + tag = MakeTypeTag(node->allocExpression.type); + break; + + case GenericArgument: + tag->type = Generic; + tag->value.genericType = + strdup(node->genericArgument.identifier->identifier.name); + break; + + case GenericTypeNode: + tag->type = Generic; + tag->value.genericType = strdup(node->genericType.name); + break; + default: fprintf( stderr, @@ -706,6 +964,7 @@ TypeTag *MakeTypeTag(Node *node) SyntaxKindString(node->syntaxKind)); return NULL; } + return tag; } @@ -729,11 +988,199 @@ char *TypeTagToString(TypeTag *tag) { char *inner = TypeTagToString(tag->value.referenceType); size_t innerStrLen = strlen(inner); - char *result = malloc(sizeof(char) * (innerStrLen + 5)); + char *result = malloc(sizeof(char) * (innerStrLen + 6)); sprintf(result, "Ref<%s>", inner); return result; } case Custom: - return tag->value.customType; + { + char *result = + malloc(sizeof(char) * (strlen(tag->value.customType) + 9)); + sprintf(result, "Custom<%s>", tag->value.customType); + return result; + } + case Generic: + { + char *result = + malloc(sizeof(char) * (strlen(tag->value.genericType) + 10)); + sprintf(result, "Generic<%s>", tag->value.genericType); + return result; + } + } +} + +void LinkParentPointers(Node *node, Node *prev) +{ + if (node == NULL) + return; + + node->parent = prev; + + uint32_t i; + switch (node->syntaxKind) + { + case AccessExpression: + LinkParentPointers(node->accessExpression.accessee, node); + LinkParentPointers(node->accessExpression.accessor, node); + return; + + case AllocExpression: + LinkParentPointers(node->allocExpression.type, node); + return; + + case Assignment: + LinkParentPointers(node->assignmentStatement.left, node); + LinkParentPointers(node->assignmentStatement.right, node); + return; + + case BinaryExpression: + LinkParentPointers(node->binaryExpression.left, node); + LinkParentPointers(node->binaryExpression.right, node); + return; + + case Comment: + return; + + case CustomTypeNode: + return; + + case Declaration: + LinkParentPointers(node->declaration.type, node); + LinkParentPointers(node->declaration.identifier, node); + return; + + case DeclarationSequence: + for (i = 0; i < node->declarationSequence.count; i += 1) + { + LinkParentPointers(node->declarationSequence.sequence[i], node); + } + return; + + case ForLoop: + LinkParentPointers(node->forLoop.declaration, node); + LinkParentPointers(node->forLoop.startNumber, node); + LinkParentPointers(node->forLoop.endNumber, node); + LinkParentPointers(node->forLoop.statementSequence, node); + return; + + case FunctionArgumentSequence: + for (i = 0; i < node->functionArgumentSequence.count; i += 1) + { + LinkParentPointers( + node->functionArgumentSequence.sequence[i], + node); + } + return; + + case FunctionCallExpression: + LinkParentPointers(node->functionCallExpression.identifier, node); + LinkParentPointers(node->functionCallExpression.argumentSequence, node); + return; + + case FunctionDeclaration: + LinkParentPointers(node->functionDeclaration.functionSignature, node); + LinkParentPointers(node->functionDeclaration.functionBody, node); + return; + + case FunctionModifiers: + for (i = 0; i < node->functionModifiers.count; i += 1) + { + LinkParentPointers(node->functionModifiers.sequence[i], node); + } + return; + + case FunctionSignature: + LinkParentPointers(node->functionSignature.identifier, node); + LinkParentPointers(node->functionSignature.type, node); + LinkParentPointers(node->functionSignature.arguments, node); + LinkParentPointers(node->functionSignature.modifiers, node); + LinkParentPointers(node->functionSignature.genericArguments, node); + return; + + case FunctionSignatureArguments: + for (i = 0; i < node->functionSignatureArguments.count; i += 1) + { + LinkParentPointers( + node->functionSignatureArguments.sequence[i], + node); + } + return; + + case GenericArgument: + LinkParentPointers(node->genericArgument.identifier, node); + LinkParentPointers(node->genericArgument.constraint, node); + return; + + case GenericArguments: + for (i = 0; i < node->genericArguments.count; i += 1) + { + LinkParentPointers(node->genericArguments.arguments[i], node); + } + return; + + case GenericTypeNode: + return; + + case Identifier: + return; + + case IfStatement: + LinkParentPointers(node->ifStatement.expression, node); + LinkParentPointers(node->ifStatement.statementSequence, node); + return; + + case IfElseStatement: + LinkParentPointers(node->ifElseStatement.ifStatement, node); + LinkParentPointers(node->ifElseStatement.elseStatement, node); + return; + + case Number: + return; + + case PrimitiveTypeNode: + return; + + case ReferenceTypeNode: + LinkParentPointers(node->referenceType.type, node); + return; + + case Return: + LinkParentPointers(node->returnStatement.expression, node); + return; + + case ReturnVoid: + return; + + case StatementSequence: + for (i = 0; i < node->statementSequence.count; i += 1) + { + LinkParentPointers(node->statementSequence.sequence[i], node); + } + return; + + case StaticModifier: + return; + + case StringLiteral: + return; + + case StructDeclaration: + LinkParentPointers(node->structDeclaration.identifier, node); + LinkParentPointers(node->structDeclaration.declarationSequence, node); + return; + + case Type: + return; + + case UnaryExpression: + LinkParentPointers(node->unaryExpression.child, node); + return; + + default: + fprintf( + stderr, + "wraith: Unhandled SyntaxKind %s in recurse function.\n", + SyntaxKindString(node->syntaxKind)); + return; } } diff --git a/src/ast.h b/src/ast.h index 60e954d..3d36cf6 100644 --- a/src/ast.h +++ b/src/ast.h @@ -1,7 +1,6 @@ #ifndef WRAITH_AST_H #define WRAITH_AST_H -#include "identcheck.h" #include /* -Wpedantic nameless union/struct silencing */ @@ -30,6 +29,9 @@ typedef enum FunctionModifiers, FunctionSignature, FunctionSignatureArguments, + GenericArgument, + GenericArguments, + GenericTypeNode, Identifier, IfStatement, IfElseStatement, @@ -87,7 +89,8 @@ typedef struct TypeTag Unknown, Primitive, Reference, - Custom + Custom, + Generic } type; union { @@ -97,6 +100,8 @@ typedef struct TypeTag struct TypeTag *referenceType; /* Valid when type = Custom. */ char *customType; + /* Valid when type = Generic. */ + char *genericType; } value; } TypeTag; @@ -192,6 +197,7 @@ struct Node Node *type; Node *arguments; Node *modifiers; + Node *genericArguments; } functionSignature; struct @@ -200,6 +206,23 @@ struct Node uint32_t count; } functionSignatureArguments; + struct + { + Node *identifier; + Node *constraint; + } genericArgument; + + struct + { + Node **arguments; + uint32_t count; + } genericArguments; + + struct + { + char *name; + } genericType; + struct { char *name; @@ -276,7 +299,6 @@ struct Node } unaryExpression; }; TypeTag *typeTag; - IdNode *idLink; }; const char *SyntaxKindString(SyntaxKind syntaxKind); @@ -306,10 +328,16 @@ Node *MakeFunctionSignatureNode( Node *identifierNode, Node *typeNode, Node *argumentsNode, - Node *modifiersNode); + Node *modifiersNode, + Node *genericArgumentsNode); Node *MakeFunctionDeclarationNode( Node *functionSignatureNode, Node *functionBodyNode); +Node *MakeGenericArgumentNode(Node *identifierNode, Node *constraintNode); +Node *MakeEmptyGenericArgumentsNode(); +Node *StartGenericArgumentsNode(Node *genericArgumentNode); +Node *AddGenericArgument(Node *genericArgumentsNode, Node *genericArgumentNode); +Node *MakeGenericTypeNode(char *name); Node *MakeStructDeclarationNode( Node *identifierNode, Node *declarationSequenceNode); @@ -337,7 +365,18 @@ Node *MakeForLoopNode( void PrintNode(Node *node, uint32_t tabCount); const char *SyntaxKindString(SyntaxKind syntaxKind); +/* Helper function for applying a void function generically over the children of + * an AST node. Used for functions that need to traverse the entire tree but + * only perform operations on a subset of node types. Such functions can match + * the syntaxKinds relevant to their purpose and invoke this function in all + * other cases. */ +void Recurse(Node *node, void (*func)(Node *)); + +void LinkParentPointers(Node *node, Node *prev); + TypeTag *MakeTypeTag(Node *node); char *TypeTagToString(TypeTag *tag); +Node *LookupIdNode(Node *current, Node *prev, char *target); + #endif /* WRAITH_AST_H */ diff --git a/src/codegen.c b/src/codegen.c index e8a5fb9..dd3bf73 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -14,6 +14,7 @@ #include #include "ast.h" +#include "util.h" typedef struct LocalVariable { @@ -22,6 +23,13 @@ typedef struct LocalVariable LLVMValueRef value; } LocalVariable; +typedef struct LocalGenericType +{ + char *name; + TypeTag *concreteTypeTag; + LLVMTypeRef type; +} LocalGenericType; + typedef struct FunctionArgument { char *name; @@ -32,6 +40,9 @@ typedef struct ScopeFrame { LocalVariable *localVariables; uint32_t localVariableCount; + + LocalGenericType *genericTypes; + uint32_t genericTypeCount; } ScopeFrame; typedef struct Scope @@ -56,6 +67,33 @@ typedef struct StructTypeFunction uint8_t isStatic; } StructTypeFunction; +typedef struct MonomorphizedGenericFunctionHashEntry +{ + uint64_t key; + TypeTag **types; + uint32_t typeCount; + StructTypeFunction function; +} MonomorphizedGenericFunctionHashEntry; + +typedef struct MonomorphizedGenericFunctionHashArray +{ + MonomorphizedGenericFunctionHashEntry *elements; + uint32_t count; +} MonomorphizedGenericFunctionHashArray; + +#define NUM_MONOMORPHIZED_HASH_BUCKETS 1031 + +typedef struct StructTypeGenericFunction +{ + char *parentStructName; + LLVMTypeRef parentStructPointerType; + char *name; + Node *functionDeclarationNode; + uint8_t isStatic; + MonomorphizedGenericFunctionHashArray + monomorphizedFunctions[NUM_MONOMORPHIZED_HASH_BUCKETS]; +} StructTypeGenericFunction; + typedef struct StructTypeDeclaration { char *name; @@ -66,11 +104,26 @@ typedef struct StructTypeDeclaration StructTypeFunction *functions; uint32_t functionCount; + + StructTypeGenericFunction *genericFunctions; + uint32_t genericFunctionCount; } StructTypeDeclaration; StructTypeDeclaration *structTypeDeclarations; uint32_t structTypeDeclarationCount; +/* FUNCTION FORWARD DECLARATIONS */ +static LLVMBasicBlockRef CompileStatement( + LLVMModuleRef module, + LLVMBuilderRef builder, + LLVMValueRef function, + Node *statement); + +static LLVMValueRef CompileExpression( + LLVMModuleRef module, + LLVMBuilderRef builder, + Node *expression); + static Scope *CreateScope() { Scope *scope = malloc(sizeof(Scope)); @@ -78,6 +131,8 @@ static Scope *CreateScope() scope->scopeStack = malloc(sizeof(ScopeFrame)); scope->scopeStack[0].localVariableCount = 0; scope->scopeStack[0].localVariables = NULL; + scope->scopeStack[0].genericTypeCount = 0; + scope->scopeStack[0].genericTypes = NULL; scope->scopeStackCount = 1; return scope; @@ -91,6 +146,8 @@ static void PushScopeFrame(Scope *scope) sizeof(ScopeFrame) * (scope->scopeStackCount + 1)); scope->scopeStack[index].localVariableCount = 0; scope->scopeStack[index].localVariables = NULL; + scope->scopeStack[index].genericTypeCount = 0; + scope->scopeStack[index].genericTypes = NULL; scope->scopeStackCount += 1; } @@ -109,31 +166,21 @@ static void PopScopeFrame(Scope *scope) free(scope->scopeStack[index].localVariables); } + if (scope->scopeStack[index].genericTypes != NULL) + { + for (i = 0; i < scope->scopeStack[index].genericTypeCount; i += 1) + { + free(scope->scopeStack[index].genericTypes[i].name); + } + free(scope->scopeStack[index].genericTypes); + } + scope->scopeStackCount -= 1; scope->scopeStack = realloc(scope->scopeStack, sizeof(ScopeFrame) * scope->scopeStackCount); } -static void AddLocalVariable( - Scope *scope, - LLVMValueRef pointer, /* can be NULL */ - LLVMValueRef value, /* can be NULL */ - char *name) -{ - ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1]; - uint32_t index = scopeFrame->localVariableCount; - - scopeFrame->localVariables = realloc( - scopeFrame->localVariables, - sizeof(LocalVariable) * (scopeFrame->localVariableCount + 1)); - scopeFrame->localVariables[index].name = strdup(name); - scopeFrame->localVariables[index].pointer = pointer; - scopeFrame->localVariables[index].value = value; - - scopeFrame->localVariableCount += 1; -} - static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type) { switch (type) @@ -155,6 +202,133 @@ static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type) return NULL; } +static LocalGenericType *LookupGenericType(char *name) +{ + int32_t i, j; + + for (i = scope->scopeStackCount - 1; i >= 0; i -= 1) + { + for (j = 0; j < scope->scopeStack[i].genericTypeCount; j += 1) + { + if (strcmp(scope->scopeStack[i].genericTypes[j].name, name) == 0) + { + return &scope->scopeStack[i].genericTypes[j]; + } + } + } + + fprintf(stderr, "Could not find resolved generic type!\n"); + return NULL; +} + +static LLVMTypeRef LookupCustomType(char *name) +{ + int32_t i; + + for (i = 0; i < structTypeDeclarationCount; i += 1) + { + if (strcmp(structTypeDeclarations[i].name, name) == 0) + { + return structTypeDeclarations[i].structType; + } + } + + fprintf(stderr, "Could not find struct type!\n"); + return NULL; +} + +static LLVMTypeRef ResolveType(TypeTag *typeTag) +{ + if (typeTag->type == Primitive) + { + return WraithTypeToLLVMType(typeTag->value.primitiveType); + } + else if (typeTag->type == Custom) + { + return LookupCustomType(typeTag->value.customType); + } + else if (typeTag->type == Reference) + { + return LLVMPointerType(ResolveType(typeTag->value.referenceType), 0); + } + else if (typeTag->type == Generic) + { + return LookupGenericType(typeTag->value.genericType)->type; + } + else + { + fprintf(stderr, "Unknown type node!\n"); + return NULL; + } +} + +static void AddLocalVariable( + Scope *scope, + LLVMValueRef pointer, /* can be NULL */ + LLVMValueRef value, /* can be NULL */ + char *name) +{ + ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1]; + uint32_t index = scopeFrame->localVariableCount; + + scopeFrame->localVariables = realloc( + scopeFrame->localVariables, + sizeof(LocalVariable) * (scopeFrame->localVariableCount + 1)); + scopeFrame->localVariables[index].name = strdup(name); + scopeFrame->localVariables[index].pointer = pointer; + scopeFrame->localVariables[index].value = value; + + scopeFrame->localVariableCount += 1; +} + +static void AddGenericVariable(Scope *scope, TypeTag *typeTag, char *name) +{ + ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1]; + uint32_t index = scopeFrame->genericTypeCount; + + scopeFrame->genericTypes = realloc( + scopeFrame->genericTypes, + sizeof(LocalGenericType) * (scopeFrame->genericTypeCount + 1)); + scopeFrame->genericTypes[index].name = strdup(name); + scopeFrame->genericTypes[index].concreteTypeTag = typeTag; + scopeFrame->genericTypes[index].type = ResolveType(typeTag); + + scopeFrame->genericTypeCount += 1; +} + +static void AddStructVariablesToScope( + LLVMBuilderRef builder, + LLVMValueRef structPointer) +{ + uint32_t i, j; + + for (i = 0; i < structTypeDeclarationCount; i += 1) + { + if (structTypeDeclarations[i].structPointerType == + LLVMTypeOf(structPointer)) + { + for (j = 0; j < structTypeDeclarations[i].fieldCount; j += 1) + { + char *ptrName = + strdup(structTypeDeclarations[i].fields[j].name); + strcat(ptrName, "_ptr"); + LLVMValueRef elementPointer = LLVMBuildStructGEP( + builder, + structPointer, + structTypeDeclarations[i].fields[j].index, + ptrName); + free(ptrName); + + AddLocalVariable( + scope, + elementPointer, + NULL, + structTypeDeclarations[i].fields[j].name); + } + } + } +} + static LLVMTypeRef FindStructType(char *name) { uint32_t i; @@ -271,6 +445,8 @@ static void AddStructDeclaration( structTypeDeclarations[index].fieldCount = 0; structTypeDeclarations[index].functions = NULL; structTypeDeclarations[index].functionCount = 0; + structTypeDeclarations[index].genericFunctions = NULL; + structTypeDeclarations[index].genericFunctionCount = 0; for (i = 0; i < fieldDeclarationCount; i += 1) { @@ -287,6 +463,7 @@ static void AddStructDeclaration( structTypeDeclarationCount += 1; } +/* FIXME: pass the declaration itself */ static void DeclareStructFunction( LLVMTypeRef wStructPointerType, LLVMValueRef function, @@ -318,49 +495,355 @@ static void DeclareStructFunction( fprintf(stderr, "Could not find struct type for function!\n"); } -static LLVMTypeRef LookupCustomType(char *name) +/* FIXME: pass the declaration itself */ +static void DeclareGenericStructFunction( + LLVMTypeRef wStructPointerType, + Node *functionDeclarationNode, + uint8_t isStatic, + char *parentStructName, + char *name) { - uint32_t i; + uint32_t i, j, index; for (i = 0; i < structTypeDeclarationCount; i += 1) { - if (strcmp(structTypeDeclarations[i].name, name) == 0) + if (structTypeDeclarations[i].structPointerType == wStructPointerType) { - return structTypeDeclarations[i].structType; + index = structTypeDeclarations[i].genericFunctionCount; + structTypeDeclarations[i].genericFunctions = realloc( + structTypeDeclarations[i].genericFunctions, + sizeof(StructTypeGenericFunction) * + (structTypeDeclarations[i].genericFunctionCount + 1)); + structTypeDeclarations[i].genericFunctions[index].name = + strdup(name); + structTypeDeclarations[i].genericFunctions[index].parentStructName = + parentStructName; + structTypeDeclarations[i].structPointerType = wStructPointerType; + structTypeDeclarations[i] + .genericFunctions[index] + .functionDeclarationNode = functionDeclarationNode; + structTypeDeclarations[i].genericFunctions[index].isStatic = + isStatic; + + for (j = 0; j < NUM_MONOMORPHIZED_HASH_BUCKETS; j += 1) + { + structTypeDeclarations[i] + .genericFunctions[index] + .monomorphizedFunctions[j] + .elements = NULL; + structTypeDeclarations[i] + .genericFunctions[index] + .monomorphizedFunctions[j] + .count = 0; + } + + structTypeDeclarations[i].genericFunctionCount += 1; + + return; + } + } +} + +static inline uint64_t HashTypeTags(TypeTag **tags, uint32_t count) +{ + const uint64_t HASH_FACTOR = 97; + uint64_t result = 1; + uint32_t i; + + for (i = 0; i < count; i += 1) + { + result *= HASH_FACTOR + str_hash(TypeTagToString(tags[i])); + } + + return result; +} + +/* FIXME: lots of duplication with non-generic function compile */ +static StructTypeFunction CompileGenericFunction( + LLVMModuleRef module, + char *parentStructName, + LLVMTypeRef wStructPointerType, + TypeTag **resolvedGenericArgumentTypes, + uint32_t genericArgumentTypeCount, + Node *functionDeclaration) +{ + uint32_t i; + uint8_t hasReturn = 0; + uint8_t isStatic = 0; + Node *functionSignature = + functionDeclaration->functionDeclaration.functionSignature; + Node *functionBody = functionDeclaration->functionDeclaration.functionBody; + uint32_t argumentCount = functionSignature->functionSignature.arguments + ->functionSignatureArguments.count; + LLVMTypeRef paramTypes[argumentCount + 1]; + uint32_t paramIndex = 0; + LLVMTypeRef returnType; + + PushScopeFrame(scope); + + for (i = 0; i < genericArgumentTypeCount; i += 1) + { + AddGenericVariable( + scope, + resolvedGenericArgumentTypes[i], + functionDeclaration->functionDeclaration.functionSignature + ->functionSignature.genericArguments->genericArguments + .arguments[i] + ->genericArgument.identifier->identifier.name); + } + + if (functionSignature->functionSignature.modifiers->functionModifiers + .count > 0) + { + for (i = 0; i < functionSignature->functionSignature.modifiers + ->functionModifiers.count; + i += 1) + { + if (functionSignature->functionSignature.modifiers + ->functionModifiers.sequence[i] + ->syntaxKind == StaticModifier) + { + isStatic = 1; + break; + } } } - fprintf(stderr, "Could not find struct type!\n"); - return NULL; + char *functionName = strdup(parentStructName); + strcat(functionName, "_"); + strcat( + functionName, + functionSignature->functionSignature.identifier->identifier.name); + + for (i = 0; i < genericArgumentTypeCount; i += 1) + { + strcat(functionName, TypeTagToString(resolvedGenericArgumentTypes[i])); + } + + if (!isStatic) + { + paramTypes[paramIndex] = wStructPointerType; + paramIndex += 1; + } + + for (i = 0; i < functionSignature->functionSignature.arguments + ->functionSignatureArguments.count; + i += 1) + { + paramTypes[paramIndex] = + ResolveType(functionSignature->functionSignature.arguments + ->functionSignatureArguments.sequence[i] + ->declaration.identifier->typeTag); + + paramIndex += 1; + } + + returnType = + ResolveType(functionSignature->functionSignature.identifier->typeTag); + + LLVMTypeRef functionType = + LLVMFunctionType(returnType, paramTypes, paramIndex, 0); + + LLVMValueRef function = LLVMAddFunction(module, functionName, functionType); + + LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry"); + LLVMBuilderRef builder = LLVMCreateBuilder(); + LLVMPositionBuilderAtEnd(builder, entry); + + if (!isStatic) + { + LLVMValueRef wStructPointer = LLVMGetParam(function, 0); + AddStructVariablesToScope(builder, wStructPointer); + } + + for (i = 0; i < functionSignature->functionSignature.arguments + ->functionSignatureArguments.count; + i += 1) + { + char *ptrName = strdup(functionSignature->functionSignature.arguments + ->functionSignatureArguments.sequence[i] + ->declaration.identifier->identifier.name); + strcat(ptrName, "_ptr"); + LLVMValueRef argument = LLVMGetParam(function, i + !isStatic); + LLVMValueRef argumentCopy = + LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName); + LLVMBuildStore(builder, argument, argumentCopy); + free(ptrName); + AddLocalVariable( + scope, + argumentCopy, + NULL, + functionSignature->functionSignature.arguments + ->functionSignatureArguments.sequence[i] + ->declaration.identifier->identifier.name); + } + + for (i = 0; i < functionBody->statementSequence.count; i += 1) + { + CompileStatement( + module, + builder, + function, + functionBody->statementSequence.sequence[i]); + } + + hasReturn = + LLVMGetBasicBlockTerminator(LLVMGetLastBasicBlock(function)) != NULL; + + if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn) + { + LLVMBuildRetVoid(builder); + } + else if (LLVMGetTypeKind(returnType) != LLVMVoidTypeKind && !hasReturn) + { + fprintf(stderr, "Return statement not provided!"); + } + + LLVMDisposeBuilder(builder); + PopScopeFrame(scope); + free(functionName); + + StructTypeFunction structTypeFunction; + structTypeFunction.name = strdup( + functionSignature->functionSignature.identifier->identifier.name); + structTypeFunction.function = function; + structTypeFunction.returnType = returnType; + structTypeFunction.isStatic = isStatic; + + return structTypeFunction; } -static LLVMTypeRef ResolveType(Node *typeNode) +static LLVMValueRef LookupGenericFunction( + LLVMModuleRef module, + StructTypeGenericFunction *genericFunction, + TypeTag **argumentTypes, + uint32_t argumentCount, + LLVMTypeRef *pReturnType, + uint8_t *pStatic) { - if (IsPrimitiveType(typeNode)) + uint32_t i, j; + uint64_t typeHash; + uint8_t match = 0; + uint32_t genericArgumentTypeCount = + genericFunction->functionDeclarationNode->functionDeclaration + .functionSignature->functionSignature.genericArguments + ->genericArguments.count; + TypeTag *resolvedGenericArgumentTypes[genericArgumentTypeCount]; + + for (i = 0; i < genericArgumentTypeCount; i += 1) { - return WraithTypeToLLVMType( - typeNode->type.typeNode->primitiveType.type); + for (j = 0; + j < genericFunction->functionDeclarationNode->functionDeclaration + .functionSignature->functionSignature.arguments + ->functionSignatureArguments.count; + j += 1) + { + if (genericFunction->functionDeclarationNode->functionDeclaration + .functionSignature->functionSignature.arguments + ->functionSignatureArguments.sequence[j] + ->declaration.identifier->typeTag->type == Generic && + strcmp( + genericFunction->functionDeclarationNode + ->functionDeclaration.functionSignature + ->functionSignature.arguments + ->functionSignatureArguments.sequence[j] + ->declaration.identifier->typeTag->value.genericType, + genericFunction->functionDeclarationNode + ->functionDeclaration.functionSignature + ->functionSignature.genericArguments->genericArguments + .arguments[i] + ->genericArgument.identifier->identifier.name) == 0) + { + resolvedGenericArgumentTypes[i] = argumentTypes[j]; + break; + } + } } - else if (typeNode->type.typeNode->syntaxKind == CustomTypeNode) + + /* Concretize generics if we are compiling nested generic functions */ + for (i = 0; i < genericArgumentTypeCount; i += 1) { - return LookupCustomType(typeNode->type.typeNode->customType.name); + if (resolvedGenericArgumentTypes[i]->type == Generic) + { + resolvedGenericArgumentTypes[i] = + LookupGenericType( + resolvedGenericArgumentTypes[i]->value.genericType) + ->concreteTypeTag; + } } - else if (typeNode->type.typeNode->syntaxKind == ReferenceTypeNode) + + typeHash = + HashTypeTags(resolvedGenericArgumentTypes, genericArgumentTypeCount); + + MonomorphizedGenericFunctionHashArray *hashArray = + &genericFunction->monomorphizedFunctions + [typeHash % NUM_MONOMORPHIZED_HASH_BUCKETS]; + + MonomorphizedGenericFunctionHashEntry *hashEntry = NULL; + for (i = 0; i < hashArray->count; i += 1) { - return LLVMPointerType( - ResolveType(typeNode->type.typeNode->referenceType.type), - 0); + match = 1; + + for (j = 0; j < hashArray->elements[i].typeCount; j += 1) + { + if (hashArray->elements[i].types[j] != + resolvedGenericArgumentTypes[j]) + { + match = 0; + break; + } + } + + if (match) + { + hashEntry = &hashArray->elements[i]; + break; + } } - else + + if (hashEntry == NULL) { - fprintf(stderr, "Unknown type node!\n"); - return NULL; + StructTypeFunction function = CompileGenericFunction( + module, + genericFunction->parentStructName, + genericFunction->parentStructPointerType, + resolvedGenericArgumentTypes, + genericArgumentTypeCount, + genericFunction->functionDeclarationNode); + + /* TODO: add to hash */ + hashArray->elements = realloc( + hashArray->elements, + sizeof(MonomorphizedGenericFunctionHashEntry) * + (hashArray->count + 1)); + hashArray->elements[hashArray->count].key = typeHash; + hashArray->elements[hashArray->count].types = + malloc(sizeof(TypeTag *) * genericArgumentTypeCount); + hashArray->elements[hashArray->count].typeCount = + genericArgumentTypeCount; + hashArray->elements[hashArray->count].function = function; + for (i = 0; i < genericArgumentTypeCount; i += 1) + { + hashArray->elements[hashArray->count].types[i] = + resolvedGenericArgumentTypes[i]; + } + hashArray->count += 1; + + hashEntry = &hashArray->elements[hashArray->count - 1]; } + + *pReturnType = hashEntry->function.returnType; + *pStatic = genericFunction->isStatic; + + return hashEntry->function.function; } static LLVMValueRef LookupFunctionByType( + LLVMModuleRef module, LLVMTypeRef structType, char *name, + TypeTag **argumentTypes, + uint32_t argumentCount, LLVMTypeRef *pReturnType, uint8_t *pStatic) { @@ -381,6 +864,23 @@ static LLVMValueRef LookupFunctionByType( return structTypeDeclarations[i].functions[j].function; } } + + for (j = 0; j < structTypeDeclarations[i].genericFunctionCount; + j += 1) + { + if (strcmp( + structTypeDeclarations[i].genericFunctions[j].name, + name) == 0) + { + return LookupGenericFunction( + module, + &structTypeDeclarations[i].genericFunctions[j], + argumentTypes, + argumentCount, + pReturnType, + pStatic); + } + } } } @@ -389,8 +889,11 @@ static LLVMValueRef LookupFunctionByType( } static LLVMValueRef LookupFunctionByPointerType( + LLVMModuleRef module, LLVMTypeRef structPointerType, char *name, + TypeTag **argumentTypes, + uint32_t argumentCount, LLVMTypeRef *pReturnType, uint8_t *pStatic) { @@ -411,6 +914,23 @@ static LLVMValueRef LookupFunctionByPointerType( return structTypeDeclarations[i].functions[j].function; } } + + for (j = 0; j < structTypeDeclarations[i].genericFunctionCount; + j += 1) + { + if (strcmp( + structTypeDeclarations[i].genericFunctions[j].name, + name) == 0) + { + return LookupGenericFunction( + module, + &structTypeDeclarations[i].genericFunctions[j], + argumentTypes, + argumentCount, + pReturnType, + pStatic); + } + } } } @@ -419,53 +939,24 @@ static LLVMValueRef LookupFunctionByPointerType( } static LLVMValueRef LookupFunctionByInstance( + LLVMModuleRef module, LLVMValueRef structPointer, char *name, + TypeTag **argumentTypes, + uint32_t argumentCount, LLVMTypeRef *pReturnType, uint8_t *pStatic) { return LookupFunctionByPointerType( + module, LLVMTypeOf(structPointer), name, + argumentTypes, + argumentCount, pReturnType, pStatic); } -static void AddStructVariablesToScope( - LLVMBuilderRef builder, - LLVMValueRef structPointer) -{ - uint32_t i, j; - - for (i = 0; i < structTypeDeclarationCount; i += 1) - { - if (structTypeDeclarations[i].structPointerType == - LLVMTypeOf(structPointer)) - { - for (j = 0; j < structTypeDeclarations[i].fieldCount; j += 1) - { - char *ptrName = - strdup(structTypeDeclarations[i].fields[j].name); - strcat(ptrName, "_ptr"); - LLVMValueRef elementPointer = LLVMBuildStructGEP( - builder, - structPointer, - structTypeDeclarations[i].fields[j].index, - ptrName); - free(ptrName); - - AddLocalVariable( - scope, - elementPointer, - NULL, - structTypeDeclarations[i].fields[j].name); - } - } - } -} - -static LLVMValueRef CompileExpression(LLVMBuilderRef builder, Node *expression); - static LLVMValueRef CompileNumber(Node *numberExpression) { return LLVMConstInt(LLVMInt64Type(), numberExpression->number.value, 0); @@ -482,13 +973,19 @@ static LLVMValueRef CompileString( } static LLVMValueRef CompileBinaryExpression( + LLVMModuleRef module, LLVMBuilderRef builder, Node *binaryExpression) { - LLVMValueRef left = - CompileExpression(builder, binaryExpression->binaryExpression.left); - LLVMValueRef right = - CompileExpression(builder, binaryExpression->binaryExpression.right); + LLVMValueRef left = CompileExpression( + module, + builder, + binaryExpression->binaryExpression.left); + + LLVMValueRef right = CompileExpression( + module, + builder, + binaryExpression->binaryExpression.right); switch (binaryExpression->binaryExpression.operator) { @@ -533,6 +1030,7 @@ static LLVMValueRef CompileBinaryExpression( /* FIXME THIS IS ALL BROKEN */ static LLVMValueRef CompileFunctionCallExpression( + LLVMModuleRef module, LLVMBuilderRef builder, Node *functionCallExpression) { @@ -542,14 +1040,30 @@ static LLVMValueRef CompileFunctionCallExpression( [functionCallExpression->functionCallExpression.argumentSequence ->functionArgumentSequence.count + 1]; + TypeTag + *argumentTypes[functionCallExpression->functionCallExpression + .argumentSequence->functionArgumentSequence.count]; LLVMValueRef function; uint8_t isStatic; LLVMValueRef structInstance; LLVMTypeRef functionReturnType; char *returnName = ""; + for (i = 0; i < functionCallExpression->functionCallExpression + .argumentSequence->functionArgumentSequence.count; + i += 1) + { + argumentTypes[i] = + functionCallExpression->functionCallExpression.argumentSequence + ->functionArgumentSequence.sequence[i] + ->typeTag; + + argumentCount += 1; + } + /* FIXME: this needs to be recursive on access chains */ - /* FIXME: this needs to be able to call same-struct functions implicitly */ + /* FIXME: this needs to be able to call same-struct functions implicitly + */ if (functionCallExpression->functionCallExpression.identifier->syntaxKind == AccessExpression) { @@ -560,9 +1074,12 @@ static LLVMValueRef CompileFunctionCallExpression( if (typeReference != NULL) { function = LookupFunctionByType( + module, typeReference, functionCallExpression->functionCallExpression.identifier ->accessExpression.accessor->identifier.name, + argumentTypes, + argumentCount, &functionReturnType, &isStatic); } @@ -572,9 +1089,12 @@ static LLVMValueRef CompileFunctionCallExpression( functionCallExpression->functionCallExpression.identifier ->accessExpression.accessee->identifier.name); function = LookupFunctionByInstance( + module, structInstance, functionCallExpression->functionCallExpression.identifier ->accessExpression.accessor->identifier.name, + argumentTypes, + argumentCount, &functionReturnType, &isStatic); } @@ -585,6 +1105,8 @@ static LLVMValueRef CompileFunctionCallExpression( return NULL; } + argumentCount = 0; + if (!isStatic) { args[argumentCount] = structInstance; @@ -596,6 +1118,7 @@ static LLVMValueRef CompileFunctionCallExpression( i += 1) { args[argumentCount] = CompileExpression( + module, builder, functionCallExpression->functionCallExpression.argumentSequence ->functionArgumentSequence.sequence[i]); @@ -642,11 +1165,14 @@ static LLVMValueRef CompileAllocExpression( LLVMBuilderRef builder, Node *allocExpression) { - LLVMTypeRef type = ResolveType(allocExpression->allocExpression.type); + LLVMTypeRef type = ResolveType(allocExpression->typeTag); return LLVMBuildMalloc(builder, type, "allocation"); } -static LLVMValueRef CompileExpression(LLVMBuilderRef builder, Node *expression) +static LLVMValueRef CompileExpression( + LLVMModuleRef module, + LLVMBuilderRef builder, + Node *expression) { switch (expression->syntaxKind) { @@ -657,10 +1183,10 @@ static LLVMValueRef CompileExpression(LLVMBuilderRef builder, Node *expression) return CompileAllocExpression(builder, expression); case BinaryExpression: - return CompileBinaryExpression(builder, expression); + return CompileBinaryExpression(module, builder, expression); case FunctionCallExpression: - return CompileFunctionCallExpression(builder, expression); + return CompileFunctionCallExpression(module, builder, expression); case Identifier: return FindVariableValue(builder, expression->identifier.name); @@ -676,17 +1202,14 @@ static LLVMValueRef CompileExpression(LLVMBuilderRef builder, Node *expression) return NULL; } -static LLVMBasicBlockRef CompileStatement( - LLVMBuilderRef builder, - LLVMValueRef function, - Node *statement); - static LLVMBasicBlockRef CompileReturn( + LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement) { LLVMValueRef expression = CompileExpression( + module, builder, returnStatemement->returnStatement.expression); LLVMBuildRet(builder, expression); @@ -715,7 +1238,7 @@ static LLVMValueRef CompileFunctionVariableDeclaration( variable = LLVMBuildAlloca( builder, - ResolveType(variableDeclaration->declaration.type), + ResolveType(variableDeclaration->declaration.identifier->typeTag), ptrName); free(ptrName); @@ -726,11 +1249,13 @@ static LLVMValueRef CompileFunctionVariableDeclaration( } static LLVMBasicBlockRef CompileAssignment( + LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement) { LLVMValueRef result = CompileExpression( + module, builder, assignmentStatement->assignmentStatement.right); LLVMValueRef identifier; @@ -768,13 +1293,14 @@ static LLVMBasicBlockRef CompileAssignment( } static LLVMBasicBlockRef CompileIfStatement( + LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement) { uint32_t i; LLVMValueRef conditional = - CompileExpression(builder, ifStatement->ifStatement.expression); + CompileExpression(module, builder, ifStatement->ifStatement.expression); LLVMBasicBlockRef block = LLVMAppendBasicBlock(function, "ifBlock"); LLVMBasicBlockRef afterCond = LLVMAppendBasicBlock(function, "afterCond"); @@ -789,6 +1315,7 @@ static LLVMBasicBlockRef CompileIfStatement( i += 1) { CompileStatement( + module, builder, function, ifStatement->ifStatement.statementSequence->statementSequence @@ -802,12 +1329,14 @@ static LLVMBasicBlockRef CompileIfStatement( } static LLVMBasicBlockRef CompileIfElseStatement( + LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *ifElseStatement) { uint32_t i; LLVMValueRef conditional = CompileExpression( + module, builder, ifElseStatement->ifElseStatement.ifStatement->ifStatement.expression); @@ -824,6 +1353,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( i += 1) { CompileStatement( + module, builder, function, ifElseStatement->ifElseStatement.ifStatement->ifStatement @@ -842,6 +1372,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( i += 1) { CompileStatement( + module, builder, function, ifElseStatement->ifElseStatement.elseStatement @@ -851,6 +1382,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( else { CompileStatement( + module, builder, function, ifElseStatement->ifElseStatement.elseStatement); @@ -863,6 +1395,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( } static LLVMBasicBlockRef CompileForLoopStatement( + LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *forLoopStatement) @@ -875,8 +1408,8 @@ static LLVMBasicBlockRef CompileForLoopStatement( LLVMAppendBasicBlock(function, "afterLoop"); char *iteratorVariableName = forLoopStatement->forLoop.declaration ->declaration.identifier->identifier.name; - LLVMTypeRef iteratorVariableType = - ResolveType(forLoopStatement->forLoop.declaration->declaration.type); + LLVMTypeRef iteratorVariableType = ResolveType( + forLoopStatement->forLoop.declaration->declaration.identifier->typeTag); PushScopeFrame(scope); @@ -922,6 +1455,7 @@ static LLVMBasicBlockRef CompileForLoopStatement( i += 1) { lastBlock = CompileStatement( + module, builder, function, forLoopStatement->forLoop.statementSequence->statementSequence @@ -950,6 +1484,7 @@ static LLVMBasicBlockRef CompileForLoopStatement( } static LLVMBasicBlockRef CompileStatement( + LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *statement) @@ -957,27 +1492,27 @@ static LLVMBasicBlockRef CompileStatement( switch (statement->syntaxKind) { case Assignment: - return CompileAssignment(builder, function, statement); + return CompileAssignment(module, builder, function, statement); case Declaration: CompileFunctionVariableDeclaration(builder, function, statement); return LLVMGetLastBasicBlock(function); case ForLoop: - return CompileForLoopStatement(builder, function, statement); + return CompileForLoopStatement(module, builder, function, statement); case FunctionCallExpression: - CompileFunctionCallExpression(builder, statement); + CompileFunctionCallExpression(module, builder, statement); return LLVMGetLastBasicBlock(function); case IfStatement: - return CompileIfStatement(builder, function, statement); + return CompileIfStatement(module, builder, function, statement); case IfElseStatement: - return CompileIfElseStatement(builder, function, statement); + return CompileIfElseStatement(module, builder, function, statement); case Return: - return CompileReturn(builder, function, statement); + return CompileReturn(module, builder, function, statement); case ReturnVoid: return CompileReturnVoid(builder, function); @@ -991,8 +1526,6 @@ static void CompileFunction( LLVMModuleRef module, char *parentStructName, LLVMTypeRef wStructPointerType, - Node **fieldDeclarations, - uint32_t fieldDeclarationCount, Node *functionDeclaration) { uint32_t i; @@ -1023,101 +1556,118 @@ static void CompileFunction( } } - if (!isStatic) - { - paramTypes[paramIndex] = wStructPointerType; - paramIndex += 1; - } - - PushScopeFrame(scope); - - /* FIXME: should work for non-primitive types */ - for (i = 0; i < functionSignature->functionSignature.arguments - ->functionSignatureArguments.count; - i += 1) - { - paramTypes[paramIndex] = - ResolveType(functionSignature->functionSignature.arguments - ->functionSignatureArguments.sequence[i] - ->declaration.type); - paramIndex += 1; - } - - LLVMTypeRef returnType = - ResolveType(functionSignature->functionSignature.type); - LLVMTypeRef functionType = - LLVMFunctionType(returnType, paramTypes, paramIndex, 0); - char *functionName = strdup(parentStructName); strcat(functionName, "_"); strcat( functionName, functionSignature->functionSignature.identifier->identifier.name); - LLVMValueRef function = LLVMAddFunction(module, functionName, functionType); - free(functionName); - DeclareStructFunction( - wStructPointerType, - function, - returnType, - isStatic, - functionSignature->functionSignature.identifier->identifier.name); - - LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry"); - LLVMBuilderRef builder = LLVMCreateBuilder(); - LLVMPositionBuilderAtEnd(builder, entry); - - if (!isStatic) + if (functionSignature->functionSignature.genericArguments->genericArguments + .count == 0) { - LLVMValueRef wStructPointer = LLVMGetParam(function, 0); - AddStructVariablesToScope(builder, wStructPointer); - } + PushScopeFrame(scope); - for (i = 0; i < functionSignature->functionSignature.arguments - ->functionSignatureArguments.count; - i += 1) - { - char *ptrName = strdup(functionSignature->functionSignature.arguments - ->functionSignatureArguments.sequence[i] - ->declaration.identifier->identifier.name); - strcat(ptrName, "_ptr"); - LLVMValueRef argument = LLVMGetParam(function, i + !isStatic); - LLVMValueRef argumentCopy = - LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName); - LLVMBuildStore(builder, argument, argumentCopy); - free(ptrName); - AddLocalVariable( - scope, - argumentCopy, - NULL, - functionSignature->functionSignature.arguments - ->functionSignatureArguments.sequence[i] - ->declaration.identifier->identifier.name); - } + if (!isStatic) + { + paramTypes[paramIndex] = wStructPointerType; + paramIndex += 1; + } - for (i = 0; i < functionBody->statementSequence.count; i += 1) - { - CompileStatement( - builder, + for (i = 0; i < functionSignature->functionSignature.arguments + ->functionSignatureArguments.count; + i += 1) + { + paramTypes[paramIndex] = + ResolveType(functionSignature->functionSignature.arguments + ->functionSignatureArguments.sequence[i] + ->declaration.identifier->typeTag); + paramIndex += 1; + } + + LLVMTypeRef returnType = ResolveType( + functionSignature->functionSignature.identifier->typeTag); + LLVMTypeRef functionType = + LLVMFunctionType(returnType, paramTypes, paramIndex, 0); + + LLVMValueRef function = + LLVMAddFunction(module, functionName, functionType); + + DeclareStructFunction( + wStructPointerType, function, - functionBody->statementSequence.sequence[i]); + returnType, + isStatic, + functionSignature->functionSignature.identifier->identifier.name); + + LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry"); + LLVMBuilderRef builder = LLVMCreateBuilder(); + LLVMPositionBuilderAtEnd(builder, entry); + + if (!isStatic) + { + LLVMValueRef wStructPointer = LLVMGetParam(function, 0); + AddStructVariablesToScope(builder, wStructPointer); + } + + for (i = 0; i < functionSignature->functionSignature.arguments + ->functionSignatureArguments.count; + i += 1) + { + char *ptrName = + strdup(functionSignature->functionSignature.arguments + ->functionSignatureArguments.sequence[i] + ->declaration.identifier->identifier.name); + strcat(ptrName, "_ptr"); + LLVMValueRef argument = LLVMGetParam(function, i + !isStatic); + LLVMValueRef argumentCopy = + LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName); + LLVMBuildStore(builder, argument, argumentCopy); + free(ptrName); + AddLocalVariable( + scope, + argumentCopy, + NULL, + functionSignature->functionSignature.arguments + ->functionSignatureArguments.sequence[i] + ->declaration.identifier->identifier.name); + } + + for (i = 0; i < functionBody->statementSequence.count; i += 1) + { + CompileStatement( + module, + builder, + function, + functionBody->statementSequence.sequence[i]); + } + + hasReturn = LLVMGetBasicBlockTerminator( + LLVMGetLastBasicBlock(function)) != NULL; + + if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn) + { + LLVMBuildRetVoid(builder); + } + else if (LLVMGetTypeKind(returnType) != LLVMVoidTypeKind && !hasReturn) + { + fprintf(stderr, "Return statement not provided!"); + } + + LLVMDisposeBuilder(builder); + + PopScopeFrame(scope); } - - hasReturn = - LLVMGetBasicBlockTerminator(LLVMGetLastBasicBlock(function)) != NULL; - - if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn) + else { - LLVMBuildRetVoid(builder); - } - else if (LLVMGetTypeKind(returnType) != LLVMVoidTypeKind && !hasReturn) - { - fprintf(stderr, "Return statement not provided!"); + DeclareGenericStructFunction( + wStructPointerType, + functionDeclaration, + isStatic, + parentStructName, + functionSignature->functionSignature.identifier->identifier.name); } - PopScopeFrame(scope); - - LLVMDisposeBuilder(builder); + free(functionName); } static void CompileStruct( @@ -1151,8 +1701,8 @@ static void CompileStruct( switch (currentDeclarationNode->syntaxKind) { case Declaration: /* this is badly named */ - types[fieldCount] = - ResolveType(currentDeclarationNode->declaration.type); + types[fieldCount] = ResolveType( + currentDeclarationNode->declaration.identifier->typeTag); fieldDeclarations[fieldCount] = currentDeclarationNode; fieldCount += 1; break; @@ -1180,8 +1730,6 @@ static void CompileStruct( module, structName, wStructPointerType, - fieldDeclarations, - fieldCount, currentDeclarationNode); break; } @@ -1211,7 +1759,8 @@ static void Compile( { fprintf( stderr, - "top level declarations that are not structs are forbidden!\n"); + "top level declarations that are not structs are " + "forbidden!\n"); } } } diff --git a/src/identcheck.c b/src/identcheck.c deleted file mode 100644 index 571a29e..0000000 --- a/src/identcheck.c +++ /dev/null @@ -1,485 +0,0 @@ -#include -#include -#include -#include -#include - -#include "ast.h" -#include "identcheck.h" - -IdNode *MakeIdNode(NodeType type, char *name, IdNode *parent) -{ - IdNode *node = (IdNode *)malloc(sizeof(IdNode)); - node->type = type; - node->name = strdup(name); - node->parent = parent; - node->childCount = 0; - node->childCapacity = 0; - node->children = NULL; - node->typeTag = NULL; - return node; -} - -void AddChildToNode(IdNode *node, IdNode *child) -{ - if (child == NULL) - return; - - if (node->children == NULL) - { - node->childCapacity = 2; - node->children = - (IdNode **)malloc(sizeof(IdNode *) * node->childCapacity); - } - else if (node->childCount == node->childCapacity) - { - node->childCapacity *= 2; - node->children = (IdNode **)realloc( - node->children, - sizeof(IdNode *) * node->childCapacity); - } - - node->children[node->childCount] = child; - node->childCount += 1; -} - -IdNode *MakeIdTree(Node *astNode, IdNode *parent) -{ - uint32_t i; - IdNode *mainNode; - switch (astNode->syntaxKind) - { - case AccessExpression: - AddChildToNode( - parent, - MakeIdTree(astNode->accessExpression.accessee, parent)); - AddChildToNode( - parent, - MakeIdTree(astNode->accessExpression.accessor, parent)); - return NULL; - - case AllocExpression: - AddChildToNode( - parent, - MakeIdTree(astNode->allocExpression.type, parent)); - return NULL; - - case Assignment: - { - if (astNode->assignmentStatement.left->syntaxKind == Declaration) - { - return MakeIdTree(astNode->assignmentStatement.left, parent); - } - else - { - AddChildToNode( - parent, - MakeIdTree(astNode->assignmentStatement.left, parent)); - AddChildToNode( - parent, - MakeIdTree(astNode->assignmentStatement.right, parent)); - return NULL; - } - } - - case BinaryExpression: - AddChildToNode( - parent, - MakeIdTree(astNode->binaryExpression.left, parent)); - AddChildToNode( - parent, - MakeIdTree(astNode->binaryExpression.right, parent)); - return NULL; - - case Declaration: - { - Node *idNode = astNode->declaration.identifier; - mainNode = MakeIdNode(Variable, idNode->identifier.name, parent); - mainNode->typeTag = MakeTypeTag(astNode); - idNode->typeTag = mainNode->typeTag; - break; - } - - case DeclarationSequence: - { - mainNode = MakeIdNode(UnorderedScope, "", parent); - for (i = 0; i < astNode->declarationSequence.count; i++) - { - AddChildToNode( - mainNode, - MakeIdTree(astNode->declarationSequence.sequence[i], mainNode)); - } - break; - } - - case ForLoop: - { - Node *loopDecl = astNode->forLoop.declaration; - Node *loopBody = astNode->forLoop.statementSequence; - mainNode = MakeIdNode(OrderedScope, "for-loop", parent); - AddChildToNode(mainNode, MakeIdTree(loopDecl, mainNode)); - AddChildToNode(mainNode, MakeIdTree(loopBody, mainNode)); - break; - } - - case FunctionArgumentSequence: - for (i = 0; i < astNode->functionArgumentSequence.count; i++) - { - AddChildToNode( - parent, - MakeIdTree( - astNode->functionArgumentSequence.sequence[i], - parent)); - } - return NULL; - - case FunctionCallExpression: - AddChildToNode( - parent, - MakeIdTree(astNode->functionCallExpression.identifier, parent)); - AddChildToNode( - parent, - MakeIdTree( - astNode->functionCallExpression.argumentSequence, - parent)); - return NULL; - - case FunctionDeclaration: - { - Node *sigNode = astNode->functionDeclaration.functionSignature; - Node *idNode = sigNode->functionSignature.identifier; - char *funcName = idNode->identifier.name; - mainNode = MakeIdNode(Function, funcName, parent); - mainNode->typeTag = MakeTypeTag(astNode); - idNode->typeTag = mainNode->typeTag; - MakeIdTree(sigNode->functionSignature.arguments, mainNode); - MakeIdTree(astNode->functionDeclaration.functionBody, mainNode); - break; - } - - case FunctionSignatureArguments: - { - for (i = 0; i < astNode->functionSignatureArguments.count; i++) - { - Node *argNode = astNode->functionSignatureArguments.sequence[i]; - AddChildToNode(parent, MakeIdTree(argNode, parent)); - } - return NULL; - } - - case Identifier: - { - char *name = astNode->identifier.name; - mainNode = MakeIdNode(Placeholder, name, parent); - IdNode *lookupNode = LookupId(mainNode, NULL, name); - if (lookupNode == NULL) - { - fprintf(stderr, "wraith: Could not find IdNode for id %s\n", name); - TypeTag *tag = (TypeTag *)malloc(sizeof(TypeTag)); - tag->type = Unknown; - astNode->typeTag = tag; - } - else - { - astNode->typeTag = lookupNode->typeTag; - } - break; - } - - case IfStatement: - { - Node *clause = astNode->ifStatement.expression; - Node *stmtSeq = astNode->ifStatement.statementSequence; - mainNode = MakeIdNode(OrderedScope, "if", parent); - MakeIdTree(clause, mainNode); - MakeIdTree(stmtSeq, mainNode); - break; - } - - case IfElseStatement: - { - Node *ifNode = astNode->ifElseStatement.ifStatement; - Node *elseStmts = astNode->ifElseStatement.elseStatement; - mainNode = MakeIdNode(OrderedScope, "if-else", parent); - IdNode *ifBranch = MakeIdTree(ifNode, mainNode); - AddChildToNode(mainNode, ifBranch); - IdNode *elseScope = MakeIdNode(OrderedScope, "else", mainNode); - MakeIdTree(elseStmts, elseScope); - AddChildToNode(mainNode, elseScope); - break; - } - - case ReferenceTypeNode: - AddChildToNode(parent, MakeIdTree(astNode->referenceType.type, parent)); - return NULL; - - case Return: - AddChildToNode( - parent, - MakeIdTree(astNode->returnStatement.expression, parent)); - return NULL; - - case StatementSequence: - { - for (i = 0; i < astNode->statementSequence.count; i++) - { - Node *argNode = astNode->statementSequence.sequence[i]; - AddChildToNode(parent, MakeIdTree(argNode, parent)); - } - return NULL; - } - - case StructDeclaration: - { - Node *idNode = astNode->structDeclaration.identifier; - Node *declsNode = astNode->structDeclaration.declarationSequence; - mainNode = MakeIdNode(Struct, idNode->identifier.name, parent); - mainNode->typeTag = MakeTypeTag(astNode); - for (i = 0; i < declsNode->declarationSequence.count; i++) - { - Node *decl = declsNode->declarationSequence.sequence[i]; - AddChildToNode(mainNode, MakeIdTree(decl, mainNode)); - } - break; - } - - case Type: - AddChildToNode(parent, MakeIdTree(astNode->type.typeNode, parent)); - return NULL; - - case UnaryExpression: - AddChildToNode( - parent, - MakeIdTree(astNode->unaryExpression.child, parent)); - return NULL; - - case Comment: - case CustomTypeNode: - case FunctionModifiers: - case FunctionSignature: - case Number: - case PrimitiveTypeNode: - case ReturnVoid: - case StaticModifier: - case StringLiteral: - return NULL; - } - - astNode->idLink = mainNode; - return mainNode; -} - -void PrintIdNode(IdNode *node) -{ - if (node == NULL) - { - fprintf( - stderr, - "wraith: Attempted to call PrintIdNode with null value.\n"); - return; - } - - switch (node->type) - { - case Placeholder: - printf("Placeholder (%s)\n", node->name); - break; - case OrderedScope: - printf("OrderedScope (%s)\n", node->name); - break; - case UnorderedScope: - printf("UnorderedScope (%s)\n", node->name); - break; - case Struct: - printf("%s : %s\n", node->name, TypeTagToString(node->typeTag)); - break; - case Function: - printf( - "%s : Function<%s>\n", - node->name, - TypeTagToString(node->typeTag)); - break; - case Variable: - printf("%s : %s\n", node->name, TypeTagToString(node->typeTag)); - break; - } -} - -void PrintIdTree(IdNode *tree, uint32_t tabCount) -{ - if (tree == NULL) - { - fprintf( - stderr, - "wraith: Attempted to call PrintIdTree on a null value.\n"); - return; - } - - uint32_t i; - for (i = 0; i < tabCount; i++) - { - printf("| "); - } - - PrintIdNode(tree); - - for (i = 0; i < tree->childCount; i++) - { - PrintIdTree(tree->children[i], tabCount + 1); - } -} - -int PrintAncestors(IdNode *node) -{ - if (node == NULL) - return -1; - - int i; - int indent = 1; - indent += PrintAncestors(node->parent); - for (i = 0; i < indent; i++) - { - printf(" "); - } - PrintIdNode(node); - return indent; -} - -IdNode *LookdownId(IdNode *root, NodeType targetType, char *targetName) -{ - if (root == NULL) - { - fprintf( - stderr, - "wraith: Attempted to call LookdownId on a null value.\n"); - return NULL; - } - - IdNode *result = NULL; - IdNode **frontier = (IdNode **)malloc(sizeof(IdNode *)); - frontier[0] = root; - uint32_t frontierCount = 1; - - while (frontierCount > 0) - { - IdNode *current = frontier[0]; - - if (current->type == targetType && - strcmp(current->name, targetName) == 0) - { - result = current; - break; - } - - uint32_t i; - for (i = 1; i < frontierCount; i++) - { - frontier[i - 1] = frontier[i]; - } - size_t newSize = frontierCount + current->childCount - 1; - if (frontierCount != newSize) - { - frontier = (IdNode **)realloc(frontier, sizeof(IdNode *) * newSize); - } - for (i = 0; i < current->childCount; i++) - { - frontier[frontierCount + i - 1] = current->children[i]; - } - frontierCount = newSize; - } - - free(frontier); - return result; -} - -bool ScopeHasOrdering(IdNode *node) -{ - switch (node->type) - { - case OrderedScope: - case Function: - case Variable: /* this is only technically true */ - return true; - default: - return false; - } -} - -IdNode *LookupId(IdNode *node, IdNode *prev, char *target) -{ - if (node == NULL) - { - return NULL; - } - - if (strcmp(node->name, target) == 0 && node->type != Placeholder) - { - return node; - } - - /* If this is the start of our search, we should not attempt to look at - * child nodes. Only looking up the scope tree is valid at this point. - * - * This has the notable side-effect that this function will return NULL if - * you attempt to look up a struct's internals starting from the node - * representing the struct itself. This is because an IdNode corresponds to - * the location *where an identifier is first declared.* Thus, an identifier - * has no knowledge of identifiers declared "inside" of it. - */ - if (prev == NULL) - { - return LookupId(node->parent, node, target); - } - - /* If the current node forms an ordered scope then we want to prevent - * ourselves from looking up identifiers declared after the scope we have - * just come from. - */ - uint32_t idxLimit; - if (ScopeHasOrdering(node)) - { - uint32_t i; - for (i = 0, idxLimit = 0; i < node->childCount; i++, idxLimit++) - { - if (node->children[i] == prev) - { - break; - } - } - } - else - { - idxLimit = node->childCount; - } - - uint32_t i; - for (i = 0; i < idxLimit; i++) - { - IdNode *child = node->children[i]; - if (child == prev || child->type == Placeholder) - { - /* Do not inspect the node we just came from or placeholders. */ - continue; - } - - if (strcmp(child->name, target) == 0) - { - return child; - } - - if (child->type == Struct) - { - uint32_t j; - for (j = 0; j < child->childCount; j++) - { - IdNode *grandchild = child->children[j]; - if (strcmp(grandchild->name, target) == 0) - { - return grandchild; - } - } - } - } - - return LookupId(node->parent, node, target); -} diff --git a/src/identcheck.h b/src/identcheck.h deleted file mode 100644 index c0ccca6..0000000 --- a/src/identcheck.h +++ /dev/null @@ -1,49 +0,0 @@ -/* Validates identifier usage in an AST. */ - -#ifndef WRAITH_IDENTCHECK_H -#define WRAITH_IDENTCHECK_H - -#include - -#include "ast.h" - -struct TypeTag; -struct Node; - -typedef enum NodeType -{ - Placeholder, - UnorderedScope, - OrderedScope, - Struct, - Function, - Variable -} NodeType; - -typedef struct IdNode -{ - NodeType type; - char *name; - struct TypeTag *typeTag; - struct IdNode *parent; - struct IdNode **children; - uint32_t childCount; - uint32_t childCapacity; -} IdNode; - -typedef struct IdStatus -{ - enum StatusCode - { - Valid, - } StatusCode; -} IdStatus; - -IdNode *MakeIdTree(struct Node *astNode, IdNode *parent); -void PrintIdNode(IdNode *node); -void PrintIdTree(IdNode *tree, uint32_t tabCount); -int PrintAncestors(IdNode *node); -IdNode *LookdownId(IdNode *root, NodeType targetType, char *targetName); -IdNode *LookupId(IdNode *node, IdNode *prev, char *target); - -#endif /* WRAITH_IDENTCHECK_H */ diff --git a/src/main.c b/src/main.c index eda6895..f3f8d36 100644 --- a/src/main.c +++ b/src/main.c @@ -2,8 +2,8 @@ #include #include "codegen.h" -#include "identcheck.h" #include "parser.h" +#include "validation.h" int main(int argc, char *argv[]) { @@ -85,12 +85,15 @@ int main(int argc, char *argv[]) } else { - { - IdNode *idTree = MakeIdTree(rootNode, NULL); - PrintIdTree(idTree, /*tabCount=*/0); - printf("\n"); - PrintNode(rootNode, /*tabCount=*/0); - } + LinkParentPointers(rootNode, NULL); + /* FIXME: ValidateIdentifiers should return some sort of + error status object. */ + ValidateIdentifiers(rootNode); + TagIdentifierTypes(rootNode); + ConvertCustomsToGenerics(rootNode); + PrintNode(rootNode, 0); + + printf("Beginning codegen.\n"); exitCode = Codegen(rootNode, optimizationLevel); } } diff --git a/src/util.c b/src/util.c index 8001d03..42911e7 100644 --- a/src/util.c +++ b/src/util.c @@ -1,7 +1,6 @@ #include "util.h" #include -#include char *strdup(const char *s) { @@ -15,3 +14,16 @@ char *strdup(const char *s) memcpy(result, s, slen + 1); return result; } + +uint64_t str_hash(char *str) +{ + uint64_t hash = 5381; + size_t c; + + while ((c = *str++)) + { + hash = ((hash << 5) + hash) + c; /* hash * 33 + c */ + } + + return hash; +} diff --git a/src/util.h b/src/util.h index 884fa24..108211b 100644 --- a/src/util.h +++ b/src/util.h @@ -1,8 +1,10 @@ #ifndef WRAITH_UTIL_H #define WRAITH_UTIL_H +#include #include char *strdup(const char *s); +uint64_t str_hash(char *str); #endif /* WRAITH_UTIL_H */ diff --git a/src/validation.c b/src/validation.c new file mode 100644 index 0000000..e62c4dd --- /dev/null +++ b/src/validation.c @@ -0,0 +1,452 @@ +#include "validation.h" + +#include +#include +#include +#include + +Node *GetIdFromStruct(Node *structDecl) +{ + if (structDecl->syntaxKind != StructDeclaration) + { + fprintf( + stderr, + "wraith: Attempted to call GetIdFromStruct on node with kind: " + "%s.\n", + SyntaxKindString(structDecl->syntaxKind)); + return NULL; + } + + return structDecl->structDeclaration.identifier; +} + +Node *GetIdFromFunction(Node *funcDecl) +{ + if (funcDecl->syntaxKind != FunctionDeclaration) + { + fprintf( + stderr, + "wraith: Attempted to call GetIdFromFunction on node with kind: " + "%s.\n", + SyntaxKindString(funcDecl->syntaxKind)); + return NULL; + } + + Node *sig = funcDecl->functionDeclaration.functionSignature; + return sig->functionSignature.identifier; +} + +Node *GetIdFromDeclaration(Node *decl) +{ + if (decl->syntaxKind != Declaration) + { + fprintf( + stderr, + "wraith: Attempted to call GetIdFromDeclaration on node with kind: " + "%s.\n", + SyntaxKindString(decl->syntaxKind)); + } + + return decl->declaration.identifier; +} + +bool AssignmentHasDeclaration(Node *assign) +{ + return ( + assign->syntaxKind == Assignment && + assign->assignmentStatement.left->syntaxKind == Declaration); +} + +Node *GetIdFromAssignment(Node *assign) +{ + if (assign->syntaxKind != Assignment) + { + fprintf( + stderr, + "wraith: Attempted to call GetIdFromAssignment on node with kind: " + "%s.\n", + SyntaxKindString(assign->syntaxKind)); + } + + if (AssignmentHasDeclaration(assign)) + { + return GetIdFromDeclaration(assign->assignmentStatement.left); + } + + return NULL; +} + +bool NodeMayHaveId(Node *node) +{ + switch (node->syntaxKind) + { + case StructDeclaration: + case FunctionDeclaration: + case Declaration: + case Assignment: + return true; + default: + return false; + } +} + +Node *TryGetId(Node *node) +{ + switch (node->syntaxKind) + { + case Assignment: + return GetIdFromAssignment(node); + case Declaration: + return GetIdFromDeclaration(node); + case FunctionDeclaration: + return GetIdFromFunction(node); + case StructDeclaration: + return GetIdFromStruct(node); + default: + return NULL; + } +} + +Node *LookupFunctionArgId(Node *funcDecl, char *target) +{ + Node *args = funcDecl->functionDeclaration.functionSignature + ->functionSignature.arguments; + + uint32_t i; + for (i = 0; i < args->functionArgumentSequence.count; i += 1) + { + Node *arg = args->functionArgumentSequence.sequence[i]; + if (arg->syntaxKind != Declaration) + { + fprintf( + stderr, + "wraith: Encountered %s node in function signature args " + "list.\n", + SyntaxKindString(arg->syntaxKind)); + continue; + } + + Node *argId = GetIdFromDeclaration(arg); + if (argId != NULL && strcmp(target, argId->identifier.name) == 0) + return argId; + } + + return NULL; +} + +Node *LookupStructInternalId(Node *structDecl, char *target) +{ + Node *decls = structDecl->structDeclaration.declarationSequence; + + uint32_t i; + for (i = 0; i < decls->declarationSequence.count; i += 1) + { + Node *match = TryGetId(decls->declarationSequence.sequence[i]); + if (match != NULL && strcmp(target, match->identifier.name) == 0) + return match; + } + + return NULL; +} + +Node *InspectNode(Node *node, char *target) +{ + /* If this node may have an identifier declaration inside it, attempt to + * look up the identifier + * node itself, returning it if it matches the given target name. */ + if (NodeMayHaveId(node)) + { + Node *candidateId = TryGetId(node); + if (candidateId != NULL && + strcmp(target, candidateId->identifier.name) == 0) + return candidateId; + } + + /* If the candidate node was not the one we wanted, but the node node is a + * function declaration, it's possible that the identifier we want is one of + * the function's parameters rather than the function's name itself. */ + if (node->syntaxKind == FunctionDeclaration) + { + Node *match = LookupFunctionArgId(node, target); + if (match != NULL) + return match; + } + + /* Likewise if the node node is a struct declaration, inspect the struct's + * internals + * to see if a top-level definition is the one we're looking for. */ + if (node->syntaxKind == StructDeclaration) + { + Node *match = LookupStructInternalId(node, target); + if (match != NULL) + return match; + } + + return NULL; +} + +/* FIXME: Handle staged lookups for AccessExpressions. */ +/* FIXME: Similar to above, disallow inspection of struct internals outside of + * AccessExpressions. */ +Node *LookupId(Node *current, Node *prev, char *target) +{ + if (current == NULL) + return NULL; + + Node *match; + + /* First inspect the current node to see if it contains the target + * identifier. */ + match = InspectNode(current, target); + if (match != NULL) + return match; + + /* If this is the start of our search, we should not attempt to look at + * child nodes. Only looking up the AST is valid at this point. + * + * This has the notable side-effect that this function will return NULL if + * you attempt to look up a struct's internals starting from the node + * representing the struct itself. The same is true for functions. */ + if (prev == NULL) + return LookupId(current->parent, current, target); + + uint32_t i; + uint32_t idxLimit; + switch (current->syntaxKind) + { + case DeclarationSequence: + for (i = 0; i < current->declarationSequence.count; i += 1) + { + Node *decl = current->declarationSequence.sequence[i]; + match = InspectNode(decl, target); + if (match != NULL) + return match; + } + break; + case StatementSequence: + idxLimit = current->statementSequence.count; + for (i = 0; i < current->statementSequence.count; i += 1) + { + if (current->statementSequence.sequence[i] == prev) + { + idxLimit = i; + break; + } + } + + for (i = 0; i < idxLimit; i += 1) + { + Node *stmt = current->statementSequence.sequence[i]; + if (stmt == prev) + break; + + match = InspectNode(stmt, target); + if (match != NULL) + return match; + } + break; + } + + return LookupId(current->parent, current, target); +} + +/* FIXME: This function should be extended to handle multi-stage ID lookups for + * AccessExpression nodes. */ +/* FIXME: Make this function return an error status object of some kind. + * A non-OK status should halt compilation. */ +void ValidateIdentifiers(Node *node) +{ + if (node == NULL) + return; + + /* Skip over generic arguments. They contain Identifiers but are not + * actually identifiers, they declare types. */ + if (node->syntaxKind == GenericArguments) + return; + + if (node->syntaxKind != Identifier) + { + Recurse(node, *ValidateIdentifiers); + return; + } + + char *name = node->identifier.name; + Node *decl = LookupId(node, NULL, name); + if (decl == NULL) + { + /* FIXME: Express this case as an error with AST information, see the + * FIXME comment above. */ + fprintf( + stderr, + "wraith: Could not find definition of identifier %s.\n", + name); + } +} + +/* FIXME: This function should be extended to handle multi-stage ID lookups for + * AccessExpression nodes. */ +void TagIdentifierTypes(Node *node) +{ + if (node == NULL) + return; + + switch (node->syntaxKind) + { + case AllocExpression: + node->typeTag = MakeTypeTag(node); + break; + + case Declaration: + node->declaration.identifier->typeTag = MakeTypeTag(node); + break; + + case FunctionDeclaration: + node->functionDeclaration.functionSignature->functionSignature + .identifier->typeTag = MakeTypeTag(node); + break; + + case StructDeclaration: + node->structDeclaration.identifier->typeTag = MakeTypeTag(node); + break; + + case GenericArgument: + node->genericArgument.identifier->typeTag = MakeTypeTag(node); + break; + + case Identifier: + { + if (node->typeTag != NULL) + return; + + char *name = node->identifier.name; + Node *declaration = LookupId(node, NULL, name); + /* FIXME: Remove this case once ValidateIdentifiers returns error status + * info and halts compilation. See ValidateIdentifiers FIXME. */ + if (declaration == NULL) + { + TypeTag *tag = (TypeTag *)malloc(sizeof(TypeTag)); + tag->type = Unknown; + node->typeTag = tag; + } + else + { + node->typeTag = declaration->typeTag; + } + break; + } + } + + Recurse(node, *TagIdentifierTypes); +} + +Node *LookupType(Node *current, char *target) +{ + if (current == NULL) + return NULL; + + switch (current->syntaxKind) + { + /* If we've encountered a function declaration, check to see if it's generic + * and, if so, if one of its type parameters is the target. */ + case FunctionDeclaration: + { + Node *typeArgs = current->functionDeclaration.functionSignature + ->functionSignature.genericArguments; + uint32_t i; + for (i = 0; i < typeArgs->genericArguments.count; i += 1) + { + Node *arg = typeArgs->genericArguments.arguments[i]; + Node *argId = arg->genericArgument.identifier; + char *argName = argId->identifier.name; + /* note: return the GenericArgument, not the Identifier, so that + * the caller can differentiate between generics and customs. */ + if (strcmp(target, argName) == 0) + return arg; + } + + return LookupType(current->parent, target); + } + + case StructDeclaration: + { + Node *structId = GetIdFromStruct(current); + if (strcmp(target, structId->identifier.name) == 0) + return structId; + + return LookupType(current->parent, target); + } + + /* If we encounter a declaration sequence, search each of its children for + * struct definitions in case one of them is the target. */ + case DeclarationSequence: + { + uint32_t i; + for (i = 0; i < current->declarationSequence.count; i += 1) + { + Node *decl = current->declarationSequence.sequence[i]; + if (decl->syntaxKind == StructDeclaration) + { + Node *structId = GetIdFromStruct(decl); + if (strcmp(target, structId->identifier.name) == 0) + return structId; + } + } + + return LookupType(current->parent, target); + } + + default: + return LookupType(current->parent, target); + } +} + +/* FIXME: This function should be modified to handle type parameters over + * structs. */ +void ConvertCustomsToGenerics(Node *node) +{ + if (node == NULL) + return; + + switch (node->syntaxKind) + { + case Declaration: + { + Node *id = node->declaration.identifier; + Node *type = node->declaration.type->type.typeNode; + if (type->syntaxKind == CustomTypeNode) + { + char *target = id->typeTag->value.customType; + Node *typeLookup = LookupType(node, target); + if (typeLookup != NULL && typeLookup->syntaxKind == GenericArgument) + { + id->typeTag->type = Generic; + free(node->declaration.type); + node->declaration.type = + MakeGenericTypeNode(id->typeTag->value.genericType); + } + } + break; + } + + case FunctionSignature: + { + Node *id = node->functionSignature.identifier; + Node *type = node->functionSignature.type->type.typeNode; + if (type->syntaxKind == CustomTypeNode) + { + char *target = id->typeTag->value.customType; + Node *typeLookup = LookupType(node, target); + if (typeLookup != NULL && typeLookup->syntaxKind == GenericArgument) + { + id->typeTag->type = Generic; + free(node->functionSignature.type); + node->functionSignature.type = + MakeGenericTypeNode(id->typeTag->value.genericType); + } + } + break; + } + } + + Recurse(node, *ConvertCustomsToGenerics); +} diff --git a/src/validation.h b/src/validation.h new file mode 100644 index 0000000..50c5432 --- /dev/null +++ b/src/validation.h @@ -0,0 +1,10 @@ +#ifndef WRAITH_VALIDATION_H +#define WRAITH_VALIDATION_H + +#include "ast.h" + +void ValidateIdentifiers(Node *node); +void TagIdentifierTypes(Node *node); +void ConvertCustomsToGenerics(Node *node); + +#endif /* WRAITH_VALIDATION_H */