monomorphization

pull/7/head
cosmonaut 2021-06-01 12:56:56 -07:00
parent 3553269fb0
commit 9f52a19a58
3 changed files with 120 additions and 53 deletions

View File

@ -5,14 +5,14 @@ struct Foo {
static Func<T>(t: T): T {
foo: T = t;
return Func2(foo);
return Foo.Func2(foo);
}
}
struct Program {
static main(): int {
static Main(): int {
x: int = 4;
y: int = Foo.Func(x);
return x;
}
}
}

View File

@ -988,22 +988,22 @@ char *TypeTagToString(TypeTag *tag)
{
char *inner = TypeTagToString(tag->value.referenceType);
size_t innerStrLen = strlen(inner);
char *result = malloc(sizeof(char) * (innerStrLen + 5));
char *result = malloc(sizeof(char) * (innerStrLen + 6));
sprintf(result, "Ref<%s>", inner);
return result;
}
case Custom:
{
char *result =
malloc(sizeof(char) * (strlen(tag->value.customType) + 8));
malloc(sizeof(char) * (strlen(tag->value.customType) + 9));
sprintf(result, "Custom<%s>", tag->value.customType);
return result;
}
case Generic:
{
char *result =
malloc(sizeof(char) * (strlen(tag->value.customType) + 9));
sprintf(result, "Generic<%s>", tag->value.customType);
malloc(sizeof(char) * (strlen(tag->value.genericType) + 10));
sprintf(result, "Generic<%s>", tag->value.genericType);
return result;
}
}

View File

