Generic Structs (#11)

Reviewed-on: cosmonaut/wraith-lang#11
Co-authored-by: cosmonaut <evan@moonside.games>
Co-committed-by: cosmonaut <evan@moonside.games>
main
cosmonaut 2021-06-07 18:51:33 +00:00
parent a571edcf6d
commit 9adfaed54c
7 changed files with 1105 additions and 337 deletions

View File

@ -113,9 +113,13 @@ BaseType : VOID
{ {
$$ = MakePrimitiveTypeNode(MemoryAddress); $$ = MakePrimitiveTypeNode(MemoryAddress);
} }
| Identifier GenericArgumentClauseNonEmpty
{
$$ = MakeConcreteGenericTypeNode($1, $2);
}
| Identifier | Identifier
{ {
$$ = MakeCustomTypeNode(yytext); $$ = MakeCustomTypeNode($1);
} }
| REFERENCE LESS_THAN Type GREATER_THAN | REFERENCE LESS_THAN Type GREATER_THAN
{ {
@ -157,6 +161,30 @@ Number : NUMBER
$$ = MakeNumberNode(yytext); $$ = MakeNumberNode(yytext);
} }
FieldInit : Identifier COLON Expression
{
$$ = MakeFieldInitNode($1, $3);
}
StructInitFields : FieldInit
{
$$ = StartStructInitFieldsNode($1);
}
| StructInitFields COMMA FieldInit
{
$$ = AddFieldInitNode($1, $3);
}
|
{
$$ = MakeEmptyFieldInitNode();
}
;
StructInitExpression : Type LEFT_BRACE StructInitFields RIGHT_BRACE
{
$$ = MakeStructInitExpressionNode($1, $3);
}
PrimaryExpression : Number PrimaryExpression : Number
| STRING_LITERAL | STRING_LITERAL
{ {
@ -168,6 +196,7 @@ PrimaryExpression : Number
} }
| FunctionCallExpression | FunctionCallExpression
| AccessExpression | AccessExpression
| StructInitExpression
; ;
UnaryExpression : BANG Expression UnaryExpression : BANG Expression
@ -290,11 +319,11 @@ Statements : Statement
$$ = AddStatement($1, $2); $$ = AddStatement($1, $2);
} }
Arguments : PrimaryExpression Arguments : Expression
{ {
$$ = StartFunctionArgumentSequenceNode($1); $$ = StartFunctionArgumentSequenceNode($1);
} }
| Arguments COMMA PrimaryExpression | Arguments COMMA Expression
{ {
$$ = AddFunctionArgumentNode($1, $3); $$ = AddFunctionArgumentNode($1, $3);
} }
@ -359,11 +388,13 @@ GenericArguments : GenericArgument
$$ = AddGenericArgument($1, $3); $$ = AddGenericArgument($1, $3);
} }
GenericArgumentClauseNonEmpty : LESS_THAN GenericArguments GREATER_THAN
{
$$ = $2;
}
;
GenericArgumentClause : LESS_THAN GenericArguments GREATER_THAN GenericArgumentClause : GenericArgumentClauseNonEmpty
{
$$ = $2;
}
| |
{ {
$$ = MakeEmptyGenericArgumentsNode(); $$ = MakeEmptyGenericArgumentsNode();

View File

@ -14,9 +14,24 @@ struct MemoryBlock<T>
start: MemoryAddress; start: MemoryAddress;
capacity: uint; capacity: uint;
AddressOf(count: uint): MemoryAddress AddressOf(index: uint): MemoryAddress
{ {
return start + (count * @sizeof<T>()); return start + (index * @sizeof<T>());
}
Get(index: uint): T
{
return @dereference<T>(AddressOf(index));
}
Set(index: uint, value: T): void
{
@memcpy(AddressOf(index), @addr(value), @sizeof<T>());
}
Free(): void
{
@free(start);
} }
} }
@ -24,8 +39,21 @@ struct Program {
static Main(): int { static Main(): int {
x: int = 4; x: int = 4;
y: int = Foo.Func(x); y: int = Foo.Func(x);
addr: MemoryAddress = @malloc(y); block: MemoryBlock<int> = MemoryBlock<int>
@free(addr); {
return x; capacity: y,
start: @malloc(y * @sizeof<int>())
};
block.Set(0, 5);
block.Set(1, 3);
block.Set(2, 9);
block.Set(3, 100);
Console.PrintLine("%p", block.start);
Console.PrintLine("%i", block.Get(0));
Console.PrintLine("%i", block.Get(1));
Console.PrintLine("%i", block.Get(2));
Console.PrintLine("%i", block.Get(3));
block.Free();
return 0;
} }
} }

