From e2332349b7ccbd7451319e6e293e5a5807b817af Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Sun, 6 Jun 2021 14:35:54 -0700 Subject: [PATCH] add struct initializer --- generators/wraith.y | 25 ++++++++ generic.w | 11 ++-- src/ast.c | 141 ++++++++++++++++++++++++++++++++++++++++++++ src/ast.h | 27 +++++++++ src/codegen.c | 102 +++++++++++++++++++++++++++++--- 5 files changed, 294 insertions(+), 12 deletions(-) diff --git a/generators/wraith.y b/generators/wraith.y index a3cd8c5..d9a637f 100644 --- a/generators/wraith.y +++ b/generators/wraith.y @@ -161,6 +161,30 @@ Number : NUMBER $$ = MakeNumberNode(yytext); } +FieldInit : Identifier COLON Expression + { + $$ = MakeFieldInitNode($1, $3); + } + +StructInitFields : FieldInit + { + $$ = StartStructInitFieldsNode($1); + } + | StructInitFields COMMA FieldInit + { + $$ = AddFieldInitNode($1, $3); + } + | + { + $$ = MakeEmptyFieldInitNode(); + } + ; + +StructInitExpression : Type LEFT_BRACE StructInitFields RIGHT_BRACE + { + $$ = MakeStructInitExpressionNode($1, $3); + } + PrimaryExpression : Number | STRING_LITERAL { @@ -172,6 +196,7 @@ PrimaryExpression : Number } | FunctionCallExpression | AccessExpression + | StructInitExpression ; UnaryExpression : BANG Expression diff --git a/generic.w b/generic.w index ba8d511..b85d7e8 100644 --- a/generic.w +++ b/generic.w @@ -39,15 +39,16 @@ struct Program { static Main(): int { x: int = 4; y: int = Foo.Func(x); - block: MemoryBlock; - block.capacity = y; - block.start = @malloc(y * @sizeof()); - z: MemoryAddress = block.AddressOf(2); - Console.PrintLine("%p", block.start); + block: MemoryBlock = MemoryBlock + { + capacity: y, + start: @malloc(y * @sizeof()) + }; block.Set(0, 5); block.Set(1, 3); block.Set(2, 9); block.Set(3, 100); + Console.PrintLine("%p", block.start); Console.PrintLine("%i", block.Get(0)); Console.PrintLine("%i", block.Get(1)); Console.PrintLine("%i", block.Get(2)); diff --git a/src/ast.c b/src/ast.c index 313a51d..a4cd091 100644 --- a/src/ast.c +++ b/src/ast.c @@ -29,6 +29,8 @@ const char *SyntaxKindString(SyntaxKind syntaxKind) return "ForLoop"; case DeclarationSequence: return "DeclarationSequence"; + case FieldInit: + return "FieldInit"; case FunctionArgumentSequence: return "FunctionArgumentSequence"; case FunctionCallExpression: @@ -73,6 +75,10 @@ const char *SyntaxKindString(SyntaxKind syntaxKind) return "StringLiteral"; case StructDeclaration: return "StructDeclaration"; + case StructInit: + return "StructInit"; + case StructInitFields: + return "StructInitFields"; case SystemCall: return "SystemCall"; case Type: @@ -557,6 +563,55 @@ Node *MakeForLoopNode( return node; } +Node *MakeFieldInitNode(Node *identifierNode, Node *expressionNode) +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = FieldInit; + node->fieldInit.identifier = identifierNode; + node->fieldInit.expression = expressionNode; + return node; +} + +Node *StartStructInitFieldsNode(Node *fieldInitNode) +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = StructInitFields; + node->structInitFields.fieldInits = (Node **)malloc(sizeof(Node *)); + node->structInitFields.fieldInits[0] = fieldInitNode; + node->structInitFields.count = 1; + return node; +} + +Node *AddFieldInitNode(Node *structInitFieldsNode, Node *fieldInitNode) +{ + structInitFieldsNode->structInitFields.fieldInits = realloc( + structInitFieldsNode->structInitFields.fieldInits, + sizeof(Node *) * (structInitFieldsNode->structInitFields.count + 1)); + structInitFieldsNode->structInitFields + .fieldInits[structInitFieldsNode->structInitFields.count] = + fieldInitNode; + structInitFieldsNode->structInitFields.count += 1; + return structInitFieldsNode; +} + +Node *MakeEmptyFieldInitNode() +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = StructInitFields; + node->structInitFields.fieldInits = NULL; + node->structInitFields.count = 0; + return node; +} + +Node *MakeStructInitExpressionNode(Node *typeNode, Node *structInitFieldsNode) +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = StructInit; + node->structInit.type = typeNode; + node->structInit.initFields = structInitFieldsNode; + return node; +} + static const char *PrimitiveTypeToString(PrimitiveType type) { switch (type) @@ -662,6 +717,12 @@ void PrintNode(Node *node, uint32_t tabCount) } return; + case FieldInit: + printf("\n"); + PrintNode(node->fieldInit.identifier, tabCount + 1); + PrintNode(node->fieldInit.expression, tabCount + 1); + return; + case ForLoop: printf("\n"); PrintNode(node->forLoop.declaration, tabCount + 1); @@ -817,6 +878,20 @@ void PrintNode(Node *node, uint32_t tabCount) PrintNode(node->structDeclaration.declarationSequence, tabCount + 1); return; + case StructInit: + printf("\n"); + PrintNode(node->structInit.type, tabCount + 1); + PrintNode(node->structInit.initFields, tabCount + 1); + return; + + case StructInitFields: + printf("\n"); + for (i = 0; i < node->structInitFields.count; i += 1) + { + PrintNode(node->structInitFields.fieldInits[i], tabCount + 1); + } + return; + case SystemCall: printf("\n"); PrintNode(node->systemCall.identifier, tabCount + 1); @@ -882,6 +957,11 @@ void Recurse(Node *node, void (*func)(Node *)) } return; + case FieldInit: + func(node->fieldInit.identifier); + func(node->fieldInit.expression); + return; + case ForLoop: func(node->forLoop.declaration); func(node->forLoop.startNumber); @@ -1003,6 +1083,18 @@ void Recurse(Node *node, void (*func)(Node *)) func(node->structDeclaration.declarationSequence); return; + case StructInit: + func(node->structInit.type); + func(node->structInit.initFields); + return; + + case StructInitFields: + for (i = 0; i < node->structInitFields.count; i += 1) + { + func(node->structInitFields.fieldInits[i]); + } + return; + case SystemCall: func(node->systemCall.identifier); func(node->systemCall.argumentSequence); @@ -1193,6 +1285,38 @@ char *TypeTagToString(TypeTag *tag) } } +uint8_t TypeTagEqual(TypeTag *typeTagA, TypeTag *typeTagB) +{ + if (typeTagA->type != typeTagB->type) + { + return 0; + } + + switch (typeTagA->type) + { + case Primitive: + return typeTagA->value.primitiveType == typeTagB->value.primitiveType; + + case Reference: + return TypeTagEqual( + typeTagA->value.referenceType, + typeTagB->value.referenceType); + + case Custom: + return strcmp(typeTagA->value.customType, typeTagB->value.customType) == + 0; + + case Generic: + return strcmp( + typeTagA->value.genericType, + typeTagB->value.genericType) == 0; + + default: + fprintf(stderr, "Invalid type comparison!"); + return 0; + } +} + void LinkParentPointers(Node *node, Node *prev) { if (node == NULL) @@ -1240,6 +1364,11 @@ void LinkParentPointers(Node *node, Node *prev) } return; + case FieldInit: + LinkParentPointers(node->fieldInit.identifier, node); + LinkParentPointers(node->fieldInit.expression, node); + return; + case ForLoop: LinkParentPointers(node->forLoop.declaration, node); LinkParentPointers(node->forLoop.startNumber, node); @@ -1364,6 +1493,18 @@ void LinkParentPointers(Node *node, Node *prev) LinkParentPointers(node->structDeclaration.declarationSequence, node); return; + case StructInit: + LinkParentPointers(node->structInit.type, node); + LinkParentPointers(node->structInit.initFields, node); + return; + + case StructInitFields: + for (i = 0; i < node->structInitFields.count; i += 1) + { + LinkParentPointers(node->structInitFields.fieldInits[i], node); + } + return; + case SystemCall: LinkParentPointers(node->systemCall.identifier, node); LinkParentPointers(node->systemCall.argumentSequence, node); diff --git a/src/ast.h b/src/ast.h index 531064b..3fdedf9 100644 --- a/src/ast.h +++ b/src/ast.h @@ -23,6 +23,7 @@ typedef enum CustomTypeNode, Declaration, DeclarationSequence, + FieldInit, ForLoop, FunctionArgumentSequence, FunctionCallExpression, @@ -47,6 +48,8 @@ typedef enum StaticModifier, StringLiteral, StructDeclaration, + StructInit, + StructInitFields, SystemCall, Type, UnaryExpression @@ -182,6 +185,12 @@ struct Node uint32_t count; } declarationSequence; + struct + { + Node *identifier; + Node *expression; + } fieldInit; + struct { Node *declaration; @@ -323,6 +332,18 @@ struct Node Node *genericDeclarations; } structDeclaration; + struct + { + Node *type; + Node *initFields; + } structInit; + + struct + { + Node **fieldInits; + uint32_t count; + } structInitFields; + struct { Node *identifier; @@ -419,6 +440,11 @@ Node *MakeForLoopNode( Node *startNumberNode, Node *endNumberNode, Node *statementSequenceNode); +Node *MakeFieldInitNode(Node *identifierNode, Node *expressionNode); +Node *StartStructInitFieldsNode(Node *fieldInitNode); +Node *AddFieldInitNode(Node *structInitFieldsNode, Node *fieldInitNode); +Node *MakeEmptyFieldInitNode(); +Node *MakeStructInitExpressionNode(Node *typeNode, Node *structInitFieldsNode); void PrintNode(Node *node, uint32_t tabCount); const char *SyntaxKindString(SyntaxKind syntaxKind); @@ -434,6 +460,7 @@ void LinkParentPointers(Node *node, Node *prev); TypeTag *MakeTypeTag(Node *node); char *TypeTagToString(TypeTag *tag); +uint8_t TypeTagEqual(TypeTag *typeTagA, TypeTag *typeTagB); Node *LookupIdNode(Node *current, Node *prev, char *target); diff --git a/src/codegen.c b/src/codegen.c index f9cbf7b..b2421f9 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -567,7 +567,9 @@ static StructTypeDeclaration *LookupGenericStructType( for (k = 0; k < hashArray->elements[j].typeCount; k += 1) { - if (hashArray->elements[j].types[k] != genericTypeTags[k]) + if (!TypeTagEqual( + hashArray->elements[j].types[k], + genericTypeTags[k])) { match = 0; break; @@ -679,7 +681,8 @@ static SystemFunction *LookupSystemFunction(Node *systemCallExpression) return NULL; } -static LLVMTypeRef FindStructType(char *name) +/* FIXME: this is awkward, should just resolve the type */ +static LLVMTypeRef LookupStructTypeByName(char *name) { uint32_t i; @@ -694,6 +697,22 @@ static LLVMTypeRef FindStructType(char *name) return NULL; } +static StructTypeDeclaration *LookupStructDeclaration(LLVMTypeRef structType) +{ + uint32_t i; + + for (i = 0; i < structTypeDeclarationCount; i += 1) + { + if (structTypeDeclarations[i].structType == structType) + { + return &structTypeDeclarations[i]; + } + } + + fprintf(stderr, "Struct type not found!"); + return NULL; +} + static LLVMValueRef FindStructFieldPointer( LLVMBuilderRef builder, LLVMValueRef structPointer, @@ -1095,8 +1114,9 @@ static LLVMValueRef LookupGenericFunction( for (j = 0; j < hashArray->elements[i].typeCount; j += 1) { - if (hashArray->elements[i].types[j] != - resolvedGenericArgumentTypes[j]) + if (!TypeTagEqual( + hashArray->elements[i].types[j], + resolvedGenericArgumentTypes[j])) { match = 0; break; @@ -1350,7 +1370,7 @@ static LLVMValueRef CompileFunctionCallExpression( if (functionCallExpression->functionCallExpression.identifier->syntaxKind == AccessExpression) { - LLVMTypeRef typeReference = FindStructType( + LLVMTypeRef typeReference = LookupStructTypeByName( functionCallExpression->functionCallExpression.identifier ->accessExpression.accessee->identifier.name); @@ -1582,6 +1602,47 @@ static LLVMValueRef CompileAllocExpression( return LLVMBuildMalloc(builder, type, "allocation"); } +static LLVMValueRef CompileStructInitExpression( + StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, + LLVMBuilderRef builder, + Node *structInitExpression) +{ + uint32_t i = 0; + + LLVMTypeRef structType = ResolveType( + ConcretizeType(structInitExpression->structInit.type->typeTag)); + + LLVMValueRef structPointer = + LLVMBuildAlloca(builder, structType, "structInit"); + + for (i = 0; + i < + structInitExpression->structInit.initFields->structInitFields.count; + i += 1) + { + LLVMValueRef structFieldPointer = FindStructFieldPointer( + builder, + structPointer, + structInitExpression->structInit.initFields->structInitFields + .fieldInits[i] + ->fieldInit.identifier->identifier.name); + + LLVMBuildStore( + builder, + CompileExpression( + structTypeDeclaration, + selfParam, + builder, + structInitExpression->structInit.initFields->structInitFields + .fieldInits[i] + ->fieldInit.expression), + structFieldPointer); + } + + return structPointer; +} + static LLVMValueRef CompileExpression( StructTypeDeclaration *structTypeDeclaration, LLVMValueRef selfParam, @@ -1619,6 +1680,13 @@ static LLVMValueRef CompileExpression( case StringLiteral: return CompileString(builder, expression); + case StructInit: + return CompileStructInitExpression( + structTypeDeclaration, + selfParam, + builder, + expression); + case SystemCall: return CompileSystemCallExpression( structTypeDeclaration, @@ -1722,7 +1790,21 @@ static LLVMBasicBlockRef CompileAssignment( return LLVMGetLastBasicBlock(function); } - LLVMBuildStore(builder, result, identifier); + if (assignmentStatement->assignmentStatement.right->syntaxKind == + StructInit) + { + LLVMBuildMemCpy( + builder, + identifier, + LLVMGetAlignment(identifier), + result, + LLVMGetAlignment(result), + LLVMSizeOf(LLVMTypeOf(result))); + } + else + { + LLVMBuildStore(builder, result, identifier); + } return LLVMGetLastBasicBlock(function); } @@ -2352,7 +2434,13 @@ static void RegisterLibraryFunctions( LLVMPointerType(LLVMInt64Type(), 0), "src"); - LLVMBuildMemCpy(builder, dest, 8, src, 8, LLVMGetParam(memcpyFunction, 2)); + LLVMBuildMemCpy( + builder, + dest, + LLVMGetAlignment(dest), + src, + LLVMGetAlignment(src), + LLVMGetParam(memcpyFunction, 2)); LLVMBuildRetVoid(builder);