@ -26,6 +26,7 @@ typedef struct LocalVariable
typedef struct LocalGenericType
{
char *name;
TypeTag *concreteTypeTag;
LLVMTypeRef type;
} LocalGenericType;
@ -171,7 +172,7 @@ static void PopScopeFrame(Scope *scope)
{
free(scope->scopeStack[index].genericTypes[i].name);
}
free(scope->scopeStack[index].localVariables);
free(scope->scopeStack[index].genericTypes);
}
scope->scopeStackCount -= 1;
@ -201,7 +202,7 @@ static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type)
return NULL;
}
static LLVMTypeRef LookupCustomType(char *name)
static LocalGenericType *LookupGenericType(char *name)
{
int32_t i, j;
@ -211,11 +212,19 @@ static LLVMTypeRef LookupCustomType(char *name)
{
if (strcmp(scope->scopeStack[i].genericTypes[j].name, name) == 0)
{
return scope->scopeStack[i].genericTypes[j].type;
return &scope->scopeStack[i].genericTypes[j];
}
}
}
fprintf(stderr, "Could not find resolved generic type!\n");
return NULL;
}
static LLVMTypeRef LookupCustomType(char *name)
{
int32_t i;
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (strcmp(structTypeDeclarations[i].name, name) == 0)
@ -242,6 +251,10 @@ static LLVMTypeRef ResolveType(TypeTag *typeTag)
{
return LLVMPointerType(ResolveType(typeTag->value.referenceType), 0);
}
else if (typeTag->type == Generic)
{
return LookupGenericType(typeTag->value.genericType)->type;
}
else
{
fprintf(stderr, "Unknown type node!\n");
@ -277,6 +290,7 @@ static void AddGenericVariable(Scope *scope, TypeTag *typeTag, char *name)
scopeFrame->genericTypes,
sizeof(LocalGenericType) * (scopeFrame->genericTypeCount + 1));
scopeFrame->genericTypes[index].name = strdup(name);
scopeFrame->genericTypes[index].concreteTypeTag = typeTag;
scopeFrame->genericTypes[index].type = ResolveType(typeTag);
scopeFrame->genericTypeCount += 1;
@ -544,11 +558,12 @@ static inline uint64_t HashTypeTags(TypeTag **tags, uint32_t count)
return result;
}
/* FIXME: lots of duplication with non-generic function compile */
static StructTypeFunction CompileGenericFunction(
LLVMModuleRef module,
char *parentStructName,
LLVMTypeRef wStructPointerType,
TypeTag **genericArgumentTypes,
TypeTag **resolvedGenericArgumentTypes,
uint32_t genericArgumentTypeCount,
Node *functionDeclaration)
{
@ -562,6 +577,7 @@ static StructTypeFunction CompileGenericFunction(
->functionSignatureArguments.count;
LLVMTypeRef paramTypes[argumentCount + 1];
uint32_t paramIndex = 0;
LLVMTypeRef returnType;
PushScopeFrame(scope);
@ -569,7 +585,7 @@ static StructTypeFunction CompileGenericFunction(
{
AddGenericVariable(
scope,
genericArgumentTypes[i],
resolvedGenericArgumentTypes[i],
functionDeclaration->functionDeclaration.functionSignature
->functionSignature.genericArguments->genericArguments
.arguments[i]
@ -601,7 +617,7 @@ static StructTypeFunction CompileGenericFunction(
for (i = 0; i < genericArgumentTypeCount; i += 1)
{
strcat(functionName, TypeTagToString(genericArgumentTypes[i]));
strcat(functionName, TypeTagToString(resolvedGenericArgumentTypes[i]));
}
if (!isStatic)
@ -618,11 +634,13 @@ static StructTypeFunction CompileGenericFunction(
ResolveType(functionSignature->functionSignature.arguments
->functionSignatureArguments.sequence[i]
->declaration.identifier->typeTag);
paramIndex += 1;
}
LLVMTypeRef returnType =
returnType =
ResolveType(functionSignature->functionSignature.identifier->typeTag);
LLVMTypeRef functionType =
LLVMFunctionType(returnType, paramTypes, paramIndex, 0);
@ -698,15 +716,64 @@ static StructTypeFunction CompileGenericFunction(
static LLVMValueRef LookupGenericFunction(
LLVMModuleRef module,
StructTypeGenericFunction *genericFunction,
TypeTag **genericArgumentTypes,
uint32_t genericArgumentTypeCount,
TypeTag **argumentTypes,
uint32_t argumentCount,
LLVMTypeRef *pReturnType,
uint8_t *pStatic)
{
uint32_t i, j;
uint64_t typeHash =
HashTypeTags(genericArgumentTypes, genericArgumentTypeCount);
uint64_t typeHash;
uint8_t match = 0;
uint32_t genericArgumentTypeCount =
genericFunction->functionDeclarationNode->functionDeclaration
.functionSignature->functionSignature.genericArguments
->genericArguments.count;
TypeTag *resolvedGenericArgumentTypes[genericArgumentTypeCount];
for (i = 0; i < genericArgumentTypeCount; i += 1)
{
for (j = 0;
j < genericFunction->functionDeclarationNode->functionDeclaration
.functionSignature->functionSignature.arguments
->functionSignatureArguments.count;
j += 1)
{
if (genericFunction->functionDeclarationNode->functionDeclaration
.functionSignature->functionSignature.arguments
->functionSignatureArguments.sequence[j]
->declaration.identifier->typeTag->type == Generic &&
strcmp(
genericFunction->functionDeclarationNode
->functionDeclaration.functionSignature
->functionSignature.arguments
->functionSignatureArguments.sequence[j]
->declaration.identifier->typeTag->value.genericType,
genericFunction->functionDeclarationNode
->functionDeclaration.functionSignature
->functionSignature.genericArguments->genericArguments
.arguments[i]
->genericArgument.identifier->identifier.name) == 0)
{
resolvedGenericArgumentTypes[i] = argumentTypes[j];
break;
}
}
}
/* Concretize generics if we are compiling nested generic functions */
for (i = 0; i < genericArgumentTypeCount; i += 1)
{
if (resolvedGenericArgumentTypes[i]->type == Generic)
{
resolvedGenericArgumentTypes[i] =
LookupGenericType(
resolvedGenericArgumentTypes[i]->value.genericType)
->concreteTypeTag;
}
}
typeHash =
HashTypeTags(resolvedGenericArgumentTypes, genericArgumentTypeCount);
MonomorphizedGenericFunctionHashArray *hashArray =
&genericFunction->monomorphizedFunctions
@ -719,7 +786,8 @@ static LLVMValueRef LookupGenericFunction(
for (j = 0; j < hashArray->elements[i].typeCount; j += 1)
{
if (hashArray->elements[i].types[j] != genericArgumentTypes[j])
if (hashArray->elements[i].types[j] !=
resolvedGenericArgumentTypes[j])
{
match = 0;
break;
@ -739,7 +807,7 @@ static LLVMValueRef LookupGenericFunction(
module,
genericFunction->parentStructName,
genericFunction->parentStructPointerType,
genericArgumentTypes,
resolvedGenericArgumentTypes,
genericArgumentTypeCount,
genericFunction->functionDeclarationNode);
@ -757,9 +825,11 @@ static LLVMValueRef LookupGenericFunction(
for (i = 0; i < genericArgumentTypeCount; i += 1)
{
hashArray->elements[hashArray->count].types[i] =
genericArgumentTypes[i];
resolvedGenericArgumentTypes[i];
}
hashArray->count += 1;
hashEntry = &hashArray->elements[hashArray->count - 1];
}
*pReturnType = hashEntry->function.returnType;
@ -772,8 +842,8 @@ static LLVMValueRef LookupFunctionByType(
LLVMModuleRef module,
LLVMTypeRef structType,
char *name,
TypeTag **genericArgumentTypes,
uint32_t genericArgumentTypeCount,
TypeTag **argumentTypes,
uint32_t argumentCount,
LLVMTypeRef *pReturnType,
uint8_t *pStatic)
{
@ -805,8 +875,8 @@ static LLVMValueRef LookupFunctionByType(
return LookupGenericFunction(
module,
&structTypeDeclarations[i].genericFunctions[j],
genericArgumentTypes,
genericArgumentTypeCount,
argumentTypes,
argumentCount,
pReturnType,
pStatic);
}
@ -822,8 +892,8 @@ static LLVMValueRef LookupFunctionByPointerType(
LLVMModuleRef module,
LLVMTypeRef structPointerType,
char *name,
TypeTag **genericArgumentTypes,
uint32_t genericArgumentTypeCount,
TypeTag **argumentTypes,
uint32_t argumentCount,
LLVMTypeRef *pReturnType,
uint8_t *pStatic)
{
@ -855,8 +925,8 @@ static LLVMValueRef LookupFunctionByPointerType(
return LookupGenericFunction(
module,
&structTypeDeclarations[i].genericFunctions[j],
genericArgumentTypes,
genericArgumentTypeCount,
argumentTypes,
argumentCount,
pReturnType,
pStatic);
}
@ -872,8 +942,8 @@ static LLVMValueRef LookupFunctionByInstance(
LLVMModuleRef module,
LLVMValueRef structPointer,
char *name,
TypeTag **genericArgumentTypes,
uint32_t genericArgumentTypeCount,
TypeTag **argumentTypes,
uint32_t argumentCount,
LLVMTypeRef *pReturnType,
uint8_t *pStatic)
{
@ -881,8 +951,8 @@ static LLVMValueRef LookupFunctionByInstance(
module,
LLVMTypeOf(structPointer),
name,
genericArgumentTypes,
genericArgumentTypeCount,
argumentTypes,
argumentCount,
pReturnType,
pStatic);
}
@ -966,40 +1036,34 @@ static LLVMValueRef CompileFunctionCallExpression(
{
uint32_t i;
uint32_t argumentCount = 0;
uint32_t genericArgumentCount = 0;
LLVMValueRef args
[functionCallExpression->functionCallExpression.argumentSequence
->functionArgumentSequence.count +
1];
TypeTag *genericArgumentTypes[functionCallExpression->functionCallExpression
.argumentSequence
->functionArgumentSequence.count];
TypeTag
*argumentTypes[functionCallExpression->functionCallExpression
.argumentSequence->functionArgumentSequence.count];
LLVMValueRef function;
uint8_t isStatic;
LLVMValueRef structInstance;
LLVMTypeRef functionReturnType;
char *returnName = "";
/* FIXME: this is completely wrong and not how we get generic args */
for (i = 0; i < functionCallExpression->functionCallExpression
.argumentSequence->functionArgumentSequence.count;
i += 1)
{
if (functionCallExpression->functionCallExpression.argumentSequence
argumentTypes[i] =
functionCallExpression->functionCallExpression.argumentSequence
->functionArgumentSequence.sequence[i]
->syntaxKind == GenericArgument)
{
genericArgumentTypes[genericArgumentCount] =
functionCallExpression->functionCallExpression.argumentSequence
->functionArgumentSequence.sequence[i]
->declaration.identifier->typeTag;
->typeTag;
genericArgumentCount += 1;
}
argumentCount += 1;
}
/* FIXME: this needs to be recursive on access chains */
/* FIXME: this needs to be able to call same-struct functions implicitly */
/* FIXME: this needs to be able to call same-struct functions implicitly
*/
if (functionCallExpression->functionCallExpression.identifier->syntaxKind ==
AccessExpression)
{
@ -1014,8 +1078,8 @@ static LLVMValueRef CompileFunctionCallExpression(
typeReference,
functionCallExpression->functionCallExpression.identifier
->accessExpression.accessor->identifier.name,
genericArgumentTypes,
genericArgumentCount,
argumentTypes,
argumentCount,
&functionReturnType,
&isStatic);
}
@ -1029,8 +1093,8 @@ static LLVMValueRef CompileFunctionCallExpression(
structInstance,
functionCallExpression->functionCallExpression.identifier
->accessExpression.accessor->identifier.name,
genericArgumentTypes,
genericArgumentCount,
argumentTypes,
argumentCount,
&functionReturnType,
&isStatic);
}
@ -1041,6 +1105,8 @@ static LLVMValueRef CompileFunctionCallExpression(
return NULL;
}
argumentCount = 0;
if (!isStatic)
{
args[argumentCount] = structInstance;
@ -1693,7 +1759,8 @@ static void Compile(
{
fprintf(
stderr,
"top level declarations that are not structs are forbidden!\n");
"top level declarations that are not structs are "
"forbidden!\n");
}
}
}