From 24bcef6d87f6c4fe33c3c8e726e6a1a4509aeeaa Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Wed, 19 May 2021 15:45:07 -0700 Subject: [PATCH 01/17] initial generics stuff --- generators/wraith.y | 32 ++++++- src/ast.c | 44 ++++++++- src/ast.h | 22 ++++- src/codegen.c | 229 ++++++++++++++++++++++++++++---------------- 4 files changed, 240 insertions(+), 87 deletions(-) diff --git a/generators/wraith.y b/generators/wraith.y index 7aaf70a..3176d60 100644 --- a/generators/wraith.y +++ b/generators/wraith.y @@ -307,14 +307,38 @@ Body : LEFT_BRACE Statements RIGHT_BRACE $$ = $2; } -FunctionSignature : Identifier LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type +GenericArgument : Identifier { - $$ = MakeFunctionSignatureNode($1, $6, $3, MakeFunctionModifiersNode(NULL, 0)); + $$ = MakeGenericArgumentNode($1, NULL); } - | STATIC Identifier LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type + +GenericArguments : GenericArgument + { + $$ = StartGenericArgumentsNode($1); + } + | GenericArguments COMMA GenericArgument + { + $$ = AddGenericArgument($1, $3); + } + +GenericArgumentsClause : LESS_THAN GenericArguments GREATER_THAN + { + $$ = $2; + } + | + { + $$ = MakeEmptyGenericArgumentsNode(); + } + + +FunctionSignature : Identifier GenericArgumentsClause LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type + { + $$ = MakeFunctionSignatureNode($1, $7, $4, MakeFunctionModifiersNode(NULL, 0), $2); + } + | STATIC Identifier GenericArgumentsClause LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type { Node *modifier = MakeStaticNode(); - $$ = MakeFunctionSignatureNode($2, $7, $4, MakeFunctionModifiersNode(&modifier, 1)); + $$ = MakeFunctionSignatureNode($2, $8, $5, MakeFunctionModifiersNode(&modifier, 1), $3); } FunctionDeclaration : FunctionSignature Body diff --git a/src/ast.c b/src/ast.c index 74ee4d9..4c578e9 100644 --- a/src/ast.c +++ b/src/ast.c @@ -271,7 +271,8 @@ Node *MakeFunctionSignatureNode( Node *identifierNode, Node *typeNode, Node *arguments, - Node *modifiersNode) + Node *modifiersNode, + Node *genericArgumentsNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = FunctionSignature; @@ -279,6 +280,7 @@ Node *MakeFunctionSignatureNode( node->functionSignature.type = typeNode; node->functionSignature.arguments = arguments; node->functionSignature.modifiers = modifiersNode; + node->functionSignature.genericArguments = genericArgumentsNode; return node; } @@ -359,6 +361,46 @@ Node *MakeEmptyFunctionArgumentSequenceNode() return node; } +Node *MakeGenericArgumentNode(Node *identifierNode, Node *constraintNode) +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = GenericArgument; + node->genericArgument.identifier = identifierNode; + node->genericArgument.constraint = constraintNode; + return node; +} + +Node *StartGenericArgumentsNode(Node *genericArgumentNode) +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = GenericArguments; + node->genericArguments.arguments = (Node **)malloc(sizeof(Node *)); + node->genericArguments.arguments[0] = genericArgumentNode; + node->genericArguments.count = 1; + return node; +} + +Node *AddGenericArgument(Node *genericArgumentsNode, Node *genericArgumentNode) +{ + genericArgumentsNode->genericArguments.arguments = (Node **)realloc( + genericArgumentsNode->genericArguments.arguments, + sizeof(Node *) * (genericArgumentsNode->genericArguments.count + 1)); + genericArgumentsNode->genericArguments + .arguments[genericArgumentsNode->genericArguments.count] = + genericArgumentNode; + genericArgumentsNode->genericArguments.count += 1; + return genericArgumentsNode; +} + +Node *MakeEmptyGenericArgumentsNode() +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = GenericArguments; + node->genericArguments.arguments = NULL; + node->genericArguments.count = 0; + return node; +} + Node *MakeFunctionCallExpressionNode( Node *identifierNode, Node *argumentSequenceNode) diff --git a/src/ast.h b/src/ast.h index 60e954d..a4b367c 100644 --- a/src/ast.h +++ b/src/ast.h @@ -30,6 +30,8 @@ typedef enum FunctionModifiers, FunctionSignature, FunctionSignatureArguments, + GenericArgument, + GenericArguments, Identifier, IfStatement, IfElseStatement, @@ -192,6 +194,7 @@ struct Node Node *type; Node *arguments; Node *modifiers; + Node *genericArguments; } functionSignature; struct @@ -200,6 +203,18 @@ struct Node uint32_t count; } functionSignatureArguments; + struct + { + Node *identifier; + Node *constraint; + } genericArgument; + + struct + { + Node **arguments; + uint32_t count; + } genericArguments; + struct { char *name; @@ -306,10 +321,15 @@ Node *MakeFunctionSignatureNode( Node *identifierNode, Node *typeNode, Node *argumentsNode, - Node *modifiersNode); + Node *modifiersNode, + Node *genericArgumentsNode); Node *MakeFunctionDeclarationNode( Node *functionSignatureNode, Node *functionBodyNode); +Node *MakeGenericArgumentNode(Node *identifierNode, Node *constraintNode); +Node *MakeEmptyGenericArgumentsNode(); +Node *StartGenericArgumentsNode(Node *genericArgumentNode); +Node *AddGenericArgument(Node *genericArgumentsNode, Node *genericArgumentNode); Node *MakeStructDeclarationNode( Node *identifierNode, Node *declarationSequenceNode); diff --git a/src/codegen.c b/src/codegen.c index e8a5fb9..c24eb2e 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -56,6 +56,24 @@ typedef struct StructTypeFunction uint8_t isStatic; } StructTypeFunction; +typedef struct StructTypeGenericFunction +{ + char *name; + Node *functionDeclarationNode; +} StructTypeGenericFunction; + +typedef struct MonomorphizedGenericFunctionHashEntry +{ + uint64_t key; + StructTypeFunction function; +} MonomorphizedGenericFunctionHashEntry; + +typedef struct MonomorphizedGenericFunctionHashArray +{ + MonomorphizedGenericFunctionHashEntry *elements; + uint32_t count; +} MonomorphizedGenericFunctionHashArray; + typedef struct StructTypeDeclaration { char *name; @@ -66,6 +84,11 @@ typedef struct StructTypeDeclaration StructTypeFunction *functions; uint32_t functionCount; + + StructTypeGenericFunction *genericFunctions; + uint32_t genericFunctionCount; + + MonomorphizedGenericFunctionHashArray monomorphizedGenericFunctions; } StructTypeDeclaration; StructTypeDeclaration *structTypeDeclarations; @@ -271,6 +294,10 @@ static void AddStructDeclaration( structTypeDeclarations[index].fieldCount = 0; structTypeDeclarations[index].functions = NULL; structTypeDeclarations[index].functionCount = 0; + structTypeDeclarations[index].genericFunctions = NULL; + structTypeDeclarations[index].genericFunctionCount = 0; + structTypeDeclarations[index].monomorphizedGenericFunctions.elements = NULL; + structTypeDeclarations[index].monomorphizedGenericFunctions.count = 0; for (i = 0; i < fieldDeclarationCount; i += 1) { @@ -287,6 +314,7 @@ static void AddStructDeclaration( structTypeDeclarationCount += 1; } +/* FIXME: pass the declaration itself */ static void DeclareStructFunction( LLVMTypeRef wStructPointerType, LLVMValueRef function, @@ -318,6 +346,31 @@ static void DeclareStructFunction( fprintf(stderr, "Could not find struct type for function!\n"); } +/* FIXME: pass the declaration itself */ +static void DeclareGenericStructFunction( + LLVMTypeRef wStructPointerType, + Node *functionDeclarationNode, + char *name) +{ + uint32_t i, index; + + for (i = 0; i < structTypeDeclarationCount; i += 1) + { + if (structTypeDeclarations[i].structPointerType == wStructPointerType) + { + index = structTypeDeclarations[i].genericFunctionCount; + structTypeDeclarations[i].genericFunctions[index].name = + strdup(name); + structTypeDeclarations[i] + .genericFunctions[index] + .functionDeclarationNode = functionDeclarationNode; + structTypeDeclarations[i].genericFunctionCount += 1; + + return; + } + } +} + static LLVMTypeRef LookupCustomType(char *name) { uint32_t i; @@ -1023,101 +1076,115 @@ static void CompileFunction( } } - if (!isStatic) - { - paramTypes[paramIndex] = wStructPointerType; - paramIndex += 1; - } - - PushScopeFrame(scope); - - /* FIXME: should work for non-primitive types */ - for (i = 0; i < functionSignature->functionSignature.arguments - ->functionSignatureArguments.count; - i += 1) - { - paramTypes[paramIndex] = - ResolveType(functionSignature->functionSignature.arguments - ->functionSignatureArguments.sequence[i] - ->declaration.type); - paramIndex += 1; - } - - LLVMTypeRef returnType = - ResolveType(functionSignature->functionSignature.type); - LLVMTypeRef functionType = - LLVMFunctionType(returnType, paramTypes, paramIndex, 0); - char *functionName = strdup(parentStructName); strcat(functionName, "_"); strcat( functionName, functionSignature->functionSignature.identifier->identifier.name); - LLVMValueRef function = LLVMAddFunction(module, functionName, functionType); - free(functionName); - DeclareStructFunction( - wStructPointerType, - function, - returnType, - isStatic, - functionSignature->functionSignature.identifier->identifier.name); - - LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry"); - LLVMBuilderRef builder = LLVMCreateBuilder(); - LLVMPositionBuilderAtEnd(builder, entry); - - if (!isStatic) + if (functionSignature->functionSignature.genericArguments->genericArguments + .count == 0) { - LLVMValueRef wStructPointer = LLVMGetParam(function, 0); - AddStructVariablesToScope(builder, wStructPointer); - } + PushScopeFrame(scope); - for (i = 0; i < functionSignature->functionSignature.arguments - ->functionSignatureArguments.count; - i += 1) - { - char *ptrName = strdup(functionSignature->functionSignature.arguments - ->functionSignatureArguments.sequence[i] - ->declaration.identifier->identifier.name); - strcat(ptrName, "_ptr"); - LLVMValueRef argument = LLVMGetParam(function, i + !isStatic); - LLVMValueRef argumentCopy = - LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName); - LLVMBuildStore(builder, argument, argumentCopy); - free(ptrName); - AddLocalVariable( - scope, - argumentCopy, - NULL, - functionSignature->functionSignature.arguments - ->functionSignatureArguments.sequence[i] - ->declaration.identifier->identifier.name); - } + if (!isStatic) + { + paramTypes[paramIndex] = wStructPointerType; + paramIndex += 1; + } - for (i = 0; i < functionBody->statementSequence.count; i += 1) - { - CompileStatement( - builder, + for (i = 0; i < functionSignature->functionSignature.arguments + ->functionSignatureArguments.count; + i += 1) + { + paramTypes[paramIndex] = + ResolveType(functionSignature->functionSignature.arguments + ->functionSignatureArguments.sequence[i] + ->declaration.type); + paramIndex += 1; + } + + LLVMTypeRef returnType = + ResolveType(functionSignature->functionSignature.type); + LLVMTypeRef functionType = + LLVMFunctionType(returnType, paramTypes, paramIndex, 0); + + LLVMValueRef function = + LLVMAddFunction(module, functionName, functionType); + + DeclareStructFunction( + wStructPointerType, function, - functionBody->statementSequence.sequence[i]); + returnType, + isStatic, + functionSignature->functionSignature.identifier->identifier.name); + + LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry"); + LLVMBuilderRef builder = LLVMCreateBuilder(); + LLVMPositionBuilderAtEnd(builder, entry); + + if (!isStatic) + { + LLVMValueRef wStructPointer = LLVMGetParam(function, 0); + AddStructVariablesToScope(builder, wStructPointer); + } + + for (i = 0; i < functionSignature->functionSignature.arguments + ->functionSignatureArguments.count; + i += 1) + { + char *ptrName = + strdup(functionSignature->functionSignature.arguments + ->functionSignatureArguments.sequence[i] + ->declaration.identifier->identifier.name); + strcat(ptrName, "_ptr"); + LLVMValueRef argument = LLVMGetParam(function, i + !isStatic); + LLVMValueRef argumentCopy = + LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName); + LLVMBuildStore(builder, argument, argumentCopy); + free(ptrName); + AddLocalVariable( + scope, + argumentCopy, + NULL, + functionSignature->functionSignature.arguments + ->functionSignatureArguments.sequence[i] + ->declaration.identifier->identifier.name); + } + + for (i = 0; i < functionBody->statementSequence.count; i += 1) + { + CompileStatement( + builder, + function, + functionBody->statementSequence.sequence[i]); + } + + hasReturn = LLVMGetBasicBlockTerminator( + LLVMGetLastBasicBlock(function)) != NULL; + + if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn) + { + LLVMBuildRetVoid(builder); + } + else if (LLVMGetTypeKind(returnType) != LLVMVoidTypeKind && !hasReturn) + { + fprintf(stderr, "Return statement not provided!"); + } + + LLVMDisposeBuilder(builder); + + PopScopeFrame(scope); } - - hasReturn = - LLVMGetBasicBlockTerminator(LLVMGetLastBasicBlock(function)) != NULL; - - if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn) + else { - LLVMBuildRetVoid(builder); - } - else if (LLVMGetTypeKind(returnType) != LLVMVoidTypeKind && !hasReturn) - { - fprintf(stderr, "Return statement not provided!"); + DeclareGenericStructFunction( + wStructPointerType, + functionDeclaration, + functionName); } - PopScopeFrame(scope); - - LLVMDisposeBuilder(builder); + free(functionName); } static void CompileStruct( -- 2.25.1 From 0d94e89045b62b9e3d35860ce2f0cff66d70ef78 Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Wed, 19 May 2021 18:09:33 -0700 Subject: [PATCH 02/17] skeleton of generic function lookup --- src/codegen.c | 76 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/src/codegen.c b/src/codegen.c index c24eb2e..bcc3eb4 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -411,9 +411,22 @@ static LLVMTypeRef ResolveType(Node *typeNode) } } +static LLVMValueRef LookupGenericFunction( + StructTypeGenericFunction *genericFunction, + TypeTag **genericArgumentTypes, + uint32_t genericArgumentTypeCount, + LLVMTypeRef *pReturnType, + uint8_t *pStatic) +{ + /* TODO: hash the argument types */ + /* TODO: compile the monomorphism if doesnt exist */ +} + static LLVMValueRef LookupFunctionByType( LLVMTypeRef structType, char *name, + TypeTag **genericArgumentTypes, + uint32_t genericArgumentTypeCount, LLVMTypeRef *pReturnType, uint8_t *pStatic) { @@ -434,6 +447,22 @@ static LLVMValueRef LookupFunctionByType( return structTypeDeclarations[i].functions[j].function; } } + + for (j = 0; j < structTypeDeclarations[i].genericFunctionCount; + j += 1) + { + if (strcmp( + structTypeDeclarations[i].genericFunctions[j].name, + name) == 0) + { + return LookupGenericFunction( + &structTypeDeclarations[i].genericFunctions[j], + genericArgumentTypes, + genericArgumentTypeCount, + pReturnType, + pStatic); + } + } } } @@ -444,6 +473,8 @@ static LLVMValueRef LookupFunctionByType( static LLVMValueRef LookupFunctionByPointerType( LLVMTypeRef structPointerType, char *name, + TypeTag **genericArgumentTypes, + uint32_t genericArgumentTypeCount, LLVMTypeRef *pReturnType, uint8_t *pStatic) { @@ -464,6 +495,22 @@ static LLVMValueRef LookupFunctionByPointerType( return structTypeDeclarations[i].functions[j].function; } } + + for (j = 0; j < structTypeDeclarations[i].genericFunctionCount; + j += 1) + { + if (strcmp( + structTypeDeclarations[i].genericFunctions[j].name, + name) == 0) + { + return LookupGenericFunction( + &structTypeDeclarations[i].genericFunctions[j], + genericArgumentTypes, + genericArgumentTypeCount, + pReturnType, + pStatic); + } + } } } @@ -474,12 +521,16 @@ static LLVMValueRef LookupFunctionByPointerType( static LLVMValueRef LookupFunctionByInstance( LLVMValueRef structPointer, char *name, + TypeTag **genericArgumentTypes, + uint32_t genericArgumentTypeCount, LLVMTypeRef *pReturnType, uint8_t *pStatic) { return LookupFunctionByPointerType( LLVMTypeOf(structPointer), name, + genericArgumentTypes, + genericArgumentTypeCount, pReturnType, pStatic); } @@ -591,16 +642,37 @@ static LLVMValueRef CompileFunctionCallExpression( { uint32_t i; uint32_t argumentCount = 0; + uint32_t genericArgumentCount = 0; LLVMValueRef args [functionCallExpression->functionCallExpression.argumentSequence ->functionArgumentSequence.count + 1]; + TypeTag *genericArgumentTypes[functionCallExpression->functionCallExpression + .argumentSequence + ->functionArgumentSequence.count]; LLVMValueRef function; uint8_t isStatic; LLVMValueRef structInstance; LLVMTypeRef functionReturnType; char *returnName = ""; + for (i = 0; i < functionCallExpression->functionCallExpression + .argumentSequence->functionArgumentSequence.count; + i += 1) + { + if (functionCallExpression->functionCallExpression.argumentSequence + ->functionArgumentSequence.sequence[i] + ->syntaxKind == GenericArgument) + { + genericArgumentTypes[genericArgumentCount] = + functionCallExpression->functionCallExpression.argumentSequence + ->functionArgumentSequence.sequence[i] + ->typeTag; + + genericArgumentCount += 1; + } + } + /* FIXME: this needs to be recursive on access chains */ /* FIXME: this needs to be able to call same-struct functions implicitly */ if (functionCallExpression->functionCallExpression.identifier->syntaxKind == @@ -616,6 +688,8 @@ static LLVMValueRef CompileFunctionCallExpression( typeReference, functionCallExpression->functionCallExpression.identifier ->accessExpression.accessor->identifier.name, + genericArgumentTypes, + genericArgumentCount, &functionReturnType, &isStatic); } @@ -628,6 +702,8 @@ static LLVMValueRef CompileFunctionCallExpression( structInstance, functionCallExpression->functionCallExpression.identifier ->accessExpression.accessor->identifier.name, + genericArgumentTypes, + genericArgumentCount, &functionReturnType, &isStatic); } -- 2.25.1 From 8a3920918c2315b428922b550363a6003f1fa057 Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Thu, 20 May 2021 13:18:57 -0700 Subject: [PATCH 03/17] generic function lookup --- CMakeLists.txt | 2 + src/codegen.c | 99 +++++++++++++++++++++++++++++++++++++++++++------- src/util.c | 14 ++++++- src/util.h | 2 + 4 files changed, 103 insertions(+), 14 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a3cbf9e..9898440 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,10 +43,12 @@ add_executable( src/codegen.h src/identcheck.h src/parser.h + src/util.h src/ast.c src/codegen.c src/identcheck.c src/parser.c + src/util.c src/main.c # Generated code ${BISON_Parser_OUTPUTS} diff --git a/src/codegen.c b/src/codegen.c index bcc3eb4..f94e192 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -14,6 +14,7 @@ #include #include "ast.h" +#include "util.h" typedef struct LocalVariable { @@ -56,15 +57,11 @@ typedef struct StructTypeFunction uint8_t isStatic; } StructTypeFunction; -typedef struct StructTypeGenericFunction -{ - char *name; - Node *functionDeclarationNode; -} StructTypeGenericFunction; - typedef struct MonomorphizedGenericFunctionHashEntry { uint64_t key; + TypeTag **types; + uint32_t typeCount; StructTypeFunction function; } MonomorphizedGenericFunctionHashEntry; @@ -74,6 +71,17 @@ typedef struct MonomorphizedGenericFunctionHashArray uint32_t count; } MonomorphizedGenericFunctionHashArray; +#define NUM_MONOMORPHIZED_HASH_BUCKETS 1031 + +typedef struct StructTypeGenericFunction +{ + char *name; + Node *functionDeclarationNode; + uint8_t isStatic; + MonomorphizedGenericFunctionHashArray + monomorphizedFunctions[NUM_MONOMORPHIZED_HASH_BUCKETS]; +} StructTypeGenericFunction; + typedef struct StructTypeDeclaration { char *name; @@ -87,8 +95,6 @@ typedef struct StructTypeDeclaration StructTypeGenericFunction *genericFunctions; uint32_t genericFunctionCount; - - MonomorphizedGenericFunctionHashArray monomorphizedGenericFunctions; } StructTypeDeclaration; StructTypeDeclaration *structTypeDeclarations; @@ -296,8 +302,6 @@ static void AddStructDeclaration( structTypeDeclarations[index].functionCount = 0; structTypeDeclarations[index].genericFunctions = NULL; structTypeDeclarations[index].genericFunctionCount = 0; - structTypeDeclarations[index].monomorphizedGenericFunctions.elements = NULL; - structTypeDeclarations[index].monomorphizedGenericFunctions.count = 0; for (i = 0; i < fieldDeclarationCount; i += 1) { @@ -350,9 +354,10 @@ static void DeclareStructFunction( static void DeclareGenericStructFunction( LLVMTypeRef wStructPointerType, Node *functionDeclarationNode, + uint8_t isStatic, char *name) { - uint32_t i, index; + uint32_t i, j, index; for (i = 0; i < structTypeDeclarationCount; i += 1) { @@ -364,6 +369,21 @@ static void DeclareGenericStructFunction( structTypeDeclarations[i] .genericFunctions[index] .functionDeclarationNode = functionDeclarationNode; + structTypeDeclarations[i].genericFunctions[index].isStatic = + isStatic; + + for (j = 0; j < NUM_MONOMORPHIZED_HASH_BUCKETS; j += 1) + { + structTypeDeclarations[i] + .genericFunctions[index] + .monomorphizedFunctions[j] + .elements = NULL; + structTypeDeclarations[i] + .genericFunctions[index] + .monomorphizedFunctions[j] + .count = 0; + } + structTypeDeclarations[i].genericFunctionCount += 1; return; @@ -411,6 +431,20 @@ static LLVMTypeRef ResolveType(Node *typeNode) } } +static inline uint64_t HashTypeTags(TypeTag **tags, uint32_t count) +{ + const uint64_t HASH_FACTOR = 97; + uint64_t result = 1; + uint32_t i; + + for (i = 0; i < count; i += 1) + { + result *= HASH_FACTOR + str_hash(TypeTagToString(tags[i])); + } + + return result; +} + static LLVMValueRef LookupGenericFunction( StructTypeGenericFunction *genericFunction, TypeTag **genericArgumentTypes, @@ -418,8 +452,46 @@ static LLVMValueRef LookupGenericFunction( LLVMTypeRef *pReturnType, uint8_t *pStatic) { - /* TODO: hash the argument types */ - /* TODO: compile the monomorphism if doesnt exist */ + uint32_t i, j; + uint64_t typeHash = + HashTypeTags(genericArgumentTypes, genericArgumentTypeCount); + uint8_t match = 0; + + MonomorphizedGenericFunctionHashArray *hashArray = + &genericFunction->monomorphizedFunctions + [typeHash % NUM_MONOMORPHIZED_HASH_BUCKETS]; + + MonomorphizedGenericFunctionHashEntry *hashEntry = NULL; + for (i = 0; i < hashArray->count; i += 1) + { + match = 1; + + for (j = 0; j < hashArray->elements[i].typeCount; j += 1) + { + if (hashArray->elements[i].types[j] != genericArgumentTypes[j]) + { + match = 0; + break; + } + } + + if (match) + { + hashEntry = &hashArray->elements[i]; + break; + } + } + + if (hashEntry == NULL) + { + + /* TODO: compile */ + } + + *pReturnType = hashEntry->function.returnType; + *pStatic = hashEntry->function.isStatic; + + return hashEntry->function.function; } static LLVMValueRef LookupFunctionByType( @@ -1257,6 +1329,7 @@ static void CompileFunction( DeclareGenericStructFunction( wStructPointerType, functionDeclaration, + isStatic, functionName); } diff --git a/src/util.c b/src/util.c index 8001d03..aa5b4fb 100644 --- a/src/util.c +++ b/src/util.c @@ -1,7 +1,6 @@ #include "util.h" #include -#include char *strdup(const char *s) { @@ -15,3 +14,16 @@ char *strdup(const char *s) memcpy(result, s, slen + 1); return result; } + +uint64_t str_hash(char *str) +{ + uint64_t hash = 5381; + size_t c; + + while (c = *str++) + { + hash = ((hash << 5) + hash) + c; /* hash * 33 + c */ + } + + return hash; +} diff --git a/src/util.h b/src/util.h index 884fa24..108211b 100644 --- a/src/util.h +++ b/src/util.h @@ -1,8 +1,10 @@ #ifndef WRAITH_UTIL_H #define WRAITH_UTIL_H +#include #include char *strdup(const char *s); +uint64_t str_hash(char *str); #endif /* WRAITH_UTIL_H */ -- 2.25.1 From d641f713de7a9980d67b99e5334360c0f25967cb Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Fri, 21 May 2021 19:52:13 -0700 Subject: [PATCH 04/17] progress on generics --- src/ast.c | 4 + src/codegen.c | 526 +++++++++++++++++++++++++++++++++++------------ src/identcheck.c | 28 ++- src/identcheck.h | 4 +- src/util.c | 2 +- 5 files changed, 429 insertions(+), 135 deletions(-) diff --git a/src/ast.c b/src/ast.c index 4c578e9..dde4693 100644 --- a/src/ast.c +++ b/src/ast.c @@ -740,6 +740,10 @@ TypeTag *MakeTypeTag(Node *node) ->functionSignature.type); break; + case AllocExpression: + tag = MakeTypeTag(node->allocExpression.type); + break; + default: fprintf( stderr, diff --git a/src/codegen.c b/src/codegen.c index f94e192..8f82dff 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -23,6 +23,12 @@ typedef struct LocalVariable LLVMValueRef value; } LocalVariable; +typedef struct LocalGenericType +{ + char *name; + LLVMTypeRef type; +} LocalGenericType; + typedef struct FunctionArgument { char *name; @@ -33,6 +39,9 @@ typedef struct ScopeFrame { LocalVariable *localVariables; uint32_t localVariableCount; + + LocalGenericType *genericTypes; + uint32_t genericTypeCount; } ScopeFrame; typedef struct Scope @@ -75,6 +84,8 @@ typedef struct MonomorphizedGenericFunctionHashArray typedef struct StructTypeGenericFunction { + char *parentStructName; + LLVMTypeRef parentStructPointerType; char *name; Node *functionDeclarationNode; uint8_t isStatic; @@ -100,6 +111,18 @@ typedef struct StructTypeDeclaration StructTypeDeclaration *structTypeDeclarations; uint32_t structTypeDeclarationCount; +/* FUNCTION FORWARD DECLARATIONS */ +static LLVMBasicBlockRef CompileStatement( + LLVMModuleRef module, + LLVMBuilderRef builder, + LLVMValueRef function, + Node *statement); + +static LLVMValueRef CompileExpression( + LLVMModuleRef module, + LLVMBuilderRef builder, + Node *expression); + static Scope *CreateScope() { Scope *scope = malloc(sizeof(Scope)); @@ -107,6 +130,8 @@ static Scope *CreateScope() scope->scopeStack = malloc(sizeof(ScopeFrame)); scope->scopeStack[0].localVariableCount = 0; scope->scopeStack[0].localVariables = NULL; + scope->scopeStack[0].genericTypeCount = 0; + scope->scopeStack[0].genericTypes = NULL; scope->scopeStackCount = 1; return scope; @@ -120,6 +145,8 @@ static void PushScopeFrame(Scope *scope) sizeof(ScopeFrame) * (scope->scopeStackCount + 1)); scope->scopeStack[index].localVariableCount = 0; scope->scopeStack[index].localVariables = NULL; + scope->scopeStack[index].genericTypeCount = 0; + scope->scopeStack[index].genericTypes = NULL; scope->scopeStackCount += 1; } @@ -138,31 +165,21 @@ static void PopScopeFrame(Scope *scope) free(scope->scopeStack[index].localVariables); } + if (scope->scopeStack[index].genericTypes != NULL) + { + for (i = 0; i < scope->scopeStack[index].genericTypeCount; i += 1) + { + free(scope->scopeStack[index].genericTypes[i].name); + } + free(scope->scopeStack[index].localVariables); + } + scope->scopeStackCount -= 1; scope->scopeStack = realloc(scope->scopeStack, sizeof(ScopeFrame) * scope->scopeStackCount); } -static void AddLocalVariable( - Scope *scope, - LLVMValueRef pointer, /* can be NULL */ - LLVMValueRef value, /* can be NULL */ - char *name) -{ - ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1]; - uint32_t index = scopeFrame->localVariableCount; - - scopeFrame->localVariables = realloc( - scopeFrame->localVariables, - sizeof(LocalVariable) * (scopeFrame->localVariableCount + 1)); - scopeFrame->localVariables[index].name = strdup(name); - scopeFrame->localVariables[index].pointer = pointer; - scopeFrame->localVariables[index].value = value; - - scopeFrame->localVariableCount += 1; -} - static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type) { switch (type) @@ -184,6 +201,120 @@ static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type) return NULL; } +static LLVMTypeRef LookupCustomType(char *name) +{ + int32_t i, j; + + for (i = scope->scopeStackCount - 1; i >= 0; i -= 1) + { + for (j = 0; j < scope->scopeStack[i].genericTypeCount; j += 1) + { + if (strcmp(scope->scopeStack[i].genericTypes[j].name, name) == 0) + { + return scope->scopeStack[i].genericTypes[j].type; + } + } + } + + for (i = 0; i < structTypeDeclarationCount; i += 1) + { + if (strcmp(structTypeDeclarations[i].name, name) == 0) + { + return structTypeDeclarations[i].structType; + } + } + + fprintf(stderr, "Could not find struct type!\n"); + return NULL; +} + +static LLVMTypeRef ResolveType(TypeTag *typeTag) +{ + if (typeTag->type == Primitive) + { + return WraithTypeToLLVMType(typeTag->value.primitiveType); + } + else if (typeTag->type == Custom) + { + return LookupCustomType(typeTag->value.customType); + } + else if (typeTag->type == Reference) + { + return LLVMPointerType(ResolveType(typeTag->value.referenceType), 0); + } + else + { + fprintf(stderr, "Unknown type node!\n"); + return NULL; + } +} + +static void AddLocalVariable( + Scope *scope, + LLVMValueRef pointer, /* can be NULL */ + LLVMValueRef value, /* can be NULL */ + char *name) +{ + ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1]; + uint32_t index = scopeFrame->localVariableCount; + + scopeFrame->localVariables = realloc( + scopeFrame->localVariables, + sizeof(LocalVariable) * (scopeFrame->localVariableCount + 1)); + scopeFrame->localVariables[index].name = strdup(name); + scopeFrame->localVariables[index].pointer = pointer; + scopeFrame->localVariables[index].value = value; + + scopeFrame->localVariableCount += 1; +} + +static void AddGenericVariable(Scope *scope, TypeTag *typeTag, char *name) +{ + ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1]; + uint32_t index = scopeFrame->genericTypeCount; + + scopeFrame->genericTypes = realloc( + scopeFrame->genericTypes, + sizeof(LocalGenericType) * (scopeFrame->genericTypeCount + 1)); + scopeFrame->genericTypes[index].name = strdup(name); + scopeFrame->genericTypes[index].type = ResolveType(typeTag); + + scopeFrame->genericTypeCount += 1; +} + +static void AddStructVariablesToScope( + LLVMBuilderRef builder, + LLVMValueRef structPointer) +{ + uint32_t i, j; + + for (i = 0; i < structTypeDeclarationCount; i += 1) + { + if (structTypeDeclarations[i].structPointerType == + LLVMTypeOf(structPointer)) + { + for (j = 0; j < structTypeDeclarations[i].fieldCount; j += 1) + { + char *ptrName = + strdup(structTypeDeclarations[i].fields[j].name); + strcat(ptrName, "_ptr"); + LLVMValueRef elementPointer = LLVMBuildStructGEP( + builder, + structPointer, + structTypeDeclarations[i].fields[j].index, + ptrName); + free(ptrName); + + AddLocalVariable( + scope, + elementPointer, + NULL, + structTypeDeclarations[i].fields[j].name); + } + } + } +} + static LLVMTypeRef FindStructType(char *name) { uint32_t i; @@ -355,6 +486,7 @@ static void DeclareGenericStructFunction( LLVMTypeRef wStructPointerType, Node *functionDeclarationNode, uint8_t isStatic, + char *parentStructName, char *name) { uint32_t i, j, index; @@ -364,8 +496,15 @@ static void DeclareGenericStructFunction( if (structTypeDeclarations[i].structPointerType == wStructPointerType) { index = structTypeDeclarations[i].genericFunctionCount; + structTypeDeclarations[i].genericFunctions = realloc( + structTypeDeclarations[i].genericFunctions, + sizeof(StructTypeGenericFunction) * + (structTypeDeclarations[i].genericFunctionCount + 1)); structTypeDeclarations[i].genericFunctions[index].name = strdup(name); + structTypeDeclarations[i].genericFunctions[index].parentStructName = + parentStructName; + structTypeDeclarations[i].structPointerType = wStructPointerType; structTypeDeclarations[i] .genericFunctions[index] .functionDeclarationNode = functionDeclarationNode; @@ -391,46 +530,6 @@ static void DeclareGenericStructFunction( } } -static LLVMTypeRef LookupCustomType(char *name) -{ - uint32_t i; - - for (i = 0; i < structTypeDeclarationCount; i += 1) - { - if (strcmp(structTypeDeclarations[i].name, name) == 0) - { - return structTypeDeclarations[i].structType; - } - } - - fprintf(stderr, "Could not find struct type!\n"); - return NULL; -} - -static LLVMTypeRef ResolveType(Node *typeNode) -{ - if (IsPrimitiveType(typeNode)) - { - return WraithTypeToLLVMType( - typeNode->type.typeNode->primitiveType.type); - } - else if (typeNode->type.typeNode->syntaxKind == CustomTypeNode) - { - return LookupCustomType(typeNode->type.typeNode->customType.name); - } - else if (typeNode->type.typeNode->syntaxKind == ReferenceTypeNode) - { - return LLVMPointerType( - ResolveType(typeNode->type.typeNode->referenceType.type), - 0); - } - else - { - fprintf(stderr, "Unknown type node!\n"); - return NULL; - } -} - static inline uint64_t HashTypeTags(TypeTag **tags, uint32_t count) { const uint64_t HASH_FACTOR = 97; @@ -445,7 +544,159 @@ static inline uint64_t HashTypeTags(TypeTag **tags, uint32_t count) return result; } +static StructTypeFunction CompileGenericFunction( + LLVMModuleRef module, + char *parentStructName, + LLVMTypeRef wStructPointerType, + TypeTag **genericArgumentTypes, + uint32_t genericArgumentTypeCount, + Node *functionDeclaration) +{ + uint32_t i; + uint8_t hasReturn = 0; + uint8_t isStatic = 0; + Node *functionSignature = + functionDeclaration->functionDeclaration.functionSignature; + Node *functionBody = functionDeclaration->functionDeclaration.functionBody; + uint32_t argumentCount = functionSignature->functionSignature.arguments + ->functionSignatureArguments.count; + LLVMTypeRef paramTypes[argumentCount + 1]; + uint32_t paramIndex = 0; + + PushScopeFrame(scope); + + for (i = 0; i < genericArgumentTypeCount; i += 1) + { + AddGenericVariable( + scope, + genericArgumentTypes[i], + functionDeclaration->functionDeclaration.functionSignature + ->functionSignature.genericArguments->genericArguments + .arguments[i] + ->genericArgument.identifier->identifier.name); + } + + if (functionSignature->functionSignature.modifiers->functionModifiers + .count > 0) + { + for (i = 0; i < functionSignature->functionSignature.modifiers + ->functionModifiers.count; + i += 1) + { + if (functionSignature->functionSignature.modifiers + ->functionModifiers.sequence[i] + ->syntaxKind == StaticModifier) + { + isStatic = 1; + break; + } + } + } + + char *functionName = strdup(parentStructName); + strcat(functionName, "_"); + strcat( + functionName, + functionSignature->functionSignature.identifier->identifier.name); + + for (i = 0; i < genericArgumentTypeCount; i += 1) + { + strcat(functionName, TypeTagToString(genericArgumentTypes[i])); + } + + if (!isStatic) + { + paramTypes[paramIndex] = wStructPointerType; + paramIndex += 1; + } + + for (i = 0; i < functionSignature->functionSignature.arguments + ->functionSignatureArguments.count; + i += 1) + { + paramTypes[paramIndex] = + ResolveType(functionSignature->functionSignature.arguments + ->functionSignatureArguments.sequence[i] + ->declaration.identifier->typeTag); + paramIndex += 1; + } + + LLVMTypeRef returnType = + ResolveType(functionSignature->functionSignature.identifier->typeTag); + LLVMTypeRef functionType = + LLVMFunctionType(returnType, paramTypes, paramIndex, 0); + + LLVMValueRef function = LLVMAddFunction(module, functionName, functionType); + + LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry"); + LLVMBuilderRef builder = LLVMCreateBuilder(); + LLVMPositionBuilderAtEnd(builder, entry); + + if (!isStatic) + { + LLVMValueRef wStructPointer = LLVMGetParam(function, 0); + AddStructVariablesToScope(builder, wStructPointer); + } + + for (i = 0; i < functionSignature->functionSignature.arguments + ->functionSignatureArguments.count; + i += 1) + { + char *ptrName = strdup(functionSignature->functionSignature.arguments + ->functionSignatureArguments.sequence[i] + ->declaration.identifier->identifier.name); + strcat(ptrName, "_ptr"); + LLVMValueRef argument = LLVMGetParam(function, i + !isStatic); + LLVMValueRef argumentCopy = + LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName); + LLVMBuildStore(builder, argument, argumentCopy); + free(ptrName); + AddLocalVariable( + scope, + argumentCopy, + NULL, + functionSignature->functionSignature.arguments + ->functionSignatureArguments.sequence[i] + ->declaration.identifier->identifier.name); + } + + for (i = 0; i < functionBody->statementSequence.count; i += 1) + { + CompileStatement( + module, + builder, + function, + functionBody->statementSequence.sequence[i]); + } + + hasReturn = + LLVMGetBasicBlockTerminator(LLVMGetLastBasicBlock(function)) != NULL; + + if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn) + { + LLVMBuildRetVoid(builder); + } + else if (LLVMGetTypeKind(returnType) != LLVMVoidTypeKind && !hasReturn) + { + fprintf(stderr, "Return statement not provided!"); + } + + LLVMDisposeBuilder(builder); + PopScopeFrame(scope); + free(functionName); + + StructTypeFunction structTypeFunction; + structTypeFunction.name = strdup( + functionSignature->functionSignature.identifier->identifier.name); + structTypeFunction.function = function; + structTypeFunction.returnType = returnType; + structTypeFunction.isStatic = isStatic; + + return structTypeFunction; +} + static LLVMValueRef LookupGenericFunction( + LLVMModuleRef module, StructTypeGenericFunction *genericFunction, TypeTag **genericArgumentTypes, uint32_t genericArgumentTypeCount, @@ -484,17 +735,41 @@ static LLVMValueRef LookupGenericFunction( if (hashEntry == NULL) { + StructTypeFunction function = CompileGenericFunction( + module, + genericFunction->parentStructName, + genericFunction->parentStructPointerType, + genericArgumentTypes, + genericArgumentTypeCount, + genericFunction->functionDeclarationNode); - /* TODO: compile */ + /* TODO: add to hash */ + hashArray->elements = realloc( + hashArray->elements, + sizeof(MonomorphizedGenericFunctionHashEntry) * + (hashArray->count + 1)); + hashArray->elements[hashArray->count].key = typeHash; + hashArray->elements[hashArray->count].types = + malloc(sizeof(TypeTag *) * genericArgumentTypeCount); + hashArray->elements[hashArray->count].typeCount = + genericArgumentTypeCount; + hashArray->elements[hashArray->count].function = function; + for (i = 0; i < genericArgumentTypeCount; i += 1) + { + hashArray->elements[hashArray->count].types[i] = + genericArgumentTypes[i]; + } + hashArray->count += 1; } *pReturnType = hashEntry->function.returnType; - *pStatic = hashEntry->function.isStatic; + *pStatic = genericFunction->isStatic; return hashEntry->function.function; } static LLVMValueRef LookupFunctionByType( + LLVMModuleRef module, LLVMTypeRef structType, char *name, TypeTag **genericArgumentTypes, @@ -528,6 +803,7 @@ static LLVMValueRef LookupFunctionByType( name) == 0) { return LookupGenericFunction( + module, &structTypeDeclarations[i].genericFunctions[j], genericArgumentTypes, genericArgumentTypeCount, @@ -543,6 +819,7 @@ static LLVMValueRef LookupFunctionByType( } static LLVMValueRef LookupFunctionByPointerType( + LLVMModuleRef module, LLVMTypeRef structPointerType, char *name, TypeTag **genericArgumentTypes, @@ -576,6 +853,7 @@ static LLVMValueRef LookupFunctionByPointerType( name) == 0) { return LookupGenericFunction( + module, &structTypeDeclarations[i].genericFunctions[j], genericArgumentTypes, genericArgumentTypeCount, @@ -591,6 +869,7 @@ static LLVMValueRef LookupFunctionByPointerType( } static LLVMValueRef LookupFunctionByInstance( + LLVMModuleRef module, LLVMValueRef structPointer, char *name, TypeTag **genericArgumentTypes, @@ -599,6 +878,7 @@ static LLVMValueRef LookupFunctionByInstance( uint8_t *pStatic) { return LookupFunctionByPointerType( + module, LLVMTypeOf(structPointer), name, genericArgumentTypes, @@ -607,41 +887,6 @@ static LLVMValueRef LookupFunctionByInstance( pStatic); } -static void AddStructVariablesToScope( - LLVMBuilderRef builder, - LLVMValueRef structPointer) -{ - uint32_t i, j; - - for (i = 0; i < structTypeDeclarationCount; i += 1) - { - if (structTypeDeclarations[i].structPointerType == - LLVMTypeOf(structPointer)) - { - for (j = 0; j < structTypeDeclarations[i].fieldCount; j += 1) - { - char *ptrName = - strdup(structTypeDeclarations[i].fields[j].name); - strcat(ptrName, "_ptr"); - LLVMValueRef elementPointer = LLVMBuildStructGEP( - builder, - structPointer, - structTypeDeclarations[i].fields[j].index, - ptrName); - free(ptrName); - - AddLocalVariable( - scope, - elementPointer, - NULL, - structTypeDeclarations[i].fields[j].name); - } - } - } -} - -static LLVMValueRef CompileExpression(LLVMBuilderRef builder, Node *expression); - static LLVMValueRef CompileNumber(Node *numberExpression) { return LLVMConstInt(LLVMInt64Type(), numberExpression->number.value, 0); @@ -658,13 +903,19 @@ static LLVMValueRef CompileString( } static LLVMValueRef CompileBinaryExpression( + LLVMModuleRef module, LLVMBuilderRef builder, Node *binaryExpression) { - LLVMValueRef left = - CompileExpression(builder, binaryExpression->binaryExpression.left); - LLVMValueRef right = - CompileExpression(builder, binaryExpression->binaryExpression.right); + LLVMValueRef left = CompileExpression( + module, + builder, + binaryExpression->binaryExpression.left); + + LLVMValueRef right = CompileExpression( + module, + builder, + binaryExpression->binaryExpression.right); switch (binaryExpression->binaryExpression.operator) { @@ -709,6 +960,7 @@ static LLVMValueRef CompileBinaryExpression( /* FIXME THIS IS ALL BROKEN */ static LLVMValueRef CompileFunctionCallExpression( + LLVMModuleRef module, LLVMBuilderRef builder, Node *functionCallExpression) { @@ -728,6 +980,7 @@ static LLVMValueRef CompileFunctionCallExpression( LLVMTypeRef functionReturnType; char *returnName = ""; + /* FIXME: this is completely wrong and not how we get generic args */ for (i = 0; i < functionCallExpression->functionCallExpression .argumentSequence->functionArgumentSequence.count; i += 1) @@ -739,7 +992,7 @@ static LLVMValueRef CompileFunctionCallExpression( genericArgumentTypes[genericArgumentCount] = functionCallExpression->functionCallExpression.argumentSequence ->functionArgumentSequence.sequence[i] - ->typeTag; + ->declaration.identifier->typeTag; genericArgumentCount += 1; } @@ -757,6 +1010,7 @@ static LLVMValueRef CompileFunctionCallExpression( if (typeReference != NULL) { function = LookupFunctionByType( + module, typeReference, functionCallExpression->functionCallExpression.identifier ->accessExpression.accessor->identifier.name, @@ -771,6 +1025,7 @@ static LLVMValueRef CompileFunctionCallExpression( functionCallExpression->functionCallExpression.identifier ->accessExpression.accessee->identifier.name); function = LookupFunctionByInstance( + module, structInstance, functionCallExpression->functionCallExpression.identifier ->accessExpression.accessor->identifier.name, @@ -797,6 +1052,7 @@ static LLVMValueRef CompileFunctionCallExpression( i += 1) { args[argumentCount] = CompileExpression( + module, builder, functionCallExpression->functionCallExpression.argumentSequence ->functionArgumentSequence.sequence[i]); @@ -843,11 +1099,14 @@ static LLVMValueRef CompileAllocExpression( LLVMBuilderRef builder, Node *allocExpression) { - LLVMTypeRef type = ResolveType(allocExpression->allocExpression.type); + LLVMTypeRef type = ResolveType(allocExpression->typeTag); return LLVMBuildMalloc(builder, type, "allocation"); } -static LLVMValueRef CompileExpression(LLVMBuilderRef builder, Node *expression) +static LLVMValueRef CompileExpression( + LLVMModuleRef module, + LLVMBuilderRef builder, + Node *expression) { switch (expression->syntaxKind) { @@ -858,10 +1117,10 @@ static LLVMValueRef CompileExpression(LLVMBuilderRef builder, Node *expression) return CompileAllocExpression(builder, expression); case BinaryExpression: - return CompileBinaryExpression(builder, expression); + return CompileBinaryExpression(module, builder, expression); case FunctionCallExpression: - return CompileFunctionCallExpression(builder, expression); + return CompileFunctionCallExpression(module, builder, expression); case Identifier: return FindVariableValue(builder, expression->identifier.name); @@ -877,17 +1136,14 @@ static LLVMValueRef CompileExpression(LLVMBuilderRef builder, Node *expression) return NULL; } -static LLVMBasicBlockRef CompileStatement( - LLVMBuilderRef builder, - LLVMValueRef function, - Node *statement); - static LLVMBasicBlockRef CompileReturn( + LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement) { LLVMValueRef expression = CompileExpression( + module, builder, returnStatemement->returnStatement.expression); LLVMBuildRet(builder, expression); @@ -916,7 +1172,7 @@ static LLVMValueRef CompileFunctionVariableDeclaration( variable = LLVMBuildAlloca( builder, - ResolveType(variableDeclaration->declaration.type), + ResolveType(variableDeclaration->declaration.identifier->typeTag), ptrName); free(ptrName); @@ -927,11 +1183,13 @@ static LLVMValueRef CompileFunctionVariableDeclaration( } static LLVMBasicBlockRef CompileAssignment( + LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement) { LLVMValueRef result = CompileExpression( + module, builder, assignmentStatement->assignmentStatement.right); LLVMValueRef identifier; @@ -969,13 +1227,14 @@ static LLVMBasicBlockRef CompileAssignment( } static LLVMBasicBlockRef CompileIfStatement( + LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement) { uint32_t i; LLVMValueRef conditional = - CompileExpression(builder, ifStatement->ifStatement.expression); + CompileExpression(module, builder, ifStatement->ifStatement.expression); LLVMBasicBlockRef block = LLVMAppendBasicBlock(function, "ifBlock"); LLVMBasicBlockRef afterCond = LLVMAppendBasicBlock(function, "afterCond"); @@ -990,6 +1249,7 @@ static LLVMBasicBlockRef CompileIfStatement( i += 1) { CompileStatement( + module, builder, function, ifStatement->ifStatement.statementSequence->statementSequence @@ -1003,12 +1263,14 @@ static LLVMBasicBlockRef CompileIfStatement( } static LLVMBasicBlockRef CompileIfElseStatement( + LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *ifElseStatement) { uint32_t i; LLVMValueRef conditional = CompileExpression( + module, builder, ifElseStatement->ifElseStatement.ifStatement->ifStatement.expression); @@ -1025,6 +1287,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( i += 1) { CompileStatement( + module, builder, function, ifElseStatement->ifElseStatement.ifStatement->ifStatement @@ -1043,6 +1306,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( i += 1) { CompileStatement( + module, builder, function, ifElseStatement->ifElseStatement.elseStatement @@ -1052,6 +1316,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( else { CompileStatement( + module, builder, function, ifElseStatement->ifElseStatement.elseStatement); @@ -1064,6 +1329,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( } static LLVMBasicBlockRef CompileForLoopStatement( + LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *forLoopStatement) @@ -1076,8 +1342,8 @@ static LLVMBasicBlockRef CompileForLoopStatement( LLVMAppendBasicBlock(function, "afterLoop"); char *iteratorVariableName = forLoopStatement->forLoop.declaration ->declaration.identifier->identifier.name; - LLVMTypeRef iteratorVariableType = - ResolveType(forLoopStatement->forLoop.declaration->declaration.type); + LLVMTypeRef iteratorVariableType = ResolveType( + forLoopStatement->forLoop.declaration->declaration.identifier->typeTag); PushScopeFrame(scope); @@ -1123,6 +1389,7 @@ static LLVMBasicBlockRef CompileForLoopStatement( i += 1) { lastBlock = CompileStatement( + module, builder, function, forLoopStatement->forLoop.statementSequence->statementSequence @@ -1151,6 +1418,7 @@ static LLVMBasicBlockRef CompileForLoopStatement( } static LLVMBasicBlockRef CompileStatement( + LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *statement) @@ -1158,27 +1426,27 @@ static LLVMBasicBlockRef CompileStatement( switch (statement->syntaxKind) { case Assignment: - return CompileAssignment(builder, function, statement); + return CompileAssignment(module, builder, function, statement); case Declaration: CompileFunctionVariableDeclaration(builder, function, statement); return LLVMGetLastBasicBlock(function); case ForLoop: - return CompileForLoopStatement(builder, function, statement); + return CompileForLoopStatement(module, builder, function, statement); case FunctionCallExpression: - CompileFunctionCallExpression(builder, statement); + CompileFunctionCallExpression(module, builder, statement); return LLVMGetLastBasicBlock(function); case IfStatement: - return CompileIfStatement(builder, function, statement); + return CompileIfStatement(module, builder, function, statement); case IfElseStatement: - return CompileIfElseStatement(builder, function, statement); + return CompileIfElseStatement(module, builder, function, statement); case Return: - return CompileReturn(builder, function, statement); + return CompileReturn(module, builder, function, statement); case ReturnVoid: return CompileReturnVoid(builder, function); @@ -1192,8 +1460,6 @@ static void CompileFunction( LLVMModuleRef module, char *parentStructName, LLVMTypeRef wStructPointerType, - Node **fieldDeclarations, - uint32_t fieldDeclarationCount, Node *functionDeclaration) { uint32_t i; @@ -1248,12 +1514,12 @@ static void CompileFunction( paramTypes[paramIndex] = ResolveType(functionSignature->functionSignature.arguments ->functionSignatureArguments.sequence[i] - ->declaration.type); + ->declaration.identifier->typeTag); paramIndex += 1; } - LLVMTypeRef returnType = - ResolveType(functionSignature->functionSignature.type); + LLVMTypeRef returnType = ResolveType( + functionSignature->functionSignature.identifier->typeTag); LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, paramIndex, 0); @@ -1303,6 +1569,7 @@ static void CompileFunction( for (i = 0; i < functionBody->statementSequence.count; i += 1) { CompileStatement( + module, builder, function, functionBody->statementSequence.sequence[i]); @@ -1330,7 +1597,8 @@ static void CompileFunction( wStructPointerType, functionDeclaration, isStatic, - functionName); + parentStructName, + functionSignature->functionSignature.identifier->identifier.name); } free(functionName); @@ -1367,8 +1635,8 @@ static void CompileStruct( switch (currentDeclarationNode->syntaxKind) { case Declaration: /* this is badly named */ - types[fieldCount] = - ResolveType(currentDeclarationNode->declaration.type); + types[fieldCount] = ResolveType( + currentDeclarationNode->declaration.identifier->typeTag); fieldDeclarations[fieldCount] = currentDeclarationNode; fieldCount += 1; break; @@ -1396,8 +1664,6 @@ static void CompileStruct( module, structName, wStructPointerType, - fieldDeclarations, - fieldCount, currentDeclarationNode); break; } diff --git a/src/identcheck.c b/src/identcheck.c index 571a29e..2d040d0 100644 --- a/src/identcheck.c +++ b/src/identcheck.c @@ -59,9 +59,7 @@ IdNode *MakeIdTree(Node *astNode, IdNode *parent) return NULL; case AllocExpression: - AddChildToNode( - parent, - MakeIdTree(astNode->allocExpression.type, parent)); + astNode->typeTag = MakeTypeTag(astNode); return NULL; case Assignment: @@ -154,6 +152,7 @@ IdNode *MakeIdTree(Node *astNode, IdNode *parent) idNode->typeTag = mainNode->typeTag; MakeIdTree(sigNode->functionSignature.arguments, mainNode); MakeIdTree(astNode->functionDeclaration.functionBody, mainNode); + MakeIdTree(sigNode->functionSignature.genericArguments, mainNode); break; } @@ -167,6 +166,23 @@ IdNode *MakeIdTree(Node *astNode, IdNode *parent) return NULL; } + case GenericArgument: + { + char *name = astNode->genericArgument.identifier->identifier.name; + mainNode = MakeIdNode(GenericType, name, parent); + break; + } + + case GenericArguments: + { + for (i = 0; i < astNode->genericArguments.count; i += 1) + { + Node *argNode = astNode->genericArguments.arguments[i]; + AddChildToNode(parent, MakeIdTree(argNode, parent)); + } + return NULL; + } + case Identifier: { char *name = astNode->identifier.name; @@ -302,6 +318,12 @@ void PrintIdNode(IdNode *node) case Variable: printf("%s : %s\n", node->name, TypeTagToString(node->typeTag)); break; + case GenericType: + printf("Generic type: %s\n", node->name); + break; + case Alloc: + printf("Alloc: %s\n", TypeTagToString(node->typeTag)); + break; } } diff --git a/src/identcheck.h b/src/identcheck.h index c0ccca6..8b287dd 100644 --- a/src/identcheck.h +++ b/src/identcheck.h @@ -17,7 +17,9 @@ typedef enum NodeType OrderedScope, Struct, Function, - Variable + Variable, + GenericType, + Alloc } NodeType; typedef struct IdNode diff --git a/src/util.c b/src/util.c index aa5b4fb..42911e7 100644 --- a/src/util.c +++ b/src/util.c @@ -20,7 +20,7 @@ uint64_t str_hash(char *str) uint64_t hash = 5381; size_t c; - while (c = *str++) + while ((c = *str++)) { hash = ((hash << 5) + hash) + c; /* hash * 33 + c */ } -- 2.25.1 From d48995716eb6c33a12928af96fe74158d37cc217 Mon Sep 17 00:00:00 2001 From: venko Date: Sun, 23 May 2021 16:58:59 -0700 Subject: [PATCH 05/17] Adds handling for generic AST nodes in PrintNode and SyntaxKindString --- generic.w | 13 +++++++++++++ src/ast.c | 19 +++++++++++++++++++ 2 files changed, 32 insertions(+) create mode 100644 generic.w diff --git a/generic.w b/generic.w new file mode 100644 index 0000000..eb25bd1 --- /dev/null +++ b/generic.w @@ -0,0 +1,13 @@ +struct Foo { + static Func(t: T): T { + return t; + } +} + +struct Program { + static main(): int { + x: int = 4; + y: int = Foo.Func(x); + return x; + } +} \ No newline at end of file diff --git a/src/ast.c b/src/ast.c index dde4693..a6809a6 100644 --- a/src/ast.c +++ b/src/ast.c @@ -39,6 +39,10 @@ const char *SyntaxKindString(SyntaxKind syntaxKind) return "FunctionSignature"; case FunctionSignatureArguments: return "FunctionSignatureArguments"; + case GenericArgument: + return "GenericArgument"; + case GenericArguments: + return "GenericArguments"; case Identifier: return "Identifier"; case IfStatement: @@ -599,6 +603,7 @@ void PrintNode(Node *node, uint32_t tabCount) case FunctionSignature: printf("\n"); PrintNode(node->functionSignature.identifier, tabCount + 1); + PrintNode(node->functionSignature.genericArguments, tabCount + 1); PrintNode(node->functionSignature.arguments, tabCount + 1); PrintNode(node->functionSignature.type, tabCount + 1); PrintNode(node->functionSignature.modifiers, tabCount + 1); @@ -614,6 +619,20 @@ void PrintNode(Node *node, uint32_t tabCount) } return; + case GenericArgument: + printf("\n"); + PrintNode(node->genericArgument.identifier, tabCount + 1); + /* Constraint nodes are not implemented. */ + /* PrintNode(node->genericArgument.constraint, tabCount + 1); */ + return; + + case GenericArguments: + printf("\n"); + for (i = 0; i < node->genericArguments.count; i += 1) { + PrintNode(node->genericArguments.arguments[i], tabCount + 1); + } + return; + case Identifier: if (node->typeTag == NULL) { -- 2.25.1 From 79d47157992cb8eb10f8956339cce9b82d9a3367 Mon Sep 17 00:00:00 2001 From: venko Date: Sun, 23 May 2021 17:04:50 -0700 Subject: [PATCH 06/17] Moves generic type identifiers to be the first children of a function in the id-tree --- src/identcheck.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/identcheck.c b/src/identcheck.c index 2d040d0..d2bd6c6 100644 --- a/src/identcheck.c +++ b/src/identcheck.c @@ -150,9 +150,9 @@ IdNode *MakeIdTree(Node *astNode, IdNode *parent) mainNode = MakeIdNode(Function, funcName, parent); mainNode->typeTag = MakeTypeTag(astNode); idNode->typeTag = mainNode->typeTag; + MakeIdTree(sigNode->functionSignature.genericArguments, mainNode); MakeIdTree(sigNode->functionSignature.arguments, mainNode); MakeIdTree(astNode->functionDeclaration.functionBody, mainNode); - MakeIdTree(sigNode->functionSignature.genericArguments, mainNode); break; } -- 2.25.1 From eb24206e1320182feced0011e535ee51b65df080 Mon Sep 17 00:00:00 2001 From: venko Date: Mon, 24 May 2021 19:20:23 -0700 Subject: [PATCH 07/17] Imlements custom to generic type conversion --- CMakeLists.txt | 2 + generic.w | 7 +- src/ast.c | 30 ++++++- src/ast.h | 12 ++- src/main.c | 11 +++ src/typeutils.c | 202 ++++++++++++++++++++++++++++++++++++++++++++++++ src/typeutils.h | 13 ++++ 7 files changed, 274 insertions(+), 3 deletions(-) create mode 100644 src/typeutils.c create mode 100644 src/typeutils.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 9898440..4cb94d8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,11 +43,13 @@ add_executable( src/codegen.h src/identcheck.h src/parser.h + src/typeutils.h src/util.h src/ast.c src/codegen.c src/identcheck.c src/parser.c + src/typeutils.c src/util.c src/main.c # Generated code diff --git a/generic.w b/generic.w index eb25bd1..63247ee 100644 --- a/generic.w +++ b/generic.w @@ -1,6 +1,11 @@ struct Foo { + static Func2(u: U) : U { + return u; + } + static Func(t: T): T { - return t; + foo: T = t; + return Func2(foo); } } diff --git a/src/ast.c b/src/ast.c index a6809a6..e705b5b 100644 --- a/src/ast.c +++ b/src/ast.c @@ -43,6 +43,8 @@ const char *SyntaxKindString(SyntaxKind syntaxKind) return "GenericArgument"; case GenericArguments: return "GenericArguments"; + case GenericTypeNode: + return "GenericTypeNode"; case Identifier: return "Identifier"; case IfStatement: @@ -405,6 +407,14 @@ Node *MakeEmptyGenericArgumentsNode() return node; } +Node *MakeGenericTypeNode(char *name) +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = GenericTypeNode; + node->genericType.name = strdup(name); + return node; +} + Node *MakeFunctionCallExpressionNode( Node *identifierNode, Node *argumentSequenceNode) @@ -633,6 +643,10 @@ void PrintNode(Node *node, uint32_t tabCount) } return; + case GenericTypeNode: + printf("%s\n", node->genericType.name); + return; + case Identifier: if (node->typeTag == NULL) { @@ -763,6 +777,10 @@ TypeTag *MakeTypeTag(Node *node) tag = MakeTypeTag(node->allocExpression.type); break; + case GenericTypeNode: + tag->type = Generic; + tag->value.genericType = strdup(node->genericType.name); + default: fprintf( stderr, @@ -799,6 +817,16 @@ char *TypeTagToString(TypeTag *tag) return result; } case Custom: - return tag->value.customType; + { + char *result = malloc(sizeof(char) * (strlen(tag->value.customType) + 8)); + sprintf(result, "Custom<%s>", tag->value.customType); + return result; + } + case Generic: + { + char *result = malloc(sizeof(char) * (strlen(tag->value.customType) + 9)); + sprintf(result, "Generic<%s>", tag->value.customType); + return result; + } } } diff --git a/src/ast.h b/src/ast.h index a4b367c..eb3cc10 100644 --- a/src/ast.h +++ b/src/ast.h @@ -32,6 +32,7 @@ typedef enum FunctionSignatureArguments, GenericArgument, GenericArguments, + GenericTypeNode, Identifier, IfStatement, IfElseStatement, @@ -89,7 +90,8 @@ typedef struct TypeTag Unknown, Primitive, Reference, - Custom + Custom, + Generic } type; union { @@ -99,6 +101,8 @@ typedef struct TypeTag struct TypeTag *referenceType; /* Valid when type = Custom. */ char *customType; + /* Valid when type = Generic. */ + char *genericType; } value; } TypeTag; @@ -215,6 +219,11 @@ struct Node uint32_t count; } genericArguments; + struct + { + char *name; + } genericType; + struct { char *name; @@ -330,6 +339,7 @@ Node *MakeGenericArgumentNode(Node *identifierNode, Node *constraintNode); Node *MakeEmptyGenericArgumentsNode(); Node *StartGenericArgumentsNode(Node *genericArgumentNode); Node *AddGenericArgument(Node *genericArgumentsNode, Node *genericArgumentNode); +Node *MakeGenericTypeNode(char *name); Node *MakeStructDeclarationNode( Node *identifierNode, Node *declarationSequenceNode); diff --git a/src/main.c b/src/main.c index eda6895..cbb7e80 100644 --- a/src/main.c +++ b/src/main.c @@ -4,6 +4,7 @@ #include "codegen.h" #include "identcheck.h" #include "parser.h" +#include "typeutils.h" int main(int argc, char *argv[]) { @@ -87,9 +88,19 @@ int main(int argc, char *argv[]) { { IdNode *idTree = MakeIdTree(rootNode, NULL); + printf("\n"); PrintIdTree(idTree, /*tabCount=*/0); + + printf("\nConverting custom types in the ID-tree.\n"); + ConvertIdCustomsToGenerics(idTree); + printf("\n"); + PrintIdTree(idTree, /*tabCount=*/0); + + printf("\nConverting custom type nodes in the AST.\n"); + ConvertASTCustomsToGenerics(rootNode); printf("\n"); PrintNode(rootNode, /*tabCount=*/0); + } exitCode = Codegen(rootNode, optimizationLevel); } diff --git a/src/typeutils.c b/src/typeutils.c new file mode 100644 index 0000000..1f20be8 --- /dev/null +++ b/src/typeutils.c @@ -0,0 +1,202 @@ +#include "typeutils.h" + +#include +#include +#include + +void ConvertIdCustomsToGenerics(IdNode *node) { + uint32_t i; + switch(node->type) + { + case UnorderedScope: + case OrderedScope: + case Struct: + /* FIXME: This case will need to be modified to handle type parameters over structs. */ + for (i = 0; i < node->childCount; i += 1) { + ConvertIdCustomsToGenerics(node->children[i]); + } + return; + + case Variable: { + TypeTag *varType = node->typeTag; + if (varType->type == Custom) { + IdNode *x = LookupId(node->parent, node, varType->value.customType); + if (x != NULL && x->type == GenericType) { + varType->type = Generic; + } + } + return; + } + + case Function: { + TypeTag *funcType = node->typeTag; + if (funcType->type == Custom) { + /* For functions we have to handle the type lookup manually since the generic type + * identifiers are declared as children of the function's IdNode. */ + for (i = 0; i < node->childCount; i += 1) { + IdNode *child = node->children[i]; + if (child->type == GenericType && strcmp(child->name, funcType->value.customType) == 0) { + funcType->type = Generic; + } + } + } + + for (i = 0; i < node->childCount; i += 1) { + ConvertIdCustomsToGenerics(node->children[i]); + } + return; + } + } +} + +void ConvertASTCustomsToGenerics(Node *node) { + uint32_t i; + switch (node->syntaxKind) { + case AccessExpression: + ConvertASTCustomsToGenerics(node->accessExpression.accessee); + ConvertASTCustomsToGenerics(node->accessExpression.accessor); + return; + + case AllocExpression: + ConvertASTCustomsToGenerics(node->allocExpression.type); + return; + + case Assignment: + ConvertASTCustomsToGenerics(node->assignmentStatement.left); + ConvertASTCustomsToGenerics(node->assignmentStatement.right); + return; + + case BinaryExpression: + ConvertASTCustomsToGenerics(node->binaryExpression.left); + ConvertASTCustomsToGenerics(node->binaryExpression.right); + return; + + case Comment: + return; + + case CustomTypeNode: + return; + + case Declaration: { + Node *type = node->declaration.type->type.typeNode; + Node *id = node->declaration.identifier; + if (id->typeTag->type == Generic && type->syntaxKind == CustomTypeNode) { + free(node->declaration.type); + node->declaration.type = MakeGenericTypeNode(id->typeTag->value.genericType); + } + return; + } + + case DeclarationSequence: + for (i = 0; i < node->declarationSequence.count; i += 1) { + ConvertASTCustomsToGenerics(node->declarationSequence.sequence[i]); + } + return; + + case ForLoop: + ConvertASTCustomsToGenerics(node->forLoop.declaration); + ConvertASTCustomsToGenerics(node->forLoop.startNumber); + ConvertASTCustomsToGenerics(node->forLoop.endNumber); + ConvertASTCustomsToGenerics(node->forLoop.statementSequence); + return; + + case FunctionArgumentSequence: + for (i = 0; i < node->functionArgumentSequence.count; i += 1) { + ConvertASTCustomsToGenerics(node->functionArgumentSequence.sequence[i]); + } + return; + + case FunctionCallExpression: + ConvertASTCustomsToGenerics(node->functionCallExpression.identifier); + ConvertASTCustomsToGenerics(node->functionCallExpression.argumentSequence); + return; + + case FunctionDeclaration: + ConvertASTCustomsToGenerics(node->functionDeclaration.functionSignature); + ConvertASTCustomsToGenerics(node->functionDeclaration.functionBody); + return; + + case FunctionModifiers: + return; + + case FunctionSignature:{ + Node *id = node->functionSignature.identifier; + Node *type = node->functionSignature.type; + if (id->typeTag->type == Generic && type->syntaxKind == CustomTypeNode) { + free(node->functionSignature.type); + node->functionSignature.type = MakeGenericTypeNode(id->typeTag->value.genericType); + } + ConvertASTCustomsToGenerics(node->functionSignature.arguments); + return; + } + + case FunctionSignatureArguments: + for (i = 0; i < node->functionSignatureArguments.count; i += 1) { + ConvertASTCustomsToGenerics(node->functionSignatureArguments.sequence[i]); + } + return; + + case GenericArgument: + return; + + case GenericArguments: + return; + + case GenericTypeNode: + return; + + case Identifier: + return; + + case IfStatement: + ConvertASTCustomsToGenerics(node->ifStatement.expression); + ConvertASTCustomsToGenerics(node->ifStatement.statementSequence); + return; + + case IfElseStatement: + ConvertASTCustomsToGenerics(node->ifElseStatement.ifStatement); + ConvertASTCustomsToGenerics(node->ifElseStatement.elseStatement); + return; + + case Number: + return; + + case PrimitiveTypeNode: + return; + + case ReferenceTypeNode: + return; + + case Return: + ConvertASTCustomsToGenerics(node->returnStatement.expression); + return; + + case ReturnVoid: + return; + + case StatementSequence: + for (i = 0; i < node->statementSequence.count; i += 1) { + ConvertASTCustomsToGenerics(node->statementSequence.sequence[i]); + } + return; + + case StaticModifier: + return; + + case StringLiteral: + return; + + case StructDeclaration: + /* FIXME: This case will need to be modified to handle type parameters over structs. */ + ConvertASTCustomsToGenerics(node->structDeclaration.identifier); + ConvertASTCustomsToGenerics(node->structDeclaration.declarationSequence); + return; + + case Type: + return; + + case UnaryExpression: + ConvertASTCustomsToGenerics(node->unaryExpression.child); + return; + } +} diff --git a/src/typeutils.h b/src/typeutils.h new file mode 100644 index 0000000..2e752e0 --- /dev/null +++ b/src/typeutils.h @@ -0,0 +1,13 @@ +/* Helper functions for working with types in the AST and ID-tree. */ + +#ifndef WRAITH_TYPEUTILS_H +#define WRAITH_TYPEUTILS_H + +#include "ast.h" +#include "identcheck.h" + +/* FIXME: These two functions will need to be modified to handle type parameters over structs. */ +void ConvertIdCustomsToGenerics(IdNode *node); +void ConvertASTCustomsToGenerics(Node *node); + +#endif /* WRAITH_TYPEUTILS_H */ -- 2.25.1 From ddd5b2f027c063abf6af85489a4c336e060a597f Mon Sep 17 00:00:00 2001 From: venko Date: Wed, 26 May 2021 14:43:51 -0700 Subject: [PATCH 08/17] Implements generic recursion over AST nodes --- src/ast.c | 159 ++++++++++++++++++++++++++++++++++++++++++++++ src/ast.h | 6 ++ src/typeutils.c | 166 +++++++----------------------------------------- 3 files changed, 188 insertions(+), 143 deletions(-) diff --git a/src/ast.c b/src/ast.c index e705b5b..9604dbf 100644 --- a/src/ast.c +++ b/src/ast.c @@ -726,6 +726,165 @@ void PrintNode(Node *node, uint32_t tabCount) } } +void Recurse(Node *node, void (*func)(Node*)) +{ + uint32_t i; + switch (node->syntaxKind) + { + case AccessExpression: + func(node->accessExpression.accessee); + func(node->accessExpression.accessor); + return; + + case AllocExpression: + func(node->allocExpression.type); + return; + + case Assignment: + func(node->assignmentStatement.left); + func(node->assignmentStatement.right); + return; + + case BinaryExpression: + func(node->binaryExpression.left); + func(node->binaryExpression.right); + return; + + case Comment: + return; + + case CustomTypeNode: + return; + + case Declaration: + func(node->declaration.type); + func(node->declaration.identifier); + return; + + case DeclarationSequence: + for (i = 0; i < node->declarationSequence.count; i += 1) { + func(node->declarationSequence.sequence[i]); + } + return; + + case ForLoop: + func(node->forLoop.declaration); + func(node->forLoop.startNumber); + func(node->forLoop.endNumber); + func(node->forLoop.statementSequence); + return; + + case FunctionArgumentSequence: + for (i = 0; i < node->functionArgumentSequence.count; i += 1) { + func(node->functionArgumentSequence.sequence[i]); + } + return; + + case FunctionCallExpression: + func(node->functionCallExpression.identifier); + func(node->functionCallExpression.argumentSequence); + return; + + case FunctionDeclaration: + func(node->functionDeclaration.functionSignature); + func(node->functionDeclaration.functionBody); + return; + + case FunctionModifiers: + for (i = 0; i < node->functionModifiers.count; i += 1) { + func(node->functionModifiers.sequence[i]); + } + return; + + case FunctionSignature: + func(node->functionSignature.identifier); + func(node->functionSignature.type); + func(node->functionSignature.arguments); + func(node->functionSignature.modifiers); + func(node->functionSignature.genericArguments); + return; + + case FunctionSignatureArguments: + for (i = 0; i < node->functionSignatureArguments.count; i += 1) { + func(node->functionSignatureArguments.sequence[i]); + } + return; + + case GenericArgument: + func(node->genericArgument.identifier); + func(node->genericArgument.constraint); + return; + + case GenericArguments: + for (i = 0; i < node->genericArguments.count; i += 1) { + func(node->genericArguments.arguments[i]); + } + return; + + case GenericTypeNode: + return; + + case Identifier: + return; + + case IfStatement: + func(node->ifStatement.expression); + func(node->ifStatement.statementSequence); + return; + + case IfElseStatement: + func(node->ifElseStatement.ifStatement); + func(node->ifElseStatement.elseStatement); + return; + + case Number: + return; + + case PrimitiveTypeNode: + return; + + case ReferenceTypeNode: + func(node->referenceType.type); + return; + + case Return: + func(node->returnStatement.expression); + return; + + case ReturnVoid: + return; + + case StatementSequence: + for (i = 0; i < node->statementSequence.count; i += 1) { + func(node->statementSequence.sequence[i]); + } + return; + + case StaticModifier: + return; + + case StringLiteral: + return; + + case StructDeclaration: + func(node->structDeclaration.identifier); + func(node->structDeclaration.declarationSequence); + return; + + case Type: + return; + + case UnaryExpression: + func(node->unaryExpression.child); + return; + + default: + fprintf(stderr, "wraith: Unhandled SyntaxKind %s in recurse function.\n", + SyntaxKindString(node->syntaxKind)); + return; + } +} + TypeTag *MakeTypeTag(Node *node) { if (node == NULL) diff --git a/src/ast.h b/src/ast.h index eb3cc10..8ad54a0 100644 --- a/src/ast.h +++ b/src/ast.h @@ -367,6 +367,12 @@ Node *MakeForLoopNode( void PrintNode(Node *node, uint32_t tabCount); const char *SyntaxKindString(SyntaxKind syntaxKind); +/* Helper function for applying a void function generically over the children of an AST node. + * Used for functions that need to traverse the entire tree but only perform operations on a subset + * of node types. Such functions can match the syntaxKinds relevant to their purpose and invoke this + * function in all other cases. */ +void Recurse(Node *node, void (*func)(Node*)); + TypeTag *MakeTypeTag(Node *node); char *TypeTagToString(TypeTag *tag); diff --git a/src/typeutils.c b/src/typeutils.c index 1f20be8..f86762c 100644 --- a/src/typeutils.c +++ b/src/typeutils.c @@ -51,152 +51,32 @@ void ConvertIdCustomsToGenerics(IdNode *node) { void ConvertASTCustomsToGenerics(Node *node) { uint32_t i; - switch (node->syntaxKind) { - case AccessExpression: - ConvertASTCustomsToGenerics(node->accessExpression.accessee); - ConvertASTCustomsToGenerics(node->accessExpression.accessor); - return; - - case AllocExpression: - ConvertASTCustomsToGenerics(node->allocExpression.type); - return; - - case Assignment: - ConvertASTCustomsToGenerics(node->assignmentStatement.left); - ConvertASTCustomsToGenerics(node->assignmentStatement.right); - return; - - case BinaryExpression: - ConvertASTCustomsToGenerics(node->binaryExpression.left); - ConvertASTCustomsToGenerics(node->binaryExpression.right); - return; - - case Comment: - return; - - case CustomTypeNode: - return; - - case Declaration: { - Node *type = node->declaration.type->type.typeNode; - Node *id = node->declaration.identifier; - if (id->typeTag->type == Generic && type->syntaxKind == CustomTypeNode) { - free(node->declaration.type); - node->declaration.type = MakeGenericTypeNode(id->typeTag->value.genericType); - } - return; + switch (node->syntaxKind) + { + case Declaration: + { + Node *type = node->declaration.type->type.typeNode; + Node *id = node->declaration.identifier; + if (id->typeTag->type == Generic && type->syntaxKind == CustomTypeNode) { + free(node->declaration.type); + node->declaration.type = MakeGenericTypeNode(id->typeTag->value.genericType); } + return; + } - case DeclarationSequence: - for (i = 0; i < node->declarationSequence.count; i += 1) { - ConvertASTCustomsToGenerics(node->declarationSequence.sequence[i]); - } - return; - - case ForLoop: - ConvertASTCustomsToGenerics(node->forLoop.declaration); - ConvertASTCustomsToGenerics(node->forLoop.startNumber); - ConvertASTCustomsToGenerics(node->forLoop.endNumber); - ConvertASTCustomsToGenerics(node->forLoop.statementSequence); - return; - - case FunctionArgumentSequence: - for (i = 0; i < node->functionArgumentSequence.count; i += 1) { - ConvertASTCustomsToGenerics(node->functionArgumentSequence.sequence[i]); - } - return; - - case FunctionCallExpression: - ConvertASTCustomsToGenerics(node->functionCallExpression.identifier); - ConvertASTCustomsToGenerics(node->functionCallExpression.argumentSequence); - return; - - case FunctionDeclaration: - ConvertASTCustomsToGenerics(node->functionDeclaration.functionSignature); - ConvertASTCustomsToGenerics(node->functionDeclaration.functionBody); - return; - - case FunctionModifiers: - return; - - case FunctionSignature:{ - Node *id = node->functionSignature.identifier; - Node *type = node->functionSignature.type; - if (id->typeTag->type == Generic && type->syntaxKind == CustomTypeNode) { - free(node->functionSignature.type); - node->functionSignature.type = MakeGenericTypeNode(id->typeTag->value.genericType); - } - ConvertASTCustomsToGenerics(node->functionSignature.arguments); - return; + case FunctionSignature: + { + Node *id = node->functionSignature.identifier; + Node *type = node->functionSignature.type; + if (id->typeTag->type == Generic && type->syntaxKind == CustomTypeNode) { + free(node->functionSignature.type); + node->functionSignature.type = MakeGenericTypeNode(id->typeTag->value.genericType); } + ConvertASTCustomsToGenerics(node->functionSignature.arguments); + return; + } - case FunctionSignatureArguments: - for (i = 0; i < node->functionSignatureArguments.count; i += 1) { - ConvertASTCustomsToGenerics(node->functionSignatureArguments.sequence[i]); - } - return; - - case GenericArgument: - return; - - case GenericArguments: - return; - - case GenericTypeNode: - return; - - case Identifier: - return; - - case IfStatement: - ConvertASTCustomsToGenerics(node->ifStatement.expression); - ConvertASTCustomsToGenerics(node->ifStatement.statementSequence); - return; - - case IfElseStatement: - ConvertASTCustomsToGenerics(node->ifElseStatement.ifStatement); - ConvertASTCustomsToGenerics(node->ifElseStatement.elseStatement); - return; - - case Number: - return; - - case PrimitiveTypeNode: - return; - - case ReferenceTypeNode: - return; - - case Return: - ConvertASTCustomsToGenerics(node->returnStatement.expression); - return; - - case ReturnVoid: - return; - - case StatementSequence: - for (i = 0; i < node->statementSequence.count; i += 1) { - ConvertASTCustomsToGenerics(node->statementSequence.sequence[i]); - } - return; - - case StaticModifier: - return; - - case StringLiteral: - return; - - case StructDeclaration: - /* FIXME: This case will need to be modified to handle type parameters over structs. */ - ConvertASTCustomsToGenerics(node->structDeclaration.identifier); - ConvertASTCustomsToGenerics(node->structDeclaration.declarationSequence); - return; - - case Type: - return; - - case UnaryExpression: - ConvertASTCustomsToGenerics(node->unaryExpression.child); - return; + default: + recurse(node, *ConvertASTCustomsToGenerics); } } -- 2.25.1 From 4f8f4fbe9e8e77606ef75e9e32eb1e81b3b546f5 Mon Sep 17 00:00:00 2001 From: venko Date: Wed, 26 May 2021 14:53:08 -0700 Subject: [PATCH 09/17] Adds back FIXME comment --- src/typeutils.c | 1 + 1 file changed, 1 insertion(+) diff --git a/src/typeutils.c b/src/typeutils.c index f86762c..a63420c 100644 --- a/src/typeutils.c +++ b/src/typeutils.c @@ -49,6 +49,7 @@ void ConvertIdCustomsToGenerics(IdNode *node) { } } +/* FIXME: This function will need to be modified to handle type parameters over structs. */ void ConvertASTCustomsToGenerics(Node *node) { uint32_t i; switch (node->syntaxKind) -- 2.25.1 From a69516b917f128197691a139f7ea39dfc25a12f1 Mon Sep 17 00:00:00 2001 From: venko Date: Thu, 27 May 2021 11:53:01 -0700 Subject: [PATCH 10/17] Fixes typo and removes unused variable. --- src/typeutils.c | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/typeutils.c b/src/typeutils.c index a63420c..4ebddea 100644 --- a/src/typeutils.c +++ b/src/typeutils.c @@ -51,7 +51,6 @@ void ConvertIdCustomsToGenerics(IdNode *node) { /* FIXME: This function will need to be modified to handle type parameters over structs. */ void ConvertASTCustomsToGenerics(Node *node) { - uint32_t i; switch (node->syntaxKind) { case Declaration: @@ -78,6 +77,6 @@ void ConvertASTCustomsToGenerics(Node *node) { } default: - recurse(node, *ConvertASTCustomsToGenerics); + Recurse(node, *ConvertASTCustomsToGenerics); } } -- 2.25.1 From ece20a99b55b3bd1dd20598a6640f79d60603805 Mon Sep 17 00:00:00 2001 From: venko Date: Thu, 27 May 2021 14:10:44 -0700 Subject: [PATCH 11/17] Fixes printing bug for string literals --- src/ast.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ast.c b/src/ast.c index 9604dbf..d87870b 100644 --- a/src/ast.c +++ b/src/ast.c @@ -705,7 +705,7 @@ void PrintNode(Node *node, uint32_t tabCount) return; case StringLiteral: - printf("%s", node->stringLiteral.string); + printf("%s\n", node->stringLiteral.string); return; case StructDeclaration: -- 2.25.1 From e3fc2826ea2a4e895d318f3ccd7370aa9685cd7f Mon Sep 17 00:00:00 2001 From: venko Date: Thu, 27 May 2021 16:17:25 -0700 Subject: [PATCH 12/17] Reimplements identifier lookup to work over the AST --- src/ast.c | 207 +++++++++++++++++++++++++++++++++++++++++++++++ src/ast.h | 3 + src/identcheck.c | 2 +- src/main.c | 1 + 4 files changed, 212 insertions(+), 1 deletion(-) diff --git a/src/ast.c b/src/ast.c index d87870b..f34fc61 100644 --- a/src/ast.c +++ b/src/ast.c @@ -1,5 +1,6 @@ #include "ast.h" +#include #include #include @@ -989,3 +990,209 @@ char *TypeTagToString(TypeTag *tag) } } } + +void LinkParentPointers(Node *node) +{ + static Node *parent = NULL; + + if (node == NULL) + { + fprintf(stderr, "wraith: Encountered NULL node while linking parent pointers.\n"); + return; + } + + node->parent = parent; + parent = node; + Recurse(node, *LinkParentPointers); +} + +Node *GetIdFromStruct(Node *structDecl) +{ + if (structDecl->syntaxKind != StructDeclaration) + { + fprintf(stderr, "wraith: Attempted to call GetIdFromStruct on node with kind: %s.\n", + SyntaxKindString(structDecl->syntaxKind)); + return NULL; + } + + return structDecl->structDeclaration.identifier; +} + +Node *GetIdFromFunction(Node *funcDecl) +{ + if (funcDecl->syntaxKind != FunctionDeclaration) + { + fprintf(stderr, "wraith: Attempted to call GetIdFromFunction on node with kind: %s.\n", + SyntaxKindString(funcDecl->syntaxKind)); + return NULL; + } + + Node *sig = funcDecl->functionDeclaration.functionSignature; + return sig->functionSignature.identifier; +} + +Node *GetIdFromDeclaration(Node *decl) +{ + if (decl->syntaxKind != Declaration) + { + fprintf(stderr, "wraith: Attempted to call GetIdFromDeclaration on node with kind: %s.\n", + SyntaxKindString(decl->syntaxKind)); + } + + return decl->declaration.identifier; +} + +bool AssignmentHasDeclaration(Node *assign) +{ + return (assign->syntaxKind == Assignment + && assign->assignmentStatement.left->syntaxKind == Declaration); +} + +Node *GetIdFromAssignment(Node *assign) +{ + if (assign->syntaxKind != Assignment) + { + fprintf(stderr, "wraith: Attempted to call GetIdFromAssignment on node with kind: %s.\n", + SyntaxKindString(assign->syntaxKind)); + } + + if (AssignmentHasDeclaration(assign)) + { + return GetIdFromDeclaration(assign->assignmentStatement.left); + } + + return NULL; +} + +bool NodeMayHaveId(Node *node) +{ + switch (node->syntaxKind) + { + case StructDeclaration: + case FunctionDeclaration: + case Declaration: + case Assignment: + return true; + default: + return false; + } +} + +Node *TryGetId(Node *node) +{ + switch (node->syntaxKind) + { + case StructDeclaration: + return GetIdFromStruct(node); + case FunctionDeclaration: + return GetIdFromFunction(node); + case Declaration: + return GetIdFromDeclaration(node); + default: + return NULL; + } +} + +Node *LookupFunctionArgId(Node *funcDecl, char *target) +{ + Node *args = + funcDecl->functionDeclaration.functionSignature->functionSignature.arguments; + + uint32_t i; + for (i = 0; i < args->functionArgumentSequence.count; i += 1) + { + Node *arg = args->functionArgumentSequence.sequence[i]; + if (arg->syntaxKind != Declaration) + { + fprintf(stderr, + "wraith: Encountered %s node in function signature args list.\n", + SyntaxKindString(arg->syntaxKind)); + continue; + } + + Node *argId = GetIdFromDeclaration(arg); + if (argId != NULL && strcmp(target, argId->identifier.name) == 0) + { + return argId; + } + } + + return NULL; +} + +Node *LookupIdNode(Node *current, Node *prev, char *target) +{ + if (current == NULL) return NULL; + + /* If this node may have an identifier declaration inside it, attempt to look up the identifier + * node itself, returning it if it matches the given target name. */ + if (NodeMayHaveId(current)) + { + Node *candidateId = TryGetId(current); + if (candidateId != NULL && strcmp(target, candidateId->identifier.name) == 0) + { + return candidateId; + } + + /* If the candidate node was not the one we wanted, but the current node is a function + * declaration, it's possible that the identifier we want is one of the function's + * parameters rather than the function's name itself. */ + if (current->syntaxKind == FunctionDeclaration) + { + Node *match = LookupFunctionArgId(current, target); + if (match != NULL) return match; + } + } + + /* If this is the start of our search, we should not attempt to look at + * child nodes. Only looking up the AST is valid at this point. + * + * This has the notable side-effect that this function will return NULL if + * you attempt to look up a struct's internals starting from the node + * representing the struct itself. */ + if (prev == NULL) + { + return LookupIdNode(current->parent, current, target); + } + + uint32_t i; + uint32_t idxLimit; + switch (current->syntaxKind) + { + case DeclarationSequence: + for (i = 0; i < current->declarationSequence.count; i += 1) + { + Node *decl = current->declarationSequence.sequence[i]; + Node *declId = TryGetId(decl); + if (declId != NULL) return declId; + } + break; + case StatementSequence: + idxLimit = current->statementSequence.count; + for (i = 0; i < current->statementSequence.count; i += 1) + { + if (current->statementSequence.sequence[i] == prev) + { + idxLimit = i; + break; + } + } + + for (i = 0; i < idxLimit; i += 1) + { + Node *stmt = current->statementSequence.sequence[i]; + if (stmt == prev) continue; + + if (NodeMayHaveId(stmt)) + { + Node *candidateId = TryGetId(current); + if (candidateId != NULL && strcmp(target, candidateId->identifier.name) == 0) + { + return candidateId; + } + } + } + break; + } + return LookupIdNode(current->parent, current, target); +} diff --git a/src/ast.h b/src/ast.h index 8ad54a0..967cbe5 100644 --- a/src/ast.h +++ b/src/ast.h @@ -373,7 +373,10 @@ const char *SyntaxKindString(SyntaxKind syntaxKind); * function in all other cases. */ void Recurse(Node *node, void (*func)(Node*)); +void LinkParentPointers(Node *node); + TypeTag *MakeTypeTag(Node *node); char *TypeTagToString(TypeTag *tag); +Node *LookupIdNode(Node *current, Node *prev, char *target); #endif /* WRAITH_AST_H */ diff --git a/src/identcheck.c b/src/identcheck.c index d2bd6c6..39498b0 100644 --- a/src/identcheck.c +++ b/src/identcheck.c @@ -187,7 +187,7 @@ IdNode *MakeIdTree(Node *astNode, IdNode *parent) { char *name = astNode->identifier.name; mainNode = MakeIdNode(Placeholder, name, parent); - IdNode *lookupNode = LookupId(mainNode, NULL, name); + Node *lookupNode = LookupIdNode(astNode, NULL, name); if (lookupNode == NULL) { fprintf(stderr, "wraith: Could not find IdNode for id %s\n", name); diff --git a/src/main.c b/src/main.c index cbb7e80..044210c 100644 --- a/src/main.c +++ b/src/main.c @@ -86,6 +86,7 @@ int main(int argc, char *argv[]) } else { + LinkParentPointers(rootNode); { IdNode *idTree = MakeIdTree(rootNode, NULL); printf("\n"); -- 2.25.1 From 7d5f5997120c1b91d6e7b5306c8888bd475c041f Mon Sep 17 00:00:00 2001 From: venko Date: Sat, 29 May 2021 18:27:13 -0700 Subject: [PATCH 13/17] Lots of bug fixes for id lookup --- src/ast.c | 315 +++++++++++++++++++++++++++++++++++++++++++++++------ src/ast.h | 4 +- src/main.c | 21 +--- 3 files changed, 291 insertions(+), 49 deletions(-) diff --git a/src/ast.c b/src/ast.c index f34fc61..51f416c 100644 --- a/src/ast.c +++ b/src/ast.c @@ -937,9 +937,16 @@ TypeTag *MakeTypeTag(Node *node) tag = MakeTypeTag(node->allocExpression.type); break; + case GenericArgument: + tag->type = Generic; + tag->value.genericType = strdup(node->genericArgument.identifier + ->identifier.name); + break; + case GenericTypeNode: tag->type = Generic; tag->value.genericType = strdup(node->genericType.name); + break; default: fprintf( @@ -949,6 +956,7 @@ TypeTag *MakeTypeTag(Node *node) SyntaxKindString(node->syntaxKind)); return NULL; } + return tag; } @@ -991,19 +999,167 @@ char *TypeTagToString(TypeTag *tag) } } -void LinkParentPointers(Node *node) +void LinkParentPointers(Node *node, Node *prev) { - static Node *parent = NULL; + if (node == NULL) return; - if (node == NULL) + node->parent = prev; + + uint32_t i; + switch (node->syntaxKind) { - fprintf(stderr, "wraith: Encountered NULL node while linking parent pointers.\n"); + case AccessExpression: + LinkParentPointers(node->accessExpression.accessee, node); + LinkParentPointers(node->accessExpression.accessor, node); + return; + + case AllocExpression: + LinkParentPointers(node->allocExpression.type, node); + return; + + case Assignment: + LinkParentPointers(node->assignmentStatement.left, node); + LinkParentPointers(node->assignmentStatement.right, node); + return; + + case BinaryExpression: + LinkParentPointers(node->binaryExpression.left, node); + LinkParentPointers(node->binaryExpression.right, node); + return; + + case Comment: + return; + + case CustomTypeNode: + return; + + case Declaration: + LinkParentPointers(node->declaration.type, node); + LinkParentPointers(node->declaration.identifier, node); + return; + + case DeclarationSequence: + for (i = 0; i < node->declarationSequence.count; i += 1) { + LinkParentPointers(node->declarationSequence.sequence[i], node); + } + return; + + case ForLoop: + LinkParentPointers(node->forLoop.declaration, node); + LinkParentPointers(node->forLoop.startNumber, node); + LinkParentPointers(node->forLoop.endNumber, node); + LinkParentPointers(node->forLoop.statementSequence, node); + return; + + case FunctionArgumentSequence: + for (i = 0; i < node->functionArgumentSequence.count; i += 1) { + LinkParentPointers(node->functionArgumentSequence.sequence[i], node); + } + return; + + case FunctionCallExpression: + LinkParentPointers(node->functionCallExpression.identifier, node); + LinkParentPointers(node->functionCallExpression.argumentSequence, node); + return; + + case FunctionDeclaration: + LinkParentPointers(node->functionDeclaration.functionSignature, node); + LinkParentPointers(node->functionDeclaration.functionBody, node); + return; + + case FunctionModifiers: + for (i = 0; i < node->functionModifiers.count; i += 1) { + LinkParentPointers(node->functionModifiers.sequence[i], node); + } + return; + + case FunctionSignature: + LinkParentPointers(node->functionSignature.identifier, node); + LinkParentPointers(node->functionSignature.type, node); + LinkParentPointers(node->functionSignature.arguments, node); + LinkParentPointers(node->functionSignature.modifiers, node); + LinkParentPointers(node->functionSignature.genericArguments, node); + return; + + case FunctionSignatureArguments: + for (i = 0; i < node->functionSignatureArguments.count; i += 1) { + LinkParentPointers(node->functionSignatureArguments.sequence[i], node); + } + return; + + case GenericArgument: + LinkParentPointers(node->genericArgument.identifier, node); + LinkParentPointers(node->genericArgument.constraint, node); + return; + + case GenericArguments: + for (i = 0; i < node->genericArguments.count; i += 1) { + LinkParentPointers(node->genericArguments.arguments[i], node); + } + return; + + case GenericTypeNode: + return; + + case Identifier: + return; + + case IfStatement: + LinkParentPointers(node->ifStatement.expression, node); + LinkParentPointers(node->ifStatement.statementSequence, node); + return; + + case IfElseStatement: + LinkParentPointers(node->ifElseStatement.ifStatement, node); + LinkParentPointers(node->ifElseStatement.elseStatement, node); + return; + + case Number: + return; + + case PrimitiveTypeNode: + return; + + case ReferenceTypeNode: + LinkParentPointers(node->referenceType.type, node); + return; + + case Return: + LinkParentPointers(node->returnStatement.expression, node); + return; + + case ReturnVoid: + return; + + case StatementSequence: + for (i = 0; i < node->statementSequence.count; i += 1) { + LinkParentPointers(node->statementSequence.sequence[i], node); + } + return; + + case StaticModifier: + return; + + case StringLiteral: + return; + + case StructDeclaration: + LinkParentPointers(node->structDeclaration.identifier, node); + LinkParentPointers(node->structDeclaration.declarationSequence, node); + return; + + case Type: + return; + + case UnaryExpression: + LinkParentPointers(node->unaryExpression.child, node); + return; + + default: + fprintf(stderr, "wraith: Unhandled SyntaxKind %s in recurse function.\n", + SyntaxKindString(node->syntaxKind)); return; } - - node->parent = parent; - parent = node; - Recurse(node, *LinkParentPointers); } Node *GetIdFromStruct(Node *structDecl) @@ -1082,12 +1238,14 @@ Node *TryGetId(Node *node) { switch (node->syntaxKind) { - case StructDeclaration: - return GetIdFromStruct(node); - case FunctionDeclaration: - return GetIdFromFunction(node); + case Assignment: + return GetIdFromAssignment(node); case Declaration: return GetIdFromDeclaration(node); + case FunctionDeclaration: + return GetIdFromFunction(node); + case StructDeclaration: + return GetIdFromStruct(node); default: return NULL; } @@ -1120,36 +1278,69 @@ Node *LookupFunctionArgId(Node *funcDecl, char *target) return NULL; } -Node *LookupIdNode(Node *current, Node *prev, char *target) +Node *LookupStructInternalId(Node *structDecl, char *target) { - if (current == NULL) return NULL; + Node *decls = structDecl->structDeclaration.declarationSequence; + uint32_t i; + for (i = 0; i < decls->declarationSequence.count; i += 1) + { + Node *match = TryGetId(decls->declarationSequence.sequence[i]); + if (match != NULL && strcmp(target, match->identifier.name) == 0) + return match; + } + + return NULL; +} + +Node *InspectNode(Node *node, char *target) +{ /* If this node may have an identifier declaration inside it, attempt to look up the identifier * node itself, returning it if it matches the given target name. */ - if (NodeMayHaveId(current)) + if (NodeMayHaveId(node)) { - Node *candidateId = TryGetId(current); + Node *candidateId = TryGetId(node); if (candidateId != NULL && strcmp(target, candidateId->identifier.name) == 0) { return candidateId; } - - /* If the candidate node was not the one we wanted, but the current node is a function - * declaration, it's possible that the identifier we want is one of the function's - * parameters rather than the function's name itself. */ - if (current->syntaxKind == FunctionDeclaration) - { - Node *match = LookupFunctionArgId(current, target); - if (match != NULL) return match; - } } + /* If the candidate node was not the one we wanted, but the node node is a function + * declaration, it's possible that the identifier we want is one of the function's + * parameters rather than the function's name itself. */ + if (node->syntaxKind == FunctionDeclaration) + { + Node *match = LookupFunctionArgId(node, target); + if (match != NULL) return match; + } + + /* Likewise if the node node is a struct declaration, inspect the struct's internals + * to see if a top-level definition is the one we're looking for. */ + if (node->syntaxKind == StructDeclaration) + { + Node *match = LookupStructInternalId(node, target); + if (match != NULL) return match; + } + + return NULL; +} + +Node *LookupIdNode(Node *current, Node *prev, char *target) +{ + if (current == NULL) return NULL; + Node *match; + + /* First inspect the current node to see if it contains the target identifier. */ + match = InspectNode(current, target); + if (match != NULL) return match; + /* If this is the start of our search, we should not attempt to look at * child nodes. Only looking up the AST is valid at this point. * * This has the notable side-effect that this function will return NULL if * you attempt to look up a struct's internals starting from the node - * representing the struct itself. */ + * representing the struct itself. The same is true for functions. */ if (prev == NULL) { return LookupIdNode(current->parent, current, target); @@ -1163,8 +1354,11 @@ Node *LookupIdNode(Node *current, Node *prev, char *target) for (i = 0; i < current->declarationSequence.count; i += 1) { Node *decl = current->declarationSequence.sequence[i]; - Node *declId = TryGetId(decl); - if (declId != NULL) return declId; + match = InspectNode(decl, target); + if (match != NULL) return match; + /*Node *declId = TryGetId(decl); + if (declId != NULL && strcmp(target, declId->identifier.name) == 0) + return declId;*/ } break; case StatementSequence: @@ -1183,16 +1377,73 @@ Node *LookupIdNode(Node *current, Node *prev, char *target) Node *stmt = current->statementSequence.sequence[i]; if (stmt == prev) continue; - if (NodeMayHaveId(stmt)) + if (strcmp(target, "g") == 0) { + printf("info: %s\n", SyntaxKindString(stmt->syntaxKind)); + } + + match = InspectNode(stmt, target); + if (match != NULL) return match; + /*if (NodeMayHaveId(stmt)) { Node *candidateId = TryGetId(current); if (candidateId != NULL && strcmp(target, candidateId->identifier.name) == 0) - { return candidateId; - } - } + }*/ } break; } + return LookupIdNode(current->parent, current, target); } + +void IdentifierPass(Node *node) +{ + if (node == NULL) return; + + switch (node->syntaxKind) + { + case AllocExpression: + node->typeTag = MakeTypeTag(node); + break; + + case Declaration: + node->declaration.identifier->typeTag = MakeTypeTag(node); + break; + + case FunctionDeclaration: + node->functionDeclaration.functionSignature + ->functionSignature.identifier->typeTag = MakeTypeTag(node); + break;; + + case StructDeclaration: + node->structDeclaration.identifier->typeTag = MakeTypeTag(node); + break; + + case GenericArgument: + node->genericArgument.identifier->typeTag = MakeTypeTag(node); + break; + + case Identifier: + { + if (node->typeTag != NULL) return; + + char *name = node->identifier.name; + Node *declaration = LookupIdNode(node, NULL, name); + if (declaration == NULL) + { + /* FIXME: Express this case as an error with AST information. */ + fprintf(stderr, "wraith: Could not find definition of identifier %s.\n", name); + TypeTag *tag = (TypeTag *)malloc(sizeof(TypeTag)); + tag->type = Unknown; + node->typeTag = tag; + } + else + { + node->typeTag = declaration->typeTag; + } + break; + } + } + + Recurse(node, *IdentifierPass); +} \ No newline at end of file diff --git a/src/ast.h b/src/ast.h index 967cbe5..251f0f6 100644 --- a/src/ast.h +++ b/src/ast.h @@ -373,10 +373,12 @@ const char *SyntaxKindString(SyntaxKind syntaxKind); * function in all other cases. */ void Recurse(Node *node, void (*func)(Node*)); -void LinkParentPointers(Node *node); +void LinkParentPointers(Node *node, Node *prev); TypeTag *MakeTypeTag(Node *node); char *TypeTagToString(TypeTag *tag); Node *LookupIdNode(Node *current, Node *prev, char *target); + +void IdentifierPass(Node *node); #endif /* WRAITH_AST_H */ diff --git a/src/main.c b/src/main.c index 044210c..1b7c70b 100644 --- a/src/main.c +++ b/src/main.c @@ -86,23 +86,12 @@ int main(int argc, char *argv[]) } else { - LinkParentPointers(rootNode); - { - IdNode *idTree = MakeIdTree(rootNode, NULL); - printf("\n"); - PrintIdTree(idTree, /*tabCount=*/0); + LinkParentPointers(rootNode, NULL); + IdentifierPass(rootNode); + /*ConvertASTCustomsToGenerics(rootNode);*/ + PrintNode(rootNode, 0); - printf("\nConverting custom types in the ID-tree.\n"); - ConvertIdCustomsToGenerics(idTree); - printf("\n"); - PrintIdTree(idTree, /*tabCount=*/0); - - printf("\nConverting custom type nodes in the AST.\n"); - ConvertASTCustomsToGenerics(rootNode); - printf("\n"); - PrintNode(rootNode, /*tabCount=*/0); - - } + printf("Beginning codegen.\n"); exitCode = Codegen(rootNode, optimizationLevel); } } -- 2.25.1 From 7f2ca56b731b3397dd3a139cda7dd7f6cc71cd95 Mon Sep 17 00:00:00 2001 From: venko Date: Sun, 30 May 2021 13:07:12 -0700 Subject: [PATCH 14/17] Applies clang-format --- src/ast.c | 187 +++++++++++++++++++++++++++++++++++------------------- src/ast.h | 11 ++-- 2 files changed, 128 insertions(+), 70 deletions(-) diff --git a/src/ast.c b/src/ast.c index 51f416c..4e244c3 100644 --- a/src/ast.c +++ b/src/ast.c @@ -639,7 +639,8 @@ void PrintNode(Node *node, uint32_t tabCount) case GenericArguments: printf("\n"); - for (i = 0; i < node->genericArguments.count; i += 1) { + for (i = 0; i < node->genericArguments.count; i += 1) + { PrintNode(node->genericArguments.arguments[i], tabCount + 1); } return; @@ -727,7 +728,7 @@ void PrintNode(Node *node, uint32_t tabCount) } } -void Recurse(Node *node, void (*func)(Node*)) +void Recurse(Node *node, void (*func)(Node *)) { uint32_t i; switch (node->syntaxKind) @@ -763,7 +764,8 @@ void Recurse(Node *node, void (*func)(Node*)) return; case DeclarationSequence: - for (i = 0; i < node->declarationSequence.count; i += 1) { + for (i = 0; i < node->declarationSequence.count; i += 1) + { func(node->declarationSequence.sequence[i]); } return; @@ -776,7 +778,8 @@ void Recurse(Node *node, void (*func)(Node*)) return; case FunctionArgumentSequence: - for (i = 0; i < node->functionArgumentSequence.count; i += 1) { + for (i = 0; i < node->functionArgumentSequence.count; i += 1) + { func(node->functionArgumentSequence.sequence[i]); } return; @@ -792,7 +795,8 @@ void Recurse(Node *node, void (*func)(Node*)) return; case FunctionModifiers: - for (i = 0; i < node->functionModifiers.count; i += 1) { + for (i = 0; i < node->functionModifiers.count; i += 1) + { func(node->functionModifiers.sequence[i]); } return; @@ -806,7 +810,8 @@ void Recurse(Node *node, void (*func)(Node*)) return; case FunctionSignatureArguments: - for (i = 0; i < node->functionSignatureArguments.count; i += 1) { + for (i = 0; i < node->functionSignatureArguments.count; i += 1) + { func(node->functionSignatureArguments.sequence[i]); } return; @@ -817,7 +822,8 @@ void Recurse(Node *node, void (*func)(Node*)) return; case GenericArguments: - for (i = 0; i < node->genericArguments.count; i += 1) { + for (i = 0; i < node->genericArguments.count; i += 1) + { func(node->genericArguments.arguments[i]); } return; @@ -856,7 +862,8 @@ void Recurse(Node *node, void (*func)(Node*)) return; case StatementSequence: - for (i = 0; i < node->statementSequence.count; i += 1) { + for (i = 0; i < node->statementSequence.count; i += 1) + { func(node->statementSequence.sequence[i]); } return; @@ -880,8 +887,10 @@ void Recurse(Node *node, void (*func)(Node*)) return; default: - fprintf(stderr, "wraith: Unhandled SyntaxKind %s in recurse function.\n", - SyntaxKindString(node->syntaxKind)); + fprintf( + stderr, + "wraith: Unhandled SyntaxKind %s in recurse function.\n", + SyntaxKindString(node->syntaxKind)); return; } } @@ -939,8 +948,8 @@ TypeTag *MakeTypeTag(Node *node) case GenericArgument: tag->type = Generic; - tag->value.genericType = strdup(node->genericArgument.identifier - ->identifier.name); + tag->value.genericType = + strdup(node->genericArgument.identifier->identifier.name); break; case GenericTypeNode: @@ -986,13 +995,15 @@ char *TypeTagToString(TypeTag *tag) } case Custom: { - char *result = malloc(sizeof(char) * (strlen(tag->value.customType) + 8)); + char *result = + malloc(sizeof(char) * (strlen(tag->value.customType) + 8)); sprintf(result, "Custom<%s>", tag->value.customType); return result; } case Generic: { - char *result = malloc(sizeof(char) * (strlen(tag->value.customType) + 9)); + char *result = + malloc(sizeof(char) * (strlen(tag->value.customType) + 9)); sprintf(result, "Generic<%s>", tag->value.customType); return result; } @@ -1001,7 +1012,8 @@ char *TypeTagToString(TypeTag *tag) void LinkParentPointers(Node *node, Node *prev) { - if (node == NULL) return; + if (node == NULL) + return; node->parent = prev; @@ -1039,7 +1051,8 @@ void LinkParentPointers(Node *node, Node *prev) return; case DeclarationSequence: - for (i = 0; i < node->declarationSequence.count; i += 1) { + for (i = 0; i < node->declarationSequence.count; i += 1) + { LinkParentPointers(node->declarationSequence.sequence[i], node); } return; @@ -1052,8 +1065,11 @@ void LinkParentPointers(Node *node, Node *prev) return; case FunctionArgumentSequence: - for (i = 0; i < node->functionArgumentSequence.count; i += 1) { - LinkParentPointers(node->functionArgumentSequence.sequence[i], node); + for (i = 0; i < node->functionArgumentSequence.count; i += 1) + { + LinkParentPointers( + node->functionArgumentSequence.sequence[i], + node); } return; @@ -1068,7 +1084,8 @@ void LinkParentPointers(Node *node, Node *prev) return; case FunctionModifiers: - for (i = 0; i < node->functionModifiers.count; i += 1) { + for (i = 0; i < node->functionModifiers.count; i += 1) + { LinkParentPointers(node->functionModifiers.sequence[i], node); } return; @@ -1082,8 +1099,11 @@ void LinkParentPointers(Node *node, Node *prev) return; case FunctionSignatureArguments: - for (i = 0; i < node->functionSignatureArguments.count; i += 1) { - LinkParentPointers(node->functionSignatureArguments.sequence[i], node); + for (i = 0; i < node->functionSignatureArguments.count; i += 1) + { + LinkParentPointers( + node->functionSignatureArguments.sequence[i], + node); } return; @@ -1093,7 +1113,8 @@ void LinkParentPointers(Node *node, Node *prev) return; case GenericArguments: - for (i = 0; i < node->genericArguments.count; i += 1) { + for (i = 0; i < node->genericArguments.count; i += 1) + { LinkParentPointers(node->genericArguments.arguments[i], node); } return; @@ -1132,7 +1153,8 @@ void LinkParentPointers(Node *node, Node *prev) return; case StatementSequence: - for (i = 0; i < node->statementSequence.count; i += 1) { + for (i = 0; i < node->statementSequence.count; i += 1) + { LinkParentPointers(node->statementSequence.sequence[i], node); } return; @@ -1156,8 +1178,10 @@ void LinkParentPointers(Node *node, Node *prev) return; default: - fprintf(stderr, "wraith: Unhandled SyntaxKind %s in recurse function.\n", - SyntaxKindString(node->syntaxKind)); + fprintf( + stderr, + "wraith: Unhandled SyntaxKind %s in recurse function.\n", + SyntaxKindString(node->syntaxKind)); return; } } @@ -1166,20 +1190,26 @@ Node *GetIdFromStruct(Node *structDecl) { if (structDecl->syntaxKind != StructDeclaration) { - fprintf(stderr, "wraith: Attempted to call GetIdFromStruct on node with kind: %s.\n", - SyntaxKindString(structDecl->syntaxKind)); + fprintf( + stderr, + "wraith: Attempted to call GetIdFromStruct on node with kind: " + "%s.\n", + SyntaxKindString(structDecl->syntaxKind)); return NULL; } return structDecl->structDeclaration.identifier; } -Node *GetIdFromFunction(Node *funcDecl) +Node *GetIdFromFunction(Node *funcDecl) { if (funcDecl->syntaxKind != FunctionDeclaration) { - fprintf(stderr, "wraith: Attempted to call GetIdFromFunction on node with kind: %s.\n", - SyntaxKindString(funcDecl->syntaxKind)); + fprintf( + stderr, + "wraith: Attempted to call GetIdFromFunction on node with kind: " + "%s.\n", + SyntaxKindString(funcDecl->syntaxKind)); return NULL; } @@ -1191,8 +1221,11 @@ Node *GetIdFromDeclaration(Node *decl) { if (decl->syntaxKind != Declaration) { - fprintf(stderr, "wraith: Attempted to call GetIdFromDeclaration on node with kind: %s.\n", - SyntaxKindString(decl->syntaxKind)); + fprintf( + stderr, + "wraith: Attempted to call GetIdFromDeclaration on node with kind: " + "%s.\n", + SyntaxKindString(decl->syntaxKind)); } return decl->declaration.identifier; @@ -1200,16 +1233,20 @@ Node *GetIdFromDeclaration(Node *decl) bool AssignmentHasDeclaration(Node *assign) { - return (assign->syntaxKind == Assignment - && assign->assignmentStatement.left->syntaxKind == Declaration); + return ( + assign->syntaxKind == Assignment && + assign->assignmentStatement.left->syntaxKind == Declaration); } Node *GetIdFromAssignment(Node *assign) { if (assign->syntaxKind != Assignment) { - fprintf(stderr, "wraith: Attempted to call GetIdFromAssignment on node with kind: %s.\n", - SyntaxKindString(assign->syntaxKind)); + fprintf( + stderr, + "wraith: Attempted to call GetIdFromAssignment on node with kind: " + "%s.\n", + SyntaxKindString(assign->syntaxKind)); } if (AssignmentHasDeclaration(assign)) @@ -1253,8 +1290,8 @@ Node *TryGetId(Node *node) Node *LookupFunctionArgId(Node *funcDecl, char *target) { - Node *args = - funcDecl->functionDeclaration.functionSignature->functionSignature.arguments; + Node *args = funcDecl->functionDeclaration.functionSignature + ->functionSignature.arguments; uint32_t i; for (i = 0; i < args->functionArgumentSequence.count; i += 1) @@ -1262,15 +1299,17 @@ Node *LookupFunctionArgId(Node *funcDecl, char *target) Node *arg = args->functionArgumentSequence.sequence[i]; if (arg->syntaxKind != Declaration) { - fprintf(stderr, - "wraith: Encountered %s node in function signature args list.\n", - SyntaxKindString(arg->syntaxKind)); + fprintf( + stderr, + "wraith: Encountered %s node in function signature args " + "list.\n", + SyntaxKindString(arg->syntaxKind)); continue; } Node *argId = GetIdFromDeclaration(arg); if (argId != NULL && strcmp(target, argId->identifier.name) == 0) - { + { return argId; } } @@ -1295,32 +1334,37 @@ Node *LookupStructInternalId(Node *structDecl, char *target) Node *InspectNode(Node *node, char *target) { - /* If this node may have an identifier declaration inside it, attempt to look up the identifier + /* If this node may have an identifier declaration inside it, attempt to + * look up the identifier * node itself, returning it if it matches the given target name. */ if (NodeMayHaveId(node)) { Node *candidateId = TryGetId(node); - if (candidateId != NULL && strcmp(target, candidateId->identifier.name) == 0) + if (candidateId != NULL && + strcmp(target, candidateId->identifier.name) == 0) { return candidateId; } } - /* If the candidate node was not the one we wanted, but the node node is a function - * declaration, it's possible that the identifier we want is one of the function's - * parameters rather than the function's name itself. */ + /* If the candidate node was not the one we wanted, but the node node is a + * function declaration, it's possible that the identifier we want is one of + * the function's parameters rather than the function's name itself. */ if (node->syntaxKind == FunctionDeclaration) { Node *match = LookupFunctionArgId(node, target); - if (match != NULL) return match; + if (match != NULL) + return match; } - /* Likewise if the node node is a struct declaration, inspect the struct's internals + /* Likewise if the node node is a struct declaration, inspect the struct's + * internals * to see if a top-level definition is the one we're looking for. */ if (node->syntaxKind == StructDeclaration) { Node *match = LookupStructInternalId(node, target); - if (match != NULL) return match; + if (match != NULL) + return match; } return NULL; @@ -1328,12 +1372,15 @@ Node *InspectNode(Node *node, char *target) Node *LookupIdNode(Node *current, Node *prev, char *target) { - if (current == NULL) return NULL; + if (current == NULL) + return NULL; Node *match; - /* First inspect the current node to see if it contains the target identifier. */ + /* First inspect the current node to see if it contains the target + * identifier. */ match = InspectNode(current, target); - if (match != NULL) return match; + if (match != NULL) + return match; /* If this is the start of our search, we should not attempt to look at * child nodes. Only looking up the AST is valid at this point. @@ -1355,7 +1402,8 @@ Node *LookupIdNode(Node *current, Node *prev, char *target) { Node *decl = current->declarationSequence.sequence[i]; match = InspectNode(decl, target); - if (match != NULL) return match; + if (match != NULL) + return match; /*Node *declId = TryGetId(decl); if (declId != NULL && strcmp(target, declId->identifier.name) == 0) return declId;*/ @@ -1375,19 +1423,22 @@ Node *LookupIdNode(Node *current, Node *prev, char *target) for (i = 0; i < idxLimit; i += 1) { Node *stmt = current->statementSequence.sequence[i]; - if (stmt == prev) continue; + if (stmt == prev) + continue; - if (strcmp(target, "g") == 0) { + if (strcmp(target, "g") == 0) + { printf("info: %s\n", SyntaxKindString(stmt->syntaxKind)); } match = InspectNode(stmt, target); - if (match != NULL) return match; + if (match != NULL) + return match; /*if (NodeMayHaveId(stmt)) { Node *candidateId = TryGetId(current); - if (candidateId != NULL && strcmp(target, candidateId->identifier.name) == 0) - return candidateId; + if (candidateId != NULL && strcmp(target, + candidateId->identifier.name) == 0) return candidateId; }*/ } break; @@ -1398,7 +1449,8 @@ Node *LookupIdNode(Node *current, Node *prev, char *target) void IdentifierPass(Node *node) { - if (node == NULL) return; + if (node == NULL) + return; switch (node->syntaxKind) { @@ -1411,9 +1463,10 @@ void IdentifierPass(Node *node) break; case FunctionDeclaration: - node->functionDeclaration.functionSignature - ->functionSignature.identifier->typeTag = MakeTypeTag(node); - break;; + node->functionDeclaration.functionSignature->functionSignature + .identifier->typeTag = MakeTypeTag(node); + break; + ; case StructDeclaration: node->structDeclaration.identifier->typeTag = MakeTypeTag(node); @@ -1425,14 +1478,18 @@ void IdentifierPass(Node *node) case Identifier: { - if (node->typeTag != NULL) return; + if (node->typeTag != NULL) + return; char *name = node->identifier.name; Node *declaration = LookupIdNode(node, NULL, name); if (declaration == NULL) { /* FIXME: Express this case as an error with AST information. */ - fprintf(stderr, "wraith: Could not find definition of identifier %s.\n", name); + fprintf( + stderr, + "wraith: Could not find definition of identifier %s.\n", + name); TypeTag *tag = (TypeTag *)malloc(sizeof(TypeTag)); tag->type = Unknown; node->typeTag = tag; diff --git a/src/ast.h b/src/ast.h index 251f0f6..179d0e6 100644 --- a/src/ast.h +++ b/src/ast.h @@ -367,11 +367,12 @@ Node *MakeForLoopNode( void PrintNode(Node *node, uint32_t tabCount); const char *SyntaxKindString(SyntaxKind syntaxKind); -/* Helper function for applying a void function generically over the children of an AST node. - * Used for functions that need to traverse the entire tree but only perform operations on a subset - * of node types. Such functions can match the syntaxKinds relevant to their purpose and invoke this - * function in all other cases. */ -void Recurse(Node *node, void (*func)(Node*)); +/* Helper function for applying a void function generically over the children of + * an AST node. Used for functions that need to traverse the entire tree but + * only perform operations on a subset of node types. Such functions can match + * the syntaxKinds relevant to their purpose and invoke this function in all + * other cases. */ +void Recurse(Node *node, void (*func)(Node *)); void LinkParentPointers(Node *node, Node *prev); -- 2.25.1 From a65fea070ab76dfa6a64caeb19ce9e6c32371979 Mon Sep 17 00:00:00 2001 From: venko Date: Sun, 30 May 2021 13:20:17 -0700 Subject: [PATCH 15/17] Removes extraneous semicolon --- src/ast.c | 1 - 1 file changed, 1 deletion(-) diff --git a/src/ast.c b/src/ast.c index 4e244c3..6252788 100644 --- a/src/ast.c +++ b/src/ast.c @@ -1466,7 +1466,6 @@ void IdentifierPass(Node *node) node->functionDeclaration.functionSignature->functionSignature .identifier->typeTag = MakeTypeTag(node); break; - ; case StructDeclaration: node->structDeclaration.identifier->typeTag = MakeTypeTag(node); -- 2.25.1 From 3553269fb0199234474bd7e46ec4331b68dd57b3 Mon Sep 17 00:00:00 2001 From: venko Date: Mon, 31 May 2021 17:03:18 -0700 Subject: [PATCH 16/17] Refactors out identcheck and typeutils code --- CMakeLists.txt | 6 +- access.w | 13 ++ ordering.w | 15 ++ src/ast.c | 319 ----------------------------- src/ast.h | 3 - src/identcheck.c | 507 ----------------------------------------------- src/identcheck.h | 51 ----- src/main.c | 10 +- src/typeutils.c | 82 -------- src/typeutils.h | 13 -- src/validation.c | 452 ++++++++++++++++++++++++++++++++++++++++++ src/validation.h | 10 + 12 files changed, 498 insertions(+), 983 deletions(-) create mode 100644 access.w create mode 100644 ordering.w delete mode 100644 src/identcheck.c delete mode 100644 src/identcheck.h delete mode 100644 src/typeutils.c delete mode 100644 src/typeutils.h create mode 100644 src/validation.c create mode 100644 src/validation.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 4cb94d8..1b7a9e5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -41,15 +41,13 @@ add_executable( # Source src/ast.h src/codegen.h - src/identcheck.h src/parser.h - src/typeutils.h + src/validation.h src/util.h src/ast.c src/codegen.c - src/identcheck.c src/parser.c - src/typeutils.c + src/validation.c src/util.c src/main.c # Generated code diff --git a/access.w b/access.w new file mode 100644 index 0000000..2534dd2 --- /dev/null +++ b/access.w @@ -0,0 +1,13 @@ +struct G { + Foo(t: bool): bool { + return t; + } +} + +struct Program { + static main(): int { + g: G = alloc G; + g.Foo(true); + return 0; + } +} \ No newline at end of file diff --git a/ordering.w b/ordering.w new file mode 100644 index 0000000..f446a68 --- /dev/null +++ b/ordering.w @@ -0,0 +1,15 @@ +struct Foo { + static Func(): void { + Func2(); + } + + static Func2(): void { + Func(); + } +} + +struct Program { + static main(): int { + return 0; + } +} \ No newline at end of file diff --git a/src/ast.c b/src/ast.c index 6252788..d57ef1f 100644 --- a/src/ast.c +++ b/src/ast.c @@ -1,6 +1,5 @@ #include "ast.h" -#include #include #include @@ -1185,321 +1184,3 @@ void LinkParentPointers(Node *node, Node *prev) return; } } - -Node *GetIdFromStruct(Node *structDecl) -{ - if (structDecl->syntaxKind != StructDeclaration) - { - fprintf( - stderr, - "wraith: Attempted to call GetIdFromStruct on node with kind: " - "%s.\n", - SyntaxKindString(structDecl->syntaxKind)); - return NULL; - } - - return structDecl->structDeclaration.identifier; -} - -Node *GetIdFromFunction(Node *funcDecl) -{ - if (funcDecl->syntaxKind != FunctionDeclaration) - { - fprintf( - stderr, - "wraith: Attempted to call GetIdFromFunction on node with kind: " - "%s.\n", - SyntaxKindString(funcDecl->syntaxKind)); - return NULL; - } - - Node *sig = funcDecl->functionDeclaration.functionSignature; - return sig->functionSignature.identifier; -} - -Node *GetIdFromDeclaration(Node *decl) -{ - if (decl->syntaxKind != Declaration) - { - fprintf( - stderr, - "wraith: Attempted to call GetIdFromDeclaration on node with kind: " - "%s.\n", - SyntaxKindString(decl->syntaxKind)); - } - - return decl->declaration.identifier; -} - -bool AssignmentHasDeclaration(Node *assign) -{ - return ( - assign->syntaxKind == Assignment && - assign->assignmentStatement.left->syntaxKind == Declaration); -} - -Node *GetIdFromAssignment(Node *assign) -{ - if (assign->syntaxKind != Assignment) - { - fprintf( - stderr, - "wraith: Attempted to call GetIdFromAssignment on node with kind: " - "%s.\n", - SyntaxKindString(assign->syntaxKind)); - } - - if (AssignmentHasDeclaration(assign)) - { - return GetIdFromDeclaration(assign->assignmentStatement.left); - } - - return NULL; -} - -bool NodeMayHaveId(Node *node) -{ - switch (node->syntaxKind) - { - case StructDeclaration: - case FunctionDeclaration: - case Declaration: - case Assignment: - return true; - default: - return false; - } -} - -Node *TryGetId(Node *node) -{ - switch (node->syntaxKind) - { - case Assignment: - return GetIdFromAssignment(node); - case Declaration: - return GetIdFromDeclaration(node); - case FunctionDeclaration: - return GetIdFromFunction(node); - case StructDeclaration: - return GetIdFromStruct(node); - default: - return NULL; - } -} - -Node *LookupFunctionArgId(Node *funcDecl, char *target) -{ - Node *args = funcDecl->functionDeclaration.functionSignature - ->functionSignature.arguments; - - uint32_t i; - for (i = 0; i < args->functionArgumentSequence.count; i += 1) - { - Node *arg = args->functionArgumentSequence.sequence[i]; - if (arg->syntaxKind != Declaration) - { - fprintf( - stderr, - "wraith: Encountered %s node in function signature args " - "list.\n", - SyntaxKindString(arg->syntaxKind)); - continue; - } - - Node *argId = GetIdFromDeclaration(arg); - if (argId != NULL && strcmp(target, argId->identifier.name) == 0) - { - return argId; - } - } - - return NULL; -} - -Node *LookupStructInternalId(Node *structDecl, char *target) -{ - Node *decls = structDecl->structDeclaration.declarationSequence; - - uint32_t i; - for (i = 0; i < decls->declarationSequence.count; i += 1) - { - Node *match = TryGetId(decls->declarationSequence.sequence[i]); - if (match != NULL && strcmp(target, match->identifier.name) == 0) - return match; - } - - return NULL; -} - -Node *InspectNode(Node *node, char *target) -{ - /* If this node may have an identifier declaration inside it, attempt to - * look up the identifier - * node itself, returning it if it matches the given target name. */ - if (NodeMayHaveId(node)) - { - Node *candidateId = TryGetId(node); - if (candidateId != NULL && - strcmp(target, candidateId->identifier.name) == 0) - { - return candidateId; - } - } - - /* If the candidate node was not the one we wanted, but the node node is a - * function declaration, it's possible that the identifier we want is one of - * the function's parameters rather than the function's name itself. */ - if (node->syntaxKind == FunctionDeclaration) - { - Node *match = LookupFunctionArgId(node, target); - if (match != NULL) - return match; - } - - /* Likewise if the node node is a struct declaration, inspect the struct's - * internals - * to see if a top-level definition is the one we're looking for. */ - if (node->syntaxKind == StructDeclaration) - { - Node *match = LookupStructInternalId(node, target); - if (match != NULL) - return match; - } - - return NULL; -} - -Node *LookupIdNode(Node *current, Node *prev, char *target) -{ - if (current == NULL) - return NULL; - Node *match; - - /* First inspect the current node to see if it contains the target - * identifier. */ - match = InspectNode(current, target); - if (match != NULL) - return match; - - /* If this is the start of our search, we should not attempt to look at - * child nodes. Only looking up the AST is valid at this point. - * - * This has the notable side-effect that this function will return NULL if - * you attempt to look up a struct's internals starting from the node - * representing the struct itself. The same is true for functions. */ - if (prev == NULL) - { - return LookupIdNode(current->parent, current, target); - } - - uint32_t i; - uint32_t idxLimit; - switch (current->syntaxKind) - { - case DeclarationSequence: - for (i = 0; i < current->declarationSequence.count; i += 1) - { - Node *decl = current->declarationSequence.sequence[i]; - match = InspectNode(decl, target); - if (match != NULL) - return match; - /*Node *declId = TryGetId(decl); - if (declId != NULL && strcmp(target, declId->identifier.name) == 0) - return declId;*/ - } - break; - case StatementSequence: - idxLimit = current->statementSequence.count; - for (i = 0; i < current->statementSequence.count; i += 1) - { - if (current->statementSequence.sequence[i] == prev) - { - idxLimit = i; - break; - } - } - - for (i = 0; i < idxLimit; i += 1) - { - Node *stmt = current->statementSequence.sequence[i]; - if (stmt == prev) - continue; - - if (strcmp(target, "g") == 0) - { - printf("info: %s\n", SyntaxKindString(stmt->syntaxKind)); - } - - match = InspectNode(stmt, target); - if (match != NULL) - return match; - /*if (NodeMayHaveId(stmt)) - { - Node *candidateId = TryGetId(current); - if (candidateId != NULL && strcmp(target, - candidateId->identifier.name) == 0) return candidateId; - }*/ - } - break; - } - - return LookupIdNode(current->parent, current, target); -} - -void IdentifierPass(Node *node) -{ - if (node == NULL) - return; - - switch (node->syntaxKind) - { - case AllocExpression: - node->typeTag = MakeTypeTag(node); - break; - - case Declaration: - node->declaration.identifier->typeTag = MakeTypeTag(node); - break; - - case FunctionDeclaration: - node->functionDeclaration.functionSignature->functionSignature - .identifier->typeTag = MakeTypeTag(node); - break; - - case StructDeclaration: - node->structDeclaration.identifier->typeTag = MakeTypeTag(node); - break; - - case GenericArgument: - node->genericArgument.identifier->typeTag = MakeTypeTag(node); - break; - - case Identifier: - { - if (node->typeTag != NULL) - return; - - char *name = node->identifier.name; - Node *declaration = LookupIdNode(node, NULL, name); - if (declaration == NULL) - { - /* FIXME: Express this case as an error with AST information. */ - fprintf( - stderr, - "wraith: Could not find definition of identifier %s.\n", - name); - TypeTag *tag = (TypeTag *)malloc(sizeof(TypeTag)); - tag->type = Unknown; - node->typeTag = tag; - } - else - { - node->typeTag = declaration->typeTag; - } - break; - } - } - - Recurse(node, *IdentifierPass); -} \ No newline at end of file diff --git a/src/ast.h b/src/ast.h index 179d0e6..3d36cf6 100644 --- a/src/ast.h +++ b/src/ast.h @@ -1,7 +1,6 @@ #ifndef WRAITH_AST_H #define WRAITH_AST_H -#include "identcheck.h" #include /* -Wpedantic nameless union/struct silencing */ @@ -300,7 +299,6 @@ struct Node } unaryExpression; }; TypeTag *typeTag; - IdNode *idLink; }; const char *SyntaxKindString(SyntaxKind syntaxKind); @@ -381,5 +379,4 @@ char *TypeTagToString(TypeTag *tag); Node *LookupIdNode(Node *current, Node *prev, char *target); -void IdentifierPass(Node *node); #endif /* WRAITH_AST_H */ diff --git a/src/identcheck.c b/src/identcheck.c deleted file mode 100644 index 39498b0..0000000 --- a/src/identcheck.c +++ /dev/null @@ -1,507 +0,0 @@ -#include -#include -#include -#include -#include - -#include "ast.h" -#include "identcheck.h" - -IdNode *MakeIdNode(NodeType type, char *name, IdNode *parent) -{ - IdNode *node = (IdNode *)malloc(sizeof(IdNode)); - node->type = type; - node->name = strdup(name); - node->parent = parent; - node->childCount = 0; - node->childCapacity = 0; - node->children = NULL; - node->typeTag = NULL; - return node; -} - -void AddChildToNode(IdNode *node, IdNode *child) -{ - if (child == NULL) - return; - - if (node->children == NULL) - { - node->childCapacity = 2; - node->children = - (IdNode **)malloc(sizeof(IdNode *) * node->childCapacity); - } - else if (node->childCount == node->childCapacity) - { - node->childCapacity *= 2; - node->children = (IdNode **)realloc( - node->children, - sizeof(IdNode *) * node->childCapacity); - } - - node->children[node->childCount] = child; - node->childCount += 1; -} - -IdNode *MakeIdTree(Node *astNode, IdNode *parent) -{ - uint32_t i; - IdNode *mainNode; - switch (astNode->syntaxKind) - { - case AccessExpression: - AddChildToNode( - parent, - MakeIdTree(astNode->accessExpression.accessee, parent)); - AddChildToNode( - parent, - MakeIdTree(astNode->accessExpression.accessor, parent)); - return NULL; - - case AllocExpression: - astNode->typeTag = MakeTypeTag(astNode); - return NULL; - - case Assignment: - { - if (astNode->assignmentStatement.left->syntaxKind == Declaration) - { - return MakeIdTree(astNode->assignmentStatement.left, parent); - } - else - { - AddChildToNode( - parent, - MakeIdTree(astNode->assignmentStatement.left, parent)); - AddChildToNode( - parent, - MakeIdTree(astNode->assignmentStatement.right, parent)); - return NULL; - } - } - - case BinaryExpression: - AddChildToNode( - parent, - MakeIdTree(astNode->binaryExpression.left, parent)); - AddChildToNode( - parent, - MakeIdTree(astNode->binaryExpression.right, parent)); - return NULL; - - case Declaration: - { - Node *idNode = astNode->declaration.identifier; - mainNode = MakeIdNode(Variable, idNode->identifier.name, parent); - mainNode->typeTag = MakeTypeTag(astNode); - idNode->typeTag = mainNode->typeTag; - break; - } - - case DeclarationSequence: - { - mainNode = MakeIdNode(UnorderedScope, "", parent); - for (i = 0; i < astNode->declarationSequence.count; i++) - { - AddChildToNode( - mainNode, - MakeIdTree(astNode->declarationSequence.sequence[i], mainNode)); - } - break; - } - - case ForLoop: - { - Node *loopDecl = astNode->forLoop.declaration; - Node *loopBody = astNode->forLoop.statementSequence; - mainNode = MakeIdNode(OrderedScope, "for-loop", parent); - AddChildToNode(mainNode, MakeIdTree(loopDecl, mainNode)); - AddChildToNode(mainNode, MakeIdTree(loopBody, mainNode)); - break; - } - - case FunctionArgumentSequence: - for (i = 0; i < astNode->functionArgumentSequence.count; i++) - { - AddChildToNode( - parent, - MakeIdTree( - astNode->functionArgumentSequence.sequence[i], - parent)); - } - return NULL; - - case FunctionCallExpression: - AddChildToNode( - parent, - MakeIdTree(astNode->functionCallExpression.identifier, parent)); - AddChildToNode( - parent, - MakeIdTree( - astNode->functionCallExpression.argumentSequence, - parent)); - return NULL; - - case FunctionDeclaration: - { - Node *sigNode = astNode->functionDeclaration.functionSignature; - Node *idNode = sigNode->functionSignature.identifier; - char *funcName = idNode->identifier.name; - mainNode = MakeIdNode(Function, funcName, parent); - mainNode->typeTag = MakeTypeTag(astNode); - idNode->typeTag = mainNode->typeTag; - MakeIdTree(sigNode->functionSignature.genericArguments, mainNode); - MakeIdTree(sigNode->functionSignature.arguments, mainNode); - MakeIdTree(astNode->functionDeclaration.functionBody, mainNode); - break; - } - - case FunctionSignatureArguments: - { - for (i = 0; i < astNode->functionSignatureArguments.count; i++) - { - Node *argNode = astNode->functionSignatureArguments.sequence[i]; - AddChildToNode(parent, MakeIdTree(argNode, parent)); - } - return NULL; - } - - case GenericArgument: - { - char *name = astNode->genericArgument.identifier->identifier.name; - mainNode = MakeIdNode(GenericType, name, parent); - break; - } - - case GenericArguments: - { - for (i = 0; i < astNode->genericArguments.count; i += 1) - { - Node *argNode = astNode->genericArguments.arguments[i]; - AddChildToNode(parent, MakeIdTree(argNode, parent)); - } - return NULL; - } - - case Identifier: - { - char *name = astNode->identifier.name; - mainNode = MakeIdNode(Placeholder, name, parent); - Node *lookupNode = LookupIdNode(astNode, NULL, name); - if (lookupNode == NULL) - { - fprintf(stderr, "wraith: Could not find IdNode for id %s\n", name); - TypeTag *tag = (TypeTag *)malloc(sizeof(TypeTag)); - tag->type = Unknown; - astNode->typeTag = tag; - } - else - { - astNode->typeTag = lookupNode->typeTag; - } - break; - } - - case IfStatement: - { - Node *clause = astNode->ifStatement.expression; - Node *stmtSeq = astNode->ifStatement.statementSequence; - mainNode = MakeIdNode(OrderedScope, "if", parent); - MakeIdTree(clause, mainNode); - MakeIdTree(stmtSeq, mainNode); - break; - } - - case IfElseStatement: - { - Node *ifNode = astNode->ifElseStatement.ifStatement; - Node *elseStmts = astNode->ifElseStatement.elseStatement; - mainNode = MakeIdNode(OrderedScope, "if-else", parent); - IdNode *ifBranch = MakeIdTree(ifNode, mainNode); - AddChildToNode(mainNode, ifBranch); - IdNode *elseScope = MakeIdNode(OrderedScope, "else", mainNode); - MakeIdTree(elseStmts, elseScope); - AddChildToNode(mainNode, elseScope); - break; - } - - case ReferenceTypeNode: - AddChildToNode(parent, MakeIdTree(astNode->referenceType.type, parent)); - return NULL; - - case Return: - AddChildToNode( - parent, - MakeIdTree(astNode->returnStatement.expression, parent)); - return NULL; - - case StatementSequence: - { - for (i = 0; i < astNode->statementSequence.count; i++) - { - Node *argNode = astNode->statementSequence.sequence[i]; - AddChildToNode(parent, MakeIdTree(argNode, parent)); - } - return NULL; - } - - case StructDeclaration: - { - Node *idNode = astNode->structDeclaration.identifier; - Node *declsNode = astNode->structDeclaration.declarationSequence; - mainNode = MakeIdNode(Struct, idNode->identifier.name, parent); - mainNode->typeTag = MakeTypeTag(astNode); - for (i = 0; i < declsNode->declarationSequence.count; i++) - { - Node *decl = declsNode->declarationSequence.sequence[i]; - AddChildToNode(mainNode, MakeIdTree(decl, mainNode)); - } - break; - } - - case Type: - AddChildToNode(parent, MakeIdTree(astNode->type.typeNode, parent)); - return NULL; - - case UnaryExpression: - AddChildToNode( - parent, - MakeIdTree(astNode->unaryExpression.child, parent)); - return NULL; - - case Comment: - case CustomTypeNode: - case FunctionModifiers: - case FunctionSignature: - case Number: - case PrimitiveTypeNode: - case ReturnVoid: - case StaticModifier: - case StringLiteral: - return NULL; - } - - astNode->idLink = mainNode; - return mainNode; -} - -void PrintIdNode(IdNode *node) -{ - if (node == NULL) - { - fprintf( - stderr, - "wraith: Attempted to call PrintIdNode with null value.\n"); - return; - } - - switch (node->type) - { - case Placeholder: - printf("Placeholder (%s)\n", node->name); - break; - case OrderedScope: - printf("OrderedScope (%s)\n", node->name); - break; - case UnorderedScope: - printf("UnorderedScope (%s)\n", node->name); - break; - case Struct: - printf("%s : %s\n", node->name, TypeTagToString(node->typeTag)); - break; - case Function: - printf( - "%s : Function<%s>\n", - node->name, - TypeTagToString(node->typeTag)); - break; - case Variable: - printf("%s : %s\n", node->name, TypeTagToString(node->typeTag)); - break; - case GenericType: - printf("Generic type: %s\n", node->name); - break; - case Alloc: - printf("Alloc: %s\n", TypeTagToString(node->typeTag)); - break; - } -} - -void PrintIdTree(IdNode *tree, uint32_t tabCount) -{ - if (tree == NULL) - { - fprintf( - stderr, - "wraith: Attempted to call PrintIdTree on a null value.\n"); - return; - } - - uint32_t i; - for (i = 0; i < tabCount; i++) - { - printf("| "); - } - - PrintIdNode(tree); - - for (i = 0; i < tree->childCount; i++) - { - PrintIdTree(tree->children[i], tabCount + 1); - } -} - -int PrintAncestors(IdNode *node) -{ - if (node == NULL) - return -1; - - int i; - int indent = 1; - indent += PrintAncestors(node->parent); - for (i = 0; i < indent; i++) - { - printf(" "); - } - PrintIdNode(node); - return indent; -} - -IdNode *LookdownId(IdNode *root, NodeType targetType, char *targetName) -{ - if (root == NULL) - { - fprintf( - stderr, - "wraith: Attempted to call LookdownId on a null value.\n"); - return NULL; - } - - IdNode *result = NULL; - IdNode **frontier = (IdNode **)malloc(sizeof(IdNode *)); - frontier[0] = root; - uint32_t frontierCount = 1; - - while (frontierCount > 0) - { - IdNode *current = frontier[0]; - - if (current->type == targetType && - strcmp(current->name, targetName) == 0) - { - result = current; - break; - } - - uint32_t i; - for (i = 1; i < frontierCount; i++) - { - frontier[i - 1] = frontier[i]; - } - size_t newSize = frontierCount + current->childCount - 1; - if (frontierCount != newSize) - { - frontier = (IdNode **)realloc(frontier, sizeof(IdNode *) * newSize); - } - for (i = 0; i < current->childCount; i++) - { - frontier[frontierCount + i - 1] = current->children[i]; - } - frontierCount = newSize; - } - - free(frontier); - return result; -} - -bool ScopeHasOrdering(IdNode *node) -{ - switch (node->type) - { - case OrderedScope: - case Function: - case Variable: /* this is only technically true */ - return true; - default: - return false; - } -} - -IdNode *LookupId(IdNode *node, IdNode *prev, char *target) -{ - if (node == NULL) - { - return NULL; - } - - if (strcmp(node->name, target) == 0 && node->type != Placeholder) - { - return node; - } - - /* If this is the start of our search, we should not attempt to look at - * child nodes. Only looking up the scope tree is valid at this point. - * - * This has the notable side-effect that this function will return NULL if - * you attempt to look up a struct's internals starting from the node - * representing the struct itself. This is because an IdNode corresponds to - * the location *where an identifier is first declared.* Thus, an identifier - * has no knowledge of identifiers declared "inside" of it. - */ - if (prev == NULL) - { - return LookupId(node->parent, node, target); - } - - /* If the current node forms an ordered scope then we want to prevent - * ourselves from looking up identifiers declared after the scope we have - * just come from. - */ - uint32_t idxLimit; - if (ScopeHasOrdering(node)) - { - uint32_t i; - for (i = 0, idxLimit = 0; i < node->childCount; i++, idxLimit++) - { - if (node->children[i] == prev) - { - break; - } - } - } - else - { - idxLimit = node->childCount; - } - - uint32_t i; - for (i = 0; i < idxLimit; i++) - { - IdNode *child = node->children[i]; - if (child == prev || child->type == Placeholder) - { - /* Do not inspect the node we just came from or placeholders. */ - continue; - } - - if (strcmp(child->name, target) == 0) - { - return child; - } - - if (child->type == Struct) - { - uint32_t j; - for (j = 0; j < child->childCount; j++) - { - IdNode *grandchild = child->children[j]; - if (strcmp(grandchild->name, target) == 0) - { - return grandchild; - } - } - } - } - - return LookupId(node->parent, node, target); -} diff --git a/src/identcheck.h b/src/identcheck.h deleted file mode 100644 index 8b287dd..0000000 --- a/src/identcheck.h +++ /dev/null @@ -1,51 +0,0 @@ -/* Validates identifier usage in an AST. */ - -#ifndef WRAITH_IDENTCHECK_H -#define WRAITH_IDENTCHECK_H - -#include - -#include "ast.h" - -struct TypeTag; -struct Node; - -typedef enum NodeType -{ - Placeholder, - UnorderedScope, - OrderedScope, - Struct, - Function, - Variable, - GenericType, - Alloc -} NodeType; - -typedef struct IdNode -{ - NodeType type; - char *name; - struct TypeTag *typeTag; - struct IdNode *parent; - struct IdNode **children; - uint32_t childCount; - uint32_t childCapacity; -} IdNode; - -typedef struct IdStatus -{ - enum StatusCode - { - Valid, - } StatusCode; -} IdStatus; - -IdNode *MakeIdTree(struct Node *astNode, IdNode *parent); -void PrintIdNode(IdNode *node); -void PrintIdTree(IdNode *tree, uint32_t tabCount); -int PrintAncestors(IdNode *node); -IdNode *LookdownId(IdNode *root, NodeType targetType, char *targetName); -IdNode *LookupId(IdNode *node, IdNode *prev, char *target); - -#endif /* WRAITH_IDENTCHECK_H */ diff --git a/src/main.c b/src/main.c index 1b7c70b..f3f8d36 100644 --- a/src/main.c +++ b/src/main.c @@ -2,9 +2,8 @@ #include #include "codegen.h" -#include "identcheck.h" #include "parser.h" -#include "typeutils.h" +#include "validation.h" int main(int argc, char *argv[]) { @@ -87,8 +86,11 @@ int main(int argc, char *argv[]) else { LinkParentPointers(rootNode, NULL); - IdentifierPass(rootNode); - /*ConvertASTCustomsToGenerics(rootNode);*/ + /* FIXME: ValidateIdentifiers should return some sort of + error status object. */ + ValidateIdentifiers(rootNode); + TagIdentifierTypes(rootNode); + ConvertCustomsToGenerics(rootNode); PrintNode(rootNode, 0); printf("Beginning codegen.\n"); diff --git a/src/typeutils.c b/src/typeutils.c deleted file mode 100644 index 4ebddea..0000000 --- a/src/typeutils.c +++ /dev/null @@ -1,82 +0,0 @@ -#include "typeutils.h" - -#include -#include -#include - -void ConvertIdCustomsToGenerics(IdNode *node) { - uint32_t i; - switch(node->type) - { - case UnorderedScope: - case OrderedScope: - case Struct: - /* FIXME: This case will need to be modified to handle type parameters over structs. */ - for (i = 0; i < node->childCount; i += 1) { - ConvertIdCustomsToGenerics(node->children[i]); - } - return; - - case Variable: { - TypeTag *varType = node->typeTag; - if (varType->type == Custom) { - IdNode *x = LookupId(node->parent, node, varType->value.customType); - if (x != NULL && x->type == GenericType) { - varType->type = Generic; - } - } - return; - } - - case Function: { - TypeTag *funcType = node->typeTag; - if (funcType->type == Custom) { - /* For functions we have to handle the type lookup manually since the generic type - * identifiers are declared as children of the function's IdNode. */ - for (i = 0; i < node->childCount; i += 1) { - IdNode *child = node->children[i]; - if (child->type == GenericType && strcmp(child->name, funcType->value.customType) == 0) { - funcType->type = Generic; - } - } - } - - for (i = 0; i < node->childCount; i += 1) { - ConvertIdCustomsToGenerics(node->children[i]); - } - return; - } - } -} - -/* FIXME: This function will need to be modified to handle type parameters over structs. */ -void ConvertASTCustomsToGenerics(Node *node) { - switch (node->syntaxKind) - { - case Declaration: - { - Node *type = node->declaration.type->type.typeNode; - Node *id = node->declaration.identifier; - if (id->typeTag->type == Generic && type->syntaxKind == CustomTypeNode) { - free(node->declaration.type); - node->declaration.type = MakeGenericTypeNode(id->typeTag->value.genericType); - } - return; - } - - case FunctionSignature: - { - Node *id = node->functionSignature.identifier; - Node *type = node->functionSignature.type; - if (id->typeTag->type == Generic && type->syntaxKind == CustomTypeNode) { - free(node->functionSignature.type); - node->functionSignature.type = MakeGenericTypeNode(id->typeTag->value.genericType); - } - ConvertASTCustomsToGenerics(node->functionSignature.arguments); - return; - } - - default: - Recurse(node, *ConvertASTCustomsToGenerics); - } -} diff --git a/src/typeutils.h b/src/typeutils.h deleted file mode 100644 index 2e752e0..0000000 --- a/src/typeutils.h +++ /dev/null @@ -1,13 +0,0 @@ -/* Helper functions for working with types in the AST and ID-tree. */ - -#ifndef WRAITH_TYPEUTILS_H -#define WRAITH_TYPEUTILS_H - -#include "ast.h" -#include "identcheck.h" - -/* FIXME: These two functions will need to be modified to handle type parameters over structs. */ -void ConvertIdCustomsToGenerics(IdNode *node); -void ConvertASTCustomsToGenerics(Node *node); - -#endif /* WRAITH_TYPEUTILS_H */ diff --git a/src/validation.c b/src/validation.c new file mode 100644 index 0000000..e62c4dd --- /dev/null +++ b/src/validation.c @@ -0,0 +1,452 @@ +#include "validation.h" + +#include +#include +#include +#include + +Node *GetIdFromStruct(Node *structDecl) +{ + if (structDecl->syntaxKind != StructDeclaration) + { + fprintf( + stderr, + "wraith: Attempted to call GetIdFromStruct on node with kind: " + "%s.\n", + SyntaxKindString(structDecl->syntaxKind)); + return NULL; + } + + return structDecl->structDeclaration.identifier; +} + +Node *GetIdFromFunction(Node *funcDecl) +{ + if (funcDecl->syntaxKind != FunctionDeclaration) + { + fprintf( + stderr, + "wraith: Attempted to call GetIdFromFunction on node with kind: " + "%s.\n", + SyntaxKindString(funcDecl->syntaxKind)); + return NULL; + } + + Node *sig = funcDecl->functionDeclaration.functionSignature; + return sig->functionSignature.identifier; +} + +Node *GetIdFromDeclaration(Node *decl) +{ + if (decl->syntaxKind != Declaration) + { + fprintf( + stderr, + "wraith: Attempted to call GetIdFromDeclaration on node with kind: " + "%s.\n", + SyntaxKindString(decl->syntaxKind)); + } + + return decl->declaration.identifier; +} + +bool AssignmentHasDeclaration(Node *assign) +{ + return ( + assign->syntaxKind == Assignment && + assign->assignmentStatement.left->syntaxKind == Declaration); +} + +Node *GetIdFromAssignment(Node *assign) +{ + if (assign->syntaxKind != Assignment) + { + fprintf( + stderr, + "wraith: Attempted to call GetIdFromAssignment on node with kind: " + "%s.\n", + SyntaxKindString(assign->syntaxKind)); + } + + if (AssignmentHasDeclaration(assign)) + { + return GetIdFromDeclaration(assign->assignmentStatement.left); + } + + return NULL; +} + +bool NodeMayHaveId(Node *node) +{ + switch (node->syntaxKind) + { + case StructDeclaration: + case FunctionDeclaration: + case Declaration: + case Assignment: + return true; + default: + return false; + } +} + +Node *TryGetId(Node *node) +{ + switch (node->syntaxKind) + { + case Assignment: + return GetIdFromAssignment(node); + case Declaration: + return GetIdFromDeclaration(node); + case FunctionDeclaration: + return GetIdFromFunction(node); + case StructDeclaration: + return GetIdFromStruct(node); + default: + return NULL; + } +} + +Node *LookupFunctionArgId(Node *funcDecl, char *target) +{ + Node *args = funcDecl->functionDeclaration.functionSignature + ->functionSignature.arguments; + + uint32_t i; + for (i = 0; i < args->functionArgumentSequence.count; i += 1) + { + Node *arg = args->functionArgumentSequence.sequence[i]; + if (arg->syntaxKind != Declaration) + { + fprintf( + stderr, + "wraith: Encountered %s node in function signature args " + "list.\n", + SyntaxKindString(arg->syntaxKind)); + continue; + } + + Node *argId = GetIdFromDeclaration(arg); + if (argId != NULL && strcmp(target, argId->identifier.name) == 0) + return argId; + } + + return NULL; +} + +Node *LookupStructInternalId(Node *structDecl, char *target) +{ + Node *decls = structDecl->structDeclaration.declarationSequence; + + uint32_t i; + for (i = 0; i < decls->declarationSequence.count; i += 1) + { + Node *match = TryGetId(decls->declarationSequence.sequence[i]); + if (match != NULL && strcmp(target, match->identifier.name) == 0) + return match; + } + + return NULL; +} + +Node *InspectNode(Node *node, char *target) +{ + /* If this node may have an identifier declaration inside it, attempt to + * look up the identifier + * node itself, returning it if it matches the given target name. */ + if (NodeMayHaveId(node)) + { + Node *candidateId = TryGetId(node); + if (candidateId != NULL && + strcmp(target, candidateId->identifier.name) == 0) + return candidateId; + } + + /* If the candidate node was not the one we wanted, but the node node is a + * function declaration, it's possible that the identifier we want is one of + * the function's parameters rather than the function's name itself. */ + if (node->syntaxKind == FunctionDeclaration) + { + Node *match = LookupFunctionArgId(node, target); + if (match != NULL) + return match; + } + + /* Likewise if the node node is a struct declaration, inspect the struct's + * internals + * to see if a top-level definition is the one we're looking for. */ + if (node->syntaxKind == StructDeclaration) + { + Node *match = LookupStructInternalId(node, target); + if (match != NULL) + return match; + } + + return NULL; +} + +/* FIXME: Handle staged lookups for AccessExpressions. */ +/* FIXME: Similar to above, disallow inspection of struct internals outside of + * AccessExpressions. */ +Node *LookupId(Node *current, Node *prev, char *target) +{ + if (current == NULL) + return NULL; + + Node *match; + + /* First inspect the current node to see if it contains the target + * identifier. */ + match = InspectNode(current, target); + if (match != NULL) + return match; + + /* If this is the start of our search, we should not attempt to look at + * child nodes. Only looking up the AST is valid at this point. + * + * This has the notable side-effect that this function will return NULL if + * you attempt to look up a struct's internals starting from the node + * representing the struct itself. The same is true for functions. */ + if (prev == NULL) + return LookupId(current->parent, current, target); + + uint32_t i; + uint32_t idxLimit; + switch (current->syntaxKind) + { + case DeclarationSequence: + for (i = 0; i < current->declarationSequence.count; i += 1) + { + Node *decl = current->declarationSequence.sequence[i]; + match = InspectNode(decl, target); + if (match != NULL) + return match; + } + break; + case StatementSequence: + idxLimit = current->statementSequence.count; + for (i = 0; i < current->statementSequence.count; i += 1) + { + if (current->statementSequence.sequence[i] == prev) + { + idxLimit = i; + break; + } + } + + for (i = 0; i < idxLimit; i += 1) + { + Node *stmt = current->statementSequence.sequence[i]; + if (stmt == prev) + break; + + match = InspectNode(stmt, target); + if (match != NULL) + return match; + } + break; + } + + return LookupId(current->parent, current, target); +} + +/* FIXME: This function should be extended to handle multi-stage ID lookups for + * AccessExpression nodes. */ +/* FIXME: Make this function return an error status object of some kind. + * A non-OK status should halt compilation. */ +void ValidateIdentifiers(Node *node) +{ + if (node == NULL) + return; + + /* Skip over generic arguments. They contain Identifiers but are not + * actually identifiers, they declare types. */ + if (node->syntaxKind == GenericArguments) + return; + + if (node->syntaxKind != Identifier) + { + Recurse(node, *ValidateIdentifiers); + return; + } + + char *name = node->identifier.name; + Node *decl = LookupId(node, NULL, name); + if (decl == NULL) + { + /* FIXME: Express this case as an error with AST information, see the + * FIXME comment above. */ + fprintf( + stderr, + "wraith: Could not find definition of identifier %s.\n", + name); + } +} + +/* FIXME: This function should be extended to handle multi-stage ID lookups for + * AccessExpression nodes. */ +void TagIdentifierTypes(Node *node) +{ + if (node == NULL) + return; + + switch (node->syntaxKind) + { + case AllocExpression: + node->typeTag = MakeTypeTag(node); + break; + + case Declaration: + node->declaration.identifier->typeTag = MakeTypeTag(node); + break; + + case FunctionDeclaration: + node->functionDeclaration.functionSignature->functionSignature + .identifier->typeTag = MakeTypeTag(node); + break; + + case StructDeclaration: + node->structDeclaration.identifier->typeTag = MakeTypeTag(node); + break; + + case GenericArgument: + node->genericArgument.identifier->typeTag = MakeTypeTag(node); + break; + + case Identifier: + { + if (node->typeTag != NULL) + return; + + char *name = node->identifier.name; + Node *declaration = LookupId(node, NULL, name); + /* FIXME: Remove this case once ValidateIdentifiers returns error status + * info and halts compilation. See ValidateIdentifiers FIXME. */ + if (declaration == NULL) + { + TypeTag *tag = (TypeTag *)malloc(sizeof(TypeTag)); + tag->type = Unknown; + node->typeTag = tag; + } + else + { + node->typeTag = declaration->typeTag; + } + break; + } + } + + Recurse(node, *TagIdentifierTypes); +} + +Node *LookupType(Node *current, char *target) +{ + if (current == NULL) + return NULL; + + switch (current->syntaxKind) + { + /* If we've encountered a function declaration, check to see if it's generic + * and, if so, if one of its type parameters is the target. */ + case FunctionDeclaration: + { + Node *typeArgs = current->functionDeclaration.functionSignature + ->functionSignature.genericArguments; + uint32_t i; + for (i = 0; i < typeArgs->genericArguments.count; i += 1) + { + Node *arg = typeArgs->genericArguments.arguments[i]; + Node *argId = arg->genericArgument.identifier; + char *argName = argId->identifier.name; + /* note: return the GenericArgument, not the Identifier, so that + * the caller can differentiate between generics and customs. */ + if (strcmp(target, argName) == 0) + return arg; + } + + return LookupType(current->parent, target); + } + + case StructDeclaration: + { + Node *structId = GetIdFromStruct(current); + if (strcmp(target, structId->identifier.name) == 0) + return structId; + + return LookupType(current->parent, target); + } + + /* If we encounter a declaration sequence, search each of its children for + * struct definitions in case one of them is the target. */ + case DeclarationSequence: + { + uint32_t i; + for (i = 0; i < current->declarationSequence.count; i += 1) + { + Node *decl = current->declarationSequence.sequence[i]; + if (decl->syntaxKind == StructDeclaration) + { + Node *structId = GetIdFromStruct(decl); + if (strcmp(target, structId->identifier.name) == 0) + return structId; + } + } + + return LookupType(current->parent, target); + } + + default: + return LookupType(current->parent, target); + } +} + +/* FIXME: This function should be modified to handle type parameters over + * structs. */ +void ConvertCustomsToGenerics(Node *node) +{ + if (node == NULL) + return; + + switch (node->syntaxKind) + { + case Declaration: + { + Node *id = node->declaration.identifier; + Node *type = node->declaration.type->type.typeNode; + if (type->syntaxKind == CustomTypeNode) + { + char *target = id->typeTag->value.customType; + Node *typeLookup = LookupType(node, target); + if (typeLookup != NULL && typeLookup->syntaxKind == GenericArgument) + { + id->typeTag->type = Generic; + free(node->declaration.type); + node->declaration.type = + MakeGenericTypeNode(id->typeTag->value.genericType); + } + } + break; + } + + case FunctionSignature: + { + Node *id = node->functionSignature.identifier; + Node *type = node->functionSignature.type->type.typeNode; + if (type->syntaxKind == CustomTypeNode) + { + char *target = id->typeTag->value.customType; + Node *typeLookup = LookupType(node, target); + if (typeLookup != NULL && typeLookup->syntaxKind == GenericArgument) + { + id->typeTag->type = Generic; + free(node->functionSignature.type); + node->functionSignature.type = + MakeGenericTypeNode(id->typeTag->value.genericType); + } + } + break; + } + } + + Recurse(node, *ConvertCustomsToGenerics); +} diff --git a/src/validation.h b/src/validation.h new file mode 100644 index 0000000..50c5432 --- /dev/null +++ b/src/validation.h @@ -0,0 +1,10 @@ +#ifndef WRAITH_VALIDATION_H +#define WRAITH_VALIDATION_H + +#include "ast.h" + +void ValidateIdentifiers(Node *node); +void TagIdentifierTypes(Node *node); +void ConvertCustomsToGenerics(Node *node); + +#endif /* WRAITH_VALIDATION_H */ -- 2.25.1 From 9f52a19a584ffb3a47d74c1e438418e4bcc76940 Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Tue, 1 Jun 2021 12:56:56 -0700 Subject: [PATCH 17/17] monomorphization --- generic.w | 6 +- src/ast.c | 8 +-- src/codegen.c | 159 +++++++++++++++++++++++++++++++++++--------------- 3 files changed, 120 insertions(+), 53 deletions(-) diff --git a/generic.w b/generic.w index 63247ee..73fb981 100644 --- a/generic.w +++ b/generic.w @@ -5,14 +5,14 @@ struct Foo { static Func(t: T): T { foo: T = t; - return Func2(foo); + return Foo.Func2(foo); } } struct Program { - static main(): int { + static Main(): int { x: int = 4; y: int = Foo.Func(x); return x; } -} \ No newline at end of file +} diff --git a/src/ast.c b/src/ast.c index d57ef1f..fddf74d 100644 --- a/src/ast.c +++ b/src/ast.c @@ -988,22 +988,22 @@ char *TypeTagToString(TypeTag *tag) { char *inner = TypeTagToString(tag->value.referenceType); size_t innerStrLen = strlen(inner); - char *result = malloc(sizeof(char) * (innerStrLen + 5)); + char *result = malloc(sizeof(char) * (innerStrLen + 6)); sprintf(result, "Ref<%s>", inner); return result; } case Custom: { char *result = - malloc(sizeof(char) * (strlen(tag->value.customType) + 8)); + malloc(sizeof(char) * (strlen(tag->value.customType) + 9)); sprintf(result, "Custom<%s>", tag->value.customType); return result; } case Generic: { char *result = - malloc(sizeof(char) * (strlen(tag->value.customType) + 9)); - sprintf(result, "Generic<%s>", tag->value.customType); + malloc(sizeof(char) * (strlen(tag->value.genericType) + 10)); + sprintf(result, "Generic<%s>", tag->value.genericType); return result; } } diff --git a/src/codegen.c b/src/codegen.c index 8f82dff..dd3bf73 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -26,6 +26,7 @@ typedef struct LocalVariable typedef struct LocalGenericType { char *name; + TypeTag *concreteTypeTag; LLVMTypeRef type; } LocalGenericType; @@ -171,7 +172,7 @@ static void PopScopeFrame(Scope *scope) { free(scope->scopeStack[index].genericTypes[i].name); } - free(scope->scopeStack[index].localVariables); + free(scope->scopeStack[index].genericTypes); } scope->scopeStackCount -= 1; @@ -201,7 +202,7 @@ static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type) return NULL; } -static LLVMTypeRef LookupCustomType(char *name) +static LocalGenericType *LookupGenericType(char *name) { int32_t i, j; @@ -211,11 +212,19 @@ static LLVMTypeRef LookupCustomType(char *name) { if (strcmp(scope->scopeStack[i].genericTypes[j].name, name) == 0) { - return scope->scopeStack[i].genericTypes[j].type; + return &scope->scopeStack[i].genericTypes[j]; } } } + fprintf(stderr, "Could not find resolved generic type!\n"); + return NULL; +} + +static LLVMTypeRef LookupCustomType(char *name) +{ + int32_t i; + for (i = 0; i < structTypeDeclarationCount; i += 1) { if (strcmp(structTypeDeclarations[i].name, name) == 0) @@ -242,6 +251,10 @@ static LLVMTypeRef ResolveType(TypeTag *typeTag) { return LLVMPointerType(ResolveType(typeTag->value.referenceType), 0); } + else if (typeTag->type == Generic) + { + return LookupGenericType(typeTag->value.genericType)->type; + } else { fprintf(stderr, "Unknown type node!\n"); @@ -277,6 +290,7 @@ static void AddGenericVariable(Scope *scope, TypeTag *typeTag, char *name) scopeFrame->genericTypes, sizeof(LocalGenericType) * (scopeFrame->genericTypeCount + 1)); scopeFrame->genericTypes[index].name = strdup(name); + scopeFrame->genericTypes[index].concreteTypeTag = typeTag; scopeFrame->genericTypes[index].type = ResolveType(typeTag); scopeFrame->genericTypeCount += 1; @@ -544,11 +558,12 @@ static inline uint64_t HashTypeTags(TypeTag **tags, uint32_t count) return result; } +/* FIXME: lots of duplication with non-generic function compile */ static StructTypeFunction CompileGenericFunction( LLVMModuleRef module, char *parentStructName, LLVMTypeRef wStructPointerType, - TypeTag **genericArgumentTypes, + TypeTag **resolvedGenericArgumentTypes, uint32_t genericArgumentTypeCount, Node *functionDeclaration) { @@ -562,6 +577,7 @@ static StructTypeFunction CompileGenericFunction( ->functionSignatureArguments.count; LLVMTypeRef paramTypes[argumentCount + 1]; uint32_t paramIndex = 0; + LLVMTypeRef returnType; PushScopeFrame(scope); @@ -569,7 +585,7 @@ static StructTypeFunction CompileGenericFunction( { AddGenericVariable( scope, - genericArgumentTypes[i], + resolvedGenericArgumentTypes[i], functionDeclaration->functionDeclaration.functionSignature ->functionSignature.genericArguments->genericArguments .arguments[i] @@ -601,7 +617,7 @@ static StructTypeFunction CompileGenericFunction( for (i = 0; i < genericArgumentTypeCount; i += 1) { - strcat(functionName, TypeTagToString(genericArgumentTypes[i])); + strcat(functionName, TypeTagToString(resolvedGenericArgumentTypes[i])); } if (!isStatic) @@ -618,11 +634,13 @@ static StructTypeFunction CompileGenericFunction( ResolveType(functionSignature->functionSignature.arguments ->functionSignatureArguments.sequence[i] ->declaration.identifier->typeTag); + paramIndex += 1; } - LLVMTypeRef returnType = + returnType = ResolveType(functionSignature->functionSignature.identifier->typeTag); + LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, paramIndex, 0); @@ -698,15 +716,64 @@ static StructTypeFunction CompileGenericFunction( static LLVMValueRef LookupGenericFunction( LLVMModuleRef module, StructTypeGenericFunction *genericFunction, - TypeTag **genericArgumentTypes, - uint32_t genericArgumentTypeCount, + TypeTag **argumentTypes, + uint32_t argumentCount, LLVMTypeRef *pReturnType, uint8_t *pStatic) { uint32_t i, j; - uint64_t typeHash = - HashTypeTags(genericArgumentTypes, genericArgumentTypeCount); + uint64_t typeHash; uint8_t match = 0; + uint32_t genericArgumentTypeCount = + genericFunction->functionDeclarationNode->functionDeclaration + .functionSignature->functionSignature.genericArguments + ->genericArguments.count; + TypeTag *resolvedGenericArgumentTypes[genericArgumentTypeCount]; + + for (i = 0; i < genericArgumentTypeCount; i += 1) + { + for (j = 0; + j < genericFunction->functionDeclarationNode->functionDeclaration + .functionSignature->functionSignature.arguments + ->functionSignatureArguments.count; + j += 1) + { + if (genericFunction->functionDeclarationNode->functionDeclaration + .functionSignature->functionSignature.arguments + ->functionSignatureArguments.sequence[j] + ->declaration.identifier->typeTag->type == Generic && + strcmp( + genericFunction->functionDeclarationNode + ->functionDeclaration.functionSignature + ->functionSignature.arguments + ->functionSignatureArguments.sequence[j] + ->declaration.identifier->typeTag->value.genericType, + genericFunction->functionDeclarationNode + ->functionDeclaration.functionSignature + ->functionSignature.genericArguments->genericArguments + .arguments[i] + ->genericArgument.identifier->identifier.name) == 0) + { + resolvedGenericArgumentTypes[i] = argumentTypes[j]; + break; + } + } + } + + /* Concretize generics if we are compiling nested generic functions */ + for (i = 0; i < genericArgumentTypeCount; i += 1) + { + if (resolvedGenericArgumentTypes[i]->type == Generic) + { + resolvedGenericArgumentTypes[i] = + LookupGenericType( + resolvedGenericArgumentTypes[i]->value.genericType) + ->concreteTypeTag; + } + } + + typeHash = + HashTypeTags(resolvedGenericArgumentTypes, genericArgumentTypeCount); MonomorphizedGenericFunctionHashArray *hashArray = &genericFunction->monomorphizedFunctions @@ -719,7 +786,8 @@ static LLVMValueRef LookupGenericFunction( for (j = 0; j < hashArray->elements[i].typeCount; j += 1) { - if (hashArray->elements[i].types[j] != genericArgumentTypes[j]) + if (hashArray->elements[i].types[j] != + resolvedGenericArgumentTypes[j]) { match = 0; break; @@ -739,7 +807,7 @@ static LLVMValueRef LookupGenericFunction( module, genericFunction->parentStructName, genericFunction->parentStructPointerType, - genericArgumentTypes, + resolvedGenericArgumentTypes, genericArgumentTypeCount, genericFunction->functionDeclarationNode); @@ -757,9 +825,11 @@ static LLVMValueRef LookupGenericFunction( for (i = 0; i < genericArgumentTypeCount; i += 1) { hashArray->elements[hashArray->count].types[i] = - genericArgumentTypes[i]; + resolvedGenericArgumentTypes[i]; } hashArray->count += 1; + + hashEntry = &hashArray->elements[hashArray->count - 1]; } *pReturnType = hashEntry->function.returnType; @@ -772,8 +842,8 @@ static LLVMValueRef LookupFunctionByType( LLVMModuleRef module, LLVMTypeRef structType, char *name, - TypeTag **genericArgumentTypes, - uint32_t genericArgumentTypeCount, + TypeTag **argumentTypes, + uint32_t argumentCount, LLVMTypeRef *pReturnType, uint8_t *pStatic) { @@ -805,8 +875,8 @@ static LLVMValueRef LookupFunctionByType( return LookupGenericFunction( module, &structTypeDeclarations[i].genericFunctions[j], - genericArgumentTypes, - genericArgumentTypeCount, + argumentTypes, + argumentCount, pReturnType, pStatic); } @@ -822,8 +892,8 @@ static LLVMValueRef LookupFunctionByPointerType( LLVMModuleRef module, LLVMTypeRef structPointerType, char *name, - TypeTag **genericArgumentTypes, - uint32_t genericArgumentTypeCount, + TypeTag **argumentTypes, + uint32_t argumentCount, LLVMTypeRef *pReturnType, uint8_t *pStatic) { @@ -855,8 +925,8 @@ static LLVMValueRef LookupFunctionByPointerType( return LookupGenericFunction( module, &structTypeDeclarations[i].genericFunctions[j], - genericArgumentTypes, - genericArgumentTypeCount, + argumentTypes, + argumentCount, pReturnType, pStatic); } @@ -872,8 +942,8 @@ static LLVMValueRef LookupFunctionByInstance( LLVMModuleRef module, LLVMValueRef structPointer, char *name, - TypeTag **genericArgumentTypes, - uint32_t genericArgumentTypeCount, + TypeTag **argumentTypes, + uint32_t argumentCount, LLVMTypeRef *pReturnType, uint8_t *pStatic) { @@ -881,8 +951,8 @@ static LLVMValueRef LookupFunctionByInstance( module, LLVMTypeOf(structPointer), name, - genericArgumentTypes, - genericArgumentTypeCount, + argumentTypes, + argumentCount, pReturnType, pStatic); } @@ -966,40 +1036,34 @@ static LLVMValueRef CompileFunctionCallExpression( { uint32_t i; uint32_t argumentCount = 0; - uint32_t genericArgumentCount = 0; LLVMValueRef args [functionCallExpression->functionCallExpression.argumentSequence ->functionArgumentSequence.count + 1]; - TypeTag *genericArgumentTypes[functionCallExpression->functionCallExpression - .argumentSequence - ->functionArgumentSequence.count]; + TypeTag + *argumentTypes[functionCallExpression->functionCallExpression + .argumentSequence->functionArgumentSequence.count]; LLVMValueRef function; uint8_t isStatic; LLVMValueRef structInstance; LLVMTypeRef functionReturnType; char *returnName = ""; - /* FIXME: this is completely wrong and not how we get generic args */ for (i = 0; i < functionCallExpression->functionCallExpression .argumentSequence->functionArgumentSequence.count; i += 1) { - if (functionCallExpression->functionCallExpression.argumentSequence + argumentTypes[i] = + functionCallExpression->functionCallExpression.argumentSequence ->functionArgumentSequence.sequence[i] - ->syntaxKind == GenericArgument) - { - genericArgumentTypes[genericArgumentCount] = - functionCallExpression->functionCallExpression.argumentSequence - ->functionArgumentSequence.sequence[i] - ->declaration.identifier->typeTag; + ->typeTag; - genericArgumentCount += 1; - } + argumentCount += 1; } /* FIXME: this needs to be recursive on access chains */ - /* FIXME: this needs to be able to call same-struct functions implicitly */ + /* FIXME: this needs to be able to call same-struct functions implicitly + */ if (functionCallExpression->functionCallExpression.identifier->syntaxKind == AccessExpression) { @@ -1014,8 +1078,8 @@ static LLVMValueRef CompileFunctionCallExpression( typeReference, functionCallExpression->functionCallExpression.identifier ->accessExpression.accessor->identifier.name, - genericArgumentTypes, - genericArgumentCount, + argumentTypes, + argumentCount, &functionReturnType, &isStatic); } @@ -1029,8 +1093,8 @@ static LLVMValueRef CompileFunctionCallExpression( structInstance, functionCallExpression->functionCallExpression.identifier ->accessExpression.accessor->identifier.name, - genericArgumentTypes, - genericArgumentCount, + argumentTypes, + argumentCount, &functionReturnType, &isStatic); } @@ -1041,6 +1105,8 @@ static LLVMValueRef CompileFunctionCallExpression( return NULL; } + argumentCount = 0; + if (!isStatic) { args[argumentCount] = structInstance; @@ -1693,7 +1759,8 @@ static void Compile( { fprintf( stderr, - "top level declarations that are not structs are forbidden!\n"); + "top level declarations that are not structs are " + "forbidden!\n"); } } } -- 2.25.1