From 9adfaed54cd880116acca077135f883a0355e753 Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Mon, 7 Jun 2021 18:51:33 +0000 Subject: [PATCH] Generic Structs (#11) Reviewed-on: https://gitea.moonside.games/cosmonaut/wraith-lang/pulls/11 Co-authored-by: cosmonaut Co-committed-by: cosmonaut --- generators/wraith.y | 45 +- generic.w | 38 +- src/ast.c | 221 ++++++++- src/ast.h | 57 ++- src/codegen.c | 1069 ++++++++++++++++++++++++++++++------------- src/util.c | 11 +- src/util.h | 1 + 7 files changed, 1105 insertions(+), 337 deletions(-) diff --git a/generators/wraith.y b/generators/wraith.y index 803d922..d9a637f 100644 --- a/generators/wraith.y +++ b/generators/wraith.y @@ -113,9 +113,13 @@ BaseType : VOID { $$ = MakePrimitiveTypeNode(MemoryAddress); } + | Identifier GenericArgumentClauseNonEmpty + { + $$ = MakeConcreteGenericTypeNode($1, $2); + } | Identifier { - $$ = MakeCustomTypeNode(yytext); + $$ = MakeCustomTypeNode($1); } | REFERENCE LESS_THAN Type GREATER_THAN { @@ -157,6 +161,30 @@ Number : NUMBER $$ = MakeNumberNode(yytext); } +FieldInit : Identifier COLON Expression + { + $$ = MakeFieldInitNode($1, $3); + } + +StructInitFields : FieldInit + { + $$ = StartStructInitFieldsNode($1); + } + | StructInitFields COMMA FieldInit + { + $$ = AddFieldInitNode($1, $3); + } + | + { + $$ = MakeEmptyFieldInitNode(); + } + ; + +StructInitExpression : Type LEFT_BRACE StructInitFields RIGHT_BRACE + { + $$ = MakeStructInitExpressionNode($1, $3); + } + PrimaryExpression : Number | STRING_LITERAL { @@ -168,6 +196,7 @@ PrimaryExpression : Number } | FunctionCallExpression | AccessExpression + | StructInitExpression ; UnaryExpression : BANG Expression @@ -290,11 +319,11 @@ Statements : Statement $$ = AddStatement($1, $2); } -Arguments : PrimaryExpression +Arguments : Expression { $$ = StartFunctionArgumentSequenceNode($1); } - | Arguments COMMA PrimaryExpression + | Arguments COMMA Expression { $$ = AddFunctionArgumentNode($1, $3); } @@ -359,11 +388,13 @@ GenericArguments : GenericArgument $$ = AddGenericArgument($1, $3); } +GenericArgumentClauseNonEmpty : LESS_THAN GenericArguments GREATER_THAN + { + $$ = $2; + } + ; -GenericArgumentClause : LESS_THAN GenericArguments GREATER_THAN - { - $$ = $2; - } +GenericArgumentClause : GenericArgumentClauseNonEmpty | { $$ = MakeEmptyGenericArgumentsNode(); diff --git a/generic.w b/generic.w index b7f6013..b85d7e8 100644 --- a/generic.w +++ b/generic.w @@ -14,9 +14,24 @@ struct MemoryBlock start: MemoryAddress; capacity: uint; - AddressOf(count: uint): MemoryAddress + AddressOf(index: uint): MemoryAddress { - return start + (count * @sizeof()); + return start + (index * @sizeof()); + } + + Get(index: uint): T + { + return @dereference(AddressOf(index)); + } + + Set(index: uint, value: T): void + { + @memcpy(AddressOf(index), @addr(value), @sizeof()); + } + + Free(): void + { + @free(start); } } @@ -24,8 +39,21 @@ struct Program { static Main(): int { x: int = 4; y: int = Foo.Func(x); - addr: MemoryAddress = @malloc(y); - @free(addr); - return x; + block: MemoryBlock = MemoryBlock + { + capacity: y, + start: @malloc(y * @sizeof()) + }; + block.Set(0, 5); + block.Set(1, 3); + block.Set(2, 9); + block.Set(3, 100); + Console.PrintLine("%p", block.start); + Console.PrintLine("%i", block.Get(0)); + Console.PrintLine("%i", block.Get(1)); + Console.PrintLine("%i", block.Get(2)); + Console.PrintLine("%i", block.Get(3)); + block.Free(); + return 0; } } diff --git a/src/ast.c b/src/ast.c index dc014bb..a4cd091 100644 --- a/src/ast.c +++ b/src/ast.c @@ -19,6 +19,8 @@ const char *SyntaxKindString(SyntaxKind syntaxKind) return "BinaryExpression"; case Comment: return "Comment"; + case ConcreteGenericTypeNode: + return "ConcreteGenericTypeNode"; case CustomTypeNode: return "CustomTypeNode"; case Declaration: @@ -27,6 +29,8 @@ const char *SyntaxKindString(SyntaxKind syntaxKind) return "ForLoop"; case DeclarationSequence: return "DeclarationSequence"; + case FieldInit: + return "FieldInit"; case FunctionArgumentSequence: return "FunctionArgumentSequence"; case FunctionCallExpression: @@ -71,6 +75,10 @@ const char *SyntaxKindString(SyntaxKind syntaxKind) return "StringLiteral"; case StructDeclaration: return "StructDeclaration"; + case StructInit: + return "StructInit"; + case StructInitFields: + return "StructInitFields"; case SystemCall: return "SystemCall"; case Type: @@ -95,11 +103,12 @@ Node *MakePrimitiveTypeNode(PrimitiveType type) return node; } -Node *MakeCustomTypeNode(char *name) +Node *MakeCustomTypeNode(Node *identifierNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = CustomTypeNode; - node->customType.name = strdup(name); + node->customType.name = strdup(identifierNode->identifier.name); + free(identifierNode); return node; } @@ -111,6 +120,18 @@ Node *MakeReferenceTypeNode(Node *typeNode) return node; } +Node *MakeConcreteGenericTypeNode( + Node *identifierNode, + Node *genericArgumentsNode) +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = ConcreteGenericTypeNode; + node->concreteGenericType.name = strdup(identifierNode->identifier.name); + node->concreteGenericType.genericArguments = genericArgumentsNode; + free(identifierNode); + return node; +} + Node *MakeTypeNode(Node *typeNode) { Node *node = (Node *)malloc(sizeof(Node)); @@ -542,6 +563,55 @@ Node *MakeForLoopNode( return node; } +Node *MakeFieldInitNode(Node *identifierNode, Node *expressionNode) +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = FieldInit; + node->fieldInit.identifier = identifierNode; + node->fieldInit.expression = expressionNode; + return node; +} + +Node *StartStructInitFieldsNode(Node *fieldInitNode) +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = StructInitFields; + node->structInitFields.fieldInits = (Node **)malloc(sizeof(Node *)); + node->structInitFields.fieldInits[0] = fieldInitNode; + node->structInitFields.count = 1; + return node; +} + +Node *AddFieldInitNode(Node *structInitFieldsNode, Node *fieldInitNode) +{ + structInitFieldsNode->structInitFields.fieldInits = realloc( + structInitFieldsNode->structInitFields.fieldInits, + sizeof(Node *) * (structInitFieldsNode->structInitFields.count + 1)); + structInitFieldsNode->structInitFields + .fieldInits[structInitFieldsNode->structInitFields.count] = + fieldInitNode; + structInitFieldsNode->structInitFields.count += 1; + return structInitFieldsNode; +} + +Node *MakeEmptyFieldInitNode() +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = StructInitFields; + node->structInitFields.fieldInits = NULL; + node->structInitFields.count = 0; + return node; +} + +Node *MakeStructInitExpressionNode(Node *typeNode, Node *structInitFieldsNode) +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = StructInit; + node->structInit.type = typeNode; + node->structInit.initFields = structInitFieldsNode; + return node; +} + static const char *PrimitiveTypeToString(PrimitiveType type) { switch (type) @@ -624,6 +694,11 @@ void PrintNode(Node *node, uint32_t tabCount) PrintNode(node->binaryExpression.right, tabCount + 1); return; + case ConcreteGenericTypeNode: + printf("%s\n", node->concreteGenericType.name); + PrintNode(node->concreteGenericType.genericArguments, tabCount + 1); + return; + case CustomTypeNode: printf("%s\n", node->customType.name); return; @@ -642,6 +717,12 @@ void PrintNode(Node *node, uint32_t tabCount) } return; + case FieldInit: + printf("\n"); + PrintNode(node->fieldInit.identifier, tabCount + 1); + PrintNode(node->fieldInit.expression, tabCount + 1); + return; + case ForLoop: printf("\n"); PrintNode(node->forLoop.declaration, tabCount + 1); @@ -797,6 +878,20 @@ void PrintNode(Node *node, uint32_t tabCount) PrintNode(node->structDeclaration.declarationSequence, tabCount + 1); return; + case StructInit: + printf("\n"); + PrintNode(node->structInit.type, tabCount + 1); + PrintNode(node->structInit.initFields, tabCount + 1); + return; + + case StructInitFields: + printf("\n"); + for (i = 0; i < node->structInitFields.count; i += 1) + { + PrintNode(node->structInitFields.fieldInits[i], tabCount + 1); + } + return; + case SystemCall: printf("\n"); PrintNode(node->systemCall.identifier, tabCount + 1); @@ -843,6 +938,10 @@ void Recurse(Node *node, void (*func)(Node *)) case Comment: return; + case ConcreteGenericTypeNode: + func(node->concreteGenericType.genericArguments); + return; + case CustomTypeNode: return; @@ -858,6 +957,11 @@ void Recurse(Node *node, void (*func)(Node *)) } return; + case FieldInit: + func(node->fieldInit.identifier); + func(node->fieldInit.expression); + return; + case ForLoop: func(node->forLoop.declaration); func(node->forLoop.startNumber); @@ -979,6 +1083,18 @@ void Recurse(Node *node, void (*func)(Node *)) func(node->structDeclaration.declarationSequence); return; + case StructInit: + func(node->structInit.type); + func(node->structInit.initFields); + return; + + case StructInitFields: + for (i = 0; i < node->structInitFields.count; i += 1) + { + func(node->structInitFields.fieldInits[i]); + } + return; + case SystemCall: func(node->systemCall.identifier); func(node->systemCall.argumentSequence); @@ -1004,6 +1120,8 @@ void Recurse(Node *node, void (*func)(Node *)) TypeTag *MakeTypeTag(Node *node) { + uint32_t i; + if (node == NULL) { fprintf( @@ -1034,6 +1152,28 @@ TypeTag *MakeTypeTag(Node *node) tag->value.customType = strdup(node->customType.name); break; + case ConcreteGenericTypeNode: + tag->type = ConcreteGeneric; + tag->value.concreteGenericType.name = + strdup(node->concreteGenericType.name); + tag->value.concreteGenericType.genericArgumentCount = + node->concreteGenericType.genericArguments->genericArguments.count; + tag->value.concreteGenericType.genericArguments = malloc( + sizeof(TypeTag *) * + tag->value.concreteGenericType.genericArgumentCount); + + for (i = 0; + i < + node->concreteGenericType.genericArguments->genericArguments.count; + i += 1) + { + tag->value.concreteGenericType.genericArguments[i] = MakeTypeTag( + node->concreteGenericType.genericArguments->genericArguments + .arguments[i] + ->genericArgument.type); + } + break; + case Declaration: tag = MakeTypeTag(node->declaration.type); break; @@ -1078,6 +1218,8 @@ TypeTag *MakeTypeTag(Node *node) char *TypeTagToString(TypeTag *tag) { + uint32_t i; + if (tag == NULL) { fprintf( @@ -1114,6 +1256,64 @@ char *TypeTagToString(TypeTag *tag) sprintf(result, "Generic<%s>", tag->value.genericType); return result; } + + case ConcreteGeneric: + { + char *result = strdup(tag->value.concreteGenericType.name); + uint32_t len = strlen(result); + len += 2; + result = realloc(result, sizeof(char) * len); + strcat(result, "<"); + + for (i = 0; i < tag->value.concreteGenericType.genericArgumentCount; + i += 1) + { + char *inner = TypeTagToString( + tag->value.concreteGenericType.genericArguments[i]); + len += strlen(inner); + result = realloc(result, sizeof(char) * (len + 3)); + if (i != tag->value.concreteGenericType.genericArgumentCount - 1) + { + strcat(result, ", "); + } + strcat(result, inner); + } + result = realloc(result, sizeof(char) * (len + 1)); + strcat(result, ">"); + return result; + } + } +} + +uint8_t TypeTagEqual(TypeTag *typeTagA, TypeTag *typeTagB) +{ + if (typeTagA->type != typeTagB->type) + { + return 0; + } + + switch (typeTagA->type) + { + case Primitive: + return typeTagA->value.primitiveType == typeTagB->value.primitiveType; + + case Reference: + return TypeTagEqual( + typeTagA->value.referenceType, + typeTagB->value.referenceType); + + case Custom: + return strcmp(typeTagA->value.customType, typeTagB->value.customType) == + 0; + + case Generic: + return strcmp( + typeTagA->value.genericType, + typeTagB->value.genericType) == 0; + + default: + fprintf(stderr, "Invalid type comparison!"); + return 0; } } @@ -1164,6 +1364,11 @@ void LinkParentPointers(Node *node, Node *prev) } return; + case FieldInit: + LinkParentPointers(node->fieldInit.identifier, node); + LinkParentPointers(node->fieldInit.expression, node); + return; + case ForLoop: LinkParentPointers(node->forLoop.declaration, node); LinkParentPointers(node->forLoop.startNumber, node); @@ -1288,6 +1493,18 @@ void LinkParentPointers(Node *node, Node *prev) LinkParentPointers(node->structDeclaration.declarationSequence, node); return; + case StructInit: + LinkParentPointers(node->structInit.type, node); + LinkParentPointers(node->structInit.initFields, node); + return; + + case StructInitFields: + for (i = 0; i < node->structInitFields.count; i += 1) + { + LinkParentPointers(node->structInitFields.fieldInits[i], node); + } + return; + case SystemCall: LinkParentPointers(node->systemCall.identifier, node); LinkParentPointers(node->systemCall.argumentSequence, node); diff --git a/src/ast.h b/src/ast.h index c7fd974..3fdedf9 100644 --- a/src/ast.h +++ b/src/ast.h @@ -19,9 +19,11 @@ typedef enum Assignment, BinaryExpression, Comment, + ConcreteGenericTypeNode, CustomTypeNode, Declaration, DeclarationSequence, + FieldInit, ForLoop, FunctionArgumentSequence, FunctionCallExpression, @@ -46,6 +48,8 @@ typedef enum StaticModifier, StringLiteral, StructDeclaration, + StructInit, + StructInitFields, SystemCall, Type, UnaryExpression @@ -86,7 +90,16 @@ typedef union BinaryOperator binaryOperator; } Operator; -typedef struct TypeTag +typedef struct TypeTag TypeTag; + +typedef struct ConcreteGenericTypeTag +{ + char *name; + TypeTag **genericArguments; + uint32_t genericArgumentCount; +} ConcreteGenericTypeTag; + +struct TypeTag { enum Type { @@ -94,7 +107,8 @@ typedef struct TypeTag Primitive, Reference, Custom, - Generic + Generic, + ConcreteGeneric } type; union { @@ -106,8 +120,10 @@ typedef struct TypeTag char *customType; /* Valid when type = Generic. */ char *genericType; + /* Valid when type = ConcreteGeneric */ + ConcreteGenericTypeTag concreteGenericType; } value; -} TypeTag; +}; typedef struct Node Node; @@ -146,6 +162,12 @@ struct Node } comment; + struct + { + char *name; + Node *genericArguments; + } concreteGenericType; + struct { char *name; @@ -163,6 +185,12 @@ struct Node uint32_t count; } declarationSequence; + struct + { + Node *identifier; + Node *expression; + } fieldInit; + struct { Node *declaration; @@ -304,6 +332,18 @@ struct Node Node *genericDeclarations; } structDeclaration; + struct + { + Node *type; + Node *initFields; + } structInit; + + struct + { + Node **fieldInits; + uint32_t count; + } structInitFields; + struct { Node *identifier; @@ -329,8 +369,11 @@ const char *SyntaxKindString(SyntaxKind syntaxKind); uint8_t IsPrimitiveType(Node *typeNode); Node *MakePrimitiveTypeNode(PrimitiveType type); -Node *MakeCustomTypeNode(char *string); +Node *MakeCustomTypeNode(Node *identifierNode); Node *MakeReferenceTypeNode(Node *typeNode); +Node *MakeConcreteGenericTypeNode( + Node *identifierNode, + Node *genericArgumentsNode); Node *MakeTypeNode(Node *typeNode); Node *MakeIdentifierNode(const char *id); Node *MakeNumberNode(const char *numberString); @@ -397,6 +440,11 @@ Node *MakeForLoopNode( Node *startNumberNode, Node *endNumberNode, Node *statementSequenceNode); +Node *MakeFieldInitNode(Node *identifierNode, Node *expressionNode); +Node *StartStructInitFieldsNode(Node *fieldInitNode); +Node *AddFieldInitNode(Node *structInitFieldsNode, Node *fieldInitNode); +Node *MakeEmptyFieldInitNode(); +Node *MakeStructInitExpressionNode(Node *typeNode, Node *structInitFieldsNode); void PrintNode(Node *node, uint32_t tabCount); const char *SyntaxKindString(SyntaxKind syntaxKind); @@ -412,6 +460,7 @@ void LinkParentPointers(Node *node, Node *prev); TypeTag *MakeTypeTag(Node *node); char *TypeTagToString(TypeTag *tag); +uint8_t TypeTagEqual(TypeTag *typeTagA, TypeTag *typeTagB); Node *LookupIdNode(Node *current, Node *prev, char *target); diff --git a/src/codegen.c b/src/codegen.c index 2e67b17..b2421f9 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -83,10 +83,11 @@ typedef struct MonomorphizedGenericFunctionHashArray #define NUM_MONOMORPHIZED_HASH_BUCKETS 1031 +typedef struct StructTypeDeclaration StructTypeDeclaration; + typedef struct StructTypeGenericFunction { - char *parentStructName; - LLVMTypeRef parentStructPointerType; + StructTypeDeclaration *parentStruct; char *name; Node *functionDeclarationNode; uint8_t isStatic; @@ -94,8 +95,9 @@ typedef struct StructTypeGenericFunction monomorphizedFunctions[NUM_MONOMORPHIZED_HASH_BUCKETS]; } StructTypeGenericFunction; -typedef struct StructTypeDeclaration +struct StructTypeDeclaration { + LLVMModuleRef module; char *name; LLVMTypeRef structType; LLVMTypeRef structPointerType; @@ -107,7 +109,7 @@ typedef struct StructTypeDeclaration StructTypeGenericFunction *genericFunctions; uint32_t genericFunctionCount; -} StructTypeDeclaration; +}; StructTypeDeclaration *structTypeDeclarations; uint32_t structTypeDeclarationCount; @@ -117,7 +119,7 @@ typedef struct MonomorphizedGenericStructHashEntry uint64_t key; TypeTag **types; uint32_t typeCount; - StructTypeDeclaration structDeclaration; + StructTypeDeclaration *structDeclaration; } MonomorphizedGenericStructHashEntry; typedef struct MonomorphizedGenericStructHashArray @@ -128,6 +130,7 @@ typedef struct MonomorphizedGenericStructHashArray typedef struct GenericStructTypeDeclaration { + LLVMModuleRef module; Node *structDeclarationNode; MonomorphizedGenericStructHashArray monomorphizedStructs[NUM_MONOMORPHIZED_HASH_BUCKETS]; @@ -148,16 +151,24 @@ uint32_t systemFunctionCount; /* FUNCTION FORWARD DECLARATIONS */ static LLVMBasicBlockRef CompileStatement( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, LLVMValueRef function, Node *statement); static LLVMValueRef CompileExpression( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, Node *expression); +static void CompileFunction( + StructTypeDeclaration *structTypeDeclaration, + Node *functionDeclaration); + +static LLVMTypeRef ResolveType(TypeTag *typeTag); + static Scope *CreateScope() { Scope *scope = malloc(sizeof(Scope)); @@ -215,6 +226,122 @@ static void PopScopeFrame(Scope *scope) realloc(scope->scopeStack, sizeof(ScopeFrame) * scope->scopeStackCount); } +static inline uint64_t HashTypeTags(TypeTag **tags, uint32_t count) +{ + const uint64_t HASH_FACTOR = 97; + uint64_t result = 1; + uint32_t i; + + for (i = 0; i < count; i += 1) + { + result *= HASH_FACTOR + str_hash(TypeTagToString(tags[i])); + } + + return result; +} + +static void AddLocalVariable( + Scope *scope, + LLVMValueRef pointer, /* can be NULL */ + LLVMValueRef value, /* can be NULL */ + char *name) +{ + ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1]; + uint32_t index = scopeFrame->localVariableCount; + + scopeFrame->localVariables = realloc( + scopeFrame->localVariables, + sizeof(LocalVariable) * (scopeFrame->localVariableCount + 1)); + scopeFrame->localVariables[index].name = strdup(name); + scopeFrame->localVariables[index].pointer = pointer; + scopeFrame->localVariables[index].value = value; + + scopeFrame->localVariableCount += 1; +} + +static void AddGenericVariable(Scope *scope, TypeTag *typeTag, char *name) +{ + ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1]; + uint32_t index = scopeFrame->genericTypeCount; + + scopeFrame->genericTypes = realloc( + scopeFrame->genericTypes, + sizeof(LocalGenericType) * (scopeFrame->genericTypeCount + 1)); + scopeFrame->genericTypes[index].name = strdup(name); + scopeFrame->genericTypes[index].concreteTypeTag = typeTag; + scopeFrame->genericTypes[index].type = ResolveType(typeTag); + + scopeFrame->genericTypeCount += 1; +} + +static void AddStructVariablesToScope( + StructTypeDeclaration *structTypeDeclaration, + LLVMBuilderRef builder, + LLVMValueRef structPointer) +{ + uint32_t i; + + for (i = 0; i < structTypeDeclaration->fieldCount; i += 1) + { + char *ptrName = strdup(structTypeDeclaration->fields[i].name); + ptrName = w_strcat(ptrName, "_ptr"); + LLVMValueRef elementPointer = LLVMBuildStructGEP( + builder, + structPointer, + structTypeDeclaration->fields[i].index, + ptrName); + free(ptrName); + + AddLocalVariable( + scope, + elementPointer, + NULL, + structTypeDeclaration->fields[i].name); + } +} + +static void AddFieldToStructDeclaration( + StructTypeDeclaration *structTypeDeclaration, + char *name) +{ + structTypeDeclaration->fields = realloc( + structTypeDeclaration->fields, + sizeof(StructTypeField) * (structTypeDeclaration->fieldCount + 1)); + structTypeDeclaration->fields[structTypeDeclaration->fieldCount].name = + strdup(name); + structTypeDeclaration->fields[structTypeDeclaration->fieldCount].index = + structTypeDeclaration->fieldCount; + structTypeDeclaration->fieldCount += 1; +} + +static void AddGenericStructDeclaration( + LLVMModuleRef module, + Node *structDeclarationNode) +{ + uint32_t i; + + genericStructTypeDeclarations = realloc( + genericStructTypeDeclarations, + sizeof(GenericStructTypeDeclaration) * + (genericStructTypeDeclarationCount + 1)); + genericStructTypeDeclarations[genericStructTypeDeclarationCount] + .structDeclarationNode = structDeclarationNode; + genericStructTypeDeclarations[genericStructTypeDeclarationCount].module = + module; + + for (i = 0; i < NUM_MONOMORPHIZED_HASH_BUCKETS; i += 1) + { + genericStructTypeDeclarations[genericStructTypeDeclarationCount] + .monomorphizedStructs[i] + .elements = NULL; + genericStructTypeDeclarations[genericStructTypeDeclarationCount] + .monomorphizedStructs[i] + .count = 0; + } + + genericStructTypeDeclarationCount += 1; +} + static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type) { switch (type) @@ -284,6 +411,215 @@ static LLVMTypeRef LookupCustomType(char *name) return NULL; } +static StructTypeDeclaration *AddStructDeclaration( + LLVMModuleRef module, + LLVMTypeRef wStructType, + LLVMTypeRef wStructPointerType, + char *name) +{ + uint32_t index = structTypeDeclarationCount; + structTypeDeclarations = realloc( + structTypeDeclarations, + sizeof(StructTypeDeclaration) * (structTypeDeclarationCount + 1)); + structTypeDeclarations[index].module = module; + structTypeDeclarations[index].structType = wStructType; + structTypeDeclarations[index].structPointerType = wStructPointerType; + structTypeDeclarations[index].name = strdup(name); + structTypeDeclarations[index].fields = NULL; + structTypeDeclarations[index].fieldCount = 0; + structTypeDeclarations[index].functions = NULL; + structTypeDeclarations[index].functionCount = 0; + structTypeDeclarations[index].genericFunctions = NULL; + structTypeDeclarations[index].genericFunctionCount = 0; + + structTypeDeclarationCount += 1; + + return &structTypeDeclarations[index]; +} + +static StructTypeDeclaration *CompileMonomorphizedGenericStruct( + GenericStructTypeDeclaration *genericStructTypeDeclaration, + TypeTag **genericArgumentTypes, + uint32_t genericArgumentTypeCount) +{ + uint32_t i = 0; + uint32_t fieldCount = 0; + Node *structDeclarationNode = + genericStructTypeDeclaration->structDeclarationNode; + uint32_t declarationCount = + structDeclarationNode->structDeclaration.declarationSequence + ->declarationSequence.count; + LLVMTypeRef types[declarationCount]; + + PushScopeFrame(scope); + + for (i = 0; i < genericArgumentTypeCount; i += 1) + { + AddGenericVariable( + scope, + genericArgumentTypes[i], + structDeclarationNode->structDeclaration.genericDeclarations + ->genericDeclarations.declarations[i] + ->genericDeclaration.identifier->identifier.name); + } + + char *structName = strdup( + structDeclarationNode->structDeclaration.identifier->identifier.name); + + for (i = 0; i < genericArgumentTypeCount; i += 1) + { + char *inner = TypeTagToString(genericArgumentTypes[i]); + structName = w_strcat(structName, "_"); + structName = w_strcat(structName, inner); + } + + LLVMContextRef context = + LLVMGetGlobalContext(); /* FIXME: should we pass a context? */ + LLVMTypeRef wStructType = LLVMStructCreateNamed(context, structName); + LLVMTypeRef wStructPointerType = LLVMPointerType(wStructType, 0); + + StructTypeDeclaration *declaration = AddStructDeclaration( + genericStructTypeDeclaration->module, + wStructType, + wStructPointerType, + structName); + + free(structName); + + /* first build the structure def */ + for (i = 0; i < declarationCount; i += 1) + { + switch (structDeclarationNode->structDeclaration.declarationSequence + ->declarationSequence.sequence[i] + ->syntaxKind) + { + case Declaration: + types[fieldCount] = ResolveType( + structDeclarationNode->structDeclaration.declarationSequence + ->declarationSequence.sequence[i] + ->declaration.identifier->typeTag); + AddFieldToStructDeclaration( + declaration, + structDeclarationNode->structDeclaration.declarationSequence + ->declarationSequence.sequence[i] + ->declaration.identifier->identifier.name); + fieldCount += 1; + break; + } + } + + LLVMStructSetBody(wStructType, types, fieldCount, 1); + + /* now we wire up the functions */ + for (i = 0; i < declarationCount; i += 1) + { + switch (structDeclarationNode->structDeclaration.declarationSequence + ->declarationSequence.sequence[i] + ->syntaxKind) + { + case FunctionDeclaration: + CompileFunction( + declaration, + structDeclarationNode->structDeclaration.declarationSequence + ->declarationSequence.sequence[i]); + break; + } + } + + PopScopeFrame(scope); + + return declaration; +} + +static StructTypeDeclaration *LookupGenericStructType( + ConcreteGenericTypeTag *typeTag) +{ + uint32_t i, j, k; + uint64_t typeHash; + uint8_t match; + TypeTag *genericTypeTags[typeTag->genericArgumentCount]; + + for (i = 0; i < typeTag->genericArgumentCount; i += 1) + { + genericTypeTags[i] = ConcretizeType(typeTag->genericArguments[i]); + } + + for (i = 0; i < genericStructTypeDeclarationCount; i += 1) + { + if (strcmp( + genericStructTypeDeclarations[i] + .structDeclarationNode->structDeclaration.identifier + ->identifier.name, + typeTag->name) == 0) + { + typeHash = + HashTypeTags(genericTypeTags, typeTag->genericArgumentCount); + + MonomorphizedGenericStructHashArray *hashArray = + &genericStructTypeDeclarations[i].monomorphizedStructs + [typeHash % NUM_MONOMORPHIZED_HASH_BUCKETS]; + + MonomorphizedGenericStructHashEntry *hashEntry = NULL; + + for (j = 0; j < hashArray->count; j += 1) + { + match = 1; + + for (k = 0; k < hashArray->elements[j].typeCount; k += 1) + { + if (!TypeTagEqual( + hashArray->elements[j].types[k], + genericTypeTags[k])) + { + match = 0; + break; + } + } + + if (match) + { + hashEntry = &hashArray->elements[i]; + break; + } + } + + if (hashEntry == NULL) + { + StructTypeDeclaration *structTypeDeclaration = + CompileMonomorphizedGenericStruct( + &genericStructTypeDeclarations[i], + genericTypeTags, + typeTag->genericArgumentCount); + + hashArray->elements = realloc( + hashArray->elements, + sizeof(MonomorphizedGenericStructHashEntry) * + (hashArray->count + 1)); + hashArray->elements[hashArray->count].key = typeHash; + hashArray->elements[hashArray->count].types = + malloc(sizeof(TypeTag *) * typeTag->genericArgumentCount); + hashArray->elements[hashArray->count].typeCount = + typeTag->genericArgumentCount; + hashArray->elements[hashArray->count].structDeclaration = + structTypeDeclaration; + for (j = 0; j < typeTag->genericArgumentCount; j += 1) + { + hashArray->elements[hashArray->count].types[j] = + genericTypeTags[j]; + } + hashArray->count += 1; + + hashEntry = &hashArray->elements[hashArray->count - 1]; + } + + return hashEntry->structDeclaration; + } + } + + fprintf(stderr, "Could not find generic struct declaration!"); + return NULL; +} + static LLVMTypeRef ResolveType(TypeTag *typeTag) { if (typeTag->type == Primitive) @@ -302,6 +638,11 @@ static LLVMTypeRef ResolveType(TypeTag *typeTag) { return LookupGenericType(typeTag->value.genericType)->type; } + else if (typeTag->type == ConcreteGeneric) + { + return LookupGenericStructType(&typeTag->value.concreteGenericType) + ->structType; + } else { fprintf(stderr, "Unknown type node!\n"); @@ -340,74 +681,8 @@ static SystemFunction *LookupSystemFunction(Node *systemCallExpression) return NULL; } -static void AddLocalVariable( - Scope *scope, - LLVMValueRef pointer, /* can be NULL */ - LLVMValueRef value, /* can be NULL */ - char *name) -{ - ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1]; - uint32_t index = scopeFrame->localVariableCount; - - scopeFrame->localVariables = realloc( - scopeFrame->localVariables, - sizeof(LocalVariable) * (scopeFrame->localVariableCount + 1)); - scopeFrame->localVariables[index].name = strdup(name); - scopeFrame->localVariables[index].pointer = pointer; - scopeFrame->localVariables[index].value = value; - - scopeFrame->localVariableCount += 1; -} - -static void AddGenericVariable(Scope *scope, TypeTag *typeTag, char *name) -{ - ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1]; - uint32_t index = scopeFrame->genericTypeCount; - - scopeFrame->genericTypes = realloc( - scopeFrame->genericTypes, - sizeof(LocalGenericType) * (scopeFrame->genericTypeCount + 1)); - scopeFrame->genericTypes[index].name = strdup(name); - scopeFrame->genericTypes[index].concreteTypeTag = typeTag; - scopeFrame->genericTypes[index].type = ResolveType(typeTag); - - scopeFrame->genericTypeCount += 1; -} - -static void AddStructVariablesToScope( - LLVMBuilderRef builder, - LLVMValueRef structPointer) -{ - uint32_t i, j; - - for (i = 0; i < structTypeDeclarationCount; i += 1) - { - if (structTypeDeclarations[i].structPointerType == - LLVMTypeOf(structPointer)) - { - for (j = 0; j < structTypeDeclarations[i].fieldCount; j += 1) - { - char *ptrName = - strdup(structTypeDeclarations[i].fields[j].name); - strcat(ptrName, "_ptr"); - LLVMValueRef elementPointer = LLVMBuildStructGEP( - builder, - structPointer, - structTypeDeclarations[i].fields[j].index, - ptrName); - free(ptrName); - - AddLocalVariable( - scope, - elementPointer, - NULL, - structTypeDeclarations[i].fields[j].name); - } - } - } -} - -static LLVMTypeRef FindStructType(char *name) +/* FIXME: this is awkward, should just resolve the type */ +static LLVMTypeRef LookupStructTypeByName(char *name) { uint32_t i; @@ -422,6 +697,22 @@ static LLVMTypeRef FindStructType(char *name) return NULL; } +static StructTypeDeclaration *LookupStructDeclaration(LLVMTypeRef structType) +{ + uint32_t i; + + for (i = 0; i < structTypeDeclarationCount; i += 1) + { + if (structTypeDeclarations[i].structType == structType) + { + return &structTypeDeclarations[i]; + } + } + + fprintf(stderr, "Struct type not found!"); + return NULL; +} + static LLVMValueRef FindStructFieldPointer( LLVMBuilderRef builder, LLVMValueRef structPointer, @@ -440,7 +731,7 @@ static LLVMValueRef FindStructFieldPointer( if (strcmp(structTypeDeclarations[i].fields[j].name, name) == 0) { char *ptrName = strdup(name); - strcat(ptrName, "_ptr"); + ptrName = w_strcat(ptrName, "_ptr"); return LLVMBuildStructGEP( builder, structPointer, @@ -504,167 +795,62 @@ static LLVMValueRef FindVariableValue(LLVMBuilderRef builder, char *name) return NULL; } -static void AddStructDeclaration( - LLVMTypeRef wStructType, - LLVMTypeRef wStructPointerType, - char *name, - Node **fieldDeclarations, - uint32_t fieldDeclarationCount) -{ - uint32_t i; - uint32_t index = structTypeDeclarationCount; - structTypeDeclarations = realloc( - structTypeDeclarations, - sizeof(StructTypeDeclaration) * (structTypeDeclarationCount + 1)); - structTypeDeclarations[index].structType = wStructType; - structTypeDeclarations[index].structPointerType = wStructPointerType; - structTypeDeclarations[index].name = strdup(name); - structTypeDeclarations[index].fields = NULL; - structTypeDeclarations[index].fieldCount = 0; - structTypeDeclarations[index].functions = NULL; - structTypeDeclarations[index].functionCount = 0; - structTypeDeclarations[index].genericFunctions = NULL; - structTypeDeclarations[index].genericFunctionCount = 0; - - for (i = 0; i < fieldDeclarationCount; i += 1) - { - structTypeDeclarations[index].fields = realloc( - structTypeDeclarations[index].fields, - sizeof(StructTypeField) * - (structTypeDeclarations[index].fieldCount + 1)); - structTypeDeclarations[index].fields[i].name = strdup( - fieldDeclarations[i]->declaration.identifier->identifier.name); - structTypeDeclarations[index].fields[i].index = i; - structTypeDeclarations[index].fieldCount += 1; - } - - structTypeDeclarationCount += 1; -} - -static void AddGenericStructDeclaration(Node *structDeclarationNode) -{ - uint32_t i; - - genericStructTypeDeclarations = realloc( - genericStructTypeDeclarations, - sizeof(GenericStructTypeDeclaration) * - (genericStructTypeDeclarationCount + 1)); - genericStructTypeDeclarations[genericStructTypeDeclarationCount] - .structDeclarationNode = structDeclarationNode; - - for (i = 0; i < NUM_MONOMORPHIZED_HASH_BUCKETS; i += 1) - { - genericStructTypeDeclarations[genericStructTypeDeclarationCount] - .monomorphizedStructs[i] - .elements = NULL; - genericStructTypeDeclarations[genericStructTypeDeclarationCount] - .monomorphizedStructs[i] - .count = 0; - } - - genericStructTypeDeclarationCount += 1; -} - -/* FIXME: pass the declaration itself */ static void DeclareStructFunction( - LLVMTypeRef wStructPointerType, + StructTypeDeclaration *structTypeDeclaration, LLVMValueRef function, LLVMTypeRef returnType, uint8_t isStatic, char *name) { - uint32_t i, index; + uint32_t index = structTypeDeclaration->functionCount; - for (i = 0; i < structTypeDeclarationCount; i += 1) - { - if (structTypeDeclarations[i].structPointerType == wStructPointerType) - { - index = structTypeDeclarations[i].functionCount; - structTypeDeclarations[i].functions = realloc( - structTypeDeclarations[i].functions, - sizeof(StructTypeFunction) * - (structTypeDeclarations[i].functionCount + 1)); - structTypeDeclarations[i].functions[index].name = strdup(name); - structTypeDeclarations[i].functions[index].function = function; - structTypeDeclarations[i].functions[index].returnType = returnType; - structTypeDeclarations[i].functions[index].isStatic = isStatic; - structTypeDeclarations[i].functionCount += 1; - - return; - } - } - - fprintf(stderr, "Could not find struct type for function!\n"); + structTypeDeclaration->functions = realloc( + structTypeDeclaration->functions, + sizeof(StructTypeFunction) * + (structTypeDeclaration->functionCount + 1)); + structTypeDeclaration->functions[index].name = strdup(name); + structTypeDeclaration->functions[index].function = function; + structTypeDeclaration->functions[index].returnType = returnType; + structTypeDeclaration->functions[index].isStatic = isStatic; + structTypeDeclaration->functionCount += 1; } -/* FIXME: pass the declaration itself */ static void DeclareGenericStructFunction( - LLVMTypeRef wStructPointerType, + StructTypeDeclaration *structTypeDeclaration, Node *functionDeclarationNode, uint8_t isStatic, - char *parentStructName, char *name) { - uint32_t i, j, index; + uint32_t i, index; - for (i = 0; i < structTypeDeclarationCount; i += 1) + index = structTypeDeclaration->genericFunctionCount; + structTypeDeclaration->genericFunctions = realloc( + structTypeDeclaration->genericFunctions, + sizeof(StructTypeGenericFunction) * + (structTypeDeclaration->genericFunctionCount + 1)); + structTypeDeclaration->genericFunctions[index].name = strdup(name); + structTypeDeclaration->genericFunctions[index].parentStruct = + structTypeDeclaration; + structTypeDeclaration->genericFunctions[index].functionDeclarationNode = + functionDeclarationNode; + structTypeDeclaration->genericFunctions[index].isStatic = isStatic; + + for (i = 0; i < NUM_MONOMORPHIZED_HASH_BUCKETS; i += 1) { - if (structTypeDeclarations[i].structPointerType == wStructPointerType) - { - index = structTypeDeclarations[i].genericFunctionCount; - structTypeDeclarations[i].genericFunctions = realloc( - structTypeDeclarations[i].genericFunctions, - sizeof(StructTypeGenericFunction) * - (structTypeDeclarations[i].genericFunctionCount + 1)); - structTypeDeclarations[i].genericFunctions[index].name = - strdup(name); - structTypeDeclarations[i].genericFunctions[index].parentStructName = - parentStructName; - structTypeDeclarations[i].structPointerType = wStructPointerType; - structTypeDeclarations[i] - .genericFunctions[index] - .functionDeclarationNode = functionDeclarationNode; - structTypeDeclarations[i].genericFunctions[index].isStatic = - isStatic; - - for (j = 0; j < NUM_MONOMORPHIZED_HASH_BUCKETS; j += 1) - { - structTypeDeclarations[i] - .genericFunctions[index] - .monomorphizedFunctions[j] - .elements = NULL; - structTypeDeclarations[i] - .genericFunctions[index] - .monomorphizedFunctions[j] - .count = 0; - } - - structTypeDeclarations[i].genericFunctionCount += 1; - - return; - } - } -} - -static inline uint64_t HashTypeTags(TypeTag **tags, uint32_t count) -{ - const uint64_t HASH_FACTOR = 97; - uint64_t result = 1; - uint32_t i; - - for (i = 0; i < count; i += 1) - { - result *= HASH_FACTOR + str_hash(TypeTagToString(tags[i])); + structTypeDeclaration->genericFunctions[index] + .monomorphizedFunctions[i] + .elements = NULL; + structTypeDeclaration->genericFunctions[index] + .monomorphizedFunctions[i] + .count = 0; } - return result; + structTypeDeclaration->genericFunctionCount += 1; } /* FIXME: lots of duplication with non-generic function compile */ static StructTypeFunction CompileGenericFunction( - LLVMModuleRef module, - char *parentStructName, - LLVMTypeRef wStructPointerType, + StructTypeDeclaration *structTypeDeclaration, TypeTag **resolvedGenericArgumentTypes, uint32_t genericArgumentTypeCount, Node *functionDeclaration) @@ -680,6 +866,7 @@ static StructTypeFunction CompileGenericFunction( LLVMTypeRef paramTypes[argumentCount + 1]; uint32_t paramIndex = 0; LLVMTypeRef returnType; + LLVMValueRef wStructPointer = NULL; PushScopeFrame(scope); @@ -711,20 +898,22 @@ static StructTypeFunction CompileGenericFunction( } } - char *functionName = strdup(parentStructName); - strcat(functionName, "_"); - strcat( + char *functionName = strdup(structTypeDeclaration->name); + functionName = w_strcat(functionName, "_"); + functionName = w_strcat( functionName, functionSignature->functionSignature.identifier->identifier.name); for (i = 0; i < genericArgumentTypeCount; i += 1) { - strcat(functionName, TypeTagToString(resolvedGenericArgumentTypes[i])); + functionName = w_strcat( + functionName, + TypeTagToString(resolvedGenericArgumentTypes[i])); } if (!isStatic) { - paramTypes[paramIndex] = wStructPointerType; + paramTypes[paramIndex] = structTypeDeclaration->structPointerType; paramIndex += 1; } @@ -746,7 +935,10 @@ static StructTypeFunction CompileGenericFunction( LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, paramIndex, 0); - LLVMValueRef function = LLVMAddFunction(module, functionName, functionType); + LLVMValueRef function = LLVMAddFunction( + structTypeDeclaration->module, + functionName, + functionType); LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry"); LLVMBuilderRef builder = LLVMCreateBuilder(); @@ -754,8 +946,11 @@ static StructTypeFunction CompileGenericFunction( if (!isStatic) { - LLVMValueRef wStructPointer = LLVMGetParam(function, 0); - AddStructVariablesToScope(builder, wStructPointer); + wStructPointer = LLVMGetParam(function, 0); + AddStructVariablesToScope( + structTypeDeclaration, + builder, + wStructPointer); } for (i = 0; i < functionSignature->functionSignature.arguments @@ -765,7 +960,7 @@ static StructTypeFunction CompileGenericFunction( char *ptrName = strdup(functionSignature->functionSignature.arguments ->functionSignatureArguments.sequence[i] ->declaration.identifier->identifier.name); - strcat(ptrName, "_ptr"); + ptrName = w_strcat(ptrName, "_ptr"); LLVMValueRef argument = LLVMGetParam(function, i + !isStatic); LLVMValueRef argumentCopy = LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName); @@ -783,7 +978,8 @@ static StructTypeFunction CompileGenericFunction( for (i = 0; i < functionBody->statementSequence.count; i += 1) { CompileStatement( - module, + structTypeDeclaration, + wStructPointer, builder, function, functionBody->statementSequence.sequence[i]); @@ -816,7 +1012,6 @@ static StructTypeFunction CompileGenericFunction( } static LLVMValueRef LookupGenericFunction( - LLVMModuleRef module, StructTypeGenericFunction *genericFunction, Node *functionCallExpression, LLVMTypeRef *pReturnType, @@ -919,8 +1114,9 @@ static LLVMValueRef LookupGenericFunction( for (j = 0; j < hashArray->elements[i].typeCount; j += 1) { - if (hashArray->elements[i].types[j] != - resolvedGenericArgumentTypes[j]) + if (!TypeTagEqual( + hashArray->elements[i].types[j], + resolvedGenericArgumentTypes[j])) { match = 0; break; @@ -937,14 +1133,11 @@ static LLVMValueRef LookupGenericFunction( if (hashEntry == NULL) { StructTypeFunction function = CompileGenericFunction( - module, - genericFunction->parentStructName, - genericFunction->parentStructPointerType, + genericFunction->parentStruct, resolvedGenericArgumentTypes, genericArgumentTypeCount, genericFunction->functionDeclarationNode); - /* TODO: add to hash */ hashArray->elements = realloc( hashArray->elements, sizeof(MonomorphizedGenericFunctionHashEntry) * @@ -972,16 +1165,13 @@ static LLVMValueRef LookupGenericFunction( } static LLVMValueRef LookupFunctionByType( - LLVMModuleRef module, LLVMTypeRef structType, + char *name, Node *functionCallExpression, LLVMTypeRef *pReturnType, uint8_t *pStatic) { uint32_t i, j; - /* FIXME: hot garbage */ - char *name = functionCallExpression->functionCallExpression.identifier - ->accessExpression.accessor->identifier.name; for (i = 0; i < structTypeDeclarationCount; i += 1) { @@ -1007,7 +1197,6 @@ static LLVMValueRef LookupFunctionByType( name) == 0) { return LookupGenericFunction( - module, &structTypeDeclarations[i].genericFunctions[j], functionCallExpression, pReturnType, @@ -1022,16 +1211,13 @@ static LLVMValueRef LookupFunctionByType( } static LLVMValueRef LookupFunctionByPointerType( - LLVMModuleRef module, LLVMTypeRef structPointerType, + char *name, Node *functionCallExpression, LLVMTypeRef *pReturnType, uint8_t *pStatic) { uint32_t i, j; - /* FIXME: hot garbage */ - char *name = functionCallExpression->functionCallExpression.identifier - ->accessExpression.accessor->identifier.name; for (i = 0; i < structTypeDeclarationCount; i += 1) { @@ -1057,7 +1243,6 @@ static LLVMValueRef LookupFunctionByPointerType( name) == 0) { return LookupGenericFunction( - module, &structTypeDeclarations[i].genericFunctions[j], functionCallExpression, pReturnType, @@ -1072,15 +1257,15 @@ static LLVMValueRef LookupFunctionByPointerType( } static LLVMValueRef LookupFunctionByInstance( - LLVMModuleRef module, LLVMValueRef structPointer, + char *functionName, Node *functionCallExpression, LLVMTypeRef *pReturnType, uint8_t *pStatic) { return LookupFunctionByPointerType( - module, LLVMTypeOf(structPointer), + functionName, functionCallExpression, pReturnType, pStatic); @@ -1102,17 +1287,20 @@ static LLVMValueRef CompileString( } static LLVMValueRef CompileBinaryExpression( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, Node *binaryExpression) { LLVMValueRef left = CompileExpression( - module, + structTypeDeclaration, + selfParam, builder, binaryExpression->binaryExpression.left); LLVMValueRef right = CompileExpression( - module, + structTypeDeclaration, + selfParam, builder, binaryExpression->binaryExpression.right); @@ -1159,7 +1347,8 @@ static LLVMValueRef CompileBinaryExpression( /* FIXME THIS IS ALL BROKEN */ static LLVMValueRef CompileFunctionCallExpression( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, Node *functionCallExpression) { @@ -1171,7 +1360,7 @@ static LLVMValueRef CompileFunctionCallExpression( 1]; LLVMValueRef function; uint8_t isStatic; - LLVMValueRef structInstance; + LLVMValueRef structInstance = NULL; LLVMTypeRef functionReturnType; char *returnName = ""; @@ -1181,15 +1370,19 @@ static LLVMValueRef CompileFunctionCallExpression( if (functionCallExpression->functionCallExpression.identifier->syntaxKind == AccessExpression) { - LLVMTypeRef typeReference = FindStructType( + LLVMTypeRef typeReference = LookupStructTypeByName( functionCallExpression->functionCallExpression.identifier ->accessExpression.accessee->identifier.name); + char *functionName = + functionCallExpression->functionCallExpression.identifier + ->accessExpression.accessor->identifier.name; + if (typeReference != NULL) { function = LookupFunctionByType( - module, typeReference, + functionName, functionCallExpression, &functionReturnType, &isStatic); @@ -1199,17 +1392,38 @@ static LLVMValueRef CompileFunctionCallExpression( structInstance = FindVariablePointer( functionCallExpression->functionCallExpression.identifier ->accessExpression.accessee->identifier.name); + function = LookupFunctionByInstance( - module, structInstance, + functionName, functionCallExpression, &functionReturnType, &isStatic); } } + else if ( + functionCallExpression->functionCallExpression.identifier->syntaxKind == + Identifier) + { + LLVMTypeRef structType = structTypeDeclaration->structType; + char *functionName = functionCallExpression->functionCallExpression + .identifier->identifier.name; + + function = LookupFunctionByType( + structType, + functionName, + functionCallExpression, + &functionReturnType, + &isStatic); + + if (!isStatic) + { + structInstance = selfParam; + } + } else { - fprintf(stderr, "Failed to find function!\n"); + fprintf(stderr, "Function identifier syntax kind not recognized!\n"); return NULL; } @@ -1224,7 +1438,8 @@ static LLVMValueRef CompileFunctionCallExpression( i += 1) { args[argumentCount] = CompileExpression( - module, + structTypeDeclaration, + selfParam, builder, functionCallExpression->functionCallExpression.argumentSequence ->functionArgumentSequence.sequence[i]); @@ -1240,7 +1455,8 @@ static LLVMValueRef CompileFunctionCallExpression( } static LLVMValueRef CompileSystemCallExpression( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, Node *systemCallExpression) { @@ -1255,7 +1471,8 @@ static LLVMValueRef CompileSystemCallExpression( i += 1) { args[i] = CompileExpression( - module, + structTypeDeclaration, + selfParam, builder, systemCallExpression->systemCall.argumentSequence ->functionArgumentSequence.sequence[i]); @@ -1277,6 +1494,57 @@ static LLVMValueRef CompileSystemCallExpression( return LLVMSizeOf(ResolveType(ConcretizeType(typeTag))); } + else if ( + strcmp( + systemCallExpression->systemCall.identifier->identifier.name, + "addr") == 0) + { + return LLVMBuildPtrToInt( + builder, + FindVariablePointer( + systemCallExpression->systemCall.argumentSequence + ->functionArgumentSequence.sequence[0] + ->identifier.name), + LLVMInt64Type(), + "addrResult"); + } + else if ( + strcmp( + systemCallExpression->systemCall.identifier->identifier.name, + "dereference") == 0) + { + TypeTag *typeTag = ConcretizeType( + systemCallExpression->systemCall.genericArguments + ->genericArguments.arguments[0] + ->type.typeNode->typeTag); + + return LLVMBuildLoad( + builder, + LLVMBuildIntToPtr( + builder, + args[0], + LLVMPointerType(ResolveType(typeTag), 0), + "deref_ptr"), + "deref"); + } + else if ( + strcmp( + systemCallExpression->systemCall.identifier->identifier.name, + "bitcast") == 0) + { + TypeTag *typeTag = ConcretizeType( + systemCallExpression->systemCall.genericArguments + ->genericArguments.arguments[0] + ->type.typeNode->typeTag); + + LLVMValueRef expression = args[0]; + + return LLVMBuildBitCast( + builder, + expression, + ResolveType(typeTag), + "castResult"); + } else { fprintf(stderr, "System function not found!"); @@ -1334,8 +1602,50 @@ static LLVMValueRef CompileAllocExpression( return LLVMBuildMalloc(builder, type, "allocation"); } +static LLVMValueRef CompileStructInitExpression( + StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, + LLVMBuilderRef builder, + Node *structInitExpression) +{ + uint32_t i = 0; + + LLVMTypeRef structType = ResolveType( + ConcretizeType(structInitExpression->structInit.type->typeTag)); + + LLVMValueRef structPointer = + LLVMBuildAlloca(builder, structType, "structInit"); + + for (i = 0; + i < + structInitExpression->structInit.initFields->structInitFields.count; + i += 1) + { + LLVMValueRef structFieldPointer = FindStructFieldPointer( + builder, + structPointer, + structInitExpression->structInit.initFields->structInitFields + .fieldInits[i] + ->fieldInit.identifier->identifier.name); + + LLVMBuildStore( + builder, + CompileExpression( + structTypeDeclaration, + selfParam, + builder, + structInitExpression->structInit.initFields->structInitFields + .fieldInits[i] + ->fieldInit.expression), + structFieldPointer); + } + + return structPointer; +} + static LLVMValueRef CompileExpression( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, Node *expression) { @@ -1348,10 +1658,18 @@ static LLVMValueRef CompileExpression( return CompileAllocExpression(builder, expression); case BinaryExpression: - return CompileBinaryExpression(module, builder, expression); + return CompileBinaryExpression( + structTypeDeclaration, + selfParam, + builder, + expression); case FunctionCallExpression: - return CompileFunctionCallExpression(module, builder, expression); + return CompileFunctionCallExpression( + structTypeDeclaration, + selfParam, + builder, + expression); case Identifier: return FindVariableValue(builder, expression->identifier.name); @@ -1362,8 +1680,19 @@ static LLVMValueRef CompileExpression( case StringLiteral: return CompileString(builder, expression); + case StructInit: + return CompileStructInitExpression( + structTypeDeclaration, + selfParam, + builder, + expression); + case SystemCall: - return CompileSystemCallExpression(module, builder, expression); + return CompileSystemCallExpression( + structTypeDeclaration, + selfParam, + builder, + expression); } fprintf(stderr, "Unknown expression kind!\n"); @@ -1371,13 +1700,15 @@ static LLVMValueRef CompileExpression( } static LLVMBasicBlockRef CompileReturn( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement) { LLVMValueRef expression = CompileExpression( - module, + structTypeDeclaration, + selfParam, builder, returnStatemement->returnStatement.expression); LLVMBuildRet(builder, expression); @@ -1402,7 +1733,7 @@ static LLVMValueRef CompileFunctionVariableDeclaration( char *variableName = variableDeclaration->declaration.identifier->identifier.name; char *ptrName = strdup(variableName); - strcat(ptrName, "_ptr"); + ptrName = w_strcat(ptrName, "_ptr"); variable = LLVMBuildAlloca( builder, @@ -1417,16 +1748,20 @@ static LLVMValueRef CompileFunctionVariableDeclaration( } static LLVMBasicBlockRef CompileAssignment( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement) { LLVMValueRef result = CompileExpression( - module, + structTypeDeclaration, + selfParam, builder, assignmentStatement->assignmentStatement.right); + LLVMValueRef identifier; + if (assignmentStatement->assignmentStatement.left->syntaxKind == AccessExpression) { @@ -1455,20 +1790,38 @@ static LLVMBasicBlockRef CompileAssignment( return LLVMGetLastBasicBlock(function); } - LLVMBuildStore(builder, result, identifier); + if (assignmentStatement->assignmentStatement.right->syntaxKind == + StructInit) + { + LLVMBuildMemCpy( + builder, + identifier, + LLVMGetAlignment(identifier), + result, + LLVMGetAlignment(result), + LLVMSizeOf(LLVMTypeOf(result))); + } + else + { + LLVMBuildStore(builder, result, identifier); + } return LLVMGetLastBasicBlock(function); } static LLVMBasicBlockRef CompileIfStatement( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement) { uint32_t i; - LLVMValueRef conditional = - CompileExpression(module, builder, ifStatement->ifStatement.expression); + LLVMValueRef conditional = CompileExpression( + structTypeDeclaration, + selfParam, + builder, + ifStatement->ifStatement.expression); LLVMBasicBlockRef block = LLVMAppendBasicBlock(function, "ifBlock"); LLVMBasicBlockRef afterCond = LLVMAppendBasicBlock(function, "afterCond"); @@ -1483,7 +1836,8 @@ static LLVMBasicBlockRef CompileIfStatement( i += 1) { CompileStatement( - module, + structTypeDeclaration, + selfParam, builder, function, ifStatement->ifStatement.statementSequence->statementSequence @@ -1497,14 +1851,16 @@ static LLVMBasicBlockRef CompileIfStatement( } static LLVMBasicBlockRef CompileIfElseStatement( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, LLVMValueRef function, Node *ifElseStatement) { uint32_t i; LLVMValueRef conditional = CompileExpression( - module, + structTypeDeclaration, + selfParam, builder, ifElseStatement->ifElseStatement.ifStatement->ifStatement.expression); @@ -1521,7 +1877,8 @@ static LLVMBasicBlockRef CompileIfElseStatement( i += 1) { CompileStatement( - module, + structTypeDeclaration, + selfParam, builder, function, ifElseStatement->ifElseStatement.ifStatement->ifStatement @@ -1540,7 +1897,8 @@ static LLVMBasicBlockRef CompileIfElseStatement( i += 1) { CompileStatement( - module, + structTypeDeclaration, + selfParam, builder, function, ifElseStatement->ifElseStatement.elseStatement @@ -1550,7 +1908,8 @@ static LLVMBasicBlockRef CompileIfElseStatement( else { CompileStatement( - module, + structTypeDeclaration, + selfParam, builder, function, ifElseStatement->ifElseStatement.elseStatement); @@ -1563,7 +1922,8 @@ static LLVMBasicBlockRef CompileIfElseStatement( } static LLVMBasicBlockRef CompileForLoopStatement( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, LLVMValueRef function, Node *forLoopStatement) @@ -1623,7 +1983,8 @@ static LLVMBasicBlockRef CompileForLoopStatement( i += 1) { lastBlock = CompileStatement( - module, + structTypeDeclaration, + selfParam, builder, function, forLoopStatement->forLoop.statementSequence->statementSequence @@ -1652,7 +2013,8 @@ static LLVMBasicBlockRef CompileForLoopStatement( } static LLVMBasicBlockRef CompileStatement( - LLVMModuleRef module, + StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, /* can be NULL for statics */ LLVMBuilderRef builder, LLVMValueRef function, Node *statement) @@ -1660,33 +2022,66 @@ static LLVMBasicBlockRef CompileStatement( switch (statement->syntaxKind) { case Assignment: - return CompileAssignment(module, builder, function, statement); + return CompileAssignment( + structTypeDeclaration, + selfParam, + builder, + function, + statement); case Declaration: CompileFunctionVariableDeclaration(builder, function, statement); return LLVMGetLastBasicBlock(function); case ForLoop: - return CompileForLoopStatement(module, builder, function, statement); + return CompileForLoopStatement( + structTypeDeclaration, + selfParam, + builder, + function, + statement); case FunctionCallExpression: - CompileFunctionCallExpression(module, builder, statement); + CompileFunctionCallExpression( + structTypeDeclaration, + selfParam, + builder, + statement); return LLVMGetLastBasicBlock(function); case IfStatement: - return CompileIfStatement(module, builder, function, statement); + return CompileIfStatement( + structTypeDeclaration, + selfParam, + builder, + function, + statement); case IfElseStatement: - return CompileIfElseStatement(module, builder, function, statement); + return CompileIfElseStatement( + structTypeDeclaration, + selfParam, + builder, + function, + statement); case Return: - return CompileReturn(module, builder, function, statement); + return CompileReturn( + structTypeDeclaration, + selfParam, + builder, + function, + statement); case ReturnVoid: return CompileReturnVoid(builder, function); case SystemCall: - CompileSystemCallExpression(module, builder, statement); + CompileSystemCallExpression( + structTypeDeclaration, + selfParam, + builder, + statement); return LLVMGetLastBasicBlock(function); } @@ -1695,9 +2090,7 @@ static LLVMBasicBlockRef CompileStatement( } static void CompileFunction( - LLVMModuleRef module, - char *parentStructName, - LLVMTypeRef wStructPointerType, + StructTypeDeclaration *structTypeDeclaration, Node *functionDeclaration) { uint32_t i; @@ -1710,6 +2103,7 @@ static void CompileFunction( ->functionSignatureArguments.count; LLVMTypeRef paramTypes[argumentCount + 1]; uint32_t paramIndex = 0; + LLVMValueRef wStructPointer = NULL; if (functionSignature->functionSignature.modifiers->functionModifiers .count > 0) @@ -1728,9 +2122,9 @@ static void CompileFunction( } } - char *functionName = strdup(parentStructName); - strcat(functionName, "_"); - strcat( + char *functionName = strdup(structTypeDeclaration->name); + functionName = w_strcat(functionName, "_"); + functionName = w_strcat( functionName, functionSignature->functionSignature.identifier->identifier.name); @@ -1741,7 +2135,7 @@ static void CompileFunction( if (!isStatic) { - paramTypes[paramIndex] = wStructPointerType; + paramTypes[paramIndex] = structTypeDeclaration->structPointerType; paramIndex += 1; } @@ -1761,11 +2155,13 @@ static void CompileFunction( LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, paramIndex, 0); - LLVMValueRef function = - LLVMAddFunction(module, functionName, functionType); + LLVMValueRef function = LLVMAddFunction( + structTypeDeclaration->module, + functionName, + functionType); DeclareStructFunction( - wStructPointerType, + structTypeDeclaration, function, returnType, isStatic, @@ -1777,8 +2173,11 @@ static void CompileFunction( if (!isStatic) { - LLVMValueRef wStructPointer = LLVMGetParam(function, 0); - AddStructVariablesToScope(builder, wStructPointer); + wStructPointer = LLVMGetParam(function, 0); + AddStructVariablesToScope( + structTypeDeclaration, + builder, + wStructPointer); } for (i = 0; i < functionSignature->functionSignature.arguments @@ -1789,7 +2188,7 @@ static void CompileFunction( strdup(functionSignature->functionSignature.arguments ->functionSignatureArguments.sequence[i] ->declaration.identifier->identifier.name); - strcat(ptrName, "_ptr"); + ptrName = w_strcat(ptrName, "_ptr"); LLVMValueRef argument = LLVMGetParam(function, i + !isStatic); LLVMValueRef argumentCopy = LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName); @@ -1807,7 +2206,8 @@ static void CompileFunction( for (i = 0; i < functionBody->statementSequence.count; i += 1) { CompileStatement( - module, + structTypeDeclaration, + wStructPointer, builder, function, functionBody->statementSequence.sequence[i]); @@ -1832,10 +2232,9 @@ static void CompileFunction( else { DeclareGenericStructFunction( - wStructPointerType, + structTypeDeclaration, functionDeclaration, isStatic, - parentStructName, functionSignature->functionSignature.identifier->identifier.name); } @@ -1854,7 +2253,6 @@ static void CompileStruct( uint8_t packed = 1; LLVMTypeRef types[declarationCount]; Node *currentDeclarationNode; - Node *fieldDeclarations[declarationCount]; char *structName = node->structDeclaration.identifier->identifier.name; PushScopeFrame(scope); @@ -1867,6 +2265,12 @@ static void CompileStruct( wStructType, 0); /* FIXME: is this address space correct? */ + StructTypeDeclaration *structTypeDeclaration = AddStructDeclaration( + module, + wStructType, + wStructPointerType, + structName); + /* first, build the structure definition */ for (i = 0; i < declarationCount; i += 1) { @@ -1876,22 +2280,19 @@ static void CompileStruct( switch (currentDeclarationNode->syntaxKind) { - case Declaration: /* this is badly named */ + case Declaration: /* FIXME: this is badly named */ types[fieldCount] = ResolveType( currentDeclarationNode->declaration.identifier->typeTag); - fieldDeclarations[fieldCount] = currentDeclarationNode; + AddFieldToStructDeclaration( + structTypeDeclaration, + currentDeclarationNode->declaration.identifier->identifier + .name); fieldCount += 1; break; } } LLVMStructSetBody(wStructType, types, fieldCount, packed); - AddStructDeclaration( - wStructType, - wStructPointerType, - structName, - fieldDeclarations, - fieldCount); /* now we can wire up the functions */ for (i = 0; i < declarationCount; i += 1) @@ -1903,18 +2304,14 @@ static void CompileStruct( switch (currentDeclarationNode->syntaxKind) { case FunctionDeclaration: - CompileFunction( - module, - structName, - wStructPointerType, - currentDeclarationNode); + CompileFunction(structTypeDeclaration, currentDeclarationNode); break; } } } else { - AddGenericStructDeclaration(node); + AddGenericStructDeclaration(module, node); } PopScopeFrame(scope); @@ -1954,7 +2351,8 @@ static void RegisterLibraryFunctions( { LLVMTypeRef structType = LLVMStructCreateNamed(context, "Console"); LLVMTypeRef structPointerType = LLVMPointerType(structType, 0); - AddStructDeclaration(structType, structPointerType, "Console", NULL, 0); + StructTypeDeclaration *structTypeDeclaration = + AddStructDeclaration(module, structType, structPointerType, "Console"); LLVMTypeRef printfArg = LLVMPointerType(LLVMInt8Type(), 0); LLVMTypeRef printfFunctionType = @@ -1989,7 +2387,7 @@ static void RegisterLibraryFunctions( LLVMBuildAnd(builder, stringPrint, newlinePrint, "and")); DeclareStructFunction( - structPointerType, + structTypeDeclaration, printLineFunction, LLVMInt8Type(), 1, @@ -2013,6 +2411,41 @@ static void RegisterLibraryFunctions( AddSystemFunction("free", freeFunctionType, freeFunction); + LLVMTypeRef memcpyParams[3]; + memcpyParams[0] = LLVMInt64Type(); + memcpyParams[1] = LLVMInt64Type(); + memcpyParams[2] = LLVMInt64Type(); + LLVMTypeRef memcpyFunctionType = + LLVMFunctionType(LLVMVoidType(), memcpyParams, 3, 0); + LLVMValueRef memcpyFunction = + LLVMAddFunction(module, "memcopy", memcpyFunctionType); + + LLVMBasicBlockRef memcpyEntry = + LLVMAppendBasicBlock(memcpyFunction, "entry"); + LLVMPositionBuilderAtEnd(builder, memcpyEntry); + LLVMValueRef dest = LLVMBuildIntToPtr( + builder, + LLVMGetParam(memcpyFunction, 0), + LLVMPointerType(LLVMInt64Type(), 0), + "dest"); + LLVMValueRef src = LLVMBuildIntToPtr( + builder, + LLVMGetParam(memcpyFunction, 1), + LLVMPointerType(LLVMInt64Type(), 0), + "src"); + + LLVMBuildMemCpy( + builder, + dest, + LLVMGetAlignment(dest), + src, + LLVMGetAlignment(src), + LLVMGetParam(memcpyFunction, 2)); + + LLVMBuildRetVoid(builder); + + AddSystemFunction("memcpy", memcpyFunctionType, memcpyFunction); + LLVMDisposeBuilder(builder); } diff --git a/src/util.c b/src/util.c index 42911e7..491bb81 100644 --- a/src/util.c +++ b/src/util.c @@ -5,7 +5,7 @@ char *strdup(const char *s) { size_t slen = strlen(s); - char *result = (char *)malloc(slen + 1); + char *result = (char *)malloc(sizeof(char) * (slen + 1)); if (result == NULL) { return NULL; @@ -15,6 +15,15 @@ char *strdup(const char *s) return result; } +char *w_strcat(char *s, char *s2) +{ + size_t slen = strlen(s); + size_t slen2 = strlen(s2); + s = realloc(s, sizeof(char) * (slen + slen2 + 1)); + strcat(s, s2); + return s; +} + uint64_t str_hash(char *str) { uint64_t hash = 5381; diff --git a/src/util.h b/src/util.h index 108211b..94a64e2 100644 --- a/src/util.h +++ b/src/util.h @@ -5,6 +5,7 @@ #include char *strdup(const char *s); +char *w_strcat(char *s, char *s2); uint64_t str_hash(char *str); #endif /* WRAITH_UTIL_H */