From a870a2c32e42cbc4acef1c7a0743d91484f68e03 Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Thu, 3 Jun 2021 14:40:14 -0700 Subject: [PATCH] monomorphizing generic structs --- src/codegen.c | 757 ++++++++++++++++++++++++++++++++------------------ 1 file changed, 484 insertions(+), 273 deletions(-) diff --git a/src/codegen.c b/src/codegen.c index 2e67b17..8919f79 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -83,10 +83,11 @@ typedef struct MonomorphizedGenericFunctionHashArray #define NUM_MONOMORPHIZED_HASH_BUCKETS 1031 +typedef struct StructTypeDeclaration StructTypeDeclaration; + typedef struct StructTypeGenericFunction { - char *parentStructName; - LLVMTypeRef parentStructPointerType; + StructTypeDeclaration *parentStruct; char *name; Node *functionDeclarationNode; uint8_t isStatic; @@ -94,8 +95,9 @@ typedef struct StructTypeGenericFunction monomorphizedFunctions[NUM_MONOMORPHIZED_HASH_BUCKETS]; } StructTypeGenericFunction; -typedef struct StructTypeDeclaration +struct StructTypeDeclaration { + LLVMModuleRef module; char *name; LLVMTypeRef structType; LLVMTypeRef structPointerType; @@ -107,7 +109,7 @@ typedef struct StructTypeDeclaration StructTypeGenericFunction *genericFunctions; uint32_t genericFunctionCount; -} StructTypeDeclaration; +}; StructTypeDeclaration *structTypeDeclarations; uint32_t structTypeDeclarationCount; @@ -128,6 +130,7 @@ typedef struct MonomorphizedGenericStructHashArray typedef struct GenericStructTypeDeclaration { + LLVMModuleRef module; Node *structDeclarationNode; MonomorphizedGenericStructHashArray monomorphizedStructs[NUM_MONOMORPHIZED_HASH_BUCKETS]; @@ -148,16 +151,22 @@ uint32_t systemFunctionCount; /* FUNCTION FORWARD DECLARATIONS */ static LLVMBasicBlockRef CompileStatement( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, LLVMValueRef function, Node *statement); static LLVMValueRef CompileExpression( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, Node *expression); +static void CompileFunction( + StructTypeDeclaration *structTypeDeclaration, + Node *functionDeclaration); + +static LLVMTypeRef ResolveType(TypeTag *typeTag); + static Scope *CreateScope() { Scope *scope = malloc(sizeof(Scope)); @@ -215,6 +224,122 @@ static void PopScopeFrame(Scope *scope) realloc(scope->scopeStack, sizeof(ScopeFrame) * scope->scopeStackCount); } +static inline uint64_t HashTypeTags(TypeTag **tags, uint32_t count) +{ + const uint64_t HASH_FACTOR = 97; + uint64_t result = 1; + uint32_t i; + + for (i = 0; i < count; i += 1) + { + result *= HASH_FACTOR + str_hash(TypeTagToString(tags[i])); + } + + return result; +} + +static void AddLocalVariable( + Scope *scope, + LLVMValueRef pointer, /* can be NULL */ + LLVMValueRef value, /* can be NULL */ + char *name) +{ + ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1]; + uint32_t index = scopeFrame->localVariableCount; + + scopeFrame->localVariables = realloc( + scopeFrame->localVariables, + sizeof(LocalVariable) * (scopeFrame->localVariableCount + 1)); + scopeFrame->localVariables[index].name = strdup(name); + scopeFrame->localVariables[index].pointer = pointer; + scopeFrame->localVariables[index].value = value; + + scopeFrame->localVariableCount += 1; +} + +static void AddGenericVariable(Scope *scope, TypeTag *typeTag, char *name) +{ + ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1]; + uint32_t index = scopeFrame->genericTypeCount; + + scopeFrame->genericTypes = realloc( + scopeFrame->genericTypes, + sizeof(LocalGenericType) * (scopeFrame->genericTypeCount + 1)); + scopeFrame->genericTypes[index].name = strdup(name); + scopeFrame->genericTypes[index].concreteTypeTag = typeTag; + scopeFrame->genericTypes[index].type = ResolveType(typeTag); + + scopeFrame->genericTypeCount += 1; +} + +static void AddStructVariablesToScope( + StructTypeDeclaration *structTypeDeclaration, + LLVMBuilderRef builder, + LLVMValueRef structPointer) +{ + uint32_t i; + + for (i = 0; i < structTypeDeclaration->fieldCount; i += 1) + { + char *ptrName = strdup(structTypeDeclaration->fields[i].name); + strcat(ptrName, "_ptr"); /* FIXME: needs to be realloc'd */ + LLVMValueRef elementPointer = LLVMBuildStructGEP( + builder, + structPointer, + structTypeDeclaration->fields[i].index, + ptrName); + free(ptrName); + + AddLocalVariable( + scope, + elementPointer, + NULL, + structTypeDeclaration->fields[i].name); + } +} + +static void AddFieldToStructDeclaration( + StructTypeDeclaration *structTypeDeclaration, + char *name) +{ + structTypeDeclaration->fields = realloc( + structTypeDeclaration->fields, + sizeof(StructTypeField) * (structTypeDeclaration->fieldCount + 1)); + structTypeDeclaration->fields[structTypeDeclaration->fieldCount].name = + strdup(name); + structTypeDeclaration->fields[structTypeDeclaration->fieldCount].index = + structTypeDeclaration->fieldCount; + structTypeDeclaration->fieldCount += 1; +} + +static void AddGenericStructDeclaration( + LLVMModuleRef module, + Node *structDeclarationNode) +{ + uint32_t i; + + genericStructTypeDeclarations = realloc( + genericStructTypeDeclarations, + sizeof(GenericStructTypeDeclaration) * + (genericStructTypeDeclarationCount + 1)); + genericStructTypeDeclarations[genericStructTypeDeclarationCount] + .structDeclarationNode = structDeclarationNode; + genericStructTypeDeclarations[genericStructTypeDeclarationCount].module = + module; + + for (i = 0; i < NUM_MONOMORPHIZED_HASH_BUCKETS; i += 1) + { + genericStructTypeDeclarations[genericStructTypeDeclarationCount] + .monomorphizedStructs[i] + .elements = NULL; + genericStructTypeDeclarations[genericStructTypeDeclarationCount] + .monomorphizedStructs[i] + .count = 0; + } + + genericStructTypeDeclarationCount += 1; +} + static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type) { switch (type) @@ -284,6 +409,195 @@ static LLVMTypeRef LookupCustomType(char *name) return NULL; } +static StructTypeDeclaration CompileMonomorphizedGenericStruct( + GenericStructTypeDeclaration *genericStructTypeDeclaration, + TypeTag **genericArgumentTypes, + uint32_t genericArgumentTypeCount) +{ + uint32_t i = 0; + uint32_t nameLen; + uint32_t fieldCount = 0; + Node *structDeclarationNode = + genericStructTypeDeclaration->structDeclarationNode; + uint32_t declarationCount = + structDeclarationNode->structDeclaration.declarationSequence + ->declarationSequence.count; + LLVMTypeRef types[declarationCount]; + + PushScopeFrame(scope); + + for (i = 0; i < genericArgumentTypeCount; i += 1) + { + AddGenericVariable( + scope, + genericArgumentTypes[i], + structDeclarationNode->structDeclaration.genericDeclarations + ->genericDeclarations.declarations[i] + ->genericDeclaration.identifier->identifier.name); + } + + char *structName = strdup( + structDeclarationNode->structDeclaration.identifier->identifier.name); + nameLen = strlen(structName); + + for (i = 0; i < genericArgumentTypeCount; i += 1) + { + char *inner = TypeTagToString(genericArgumentTypes[i]); + nameLen += 2 + strlen(inner); + structName = realloc(structName, sizeof(char) * nameLen); + strcat(structName, "_"); + strcat(structName, inner); + } + + LLVMContextRef context = + LLVMGetGlobalContext(); /* FIXME: should we pass a context? */ + LLVMTypeRef wStructType = LLVMStructCreateNamed(context, structName); + LLVMTypeRef wStructPointerType = LLVMPointerType(wStructType, 0); + + StructTypeDeclaration declaration; + declaration.module = genericStructTypeDeclaration->module; + declaration.name = structName; + declaration.structType = wStructType; + declaration.structPointerType = wStructPointerType; + declaration.genericFunctions = NULL; + declaration.genericFunctionCount = 0; + declaration.functions = NULL; + declaration.functionCount = 0; + declaration.fields = NULL; + declaration.fieldCount = 0; + + /* first build the structure def */ + for (i = 0; i < declarationCount; i += 1) + { + switch (structDeclarationNode->structDeclaration.declarationSequence + ->declarationSequence.sequence[i] + ->syntaxKind) + { + case Declaration: + types[fieldCount] = ResolveType( + structDeclarationNode->structDeclaration.declarationSequence + ->declarationSequence.sequence[i] + ->declaration.identifier->typeTag); + AddFieldToStructDeclaration( + &declaration, + structDeclarationNode->structDeclaration.declarationSequence + ->declarationSequence.sequence[i] + ->declaration.identifier->identifier.name); + fieldCount += 1; + break; + } + } + + LLVMStructSetBody(wStructType, types, fieldCount, 1); + + /* now we wire up the functions */ + for (i = 0; i < declarationCount; i += 1) + { + switch (structDeclarationNode->structDeclaration.declarationSequence + ->declarationSequence.sequence[i] + ->syntaxKind) + { + case FunctionDeclaration: + CompileFunction( + &declaration, + structDeclarationNode->structDeclaration.declarationSequence + ->declarationSequence.sequence[i]); + break; + } + } + + PopScopeFrame(scope); + + return declaration; +} + +static StructTypeDeclaration *LookupGenericStructType( + ConcreteGenericTypeTag *typeTag) +{ + uint32_t i, j, k; + uint64_t typeHash; + uint8_t match; + TypeTag *genericTypeTags[typeTag->genericArgumentCount]; + + for (i = 0; i < typeTag->genericArgumentCount; i += 1) + { + genericTypeTags[i] = ConcretizeType(typeTag->genericArguments[i]); + } + + for (i = 0; i < genericStructTypeDeclarationCount; i += 1) + { + if (strcmp( + genericStructTypeDeclarations[i] + .structDeclarationNode->structDeclaration.identifier + ->identifier.name, + typeTag->name) == 0) + { + typeHash = + HashTypeTags(genericTypeTags, typeTag->genericArgumentCount); + + MonomorphizedGenericStructHashArray *hashArray = + &genericStructTypeDeclarations[i].monomorphizedStructs + [typeHash % NUM_MONOMORPHIZED_HASH_BUCKETS]; + + MonomorphizedGenericStructHashEntry *hashEntry = NULL; + + for (j = 0; j < hashArray->count; j += 1) + { + match = 1; + + for (k = 0; k < hashArray->elements[j].typeCount; k += 1) + { + if (hashArray->elements[j].types[k] != genericTypeTags[k]) + { + match = 0; + break; + } + } + + if (match) + { + hashEntry = &hashArray->elements[i]; + break; + } + } + + if (hashEntry == NULL) + { + StructTypeDeclaration structTypeDeclaration = + CompileMonomorphizedGenericStruct( + &genericStructTypeDeclarations[i], + genericTypeTags, + typeTag->genericArgumentCount); + + hashArray->elements = realloc( + hashArray->elements, + sizeof(MonomorphizedGenericStructHashEntry) * + (hashArray->count + 1)); + hashArray->elements[hashArray->count].key = typeHash; + hashArray->elements[hashArray->count].types = + malloc(sizeof(TypeTag *) * typeTag->genericArgumentCount); + hashArray->elements[hashArray->count].typeCount = + typeTag->genericArgumentCount; + hashArray->elements[hashArray->count].structDeclaration = + structTypeDeclaration; + for (j = 0; j < typeTag->genericArgumentCount; j += 1) + { + hashArray->elements[hashArray->count].types[j] = + genericTypeTags[j]; + } + hashArray->count += 1; + + hashEntry = &hashArray->elements[hashArray->count - 1]; + } + + return &hashEntry->structDeclaration; + } + } + + fprintf(stderr, "Could not find generic struct declaration!"); + return NULL; +} + static LLVMTypeRef ResolveType(TypeTag *typeTag) { if (typeTag->type == Primitive) @@ -302,6 +616,11 @@ static LLVMTypeRef ResolveType(TypeTag *typeTag) { return LookupGenericType(typeTag->value.genericType)->type; } + else if (typeTag->type == ConcreteGeneric) + { + return LookupGenericStructType(&typeTag->value.concreteGenericType) + ->structType; + } else { fprintf(stderr, "Unknown type node!\n"); @@ -340,73 +659,6 @@ static SystemFunction *LookupSystemFunction(Node *systemCallExpression) return NULL; } -static void AddLocalVariable( - Scope *scope, - LLVMValueRef pointer, /* can be NULL */ - LLVMValueRef value, /* can be NULL */ - char *name) -{ - ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1]; - uint32_t index = scopeFrame->localVariableCount; - - scopeFrame->localVariables = realloc( - scopeFrame->localVariables, - sizeof(LocalVariable) * (scopeFrame->localVariableCount + 1)); - scopeFrame->localVariables[index].name = strdup(name); - scopeFrame->localVariables[index].pointer = pointer; - scopeFrame->localVariables[index].value = value; - - scopeFrame->localVariableCount += 1; -} - -static void AddGenericVariable(Scope *scope, TypeTag *typeTag, char *name) -{ - ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1]; - uint32_t index = scopeFrame->genericTypeCount; - - scopeFrame->genericTypes = realloc( - scopeFrame->genericTypes, - sizeof(LocalGenericType) * (scopeFrame->genericTypeCount + 1)); - scopeFrame->genericTypes[index].name = strdup(name); - scopeFrame->genericTypes[index].concreteTypeTag = typeTag; - scopeFrame->genericTypes[index].type = ResolveType(typeTag); - - scopeFrame->genericTypeCount += 1; -} - -static void AddStructVariablesToScope( - LLVMBuilderRef builder, - LLVMValueRef structPointer) -{ - uint32_t i, j; - - for (i = 0; i < structTypeDeclarationCount; i += 1) - { - if (structTypeDeclarations[i].structPointerType == - LLVMTypeOf(structPointer)) - { - for (j = 0; j < structTypeDeclarations[i].fieldCount; j += 1) - { - char *ptrName = - strdup(structTypeDeclarations[i].fields[j].name); - strcat(ptrName, "_ptr"); - LLVMValueRef elementPointer = LLVMBuildStructGEP( - builder, - structPointer, - structTypeDeclarations[i].fields[j].index, - ptrName); - free(ptrName); - - AddLocalVariable( - scope, - elementPointer, - NULL, - structTypeDeclarations[i].fields[j].name); - } - } - } -} - static LLVMTypeRef FindStructType(char *name) { uint32_t i; @@ -504,18 +756,17 @@ static LLVMValueRef FindVariableValue(LLVMBuilderRef builder, char *name) return NULL; } -static void AddStructDeclaration( +static StructTypeDeclaration *AddStructDeclaration( + LLVMModuleRef module, LLVMTypeRef wStructType, LLVMTypeRef wStructPointerType, - char *name, - Node **fieldDeclarations, - uint32_t fieldDeclarationCount) + char *name) { - uint32_t i; uint32_t index = structTypeDeclarationCount; structTypeDeclarations = realloc( structTypeDeclarations, sizeof(StructTypeDeclaration) * (structTypeDeclarationCount + 1)); + structTypeDeclarations[index].module = module; structTypeDeclarations[index].structType = wStructType; structTypeDeclarations[index].structPointerType = wStructPointerType; structTypeDeclarations[index].name = strdup(name); @@ -526,145 +777,67 @@ static void AddStructDeclaration( structTypeDeclarations[index].genericFunctions = NULL; structTypeDeclarations[index].genericFunctionCount = 0; - for (i = 0; i < fieldDeclarationCount; i += 1) - { - structTypeDeclarations[index].fields = realloc( - structTypeDeclarations[index].fields, - sizeof(StructTypeField) * - (structTypeDeclarations[index].fieldCount + 1)); - structTypeDeclarations[index].fields[i].name = strdup( - fieldDeclarations[i]->declaration.identifier->identifier.name); - structTypeDeclarations[index].fields[i].index = i; - structTypeDeclarations[index].fieldCount += 1; - } - structTypeDeclarationCount += 1; + + return &structTypeDeclarations[index]; } -static void AddGenericStructDeclaration(Node *structDeclarationNode) -{ - uint32_t i; - - genericStructTypeDeclarations = realloc( - genericStructTypeDeclarations, - sizeof(GenericStructTypeDeclaration) * - (genericStructTypeDeclarationCount + 1)); - genericStructTypeDeclarations[genericStructTypeDeclarationCount] - .structDeclarationNode = structDeclarationNode; - - for (i = 0; i < NUM_MONOMORPHIZED_HASH_BUCKETS; i += 1) - { - genericStructTypeDeclarations[genericStructTypeDeclarationCount] - .monomorphizedStructs[i] - .elements = NULL; - genericStructTypeDeclarations[genericStructTypeDeclarationCount] - .monomorphizedStructs[i] - .count = 0; - } - - genericStructTypeDeclarationCount += 1; -} - -/* FIXME: pass the declaration itself */ static void DeclareStructFunction( - LLVMTypeRef wStructPointerType, + StructTypeDeclaration *structTypeDeclaration, LLVMValueRef function, LLVMTypeRef returnType, uint8_t isStatic, char *name) { - uint32_t i, index; + uint32_t index = structTypeDeclaration->functionCount; - for (i = 0; i < structTypeDeclarationCount; i += 1) - { - if (structTypeDeclarations[i].structPointerType == wStructPointerType) - { - index = structTypeDeclarations[i].functionCount; - structTypeDeclarations[i].functions = realloc( - structTypeDeclarations[i].functions, - sizeof(StructTypeFunction) * - (structTypeDeclarations[i].functionCount + 1)); - structTypeDeclarations[i].functions[index].name = strdup(name); - structTypeDeclarations[i].functions[index].function = function; - structTypeDeclarations[i].functions[index].returnType = returnType; - structTypeDeclarations[i].functions[index].isStatic = isStatic; - structTypeDeclarations[i].functionCount += 1; - - return; - } - } - - fprintf(stderr, "Could not find struct type for function!\n"); + structTypeDeclaration->functions = realloc( + structTypeDeclaration->functions, + sizeof(StructTypeFunction) * + (structTypeDeclaration->functionCount + 1)); + structTypeDeclaration->functions[index].name = strdup(name); + structTypeDeclaration->functions[index].function = function; + structTypeDeclaration->functions[index].returnType = returnType; + structTypeDeclaration->functions[index].isStatic = isStatic; + structTypeDeclaration->functionCount += 1; } -/* FIXME: pass the declaration itself */ static void DeclareGenericStructFunction( - LLVMTypeRef wStructPointerType, + StructTypeDeclaration *structTypeDeclaration, Node *functionDeclarationNode, uint8_t isStatic, - char *parentStructName, char *name) { - uint32_t i, j, index; + uint32_t i, index; - for (i = 0; i < structTypeDeclarationCount; i += 1) + index = structTypeDeclaration->genericFunctionCount; + structTypeDeclaration->genericFunctions = realloc( + structTypeDeclaration->genericFunctions, + sizeof(StructTypeGenericFunction) * + (structTypeDeclaration->genericFunctionCount + 1)); + structTypeDeclaration->genericFunctions[index].name = strdup(name); + structTypeDeclaration->genericFunctions[index].parentStruct = + structTypeDeclaration; + structTypeDeclaration->genericFunctions[index].functionDeclarationNode = + functionDeclarationNode; + structTypeDeclaration->genericFunctions[index].isStatic = isStatic; + + for (i = 0; i < NUM_MONOMORPHIZED_HASH_BUCKETS; i += 1) { - if (structTypeDeclarations[i].structPointerType == wStructPointerType) - { - index = structTypeDeclarations[i].genericFunctionCount; - structTypeDeclarations[i].genericFunctions = realloc( - structTypeDeclarations[i].genericFunctions, - sizeof(StructTypeGenericFunction) * - (structTypeDeclarations[i].genericFunctionCount + 1)); - structTypeDeclarations[i].genericFunctions[index].name = - strdup(name); - structTypeDeclarations[i].genericFunctions[index].parentStructName = - parentStructName; - structTypeDeclarations[i].structPointerType = wStructPointerType; - structTypeDeclarations[i] - .genericFunctions[index] - .functionDeclarationNode = functionDeclarationNode; - structTypeDeclarations[i].genericFunctions[index].isStatic = - isStatic; - - for (j = 0; j < NUM_MONOMORPHIZED_HASH_BUCKETS; j += 1) - { - structTypeDeclarations[i] - .genericFunctions[index] - .monomorphizedFunctions[j] - .elements = NULL; - structTypeDeclarations[i] - .genericFunctions[index] - .monomorphizedFunctions[j] - .count = 0; - } - - structTypeDeclarations[i].genericFunctionCount += 1; - - return; - } - } -} - -static inline uint64_t HashTypeTags(TypeTag **tags, uint32_t count) -{ - const uint64_t HASH_FACTOR = 97; - uint64_t result = 1; - uint32_t i; - - for (i = 0; i < count; i += 1) - { - result *= HASH_FACTOR + str_hash(TypeTagToString(tags[i])); + structTypeDeclaration->genericFunctions[index] + .monomorphizedFunctions[i] + .elements = NULL; + structTypeDeclaration->genericFunctions[index] + .monomorphizedFunctions[i] + .count = 0; } - return result; + structTypeDeclaration->genericFunctionCount += 1; } /* FIXME: lots of duplication with non-generic function compile */ static StructTypeFunction CompileGenericFunction( - LLVMModuleRef module, - char *parentStructName, - LLVMTypeRef wStructPointerType, + StructTypeDeclaration *structTypeDeclaration, TypeTag **resolvedGenericArgumentTypes, uint32_t genericArgumentTypeCount, Node *functionDeclaration) @@ -711,7 +884,8 @@ static StructTypeFunction CompileGenericFunction( } } - char *functionName = strdup(parentStructName); + /* FIXME: these cats need to be realloc'd */ + char *functionName = strdup(structTypeDeclaration->name); strcat(functionName, "_"); strcat( functionName, @@ -724,7 +898,7 @@ static StructTypeFunction CompileGenericFunction( if (!isStatic) { - paramTypes[paramIndex] = wStructPointerType; + paramTypes[paramIndex] = structTypeDeclaration->structPointerType; paramIndex += 1; } @@ -746,7 +920,10 @@ static StructTypeFunction CompileGenericFunction( LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, paramIndex, 0); - LLVMValueRef function = LLVMAddFunction(module, functionName, functionType); + LLVMValueRef function = LLVMAddFunction( + structTypeDeclaration->module, + functionName, + functionType); LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry"); LLVMBuilderRef builder = LLVMCreateBuilder(); @@ -755,7 +932,10 @@ static StructTypeFunction CompileGenericFunction( if (!isStatic) { LLVMValueRef wStructPointer = LLVMGetParam(function, 0); - AddStructVariablesToScope(builder, wStructPointer); + AddStructVariablesToScope( + structTypeDeclaration, + builder, + wStructPointer); } for (i = 0; i < functionSignature->functionSignature.arguments @@ -783,7 +963,7 @@ static StructTypeFunction CompileGenericFunction( for (i = 0; i < functionBody->statementSequence.count; i += 1) { CompileStatement( - module, + structTypeDeclaration, builder, function, functionBody->statementSequence.sequence[i]); @@ -816,7 +996,6 @@ static StructTypeFunction CompileGenericFunction( } static LLVMValueRef LookupGenericFunction( - LLVMModuleRef module, StructTypeGenericFunction *genericFunction, Node *functionCallExpression, LLVMTypeRef *pReturnType, @@ -937,14 +1116,11 @@ static LLVMValueRef LookupGenericFunction( if (hashEntry == NULL) { StructTypeFunction function = CompileGenericFunction( - module, - genericFunction->parentStructName, - genericFunction->parentStructPointerType, + genericFunction->parentStruct, resolvedGenericArgumentTypes, genericArgumentTypeCount, genericFunction->functionDeclarationNode); - /* TODO: add to hash */ hashArray->elements = realloc( hashArray->elements, sizeof(MonomorphizedGenericFunctionHashEntry) * @@ -972,7 +1148,6 @@ static LLVMValueRef LookupGenericFunction( } static LLVMValueRef LookupFunctionByType( - LLVMModuleRef module, LLVMTypeRef structType, Node *functionCallExpression, LLVMTypeRef *pReturnType, @@ -1007,7 +1182,6 @@ static LLVMValueRef LookupFunctionByType( name) == 0) { return LookupGenericFunction( - module, &structTypeDeclarations[i].genericFunctions[j], functionCallExpression, pReturnType, @@ -1022,7 +1196,6 @@ static LLVMValueRef LookupFunctionByType( } static LLVMValueRef LookupFunctionByPointerType( - LLVMModuleRef module, LLVMTypeRef structPointerType, Node *functionCallExpression, LLVMTypeRef *pReturnType, @@ -1057,7 +1230,6 @@ static LLVMValueRef LookupFunctionByPointerType( name) == 0) { return LookupGenericFunction( - module, &structTypeDeclarations[i].genericFunctions[j], functionCallExpression, pReturnType, @@ -1072,14 +1244,12 @@ static LLVMValueRef LookupFunctionByPointerType( } static LLVMValueRef LookupFunctionByInstance( - LLVMModuleRef module, LLVMValueRef structPointer, Node *functionCallExpression, LLVMTypeRef *pReturnType, uint8_t *pStatic) { return LookupFunctionByPointerType( - module, LLVMTypeOf(structPointer), functionCallExpression, pReturnType, @@ -1102,17 +1272,17 @@ static LLVMValueRef CompileString( } static LLVMValueRef CompileBinaryExpression( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, Node *binaryExpression) { LLVMValueRef left = CompileExpression( - module, + structTypeDeclaration, builder, binaryExpression->binaryExpression.left); LLVMValueRef right = CompileExpression( - module, + structTypeDeclaration, builder, binaryExpression->binaryExpression.right); @@ -1159,7 +1329,7 @@ static LLVMValueRef CompileBinaryExpression( /* FIXME THIS IS ALL BROKEN */ static LLVMValueRef CompileFunctionCallExpression( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, Node *functionCallExpression) { @@ -1188,7 +1358,6 @@ static LLVMValueRef CompileFunctionCallExpression( if (typeReference != NULL) { function = LookupFunctionByType( - module, typeReference, functionCallExpression, &functionReturnType, @@ -1200,7 +1369,6 @@ static LLVMValueRef CompileFunctionCallExpression( functionCallExpression->functionCallExpression.identifier ->accessExpression.accessee->identifier.name); function = LookupFunctionByInstance( - module, structInstance, functionCallExpression, &functionReturnType, @@ -1224,7 +1392,7 @@ static LLVMValueRef CompileFunctionCallExpression( i += 1) { args[argumentCount] = CompileExpression( - module, + structTypeDeclaration, builder, functionCallExpression->functionCallExpression.argumentSequence ->functionArgumentSequence.sequence[i]); @@ -1240,7 +1408,7 @@ static LLVMValueRef CompileFunctionCallExpression( } static LLVMValueRef CompileSystemCallExpression( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, Node *systemCallExpression) { @@ -1255,7 +1423,7 @@ static LLVMValueRef CompileSystemCallExpression( i += 1) { args[i] = CompileExpression( - module, + structTypeDeclaration, builder, systemCallExpression->systemCall.argumentSequence ->functionArgumentSequence.sequence[i]); @@ -1335,7 +1503,7 @@ static LLVMValueRef CompileAllocExpression( } static LLVMValueRef CompileExpression( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, Node *expression) { @@ -1348,10 +1516,16 @@ static LLVMValueRef CompileExpression( return CompileAllocExpression(builder, expression); case BinaryExpression: - return CompileBinaryExpression(module, builder, expression); + return CompileBinaryExpression( + structTypeDeclaration, + builder, + expression); case FunctionCallExpression: - return CompileFunctionCallExpression(module, builder, expression); + return CompileFunctionCallExpression( + structTypeDeclaration, + builder, + expression); case Identifier: return FindVariableValue(builder, expression->identifier.name); @@ -1363,7 +1537,10 @@ static LLVMValueRef CompileExpression( return CompileString(builder, expression); case SystemCall: - return CompileSystemCallExpression(module, builder, expression); + return CompileSystemCallExpression( + structTypeDeclaration, + builder, + expression); } fprintf(stderr, "Unknown expression kind!\n"); @@ -1371,13 +1548,13 @@ static LLVMValueRef CompileExpression( } static LLVMBasicBlockRef CompileReturn( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement) { LLVMValueRef expression = CompileExpression( - module, + structTypeDeclaration, builder, returnStatemement->returnStatement.expression); LLVMBuildRet(builder, expression); @@ -1417,16 +1594,18 @@ static LLVMValueRef CompileFunctionVariableDeclaration( } static LLVMBasicBlockRef CompileAssignment( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement) { LLVMValueRef result = CompileExpression( - module, + structTypeDeclaration, builder, assignmentStatement->assignmentStatement.right); + LLVMValueRef identifier; + if (assignmentStatement->assignmentStatement.left->syntaxKind == AccessExpression) { @@ -1461,14 +1640,16 @@ static LLVMBasicBlockRef CompileAssignment( } static LLVMBasicBlockRef CompileIfStatement( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement) { uint32_t i; - LLVMValueRef conditional = - CompileExpression(module, builder, ifStatement->ifStatement.expression); + LLVMValueRef conditional = CompileExpression( + structTypeDeclaration, + builder, + ifStatement->ifStatement.expression); LLVMBasicBlockRef block = LLVMAppendBasicBlock(function, "ifBlock"); LLVMBasicBlockRef afterCond = LLVMAppendBasicBlock(function, "afterCond"); @@ -1483,7 +1664,7 @@ static LLVMBasicBlockRef CompileIfStatement( i += 1) { CompileStatement( - module, + structTypeDeclaration, builder, function, ifStatement->ifStatement.statementSequence->statementSequence @@ -1497,14 +1678,14 @@ static LLVMBasicBlockRef CompileIfStatement( } static LLVMBasicBlockRef CompileIfElseStatement( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, LLVMValueRef function, Node *ifElseStatement) { uint32_t i; LLVMValueRef conditional = CompileExpression( - module, + structTypeDeclaration, builder, ifElseStatement->ifElseStatement.ifStatement->ifStatement.expression); @@ -1521,7 +1702,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( i += 1) { CompileStatement( - module, + structTypeDeclaration, builder, function, ifElseStatement->ifElseStatement.ifStatement->ifStatement @@ -1540,7 +1721,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( i += 1) { CompileStatement( - module, + structTypeDeclaration, builder, function, ifElseStatement->ifElseStatement.elseStatement @@ -1550,7 +1731,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( else { CompileStatement( - module, + structTypeDeclaration, builder, function, ifElseStatement->ifElseStatement.elseStatement); @@ -1563,7 +1744,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( } static LLVMBasicBlockRef CompileForLoopStatement( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, LLVMValueRef function, Node *forLoopStatement) @@ -1623,7 +1804,7 @@ static LLVMBasicBlockRef CompileForLoopStatement( i += 1) { lastBlock = CompileStatement( - module, + structTypeDeclaration, builder, function, forLoopStatement->forLoop.statementSequence->statementSequence @@ -1652,7 +1833,7 @@ static LLVMBasicBlockRef CompileForLoopStatement( } static LLVMBasicBlockRef CompileStatement( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, LLVMValueRef function, Node *statement) @@ -1660,33 +1841,56 @@ static LLVMBasicBlockRef CompileStatement( switch (statement->syntaxKind) { case Assignment: - return CompileAssignment(module, builder, function, statement); + return CompileAssignment( + structTypeDeclaration, + builder, + function, + statement); case Declaration: CompileFunctionVariableDeclaration(builder, function, statement); return LLVMGetLastBasicBlock(function); case ForLoop: - return CompileForLoopStatement(module, builder, function, statement); + return CompileForLoopStatement( + structTypeDeclaration, + builder, + function, + statement); case FunctionCallExpression: - CompileFunctionCallExpression(module, builder, statement); + CompileFunctionCallExpression( + structTypeDeclaration, + builder, + statement); return LLVMGetLastBasicBlock(function); case IfStatement: - return CompileIfStatement(module, builder, function, statement); + return CompileIfStatement( + structTypeDeclaration, + builder, + function, + statement); case IfElseStatement: - return CompileIfElseStatement(module, builder, function, statement); + return CompileIfElseStatement( + structTypeDeclaration, + builder, + function, + statement); case Return: - return CompileReturn(module, builder, function, statement); + return CompileReturn( + structTypeDeclaration, + builder, + function, + statement); case ReturnVoid: return CompileReturnVoid(builder, function); case SystemCall: - CompileSystemCallExpression(module, builder, statement); + CompileSystemCallExpression(structTypeDeclaration, builder, statement); return LLVMGetLastBasicBlock(function); } @@ -1695,9 +1899,7 @@ static LLVMBasicBlockRef CompileStatement( } static void CompileFunction( - LLVMModuleRef module, - char *parentStructName, - LLVMTypeRef wStructPointerType, + StructTypeDeclaration *structTypeDeclaration, Node *functionDeclaration) { uint32_t i; @@ -1728,7 +1930,13 @@ static void CompileFunction( } } - char *functionName = strdup(parentStructName); + char *functionName = strdup(structTypeDeclaration->name); + uint32_t nameLen = strlen(functionName); + nameLen += + 2 + + strlen( + functionSignature->functionSignature.identifier->identifier.name); + functionName = realloc(functionName, sizeof(char) * nameLen); strcat(functionName, "_"); strcat( functionName, @@ -1741,7 +1949,7 @@ static void CompileFunction( if (!isStatic) { - paramTypes[paramIndex] = wStructPointerType; + paramTypes[paramIndex] = structTypeDeclaration->structPointerType; paramIndex += 1; } @@ -1761,11 +1969,13 @@ static void CompileFunction( LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, paramIndex, 0); - LLVMValueRef function = - LLVMAddFunction(module, functionName, functionType); + LLVMValueRef function = LLVMAddFunction( + structTypeDeclaration->module, + functionName, + functionType); DeclareStructFunction( - wStructPointerType, + structTypeDeclaration, function, returnType, isStatic, @@ -1778,7 +1988,10 @@ static void CompileFunction( if (!isStatic) { LLVMValueRef wStructPointer = LLVMGetParam(function, 0); - AddStructVariablesToScope(builder, wStructPointer); + AddStructVariablesToScope( + structTypeDeclaration, + builder, + wStructPointer); } for (i = 0; i < functionSignature->functionSignature.arguments @@ -1807,7 +2020,7 @@ static void CompileFunction( for (i = 0; i < functionBody->statementSequence.count; i += 1) { CompileStatement( - module, + structTypeDeclaration, builder, function, functionBody->statementSequence.sequence[i]); @@ -1832,10 +2045,9 @@ static void CompileFunction( else { DeclareGenericStructFunction( - wStructPointerType, + structTypeDeclaration, functionDeclaration, isStatic, - parentStructName, functionSignature->functionSignature.identifier->identifier.name); } @@ -1854,7 +2066,6 @@ static void CompileStruct( uint8_t packed = 1; LLVMTypeRef types[declarationCount]; Node *currentDeclarationNode; - Node *fieldDeclarations[declarationCount]; char *structName = node->structDeclaration.identifier->identifier.name; PushScopeFrame(scope); @@ -1867,6 +2078,12 @@ static void CompileStruct( wStructType, 0); /* FIXME: is this address space correct? */ + StructTypeDeclaration *structTypeDeclaration = AddStructDeclaration( + module, + wStructType, + wStructPointerType, + structName); + /* first, build the structure definition */ for (i = 0; i < declarationCount; i += 1) { @@ -1876,22 +2093,19 @@ static void CompileStruct( switch (currentDeclarationNode->syntaxKind) { - case Declaration: /* this is badly named */ + case Declaration: /* FIXME: this is badly named */ types[fieldCount] = ResolveType( currentDeclarationNode->declaration.identifier->typeTag); - fieldDeclarations[fieldCount] = currentDeclarationNode; + AddFieldToStructDeclaration( + structTypeDeclaration, + currentDeclarationNode->declaration.identifier->identifier + .name); fieldCount += 1; break; } } LLVMStructSetBody(wStructType, types, fieldCount, packed); - AddStructDeclaration( - wStructType, - wStructPointerType, - structName, - fieldDeclarations, - fieldCount); /* now we can wire up the functions */ for (i = 0; i < declarationCount; i += 1) @@ -1903,18 +2117,14 @@ static void CompileStruct( switch (currentDeclarationNode->syntaxKind) { case FunctionDeclaration: - CompileFunction( - module, - structName, - wStructPointerType, - currentDeclarationNode); + CompileFunction(structTypeDeclaration, currentDeclarationNode); break; } } } else { - AddGenericStructDeclaration(node); + AddGenericStructDeclaration(module, node); } PopScopeFrame(scope); @@ -1954,7 +2164,8 @@ static void RegisterLibraryFunctions( { LLVMTypeRef structType = LLVMStructCreateNamed(context, "Console"); LLVMTypeRef structPointerType = LLVMPointerType(structType, 0); - AddStructDeclaration(structType, structPointerType, "Console", NULL, 0); + StructTypeDeclaration *structTypeDeclaration = + AddStructDeclaration(module, structType, structPointerType, "Console"); LLVMTypeRef printfArg = LLVMPointerType(LLVMInt8Type(), 0); LLVMTypeRef printfFunctionType = @@ -1989,7 +2200,7 @@ static void RegisterLibraryFunctions( LLVMBuildAnd(builder, stringPrint, newlinePrint, "and")); DeclareStructFunction( - structPointerType, + structTypeDeclaration, printLineFunction, LLVMInt8Type(), 1,