221
src/ast.c
View File

@ -19,6 +19,8 @@ const char *SyntaxKindString(SyntaxKind syntaxKind)
return "BinaryExpression"; return "BinaryExpression";
case Comment: case Comment:
return "Comment"; return "Comment";
case ConcreteGenericTypeNode:
return "ConcreteGenericTypeNode";
case CustomTypeNode: case CustomTypeNode:
return "CustomTypeNode"; return "CustomTypeNode";
case Declaration: case Declaration:
@ -27,6 +29,8 @@ const char *SyntaxKindString(SyntaxKind syntaxKind)
return "ForLoop"; return "ForLoop";
case DeclarationSequence: case DeclarationSequence:
return "DeclarationSequence"; return "DeclarationSequence";
case FieldInit:
return "FieldInit";
case FunctionArgumentSequence: case FunctionArgumentSequence:
return "FunctionArgumentSequence"; return "FunctionArgumentSequence";
case FunctionCallExpression: case FunctionCallExpression:
@ -71,6 +75,10 @@ const char *SyntaxKindString(SyntaxKind syntaxKind)
return "StringLiteral"; return "StringLiteral";
case StructDeclaration: case StructDeclaration:
return "StructDeclaration"; return "StructDeclaration";
case StructInit:
return "StructInit";
case StructInitFields:
return "StructInitFields";
case SystemCall: case SystemCall:
return "SystemCall"; return "SystemCall";
case Type: case Type:
@ -95,11 +103,12 @@ Node *MakePrimitiveTypeNode(PrimitiveType type)
return node; return node;
} }
Node *MakeCustomTypeNode(char *name) Node *MakeCustomTypeNode(Node *identifierNode)
{ {
Node *node = (Node *)malloc(sizeof(Node)); Node *node = (Node *)malloc(sizeof(Node));
node->syntaxKind = CustomTypeNode; node->syntaxKind = CustomTypeNode;
node->customType.name = strdup(name); node->customType.name = strdup(identifierNode->identifier.name);
free(identifierNode);
return node; return node;
} }
@ -111,6 +120,18 @@ Node *MakeReferenceTypeNode(Node *typeNode)
return node; return node;
} }
Node *MakeConcreteGenericTypeNode(
Node *identifierNode,
Node *genericArgumentsNode)
{
Node *node = (Node *)malloc(sizeof(Node));
node->syntaxKind = ConcreteGenericTypeNode;
node->concreteGenericType.name = strdup(identifierNode->identifier.name);
node->concreteGenericType.genericArguments = genericArgumentsNode;
free(identifierNode);
return node;
}
Node *MakeTypeNode(Node *typeNode) Node *MakeTypeNode(Node *typeNode)
{ {
Node *node = (Node *)malloc(sizeof(Node)); Node *node = (Node *)malloc(sizeof(Node));
@ -542,6 +563,55 @@ Node *MakeForLoopNode(
return node; return node;
} }
Node *MakeFieldInitNode(Node *identifierNode, Node *expressionNode)
{
Node *node = (Node *)malloc(sizeof(Node));
node->syntaxKind = FieldInit;
node->fieldInit.identifier = identifierNode;
node->fieldInit.expression = expressionNode;
return node;
}
Node *StartStructInitFieldsNode(Node *fieldInitNode)
{
Node *node = (Node *)malloc(sizeof(Node));
node->syntaxKind = StructInitFields;
node->structInitFields.fieldInits = (Node **)malloc(sizeof(Node *));
node->structInitFields.fieldInits[0] = fieldInitNode;
node->structInitFields.count = 1;
return node;
}
Node *AddFieldInitNode(Node *structInitFieldsNode, Node *fieldInitNode)
{
structInitFieldsNode->structInitFields.fieldInits = realloc(
structInitFieldsNode->structInitFields.fieldInits,
sizeof(Node *) * (structInitFieldsNode->structInitFields.count + 1));
structInitFieldsNode->structInitFields
.fieldInits[structInitFieldsNode->structInitFields.count] =
fieldInitNode;
structInitFieldsNode->structInitFields.count += 1;
return structInitFieldsNode;
}
Node *MakeEmptyFieldInitNode()
{
Node *node = (Node *)malloc(sizeof(Node));
node->syntaxKind = StructInitFields;
node->structInitFields.fieldInits = NULL;
node->structInitFields.count = 0;
return node;
}
Node *MakeStructInitExpressionNode(Node *typeNode, Node *structInitFieldsNode)
{
Node *node = (Node *)malloc(sizeof(Node));
node->syntaxKind = StructInit;
node->structInit.type = typeNode;
node->structInit.initFields = structInitFieldsNode;
return node;
}
static const char *PrimitiveTypeToString(PrimitiveType type) static const char *PrimitiveTypeToString(PrimitiveType type)
{ {
switch (type) switch (type)
@ -624,6 +694,11 @@ void PrintNode(Node *node, uint32_t tabCount)
PrintNode(node->binaryExpression.right, tabCount + 1); PrintNode(node->binaryExpression.right, tabCount + 1);
return; return;
case ConcreteGenericTypeNode:
printf("%s\n", node->concreteGenericType.name);
PrintNode(node->concreteGenericType.genericArguments, tabCount + 1);
return;
case CustomTypeNode: case CustomTypeNode:
printf("%s\n", node->customType.name); printf("%s\n", node->customType.name);
return; return;
@ -642,6 +717,12 @@ void PrintNode(Node *node, uint32_t tabCount)
} }
return; return;
case FieldInit:
printf("\n");
PrintNode(node->fieldInit.identifier, tabCount + 1);
PrintNode(node->fieldInit.expression, tabCount + 1);
return;
case ForLoop: case ForLoop:
printf("\n"); printf("\n");
PrintNode(node->forLoop.declaration, tabCount + 1); PrintNode(node->forLoop.declaration, tabCount + 1);
@ -797,6 +878,20 @@ void PrintNode(Node *node, uint32_t tabCount)
PrintNode(node->structDeclaration.declarationSequence, tabCount + 1); PrintNode(node->structDeclaration.declarationSequence, tabCount + 1);
return; return;
case StructInit:
printf("\n");
PrintNode(node->structInit.type, tabCount + 1);
PrintNode(node->structInit.initFields, tabCount + 1);
return;
case StructInitFields:
printf("\n");
for (i = 0; i < node->structInitFields.count; i += 1)
{
PrintNode(node->structInitFields.fieldInits[i], tabCount + 1);
}
return;
case SystemCall: case SystemCall:
printf("\n"); printf("\n");
PrintNode(node->systemCall.identifier, tabCount + 1); PrintNode(node->systemCall.identifier, tabCount + 1);
@ -843,6 +938,10 @@ void Recurse(Node *node, void (*func)(Node *))
case Comment: case Comment:
return; return;
case ConcreteGenericTypeNode:
func(node->concreteGenericType.genericArguments);
return;
case CustomTypeNode: case CustomTypeNode:
return; return;
@ -858,6 +957,11 @@ void Recurse(Node *node, void (*func)(Node *))
} }
return; return;
case FieldInit:
func(node->fieldInit.identifier);
func(node->fieldInit.expression);
return;
case ForLoop: case ForLoop:
func(node->forLoop.declaration); func(node->forLoop.declaration);
func(node->forLoop.startNumber); func(node->forLoop.startNumber);
@ -979,6 +1083,18 @@ void Recurse(Node *node, void (*func)(Node *))
func(node->structDeclaration.declarationSequence); func(node->structDeclaration.declarationSequence);
return; return;
case StructInit:
func(node->structInit.type);
func(node->structInit.initFields);
return;
case StructInitFields:
for (i = 0; i < node->structInitFields.count; i += 1)
{
func(node->structInitFields.fieldInits[i]);
}
return;
case SystemCall: case SystemCall:
func(node->systemCall.identifier); func(node->systemCall.identifier);
func(node->systemCall.argumentSequence); func(node->systemCall.argumentSequence);
@ -1004,6 +1120,8 @@ void Recurse(Node *node, void (*func)(Node *))
TypeTag *MakeTypeTag(Node *node) TypeTag *MakeTypeTag(Node *node)
{ {
uint32_t i;
if (node == NULL) if (node == NULL)
{ {
fprintf( fprintf(
@ -1034,6 +1152,28 @@ TypeTag *MakeTypeTag(Node *node)
tag->value.customType = strdup(node->customType.name); tag->value.customType = strdup(node->customType.name);
break; break;
case ConcreteGenericTypeNode:
tag->type = ConcreteGeneric;
tag->value.concreteGenericType.name =
strdup(node->concreteGenericType.name);
tag->value.concreteGenericType.genericArgumentCount =
node->concreteGenericType.genericArguments->genericArguments.count;
tag->value.concreteGenericType.genericArguments = malloc(
sizeof(TypeTag *) *
tag->value.concreteGenericType.genericArgumentCount);
for (i = 0;
i <
node->concreteGenericType.genericArguments->genericArguments.count;
i += 1)
{
tag->value.concreteGenericType.genericArguments[i] = MakeTypeTag(
node->concreteGenericType.genericArguments->genericArguments
.arguments[i]
->genericArgument.type);
}
break;
case Declaration: case Declaration:
tag = MakeTypeTag(node->declaration.type); tag = MakeTypeTag(node->declaration.type);
break; break;
@ -1078,6 +1218,8 @@ TypeTag *MakeTypeTag(Node *node)
char *TypeTagToString(TypeTag *tag) char *TypeTagToString(TypeTag *tag)
{ {
uint32_t i;
if (tag == NULL) if (tag == NULL)
{ {
fprintf( fprintf(
@ -1114,6 +1256,64 @@ char *TypeTagToString(TypeTag *tag)
sprintf(result, "Generic<%s>", tag->value.genericType); sprintf(result, "Generic<%s>", tag->value.genericType);
return result; return result;
} }
case ConcreteGeneric:
{
char *result = strdup(tag->value.concreteGenericType.name);
uint32_t len = strlen(result);
len += 2;
result = realloc(result, sizeof(char) * len);
strcat(result, "<");
for (i = 0; i < tag->value.concreteGenericType.genericArgumentCount;
i += 1)
{
char *inner = TypeTagToString(
tag->value.concreteGenericType.genericArguments[i]);
len += strlen(inner);
result = realloc(result, sizeof(char) * (len + 3));
if (i != tag->value.concreteGenericType.genericArgumentCount - 1)
{
strcat(result, ", ");
}
strcat(result, inner);
}
result = realloc(result, sizeof(char) * (len + 1));
strcat(result, ">");
return result;
}
}
}
uint8_t TypeTagEqual(TypeTag *typeTagA, TypeTag *typeTagB)
{
if (typeTagA->type != typeTagB->type)
{
return 0;
}
switch (typeTagA->type)
{
case Primitive:
return typeTagA->value.primitiveType == typeTagB->value.primitiveType;
case Reference:
return TypeTagEqual(
typeTagA->value.referenceType,
typeTagB->value.referenceType);
case Custom:
return strcmp(typeTagA->value.customType, typeTagB->value.customType) ==
0;
case Generic:
return strcmp(
typeTagA->value.genericType,
typeTagB->value.genericType) == 0;
default:
fprintf(stderr, "Invalid type comparison!");
return 0;
} }
} }
@ -1164,6 +1364,11 @@ void LinkParentPointers(Node *node, Node *prev)
} }
return; return;
case FieldInit:
LinkParentPointers(node->fieldInit.identifier, node);
LinkParentPointers(node->fieldInit.expression, node);
return;
case ForLoop: case ForLoop:
LinkParentPointers(node->forLoop.declaration, node); LinkParentPointers(node->forLoop.declaration, node);
LinkParentPointers(node->forLoop.startNumber, node); LinkParentPointers(node->forLoop.startNumber, node);
@ -1288,6 +1493,18 @@ void LinkParentPointers(Node *node, Node *prev)
LinkParentPointers(node->structDeclaration.declarationSequence, node); LinkParentPointers(node->structDeclaration.declarationSequence, node);
return; return;
case StructInit:
LinkParentPointers(node->structInit.type, node);
LinkParentPointers(node->structInit.initFields, node);
return;
case StructInitFields:
for (i = 0; i < node->structInitFields.count; i += 1)
{
LinkParentPointers(node->structInitFields.fieldInits[i], node);
}
return;
case SystemCall: case SystemCall:
LinkParentPointers(node->systemCall.identifier, node); LinkParentPointers(node->systemCall.identifier, node);
LinkParentPointers(node->systemCall.argumentSequence, node); LinkParentPointers(node->systemCall.argumentSequence, node);

View File

@ -19,9 +19,11 @@ typedef enum
Assignment, Assignment,
BinaryExpression, BinaryExpression,
Comment, Comment,
ConcreteGenericTypeNode,
CustomTypeNode, CustomTypeNode,
Declaration, Declaration,
DeclarationSequence, DeclarationSequence,
FieldInit,
ForLoop, ForLoop,
FunctionArgumentSequence, FunctionArgumentSequence,
FunctionCallExpression, FunctionCallExpression,
@ -46,6 +48,8 @@ typedef enum
StaticModifier, StaticModifier,
StringLiteral, StringLiteral,
StructDeclaration, StructDeclaration,
StructInit,
StructInitFields,
SystemCall, SystemCall,
Type, Type,
UnaryExpression UnaryExpression
@ -86,7 +90,16 @@ typedef union
BinaryOperator binaryOperator; BinaryOperator binaryOperator;
} Operator; } Operator;
typedef struct TypeTag typedef struct TypeTag TypeTag;
typedef struct ConcreteGenericTypeTag
{
char *name;
TypeTag **genericArguments;
uint32_t genericArgumentCount;
} ConcreteGenericTypeTag;
struct TypeTag
{ {
enum Type enum Type
{ {
@ -94,7 +107,8 @@ typedef struct TypeTag
Primitive, Primitive,
Reference, Reference,
Custom, Custom,
Generic Generic,
ConcreteGeneric
} type; } type;
union union
{ {
@ -106,8 +120,10 @@ typedef struct TypeTag
char *customType; char *customType;
/* Valid when type = Generic. */ /* Valid when type = Generic. */
char *genericType; char *genericType;
/* Valid when type = ConcreteGeneric */
ConcreteGenericTypeTag concreteGenericType;
} value; } value;
} TypeTag; };
typedef struct Node Node; typedef struct Node Node;
@ -146,6 +162,12 @@ struct Node
} comment; } comment;
struct
{
char *name;
Node *genericArguments;
} concreteGenericType;
struct struct
{ {
char *name; char *name;
@ -163,6 +185,12 @@ struct Node
uint32_t count; uint32_t count;
} declarationSequence; } declarationSequence;
struct
{
Node *identifier;
Node *expression;
} fieldInit;
struct struct
{ {
Node *declaration; Node *declaration;
@ -304,6 +332,18 @@ struct Node
Node *genericDeclarations; Node *genericDeclarations;
} structDeclaration; } structDeclaration;
struct
{
Node *type;
Node *initFields;
} structInit;
struct
{
Node **fieldInits;
uint32_t count;
} structInitFields;
struct struct
{ {
Node *identifier; Node *identifier;
@ -329,8 +369,11 @@ const char *SyntaxKindString(SyntaxKind syntaxKind);
uint8_t IsPrimitiveType(Node *typeNode); uint8_t IsPrimitiveType(Node *typeNode);
Node *MakePrimitiveTypeNode(PrimitiveType type); Node *MakePrimitiveTypeNode(PrimitiveType type);
Node *MakeCustomTypeNode(char *string); Node *MakeCustomTypeNode(Node *identifierNode);
Node *MakeReferenceTypeNode(Node *typeNode); Node *MakeReferenceTypeNode(Node *typeNode);
Node *MakeConcreteGenericTypeNode(
Node *identifierNode,
Node *genericArgumentsNode);
Node *MakeTypeNode(Node *typeNode); Node *MakeTypeNode(Node *typeNode);
Node *MakeIdentifierNode(const char *id); Node *MakeIdentifierNode(const char *id);
Node *MakeNumberNode(const char *numberString); Node *MakeNumberNode(const char *numberString);
@ -397,6 +440,11 @@ Node *MakeForLoopNode(
Node *startNumberNode, Node *startNumberNode,
Node *endNumberNode, Node *endNumberNode,
Node *statementSequenceNode); Node *statementSequenceNode);
Node *MakeFieldInitNode(Node *identifierNode, Node *expressionNode);
Node *StartStructInitFieldsNode(Node *fieldInitNode);
Node *AddFieldInitNode(Node *structInitFieldsNode, Node *fieldInitNode);
Node *MakeEmptyFieldInitNode();
Node *MakeStructInitExpressionNode(Node *typeNode, Node *structInitFieldsNode);
void PrintNode(Node *node, uint32_t tabCount); void PrintNode(Node *node, uint32_t tabCount);
const char *SyntaxKindString(SyntaxKind syntaxKind); const char *SyntaxKindString(SyntaxKind syntaxKind);
@ -412,6 +460,7 @@ void LinkParentPointers(Node *node, Node *prev);
TypeTag *MakeTypeTag(Node *node); TypeTag *MakeTypeTag(Node *node);
char *TypeTagToString(TypeTag *tag); char *TypeTagToString(TypeTag *tag);
uint8_t TypeTagEqual(TypeTag *typeTagA, TypeTag *typeTagB);
Node *LookupIdNode(Node *current, Node *prev, char *target); Node *LookupIdNode(Node *current, Node *prev, char *target);

File diff suppressed because it is too large Load Diff

View File

@ -5,7 +5,7 @@
char *strdup(const char *s) char *strdup(const char *s)
{ {
size_t slen = strlen(s); size_t slen = strlen(s);
char *result = (char *)malloc(slen + 1); char *result = (char *)malloc(sizeof(char) * (slen + 1));
if (result == NULL) if (result == NULL)
{ {
return NULL; return NULL;
@ -15,6 +15,15 @@ char *strdup(const char *s)
return result; return result;
} }
char *w_strcat(char *s, char *s2)
{
size_t slen = strlen(s);
size_t slen2 = strlen(s2);
s = realloc(s, sizeof(char) * (slen + slen2 + 1));
strcat(s, s2);
return s;
}
uint64_t str_hash(char *str) uint64_t str_hash(char *str)
{ {
uint64_t hash = 5381; uint64_t hash = 5381;

View File

@ -5,6 +5,7 @@
#include <string.h> #include <string.h>
char *strdup(const char *s); char *strdup(const char *s);
char *w_strcat(char *s, char *s2);
uint64_t str_hash(char *str); uint64_t str_hash(char *str);
#endif /* WRAITH_UTIL_H */ #endif /* WRAITH_UTIL_H */