add struct initializer

pull/11/head
cosmonaut 2021-06-06 14:35:54 -07:00
parent 01da2dc377
commit e2332349b7
5 changed files with 294 additions and 12 deletions

View File

@ -161,6 +161,30 @@ Number : NUMBER
$$ = MakeNumberNode(yytext);
}
FieldInit : Identifier COLON Expression
{
$$ = MakeFieldInitNode($1, $3);
}
StructInitFields : FieldInit
{
$$ = StartStructInitFieldsNode($1);
}
| StructInitFields COMMA FieldInit
{
$$ = AddFieldInitNode($1, $3);
}
|
{
$$ = MakeEmptyFieldInitNode();
}
;
StructInitExpression : Type LEFT_BRACE StructInitFields RIGHT_BRACE
{
$$ = MakeStructInitExpressionNode($1, $3);
}
PrimaryExpression : Number
| STRING_LITERAL
{
@ -172,6 +196,7 @@ PrimaryExpression : Number
}
| FunctionCallExpression
| AccessExpression
| StructInitExpression
;
UnaryExpression : BANG Expression

View File

@ -39,15 +39,16 @@ struct Program {
static Main(): int {
x: int = 4;
y: int = Foo.Func(x);
block: MemoryBlock<int>;
block.capacity = y;
block.start = @malloc(y * @sizeof<int>());
z: MemoryAddress = block.AddressOf(2);
Console.PrintLine("%p", block.start);
block: MemoryBlock<int> = MemoryBlock<int>
{
capacity: y,
start: @malloc(y * @sizeof<int>())
};
block.Set(0, 5);
block.Set(1, 3);
block.Set(2, 9);
block.Set(3, 100);
Console.PrintLine("%p", block.start);
Console.PrintLine("%i", block.Get(0));
Console.PrintLine("%i", block.Get(1));
Console.PrintLine("%i", block.Get(2));

141
src/ast.c
View File

@ -29,6 +29,8 @@ const char *SyntaxKindString(SyntaxKind syntaxKind)
return "ForLoop";
case DeclarationSequence:
return "DeclarationSequence";
case FieldInit:
return "FieldInit";
case FunctionArgumentSequence:
return "FunctionArgumentSequence";
case FunctionCallExpression:
@ -73,6 +75,10 @@ const char *SyntaxKindString(SyntaxKind syntaxKind)
return "StringLiteral";
case StructDeclaration:
return "StructDeclaration";
case StructInit:
return "StructInit";
case StructInitFields:
return "StructInitFields";
case SystemCall:
return "SystemCall";
case Type:
@ -557,6 +563,55 @@ Node *MakeForLoopNode(
return node;
}
Node *MakeFieldInitNode(Node *identifierNode, Node *expressionNode)
{
Node *node = (Node *)malloc(sizeof(Node));
node->syntaxKind = FieldInit;
node->fieldInit.identifier = identifierNode;
node->fieldInit.expression = expressionNode;
return node;
}
Node *StartStructInitFieldsNode(Node *fieldInitNode)
{
Node *node = (Node *)malloc(sizeof(Node));
node->syntaxKind = StructInitFields;
node->structInitFields.fieldInits = (Node **)malloc(sizeof(Node *));
node->structInitFields.fieldInits[0] = fieldInitNode;
node->structInitFields.count = 1;
return node;
}
Node *AddFieldInitNode(Node *structInitFieldsNode, Node *fieldInitNode)
{
structInitFieldsNode->structInitFields.fieldInits = realloc(
structInitFieldsNode->structInitFields.fieldInits,
sizeof(Node *) * (structInitFieldsNode->structInitFields.count + 1));
structInitFieldsNode->structInitFields
.fieldInits[structInitFieldsNode->structInitFields.count] =
fieldInitNode;
structInitFieldsNode->structInitFields.count += 1;
return structInitFieldsNode;
}
Node *MakeEmptyFieldInitNode()
{
Node *node = (Node *)malloc(sizeof(Node));
node->syntaxKind = StructInitFields;
node->structInitFields.fieldInits = NULL;
node->structInitFields.count = 0;
return node;
}
Node *MakeStructInitExpressionNode(Node *typeNode, Node *structInitFieldsNode)
{
Node *node = (Node *)malloc(sizeof(Node));
node->syntaxKind = StructInit;
node->structInit.type = typeNode;
node->structInit.initFields = structInitFieldsNode;
return node;
}
static const char *PrimitiveTypeToString(PrimitiveType type)
{
switch (type)
@ -662,6 +717,12 @@ void PrintNode(Node *node, uint32_t tabCount)
}
return;
case FieldInit:
printf("\n");
PrintNode(node->fieldInit.identifier, tabCount + 1);
PrintNode(node->fieldInit.expression, tabCount + 1);
return;
case ForLoop:
printf("\n");
PrintNode(node->forLoop.declaration, tabCount + 1);
@ -817,6 +878,20 @@ void PrintNode(Node *node, uint32_t tabCount)
PrintNode(node->structDeclaration.declarationSequence, tabCount + 1);
return;
case StructInit:
printf("\n");
PrintNode(node->structInit.type, tabCount + 1);
PrintNode(node->structInit.initFields, tabCount + 1);
return;
case StructInitFields:
printf("\n");
for (i = 0; i < node->structInitFields.count; i += 1)
{
PrintNode(node->structInitFields.fieldInits[i], tabCount + 1);
}
return;
case SystemCall:
printf("\n");
PrintNode(node->systemCall.identifier, tabCount + 1);
@ -882,6 +957,11 @@ void Recurse(Node *node, void (*func)(Node *))
}
return;
case FieldInit:
func(node->fieldInit.identifier);
func(node->fieldInit.expression);
return;
case ForLoop:
func(node->forLoop.declaration);
func(node->forLoop.startNumber);
@ -1003,6 +1083,18 @@ void Recurse(Node *node, void (*func)(Node *))
func(node->structDeclaration.declarationSequence);
return;
case StructInit:
func(node->structInit.type);
func(node->structInit.initFields);
return;
case StructInitFields:
for (i = 0; i < node->structInitFields.count; i += 1)
{
func(node->structInitFields.fieldInits[i]);
}
return;
case SystemCall:
func(node->systemCall.identifier);
func(node->systemCall.argumentSequence);
@ -1193,6 +1285,38 @@ char *TypeTagToString(TypeTag *tag)
}
}
uint8_t TypeTagEqual(TypeTag *typeTagA, TypeTag *typeTagB)
{
if (typeTagA->type != typeTagB->type)
{
return 0;
}
switch (typeTagA->type)
{
case Primitive:
return typeTagA->value.primitiveType == typeTagB->value.primitiveType;
case Reference:
return TypeTagEqual(
typeTagA->value.referenceType,
typeTagB->value.referenceType);
case Custom:
return strcmp(typeTagA->value.customType, typeTagB->value.customType) ==
0;
case Generic:
return strcmp(
typeTagA->value.genericType,
typeTagB->value.genericType) == 0;
default:
fprintf(stderr, "Invalid type comparison!");
return 0;
}
}
void LinkParentPointers(Node *node, Node *prev)
{
if (node == NULL)
@ -1240,6 +1364,11 @@ void LinkParentPointers(Node *node, Node *prev)
}
return;
case FieldInit:
LinkParentPointers(node->fieldInit.identifier, node);
LinkParentPointers(node->fieldInit.expression, node);
return;
case ForLoop:
LinkParentPointers(node->forLoop.declaration, node);
LinkParentPointers(node->forLoop.startNumber, node);
@ -1364,6 +1493,18 @@ void LinkParentPointers(Node *node, Node *prev)
LinkParentPointers(node->structDeclaration.declarationSequence, node);
return;
case StructInit:
LinkParentPointers(node->structInit.type, node);
LinkParentPointers(node->structInit.initFields, node);
return;
case StructInitFields:
for (i = 0; i < node->structInitFields.count; i += 1)
{
LinkParentPointers(node->structInitFields.fieldInits[i], node);
}
return;
case SystemCall:
LinkParentPointers(node->systemCall.identifier, node);
LinkParentPointers(node->systemCall.argumentSequence, node);

