From a571edcf6d90a6b97bb4d331bab4d9c9d529f969 Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Wed, 2 Jun 2021 17:26:26 -0700 Subject: [PATCH] preparing for struct generics --- generators/wraith.y | 5 +- generic.w | 11 ++ src/ast.c | 36 +++++ src/codegen.c | 324 +++++++++++++++++++++++++++----------------- src/validation.c | 20 +++ 5 files changed, 273 insertions(+), 123 deletions(-) diff --git a/generators/wraith.y b/generators/wraith.y index 27c710c..803d922 100644 --- a/generators/wraith.y +++ b/generators/wraith.y @@ -345,7 +345,10 @@ GenericDeclarationClause : LESS_THAN GenericDeclarations GREATER_THAN $$ = MakeEmptyGenericDeclarationsNode(); } -GenericArgument : Type; +GenericArgument : Type + { + $$ = MakeGenericArgumentNode($1); + } GenericArguments : GenericArgument { diff --git a/generic.w b/generic.w index 04f52eb..b7f6013 100644 --- a/generic.w +++ b/generic.w @@ -9,6 +9,17 @@ struct Foo { } } +struct MemoryBlock +{ + start: MemoryAddress; + capacity: uint; + + AddressOf(count: uint): MemoryAddress + { + return start + (count * @sizeof()); + } +} + struct Program { static Main(): int { x: int = 4; diff --git a/src/ast.c b/src/ast.c index 17f40ce..dc014bb 100644 --- a/src/ast.c +++ b/src/ast.c @@ -875,6 +875,7 @@ void Recurse(Node *node, void (*func)(Node *)) case FunctionCallExpression: func(node->functionCallExpression.identifier); func(node->functionCallExpression.argumentSequence); + func(node->functionCallExpression.genericArguments); return; case FunctionDeclaration: @@ -904,6 +905,17 @@ void Recurse(Node *node, void (*func)(Node *)) } return; + case GenericArgument: + func(node->genericArgument.type); + break; + + case GenericArguments: + for (i = 0; i < node->genericArguments.count; i += 1) + { + func(node->genericArguments.arguments[i]); + } + return; + case GenericDeclaration: func(node->genericDeclaration.identifier); func(node->genericDeclaration.constraint); @@ -967,7 +979,14 @@ void Recurse(Node *node, void (*func)(Node *)) func(node->structDeclaration.declarationSequence); return; + case SystemCall: + func(node->systemCall.identifier); + func(node->systemCall.argumentSequence); + func(node->systemCall.genericArguments); + return; + case Type: + func(node->type.typeNode); return; case UnaryExpression: @@ -1195,6 +1214,17 @@ void LinkParentPointers(Node *node, Node *prev) } return; + case GenericArgument: + LinkParentPointers(node->genericArgument.type, node); + return; + + case GenericArguments: + for (i = 0; i < node->genericArguments.count; i += 1) + { + LinkParentPointers(node->genericArguments.arguments[i], node); + } + return; + case GenericDeclaration: LinkParentPointers(node->genericDeclaration.identifier, node); LinkParentPointers(node->genericDeclaration.constraint, node); @@ -1258,6 +1288,12 @@ void LinkParentPointers(Node *node, Node *prev) LinkParentPointers(node->structDeclaration.declarationSequence, node); return; + case SystemCall: + LinkParentPointers(node->systemCall.identifier, node); + LinkParentPointers(node->systemCall.argumentSequence, node); + LinkParentPointers(node->systemCall.genericArguments, node); + return; + case Type: return; diff --git a/src/codegen.c b/src/codegen.c index cc2aafa..2e67b17 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -112,6 +112,30 @@ typedef struct StructTypeDeclaration StructTypeDeclaration *structTypeDeclarations; uint32_t structTypeDeclarationCount; +typedef struct MonomorphizedGenericStructHashEntry +{ + uint64_t key; + TypeTag **types; + uint32_t typeCount; + StructTypeDeclaration structDeclaration; +} MonomorphizedGenericStructHashEntry; + +typedef struct MonomorphizedGenericStructHashArray +{ + MonomorphizedGenericStructHashEntry *elements; + uint32_t count; +} MonomorphizedGenericStructHashArray; + +typedef struct GenericStructTypeDeclaration +{ + Node *structDeclarationNode; + MonomorphizedGenericStructHashArray + monomorphizedStructs[NUM_MONOMORPHIZED_HASH_BUCKETS]; +} GenericStructTypeDeclaration; + +GenericStructTypeDeclaration *genericStructTypeDeclarations; +uint32_t genericStructTypeDeclarationCount; + typedef struct SystemFunction { char *name; @@ -234,16 +258,14 @@ static LocalGenericType *LookupGenericType(char *name) return NULL; } -static TypeTag *ConcretizeGenericType(char *name) +static TypeTag *ConcretizeType(TypeTag *type) { - LocalGenericType *type = LookupGenericType(name); - - if (type == NULL) + if (type->type == Generic) { - return NULL; + return LookupGenericType(type->value.genericType)->concreteTypeTag; } - return type->concreteTypeTag; + return type; } static LLVMTypeRef LookupCustomType(char *name) @@ -258,7 +280,7 @@ static LLVMTypeRef LookupCustomType(char *name) } } - fprintf(stderr, "Could not find struct type!\n"); + fprintf(stderr, "Could not find custom type!\n"); return NULL; } @@ -301,9 +323,10 @@ static void AddSystemFunction( systemFunctionCount += 1; } -static SystemFunction *LookupSystemFunction(char *name) +static SystemFunction *LookupSystemFunction(Node *systemCallExpression) { uint32_t i; + char *name = systemCallExpression->systemCall.identifier->identifier.name; for (i = 0; i < systemFunctionCount; i += 1) { @@ -518,6 +541,30 @@ static void AddStructDeclaration( structTypeDeclarationCount += 1; } +static void AddGenericStructDeclaration(Node *structDeclarationNode) +{ + uint32_t i; + + genericStructTypeDeclarations = realloc( + genericStructTypeDeclarations, + sizeof(GenericStructTypeDeclaration) * + (genericStructTypeDeclarationCount + 1)); + genericStructTypeDeclarations[genericStructTypeDeclarationCount] + .structDeclarationNode = structDeclarationNode; + + for (i = 0; i < NUM_MONOMORPHIZED_HASH_BUCKETS; i += 1) + { + genericStructTypeDeclarations[genericStructTypeDeclarationCount] + .monomorphizedStructs[i] + .elements = NULL; + genericStructTypeDeclarations[genericStructTypeDeclarationCount] + .monomorphizedStructs[i] + .count = 0; + } + + genericStructTypeDeclarationCount += 1; +} + /* FIXME: pass the declaration itself */ static void DeclareStructFunction( LLVMTypeRef wStructPointerType, @@ -771,8 +818,7 @@ static StructTypeFunction CompileGenericFunction( static LLVMValueRef LookupGenericFunction( LLVMModuleRef module, StructTypeGenericFunction *genericFunction, - TypeTag **argumentTypes, - uint32_t argumentCount, + Node *functionCallExpression, LLVMTypeRef *pReturnType, uint8_t *pStatic) { @@ -784,33 +830,70 @@ static LLVMValueRef LookupGenericFunction( .functionSignature->functionSignature.genericDeclarations ->genericDeclarations.count; TypeTag *resolvedGenericArgumentTypes[genericArgumentTypeCount]; + uint32_t argumentCount = 0; + TypeTag + *argumentTypes[functionCallExpression->functionCallExpression + .argumentSequence->functionArgumentSequence.count]; - for (i = 0; i < genericArgumentTypeCount; i += 1) + for (i = 0; i < functionCallExpression->functionCallExpression + .argumentSequence->functionArgumentSequence.count; + i += 1) { - for (j = 0; - j < genericFunction->functionDeclarationNode->functionDeclaration + argumentTypes[i] = + functionCallExpression->functionCallExpression.argumentSequence + ->functionArgumentSequence.sequence[i] + ->typeTag; + + argumentCount += 1; + } + + if (functionCallExpression->functionCallExpression.genericArguments + ->genericArguments.count > 0) + { + for (i = 0; i < functionCallExpression->functionCallExpression + .genericArguments->genericArguments.count; + i += 1) + { + resolvedGenericArgumentTypes[i] = + functionCallExpression->functionCallExpression.genericArguments + ->genericArguments.arguments[i] + ->genericArgument.type->typeTag; + } + } + else /* we have to infer the generics */ + { + for (i = 0; i < genericArgumentTypeCount; i += 1) + { + for (j = 0; + j < + genericFunction->functionDeclarationNode->functionDeclaration .functionSignature->functionSignature.arguments ->functionSignatureArguments.count; - j += 1) - { - if (genericFunction->functionDeclarationNode->functionDeclaration - .functionSignature->functionSignature.arguments - ->functionSignatureArguments.sequence[j] - ->declaration.identifier->typeTag->type == Generic && - strcmp( - genericFunction->functionDeclarationNode - ->functionDeclaration.functionSignature - ->functionSignature.arguments - ->functionSignatureArguments.sequence[j] - ->declaration.identifier->typeTag->value.genericType, - genericFunction->functionDeclarationNode - ->functionDeclaration.functionSignature - ->functionSignature.genericDeclarations - ->genericDeclarations.declarations[i] - ->genericDeclaration.identifier->identifier.name) == 0) + j += 1) { - resolvedGenericArgumentTypes[i] = argumentTypes[j]; - break; + if (genericFunction->functionDeclarationNode + ->functionDeclaration.functionSignature + ->functionSignature.arguments + ->functionSignatureArguments.sequence[j] + ->declaration.identifier->typeTag->type == + Generic && + strcmp( + genericFunction->functionDeclarationNode + ->functionDeclaration.functionSignature + ->functionSignature.arguments + ->functionSignatureArguments.sequence[j] + ->declaration.identifier->typeTag->value + .genericType, + genericFunction->functionDeclarationNode + ->functionDeclaration.functionSignature + ->functionSignature.genericDeclarations + ->genericDeclarations.declarations[i] + ->genericDeclaration.identifier->identifier.name) == + 0) + { + resolvedGenericArgumentTypes[i] = argumentTypes[j]; + break; + } } } } @@ -818,11 +901,8 @@ static LLVMValueRef LookupGenericFunction( /* Concretize generics if we are compiling nested generic functions */ for (i = 0; i < genericArgumentTypeCount; i += 1) { - if (resolvedGenericArgumentTypes[i]->type == Generic) - { - resolvedGenericArgumentTypes[i] = ConcretizeGenericType( - resolvedGenericArgumentTypes[i]->value.genericType); - } + resolvedGenericArgumentTypes[i] = + ConcretizeType(resolvedGenericArgumentTypes[i]); } typeHash = @@ -894,13 +974,14 @@ static LLVMValueRef LookupGenericFunction( static LLVMValueRef LookupFunctionByType( LLVMModuleRef module, LLVMTypeRef structType, - char *name, - TypeTag **argumentTypes, - uint32_t argumentCount, + Node *functionCallExpression, LLVMTypeRef *pReturnType, uint8_t *pStatic) { uint32_t i, j; + /* FIXME: hot garbage */ + char *name = functionCallExpression->functionCallExpression.identifier + ->accessExpression.accessor->identifier.name; for (i = 0; i < structTypeDeclarationCount; i += 1) { @@ -928,8 +1009,7 @@ static LLVMValueRef LookupFunctionByType( return LookupGenericFunction( module, &structTypeDeclarations[i].genericFunctions[j], - argumentTypes, - argumentCount, + functionCallExpression, pReturnType, pStatic); } @@ -944,13 +1024,14 @@ static LLVMValueRef LookupFunctionByType( static LLVMValueRef LookupFunctionByPointerType( LLVMModuleRef module, LLVMTypeRef structPointerType, - char *name, - TypeTag **argumentTypes, - uint32_t argumentCount, + Node *functionCallExpression, LLVMTypeRef *pReturnType, uint8_t *pStatic) { uint32_t i, j; + /* FIXME: hot garbage */ + char *name = functionCallExpression->functionCallExpression.identifier + ->accessExpression.accessor->identifier.name; for (i = 0; i < structTypeDeclarationCount; i += 1) { @@ -978,8 +1059,7 @@ static LLVMValueRef LookupFunctionByPointerType( return LookupGenericFunction( module, &structTypeDeclarations[i].genericFunctions[j], - argumentTypes, - argumentCount, + functionCallExpression, pReturnType, pStatic); } @@ -994,18 +1074,14 @@ static LLVMValueRef LookupFunctionByPointerType( static LLVMValueRef LookupFunctionByInstance( LLVMModuleRef module, LLVMValueRef structPointer, - char *name, - TypeTag **argumentTypes, - uint32_t argumentCount, + Node *functionCallExpression, LLVMTypeRef *pReturnType, uint8_t *pStatic) { return LookupFunctionByPointerType( module, LLVMTypeOf(structPointer), - name, - argumentTypes, - argumentCount, + functionCallExpression, pReturnType, pStatic); } @@ -1093,27 +1169,12 @@ static LLVMValueRef CompileFunctionCallExpression( [functionCallExpression->functionCallExpression.argumentSequence ->functionArgumentSequence.count + 1]; - TypeTag - *argumentTypes[functionCallExpression->functionCallExpression - .argumentSequence->functionArgumentSequence.count]; LLVMValueRef function; uint8_t isStatic; LLVMValueRef structInstance; LLVMTypeRef functionReturnType; char *returnName = ""; - for (i = 0; i < functionCallExpression->functionCallExpression - .argumentSequence->functionArgumentSequence.count; - i += 1) - { - argumentTypes[i] = - functionCallExpression->functionCallExpression.argumentSequence - ->functionArgumentSequence.sequence[i] - ->typeTag; - - argumentCount += 1; - } - /* FIXME: this needs to be recursive on access chains */ /* FIXME: this needs to be able to call same-struct functions implicitly */ @@ -1129,10 +1190,7 @@ static LLVMValueRef CompileFunctionCallExpression( function = LookupFunctionByType( module, typeReference, - functionCallExpression->functionCallExpression.identifier - ->accessExpression.accessor->identifier.name, - argumentTypes, - argumentCount, + functionCallExpression, &functionReturnType, &isStatic); } @@ -1144,10 +1202,7 @@ static LLVMValueRef CompileFunctionCallExpression( function = LookupFunctionByInstance( module, structInstance, - functionCallExpression->functionCallExpression.identifier - ->accessExpression.accessor->identifier.name, - argumentTypes, - argumentCount, + functionCallExpression, &functionReturnType, &isStatic); } @@ -1158,8 +1213,6 @@ static LLVMValueRef CompileFunctionCallExpression( return NULL; } - argumentCount = 0; - if (!isStatic) { args[argumentCount] = structInstance; @@ -1208,13 +1261,27 @@ static LLVMValueRef CompileSystemCallExpression( ->functionArgumentSequence.sequence[i]); } - SystemFunction *systemFunction = LookupSystemFunction( - systemCallExpression->systemCall.identifier->identifier.name); + SystemFunction *systemFunction = LookupSystemFunction(systemCallExpression); if (systemFunction == NULL) { - fprintf(stderr, "System function not found!"); - return NULL; + /* special case for sizeof */ + if (strcmp( + systemCallExpression->systemCall.identifier->identifier.name, + "sizeof") == 0) + { + TypeTag *typeTag = + systemCallExpression->systemCall.genericArguments + ->genericArguments.arguments[0] + ->type.typeNode->typeTag; + + return LLVMSizeOf(ResolveType(ConcretizeType(typeTag))); + } + else + { + fprintf(stderr, "System function not found!"); + return NULL; + } } if (LLVMGetTypeKind(LLVMGetReturnType(systemFunction->functionType)) != @@ -1792,52 +1859,62 @@ static void CompileStruct( PushScopeFrame(scope); - LLVMTypeRef wStructType = LLVMStructCreateNamed(context, structName); - LLVMTypeRef wStructPointerType = LLVMPointerType( - wStructType, - 0); /* FIXME: is this address space correct? */ - - /* first, build the structure definition */ - for (i = 0; i < declarationCount; i += 1) + if (node->structDeclaration.genericDeclarations->genericDeclarations + .count == 0) { - currentDeclarationNode = node->structDeclaration.declarationSequence - ->declarationSequence.sequence[i]; + LLVMTypeRef wStructType = LLVMStructCreateNamed(context, structName); + LLVMTypeRef wStructPointerType = LLVMPointerType( + wStructType, + 0); /* FIXME: is this address space correct? */ - switch (currentDeclarationNode->syntaxKind) + /* first, build the structure definition */ + for (i = 0; i < declarationCount; i += 1) { - case Declaration: /* this is badly named */ - types[fieldCount] = ResolveType( - currentDeclarationNode->declaration.identifier->typeTag); - fieldDeclarations[fieldCount] = currentDeclarationNode; - fieldCount += 1; - break; + currentDeclarationNode = + node->structDeclaration.declarationSequence->declarationSequence + .sequence[i]; + + switch (currentDeclarationNode->syntaxKind) + { + case Declaration: /* this is badly named */ + types[fieldCount] = ResolveType( + currentDeclarationNode->declaration.identifier->typeTag); + fieldDeclarations[fieldCount] = currentDeclarationNode; + fieldCount += 1; + break; + } + } + + LLVMStructSetBody(wStructType, types, fieldCount, packed); + AddStructDeclaration( + wStructType, + wStructPointerType, + structName, + fieldDeclarations, + fieldCount); + + /* now we can wire up the functions */ + for (i = 0; i < declarationCount; i += 1) + { + currentDeclarationNode = + node->structDeclaration.declarationSequence->declarationSequence + .sequence[i]; + + switch (currentDeclarationNode->syntaxKind) + { + case FunctionDeclaration: + CompileFunction( + module, + structName, + wStructPointerType, + currentDeclarationNode); + break; + } } } - - LLVMStructSetBody(wStructType, types, fieldCount, packed); - AddStructDeclaration( - wStructType, - wStructPointerType, - structName, - fieldDeclarations, - fieldCount); - - /* now we can wire up the functions */ - for (i = 0; i < declarationCount; i += 1) + else { - currentDeclarationNode = node->structDeclaration.declarationSequence - ->declarationSequence.sequence[i]; - - switch (currentDeclarationNode->syntaxKind) - { - case FunctionDeclaration: - CompileFunction( - module, - structName, - wStructPointerType, - currentDeclarationNode); - break; - } + AddGenericStructDeclaration(node); } PopScopeFrame(scope); @@ -1946,6 +2023,9 @@ int Codegen(Node *node, uint32_t optimizationLevel) structTypeDeclarations = NULL; structTypeDeclarationCount = 0; + genericStructTypeDeclarations = NULL; + genericStructTypeDeclarationCount = 0; + systemFunctions = NULL; systemFunctionCount = 0; diff --git a/src/validation.c b/src/validation.c index eed89f0..d80e052 100644 --- a/src/validation.c +++ b/src/validation.c @@ -313,6 +313,10 @@ void TagIdentifierTypes(Node *node) node->genericDeclaration.identifier->typeTag = MakeTypeTag(node); break; + case Type: + node->typeTag = MakeTypeTag(node); + break; + case Identifier: { if (node->typeTag != NULL) @@ -467,6 +471,22 @@ void ConvertCustomsToGenerics(Node *node) } break; } + + case GenericArgument: + { + Node *typeNode = node->genericArgument.type; + if (typeNode->typeTag->type == Custom) + { + char *target = typeNode->typeTag->value.customType; + Node *typeLookup = LookupType(node, target); + if (typeLookup != NULL && + typeLookup->syntaxKind == GenericDeclaration) + { + typeNode->typeTag->type = Generic; + } + } + break; + } } Recurse(node, *ConvertCustomsToGenerics);