From ea203e6c3cdda14098f7d194eeb18a877b68c868 Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Wed, 2 Jun 2021 14:33:15 -0700 Subject: [PATCH] allow explicit generic arguments on calls --- generators/wraith.y | 44 +++++++++++----- src/ast.c | 126 ++++++++++++++++++++++++++++++++++---------- src/ast.h | 39 +++++++++++--- src/codegen.c | 40 +++++++++----- src/validation.c | 41 ++++++++++---- 5 files changed, 218 insertions(+), 72 deletions(-) diff --git a/generators/wraith.y b/generators/wraith.y index a329697..27c710c 100644 --- a/generators/wraith.y +++ b/generators/wraith.y @@ -241,13 +241,13 @@ ReturnStatement : RETURN Expression $$ = MakeReturnVoidStatementNode(); } -FunctionCallExpression : AccessExpression LEFT_PAREN Arguments RIGHT_PAREN +FunctionCallExpression : AccessExpression GenericArgumentClause LEFT_PAREN Arguments RIGHT_PAREN { - $$ = MakeFunctionCallExpressionNode($1, $3); + $$ = MakeFunctionCallExpressionNode($1, $4, $2); } - | SystemCallExpression LEFT_PAREN Arguments RIGHT_PAREN + | SystemCallExpression GenericArgumentClause LEFT_PAREN Arguments RIGHT_PAREN { - $$ = MakeSystemCallExpressionNode($1, $3); + $$ = MakeSystemCallExpressionNode($1, $4, $2); } PartialStatement : FunctionCallExpression @@ -322,11 +322,31 @@ Body : LEFT_BRACE Statements RIGHT_BRACE $$ = $2; } -GenericArgument : Identifier +GenericDeclaration : Identifier { - $$ = MakeGenericArgumentNode($1, NULL); + $$ = MakeGenericDeclarationNode($1, NULL); } +GenericDeclarations : GenericDeclaration + { + $$ = StartGenericDeclarationsNode($1); + } + | GenericDeclarations COMMA GenericDeclaration + { + $$ = AddGenericDeclaration($1, $3); + } + +GenericDeclarationClause : LESS_THAN GenericDeclarations GREATER_THAN + { + $$ = $2; + } + | + { + $$ = MakeEmptyGenericDeclarationsNode(); + } + +GenericArgument : Type; + GenericArguments : GenericArgument { $$ = StartGenericArgumentsNode($1); @@ -336,7 +356,8 @@ GenericArguments : GenericArgument $$ = AddGenericArgument($1, $3); } -GenericArgumentsClause : LESS_THAN GenericArguments GREATER_THAN + +GenericArgumentClause : LESS_THAN GenericArguments GREATER_THAN { $$ = $2; } @@ -345,12 +366,11 @@ GenericArgumentsClause : LESS_THAN GenericArguments GREATER_THAN $$ = MakeEmptyGenericArgumentsNode(); } - -FunctionSignature : Identifier GenericArgumentsClause LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type +FunctionSignature : Identifier GenericDeclarationClause LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type { $$ = MakeFunctionSignatureNode($1, $7, $4, MakeFunctionModifiersNode(NULL, 0), $2); } - | STATIC Identifier GenericArgumentsClause LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type + | STATIC Identifier GenericDeclarationClause LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type { Node *modifier = MakeStaticNode(); $$ = MakeFunctionSignatureNode($2, $8, $5, MakeFunctionModifiersNode(&modifier, 1), $3); @@ -361,9 +381,9 @@ FunctionDeclaration : FunctionSignature Body $$ = MakeFunctionDeclarationNode($1, $2); } -StructDeclaration : STRUCT Identifier LEFT_BRACE Declarations RIGHT_BRACE +StructDeclaration : STRUCT Identifier GenericDeclarationClause LEFT_BRACE Declarations RIGHT_BRACE { - $$ = MakeStructDeclarationNode($2, $4); + $$ = MakeStructDeclarationNode($2, $5, $3); } Declaration : FunctionDeclaration diff --git a/src/ast.c b/src/ast.c index 12d5677..17f40ce 100644 --- a/src/ast.c +++ b/src/ast.c @@ -43,6 +43,10 @@ const char *SyntaxKindString(SyntaxKind syntaxKind) return "GenericArgument"; case GenericArguments: return "GenericArguments"; + case GenericDeclaration: + return "GenericDeclaration"; + case GenericDeclarations: + return "GenericDeclarations"; case GenericTypeNode: return "GenericTypeNode"; case Identifier: @@ -288,7 +292,7 @@ Node *MakeFunctionSignatureNode( node->functionSignature.type = typeNode; node->functionSignature.arguments = arguments; node->functionSignature.modifiers = modifiersNode; - node->functionSignature.genericArguments = genericArgumentsNode; + node->functionSignature.genericDeclarations = genericArgumentsNode; return node; } @@ -305,12 +309,14 @@ Node *MakeFunctionDeclarationNode( Node *MakeStructDeclarationNode( Node *identifierNode, - Node *declarationSequenceNode) + Node *declarationSequenceNode, + Node *genericArgumentsNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = StructDeclaration; node->structDeclaration.identifier = identifierNode; node->structDeclaration.declarationSequence = declarationSequenceNode; + node->structDeclaration.genericDeclarations = genericArgumentsNode; return node; } @@ -369,12 +375,55 @@ Node *MakeEmptyFunctionArgumentSequenceNode() return node; } -Node *MakeGenericArgumentNode(Node *identifierNode, Node *constraintNode) +Node *MakeGenericDeclarationNode(Node *identifierNode, Node *constraintNode) +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = GenericDeclaration; + node->genericDeclaration.identifier = identifierNode; + node->genericDeclaration.constraint = constraintNode; + return node; +} + +Node *StartGenericDeclarationsNode(Node *genericArgumentNode) +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = GenericDeclarations; + node->genericDeclarations.declarations = (Node **)malloc(sizeof(Node *)); + node->genericDeclarations.declarations[0] = genericArgumentNode; + node->genericDeclarations.count = 1; + return node; +} + +Node *AddGenericDeclaration( + Node *genericDeclarationsNode, + Node *genericDeclarationNode) +{ + genericDeclarationsNode->genericDeclarations.declarations = + (Node **)realloc( + genericDeclarationsNode->genericDeclarations.declarations, + sizeof(Node *) * + (genericDeclarationsNode->genericDeclarations.count + 1)); + genericDeclarationsNode->genericDeclarations + .declarations[genericDeclarationsNode->genericDeclarations.count] = + genericDeclarationNode; + genericDeclarationsNode->genericDeclarations.count += 1; + return genericDeclarationsNode; +} + +Node *MakeEmptyGenericDeclarationsNode() +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = GenericDeclarations; + node->genericDeclarations.declarations = NULL; + node->genericDeclarations.count = 0; + return node; +} + +Node *MakeGenericArgumentNode(Node *typeNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = GenericArgument; - node->genericArgument.identifier = identifierNode; - node->genericArgument.constraint = constraintNode; + node->genericArgument.type = typeNode; return node; } @@ -390,13 +439,13 @@ Node *StartGenericArgumentsNode(Node *genericArgumentNode) Node *AddGenericArgument(Node *genericArgumentsNode, Node *genericArgumentNode) { - genericArgumentsNode->genericArguments.arguments = (Node **)realloc( + genericArgumentsNode->genericArguments.arguments = realloc( genericArgumentsNode->genericArguments.arguments, sizeof(Node *) * (genericArgumentsNode->genericArguments.count + 1)); genericArgumentsNode->genericArguments .arguments[genericArgumentsNode->genericArguments.count] = genericArgumentNode; - genericArgumentsNode->genericArguments.count += 1; + genericArgumentNode->genericArguments.count += 1; return genericArgumentsNode; } @@ -419,23 +468,27 @@ Node *MakeGenericTypeNode(char *name) Node *MakeFunctionCallExpressionNode( Node *identifierNode, - Node *argumentSequenceNode) + Node *argumentSequenceNode, + Node *genericArgumentsNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = FunctionCallExpression; node->functionCallExpression.identifier = identifierNode; node->functionCallExpression.argumentSequence = argumentSequenceNode; + node->functionCallExpression.genericArguments = genericArgumentsNode; return node; } Node *MakeSystemCallExpressionNode( Node *identifierNode, - Node *argumentSequenceNode) + Node *argumentSequenceNode, + Node *genericArgumentsNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = SystemCall; node->systemCall.identifier = identifierNode; node->systemCall.argumentSequence = argumentSequenceNode; + node->systemCall.genericArguments = genericArgumentsNode; return node; } @@ -609,6 +662,7 @@ void PrintNode(Node *node, uint32_t tabCount) printf("\n"); PrintNode(node->functionCallExpression.identifier, tabCount + 1); PrintNode(node->functionCallExpression.argumentSequence, tabCount + 1); + PrintNode(node->functionCallExpression.genericArguments, tabCount + 1); return; case FunctionDeclaration: @@ -628,7 +682,7 @@ void PrintNode(Node *node, uint32_t tabCount) case FunctionSignature: printf("\n"); PrintNode(node->functionSignature.identifier, tabCount + 1); - PrintNode(node->functionSignature.genericArguments, tabCount + 1); + PrintNode(node->functionSignature.genericDeclarations, tabCount + 1); PrintNode(node->functionSignature.arguments, tabCount + 1); PrintNode(node->functionSignature.type, tabCount + 1); PrintNode(node->functionSignature.modifiers, tabCount + 1); @@ -646,9 +700,7 @@ void PrintNode(Node *node, uint32_t tabCount) case GenericArgument: printf("\n"); - PrintNode(node->genericArgument.identifier, tabCount + 1); - /* Constraint nodes are not implemented. */ - /* PrintNode(node->genericArgument.constraint, tabCount + 1); */ + PrintNode(node->genericArgument.type, tabCount + 1); return; case GenericArguments: @@ -659,6 +711,21 @@ void PrintNode(Node *node, uint32_t tabCount) } return; + case GenericDeclaration: + printf("\n"); + PrintNode(node->genericDeclaration.identifier, tabCount + 1); + /* Constraint nodes are not implemented. */ + /* PrintNode(node->genericDeclaration.constraint, tabCount + 1); */ + return; + + case GenericDeclarations: + printf("\n"); + for (i = 0; i < node->genericDeclarations.count; i += 1) + { + PrintNode(node->genericDeclarations.declarations[i], tabCount + 1); + } + return; + case GenericTypeNode: printf("%s\n", node->genericType.name); return; @@ -734,6 +801,7 @@ void PrintNode(Node *node, uint32_t tabCount) printf("\n"); PrintNode(node->systemCall.identifier, tabCount + 1); PrintNode(node->systemCall.argumentSequence, tabCount + 1); + PrintNode(node->systemCall.genericArguments, tabCount + 1); return; case Type: @@ -826,7 +894,7 @@ void Recurse(Node *node, void (*func)(Node *)) func(node->functionSignature.type); func(node->functionSignature.arguments); func(node->functionSignature.modifiers); - func(node->functionSignature.genericArguments); + func(node->functionSignature.genericDeclarations); return; case FunctionSignatureArguments: @@ -836,15 +904,15 @@ void Recurse(Node *node, void (*func)(Node *)) } return; - case GenericArgument: - func(node->genericArgument.identifier); - func(node->genericArgument.constraint); + case GenericDeclaration: + func(node->genericDeclaration.identifier); + func(node->genericDeclaration.constraint); return; - case GenericArguments: - for (i = 0; i < node->genericArguments.count; i += 1) + case GenericDeclarations: + for (i = 0; i < node->genericDeclarations.count; i += 1) { - func(node->genericArguments.arguments[i]); + func(node->genericDeclarations.declarations[i]); } return; @@ -966,10 +1034,10 @@ TypeTag *MakeTypeTag(Node *node) tag = MakeTypeTag(node->allocExpression.type); break; - case GenericArgument: + case GenericDeclaration: tag->type = Generic; tag->value.genericType = - strdup(node->genericArgument.identifier->identifier.name); + strdup(node->genericDeclaration.identifier->identifier.name); break; case GenericTypeNode: @@ -1115,7 +1183,7 @@ void LinkParentPointers(Node *node, Node *prev) LinkParentPointers(node->functionSignature.type, node); LinkParentPointers(node->functionSignature.arguments, node); LinkParentPointers(node->functionSignature.modifiers, node); - LinkParentPointers(node->functionSignature.genericArguments, node); + LinkParentPointers(node->functionSignature.genericDeclarations, node); return; case FunctionSignatureArguments: @@ -1127,15 +1195,15 @@ void LinkParentPointers(Node *node, Node *prev) } return; - case GenericArgument: - LinkParentPointers(node->genericArgument.identifier, node); - LinkParentPointers(node->genericArgument.constraint, node); + case GenericDeclaration: + LinkParentPointers(node->genericDeclaration.identifier, node); + LinkParentPointers(node->genericDeclaration.constraint, node); return; - case GenericArguments: - for (i = 0; i < node->genericArguments.count; i += 1) + case GenericDeclarations: + for (i = 0; i < node->genericDeclarations.count; i += 1) { - LinkParentPointers(node->genericArguments.arguments[i], node); + LinkParentPointers(node->genericDeclarations.declarations[i], node); } return; diff --git a/src/ast.h b/src/ast.h index 65392a8..c7fd974 100644 --- a/src/ast.h +++ b/src/ast.h @@ -31,6 +31,8 @@ typedef enum FunctionSignatureArguments, GenericArgument, GenericArguments, + GenericDeclaration, + GenericDeclarations, GenericTypeNode, Identifier, IfStatement, @@ -179,6 +181,7 @@ struct Node { Node *identifier; /* FIXME: need better name */ Node *argumentSequence; + Node *genericArguments; } functionCallExpression; struct @@ -199,7 +202,7 @@ struct Node Node *type; Node *arguments; Node *modifiers; - Node *genericArguments; + Node *genericDeclarations; } functionSignature; struct @@ -210,8 +213,7 @@ struct Node struct { - Node *identifier; - Node *constraint; + Node *type; } genericArgument; struct @@ -220,6 +222,18 @@ struct Node uint32_t count; } genericArguments; + struct + { + Node *identifier; + Node *constraint; + } genericDeclaration; + + struct + { + Node **declarations; + uint32_t count; + } genericDeclarations; + struct { char *name; @@ -287,12 +301,14 @@ struct Node { Node *identifier; Node *declarationSequence; + Node *genericDeclarations; } structDeclaration; struct { Node *identifier; Node *argumentSequence; + Node *genericArguments; } systemCall; struct @@ -341,14 +357,21 @@ Node *MakeFunctionSignatureNode( Node *MakeFunctionDeclarationNode( Node *functionSignatureNode, Node *functionBodyNode); -Node *MakeGenericArgumentNode(Node *identifierNode, Node *constraintNode); +Node *MakeGenericDeclarationNode(Node *identifierNode, Node *constraintNode); +Node *MakeEmptyGenericDeclarationsNode(); +Node *StartGenericDeclarationsNode(Node *genericDeclarationNode); +Node *AddGenericDeclaration( + Node *genericDeclarationsNode, + Node *genericDeclarationNode); +Node *MakeGenericArgumentNode(Node *typeNode); Node *MakeEmptyGenericArgumentsNode(); Node *StartGenericArgumentsNode(Node *genericArgumentNode); Node *AddGenericArgument(Node *genericArgumentsNode, Node *genericArgumentNode); Node *MakeGenericTypeNode(char *name); Node *MakeStructDeclarationNode( Node *identifierNode, - Node *declarationSequenceNode); + Node *declarationSequenceNode, + Node *genericArgumentsNode); Node *StartDeclarationSequenceNode(Node *declarationNode); Node *AddDeclarationNode(Node *declarationSequenceNode, Node *declarationNode); Node *StartFunctionArgumentSequenceNode(Node *argumentNode); @@ -356,10 +379,12 @@ Node *AddFunctionArgumentNode(Node *argumentSequenceNode, Node *argumentNode); Node *MakeEmptyFunctionArgumentSequenceNode(); Node *MakeFunctionCallExpressionNode( Node *identifierNode, - Node *argumentSequenceNode); + Node *argumentSequenceNode, + Node *genericArgumentsNode); Node *MakeSystemCallExpressionNode( Node *identifierNode, - Node *argumentSequenceNode); + Node *argumentSequenceNode, + Node *genericArgumentsNode); Node *MakeAccessExpressionNode(Node *accessee, Node *accessor); Node *MakeAllocNode(Node *typeNode); Node *MakeIfNode(Node *expressionNode, Node *statementSequenceNode); diff --git a/src/codegen.c b/src/codegen.c index a1421b2..cc2aafa 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -234,6 +234,18 @@ static LocalGenericType *LookupGenericType(char *name) return NULL; } +static TypeTag *ConcretizeGenericType(char *name) +{ + LocalGenericType *type = LookupGenericType(name); + + if (type == NULL) + { + return NULL; + } + + return type->concreteTypeTag; +} + static LLVMTypeRef LookupCustomType(char *name) { int32_t i; @@ -630,9 +642,9 @@ static StructTypeFunction CompileGenericFunction( scope, resolvedGenericArgumentTypes[i], functionDeclaration->functionDeclaration.functionSignature - ->functionSignature.genericArguments->genericArguments - .arguments[i] - ->genericArgument.identifier->identifier.name); + ->functionSignature.genericDeclarations->genericDeclarations + .declarations[i] + ->genericDeclaration.identifier->identifier.name); } if (functionSignature->functionSignature.modifiers->functionModifiers @@ -769,8 +781,8 @@ static LLVMValueRef LookupGenericFunction( uint8_t match = 0; uint32_t genericArgumentTypeCount = genericFunction->functionDeclarationNode->functionDeclaration - .functionSignature->functionSignature.genericArguments - ->genericArguments.count; + .functionSignature->functionSignature.genericDeclarations + ->genericDeclarations.count; TypeTag *resolvedGenericArgumentTypes[genericArgumentTypeCount]; for (i = 0; i < genericArgumentTypeCount; i += 1) @@ -793,9 +805,9 @@ static LLVMValueRef LookupGenericFunction( ->declaration.identifier->typeTag->value.genericType, genericFunction->functionDeclarationNode ->functionDeclaration.functionSignature - ->functionSignature.genericArguments->genericArguments - .arguments[i] - ->genericArgument.identifier->identifier.name) == 0) + ->functionSignature.genericDeclarations + ->genericDeclarations.declarations[i] + ->genericDeclaration.identifier->identifier.name) == 0) { resolvedGenericArgumentTypes[i] = argumentTypes[j]; break; @@ -808,10 +820,8 @@ static LLVMValueRef LookupGenericFunction( { if (resolvedGenericArgumentTypes[i]->type == Generic) { - resolvedGenericArgumentTypes[i] = - LookupGenericType( - resolvedGenericArgumentTypes[i]->value.genericType) - ->concreteTypeTag; + resolvedGenericArgumentTypes[i] = ConcretizeGenericType( + resolvedGenericArgumentTypes[i]->value.genericType); } } @@ -1657,8 +1667,8 @@ static void CompileFunction( functionName, functionSignature->functionSignature.identifier->identifier.name); - if (functionSignature->functionSignature.genericArguments->genericArguments - .count == 0) + if (functionSignature->functionSignature.genericDeclarations + ->genericDeclarations.count == 0) { PushScopeFrame(scope); @@ -1925,6 +1935,8 @@ static void RegisterLibraryFunctions( LLVMSetLinkage(freeFunction, LLVMExternalLinkage); AddSystemFunction("free", freeFunctionType, freeFunction); + + LLVMDisposeBuilder(builder); } int Codegen(Node *node, uint32_t optimizationLevel) diff --git a/src/validation.c b/src/validation.c index e62c4dd..eed89f0 100644 --- a/src/validation.c +++ b/src/validation.c @@ -261,7 +261,7 @@ void ValidateIdentifiers(Node *node) /* Skip over generic arguments. They contain Identifiers but are not * actually identifiers, they declare types. */ - if (node->syntaxKind == GenericArguments) + if (node->syntaxKind == GenericDeclarations) return; if (node->syntaxKind != Identifier) @@ -309,8 +309,8 @@ void TagIdentifierTypes(Node *node) node->structDeclaration.identifier->typeTag = MakeTypeTag(node); break; - case GenericArgument: - node->genericArgument.identifier->typeTag = MakeTypeTag(node); + case GenericDeclaration: + node->genericDeclaration.identifier->typeTag = MakeTypeTag(node); break; case Identifier: @@ -351,17 +351,19 @@ Node *LookupType(Node *current, char *target) case FunctionDeclaration: { Node *typeArgs = current->functionDeclaration.functionSignature - ->functionSignature.genericArguments; + ->functionSignature.genericDeclarations; uint32_t i; - for (i = 0; i < typeArgs->genericArguments.count; i += 1) + for (i = 0; i < typeArgs->genericDeclarations.count; i += 1) { - Node *arg = typeArgs->genericArguments.arguments[i]; - Node *argId = arg->genericArgument.identifier; + Node *arg = typeArgs->genericDeclarations.declarations[i]; + Node *argId = arg->genericDeclaration.identifier; char *argName = argId->identifier.name; - /* note: return the GenericArgument, not the Identifier, so that + /* note: return the GenericDeclaration, not the Identifier, so that * the caller can differentiate between generics and customs. */ if (strcmp(target, argName) == 0) + { return arg; + } } return LookupType(current->parent, target); @@ -369,9 +371,26 @@ Node *LookupType(Node *current, char *target) case StructDeclaration: { + uint32_t i; + Node *typeArgs = current->structDeclaration.genericDeclarations; + for (i = 0; i < typeArgs->genericDeclarations.count; i += 1) + { + Node *arg = typeArgs->genericDeclarations.declarations[i]; + Node *argId = arg->genericDeclaration.identifier; + char *argName = argId->identifier.name; + /* note: return the GenericDeclaration, not the Identifier, so that + * the caller can differentiate between generics and customs. */ + if (strcmp(target, argName) == 0) + { + return arg; + } + } + Node *structId = GetIdFromStruct(current); if (strcmp(target, structId->identifier.name) == 0) + { return structId; + } return LookupType(current->parent, target); } @@ -417,7 +436,8 @@ void ConvertCustomsToGenerics(Node *node) { char *target = id->typeTag->value.customType; Node *typeLookup = LookupType(node, target); - if (typeLookup != NULL && typeLookup->syntaxKind == GenericArgument) + if (typeLookup != NULL && + typeLookup->syntaxKind == GenericDeclaration) { id->typeTag->type = Generic; free(node->declaration.type); @@ -436,7 +456,8 @@ void ConvertCustomsToGenerics(Node *node) { char *target = id->typeTag->value.customType; Node *typeLookup = LookupType(node, target); - if (typeLookup != NULL && typeLookup->syntaxKind == GenericArgument) + if (typeLookup != NULL && + typeLookup->syntaxKind == GenericDeclaration) { id->typeTag->type = Generic; free(node->functionSignature.type);