View File

@ -23,6 +23,7 @@ typedef enum
CustomTypeNode,
Declaration,
DeclarationSequence,
FieldInit,
ForLoop,
FunctionArgumentSequence,
FunctionCallExpression,
@ -47,6 +48,8 @@ typedef enum
StaticModifier,
StringLiteral,
StructDeclaration,
StructInit,
StructInitFields,
SystemCall,
Type,
UnaryExpression
@ -182,6 +185,12 @@ struct Node
uint32_t count;
} declarationSequence;
struct
{
Node *identifier;
Node *expression;
} fieldInit;
struct
{
Node *declaration;
@ -323,6 +332,18 @@ struct Node
Node *genericDeclarations;
} structDeclaration;
struct
{
Node *type;
Node *initFields;
} structInit;
struct
{
Node **fieldInits;
uint32_t count;
} structInitFields;
struct
{
Node *identifier;
@ -419,6 +440,11 @@ Node *MakeForLoopNode(
Node *startNumberNode,
Node *endNumberNode,
Node *statementSequenceNode);
Node *MakeFieldInitNode(Node *identifierNode, Node *expressionNode);
Node *StartStructInitFieldsNode(Node *fieldInitNode);
Node *AddFieldInitNode(Node *structInitFieldsNode, Node *fieldInitNode);
Node *MakeEmptyFieldInitNode();
Node *MakeStructInitExpressionNode(Node *typeNode, Node *structInitFieldsNode);
void PrintNode(Node *node, uint32_t tabCount);
const char *SyntaxKindString(SyntaxKind syntaxKind);
@ -434,6 +460,7 @@ void LinkParentPointers(Node *node, Node *prev);
TypeTag *MakeTypeTag(Node *node);
char *TypeTagToString(TypeTag *tag);
uint8_t TypeTagEqual(TypeTag *typeTagA, TypeTag *typeTagB);
Node *LookupIdNode(Node *current, Node *prev, char *target);

View File

@ -567,7 +567,9 @@ static StructTypeDeclaration *LookupGenericStructType(
for (k = 0; k < hashArray->elements[j].typeCount; k += 1)
{
if (hashArray->elements[j].types[k] != genericTypeTags[k])
if (!TypeTagEqual(
hashArray->elements[j].types[k],
genericTypeTags[k]))
{
match = 0;
break;
@ -679,7 +681,8 @@ static SystemFunction *LookupSystemFunction(Node *systemCallExpression)
return NULL;
}
static LLVMTypeRef FindStructType(char *name)
/* FIXME: this is awkward, should just resolve the type */
static LLVMTypeRef LookupStructTypeByName(char *name)
{
uint32_t i;
@ -694,6 +697,22 @@ static LLVMTypeRef FindStructType(char *name)
return NULL;
}
static StructTypeDeclaration *LookupStructDeclaration(LLVMTypeRef structType)
{
uint32_t i;
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (structTypeDeclarations[i].structType == structType)
{
return &structTypeDeclarations[i];
}
}
fprintf(stderr, "Struct type not found!");
return NULL;
}
static LLVMValueRef FindStructFieldPointer(
LLVMBuilderRef builder,
LLVMValueRef structPointer,
@ -1095,8 +1114,9 @@ static LLVMValueRef LookupGenericFunction(
for (j = 0; j < hashArray->elements[i].typeCount; j += 1)
{
if (hashArray->elements[i].types[j] !=
resolvedGenericArgumentTypes[j])
if (!TypeTagEqual(
hashArray->elements[i].types[j],
resolvedGenericArgumentTypes[j]))
{
match = 0;
break;
@ -1350,7 +1370,7 @@ static LLVMValueRef CompileFunctionCallExpression(
if (functionCallExpression->functionCallExpression.identifier->syntaxKind ==
AccessExpression)
{
LLVMTypeRef typeReference = FindStructType(
LLVMTypeRef typeReference = LookupStructTypeByName(
functionCallExpression->functionCallExpression.identifier
->accessExpression.accessee->identifier.name);
@ -1582,6 +1602,47 @@ static LLVMValueRef CompileAllocExpression(
return LLVMBuildMalloc(builder, type, "allocation");
}
static LLVMValueRef CompileStructInitExpression(
StructTypeDeclaration *structTypeDeclaration,
LLVMValueRef selfParam,
LLVMBuilderRef builder,
Node *structInitExpression)
{
uint32_t i = 0;
LLVMTypeRef structType = ResolveType(
ConcretizeType(structInitExpression->structInit.type->typeTag));
LLVMValueRef structPointer =
LLVMBuildAlloca(builder, structType, "structInit");
for (i = 0;
i <
structInitExpression->structInit.initFields->structInitFields.count;
i += 1)
{
LLVMValueRef structFieldPointer = FindStructFieldPointer(
builder,
structPointer,
structInitExpression->structInit.initFields->structInitFields
.fieldInits[i]
->fieldInit.identifier->identifier.name);
LLVMBuildStore(
builder,
CompileExpression(
structTypeDeclaration,
selfParam,
builder,
structInitExpression->structInit.initFields->structInitFields
.fieldInits[i]
->fieldInit.expression),
structFieldPointer);
}
return structPointer;
}
static LLVMValueRef CompileExpression(
StructTypeDeclaration *structTypeDeclaration,
LLVMValueRef selfParam,
@ -1619,6 +1680,13 @@ static LLVMValueRef CompileExpression(
case StringLiteral:
return CompileString(builder, expression);
case StructInit:
return CompileStructInitExpression(
structTypeDeclaration,
selfParam,
builder,
expression);
case SystemCall:
return CompileSystemCallExpression(
structTypeDeclaration,
@ -1722,7 +1790,21 @@ static LLVMBasicBlockRef CompileAssignment(
return LLVMGetLastBasicBlock(function);
}
LLVMBuildStore(builder, result, identifier);
if (assignmentStatement->assignmentStatement.right->syntaxKind ==
StructInit)
{
LLVMBuildMemCpy(
builder,
identifier,
LLVMGetAlignment(identifier),
result,
LLVMGetAlignment(result),
LLVMSizeOf(LLVMTypeOf(result)));
}
else
{
LLVMBuildStore(builder, result, identifier);
}
return LLVMGetLastBasicBlock(function);
}
@ -2352,7 +2434,13 @@ static void RegisterLibraryFunctions(
LLVMPointerType(LLVMInt64Type(), 0),
"src");
LLVMBuildMemCpy(builder, dest, 8, src, 8, LLVMGetParam(memcpyFunction, 2));
LLVMBuildMemCpy(
builder,
dest,
LLVMGetAlignment(dest),
src,
LLVMGetAlignment(src),
LLVMGetParam(memcpyFunction, 2));
LLVMBuildRetVoid(builder);