From f3435f865940b34ba72dd4c3b054de5030af8972 Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Wed, 2 Jun 2021 19:03:58 -0700 Subject: [PATCH 1/6] groundwork for struct generics --- generators/wraith.y | 16 ++++++--- generic.w | 1 + src/ast.c | 79 +++++++++++++++++++++++++++++++++++++++++++-- src/ast.h | 30 ++++++++++++++--- 4 files changed, 115 insertions(+), 11 deletions(-) diff --git a/generators/wraith.y b/generators/wraith.y index 803d922..53678f9 100644 --- a/generators/wraith.y +++ b/generators/wraith.y @@ -113,9 +113,13 @@ BaseType : VOID { $$ = MakePrimitiveTypeNode(MemoryAddress); } + | Identifier GenericArgumentClauseNonEmpty + { + $$ = MakeConcreteGenericTypeNode($1, $2); + } | Identifier { - $$ = MakeCustomTypeNode(yytext); + $$ = MakeCustomTypeNode($1); } | REFERENCE LESS_THAN Type GREATER_THAN { @@ -359,11 +363,13 @@ GenericArguments : GenericArgument $$ = AddGenericArgument($1, $3); } +GenericArgumentClauseNonEmpty : LESS_THAN GenericArguments GREATER_THAN + { + $$ = $2; + } + ; -GenericArgumentClause : LESS_THAN GenericArguments GREATER_THAN - { - $$ = $2; - } +GenericArgumentClause : GenericArgumentClauseNonEmpty | { $$ = MakeEmptyGenericArgumentsNode(); diff --git a/generic.w b/generic.w index b7f6013..8007da7 100644 --- a/generic.w +++ b/generic.w @@ -24,6 +24,7 @@ struct Program { static Main(): int { x: int = 4; y: int = Foo.Func(x); + block: MemoryBlock; addr: MemoryAddress = @malloc(y); @free(addr); return x; diff --git a/src/ast.c b/src/ast.c index dc014bb..d761850 100644 --- a/src/ast.c +++ b/src/ast.c @@ -19,6 +19,8 @@ const char *SyntaxKindString(SyntaxKind syntaxKind) return "BinaryExpression"; case Comment: return "Comment"; + case ConcreteGenericTypeNode: + return "ConcreteGenericTypeNode"; case CustomTypeNode: return "CustomTypeNode"; case Declaration: @@ -95,11 +97,12 @@ Node *MakePrimitiveTypeNode(PrimitiveType type) return node; } -Node *MakeCustomTypeNode(char *name) +Node *MakeCustomTypeNode(Node *identifierNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = CustomTypeNode; - node->customType.name = strdup(name); + node->customType.name = strdup(identifierNode->identifier.name); + free(identifierNode); return node; } @@ -111,6 +114,18 @@ Node *MakeReferenceTypeNode(Node *typeNode) return node; } +Node *MakeConcreteGenericTypeNode( + Node *identifierNode, + Node *genericArgumentsNode) +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = ConcreteGenericTypeNode; + node->concreteGenericType.name = strdup(identifierNode->identifier.name); + node->concreteGenericType.genericArguments = genericArgumentsNode; + free(identifierNode); + return node; +} + Node *MakeTypeNode(Node *typeNode) { Node *node = (Node *)malloc(sizeof(Node)); @@ -624,6 +639,11 @@ void PrintNode(Node *node, uint32_t tabCount) PrintNode(node->binaryExpression.right, tabCount + 1); return; + case ConcreteGenericTypeNode: + printf("%s\n", node->concreteGenericType.name); + PrintNode(node->concreteGenericType.genericArguments, tabCount + 1); + return; + case CustomTypeNode: printf("%s\n", node->customType.name); return; @@ -843,6 +863,10 @@ void Recurse(Node *node, void (*func)(Node *)) case Comment: return; + case ConcreteGenericTypeNode: + func(node->concreteGenericType.genericArguments); + return; + case CustomTypeNode: return; @@ -1004,6 +1028,8 @@ void Recurse(Node *node, void (*func)(Node *)) TypeTag *MakeTypeTag(Node *node) { + uint32_t i; + if (node == NULL) { fprintf( @@ -1034,6 +1060,28 @@ TypeTag *MakeTypeTag(Node *node) tag->value.customType = strdup(node->customType.name); break; + case ConcreteGenericTypeNode: + tag->type = ConcreteGeneric; + tag->value.concreteGenericType.name = + strdup(node->concreteGenericType.name); + tag->value.concreteGenericType.genericArgumentCount = + node->concreteGenericType.genericArguments->genericArguments.count; + tag->value.concreteGenericType.genericArguments = malloc( + sizeof(TypeTag *) * + tag->value.concreteGenericType.genericArgumentCount); + + for (i = 0; + i < + node->concreteGenericType.genericArguments->genericArguments.count; + i += 1) + { + tag->value.concreteGenericType.genericArguments[i] = MakeTypeTag( + node->concreteGenericType.genericArguments->genericArguments + .arguments[i] + ->genericArgument.type); + } + break; + case Declaration: tag = MakeTypeTag(node->declaration.type); break; @@ -1078,6 +1126,8 @@ TypeTag *MakeTypeTag(Node *node) char *TypeTagToString(TypeTag *tag) { + uint32_t i; + if (tag == NULL) { fprintf( @@ -1114,6 +1164,31 @@ char *TypeTagToString(TypeTag *tag) sprintf(result, "Generic<%s>", tag->value.genericType); return result; } + + case ConcreteGeneric: + { + char *result = strdup(tag->value.concreteGenericType.name); + uint32_t len = strlen(result); + result = realloc(result, len + 2); + strcat(result, "<"); + + for (i = 0; i < tag->value.concreteGenericType.genericArgumentCount; + i += 1) + { + char *inner = TypeTagToString( + tag->value.concreteGenericType.genericArguments[i]); + len += strlen(inner); + result = realloc(result, sizeof(char) * (len + 3)); + if (i != tag->value.concreteGenericType.genericArgumentCount - 1) + { + strcat(result, ", "); + } + strcat(result, inner); + } + result = realloc(result, sizeof(char) * (len + 2)); + strcat(result, ">"); + return result; + } } } diff --git a/src/ast.h b/src/ast.h index c7fd974..531064b 100644 --- a/src/ast.h +++ b/src/ast.h @@ -19,6 +19,7 @@ typedef enum Assignment, BinaryExpression, Comment, + ConcreteGenericTypeNode, CustomTypeNode, Declaration, DeclarationSequence, @@ -86,7 +87,16 @@ typedef union BinaryOperator binaryOperator; } Operator; -typedef struct TypeTag +typedef struct TypeTag TypeTag; + +typedef struct ConcreteGenericTypeTag +{ + char *name; + TypeTag **genericArguments; + uint32_t genericArgumentCount; +} ConcreteGenericTypeTag; + +struct TypeTag { enum Type { @@ -94,7 +104,8 @@ typedef struct TypeTag Primitive, Reference, Custom, - Generic + Generic, + ConcreteGeneric } type; union { @@ -106,8 +117,10 @@ typedef struct TypeTag char *customType; /* Valid when type = Generic. */ char *genericType; + /* Valid when type = ConcreteGeneric */ + ConcreteGenericTypeTag concreteGenericType; } value; -} TypeTag; +}; typedef struct Node Node; @@ -146,6 +159,12 @@ struct Node } comment; + struct + { + char *name; + Node *genericArguments; + } concreteGenericType; + struct { char *name; @@ -329,8 +348,11 @@ const char *SyntaxKindString(SyntaxKind syntaxKind); uint8_t IsPrimitiveType(Node *typeNode); Node *MakePrimitiveTypeNode(PrimitiveType type); -Node *MakeCustomTypeNode(char *string); +Node *MakeCustomTypeNode(Node *identifierNode); Node *MakeReferenceTypeNode(Node *typeNode); +Node *MakeConcreteGenericTypeNode( + Node *identifierNode, + Node *genericArgumentsNode); Node *MakeTypeNode(Node *typeNode); Node *MakeIdentifierNode(const char *id); Node *MakeNumberNode(const char *numberString); -- 2.25.1 From a870a2c32e42cbc4acef1c7a0743d91484f68e03 Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Thu, 3 Jun 2021 14:40:14 -0700 Subject: [PATCH 2/6] monomorphizing generic structs --- src/codegen.c | 757 ++++++++++++++++++++++++++++++++------------------ 1 file changed, 484 insertions(+), 273 deletions(-) diff --git a/src/codegen.c b/src/codegen.c index 2e67b17..8919f79 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -83,10 +83,11 @@ typedef struct MonomorphizedGenericFunctionHashArray #define NUM_MONOMORPHIZED_HASH_BUCKETS 1031 +typedef struct StructTypeDeclaration StructTypeDeclaration; + typedef struct StructTypeGenericFunction { - char *parentStructName; - LLVMTypeRef parentStructPointerType; + StructTypeDeclaration *parentStruct; char *name; Node *functionDeclarationNode; uint8_t isStatic; @@ -94,8 +95,9 @@ typedef struct StructTypeGenericFunction monomorphizedFunctions[NUM_MONOMORPHIZED_HASH_BUCKETS]; } StructTypeGenericFunction; -typedef struct StructTypeDeclaration +struct StructTypeDeclaration { + LLVMModuleRef module; char *name; LLVMTypeRef structType; LLVMTypeRef structPointerType; @@ -107,7 +109,7 @@ typedef struct StructTypeDeclaration StructTypeGenericFunction *genericFunctions; uint32_t genericFunctionCount; -} StructTypeDeclaration; +}; StructTypeDeclaration *structTypeDeclarations; uint32_t structTypeDeclarationCount; @@ -128,6 +130,7 @@ typedef struct MonomorphizedGenericStructHashArray typedef struct GenericStructTypeDeclaration { + LLVMModuleRef module; Node *structDeclarationNode; MonomorphizedGenericStructHashArray monomorphizedStructs[NUM_MONOMORPHIZED_HASH_BUCKETS]; @@ -148,16 +151,22 @@ uint32_t systemFunctionCount; /* FUNCTION FORWARD DECLARATIONS */ static LLVMBasicBlockRef CompileStatement( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, LLVMValueRef function, Node *statement); static LLVMValueRef CompileExpression( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, Node *expression); +static void CompileFunction( + StructTypeDeclaration *structTypeDeclaration, + Node *functionDeclaration); + +static LLVMTypeRef ResolveType(TypeTag *typeTag); + static Scope *CreateScope() { Scope *scope = malloc(sizeof(Scope)); @@ -215,6 +224,122 @@ static void PopScopeFrame(Scope *scope) realloc(scope->scopeStack, sizeof(ScopeFrame) * scope->scopeStackCount); } +static inline uint64_t HashTypeTags(TypeTag **tags, uint32_t count) +{ + const uint64_t HASH_FACTOR = 97; + uint64_t result = 1; + uint32_t i; + + for (i = 0; i < count; i += 1) + { + result *= HASH_FACTOR + str_hash(TypeTagToString(tags[i])); + } + + return result; +} + +static void AddLocalVariable( + Scope *scope, + LLVMValueRef pointer, /* can be NULL */ + LLVMValueRef value, /* can be NULL */ + char *name) +{ + ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1]; + uint32_t index = scopeFrame->localVariableCount; + + scopeFrame->localVariables = realloc( + scopeFrame->localVariables, + sizeof(LocalVariable) * (scopeFrame->localVariableCount + 1)); + scopeFrame->localVariables[index].name = strdup(name); + scopeFrame->localVariables[index].pointer = pointer; + scopeFrame->localVariables[index].value = value; + + scopeFrame->localVariableCount += 1; +} + +static void AddGenericVariable(Scope *scope, TypeTag *typeTag, char *name) +{ + ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1]; + uint32_t index = scopeFrame->genericTypeCount; + + scopeFrame->genericTypes = realloc( + scopeFrame->genericTypes, + sizeof(LocalGenericType) * (scopeFrame->genericTypeCount + 1)); + scopeFrame->genericTypes[index].name = strdup(name); + scopeFrame->genericTypes[index].concreteTypeTag = typeTag; + scopeFrame->genericTypes[index].type = ResolveType(typeTag); + + scopeFrame->genericTypeCount += 1; +} + +static void AddStructVariablesToScope( + StructTypeDeclaration *structTypeDeclaration, + LLVMBuilderRef builder, + LLVMValueRef structPointer) +{ + uint32_t i; + + for (i = 0; i < structTypeDeclaration->fieldCount; i += 1) + { + char *ptrName = strdup(structTypeDeclaration->fields[i].name); + strcat(ptrName, "_ptr"); /* FIXME: needs to be realloc'd */ + LLVMValueRef elementPointer = LLVMBuildStructGEP( + builder, + structPointer, + structTypeDeclaration->fields[i].index, + ptrName); + free(ptrName); + + AddLocalVariable( + scope, + elementPointer, + NULL, + structTypeDeclaration->fields[i].name); + } +} + +static void AddFieldToStructDeclaration( + StructTypeDeclaration *structTypeDeclaration, + char *name) +{ + structTypeDeclaration->fields = realloc( + structTypeDeclaration->fields, + sizeof(StructTypeField) * (structTypeDeclaration->fieldCount + 1)); + structTypeDeclaration->fields[structTypeDeclaration->fieldCount].name = + strdup(name); + structTypeDeclaration->fields[structTypeDeclaration->fieldCount].index = + structTypeDeclaration->fieldCount; + structTypeDeclaration->fieldCount += 1; +} + +static void AddGenericStructDeclaration( + LLVMModuleRef module, + Node *structDeclarationNode) +{ + uint32_t i; + + genericStructTypeDeclarations = realloc( + genericStructTypeDeclarations, + sizeof(GenericStructTypeDeclaration) * + (genericStructTypeDeclarationCount + 1)); + genericStructTypeDeclarations[genericStructTypeDeclarationCount] + .structDeclarationNode = structDeclarationNode; + genericStructTypeDeclarations[genericStructTypeDeclarationCount].module = + module; + + for (i = 0; i < NUM_MONOMORPHIZED_HASH_BUCKETS; i += 1) + { + genericStructTypeDeclarations[genericStructTypeDeclarationCount] + .monomorphizedStructs[i] + .elements = NULL; + genericStructTypeDeclarations[genericStructTypeDeclarationCount] + .monomorphizedStructs[i] + .count = 0; + } + + genericStructTypeDeclarationCount += 1; +} + static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type) { switch (type) @@ -284,6 +409,195 @@ static LLVMTypeRef LookupCustomType(char *name) return NULL; } +static StructTypeDeclaration CompileMonomorphizedGenericStruct( + GenericStructTypeDeclaration *genericStructTypeDeclaration, + TypeTag **genericArgumentTypes, + uint32_t genericArgumentTypeCount) +{ + uint32_t i = 0; + uint32_t nameLen; + uint32_t fieldCount = 0; + Node *structDeclarationNode = + genericStructTypeDeclaration->structDeclarationNode; + uint32_t declarationCount = + structDeclarationNode->structDeclaration.declarationSequence + ->declarationSequence.count; + LLVMTypeRef types[declarationCount]; + + PushScopeFrame(scope); + + for (i = 0; i < genericArgumentTypeCount; i += 1) + { + AddGenericVariable( + scope, + genericArgumentTypes[i], + structDeclarationNode->structDeclaration.genericDeclarations + ->genericDeclarations.declarations[i] + ->genericDeclaration.identifier->identifier.name); + } + + char *structName = strdup( + structDeclarationNode->structDeclaration.identifier->identifier.name); + nameLen = strlen(structName); + + for (i = 0; i < genericArgumentTypeCount; i += 1) + { + char *inner = TypeTagToString(genericArgumentTypes[i]); + nameLen += 2 + strlen(inner); + structName = realloc(structName, sizeof(char) * nameLen); + strcat(structName, "_"); + strcat(structName, inner); + } + + LLVMContextRef context = + LLVMGetGlobalContext(); /* FIXME: should we pass a context? */ + LLVMTypeRef wStructType = LLVMStructCreateNamed(context, structName); + LLVMTypeRef wStructPointerType = LLVMPointerType(wStructType, 0); + + StructTypeDeclaration declaration; + declaration.module = genericStructTypeDeclaration->module; + declaration.name = structName; + declaration.structType = wStructType; + declaration.structPointerType = wStructPointerType; + declaration.genericFunctions = NULL; + declaration.genericFunctionCount = 0; + declaration.functions = NULL; + declaration.functionCount = 0; + declaration.fields = NULL; + declaration.fieldCount = 0; + + /* first build the structure def */ + for (i = 0; i < declarationCount; i += 1) + { + switch (structDeclarationNode->structDeclaration.declarationSequence + ->declarationSequence.sequence[i] + ->syntaxKind) + { + case Declaration: + types[fieldCount] = ResolveType( + structDeclarationNode->structDeclaration.declarationSequence + ->declarationSequence.sequence[i] + ->declaration.identifier->typeTag); + AddFieldToStructDeclaration( + &declaration, + structDeclarationNode->structDeclaration.declarationSequence + ->declarationSequence.sequence[i] + ->declaration.identifier->identifier.name); + fieldCount += 1; + break; + } + } + + LLVMStructSetBody(wStructType, types, fieldCount, 1); + + /* now we wire up the functions */ + for (i = 0; i < declarationCount; i += 1) + { + switch (structDeclarationNode->structDeclaration.declarationSequence + ->declarationSequence.sequence[i] + ->syntaxKind) + { + case FunctionDeclaration: + CompileFunction( + &declaration, + structDeclarationNode->structDeclaration.declarationSequence + ->declarationSequence.sequence[i]); + break; + } + } + + PopScopeFrame(scope); + + return declaration; +} + +static StructTypeDeclaration *LookupGenericStructType( + ConcreteGenericTypeTag *typeTag) +{ + uint32_t i, j, k; + uint64_t typeHash; + uint8_t match; + TypeTag *genericTypeTags[typeTag->genericArgumentCount]; + + for (i = 0; i < typeTag->genericArgumentCount; i += 1) + { + genericTypeTags[i] = ConcretizeType(typeTag->genericArguments[i]); + } + + for (i = 0; i < genericStructTypeDeclarationCount; i += 1) + { + if (strcmp( + genericStructTypeDeclarations[i] + .structDeclarationNode->structDeclaration.identifier + ->identifier.name, + typeTag->name) == 0) + { + typeHash = + HashTypeTags(genericTypeTags, typeTag->genericArgumentCount); + + MonomorphizedGenericStructHashArray *hashArray = + &genericStructTypeDeclarations[i].monomorphizedStructs + [typeHash % NUM_MONOMORPHIZED_HASH_BUCKETS]; + + MonomorphizedGenericStructHashEntry *hashEntry = NULL; + + for (j = 0; j < hashArray->count; j += 1) + { + match = 1; + + for (k = 0; k < hashArray->elements[j].typeCount; k += 1) + { + if (hashArray->elements[j].types[k] != genericTypeTags[k]) + { + match = 0; + break; + } + } + + if (match) + { + hashEntry = &hashArray->elements[i]; + break; + } + } + + if (hashEntry == NULL) + { + StructTypeDeclaration structTypeDeclaration = + CompileMonomorphizedGenericStruct( + &genericStructTypeDeclarations[i], + genericTypeTags, + typeTag->genericArgumentCount); + + hashArray->elements = realloc( + hashArray->elements, + sizeof(MonomorphizedGenericStructHashEntry) * + (hashArray->count + 1)); + hashArray->elements[hashArray->count].key = typeHash; + hashArray->elements[hashArray->count].types = + malloc(sizeof(TypeTag *) * typeTag->genericArgumentCount); + hashArray->elements[hashArray->count].typeCount = + typeTag->genericArgumentCount; + hashArray->elements[hashArray->count].structDeclaration = + structTypeDeclaration; + for (j = 0; j < typeTag->genericArgumentCount; j += 1) + { + hashArray->elements[hashArray->count].types[j] = + genericTypeTags[j]; + } + hashArray->count += 1; + + hashEntry = &hashArray->elements[hashArray->count - 1]; + } + + return &hashEntry->structDeclaration; + } + } + + fprintf(stderr, "Could not find generic struct declaration!"); + return NULL; +} + static LLVMTypeRef ResolveType(TypeTag *typeTag) { if (typeTag->type == Primitive) @@ -302,6 +616,11 @@ static LLVMTypeRef ResolveType(TypeTag *typeTag) { return LookupGenericType(typeTag->value.genericType)->type; } + else if (typeTag->type == ConcreteGeneric) + { + return LookupGenericStructType(&typeTag->value.concreteGenericType) + ->structType; + } else { fprintf(stderr, "Unknown type node!\n"); @@ -340,73 +659,6 @@ static SystemFunction *LookupSystemFunction(Node *systemCallExpression) return NULL; } -static void AddLocalVariable( - Scope *scope, - LLVMValueRef pointer, /* can be NULL */ - LLVMValueRef value, /* can be NULL */ - char *name) -{ - ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1]; - uint32_t index = scopeFrame->localVariableCount; - - scopeFrame->localVariables = realloc( - scopeFrame->localVariables, - sizeof(LocalVariable) * (scopeFrame->localVariableCount + 1)); - scopeFrame->localVariables[index].name = strdup(name); - scopeFrame->localVariables[index].pointer = pointer; - scopeFrame->localVariables[index].value = value; - - scopeFrame->localVariableCount += 1; -} - -static void AddGenericVariable(Scope *scope, TypeTag *typeTag, char *name) -{ - ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1]; - uint32_t index = scopeFrame->genericTypeCount; - - scopeFrame->genericTypes = realloc( - scopeFrame->genericTypes, - sizeof(LocalGenericType) * (scopeFrame->genericTypeCount + 1)); - scopeFrame->genericTypes[index].name = strdup(name); - scopeFrame->genericTypes[index].concreteTypeTag = typeTag; - scopeFrame->genericTypes[index].type = ResolveType(typeTag); - - scopeFrame->genericTypeCount += 1; -} - -static void AddStructVariablesToScope( - LLVMBuilderRef builder, - LLVMValueRef structPointer) -{ - uint32_t i, j; - - for (i = 0; i < structTypeDeclarationCount; i += 1) - { - if (structTypeDeclarations[i].structPointerType == - LLVMTypeOf(structPointer)) - { - for (j = 0; j < structTypeDeclarations[i].fieldCount; j += 1) - { - char *ptrName = - strdup(structTypeDeclarations[i].fields[j].name); - strcat(ptrName, "_ptr"); - LLVMValueRef elementPointer = LLVMBuildStructGEP( - builder, - structPointer, - structTypeDeclarations[i].fields[j].index, - ptrName); - free(ptrName); - - AddLocalVariable( - scope, - elementPointer, - NULL, - structTypeDeclarations[i].fields[j].name); - } - } - } -} - static LLVMTypeRef FindStructType(char *name) { uint32_t i; @@ -504,18 +756,17 @@ static LLVMValueRef FindVariableValue(LLVMBuilderRef builder, char *name) return NULL; } -static void AddStructDeclaration( +static StructTypeDeclaration *AddStructDeclaration( + LLVMModuleRef module, LLVMTypeRef wStructType, LLVMTypeRef wStructPointerType, - char *name, - Node **fieldDeclarations, - uint32_t fieldDeclarationCount) + char *name) { - uint32_t i; uint32_t index = structTypeDeclarationCount; structTypeDeclarations = realloc( structTypeDeclarations, sizeof(StructTypeDeclaration) * (structTypeDeclarationCount + 1)); + structTypeDeclarations[index].module = module; structTypeDeclarations[index].structType = wStructType; structTypeDeclarations[index].structPointerType = wStructPointerType; structTypeDeclarations[index].name = strdup(name); @@ -526,145 +777,67 @@ static void AddStructDeclaration( structTypeDeclarations[index].genericFunctions = NULL; structTypeDeclarations[index].genericFunctionCount = 0; - for (i = 0; i < fieldDeclarationCount; i += 1) - { - structTypeDeclarations[index].fields = realloc( - structTypeDeclarations[index].fields, - sizeof(StructTypeField) * - (structTypeDeclarations[index].fieldCount + 1)); - structTypeDeclarations[index].fields[i].name = strdup( - fieldDeclarations[i]->declaration.identifier->identifier.name); - structTypeDeclarations[index].fields[i].index = i; - structTypeDeclarations[index].fieldCount += 1; - } - structTypeDeclarationCount += 1; + + return &structTypeDeclarations[index]; } -static void AddGenericStructDeclaration(Node *structDeclarationNode) -{ - uint32_t i; - - genericStructTypeDeclarations = realloc( - genericStructTypeDeclarations, - sizeof(GenericStructTypeDeclaration) * - (genericStructTypeDeclarationCount + 1)); - genericStructTypeDeclarations[genericStructTypeDeclarationCount] - .structDeclarationNode = structDeclarationNode; - - for (i = 0; i < NUM_MONOMORPHIZED_HASH_BUCKETS; i += 1) - { - genericStructTypeDeclarations[genericStructTypeDeclarationCount] - .monomorphizedStructs[i] - .elements = NULL; - genericStructTypeDeclarations[genericStructTypeDeclarationCount] - .monomorphizedStructs[i] - .count = 0; - } - - genericStructTypeDeclarationCount += 1; -} - -/* FIXME: pass the declaration itself */ static void DeclareStructFunction( - LLVMTypeRef wStructPointerType, + StructTypeDeclaration *structTypeDeclaration, LLVMValueRef function, LLVMTypeRef returnType, uint8_t isStatic, char *name) { - uint32_t i, index; + uint32_t index = structTypeDeclaration->functionCount; - for (i = 0; i < structTypeDeclarationCount; i += 1) - { - if (structTypeDeclarations[i].structPointerType == wStructPointerType) - { - index = structTypeDeclarations[i].functionCount; - structTypeDeclarations[i].functions = realloc( - structTypeDeclarations[i].functions, - sizeof(StructTypeFunction) * - (structTypeDeclarations[i].functionCount + 1)); - structTypeDeclarations[i].functions[index].name = strdup(name); - structTypeDeclarations[i].functions[index].function = function; - structTypeDeclarations[i].functions[index].returnType = returnType; - structTypeDeclarations[i].functions[index].isStatic = isStatic; - structTypeDeclarations[i].functionCount += 1; - - return; - } - } - - fprintf(stderr, "Could not find struct type for function!\n"); + structTypeDeclaration->functions = realloc( + structTypeDeclaration->functions, + sizeof(StructTypeFunction) * + (structTypeDeclaration->functionCount + 1)); + structTypeDeclaration->functions[index].name = strdup(name); + structTypeDeclaration->functions[index].function = function; + structTypeDeclaration->functions[index].returnType = returnType; + structTypeDeclaration->functions[index].isStatic = isStatic; + structTypeDeclaration->functionCount += 1; } -/* FIXME: pass the declaration itself */ static void DeclareGenericStructFunction( - LLVMTypeRef wStructPointerType, + StructTypeDeclaration *structTypeDeclaration, Node *functionDeclarationNode, uint8_t isStatic, - char *parentStructName, char *name) { - uint32_t i, j, index; + uint32_t i, index; - for (i = 0; i < structTypeDeclarationCount; i += 1) + index = structTypeDeclaration->genericFunctionCount; + structTypeDeclaration->genericFunctions = realloc( + structTypeDeclaration->genericFunctions, + sizeof(StructTypeGenericFunction) * + (structTypeDeclaration->genericFunctionCount + 1)); + structTypeDeclaration->genericFunctions[index].name = strdup(name); + structTypeDeclaration->genericFunctions[index].parentStruct = + structTypeDeclaration; + structTypeDeclaration->genericFunctions[index].functionDeclarationNode = + functionDeclarationNode; + structTypeDeclaration->genericFunctions[index].isStatic = isStatic; + + for (i = 0; i < NUM_MONOMORPHIZED_HASH_BUCKETS; i += 1) { - if (structTypeDeclarations[i].structPointerType == wStructPointerType) - { - index = structTypeDeclarations[i].genericFunctionCount; - structTypeDeclarations[i].genericFunctions = realloc( - structTypeDeclarations[i].genericFunctions, - sizeof(StructTypeGenericFunction) * - (structTypeDeclarations[i].genericFunctionCount + 1)); - structTypeDeclarations[i].genericFunctions[index].name = - strdup(name); - structTypeDeclarations[i].genericFunctions[index].parentStructName = - parentStructName; - structTypeDeclarations[i].structPointerType = wStructPointerType; - structTypeDeclarations[i] - .genericFunctions[index] - .functionDeclarationNode = functionDeclarationNode; - structTypeDeclarations[i].genericFunctions[index].isStatic = - isStatic; - - for (j = 0; j < NUM_MONOMORPHIZED_HASH_BUCKETS; j += 1) - { - structTypeDeclarations[i] - .genericFunctions[index] - .monomorphizedFunctions[j] - .elements = NULL; - structTypeDeclarations[i] - .genericFunctions[index] - .monomorphizedFunctions[j] - .count = 0; - } - - structTypeDeclarations[i].genericFunctionCount += 1; - - return; - } - } -} - -static inline uint64_t HashTypeTags(TypeTag **tags, uint32_t count) -{ - const uint64_t HASH_FACTOR = 97; - uint64_t result = 1; - uint32_t i; - - for (i = 0; i < count; i += 1) - { - result *= HASH_FACTOR + str_hash(TypeTagToString(tags[i])); + structTypeDeclaration->genericFunctions[index] + .monomorphizedFunctions[i] + .elements = NULL; + structTypeDeclaration->genericFunctions[index] + .monomorphizedFunctions[i] + .count = 0; } - return result; + structTypeDeclaration->genericFunctionCount += 1; } /* FIXME: lots of duplication with non-generic function compile */ static StructTypeFunction CompileGenericFunction( - LLVMModuleRef module, - char *parentStructName, - LLVMTypeRef wStructPointerType, + StructTypeDeclaration *structTypeDeclaration, TypeTag **resolvedGenericArgumentTypes, uint32_t genericArgumentTypeCount, Node *functionDeclaration) @@ -711,7 +884,8 @@ static StructTypeFunction CompileGenericFunction( } } - char *functionName = strdup(parentStructName); + /* FIXME: these cats need to be realloc'd */ + char *functionName = strdup(structTypeDeclaration->name); strcat(functionName, "_"); strcat( functionName, @@ -724,7 +898,7 @@ static StructTypeFunction CompileGenericFunction( if (!isStatic) { - paramTypes[paramIndex] = wStructPointerType; + paramTypes[paramIndex] = structTypeDeclaration->structPointerType; paramIndex += 1; } @@ -746,7 +920,10 @@ static StructTypeFunction CompileGenericFunction( LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, paramIndex, 0); - LLVMValueRef function = LLVMAddFunction(module, functionName, functionType); + LLVMValueRef function = LLVMAddFunction( + structTypeDeclaration->module, + functionName, + functionType); LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry"); LLVMBuilderRef builder = LLVMCreateBuilder(); @@ -755,7 +932,10 @@ static StructTypeFunction CompileGenericFunction( if (!isStatic) { LLVMValueRef wStructPointer = LLVMGetParam(function, 0); - AddStructVariablesToScope(builder, wStructPointer); + AddStructVariablesToScope( + structTypeDeclaration, + builder, + wStructPointer); } for (i = 0; i < functionSignature->functionSignature.arguments @@ -783,7 +963,7 @@ static StructTypeFunction CompileGenericFunction( for (i = 0; i < functionBody->statementSequence.count; i += 1) { CompileStatement( - module, + structTypeDeclaration, builder, function, functionBody->statementSequence.sequence[i]); @@ -816,7 +996,6 @@ static StructTypeFunction CompileGenericFunction( } static LLVMValueRef LookupGenericFunction( - LLVMModuleRef module, StructTypeGenericFunction *genericFunction, Node *functionCallExpression, LLVMTypeRef *pReturnType, @@ -937,14 +1116,11 @@ static LLVMValueRef LookupGenericFunction( if (hashEntry == NULL) { StructTypeFunction function = CompileGenericFunction( - module, - genericFunction->parentStructName, - genericFunction->parentStructPointerType, + genericFunction->parentStruct, resolvedGenericArgumentTypes, genericArgumentTypeCount, genericFunction->functionDeclarationNode); - /* TODO: add to hash */ hashArray->elements = realloc( hashArray->elements, sizeof(MonomorphizedGenericFunctionHashEntry) * @@ -972,7 +1148,6 @@ static LLVMValueRef LookupGenericFunction( } static LLVMValueRef LookupFunctionByType( - LLVMModuleRef module, LLVMTypeRef structType, Node *functionCallExpression, LLVMTypeRef *pReturnType, @@ -1007,7 +1182,6 @@ static LLVMValueRef LookupFunctionByType( name) == 0) { return LookupGenericFunction( - module, &structTypeDeclarations[i].genericFunctions[j], functionCallExpression, pReturnType, @@ -1022,7 +1196,6 @@ static LLVMValueRef LookupFunctionByType( } static LLVMValueRef LookupFunctionByPointerType( - LLVMModuleRef module, LLVMTypeRef structPointerType, Node *functionCallExpression, LLVMTypeRef *pReturnType, @@ -1057,7 +1230,6 @@ static LLVMValueRef LookupFunctionByPointerType( name) == 0) { return LookupGenericFunction( - module, &structTypeDeclarations[i].genericFunctions[j], functionCallExpression, pReturnType, @@ -1072,14 +1244,12 @@ static LLVMValueRef LookupFunctionByPointerType( } static LLVMValueRef LookupFunctionByInstance( - LLVMModuleRef module, LLVMValueRef structPointer, Node *functionCallExpression, LLVMTypeRef *pReturnType, uint8_t *pStatic) { return LookupFunctionByPointerType( - module, LLVMTypeOf(structPointer), functionCallExpression, pReturnType, @@ -1102,17 +1272,17 @@ static LLVMValueRef CompileString( } static LLVMValueRef CompileBinaryExpression( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, Node *binaryExpression) { LLVMValueRef left = CompileExpression( - module, + structTypeDeclaration, builder, binaryExpression->binaryExpression.left); LLVMValueRef right = CompileExpression( - module, + structTypeDeclaration, builder, binaryExpression->binaryExpression.right); @@ -1159,7 +1329,7 @@ static LLVMValueRef CompileBinaryExpression( /* FIXME THIS IS ALL BROKEN */ static LLVMValueRef CompileFunctionCallExpression( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, Node *functionCallExpression) { @@ -1188,7 +1358,6 @@ static LLVMValueRef CompileFunctionCallExpression( if (typeReference != NULL) { function = LookupFunctionByType( - module, typeReference, functionCallExpression, &functionReturnType, @@ -1200,7 +1369,6 @@ static LLVMValueRef CompileFunctionCallExpression( functionCallExpression->functionCallExpression.identifier ->accessExpression.accessee->identifier.name); function = LookupFunctionByInstance( - module, structInstance, functionCallExpression, &functionReturnType, @@ -1224,7 +1392,7 @@ static LLVMValueRef CompileFunctionCallExpression( i += 1) { args[argumentCount] = CompileExpression( - module, + structTypeDeclaration, builder, functionCallExpression->functionCallExpression.argumentSequence ->functionArgumentSequence.sequence[i]); @@ -1240,7 +1408,7 @@ static LLVMValueRef CompileFunctionCallExpression( } static LLVMValueRef CompileSystemCallExpression( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, Node *systemCallExpression) { @@ -1255,7 +1423,7 @@ static LLVMValueRef CompileSystemCallExpression( i += 1) { args[i] = CompileExpression( - module, + structTypeDeclaration, builder, systemCallExpression->systemCall.argumentSequence ->functionArgumentSequence.sequence[i]); @@ -1335,7 +1503,7 @@ static LLVMValueRef CompileAllocExpression( } static LLVMValueRef CompileExpression( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, Node *expression) { @@ -1348,10 +1516,16 @@ static LLVMValueRef CompileExpression( return CompileAllocExpression(builder, expression); case BinaryExpression: - return CompileBinaryExpression(module, builder, expression); + return CompileBinaryExpression( + structTypeDeclaration, + builder, + expression); case FunctionCallExpression: - return CompileFunctionCallExpression(module, builder, expression); + return CompileFunctionCallExpression( + structTypeDeclaration, + builder, + expression); case Identifier: return FindVariableValue(builder, expression->identifier.name); @@ -1363,7 +1537,10 @@ static LLVMValueRef CompileExpression( return CompileString(builder, expression); case SystemCall: - return CompileSystemCallExpression(module, builder, expression); + return CompileSystemCallExpression( + structTypeDeclaration, + builder, + expression); } fprintf(stderr, "Unknown expression kind!\n"); @@ -1371,13 +1548,13 @@ static LLVMValueRef CompileExpression( } static LLVMBasicBlockRef CompileReturn( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement) { LLVMValueRef expression = CompileExpression( - module, + structTypeDeclaration, builder, returnStatemement->returnStatement.expression); LLVMBuildRet(builder, expression); @@ -1417,16 +1594,18 @@ static LLVMValueRef CompileFunctionVariableDeclaration( } static LLVMBasicBlockRef CompileAssignment( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement) { LLVMValueRef result = CompileExpression( - module, + structTypeDeclaration, builder, assignmentStatement->assignmentStatement.right); + LLVMValueRef identifier; + if (assignmentStatement->assignmentStatement.left->syntaxKind == AccessExpression) { @@ -1461,14 +1640,16 @@ static LLVMBasicBlockRef CompileAssignment( } static LLVMBasicBlockRef CompileIfStatement( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement) { uint32_t i; - LLVMValueRef conditional = - CompileExpression(module, builder, ifStatement->ifStatement.expression); + LLVMValueRef conditional = CompileExpression( + structTypeDeclaration, + builder, + ifStatement->ifStatement.expression); LLVMBasicBlockRef block = LLVMAppendBasicBlock(function, "ifBlock"); LLVMBasicBlockRef afterCond = LLVMAppendBasicBlock(function, "afterCond"); @@ -1483,7 +1664,7 @@ static LLVMBasicBlockRef CompileIfStatement( i += 1) { CompileStatement( - module, + structTypeDeclaration, builder, function, ifStatement->ifStatement.statementSequence->statementSequence @@ -1497,14 +1678,14 @@ static LLVMBasicBlockRef CompileIfStatement( } static LLVMBasicBlockRef CompileIfElseStatement( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, LLVMValueRef function, Node *ifElseStatement) { uint32_t i; LLVMValueRef conditional = CompileExpression( - module, + structTypeDeclaration, builder, ifElseStatement->ifElseStatement.ifStatement->ifStatement.expression); @@ -1521,7 +1702,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( i += 1) { CompileStatement( - module, + structTypeDeclaration, builder, function, ifElseStatement->ifElseStatement.ifStatement->ifStatement @@ -1540,7 +1721,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( i += 1) { CompileStatement( - module, + structTypeDeclaration, builder, function, ifElseStatement->ifElseStatement.elseStatement @@ -1550,7 +1731,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( else { CompileStatement( - module, + structTypeDeclaration, builder, function, ifElseStatement->ifElseStatement.elseStatement); @@ -1563,7 +1744,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( } static LLVMBasicBlockRef CompileForLoopStatement( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, LLVMValueRef function, Node *forLoopStatement) @@ -1623,7 +1804,7 @@ static LLVMBasicBlockRef CompileForLoopStatement( i += 1) { lastBlock = CompileStatement( - module, + structTypeDeclaration, builder, function, forLoopStatement->forLoop.statementSequence->statementSequence @@ -1652,7 +1833,7 @@ static LLVMBasicBlockRef CompileForLoopStatement( } static LLVMBasicBlockRef CompileStatement( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, LLVMBuilderRef builder, LLVMValueRef function, Node *statement) @@ -1660,33 +1841,56 @@ static LLVMBasicBlockRef CompileStatement( switch (statement->syntaxKind) { case Assignment: - return CompileAssignment(module, builder, function, statement); + return CompileAssignment( + structTypeDeclaration, + builder, + function, + statement); case Declaration: CompileFunctionVariableDeclaration(builder, function, statement); return LLVMGetLastBasicBlock(function); case ForLoop: - return CompileForLoopStatement(module, builder, function, statement); + return CompileForLoopStatement( + structTypeDeclaration, + builder, + function, + statement); case FunctionCallExpression: - CompileFunctionCallExpression(module, builder, statement); + CompileFunctionCallExpression( + structTypeDeclaration, + builder, + statement); return LLVMGetLastBasicBlock(function); case IfStatement: - return CompileIfStatement(module, builder, function, statement); + return CompileIfStatement( + structTypeDeclaration, + builder, + function, + statement); case IfElseStatement: - return CompileIfElseStatement(module, builder, function, statement); + return CompileIfElseStatement( + structTypeDeclaration, + builder, + function, + statement); case Return: - return CompileReturn(module, builder, function, statement); + return CompileReturn( + structTypeDeclaration, + builder, + function, + statement); case ReturnVoid: return CompileReturnVoid(builder, function); case SystemCall: - CompileSystemCallExpression(module, builder, statement); + CompileSystemCallExpression(structTypeDeclaration, builder, statement); return LLVMGetLastBasicBlock(function); } @@ -1695,9 +1899,7 @@ static LLVMBasicBlockRef CompileStatement( } static void CompileFunction( - LLVMModuleRef module, - char *parentStructName, - LLVMTypeRef wStructPointerType, + StructTypeDeclaration *structTypeDeclaration, Node *functionDeclaration) { uint32_t i; @@ -1728,7 +1930,13 @@ static void CompileFunction( } } - char *functionName = strdup(parentStructName); + char *functionName = strdup(structTypeDeclaration->name); + uint32_t nameLen = strlen(functionName); + nameLen += + 2 + + strlen( + functionSignature->functionSignature.identifier->identifier.name); + functionName = realloc(functionName, sizeof(char) * nameLen); strcat(functionName, "_"); strcat( functionName, @@ -1741,7 +1949,7 @@ static void CompileFunction( if (!isStatic) { - paramTypes[paramIndex] = wStructPointerType; + paramTypes[paramIndex] = structTypeDeclaration->structPointerType; paramIndex += 1; } @@ -1761,11 +1969,13 @@ static void CompileFunction( LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, paramIndex, 0); - LLVMValueRef function = - LLVMAddFunction(module, functionName, functionType); + LLVMValueRef function = LLVMAddFunction( + structTypeDeclaration->module, + functionName, + functionType); DeclareStructFunction( - wStructPointerType, + structTypeDeclaration, function, returnType, isStatic, @@ -1778,7 +1988,10 @@ static void CompileFunction( if (!isStatic) { LLVMValueRef wStructPointer = LLVMGetParam(function, 0); - AddStructVariablesToScope(builder, wStructPointer); + AddStructVariablesToScope( + structTypeDeclaration, + builder, + wStructPointer); } for (i = 0; i < functionSignature->functionSignature.arguments @@ -1807,7 +2020,7 @@ static void CompileFunction( for (i = 0; i < functionBody->statementSequence.count; i += 1) { CompileStatement( - module, + structTypeDeclaration, builder, function, functionBody->statementSequence.sequence[i]); @@ -1832,10 +2045,9 @@ static void CompileFunction( else { DeclareGenericStructFunction( - wStructPointerType, + structTypeDeclaration, functionDeclaration, isStatic, - parentStructName, functionSignature->functionSignature.identifier->identifier.name); } @@ -1854,7 +2066,6 @@ static void CompileStruct( uint8_t packed = 1; LLVMTypeRef types[declarationCount]; Node *currentDeclarationNode; - Node *fieldDeclarations[declarationCount]; char *structName = node->structDeclaration.identifier->identifier.name; PushScopeFrame(scope); @@ -1867,6 +2078,12 @@ static void CompileStruct( wStructType, 0); /* FIXME: is this address space correct? */ + StructTypeDeclaration *structTypeDeclaration = AddStructDeclaration( + module, + wStructType, + wStructPointerType, + structName); + /* first, build the structure definition */ for (i = 0; i < declarationCount; i += 1) { @@ -1876,22 +2093,19 @@ static void CompileStruct( switch (currentDeclarationNode->syntaxKind) { - case Declaration: /* this is badly named */ + case Declaration: /* FIXME: this is badly named */ types[fieldCount] = ResolveType( currentDeclarationNode->declaration.identifier->typeTag); - fieldDeclarations[fieldCount] = currentDeclarationNode; + AddFieldToStructDeclaration( + structTypeDeclaration, + currentDeclarationNode->declaration.identifier->identifier + .name); fieldCount += 1; break; } } LLVMStructSetBody(wStructType, types, fieldCount, packed); - AddStructDeclaration( - wStructType, - wStructPointerType, - structName, - fieldDeclarations, - fieldCount); /* now we can wire up the functions */ for (i = 0; i < declarationCount; i += 1) @@ -1903,18 +2117,14 @@ static void CompileStruct( switch (currentDeclarationNode->syntaxKind) { case FunctionDeclaration: - CompileFunction( - module, - structName, - wStructPointerType, - currentDeclarationNode); + CompileFunction(structTypeDeclaration, currentDeclarationNode); break; } } } else { - AddGenericStructDeclaration(node); + AddGenericStructDeclaration(module, node); } PopScopeFrame(scope); @@ -1954,7 +2164,8 @@ static void RegisterLibraryFunctions( { LLVMTypeRef structType = LLVMStructCreateNamed(context, "Console"); LLVMTypeRef structPointerType = LLVMPointerType(structType, 0); - AddStructDeclaration(structType, structPointerType, "Console", NULL, 0); + StructTypeDeclaration *structTypeDeclaration = + AddStructDeclaration(module, structType, structPointerType, "Console"); LLVMTypeRef printfArg = LLVMPointerType(LLVMInt8Type(), 0); LLVMTypeRef printfFunctionType = @@ -1989,7 +2200,7 @@ static void RegisterLibraryFunctions( LLVMBuildAnd(builder, stringPrint, newlinePrint, "and")); DeclareStructFunction( - structPointerType, + structTypeDeclaration, printLineFunction, LLVMInt8Type(), 1, -- 2.25.1 From ca053585ac34dbd2edb95e9142dc17ff217d1e65 Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Thu, 3 Jun 2021 15:08:01 -0700 Subject: [PATCH 3/6] fix var lookups on generic structs --- generators/wraith.y | 4 +-- generic.w | 10 ++++-- src/codegen.c | 80 +++++++++++++++++++++------------------------ 3 files changed, 46 insertions(+), 48 deletions(-) diff --git a/generators/wraith.y b/generators/wraith.y index 53678f9..a3cd8c5 100644 --- a/generators/wraith.y +++ b/generators/wraith.y @@ -294,11 +294,11 @@ Statements : Statement $$ = AddStatement($1, $2); } -Arguments : PrimaryExpression +Arguments : Expression { $$ = StartFunctionArgumentSequenceNode($1); } - | Arguments COMMA PrimaryExpression + | Arguments COMMA Expression { $$ = AddFunctionArgumentNode($1, $3); } diff --git a/generic.w b/generic.w index 8007da7..51fb64f 100644 --- a/generic.w +++ b/generic.w @@ -25,8 +25,12 @@ struct Program { x: int = 4; y: int = Foo.Func(x); block: MemoryBlock; - addr: MemoryAddress = @malloc(y); - @free(addr); - return x; + block.capacity = y; + block.start = @malloc(y * @sizeof()); + z: MemoryAddress = block.AddressOf(2); + Console.PrintLine("%p", block.start); + Console.PrintLine("%p", z); + @free(block.start); + return 0; } } diff --git a/src/codegen.c b/src/codegen.c index 8919f79..fad7286 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -119,7 +119,7 @@ typedef struct MonomorphizedGenericStructHashEntry uint64_t key; TypeTag **types; uint32_t typeCount; - StructTypeDeclaration structDeclaration; + StructTypeDeclaration *structDeclaration; } MonomorphizedGenericStructHashEntry; typedef struct MonomorphizedGenericStructHashArray @@ -409,7 +409,33 @@ static LLVMTypeRef LookupCustomType(char *name) return NULL; } -static StructTypeDeclaration CompileMonomorphizedGenericStruct( +static StructTypeDeclaration *AddStructDeclaration( + LLVMModuleRef module, + LLVMTypeRef wStructType, + LLVMTypeRef wStructPointerType, + char *name) +{ + uint32_t index = structTypeDeclarationCount; + structTypeDeclarations = realloc( + structTypeDeclarations, + sizeof(StructTypeDeclaration) * (structTypeDeclarationCount + 1)); + structTypeDeclarations[index].module = module; + structTypeDeclarations[index].structType = wStructType; + structTypeDeclarations[index].structPointerType = wStructPointerType; + structTypeDeclarations[index].name = strdup(name); + structTypeDeclarations[index].fields = NULL; + structTypeDeclarations[index].fieldCount = 0; + structTypeDeclarations[index].functions = NULL; + structTypeDeclarations[index].functionCount = 0; + structTypeDeclarations[index].genericFunctions = NULL; + structTypeDeclarations[index].genericFunctionCount = 0; + + structTypeDeclarationCount += 1; + + return &structTypeDeclarations[index]; +} + +static StructTypeDeclaration *CompileMonomorphizedGenericStruct( GenericStructTypeDeclaration *genericStructTypeDeclaration, TypeTag **genericArgumentTypes, uint32_t genericArgumentTypeCount) @@ -454,17 +480,11 @@ static StructTypeDeclaration CompileMonomorphizedGenericStruct( LLVMTypeRef wStructType = LLVMStructCreateNamed(context, structName); LLVMTypeRef wStructPointerType = LLVMPointerType(wStructType, 0); - StructTypeDeclaration declaration; - declaration.module = genericStructTypeDeclaration->module; - declaration.name = structName; - declaration.structType = wStructType; - declaration.structPointerType = wStructPointerType; - declaration.genericFunctions = NULL; - declaration.genericFunctionCount = 0; - declaration.functions = NULL; - declaration.functionCount = 0; - declaration.fields = NULL; - declaration.fieldCount = 0; + StructTypeDeclaration *declaration = AddStructDeclaration( + genericStructTypeDeclaration->module, + wStructType, + wStructPointerType, + structName); /* first build the structure def */ for (i = 0; i < declarationCount; i += 1) @@ -479,7 +499,7 @@ static StructTypeDeclaration CompileMonomorphizedGenericStruct( ->declarationSequence.sequence[i] ->declaration.identifier->typeTag); AddFieldToStructDeclaration( - &declaration, + declaration, structDeclarationNode->structDeclaration.declarationSequence ->declarationSequence.sequence[i] ->declaration.identifier->identifier.name); @@ -499,7 +519,7 @@ static StructTypeDeclaration CompileMonomorphizedGenericStruct( { case FunctionDeclaration: CompileFunction( - &declaration, + declaration, structDeclarationNode->structDeclaration.declarationSequence ->declarationSequence.sequence[i]); break; @@ -563,7 +583,7 @@ static StructTypeDeclaration *LookupGenericStructType( if (hashEntry == NULL) { - StructTypeDeclaration structTypeDeclaration = + StructTypeDeclaration *structTypeDeclaration = CompileMonomorphizedGenericStruct( &genericStructTypeDeclarations[i], genericTypeTags, @@ -590,7 +610,7 @@ static StructTypeDeclaration *LookupGenericStructType( hashEntry = &hashArray->elements[hashArray->count - 1]; } - return &hashEntry->structDeclaration; + return hashEntry->structDeclaration; } } @@ -756,32 +776,6 @@ static LLVMValueRef FindVariableValue(LLVMBuilderRef builder, char *name) return NULL; } -static StructTypeDeclaration *AddStructDeclaration( - LLVMModuleRef module, - LLVMTypeRef wStructType, - LLVMTypeRef wStructPointerType, - char *name) -{ - uint32_t index = structTypeDeclarationCount; - structTypeDeclarations = realloc( - structTypeDeclarations, - sizeof(StructTypeDeclaration) * (structTypeDeclarationCount + 1)); - structTypeDeclarations[index].module = module; - structTypeDeclarations[index].structType = wStructType; - structTypeDeclarations[index].structPointerType = wStructPointerType; - structTypeDeclarations[index].name = strdup(name); - structTypeDeclarations[index].fields = NULL; - structTypeDeclarations[index].fieldCount = 0; - structTypeDeclarations[index].functions = NULL; - structTypeDeclarations[index].functionCount = 0; - structTypeDeclarations[index].genericFunctions = NULL; - structTypeDeclarations[index].genericFunctionCount = 0; - - structTypeDeclarationCount += 1; - - return &structTypeDeclarations[index]; -} - static void DeclareStructFunction( StructTypeDeclaration *structTypeDeclaration, LLVMValueRef function, -- 2.25.1 From 12ac9cc9808de6ac0e65d2fff835fe38e8c829f4 Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Thu, 3 Jun 2021 17:48:16 -0700 Subject: [PATCH 4/6] allow function self reference --- generic.w | 19 +++++-- src/ast.c | 5 +- src/codegen.c | 151 ++++++++++++++++++++++++++++++++++++++------------ src/util.c | 11 +++- src/util.h | 1 + 5 files changed, 146 insertions(+), 41 deletions(-) diff --git a/generic.w b/generic.w index 51fb64f..c127e17 100644 --- a/generic.w +++ b/generic.w @@ -14,9 +14,19 @@ struct MemoryBlock start: MemoryAddress; capacity: uint; - AddressOf(count: uint): MemoryAddress + AddressOf(index: uint): MemoryAddress { - return start + (count * @sizeof()); + return start + (index * @sizeof()); + } + + Get(index: uint): T + { + return @bitcast(AddressOf(index)); + } + + Free(): void + { + @free(start); } } @@ -24,13 +34,14 @@ struct Program { static Main(): int { x: int = 4; y: int = Foo.Func(x); - block: MemoryBlock; + block: MemoryBlock; block.capacity = y; block.start = @malloc(y * @sizeof()); z: MemoryAddress = block.AddressOf(2); + Console.PrintLine("%u", block.Get(0)); Console.PrintLine("%p", block.start); Console.PrintLine("%p", z); - @free(block.start); + block.Free(); return 0; } } diff --git a/src/ast.c b/src/ast.c index d761850..313a51d 100644 --- a/src/ast.c +++ b/src/ast.c @@ -1169,7 +1169,8 @@ char *TypeTagToString(TypeTag *tag) { char *result = strdup(tag->value.concreteGenericType.name); uint32_t len = strlen(result); - result = realloc(result, len + 2); + len += 2; + result = realloc(result, sizeof(char) * len); strcat(result, "<"); for (i = 0; i < tag->value.concreteGenericType.genericArgumentCount; @@ -1185,7 +1186,7 @@ char *TypeTagToString(TypeTag *tag) } strcat(result, inner); } - result = realloc(result, sizeof(char) * (len + 2)); + result = realloc(result, sizeof(char) * (len + 1)); strcat(result, ">"); return result; } diff --git a/src/codegen.c b/src/codegen.c index fad7286..991098f 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -152,12 +152,14 @@ uint32_t systemFunctionCount; /* FUNCTION FORWARD DECLARATIONS */ static LLVMBasicBlockRef CompileStatement( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, LLVMValueRef function, Node *statement); static LLVMValueRef CompileExpression( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, Node *expression); @@ -282,7 +284,7 @@ static void AddStructVariablesToScope( for (i = 0; i < structTypeDeclaration->fieldCount; i += 1) { char *ptrName = strdup(structTypeDeclaration->fields[i].name); - strcat(ptrName, "_ptr"); /* FIXME: needs to be realloc'd */ + ptrName = w_strcat(ptrName, "_ptr"); LLVMValueRef elementPointer = LLVMBuildStructGEP( builder, structPointer, @@ -441,7 +443,6 @@ static StructTypeDeclaration *CompileMonomorphizedGenericStruct( uint32_t genericArgumentTypeCount) { uint32_t i = 0; - uint32_t nameLen; uint32_t fieldCount = 0; Node *structDeclarationNode = genericStructTypeDeclaration->structDeclarationNode; @@ -464,15 +465,12 @@ static StructTypeDeclaration *CompileMonomorphizedGenericStruct( char *structName = strdup( structDeclarationNode->structDeclaration.identifier->identifier.name); - nameLen = strlen(structName); for (i = 0; i < genericArgumentTypeCount; i += 1) { char *inner = TypeTagToString(genericArgumentTypes[i]); - nameLen += 2 + strlen(inner); - structName = realloc(structName, sizeof(char) * nameLen); - strcat(structName, "_"); - strcat(structName, inner); + structName = w_strcat(structName, "_"); + structName = w_strcat(structName, inner); } LLVMContextRef context = @@ -486,6 +484,8 @@ static StructTypeDeclaration *CompileMonomorphizedGenericStruct( wStructPointerType, structName); + free(structName); + /* first build the structure def */ for (i = 0; i < declarationCount; i += 1) { @@ -712,7 +712,7 @@ static LLVMValueRef FindStructFieldPointer( if (strcmp(structTypeDeclarations[i].fields[j].name, name) == 0) { char *ptrName = strdup(name); - strcat(ptrName, "_ptr"); + ptrName = w_strcat(ptrName, "_ptr"); return LLVMBuildStructGEP( builder, structPointer, @@ -847,6 +847,7 @@ static StructTypeFunction CompileGenericFunction( LLVMTypeRef paramTypes[argumentCount + 1]; uint32_t paramIndex = 0; LLVMTypeRef returnType; + LLVMValueRef wStructPointer = NULL; PushScopeFrame(scope); @@ -878,16 +879,17 @@ static StructTypeFunction CompileGenericFunction( } } - /* FIXME: these cats need to be realloc'd */ char *functionName = strdup(structTypeDeclaration->name); - strcat(functionName, "_"); - strcat( + functionName = w_strcat(functionName, "_"); + functionName = w_strcat( functionName, functionSignature->functionSignature.identifier->identifier.name); for (i = 0; i < genericArgumentTypeCount; i += 1) { - strcat(functionName, TypeTagToString(resolvedGenericArgumentTypes[i])); + functionName = w_strcat( + functionName, + TypeTagToString(resolvedGenericArgumentTypes[i])); } if (!isStatic) @@ -925,7 +927,7 @@ static StructTypeFunction CompileGenericFunction( if (!isStatic) { - LLVMValueRef wStructPointer = LLVMGetParam(function, 0); + wStructPointer = LLVMGetParam(function, 0); AddStructVariablesToScope( structTypeDeclaration, builder, @@ -939,7 +941,7 @@ static StructTypeFunction CompileGenericFunction( char *ptrName = strdup(functionSignature->functionSignature.arguments ->functionSignatureArguments.sequence[i] ->declaration.identifier->identifier.name); - strcat(ptrName, "_ptr"); + ptrName = w_strcat(ptrName, "_ptr"); LLVMValueRef argument = LLVMGetParam(function, i + !isStatic); LLVMValueRef argumentCopy = LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName); @@ -958,6 +960,7 @@ static StructTypeFunction CompileGenericFunction( { CompileStatement( structTypeDeclaration, + wStructPointer, builder, function, functionBody->statementSequence.sequence[i]); @@ -1143,14 +1146,12 @@ static LLVMValueRef LookupGenericFunction( static LLVMValueRef LookupFunctionByType( LLVMTypeRef structType, + char *name, Node *functionCallExpression, LLVMTypeRef *pReturnType, uint8_t *pStatic) { uint32_t i, j; - /* FIXME: hot garbage */ - char *name = functionCallExpression->functionCallExpression.identifier - ->accessExpression.accessor->identifier.name; for (i = 0; i < structTypeDeclarationCount; i += 1) { @@ -1191,14 +1192,12 @@ static LLVMValueRef LookupFunctionByType( static LLVMValueRef LookupFunctionByPointerType( LLVMTypeRef structPointerType, + char *name, Node *functionCallExpression, LLVMTypeRef *pReturnType, uint8_t *pStatic) { uint32_t i, j; - /* FIXME: hot garbage */ - char *name = functionCallExpression->functionCallExpression.identifier - ->accessExpression.accessor->identifier.name; for (i = 0; i < structTypeDeclarationCount; i += 1) { @@ -1239,12 +1238,14 @@ static LLVMValueRef LookupFunctionByPointerType( static LLVMValueRef LookupFunctionByInstance( LLVMValueRef structPointer, + char *functionName, Node *functionCallExpression, LLVMTypeRef *pReturnType, uint8_t *pStatic) { return LookupFunctionByPointerType( LLVMTypeOf(structPointer), + functionName, functionCallExpression, pReturnType, pStatic); @@ -1267,16 +1268,19 @@ static LLVMValueRef CompileString( static LLVMValueRef CompileBinaryExpression( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, Node *binaryExpression) { LLVMValueRef left = CompileExpression( structTypeDeclaration, + selfParam, builder, binaryExpression->binaryExpression.left); LLVMValueRef right = CompileExpression( structTypeDeclaration, + selfParam, builder, binaryExpression->binaryExpression.right); @@ -1324,6 +1328,7 @@ static LLVMValueRef CompileBinaryExpression( /* FIXME THIS IS ALL BROKEN */ static LLVMValueRef CompileFunctionCallExpression( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, Node *functionCallExpression) { @@ -1335,7 +1340,7 @@ static LLVMValueRef CompileFunctionCallExpression( 1]; LLVMValueRef function; uint8_t isStatic; - LLVMValueRef structInstance; + LLVMValueRef structInstance = NULL; LLVMTypeRef functionReturnType; char *returnName = ""; @@ -1349,10 +1354,15 @@ static LLVMValueRef CompileFunctionCallExpression( functionCallExpression->functionCallExpression.identifier ->accessExpression.accessee->identifier.name); + char *functionName = + functionCallExpression->functionCallExpression.identifier + ->accessExpression.accessor->identifier.name; + if (typeReference != NULL) { function = LookupFunctionByType( typeReference, + functionName, functionCallExpression, &functionReturnType, &isStatic); @@ -1362,16 +1372,38 @@ static LLVMValueRef CompileFunctionCallExpression( structInstance = FindVariablePointer( functionCallExpression->functionCallExpression.identifier ->accessExpression.accessee->identifier.name); + function = LookupFunctionByInstance( structInstance, + functionName, functionCallExpression, &functionReturnType, &isStatic); } } + else if ( + functionCallExpression->functionCallExpression.identifier->syntaxKind == + Identifier) + { + LLVMTypeRef structType = structTypeDeclaration->structType; + char *functionName = functionCallExpression->functionCallExpression + .identifier->identifier.name; + + function = LookupFunctionByType( + structType, + functionName, + functionCallExpression, + &functionReturnType, + &isStatic); + + if (!isStatic) + { + structInstance = selfParam; + } + } else { - fprintf(stderr, "Failed to find function!\n"); + fprintf(stderr, "Function identifier syntax kind not recognized!\n"); return NULL; } @@ -1387,6 +1419,7 @@ static LLVMValueRef CompileFunctionCallExpression( { args[argumentCount] = CompileExpression( structTypeDeclaration, + selfParam, builder, functionCallExpression->functionCallExpression.argumentSequence ->functionArgumentSequence.sequence[i]); @@ -1403,6 +1436,7 @@ static LLVMValueRef CompileFunctionCallExpression( static LLVMValueRef CompileSystemCallExpression( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, Node *systemCallExpression) { @@ -1418,6 +1452,7 @@ static LLVMValueRef CompileSystemCallExpression( { args[i] = CompileExpression( structTypeDeclaration, + selfParam, builder, systemCallExpression->systemCall.argumentSequence ->functionArgumentSequence.sequence[i]); @@ -1439,6 +1474,29 @@ static LLVMValueRef CompileSystemCallExpression( return LLVMSizeOf(ResolveType(ConcretizeType(typeTag))); } + else if ( + strcmp( + systemCallExpression->systemCall.identifier->identifier.name, + "bitcast") == 0) + { + TypeTag *typeTag = ConcretizeType( + systemCallExpression->systemCall.genericArguments + ->genericArguments.arguments[0] + ->type.typeNode->typeTag); + + LLVMValueRef expression = CompileExpression( + structTypeDeclaration, + selfParam, + builder, + systemCallExpression->systemCall.argumentSequence + ->functionArgumentSequence.sequence[0]); + + return LLVMBuildBitCast( + builder, + expression, + ResolveType(typeTag), + "castResult"); + } else { fprintf(stderr, "System function not found!"); @@ -1498,6 +1556,7 @@ static LLVMValueRef CompileAllocExpression( static LLVMValueRef CompileExpression( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, Node *expression) { @@ -1512,12 +1571,14 @@ static LLVMValueRef CompileExpression( case BinaryExpression: return CompileBinaryExpression( structTypeDeclaration, + selfParam, builder, expression); case FunctionCallExpression: return CompileFunctionCallExpression( structTypeDeclaration, + selfParam, builder, expression); @@ -1533,6 +1594,7 @@ static LLVMValueRef CompileExpression( case SystemCall: return CompileSystemCallExpression( structTypeDeclaration, + selfParam, builder, expression); } @@ -1543,12 +1605,14 @@ static LLVMValueRef CompileExpression( static LLVMBasicBlockRef CompileReturn( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement) { LLVMValueRef expression = CompileExpression( structTypeDeclaration, + selfParam, builder, returnStatemement->returnStatement.expression); LLVMBuildRet(builder, expression); @@ -1573,7 +1637,7 @@ static LLVMValueRef CompileFunctionVariableDeclaration( char *variableName = variableDeclaration->declaration.identifier->identifier.name; char *ptrName = strdup(variableName); - strcat(ptrName, "_ptr"); + ptrName = w_strcat(ptrName, "_ptr"); variable = LLVMBuildAlloca( builder, @@ -1589,12 +1653,14 @@ static LLVMValueRef CompileFunctionVariableDeclaration( static LLVMBasicBlockRef CompileAssignment( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement) { LLVMValueRef result = CompileExpression( structTypeDeclaration, + selfParam, builder, assignmentStatement->assignmentStatement.right); @@ -1635,6 +1701,7 @@ static LLVMBasicBlockRef CompileAssignment( static LLVMBasicBlockRef CompileIfStatement( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement) @@ -1642,6 +1709,7 @@ static LLVMBasicBlockRef CompileIfStatement( uint32_t i; LLVMValueRef conditional = CompileExpression( structTypeDeclaration, + selfParam, builder, ifStatement->ifStatement.expression); @@ -1659,6 +1727,7 @@ static LLVMBasicBlockRef CompileIfStatement( { CompileStatement( structTypeDeclaration, + selfParam, builder, function, ifStatement->ifStatement.statementSequence->statementSequence @@ -1673,6 +1742,7 @@ static LLVMBasicBlockRef CompileIfStatement( static LLVMBasicBlockRef CompileIfElseStatement( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, LLVMValueRef function, Node *ifElseStatement) @@ -1680,6 +1750,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( uint32_t i; LLVMValueRef conditional = CompileExpression( structTypeDeclaration, + selfParam, builder, ifElseStatement->ifElseStatement.ifStatement->ifStatement.expression); @@ -1697,6 +1768,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( { CompileStatement( structTypeDeclaration, + selfParam, builder, function, ifElseStatement->ifElseStatement.ifStatement->ifStatement @@ -1716,6 +1788,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( { CompileStatement( structTypeDeclaration, + selfParam, builder, function, ifElseStatement->ifElseStatement.elseStatement @@ -1726,6 +1799,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( { CompileStatement( structTypeDeclaration, + selfParam, builder, function, ifElseStatement->ifElseStatement.elseStatement); @@ -1739,6 +1813,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( static LLVMBasicBlockRef CompileForLoopStatement( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, LLVMValueRef function, Node *forLoopStatement) @@ -1799,6 +1874,7 @@ static LLVMBasicBlockRef CompileForLoopStatement( { lastBlock = CompileStatement( structTypeDeclaration, + selfParam, builder, function, forLoopStatement->forLoop.statementSequence->statementSequence @@ -1828,6 +1904,7 @@ static LLVMBasicBlockRef CompileForLoopStatement( static LLVMBasicBlockRef CompileStatement( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, /* can be NULL for statics */ LLVMBuilderRef builder, LLVMValueRef function, Node *statement) @@ -1837,6 +1914,7 @@ static LLVMBasicBlockRef CompileStatement( case Assignment: return CompileAssignment( structTypeDeclaration, + selfParam, builder, function, statement); @@ -1848,6 +1926,7 @@ static LLVMBasicBlockRef CompileStatement( case ForLoop: return CompileForLoopStatement( structTypeDeclaration, + selfParam, builder, function, statement); @@ -1855,6 +1934,7 @@ static LLVMBasicBlockRef CompileStatement( case FunctionCallExpression: CompileFunctionCallExpression( structTypeDeclaration, + selfParam, builder, statement); return LLVMGetLastBasicBlock(function); @@ -1862,6 +1942,7 @@ static LLVMBasicBlockRef CompileStatement( case IfStatement: return CompileIfStatement( structTypeDeclaration, + selfParam, builder, function, statement); @@ -1869,6 +1950,7 @@ static LLVMBasicBlockRef CompileStatement( case IfElseStatement: return CompileIfElseStatement( structTypeDeclaration, + selfParam, builder, function, statement); @@ -1876,6 +1958,7 @@ static LLVMBasicBlockRef CompileStatement( case Return: return CompileReturn( structTypeDeclaration, + selfParam, builder, function, statement); @@ -1884,7 +1967,11 @@ static LLVMBasicBlockRef CompileStatement( return CompileReturnVoid(builder, function); case SystemCall: - CompileSystemCallExpression(structTypeDeclaration, builder, statement); + CompileSystemCallExpression( + structTypeDeclaration, + selfParam, + builder, + statement); return LLVMGetLastBasicBlock(function); } @@ -1906,6 +1993,7 @@ static void CompileFunction( ->functionSignatureArguments.count; LLVMTypeRef paramTypes[argumentCount + 1]; uint32_t paramIndex = 0; + LLVMValueRef wStructPointer = NULL; if (functionSignature->functionSignature.modifiers->functionModifiers .count > 0) @@ -1925,14 +2013,8 @@ static void CompileFunction( } char *functionName = strdup(structTypeDeclaration->name); - uint32_t nameLen = strlen(functionName); - nameLen += - 2 + - strlen( - functionSignature->functionSignature.identifier->identifier.name); - functionName = realloc(functionName, sizeof(char) * nameLen); - strcat(functionName, "_"); - strcat( + functionName = w_strcat(functionName, "_"); + functionName = w_strcat( functionName, functionSignature->functionSignature.identifier->identifier.name); @@ -1981,7 +2063,7 @@ static void CompileFunction( if (!isStatic) { - LLVMValueRef wStructPointer = LLVMGetParam(function, 0); + wStructPointer = LLVMGetParam(function, 0); AddStructVariablesToScope( structTypeDeclaration, builder, @@ -1996,7 +2078,7 @@ static void CompileFunction( strdup(functionSignature->functionSignature.arguments ->functionSignatureArguments.sequence[i] ->declaration.identifier->identifier.name); - strcat(ptrName, "_ptr"); + ptrName = w_strcat(ptrName, "_ptr"); LLVMValueRef argument = LLVMGetParam(function, i + !isStatic); LLVMValueRef argumentCopy = LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName); @@ -2015,6 +2097,7 @@ static void CompileFunction( { CompileStatement( structTypeDeclaration, + wStructPointer, builder, function, functionBody->statementSequence.sequence[i]); diff --git a/src/util.c b/src/util.c index 42911e7..491bb81 100644 --- a/src/util.c +++ b/src/util.c @@ -5,7 +5,7 @@ char *strdup(const char *s) { size_t slen = strlen(s); - char *result = (char *)malloc(slen + 1); + char *result = (char *)malloc(sizeof(char) * (slen + 1)); if (result == NULL) { return NULL; @@ -15,6 +15,15 @@ char *strdup(const char *s) return result; } +char *w_strcat(char *s, char *s2) +{ + size_t slen = strlen(s); + size_t slen2 = strlen(s2); + s = realloc(s, sizeof(char) * (slen + slen2 + 1)); + strcat(s, s2); + return s; +} + uint64_t str_hash(char *str) { uint64_t hash = 5381; diff --git a/src/util.h b/src/util.h index 108211b..94a64e2 100644 --- a/src/util.h +++ b/src/util.h @@ -5,6 +5,7 @@ #include char *strdup(const char *s); +char *w_strcat(char *s, char *s2); uint64_t str_hash(char *str); #endif /* WRAITH_UTIL_H */ -- 2.25.1 From 01da2dc3774085caf30260e34bc0aeb5a990628e Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Fri, 4 Jun 2021 00:41:30 -0700 Subject: [PATCH 5/6] add some memory sys calls --- generic.w | 19 +++++++++++--- src/codegen.c | 69 ++++++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 78 insertions(+), 10 deletions(-) diff --git a/generic.w b/generic.w index c127e17..ba8d511 100644 --- a/generic.w +++ b/generic.w @@ -21,7 +21,12 @@ struct MemoryBlock Get(index: uint): T { - return @bitcast(AddressOf(index)); + return @dereference(AddressOf(index)); + } + + Set(index: uint, value: T): void + { + @memcpy(AddressOf(index), @addr(value), @sizeof()); } Free(): void @@ -34,13 +39,19 @@ struct Program { static Main(): int { x: int = 4; y: int = Foo.Func(x); - block: MemoryBlock; + block: MemoryBlock; block.capacity = y; block.start = @malloc(y * @sizeof()); z: MemoryAddress = block.AddressOf(2); - Console.PrintLine("%u", block.Get(0)); Console.PrintLine("%p", block.start); - Console.PrintLine("%p", z); + block.Set(0, 5); + block.Set(1, 3); + block.Set(2, 9); + block.Set(3, 100); + Console.PrintLine("%i", block.Get(0)); + Console.PrintLine("%i", block.Get(1)); + Console.PrintLine("%i", block.Get(2)); + Console.PrintLine("%i", block.Get(3)); block.Free(); return 0; } diff --git a/src/codegen.c b/src/codegen.c index 991098f..f9cbf7b 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -1474,6 +1474,39 @@ static LLVMValueRef CompileSystemCallExpression( return LLVMSizeOf(ResolveType(ConcretizeType(typeTag))); } + else if ( + strcmp( + systemCallExpression->systemCall.identifier->identifier.name, + "addr") == 0) + { + return LLVMBuildPtrToInt( + builder, + FindVariablePointer( + systemCallExpression->systemCall.argumentSequence + ->functionArgumentSequence.sequence[0] + ->identifier.name), + LLVMInt64Type(), + "addrResult"); + } + else if ( + strcmp( + systemCallExpression->systemCall.identifier->identifier.name, + "dereference") == 0) + { + TypeTag *typeTag = ConcretizeType( + systemCallExpression->systemCall.genericArguments + ->genericArguments.arguments[0] + ->type.typeNode->typeTag); + + return LLVMBuildLoad( + builder, + LLVMBuildIntToPtr( + builder, + args[0], + LLVMPointerType(ResolveType(typeTag), 0), + "deref_ptr"), + "deref"); + } else if ( strcmp( systemCallExpression->systemCall.identifier->identifier.name, @@ -1484,12 +1517,7 @@ static LLVMValueRef CompileSystemCallExpression( ->genericArguments.arguments[0] ->type.typeNode->typeTag); - LLVMValueRef expression = CompileExpression( - structTypeDeclaration, - selfParam, - builder, - systemCallExpression->systemCall.argumentSequence - ->functionArgumentSequence.sequence[0]); + LLVMValueRef expression = args[0]; return LLVMBuildBitCast( builder, @@ -2301,6 +2329,35 @@ static void RegisterLibraryFunctions( AddSystemFunction("free", freeFunctionType, freeFunction); + LLVMTypeRef memcpyParams[3]; + memcpyParams[0] = LLVMInt64Type(); + memcpyParams[1] = LLVMInt64Type(); + memcpyParams[2] = LLVMInt64Type(); + LLVMTypeRef memcpyFunctionType = + LLVMFunctionType(LLVMVoidType(), memcpyParams, 3, 0); + LLVMValueRef memcpyFunction = + LLVMAddFunction(module, "memcopy", memcpyFunctionType); + + LLVMBasicBlockRef memcpyEntry = + LLVMAppendBasicBlock(memcpyFunction, "entry"); + LLVMPositionBuilderAtEnd(builder, memcpyEntry); + LLVMValueRef dest = LLVMBuildIntToPtr( + builder, + LLVMGetParam(memcpyFunction, 0), + LLVMPointerType(LLVMInt64Type(), 0), + "dest"); + LLVMValueRef src = LLVMBuildIntToPtr( + builder, + LLVMGetParam(memcpyFunction, 1), + LLVMPointerType(LLVMInt64Type(), 0), + "src"); + + LLVMBuildMemCpy(builder, dest, 8, src, 8, LLVMGetParam(memcpyFunction, 2)); + + LLVMBuildRetVoid(builder); + + AddSystemFunction("memcpy", memcpyFunctionType, memcpyFunction); + LLVMDisposeBuilder(builder); } -- 2.25.1 From e2332349b7ccbd7451319e6e293e5a5807b817af Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Sun, 6 Jun 2021 14:35:54 -0700 Subject: [PATCH 6/6] add struct initializer --- generators/wraith.y | 25 ++++++++ generic.w | 11 ++-- src/ast.c | 141 ++++++++++++++++++++++++++++++++++++++++++++ src/ast.h | 27 +++++++++ src/codegen.c | 102 +++++++++++++++++++++++++++++--- 5 files changed, 294 insertions(+), 12 deletions(-) diff --git a/generators/wraith.y b/generators/wraith.y index a3cd8c5..d9a637f 100644 --- a/generators/wraith.y +++ b/generators/wraith.y @@ -161,6 +161,30 @@ Number : NUMBER $$ = MakeNumberNode(yytext); } +FieldInit : Identifier COLON Expression + { + $$ = MakeFieldInitNode($1, $3); + } + +StructInitFields : FieldInit + { + $$ = StartStructInitFieldsNode($1); + } + | StructInitFields COMMA FieldInit + { + $$ = AddFieldInitNode($1, $3); + } + | + { + $$ = MakeEmptyFieldInitNode(); + } + ; + +StructInitExpression : Type LEFT_BRACE StructInitFields RIGHT_BRACE + { + $$ = MakeStructInitExpressionNode($1, $3); + } + PrimaryExpression : Number | STRING_LITERAL { @@ -172,6 +196,7 @@ PrimaryExpression : Number } | FunctionCallExpression | AccessExpression + | StructInitExpression ; UnaryExpression : BANG Expression diff --git a/generic.w b/generic.w index ba8d511..b85d7e8 100644 --- a/generic.w +++ b/generic.w @@ -39,15 +39,16 @@ struct Program { static Main(): int { x: int = 4; y: int = Foo.Func(x); - block: MemoryBlock; - block.capacity = y; - block.start = @malloc(y * @sizeof()); - z: MemoryAddress = block.AddressOf(2); - Console.PrintLine("%p", block.start); + block: MemoryBlock = MemoryBlock + { + capacity: y, + start: @malloc(y * @sizeof()) + }; block.Set(0, 5); block.Set(1, 3); block.Set(2, 9); block.Set(3, 100); + Console.PrintLine("%p", block.start); Console.PrintLine("%i", block.Get(0)); Console.PrintLine("%i", block.Get(1)); Console.PrintLine("%i", block.Get(2)); diff --git a/src/ast.c b/src/ast.c index 313a51d..a4cd091 100644 --- a/src/ast.c +++ b/src/ast.c @@ -29,6 +29,8 @@ const char *SyntaxKindString(SyntaxKind syntaxKind) return "ForLoop"; case DeclarationSequence: return "DeclarationSequence"; + case FieldInit: + return "FieldInit"; case FunctionArgumentSequence: return "FunctionArgumentSequence"; case FunctionCallExpression: @@ -73,6 +75,10 @@ const char *SyntaxKindString(SyntaxKind syntaxKind) return "StringLiteral"; case StructDeclaration: return "StructDeclaration"; + case StructInit: + return "StructInit"; + case StructInitFields: + return "StructInitFields"; case SystemCall: return "SystemCall"; case Type: @@ -557,6 +563,55 @@ Node *MakeForLoopNode( return node; } +Node *MakeFieldInitNode(Node *identifierNode, Node *expressionNode) +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = FieldInit; + node->fieldInit.identifier = identifierNode; + node->fieldInit.expression = expressionNode; + return node; +} + +Node *StartStructInitFieldsNode(Node *fieldInitNode) +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = StructInitFields; + node->structInitFields.fieldInits = (Node **)malloc(sizeof(Node *)); + node->structInitFields.fieldInits[0] = fieldInitNode; + node->structInitFields.count = 1; + return node; +} + +Node *AddFieldInitNode(Node *structInitFieldsNode, Node *fieldInitNode) +{ + structInitFieldsNode->structInitFields.fieldInits = realloc( + structInitFieldsNode->structInitFields.fieldInits, + sizeof(Node *) * (structInitFieldsNode->structInitFields.count + 1)); + structInitFieldsNode->structInitFields + .fieldInits[structInitFieldsNode->structInitFields.count] = + fieldInitNode; + structInitFieldsNode->structInitFields.count += 1; + return structInitFieldsNode; +} + +Node *MakeEmptyFieldInitNode() +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = StructInitFields; + node->structInitFields.fieldInits = NULL; + node->structInitFields.count = 0; + return node; +} + +Node *MakeStructInitExpressionNode(Node *typeNode, Node *structInitFieldsNode) +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = StructInit; + node->structInit.type = typeNode; + node->structInit.initFields = structInitFieldsNode; + return node; +} + static const char *PrimitiveTypeToString(PrimitiveType type) { switch (type) @@ -662,6 +717,12 @@ void PrintNode(Node *node, uint32_t tabCount) } return; + case FieldInit: + printf("\n"); + PrintNode(node->fieldInit.identifier, tabCount + 1); + PrintNode(node->fieldInit.expression, tabCount + 1); + return; + case ForLoop: printf("\n"); PrintNode(node->forLoop.declaration, tabCount + 1); @@ -817,6 +878,20 @@ void PrintNode(Node *node, uint32_t tabCount) PrintNode(node->structDeclaration.declarationSequence, tabCount + 1); return; + case StructInit: + printf("\n"); + PrintNode(node->structInit.type, tabCount + 1); + PrintNode(node->structInit.initFields, tabCount + 1); + return; + + case StructInitFields: + printf("\n"); + for (i = 0; i < node->structInitFields.count; i += 1) + { + PrintNode(node->structInitFields.fieldInits[i], tabCount + 1); + } + return; + case SystemCall: printf("\n"); PrintNode(node->systemCall.identifier, tabCount + 1); @@ -882,6 +957,11 @@ void Recurse(Node *node, void (*func)(Node *)) } return; + case FieldInit: + func(node->fieldInit.identifier); + func(node->fieldInit.expression); + return; + case ForLoop: func(node->forLoop.declaration); func(node->forLoop.startNumber); @@ -1003,6 +1083,18 @@ void Recurse(Node *node, void (*func)(Node *)) func(node->structDeclaration.declarationSequence); return; + case StructInit: + func(node->structInit.type); + func(node->structInit.initFields); + return; + + case StructInitFields: + for (i = 0; i < node->structInitFields.count; i += 1) + { + func(node->structInitFields.fieldInits[i]); + } + return; + case SystemCall: func(node->systemCall.identifier); func(node->systemCall.argumentSequence); @@ -1193,6 +1285,38 @@ char *TypeTagToString(TypeTag *tag) } } +uint8_t TypeTagEqual(TypeTag *typeTagA, TypeTag *typeTagB) +{ + if (typeTagA->type != typeTagB->type) + { + return 0; + } + + switch (typeTagA->type) + { + case Primitive: + return typeTagA->value.primitiveType == typeTagB->value.primitiveType; + + case Reference: + return TypeTagEqual( + typeTagA->value.referenceType, + typeTagB->value.referenceType); + + case Custom: + return strcmp(typeTagA->value.customType, typeTagB->value.customType) == + 0; + + case Generic: + return strcmp( + typeTagA->value.genericType, + typeTagB->value.genericType) == 0; + + default: + fprintf(stderr, "Invalid type comparison!"); + return 0; + } +} + void LinkParentPointers(Node *node, Node *prev) { if (node == NULL) @@ -1240,6 +1364,11 @@ void LinkParentPointers(Node *node, Node *prev) } return; + case FieldInit: + LinkParentPointers(node->fieldInit.identifier, node); + LinkParentPointers(node->fieldInit.expression, node); + return; + case ForLoop: LinkParentPointers(node->forLoop.declaration, node); LinkParentPointers(node->forLoop.startNumber, node); @@ -1364,6 +1493,18 @@ void LinkParentPointers(Node *node, Node *prev) LinkParentPointers(node->structDeclaration.declarationSequence, node); return; + case StructInit: + LinkParentPointers(node->structInit.type, node); + LinkParentPointers(node->structInit.initFields, node); + return; + + case StructInitFields: + for (i = 0; i < node->structInitFields.count; i += 1) + { + LinkParentPointers(node->structInitFields.fieldInits[i], node); + } + return; + case SystemCall: LinkParentPointers(node->systemCall.identifier, node); LinkParentPointers(node->systemCall.argumentSequence, node); diff --git a/src/ast.h b/src/ast.h index 531064b..3fdedf9 100644 --- a/src/ast.h +++ b/src/ast.h @@ -23,6 +23,7 @@ typedef enum CustomTypeNode, Declaration, DeclarationSequence, + FieldInit, ForLoop, FunctionArgumentSequence, FunctionCallExpression, @@ -47,6 +48,8 @@ typedef enum StaticModifier, StringLiteral, StructDeclaration, + StructInit, + StructInitFields, SystemCall, Type, UnaryExpression @@ -182,6 +185,12 @@ struct Node uint32_t count; } declarationSequence; + struct + { + Node *identifier; + Node *expression; + } fieldInit; + struct { Node *declaration; @@ -323,6 +332,18 @@ struct Node Node *genericDeclarations; } structDeclaration; + struct + { + Node *type; + Node *initFields; + } structInit; + + struct + { + Node **fieldInits; + uint32_t count; + } structInitFields; + struct { Node *identifier; @@ -419,6 +440,11 @@ Node *MakeForLoopNode( Node *startNumberNode, Node *endNumberNode, Node *statementSequenceNode); +Node *MakeFieldInitNode(Node *identifierNode, Node *expressionNode); +Node *StartStructInitFieldsNode(Node *fieldInitNode); +Node *AddFieldInitNode(Node *structInitFieldsNode, Node *fieldInitNode); +Node *MakeEmptyFieldInitNode(); +Node *MakeStructInitExpressionNode(Node *typeNode, Node *structInitFieldsNode); void PrintNode(Node *node, uint32_t tabCount); const char *SyntaxKindString(SyntaxKind syntaxKind); @@ -434,6 +460,7 @@ void LinkParentPointers(Node *node, Node *prev); TypeTag *MakeTypeTag(Node *node); char *TypeTagToString(TypeTag *tag); +uint8_t TypeTagEqual(TypeTag *typeTagA, TypeTag *typeTagB); Node *LookupIdNode(Node *current, Node *prev, char *target); diff --git a/src/codegen.c b/src/codegen.c index f9cbf7b..b2421f9 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -567,7 +567,9 @@ static StructTypeDeclaration *LookupGenericStructType( for (k = 0; k < hashArray->elements[j].typeCount; k += 1) { - if (hashArray->elements[j].types[k] != genericTypeTags[k]) + if (!TypeTagEqual( + hashArray->elements[j].types[k], + genericTypeTags[k])) { match = 0; break; @@ -679,7 +681,8 @@ static SystemFunction *LookupSystemFunction(Node *systemCallExpression) return NULL; } -static LLVMTypeRef FindStructType(char *name) +/* FIXME: this is awkward, should just resolve the type */ +static LLVMTypeRef LookupStructTypeByName(char *name) { uint32_t i; @@ -694,6 +697,22 @@ static LLVMTypeRef FindStructType(char *name) return NULL; } +static StructTypeDeclaration *LookupStructDeclaration(LLVMTypeRef structType) +{ + uint32_t i; + + for (i = 0; i < structTypeDeclarationCount; i += 1) + { + if (structTypeDeclarations[i].structType == structType) + { + return &structTypeDeclarations[i]; + } + } + + fprintf(stderr, "Struct type not found!"); + return NULL; +} + static LLVMValueRef FindStructFieldPointer( LLVMBuilderRef builder, LLVMValueRef structPointer, @@ -1095,8 +1114,9 @@ static LLVMValueRef LookupGenericFunction( for (j = 0; j < hashArray->elements[i].typeCount; j += 1) { - if (hashArray->elements[i].types[j] != - resolvedGenericArgumentTypes[j]) + if (!TypeTagEqual( + hashArray->elements[i].types[j], + resolvedGenericArgumentTypes[j])) { match = 0; break; @@ -1350,7 +1370,7 @@ static LLVMValueRef CompileFunctionCallExpression( if (functionCallExpression->functionCallExpression.identifier->syntaxKind == AccessExpression) { - LLVMTypeRef typeReference = FindStructType( + LLVMTypeRef typeReference = LookupStructTypeByName( functionCallExpression->functionCallExpression.identifier ->accessExpression.accessee->identifier.name); @@ -1582,6 +1602,47 @@ static LLVMValueRef CompileAllocExpression( return LLVMBuildMalloc(builder, type, "allocation"); } +static LLVMValueRef CompileStructInitExpression( + StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, + LLVMBuilderRef builder, + Node *structInitExpression) +{ + uint32_t i = 0; + + LLVMTypeRef structType = ResolveType( + ConcretizeType(structInitExpression->structInit.type->typeTag)); + + LLVMValueRef structPointer = + LLVMBuildAlloca(builder, structType, "structInit"); + + for (i = 0; + i < + structInitExpression->structInit.initFields->structInitFields.count; + i += 1) + { + LLVMValueRef structFieldPointer = FindStructFieldPointer( + builder, + structPointer, + structInitExpression->structInit.initFields->structInitFields + .fieldInits[i] + ->fieldInit.identifier->identifier.name); + + LLVMBuildStore( + builder, + CompileExpression( + structTypeDeclaration, + selfParam, + builder, + structInitExpression->structInit.initFields->structInitFields + .fieldInits[i] + ->fieldInit.expression), + structFieldPointer); + } + + return structPointer; +} + static LLVMValueRef CompileExpression( StructTypeDeclaration *structTypeDeclaration, LLVMValueRef selfParam, @@ -1619,6 +1680,13 @@ static LLVMValueRef CompileExpression( case StringLiteral: return CompileString(builder, expression); + case StructInit: + return CompileStructInitExpression( + structTypeDeclaration, + selfParam, + builder, + expression); + case SystemCall: return CompileSystemCallExpression( structTypeDeclaration, @@ -1722,7 +1790,21 @@ static LLVMBasicBlockRef CompileAssignment( return LLVMGetLastBasicBlock(function); } - LLVMBuildStore(builder, result, identifier); + if (assignmentStatement->assignmentStatement.right->syntaxKind == + StructInit) + { + LLVMBuildMemCpy( + builder, + identifier, + LLVMGetAlignment(identifier), + result, + LLVMGetAlignment(result), + LLVMSizeOf(LLVMTypeOf(result))); + } + else + { + LLVMBuildStore(builder, result, identifier); + } return LLVMGetLastBasicBlock(function); } @@ -2352,7 +2434,13 @@ static void RegisterLibraryFunctions( LLVMPointerType(LLVMInt64Type(), 0), "src"); - LLVMBuildMemCpy(builder, dest, 8, src, 8, LLVMGetParam(memcpyFunction, 2)); + LLVMBuildMemCpy( + builder, + dest, + LLVMGetAlignment(dest), + src, + LLVMGetAlignment(src), + LLVMGetParam(memcpyFunction, 2)); LLVMBuildRetVoid(builder); -- 2.25.1