monomorphization

pull/7/head
cosmonaut 2021-06-01 12:56:56 -07:00
parent 3553269fb0
commit 9f52a19a58
3 changed files with 120 additions and 53 deletions

View File

@ -5,12 +5,12 @@ struct Foo {
static Func<T>(t: T): T { static Func<T>(t: T): T {
foo: T = t; foo: T = t;
return Func2(foo); return Foo.Func2(foo);
} }
} }
struct Program { struct Program {
static main(): int { static Main(): int {
x: int = 4; x: int = 4;
y: int = Foo.Func(x); y: int = Foo.Func(x);
return x; return x;

View File

@ -988,22 +988,22 @@ char *TypeTagToString(TypeTag *tag)
{ {
char *inner = TypeTagToString(tag->value.referenceType); char *inner = TypeTagToString(tag->value.referenceType);
size_t innerStrLen = strlen(inner); size_t innerStrLen = strlen(inner);
char *result = malloc(sizeof(char) * (innerStrLen + 5)); char *result = malloc(sizeof(char) * (innerStrLen + 6));
sprintf(result, "Ref<%s>", inner); sprintf(result, "Ref<%s>", inner);
return result; return result;
} }
case Custom: case Custom:
{ {
char *result = char *result =
malloc(sizeof(char) * (strlen(tag->value.customType) + 8)); malloc(sizeof(char) * (strlen(tag->value.customType) + 9));
sprintf(result, "Custom<%s>", tag->value.customType); sprintf(result, "Custom<%s>", tag->value.customType);
return result; return result;
} }
case Generic: case Generic:
{ {
char *result = char *result =
malloc(sizeof(char) * (strlen(tag->value.customType) + 9)); malloc(sizeof(char) * (strlen(tag->value.genericType) + 10));
sprintf(result, "Generic<%s>", tag->value.customType); sprintf(result, "Generic<%s>", tag->value.genericType);
return result; return result;
} }
} }

View File

@ -26,6 +26,7 @@ typedef struct LocalVariable
typedef struct LocalGenericType typedef struct LocalGenericType
{ {
char *name; char *name;
TypeTag *concreteTypeTag;
LLVMTypeRef type; LLVMTypeRef type;
} LocalGenericType; } LocalGenericType;
@ -171,7 +172,7 @@ static void PopScopeFrame(Scope *scope)
{ {
free(scope->scopeStack[index].genericTypes[i].name); free(scope->scopeStack[index].genericTypes[i].name);
} }
free(scope->scopeStack[index].localVariables); free(scope->scopeStack[index].genericTypes);
} }
scope->scopeStackCount -= 1; scope->scopeStackCount -= 1;
@ -201,7 +202,7 @@ static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type)
return NULL; return NULL;
} }
static LLVMTypeRef LookupCustomType(char *name) static LocalGenericType *LookupGenericType(char *name)
{ {
int32_t i, j; int32_t i, j;
@ -211,11 +212,19 @@ static LLVMTypeRef LookupCustomType(char *name)
{ {
if (strcmp(scope->scopeStack[i].genericTypes[j].name, name) == 0) if (strcmp(scope->scopeStack[i].genericTypes[j].name, name) == 0)
{ {
return scope->scopeStack[i].genericTypes[j].type; return &scope->scopeStack[i].genericTypes[j];
} }
} }
} }
fprintf(stderr, "Could not find resolved generic type!\n");
return NULL;
}
static LLVMTypeRef LookupCustomType(char *name)
{
int32_t i;
for (i = 0; i < structTypeDeclarationCount; i += 1) for (i = 0; i < structTypeDeclarationCount; i += 1)
{ {
if (strcmp(structTypeDeclarations[i].name, name) == 0) if (strcmp(structTypeDeclarations[i].name, name) == 0)
@ -242,6 +251,10 @@ static LLVMTypeRef ResolveType(TypeTag *typeTag)
{ {
return LLVMPointerType(ResolveType(typeTag->value.referenceType), 0); return LLVMPointerType(ResolveType(typeTag->value.referenceType), 0);
} }
else if (typeTag->type == Generic)
{
return LookupGenericType(typeTag->value.genericType)->type;
}
else else
{ {
fprintf(stderr, "Unknown type node!\n"); fprintf(stderr, "Unknown type node!\n");
@ -277,6 +290,7 @@ static void AddGenericVariable(Scope *scope, TypeTag *typeTag, char *name)
scopeFrame->genericTypes, scopeFrame->genericTypes,
sizeof(LocalGenericType) * (scopeFrame->genericTypeCount + 1)); sizeof(LocalGenericType) * (scopeFrame->genericTypeCount + 1));
scopeFrame->genericTypes[index].name = strdup(name); scopeFrame->genericTypes[index].name = strdup(name);
scopeFrame->genericTypes[index].concreteTypeTag = typeTag;
scopeFrame->genericTypes[index].type = ResolveType(typeTag); scopeFrame->genericTypes[index].type = ResolveType(typeTag);
scopeFrame->genericTypeCount += 1; scopeFrame->genericTypeCount += 1;
@ -544,11 +558,12 @@ static inline uint64_t HashTypeTags(TypeTag **tags, uint32_t count)
return result; return result;
} }
/* FIXME: lots of duplication with non-generic function compile */
static StructTypeFunction CompileGenericFunction( static StructTypeFunction CompileGenericFunction(
LLVMModuleRef module, LLVMModuleRef module,
char *parentStructName, char *parentStructName,
LLVMTypeRef wStructPointerType, LLVMTypeRef wStructPointerType,
TypeTag **genericArgumentTypes, TypeTag **resolvedGenericArgumentTypes,
uint32_t genericArgumentTypeCount, uint32_t genericArgumentTypeCount,
Node *functionDeclaration) Node *functionDeclaration)
{ {
@ -562,6 +577,7 @@ static StructTypeFunction CompileGenericFunction(
->functionSignatureArguments.count; ->functionSignatureArguments.count;
LLVMTypeRef paramTypes[argumentCount + 1]; LLVMTypeRef paramTypes[argumentCount + 1];
uint32_t paramIndex = 0; uint32_t paramIndex = 0;
LLVMTypeRef returnType;
PushScopeFrame(scope); PushScopeFrame(scope);
@ -569,7 +585,7 @@ static StructTypeFunction CompileGenericFunction(
{ {
AddGenericVariable( AddGenericVariable(
scope, scope,
genericArgumentTypes[i], resolvedGenericArgumentTypes[i],
functionDeclaration->functionDeclaration.functionSignature functionDeclaration->functionDeclaration.functionSignature
->functionSignature.genericArguments->genericArguments ->functionSignature.genericArguments->genericArguments
.arguments[i] .arguments[i]
@ -601,7 +617,7 @@ static StructTypeFunction CompileGenericFunction(
for (i = 0; i < genericArgumentTypeCount; i += 1) for (i = 0; i < genericArgumentTypeCount; i += 1)
{ {
strcat(functionName, TypeTagToString(genericArgumentTypes[i])); strcat(functionName, TypeTagToString(resolvedGenericArgumentTypes[i]));
} }
if (!isStatic) if (!isStatic)
@ -618,11 +634,13 @@ static StructTypeFunction CompileGenericFunction(
ResolveType(functionSignature->functionSignature.arguments ResolveType(functionSignature->functionSignature.arguments
->functionSignatureArguments.sequence[i] ->functionSignatureArguments.sequence[i]
->declaration.identifier->typeTag); ->declaration.identifier->typeTag);
paramIndex += 1; paramIndex += 1;
} }
LLVMTypeRef returnType = returnType =
ResolveType(functionSignature->functionSignature.identifier->typeTag); ResolveType(functionSignature->functionSignature.identifier->typeTag);
LLVMTypeRef functionType = LLVMTypeRef functionType =
LLVMFunctionType(returnType, paramTypes, paramIndex, 0); LLVMFunctionType(returnType, paramTypes, paramIndex, 0);
@ -698,15 +716,64 @@ static StructTypeFunction CompileGenericFunction(
static LLVMValueRef LookupGenericFunction( static LLVMValueRef LookupGenericFunction(
LLVMModuleRef module, LLVMModuleRef module,
StructTypeGenericFunction *genericFunction, StructTypeGenericFunction *genericFunction,
TypeTag **genericArgumentTypes, TypeTag **argumentTypes,
uint32_t genericArgumentTypeCount, uint32_t argumentCount,
LLVMTypeRef *pReturnType, LLVMTypeRef *pReturnType,
uint8_t *pStatic) uint8_t *pStatic)
{ {
uint32_t i, j; uint32_t i, j;
uint64_t typeHash = uint64_t typeHash;
HashTypeTags(genericArgumentTypes, genericArgumentTypeCount);
uint8_t match = 0; uint8_t match = 0;
uint32_t genericArgumentTypeCount =
genericFunction->functionDeclarationNode->functionDeclaration
.functionSignature->functionSignature.genericArguments
->genericArguments.count;
TypeTag *resolvedGenericArgumentTypes[genericArgumentTypeCount];
for (i = 0; i < genericArgumentTypeCount; i += 1)
{
for (j = 0;
j < genericFunction->functionDeclarationNode->functionDeclaration
.functionSignature->functionSignature.arguments
->functionSignatureArguments.count;
j += 1)
{
if (genericFunction->functionDeclarationNode->functionDeclaration
.functionSignature->functionSignature.arguments
->functionSignatureArguments.sequence[j]
->declaration.identifier->typeTag->type == Generic &&
strcmp(
genericFunction->functionDeclarationNode
->functionDeclaration.functionSignature
->functionSignature.arguments
->functionSignatureArguments.sequence[j]
->declaration.identifier->typeTag->value.genericType,
genericFunction->functionDeclarationNode
->functionDeclaration.functionSignature
->functionSignature.genericArguments->genericArguments
.arguments[i]
->genericArgument.identifier->identifier.name) == 0)
{
resolvedGenericArgumentTypes[i] = argumentTypes[j];
break;
}
}
}
/* Concretize generics if we are compiling nested generic functions */
for (i = 0; i < genericArgumentTypeCount; i += 1)
{
if (resolvedGenericArgumentTypes[i]->type == Generic)
{
resolvedGenericArgumentTypes[i] =
LookupGenericType(
resolvedGenericArgumentTypes[i]->value.genericType)
->concreteTypeTag;
}
}
typeHash =
HashTypeTags(resolvedGenericArgumentTypes, genericArgumentTypeCount);
MonomorphizedGenericFunctionHashArray *hashArray = MonomorphizedGenericFunctionHashArray *hashArray =
&genericFunction->monomorphizedFunctions &genericFunction->monomorphizedFunctions
@ -719,7 +786,8 @@ static LLVMValueRef LookupGenericFunction(
for (j = 0; j < hashArray->elements[i].typeCount; j += 1) for (j = 0; j < hashArray->elements[i].typeCount; j += 1)
{ {
if (hashArray->elements[i].types[j] != genericArgumentTypes[j]) if (hashArray->elements[i].types[j] !=
resolvedGenericArgumentTypes[j])
{ {
match = 0; match = 0;
break; break;
@ -739,7 +807,7 @@ static LLVMValueRef LookupGenericFunction(
module, module,
genericFunction->parentStructName, genericFunction->parentStructName,
genericFunction->parentStructPointerType, genericFunction->parentStructPointerType,
genericArgumentTypes, resolvedGenericArgumentTypes,
genericArgumentTypeCount, genericArgumentTypeCount,
genericFunction->functionDeclarationNode); genericFunction->functionDeclarationNode);
@ -757,9 +825,11 @@ static LLVMValueRef LookupGenericFunction(
for (i = 0; i < genericArgumentTypeCount; i += 1) for (i = 0; i < genericArgumentTypeCount; i += 1)
{ {
hashArray->elements[hashArray->count].types[i] = hashArray->elements[hashArray->count].types[i] =
genericArgumentTypes[i]; resolvedGenericArgumentTypes[i];
} }
hashArray->count += 1; hashArray->count += 1;
hashEntry = &hashArray->elements[hashArray->count - 1];
} }
*pReturnType = hashEntry->function.returnType; *pReturnType = hashEntry->function.returnType;
@ -772,8 +842,8 @@ static LLVMValueRef LookupFunctionByType(
LLVMModuleRef module, LLVMModuleRef module,
LLVMTypeRef structType, LLVMTypeRef structType,
char *name, char *name,
TypeTag **genericArgumentTypes, TypeTag **argumentTypes,
uint32_t genericArgumentTypeCount, uint32_t argumentCount,
LLVMTypeRef *pReturnType, LLVMTypeRef *pReturnType,
uint8_t *pStatic) uint8_t *pStatic)
{ {
@ -805,8 +875,8 @@ static LLVMValueRef LookupFunctionByType(
return LookupGenericFunction( return LookupGenericFunction(
module, module,
&structTypeDeclarations[i].genericFunctions[j], &structTypeDeclarations[i].genericFunctions[j],
genericArgumentTypes, argumentTypes,
genericArgumentTypeCount, argumentCount,
pReturnType, pReturnType,
pStatic); pStatic);
} }
@ -822,8 +892,8 @@ static LLVMValueRef LookupFunctionByPointerType(
LLVMModuleRef module, LLVMModuleRef module,
LLVMTypeRef structPointerType, LLVMTypeRef structPointerType,
char *name, char *name,
TypeTag **genericArgumentTypes, TypeTag **argumentTypes,
uint32_t genericArgumentTypeCount, uint32_t argumentCount,
LLVMTypeRef *pReturnType, LLVMTypeRef *pReturnType,
uint8_t *pStatic) uint8_t *pStatic)
{ {
@ -855,8 +925,8 @@ static LLVMValueRef LookupFunctionByPointerType(
return LookupGenericFunction( return LookupGenericFunction(
module, module,
&structTypeDeclarations[i].genericFunctions[j], &structTypeDeclarations[i].genericFunctions[j],
genericArgumentTypes, argumentTypes,
genericArgumentTypeCount, argumentCount,
pReturnType, pReturnType,
pStatic); pStatic);
} }
@ -872,8 +942,8 @@ static LLVMValueRef LookupFunctionByInstance(
LLVMModuleRef module, LLVMModuleRef module,
LLVMValueRef structPointer, LLVMValueRef structPointer,
char *name, char *name,
TypeTag **genericArgumentTypes, TypeTag **argumentTypes,
uint32_t genericArgumentTypeCount, uint32_t argumentCount,
LLVMTypeRef *pReturnType, LLVMTypeRef *pReturnType,
uint8_t *pStatic) uint8_t *pStatic)
{ {
@ -881,8 +951,8 @@ static LLVMValueRef LookupFunctionByInstance(
module, module,
LLVMTypeOf(structPointer), LLVMTypeOf(structPointer),
name, name,
genericArgumentTypes, argumentTypes,
genericArgumentTypeCount, argumentCount,
pReturnType, pReturnType,
pStatic); pStatic);
} }
@ -966,40 +1036,34 @@ static LLVMValueRef CompileFunctionCallExpression(
{ {
uint32_t i; uint32_t i;
uint32_t argumentCount = 0; uint32_t argumentCount = 0;
uint32_t genericArgumentCount = 0;
LLVMValueRef args LLVMValueRef args
[functionCallExpression->functionCallExpression.argumentSequence [functionCallExpression->functionCallExpression.argumentSequence
->functionArgumentSequence.count + ->functionArgumentSequence.count +
1]; 1];
TypeTag *genericArgumentTypes[functionCallExpression->functionCallExpression TypeTag
.argumentSequence *argumentTypes[functionCallExpression->functionCallExpression
->functionArgumentSequence.count]; .argumentSequence->functionArgumentSequence.count];
LLVMValueRef function; LLVMValueRef function;
uint8_t isStatic; uint8_t isStatic;
LLVMValueRef structInstance; LLVMValueRef structInstance;
LLVMTypeRef functionReturnType; LLVMTypeRef functionReturnType;
char *returnName = ""; char *returnName = "";
/* FIXME: this is completely wrong and not how we get generic args */
for (i = 0; i < functionCallExpression->functionCallExpression for (i = 0; i < functionCallExpression->functionCallExpression
.argumentSequence->functionArgumentSequence.count; .argumentSequence->functionArgumentSequence.count;
i += 1) i += 1)
{ {
if (functionCallExpression->functionCallExpression.argumentSequence argumentTypes[i] =
->functionArgumentSequence.sequence[i]
->syntaxKind == GenericArgument)
{
genericArgumentTypes[genericArgumentCount] =
functionCallExpression->functionCallExpression.argumentSequence functionCallExpression->functionCallExpression.argumentSequence
->functionArgumentSequence.sequence[i] ->functionArgumentSequence.sequence[i]
->declaration.identifier->typeTag; ->typeTag;
genericArgumentCount += 1; argumentCount += 1;
}
} }
/* FIXME: this needs to be recursive on access chains */ /* FIXME: this needs to be recursive on access chains */
/* FIXME: this needs to be able to call same-struct functions implicitly */ /* FIXME: this needs to be able to call same-struct functions implicitly
*/
if (functionCallExpression->functionCallExpression.identifier->syntaxKind == if (functionCallExpression->functionCallExpression.identifier->syntaxKind ==
AccessExpression) AccessExpression)
{ {
@ -1014,8 +1078,8 @@ static LLVMValueRef CompileFunctionCallExpression(
typeReference, typeReference,
functionCallExpression->functionCallExpression.identifier functionCallExpression->functionCallExpression.identifier
->accessExpression.accessor->identifier.name, ->accessExpression.accessor->identifier.name,
genericArgumentTypes, argumentTypes,
genericArgumentCount, argumentCount,
&functionReturnType, &functionReturnType,
&isStatic); &isStatic);
} }
@ -1029,8 +1093,8 @@ static LLVMValueRef CompileFunctionCallExpression(
structInstance, structInstance,
functionCallExpression->functionCallExpression.identifier functionCallExpression->functionCallExpression.identifier
->accessExpression.accessor->identifier.name, ->accessExpression.accessor->identifier.name,
genericArgumentTypes, argumentTypes,
genericArgumentCount, argumentCount,
&functionReturnType, &functionReturnType,
&isStatic); &isStatic);
} }
@ -1041,6 +1105,8 @@ static LLVMValueRef CompileFunctionCallExpression(
return NULL; return NULL;
} }
argumentCount = 0;
if (!isStatic) if (!isStatic)
{ {
args[argumentCount] = structInstance; args[argumentCount] = structInstance;
@ -1693,7 +1759,8 @@ static void Compile(
{ {
fprintf( fprintf(
stderr, stderr,
"top level declarations that are not structs are forbidden!\n"); "top level declarations that are not structs are "
"forbidden!\n");
} }
} }
} }