diff --git a/generic.w b/generic.w index 63247ee..73fb981 100644 --- a/generic.w +++ b/generic.w @@ -5,14 +5,14 @@ struct Foo { static Func(t: T): T { foo: T = t; - return Func2(foo); + return Foo.Func2(foo); } } struct Program { - static main(): int { + static Main(): int { x: int = 4; y: int = Foo.Func(x); return x; } -} \ No newline at end of file +} diff --git a/src/ast.c b/src/ast.c index d57ef1f..fddf74d 100644 --- a/src/ast.c +++ b/src/ast.c @@ -988,22 +988,22 @@ char *TypeTagToString(TypeTag *tag) { char *inner = TypeTagToString(tag->value.referenceType); size_t innerStrLen = strlen(inner); - char *result = malloc(sizeof(char) * (innerStrLen + 5)); + char *result = malloc(sizeof(char) * (innerStrLen + 6)); sprintf(result, "Ref<%s>", inner); return result; } case Custom: { char *result = - malloc(sizeof(char) * (strlen(tag->value.customType) + 8)); + malloc(sizeof(char) * (strlen(tag->value.customType) + 9)); sprintf(result, "Custom<%s>", tag->value.customType); return result; } case Generic: { char *result = - malloc(sizeof(char) * (strlen(tag->value.customType) + 9)); - sprintf(result, "Generic<%s>", tag->value.customType); + malloc(sizeof(char) * (strlen(tag->value.genericType) + 10)); + sprintf(result, "Generic<%s>", tag->value.genericType); return result; } } diff --git a/src/codegen.c b/src/codegen.c index 8f82dff..dd3bf73 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -26,6 +26,7 @@ typedef struct LocalVariable typedef struct LocalGenericType { char *name; + TypeTag *concreteTypeTag; LLVMTypeRef type; } LocalGenericType; @@ -171,7 +172,7 @@ static void PopScopeFrame(Scope *scope) { free(scope->scopeStack[index].genericTypes[i].name); } - free(scope->scopeStack[index].localVariables); + free(scope->scopeStack[index].genericTypes); } scope->scopeStackCount -= 1; @@ -201,7 +202,7 @@ static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type) return NULL; } -static LLVMTypeRef LookupCustomType(char *name) +static LocalGenericType *LookupGenericType(char *name) { int32_t i, j; @@ -211,11 +212,19 @@ static LLVMTypeRef LookupCustomType(char *name) { if (strcmp(scope->scopeStack[i].genericTypes[j].name, name) == 0) { - return scope->scopeStack[i].genericTypes[j].type; + return &scope->scopeStack[i].genericTypes[j]; } } } + fprintf(stderr, "Could not find resolved generic type!\n"); + return NULL; +} + +static LLVMTypeRef LookupCustomType(char *name) +{ + int32_t i; + for (i = 0; i < structTypeDeclarationCount; i += 1) { if (strcmp(structTypeDeclarations[i].name, name) == 0) @@ -242,6 +251,10 @@ static LLVMTypeRef ResolveType(TypeTag *typeTag) { return LLVMPointerType(ResolveType(typeTag->value.referenceType), 0); } + else if (typeTag->type == Generic) + { + return LookupGenericType(typeTag->value.genericType)->type; + } else { fprintf(stderr, "Unknown type node!\n"); @@ -277,6 +290,7 @@ static void AddGenericVariable(Scope *scope, TypeTag *typeTag, char *name) scopeFrame->genericTypes, sizeof(LocalGenericType) * (scopeFrame->genericTypeCount + 1)); scopeFrame->genericTypes[index].name = strdup(name); + scopeFrame->genericTypes[index].concreteTypeTag = typeTag; scopeFrame->genericTypes[index].type = ResolveType(typeTag); scopeFrame->genericTypeCount += 1; @@ -544,11 +558,12 @@ static inline uint64_t HashTypeTags(TypeTag **tags, uint32_t count) return result; } +/* FIXME: lots of duplication with non-generic function compile */ static StructTypeFunction CompileGenericFunction( LLVMModuleRef module, char *parentStructName, LLVMTypeRef wStructPointerType, - TypeTag **genericArgumentTypes, + TypeTag **resolvedGenericArgumentTypes, uint32_t genericArgumentTypeCount, Node *functionDeclaration) { @@ -562,6 +577,7 @@ static StructTypeFunction CompileGenericFunction( ->functionSignatureArguments.count; LLVMTypeRef paramTypes[argumentCount + 1]; uint32_t paramIndex = 0; + LLVMTypeRef returnType; PushScopeFrame(scope); @@ -569,7 +585,7 @@ static StructTypeFunction CompileGenericFunction( { AddGenericVariable( scope, - genericArgumentTypes[i], + resolvedGenericArgumentTypes[i], functionDeclaration->functionDeclaration.functionSignature ->functionSignature.genericArguments->genericArguments .arguments[i] @@ -601,7 +617,7 @@ static StructTypeFunction CompileGenericFunction( for (i = 0; i < genericArgumentTypeCount; i += 1) { - strcat(functionName, TypeTagToString(genericArgumentTypes[i])); + strcat(functionName, TypeTagToString(resolvedGenericArgumentTypes[i])); } if (!isStatic) @@ -618,11 +634,13 @@ static StructTypeFunction CompileGenericFunction( ResolveType(functionSignature->functionSignature.arguments ->functionSignatureArguments.sequence[i] ->declaration.identifier->typeTag); + paramIndex += 1; } - LLVMTypeRef returnType = + returnType = ResolveType(functionSignature->functionSignature.identifier->typeTag); + LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, paramIndex, 0); @@ -698,15 +716,64 @@ static StructTypeFunction CompileGenericFunction( static LLVMValueRef LookupGenericFunction( LLVMModuleRef module, StructTypeGenericFunction *genericFunction, - TypeTag **genericArgumentTypes, - uint32_t genericArgumentTypeCount, + TypeTag **argumentTypes, + uint32_t argumentCount, LLVMTypeRef *pReturnType, uint8_t *pStatic) { uint32_t i, j; - uint64_t typeHash = - HashTypeTags(genericArgumentTypes, genericArgumentTypeCount); + uint64_t typeHash; uint8_t match = 0; + uint32_t genericArgumentTypeCount = + genericFunction->functionDeclarationNode->functionDeclaration + .functionSignature->functionSignature.genericArguments + ->genericArguments.count; + TypeTag *resolvedGenericArgumentTypes[genericArgumentTypeCount]; + + for (i = 0; i < genericArgumentTypeCount; i += 1) + { + for (j = 0; + j < genericFunction->functionDeclarationNode->functionDeclaration + .functionSignature->functionSignature.arguments + ->functionSignatureArguments.count; + j += 1) + { + if (genericFunction->functionDeclarationNode->functionDeclaration + .functionSignature->functionSignature.arguments + ->functionSignatureArguments.sequence[j] + ->declaration.identifier->typeTag->type == Generic && + strcmp( + genericFunction->functionDeclarationNode + ->functionDeclaration.functionSignature + ->functionSignature.arguments + ->functionSignatureArguments.sequence[j] + ->declaration.identifier->typeTag->value.genericType, + genericFunction->functionDeclarationNode + ->functionDeclaration.functionSignature + ->functionSignature.genericArguments->genericArguments + .arguments[i] + ->genericArgument.identifier->identifier.name) == 0) + { + resolvedGenericArgumentTypes[i] = argumentTypes[j]; + break; + } + } + } + + /* Concretize generics if we are compiling nested generic functions */ + for (i = 0; i < genericArgumentTypeCount; i += 1) + { + if (resolvedGenericArgumentTypes[i]->type == Generic) + { + resolvedGenericArgumentTypes[i] = + LookupGenericType( + resolvedGenericArgumentTypes[i]->value.genericType) + ->concreteTypeTag; + } + } + + typeHash = + HashTypeTags(resolvedGenericArgumentTypes, genericArgumentTypeCount); MonomorphizedGenericFunctionHashArray *hashArray = &genericFunction->monomorphizedFunctions @@ -719,7 +786,8 @@ static LLVMValueRef LookupGenericFunction( for (j = 0; j < hashArray->elements[i].typeCount; j += 1) { - if (hashArray->elements[i].types[j] != genericArgumentTypes[j]) + if (hashArray->elements[i].types[j] != + resolvedGenericArgumentTypes[j]) { match = 0; break; @@ -739,7 +807,7 @@ static LLVMValueRef LookupGenericFunction( module, genericFunction->parentStructName, genericFunction->parentStructPointerType, - genericArgumentTypes, + resolvedGenericArgumentTypes, genericArgumentTypeCount, genericFunction->functionDeclarationNode); @@ -757,9 +825,11 @@ static LLVMValueRef LookupGenericFunction( for (i = 0; i < genericArgumentTypeCount; i += 1) { hashArray->elements[hashArray->count].types[i] = - genericArgumentTypes[i]; + resolvedGenericArgumentTypes[i]; } hashArray->count += 1; + + hashEntry = &hashArray->elements[hashArray->count - 1]; } *pReturnType = hashEntry->function.returnType; @@ -772,8 +842,8 @@ static LLVMValueRef LookupFunctionByType( LLVMModuleRef module, LLVMTypeRef structType, char *name, - TypeTag **genericArgumentTypes, - uint32_t genericArgumentTypeCount, + TypeTag **argumentTypes, + uint32_t argumentCount, LLVMTypeRef *pReturnType, uint8_t *pStatic) { @@ -805,8 +875,8 @@ static LLVMValueRef LookupFunctionByType( return LookupGenericFunction( module, &structTypeDeclarations[i].genericFunctions[j], - genericArgumentTypes, - genericArgumentTypeCount, + argumentTypes, + argumentCount, pReturnType, pStatic); } @@ -822,8 +892,8 @@ static LLVMValueRef LookupFunctionByPointerType( LLVMModuleRef module, LLVMTypeRef structPointerType, char *name, - TypeTag **genericArgumentTypes, - uint32_t genericArgumentTypeCount, + TypeTag **argumentTypes, + uint32_t argumentCount, LLVMTypeRef *pReturnType, uint8_t *pStatic) { @@ -855,8 +925,8 @@ static LLVMValueRef LookupFunctionByPointerType( return LookupGenericFunction( module, &structTypeDeclarations[i].genericFunctions[j], - genericArgumentTypes, - genericArgumentTypeCount, + argumentTypes, + argumentCount, pReturnType, pStatic); } @@ -872,8 +942,8 @@ static LLVMValueRef LookupFunctionByInstance( LLVMModuleRef module, LLVMValueRef structPointer, char *name, - TypeTag **genericArgumentTypes, - uint32_t genericArgumentTypeCount, + TypeTag **argumentTypes, + uint32_t argumentCount, LLVMTypeRef *pReturnType, uint8_t *pStatic) { @@ -881,8 +951,8 @@ static LLVMValueRef LookupFunctionByInstance( module, LLVMTypeOf(structPointer), name, - genericArgumentTypes, - genericArgumentTypeCount, + argumentTypes, + argumentCount, pReturnType, pStatic); } @@ -966,40 +1036,34 @@ static LLVMValueRef CompileFunctionCallExpression( { uint32_t i; uint32_t argumentCount = 0; - uint32_t genericArgumentCount = 0; LLVMValueRef args [functionCallExpression->functionCallExpression.argumentSequence ->functionArgumentSequence.count + 1]; - TypeTag *genericArgumentTypes[functionCallExpression->functionCallExpression - .argumentSequence - ->functionArgumentSequence.count]; + TypeTag + *argumentTypes[functionCallExpression->functionCallExpression + .argumentSequence->functionArgumentSequence.count]; LLVMValueRef function; uint8_t isStatic; LLVMValueRef structInstance; LLVMTypeRef functionReturnType; char *returnName = ""; - /* FIXME: this is completely wrong and not how we get generic args */ for (i = 0; i < functionCallExpression->functionCallExpression .argumentSequence->functionArgumentSequence.count; i += 1) { - if (functionCallExpression->functionCallExpression.argumentSequence + argumentTypes[i] = + functionCallExpression->functionCallExpression.argumentSequence ->functionArgumentSequence.sequence[i] - ->syntaxKind == GenericArgument) - { - genericArgumentTypes[genericArgumentCount] = - functionCallExpression->functionCallExpression.argumentSequence - ->functionArgumentSequence.sequence[i] - ->declaration.identifier->typeTag; + ->typeTag; - genericArgumentCount += 1; - } + argumentCount += 1; } /* FIXME: this needs to be recursive on access chains */ - /* FIXME: this needs to be able to call same-struct functions implicitly */ + /* FIXME: this needs to be able to call same-struct functions implicitly + */ if (functionCallExpression->functionCallExpression.identifier->syntaxKind == AccessExpression) { @@ -1014,8 +1078,8 @@ static LLVMValueRef CompileFunctionCallExpression( typeReference, functionCallExpression->functionCallExpression.identifier ->accessExpression.accessor->identifier.name, - genericArgumentTypes, - genericArgumentCount, + argumentTypes, + argumentCount, &functionReturnType, &isStatic); } @@ -1029,8 +1093,8 @@ static LLVMValueRef CompileFunctionCallExpression( structInstance, functionCallExpression->functionCallExpression.identifier ->accessExpression.accessor->identifier.name, - genericArgumentTypes, - genericArgumentCount, + argumentTypes, + argumentCount, &functionReturnType, &isStatic); } @@ -1041,6 +1105,8 @@ static LLVMValueRef CompileFunctionCallExpression( return NULL; } + argumentCount = 0; + if (!isStatic) { args[argumentCount] = structInstance; @@ -1693,7 +1759,8 @@ static void Compile( { fprintf( stderr, - "top level declarations that are not structs are forbidden!\n"); + "top level declarations that are not structs are " + "forbidden!\n"); } } }