From abc82f381e7397485b8195be36f3ccae63a8fc87 Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Sat, 15 May 2021 15:34:15 -0700 Subject: [PATCH 1/2] refactor AST to use nameless union instead of child array --- src/ast.c | 368 ++++++++++++++++++++++++++++++-------------------- src/ast.h | 201 ++++++++++++++++++++++++--- src/codegen.c | 172 ++++++++++++----------- src/main.c | 2 +- src/parser.c | 2 +- 5 files changed, 494 insertions(+), 251 deletions(-) diff --git a/src/ast.c b/src/ast.c index 273739a..942ff74 100644 --- a/src/ast.c +++ b/src/ast.c @@ -16,7 +16,6 @@ const char* SyntaxKindString(SyntaxKind syntaxKind) case Comment: return "Comment"; case CustomTypeNode: return "CustomTypeNode"; case Declaration: return "Declaration"; - case Expression: return "Expression"; case ForLoop: return "ForLoop"; case DeclarationSequence: return "DeclarationSequence"; case FunctionArgumentSequence: return "FunctionArgumentSequence"; @@ -45,7 +44,7 @@ const char* SyntaxKindString(SyntaxKind syntaxKind) uint8_t IsPrimitiveType( Node *typeNode ) { - return typeNode->children[0]->syntaxKind == PrimitiveTypeNode; + return typeNode->type.typeNode->syntaxKind == PrimitiveTypeNode; } Node* MakePrimitiveTypeNode( @@ -53,8 +52,7 @@ Node* MakePrimitiveTypeNode( ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = PrimitiveTypeNode; - node->primitiveType = type; - node->childCount = 0; + node->primitiveType.type = type; return node; } @@ -63,8 +61,7 @@ Node* MakeCustomTypeNode( ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = CustomTypeNode; - node->value.string = strdup(name); - node->childCount = 0; + node->customType.name = strdup(name); return node; } @@ -73,9 +70,7 @@ Node* MakeReferenceTypeNode( ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = ReferenceTypeNode; - node->childCount = 1; - node->children = (Node**) malloc(sizeof(Node*)); - node->children[0] = typeNode; + node->referenceType.type = typeNode; return node; } @@ -84,9 +79,7 @@ Node* MakeTypeNode( ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Type; - node->childCount = 1; - node->children = (Node**) malloc(sizeof(Node*)); - node->children[0] = typeNode; + node->type.typeNode = typeNode; return node; } @@ -95,8 +88,7 @@ Node* MakeIdentifierNode( ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Identifier; - node->value.string = strdup(id); - node->childCount = 0; + node->identifier.name = strdup(id); node->typeTag = NULL; return node; } @@ -107,8 +99,7 @@ Node* MakeNumberNode( char *ptr; Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Number; - node->value.number = strtoul(numberString, &ptr, 10); - node->childCount = 0; + node->number.value = strtoul(numberString, &ptr, 10); return node; } @@ -118,8 +109,7 @@ Node* MakeStringNode( size_t slen = strlen(string); Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = StringLiteral; - node->value.string = strndup(string + 1, slen - 2); - node->childCount = 0; + node->stringLiteral.string = strndup(string + 1, slen - 2); return node; } @@ -127,10 +117,10 @@ Node* MakeStaticNode() { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = StaticModifier; - node->childCount = 0; return node; } +/* FIXME: this sucks */ Node* MakeFunctionModifiersNode( Node **pModifierNodes, uint32_t modifierCount @@ -138,13 +128,14 @@ Node* MakeFunctionModifiersNode( uint32_t i; Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionModifiers; - node->childCount = modifierCount; + node->functionModifiers.count = modifierCount; + node->functionModifiers.sequence = NULL; if (modifierCount > 0) { - node->children = malloc(sizeof(Node*) * node->childCount); + node->functionModifiers.sequence = malloc(sizeof(Node*) * node->functionModifiers.count); for (i = 0; i < modifierCount; i += 1) { - node->children[i] = pModifierNodes[i]; + node->functionModifiers.sequence[i] = pModifierNodes[i]; } } @@ -157,10 +148,8 @@ Node* MakeUnaryNode( ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = UnaryExpression; - node->operator.unaryOperator = operator; - node->children = malloc(sizeof(Node*)); - node->children[0] = child; - node->childCount = 1; + node->unaryExpression.operator = operator; + node->unaryExpression.child = child; return node; } @@ -171,11 +160,9 @@ Node* MakeBinaryNode( ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = BinaryExpression; - node->operator.binaryOperator = operator; - node->children = malloc(sizeof(Node*) * 2); - node->children[0] = left; - node->children[1] = right; - node->childCount = 2; + node->binaryExpression.left = left; + node->binaryExpression.right = right; + node->binaryExpression.operator = operator; return node; } @@ -185,10 +172,8 @@ Node* MakeDeclarationNode( ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Declaration; - node->children = (Node**) malloc(sizeof(Node*) * 2); - node->childCount = 2; - node->children[0] = typeNode; - node->children[1] = identifierNode; + node->declaration.type = typeNode; + node->declaration.identifier = identifierNode; return node; } @@ -198,10 +183,8 @@ Node* MakeAssignmentNode( ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Assignment; - node->childCount = 2; - node->children = malloc(sizeof(Node*) * 2); - node->children[0] = left; - node->children[1] = right; + node->assignmentStatement.left = left; + node->assignmentStatement.right = right; return node; } @@ -210,9 +193,9 @@ Node* StartStatementSequenceNode( ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = StatementSequence; - node->children = (Node**) malloc(sizeof(Node*)); - node->childCount = 1; - node->children[0] = statementNode; + node->statementSequence.sequence = (Node**) malloc(sizeof(Node*)); + node->statementSequence.sequence[0] = statementNode; + node->statementSequence.count = 1; return node; } @@ -220,9 +203,9 @@ Node* AddStatement( Node* statementSequenceNode, Node *statementNode ) { - statementSequenceNode->children = realloc(statementSequenceNode->children, sizeof(Node*) * (statementSequenceNode->childCount + 1)); - statementSequenceNode->children[statementSequenceNode->childCount] = statementNode; - statementSequenceNode->childCount += 1; + statementSequenceNode->statementSequence.sequence = realloc(statementSequenceNode->statementSequence.sequence, sizeof(Node*) * (statementSequenceNode->statementSequence.count + 1)); + statementSequenceNode->statementSequence.sequence[statementSequenceNode->statementSequence.count] = statementNode; + statementSequenceNode->statementSequence.count += 1; return statementSequenceNode; } @@ -231,9 +214,7 @@ Node* MakeReturnStatementNode( ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Return; - node->children = (Node**) malloc(sizeof(Node*)); - node->childCount = 1; - node->children[0] = expressionNode; + node->returnStatement.expression = expressionNode; return node; } @@ -241,8 +222,6 @@ Node* MakeReturnVoidStatementNode() { Node *node = (Node*) malloc(sizeof(Node)); node->syntaxKind = ReturnVoid; - node->childCount = 0; - node->children = NULL; return node; } @@ -251,9 +230,9 @@ Node *StartFunctionSignatureArgumentsNode( ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionSignatureArguments; - node->childCount = 1; - node->children = (Node**) malloc(sizeof(Node*)); - node->children[0] = argumentNode; + node->functionSignatureArguments.sequence = (Node**) malloc(sizeof(Node*)); + node->functionSignatureArguments.sequence[0] = argumentNode; + node->functionSignatureArguments.count = 1; return node; } @@ -261,9 +240,9 @@ Node* AddFunctionSignatureArgumentNode( Node *argumentsNode, Node *argumentNode ) { - argumentsNode->children = realloc(argumentsNode->children, sizeof(Node*) * (argumentsNode->childCount + 1)); - argumentsNode->children[argumentsNode->childCount] = argumentNode; - argumentsNode->childCount += 1; + argumentsNode->functionSignatureArguments.sequence = realloc(argumentsNode->functionSignatureArguments.sequence, sizeof(Node*) * (argumentsNode->functionSignatureArguments.count + 1)); + argumentsNode->functionSignatureArguments.sequence[argumentsNode->functionSignatureArguments.count] = argumentNode; + argumentsNode->functionSignatureArguments.count += 1; return argumentsNode; } @@ -271,8 +250,8 @@ Node *MakeEmptyFunctionSignatureArgumentsNode() { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionSignatureArguments; - node->childCount = 0; - node->children = NULL; + node->functionSignatureArguments.sequence = NULL; + node->functionSignatureArguments.count = 0; return node; } @@ -284,12 +263,10 @@ Node* MakeFunctionSignatureNode( ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionSignature; - node->childCount = 4; - node->children = (Node**) malloc(sizeof(Node*) * (node->childCount)); - node->children[0] = identifierNode; - node->children[1] = typeNode; - node->children[2] = arguments; - node->children[3] = modifiersNode; + node->functionSignature.identifier = identifierNode; + node->functionSignature.type = typeNode; + node->functionSignature.arguments = arguments; + node->functionSignature.modifiers = modifiersNode; return node; } @@ -299,10 +276,8 @@ Node* MakeFunctionDeclarationNode( ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionDeclaration; - node->childCount = 2; - node->children = (Node**) malloc(sizeof(Node*) * 2); - node->children[0] = functionSignatureNode; - node->children[1] = functionBodyNode; + node->functionDeclaration.functionSignature = functionSignatureNode; + node->functionDeclaration.functionBody = functionBodyNode; return node; } @@ -312,10 +287,8 @@ Node* MakeStructDeclarationNode( ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = StructDeclaration; - node->childCount = 2; - node->children = (Node**) malloc(sizeof(Node*) * 2); - node->children[0] = identifierNode; - node->children[1] = declarationSequenceNode; + node->structDeclaration.identifier = identifierNode; + node->structDeclaration.declarationSequence = declarationSequenceNode; return node; } @@ -324,9 +297,9 @@ Node* StartDeclarationSequenceNode( ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = DeclarationSequence; - node->children = (Node**) malloc(sizeof(Node*)); - node->childCount = 1; - node->children[0] = declarationNode; + node->declarationSequence.sequence = (Node**) malloc(sizeof(Node*)); + node->declarationSequence.sequence[0] = declarationNode; + node->declarationSequence.count = 1; return node; } @@ -334,9 +307,9 @@ Node* AddDeclarationNode( Node *declarationSequenceNode, Node *declarationNode ) { - declarationSequenceNode->children = realloc(declarationSequenceNode->children, sizeof(Node*) * (declarationSequenceNode->childCount + 1)); - declarationSequenceNode->children[declarationSequenceNode->childCount] = declarationNode; - declarationSequenceNode->childCount += 1; + declarationSequenceNode->declarationSequence.sequence = (Node**) realloc(declarationSequenceNode->declarationSequence.sequence, sizeof(Node*) * (declarationSequenceNode->declarationSequence.count + 1)); + declarationSequenceNode->declarationSequence.sequence[declarationSequenceNode->declarationSequence.count] = declarationNode; + declarationSequenceNode->declarationSequence.count += 1; return declarationSequenceNode; } @@ -345,9 +318,9 @@ Node* StartFunctionArgumentSequenceNode( ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionArgumentSequence; - node->childCount = 1; - node->children = (Node**) malloc(sizeof(Node*)); - node->children[0] = argumentNode; + node->functionArgumentSequence.sequence = (Node**) malloc(sizeof(Node*)); + node->functionArgumentSequence.sequence[0] = argumentNode; + node->functionArgumentSequence.count = 1; return node; } @@ -355,9 +328,9 @@ Node* AddFunctionArgumentNode( Node *argumentSequenceNode, Node *argumentNode ) { - argumentSequenceNode->children = realloc(argumentSequenceNode->children, sizeof(Node*) * (argumentSequenceNode->childCount + 1)); - argumentSequenceNode->children[argumentSequenceNode->childCount] = argumentNode; - argumentSequenceNode->childCount += 1; + argumentSequenceNode->functionArgumentSequence.sequence = (Node**) realloc(argumentSequenceNode->functionArgumentSequence.sequence, sizeof(Node*) * (argumentSequenceNode->functionArgumentSequence.count + 1)); + argumentSequenceNode->functionArgumentSequence.sequence[argumentSequenceNode->functionArgumentSequence.count] = argumentNode; + argumentSequenceNode->functionArgumentSequence.count += 1; return argumentSequenceNode; } @@ -365,8 +338,8 @@ Node *MakeEmptyFunctionArgumentSequenceNode() { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionArgumentSequence; - node->childCount = 0; - node->children = NULL; + node->functionArgumentSequence.count = 0; + node->functionArgumentSequence.sequence = NULL; return node; } @@ -376,10 +349,8 @@ Node* MakeFunctionCallExpressionNode( ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionCallExpression; - node->children = (Node**) malloc(sizeof(Node*) * 2); - node->childCount = 2; - node->children[0] = identifierNode; - node->children[1] = argumentSequenceNode; + node->functionCallExpression.identifier = identifierNode; + node->functionCallExpression.argumentSequence = argumentSequenceNode; return node; } @@ -389,10 +360,8 @@ Node* MakeAccessExpressionNode( ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = AccessExpression; - node->children = (Node**) malloc(sizeof(Node*) * 2); - node->childCount = 2; - node->children[0] = accessee; - node->children[1] = accessor; + node->accessExpression.accessee = accessee; + node->accessExpression.accessor = accessor; return node; } @@ -400,9 +369,7 @@ Node* MakeAllocNode(Node *typeNode) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = AllocExpression; - node->childCount = 1; - node->children = (Node**) malloc(sizeof(Node*)); - node->children[0] = typeNode; + node->allocExpression.type = typeNode; return node; } @@ -412,40 +379,34 @@ Node* MakeIfNode( ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = IfStatement; - node->childCount = 2; - node->children = (Node**) malloc(sizeof(Node*)); - node->children[0] = expressionNode; - node->children[1] = statementSequenceNode; + node->ifStatement.expression = expressionNode; + node->ifStatement.statementSequence = statementSequenceNode; return node; } Node* MakeIfElseNode( Node *ifNode, - Node *statementSequenceNode + Node *elseNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = IfElseStatement; - node->childCount = 2; - node->children = (Node**) malloc(sizeof(Node*)); - node->children[0] = ifNode; - node->children[1] = statementSequenceNode; + node->ifElseStatement.ifStatement = ifNode; + node->ifElseStatement.elseStatement = elseNode; return node; } Node* MakeForLoopNode( - Node *identifierNode, + Node *declarationNode, Node *startNumberNode, Node *endNumberNode, Node *statementSequenceNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = ForLoop; - node->childCount = 4; - node->children = (Node**) malloc(sizeof(Node*) * 4); - node->children[0] = identifierNode; - node->children[1] = startNumberNode; - node->children[2] = endNumberNode; - node->children[3] = statementSequenceNode; + node->forLoop.declaration = declarationNode; + node->forLoop.startNumber = startNumberNode; + node->forLoop.endNumber = endNumberNode; + node->forLoop.statementSequence = statementSequenceNode; return node; } @@ -462,9 +423,19 @@ static const char* PrimitiveTypeToString(PrimitiveType type) return "Unknown"; } -static void PrintBinaryOperator(BinaryOperator expression) +static void PrintUnaryOperator(UnaryOperator operator) { - switch (expression) + switch (operator) + { + case Negate: + printf("!"); + break; + } +} + +static void PrintBinaryOperator(BinaryOperator operator) +{ + switch (operator) { case Add: printf("+"); @@ -480,7 +451,7 @@ static void PrintBinaryOperator(BinaryOperator expression) } } -static void PrintNode(Node *node, int tabCount) +void PrintNode(Node *node, uint32_t tabCount) { uint32_t i; for (i = 0; i < tabCount; i += 1) @@ -491,48 +462,157 @@ static void PrintNode(Node *node, int tabCount) printf("%s: ", SyntaxKindString(node->syntaxKind)); switch (node->syntaxKind) { - case BinaryExpression: - PrintBinaryOperator(node->operator.binaryOperator); + case AccessExpression: + PrintNode(node->accessExpression.accessee, tabCount + 1); + PrintNode(node->accessExpression.accessor, tabCount + 1); break; - case Declaration: + case AllocExpression: + PrintNode(node->allocExpression.type, tabCount + 1); + break; + + case Assignment: + PrintNode(node->assignmentStatement.left, tabCount + 1); + PrintNode(node->assignmentStatement.right, tabCount + 1); + break; + + case BinaryExpression: + PrintNode(node->binaryExpression.left, tabCount + 1); + PrintBinaryOperator(node->binaryExpression.operator); + PrintNode(node->binaryExpression.right, tabCount + 1); break; case CustomTypeNode: - printf("%s", node->value.string); + printf("%s", node->customType.name); break; - case PrimitiveTypeNode: - printf("%s", PrimitiveTypeToString(node->primitiveType)); + case Declaration: + PrintNode(node->declaration.identifier, tabCount + 1); + PrintNode(node->declaration.type, tabCount + 1); + break; + + case DeclarationSequence: + for (i = 0; i < node->declarationSequence.count; i += 1) + { + PrintNode(node->declarationSequence.sequence[i], tabCount + 1); + } + break; + + case ForLoop: + PrintNode(node->forLoop.declaration, tabCount + 1); + PrintNode(node->forLoop.startNumber, tabCount + 1); + PrintNode(node->forLoop.endNumber, tabCount + 1); + PrintNode(node->forLoop.statementSequence, tabCount + 1); + break; + + case FunctionArgumentSequence: + for (i = 0; i < node->functionArgumentSequence.count; i += 1) + { + PrintNode(node->functionArgumentSequence.sequence[i], tabCount + 1); + } + break; + + case FunctionCallExpression: + PrintNode(node->functionCallExpression.identifier, tabCount + 1); + PrintNode(node->functionCallExpression.argumentSequence, tabCount + 1); + break; + + case FunctionDeclaration: + PrintNode(node->functionDeclaration.functionSignature, tabCount + 1); + PrintNode(node->functionDeclaration.functionBody, tabCount + 1); + break; + + case FunctionModifiers: + for (i = 0; i < node->functionModifiers.count; i += 1) + { + PrintNode(node->functionModifiers.sequence[i], tabCount + 1); + } + break; + + case FunctionSignature: + PrintNode(node->functionSignature.identifier, tabCount + 1); + PrintNode(node->functionSignature.arguments, tabCount + 1); + PrintNode(node->functionSignature.type, tabCount + 1); + PrintNode(node->functionSignature.modifiers, tabCount + 1); + break; + + case FunctionSignatureArguments: + for (i = 0; i < node->functionSignatureArguments.count; i += 1) + { + PrintNode(node->functionSignatureArguments.sequence[i], tabCount + 1); + } break; case Identifier: if (node->typeTag == NULL) { - printf("%s", node->value.string); + printf("%s", node->identifier.name); } else { char *type = TypeTagToString(node->typeTag); - printf("%s<%s>", node->value.string, type); + printf("%s<%s>", node->identifier.name, type); } break; + case IfStatement: + PrintNode(node->ifStatement.expression, tabCount + 1); + PrintNode(node->ifStatement.statementSequence, tabCount + 1); + break; + + case IfElseStatement: + PrintNode(node->ifElseStatement.ifStatement, tabCount + 1); + PrintNode(node->ifElseStatement.elseStatement, tabCount + 1); + break; + case Number: - printf("%lu", node->value.number); + printf("%lu", node->number.value); + break; + + case PrimitiveTypeNode: + printf("%s", PrimitiveTypeToString(node->primitiveType.type)); + break; + + case ReferenceTypeNode: + PrintNode(node->referenceType.type, tabCount + 1); + break; + + case Return: + PrintNode(node->returnStatement.expression, tabCount + 1); + break; + + case ReturnVoid: + break; + + case StatementSequence: + for (i = 0; i < node->statementSequence.count; i += 1) + { + PrintNode(node->statementSequence.sequence[i], tabCount + 1); + } + break; + + case StaticModifier: + break; + + case StringLiteral: + printf("%s", node->stringLiteral.string); + break; + + case StructDeclaration: + PrintNode(node->structDeclaration.identifier, tabCount + 1); + PrintNode(node->structDeclaration.declarationSequence, tabCount + 1); + break; + + case Type: + PrintNode(node->type.typeNode, tabCount + 1); + break; + + case UnaryExpression: + PrintUnaryOperator(node->unaryExpression.operator); + PrintNode(node->unaryExpression.child, tabCount + 1); break; } printf("\n"); } -void PrintTree(Node *node, uint32_t tabCount) -{ - uint32_t i; - PrintNode(node, tabCount); - for (i = 0; i < node->childCount; i += 1) - { - PrintTree(node->children[i], tabCount + 1); - } -} - TypeTag* MakeTypeTag(Node *node) { if (node == NULL) { fprintf(stderr, "wraith: Attempted to call MakeTypeTag on null value.\n"); @@ -542,40 +622,40 @@ TypeTag* MakeTypeTag(Node *node) { TypeTag *tag = (TypeTag*)malloc(sizeof(TypeTag)); switch (node->syntaxKind) { case Type: - tag = MakeTypeTag(node->children[0]); + tag = MakeTypeTag(node->type.typeNode); break; case PrimitiveTypeNode: tag->type = Primitive; - tag->value.primitiveType = node->primitiveType; + tag->value.primitiveType = node->primitiveType.type; break; case ReferenceTypeNode: tag->type = Reference; - tag->value.referenceType = MakeTypeTag(node->children[0]); + tag->value.referenceType = MakeTypeTag(node->referenceType.type); break; case CustomTypeNode: tag->type = Custom; - tag->value.customType = strdup(node->value.string); + tag->value.customType = strdup(node->customType.name); break; case Declaration: - tag = MakeTypeTag(node->children[0]); + tag = MakeTypeTag(node->declaration.type); break; case StructDeclaration: tag->type = Custom; - tag->value.customType = strdup(node->children[0]->value.string); + tag->value.customType = strdup(node->structDeclaration.identifier->identifier.name); printf("Struct tag: %s\n", TypeTagToString(tag)); break; case FunctionDeclaration: - tag = MakeTypeTag(node->children[0]->children[1]); + tag = MakeTypeTag(node->functionDeclaration.functionSignature->functionSignature.type); break; default: - fprintf(stderr, + fprintf(stderr, "wraith: Attempted to call MakeTypeTag on" " node with unsupported SyntaxKind: %s\n", SyntaxKindString(node->syntaxKind)); @@ -605,4 +685,4 @@ char* TypeTagToString(TypeTag *tag) { case Custom: return tag->value.customType; } -} \ No newline at end of file +} diff --git a/src/ast.h b/src/ast.h index 9c1312f..ff97f3e 100644 --- a/src/ast.h +++ b/src/ast.h @@ -4,6 +4,15 @@ #include #include "identcheck.h" +/* -Wpedantic nameless union/struct silencing */ +#ifndef WRAITHNAMELESS +#ifdef __GNUC__ +#define WRAITHNAMELESS __extension__ +#else +#define WRAITHNAMELESS +#endif /* __GNUC__ */ +#endif /* WRAITHNAMELESS */ + typedef enum { AccessExpression, @@ -14,7 +23,6 @@ typedef enum CustomTypeNode, Declaration, DeclarationSequence, - Expression, ForLoop, FunctionArgumentSequence, FunctionCallExpression, @@ -92,25 +100,184 @@ typedef struct TypeTag } value; } TypeTag; -typedef struct Node +typedef struct Node Node; + +struct Node { + Node *parent; SyntaxKind syntaxKind; - struct Node **children; - uint32_t childCount; - union + WRAITHNAMELESS union { - UnaryOperator unaryOperator; - BinaryOperator binaryOperator; - } operator; - union - { - char *string; - uint64_t number; - } value; - PrimitiveType primitiveType; + struct + { + Node *accessee; + Node *accessor; + } accessExpression; + + struct + { + Node *type; + } allocExpression; + + struct + { + Node *left; + Node *right; + } assignmentStatement; + + struct + { + Node *left; + Node *right; + BinaryOperator operator; + } binaryExpression; + + struct + { + + } comment; + + struct + { + char *name; + } customType; + + struct + { + Node *type; + Node *identifier; + } declaration; + + struct + { + Node **sequence; + uint32_t count; + } declarationSequence; + + struct + { + Node *declaration; + Node *startNumber; + Node *endNumber; + Node *statementSequence; + } forLoop; + + struct + { + Node **sequence; + uint32_t count; + } functionArgumentSequence; + + struct + { + Node *identifier; /* FIXME: need better name */ + Node *argumentSequence; + } functionCallExpression; + + struct + { + Node *functionSignature; + Node *functionBody; + } functionDeclaration; + + struct + { + Node **sequence; + uint32_t count; + } functionModifiers; + + struct + { + Node *identifier; + Node *type; + Node *arguments; + Node *modifiers; + } functionSignature; + + struct + { + Node **sequence; + uint32_t count; + } functionSignatureArguments; + + struct + { + char *name; + } identifier; + + struct + { + Node *expression; + Node *statementSequence; + } ifStatement; + + struct + { + Node *ifStatement; + Node *elseStatement; + } ifElseStatement; + + struct + { + uint64_t value; + } number; + + struct + { + PrimitiveType type; + } primitiveType; + + struct + { + Node *type; + } referenceType; + + struct + { + Node *expression; + } returnStatement; + + struct + { + + } returnVoidStatement; + + struct + { + Node **sequence; + uint32_t count; + } statementSequence; + + struct + { + + } staticModifier; /* FIXME: modifiers should just be an enum */ + + struct + { + char *string; + } stringLiteral; + + struct + { + Node *identifier; + Node *declarationSequence; + } structDeclaration; + + struct + { + Node *typeNode; + } type; /* FIXME: this needs a refactor */ + + struct + { + Node *child; + UnaryOperator operator; + } unaryExpression; + }; TypeTag *typeTag; IdNode *idLink; -} Node; +}; const char* SyntaxKindString(SyntaxKind syntaxKind); @@ -223,7 +390,7 @@ Node* MakeIfNode( ); Node* MakeIfElseNode( Node *ifNode, - Node *statementSequenceNode + Node *elseNode /* can be a conditional or a statement sequence */ ); Node* MakeForLoopNode( Node *identifierNode, @@ -232,7 +399,7 @@ Node* MakeForLoopNode( Node *statementSequenceNode ); -void PrintTree(Node *node, uint32_t tabCount); +void PrintNode(Node *node, uint32_t tabCount); const char* SyntaxKindString(SyntaxKind syntaxKind); TypeTag* MakeTypeTag(Node *node); diff --git a/src/codegen.c b/src/codegen.c index 3e75d1b..87f9850 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -263,7 +263,7 @@ static void AddStructDeclaration( for (i = 0; i < fieldDeclarationCount; i += 1) { structTypeDeclarations[index].fields = realloc(structTypeDeclarations[index].fields, sizeof(StructTypeField) * (structTypeDeclarations[index].fieldCount + 1)); - structTypeDeclarations[index].fields[i].name = strdup(fieldDeclarations[i]->children[1]->value.string); + structTypeDeclarations[index].fields[i].name = strdup(fieldDeclarations[i]->declaration.identifier->identifier.name); structTypeDeclarations[index].fields[i].index = i; structTypeDeclarations[index].fieldCount += 1; } @@ -319,16 +319,15 @@ static LLVMTypeRef ResolveType(Node* typeNode) { if (IsPrimitiveType(typeNode)) { - return WraithTypeToLLVMType(typeNode->children[0]->primitiveType); + return WraithTypeToLLVMType(typeNode->type.typeNode->primitiveType.type); } - else if (typeNode->children[0]->syntaxKind == CustomTypeNode) + else if (typeNode->type.typeNode->syntaxKind == CustomTypeNode) { - char *typeName = typeNode->children[0]->value.string; - return LookupCustomType(typeName); + return LookupCustomType(typeNode->type.typeNode->customType.name); } - else if (typeNode->children[0]->syntaxKind == ReferenceTypeNode) + else if (typeNode->type.typeNode->syntaxKind == ReferenceTypeNode) { - return LLVMPointerType(ResolveType(typeNode->children[0]->children[0]), 0); + return LLVMPointerType(ResolveType(typeNode->type.typeNode->referenceType.type), 0); } else { @@ -443,24 +442,24 @@ static LLVMValueRef CompileExpression( static LLVMValueRef CompileNumber( Node *numberExpression ) { - return LLVMConstInt(LLVMInt64Type(), numberExpression->value.number, 0); + return LLVMConstInt(LLVMInt64Type(), numberExpression->number.value, 0); } static LLVMValueRef CompileString( LLVMBuilderRef builder, Node *stringExpression ) { - return LLVMBuildGlobalStringPtr(builder, stringExpression->value.string, "stringConstant"); + return LLVMBuildGlobalStringPtr(builder, stringExpression->stringLiteral.string, "stringConstant"); } static LLVMValueRef CompileBinaryExpression( LLVMBuilderRef builder, Node *binaryExpression ) { - LLVMValueRef left = CompileExpression(builder, binaryExpression->children[0]); - LLVMValueRef right = CompileExpression(builder, binaryExpression->children[1]); + LLVMValueRef left = CompileExpression(builder, binaryExpression->binaryExpression.left); + LLVMValueRef right = CompileExpression(builder, binaryExpression->binaryExpression.right); - switch (binaryExpression->operator.binaryOperator) + switch (binaryExpression->binaryExpression.operator) { case Add: return LLVMBuildAdd(builder, left, right, "addResult"); @@ -494,11 +493,11 @@ static LLVMValueRef CompileBinaryExpression( /* FIXME THIS IS ALL BROKEN */ static LLVMValueRef CompileFunctionCallExpression( LLVMBuilderRef builder, - Node *expression + Node *functionCallExpression ) { uint32_t i; uint32_t argumentCount = 0; - LLVMValueRef args[expression->children[1]->childCount + 1]; + LLVMValueRef args[functionCallExpression->functionCallExpression.argumentSequence->functionArgumentSequence.count + 1]; LLVMValueRef function; uint8_t isStatic; LLVMValueRef structInstance; @@ -506,25 +505,26 @@ static LLVMValueRef CompileFunctionCallExpression( char *returnName = ""; /* FIXME: this needs to be recursive on access chains */ - if (expression->children[0]->syntaxKind == AccessExpression) + /* FIXME: this needs to be able to call same-struct functions implicitly */ + if (functionCallExpression->functionCallExpression.identifier->syntaxKind == AccessExpression) { LLVMTypeRef typeReference = FindStructType( - expression->children[0]->children[0]->value.string + functionCallExpression->functionCallExpression.identifier->identifier.name ); if (typeReference != NULL) { function = LookupFunctionByType( typeReference, - expression->children[0]->children[1]->value.string, + functionCallExpression->functionCallExpression.identifier->accessExpression.accessor->identifier.name, &functionReturnType, &isStatic ); } else { - structInstance = FindVariablePointer(expression->children[0]->children[0]->value.string); - function = LookupFunctionByInstance(structInstance, expression->children[0]->children[1]->value.string, &functionReturnType, &isStatic); + structInstance = FindVariablePointer(functionCallExpression->functionCallExpression.identifier->accessExpression.accessee->identifier.name); + function = LookupFunctionByInstance(structInstance, functionCallExpression->functionCallExpression.identifier->accessExpression.accessor->identifier.name, &functionReturnType, &isStatic); } } else @@ -539,9 +539,9 @@ static LLVMValueRef CompileFunctionCallExpression( argumentCount += 1; } - for (i = 0; i < expression->children[1]->childCount; i += 1) + for (i = 0; i < functionCallExpression->functionCallExpression.argumentSequence->functionArgumentSequence.count; i += 1) { - args[argumentCount] = CompileExpression(builder, expression->children[1]->children[i]); + args[argumentCount] = CompileExpression(builder, functionCallExpression->functionCallExpression.argumentSequence->functionArgumentSequence.sequence[i]); argumentCount += 1; } @@ -555,30 +555,26 @@ static LLVMValueRef CompileFunctionCallExpression( static LLVMValueRef CompileAccessExpressionForStore( LLVMBuilderRef builder, - Node *expression + Node *accessExpression ) { - Node *accessee = expression->children[0]; - Node *accessor = expression->children[1]; - LLVMValueRef accesseeValue = FindVariablePointer(accessee->value.string); - return FindStructFieldPointer(builder, accesseeValue, accessor->value.string); + LLVMValueRef accesseeValue = FindVariablePointer(accessExpression->accessExpression.accessee->identifier.name); + return FindStructFieldPointer(builder, accesseeValue, accessExpression->accessExpression.accessor->identifier.name); } static LLVMValueRef CompileAccessExpression( LLVMBuilderRef builder, - Node *expression + Node *accessExpression ) { - Node *accessee = expression->children[0]; - Node *accessor = expression->children[1]; - LLVMValueRef accesseeValue = FindVariablePointer(accessee->value.string); - LLVMValueRef access = FindStructFieldPointer(builder, accesseeValue, accessor->value.string); - return LLVMBuildLoad(builder, access, accessor->value.string); + LLVMValueRef accesseeValue = FindVariablePointer(accessExpression->accessExpression.accessee->identifier.name); + LLVMValueRef access = FindStructFieldPointer(builder, accesseeValue, accessExpression->accessExpression.accessor->identifier.name); + return LLVMBuildLoad(builder, access, accessExpression->accessExpression.accessor->identifier.name); } static LLVMValueRef CompileAllocExpression( LLVMBuilderRef builder, - Node *expression + Node *allocExpression ) { - LLVMTypeRef type = ResolveType(expression->children[0]); + LLVMTypeRef type = ResolveType(allocExpression->allocExpression.type); return LLVMBuildMalloc(builder, type, "allocation"); } @@ -601,7 +597,7 @@ static LLVMValueRef CompileExpression( return CompileFunctionCallExpression(builder, expression); case Identifier: - return FindVariableValue(builder, expression->value.string); + return FindVariableValue(builder, expression->identifier.name); case Number: return CompileNumber(expression); @@ -619,7 +615,7 @@ static LLVMBasicBlockRef CompileStatement(LLVMBuilderRef builder, LLVMValueRef f static LLVMBasicBlockRef CompileReturn(LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement) { - LLVMValueRef expression = CompileExpression(builder, returnStatemement->children[0]); + LLVMValueRef expression = CompileExpression(builder, returnStatemement->returnStatement.expression); LLVMBuildRet(builder, expression); return LLVMGetLastBasicBlock(function); } @@ -634,11 +630,11 @@ static LLVMBasicBlockRef CompileReturnVoid(LLVMBuilderRef builder, LLVMValueRef static LLVMValueRef CompileFunctionVariableDeclaration(LLVMBuilderRef builder, LLVMValueRef function, Node *variableDeclaration) { LLVMValueRef variable; - char *variableName = variableDeclaration->children[1]->value.string; + char *variableName = variableDeclaration->declaration.identifier->identifier.name; char *ptrName = strdup(variableName); strcat(ptrName, "_ptr"); - variable = LLVMBuildAlloca(builder, ResolveType(variableDeclaration->children[0]), ptrName); + variable = LLVMBuildAlloca(builder, ResolveType(variableDeclaration->declaration.type), ptrName); free(ptrName); @@ -649,19 +645,19 @@ static LLVMValueRef CompileFunctionVariableDeclaration(LLVMBuilderRef builder, L static LLVMBasicBlockRef CompileAssignment(LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement) { - LLVMValueRef result = CompileExpression(builder, assignmentStatement->children[1]); + LLVMValueRef result = CompileExpression(builder, assignmentStatement->assignmentStatement.right); LLVMValueRef identifier; - if (assignmentStatement->children[0]->syntaxKind == AccessExpression) + if (assignmentStatement->assignmentStatement.left->syntaxKind == AccessExpression) { - identifier = CompileAccessExpressionForStore(builder, assignmentStatement->children[0]); + identifier = CompileAccessExpressionForStore(builder, assignmentStatement->assignmentStatement.left); } - else if (assignmentStatement->children[0]->syntaxKind == Identifier) + else if (assignmentStatement->assignmentStatement.left->syntaxKind == Identifier) { - identifier = FindVariablePointer(assignmentStatement->children[0]->value.string); + identifier = FindVariablePointer(assignmentStatement->assignmentStatement.left->identifier.name); } - else if (assignmentStatement->children[0]->syntaxKind == Declaration) + else if (assignmentStatement->assignmentStatement.left->syntaxKind == Declaration) { - identifier = CompileFunctionVariableDeclaration(builder, function, assignmentStatement->children[0]); + identifier = CompileFunctionVariableDeclaration(builder, function, assignmentStatement->assignmentStatement.left); } else { @@ -677,7 +673,7 @@ static LLVMBasicBlockRef CompileAssignment(LLVMBuilderRef builder, LLVMValueRef static LLVMBasicBlockRef CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement) { uint32_t i; - LLVMValueRef conditional = CompileExpression(builder, ifStatement->children[0]); + LLVMValueRef conditional = CompileExpression(builder, ifStatement->ifStatement.expression); LLVMBasicBlockRef block = LLVMAppendBasicBlock(function, "ifBlock"); LLVMBasicBlockRef afterCond = LLVMAppendBasicBlock(function, "afterCond"); @@ -686,9 +682,9 @@ static LLVMBasicBlockRef CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef LLVMPositionBuilderAtEnd(builder, block); - for (i = 0; i < ifStatement->children[1]->childCount; i += 1) + for (i = 0; i < ifStatement->ifStatement.statementSequence->statementSequence.count; i += 1) { - CompileStatement(builder, function, ifStatement->children[1]->children[i]); + CompileStatement(builder, function, ifStatement->ifStatement.statementSequence->statementSequence.sequence[i]); } LLVMBuildBr(builder, afterCond); @@ -700,7 +696,7 @@ static LLVMBasicBlockRef CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef static LLVMBasicBlockRef CompileIfElseStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifElseStatement) { uint32_t i; - LLVMValueRef conditional = CompileExpression(builder, ifElseStatement->children[0]->children[0]); + LLVMValueRef conditional = CompileExpression(builder, ifElseStatement->ifElseStatement.ifStatement->ifStatement.expression); LLVMBasicBlockRef ifBlock = LLVMAppendBasicBlock(function, "ifBlock"); LLVMBasicBlockRef elseBlock = LLVMAppendBasicBlock(function, "elseBlock"); @@ -710,25 +706,25 @@ static LLVMBasicBlockRef CompileIfElseStatement(LLVMBuilderRef builder, LLVMValu LLVMPositionBuilderAtEnd(builder, ifBlock); - for (i = 0; i < ifElseStatement->children[0]->children[1]->childCount; i += 1) + for (i = 0; i < ifElseStatement->ifElseStatement.ifStatement->ifStatement.statementSequence->statementSequence.count; i += 1) { - CompileStatement(builder, function, ifElseStatement->children[0]->children[1]->children[i]); + CompileStatement(builder, function, ifElseStatement->ifStatement.statementSequence->statementSequence.sequence[i]); } LLVMBuildBr(builder, afterCond); LLVMPositionBuilderAtEnd(builder, elseBlock); - if (ifElseStatement->children[1]->syntaxKind == StatementSequence) + if (ifElseStatement->ifElseStatement.elseStatement->syntaxKind == StatementSequence) { - for (i = 0; i < ifElseStatement->children[1]->childCount; i += 1) + for (i = 0; i < ifElseStatement->ifElseStatement.elseStatement->statementSequence.count; i += 1) { - CompileStatement(builder, function, ifElseStatement->children[1]->children[i]); + CompileStatement(builder, function, ifElseStatement->ifElseStatement.elseStatement->statementSequence.sequence[i]); } } else { - CompileStatement(builder, function, ifElseStatement->children[1]); + CompileStatement(builder, function, ifElseStatement->ifElseStatement.elseStatement); } LLVMBuildBr(builder, afterCond); @@ -744,8 +740,8 @@ static LLVMBasicBlockRef CompileForLoopStatement(LLVMBuilderRef builder, LLVMVal LLVMBasicBlockRef checkBlock = LLVMAppendBasicBlock(function, "loopCheck"); LLVMBasicBlockRef bodyBlock = LLVMAppendBasicBlock(function, "loopBody"); LLVMBasicBlockRef afterLoopBlock = LLVMAppendBasicBlock(function, "afterLoop"); - char *iteratorVariableName = forLoopStatement->children[0]->children[1]->value.string; - LLVMTypeRef iteratorVariableType = ResolveType(forLoopStatement->children[0]->children[0]); + char *iteratorVariableName = forLoopStatement->forLoop.declaration->declaration.identifier->identifier.name; + LLVMTypeRef iteratorVariableType = ResolveType(forLoopStatement->forLoop.declaration->declaration.type); PushScopeFrame(scope); @@ -762,13 +758,13 @@ static LLVMBasicBlockRef CompileForLoopStatement(LLVMBuilderRef builder, LLVMVal LLVMValueRef nextValue = LLVMBuildAdd( builder, iteratorValue, - LLVMConstInt(iteratorVariableType, forLoopStatement->children[1]->value.number, 0), + LLVMConstInt(iteratorVariableType, 1, 0), /* FIXME: add custom increment value */ "next" ); LLVMPositionBuilderAtEnd(builder, checkBlock); - LLVMValueRef iteratorEndValue = CompileNumber(forLoopStatement->children[2]); + LLVMValueRef iteratorEndValue = CompileNumber(forLoopStatement->forLoop.endNumber); LLVMValueRef comparison = LLVMBuildICmp(builder, LLVMIntULE, iteratorValue, iteratorEndValue, "iteratorCompare"); LLVMBuildCondBr(builder, comparison, bodyBlock, afterLoopBlock); @@ -776,9 +772,9 @@ static LLVMBasicBlockRef CompileForLoopStatement(LLVMBuilderRef builder, LLVMVal LLVMPositionBuilderAtEnd(builder, bodyBlock); LLVMBasicBlockRef lastBlock; - for (i = 0; i < forLoopStatement->children[3]->childCount; i += 1) + for (i = 0; i < forLoopStatement->forLoop.statementSequence->statementSequence.count; i += 1) { - lastBlock = CompileStatement(builder, function, forLoopStatement->children[3]->children[i]); + lastBlock = CompileStatement(builder, function, forLoopStatement->forLoop.statementSequence->statementSequence.sequence[i]); } LLVMBuildBr(builder, checkBlock); @@ -786,7 +782,7 @@ static LLVMBasicBlockRef CompileForLoopStatement(LLVMBuilderRef builder, LLVMVal LLVMPositionBuilderBefore(builder, LLVMGetFirstInstruction(checkBlock)); LLVMValueRef incomingValues[2]; - incomingValues[0] = CompileNumber(forLoopStatement->children[1]); + incomingValues[0] = CompileNumber(forLoopStatement->forLoop.startNumber); incomingValues[1] = nextValue; LLVMBasicBlockRef incomingBlocks[2]; @@ -848,17 +844,17 @@ static void CompileFunction( uint32_t i; uint8_t hasReturn = 0; uint8_t isStatic = 0; - Node *functionSignature = functionDeclaration->children[0]; - Node *functionBody = functionDeclaration->children[1]; - uint32_t argumentCount = functionSignature->children[2]->childCount; + Node *functionSignature = functionDeclaration->functionDeclaration.functionSignature; + Node *functionBody = functionDeclaration->functionDeclaration.functionBody; + uint32_t argumentCount = functionSignature->functionSignature.arguments->functionSignatureArguments.count; LLVMTypeRef paramTypes[argumentCount + 1]; uint32_t paramIndex = 0; - if (functionSignature->children[3]->childCount > 0) + if (functionSignature->functionSignature.modifiers->functionModifiers.count > 0) { - for (i = 0; i < functionSignature->children[3]->childCount; i += 1) + for (i = 0; i < functionSignature->functionSignature.modifiers->functionModifiers.count; i += 1) { - if (functionSignature->children[3]->children[i]->syntaxKind == StaticModifier) + if (functionSignature->functionSignature.modifiers->functionModifiers.sequence[i]->syntaxKind == StaticModifier) { isStatic = 1; break; @@ -875,22 +871,22 @@ static void CompileFunction( PushScopeFrame(scope); /* FIXME: should work for non-primitive types */ - for (i = 0; i < functionSignature->children[2]->childCount; i += 1) + for (i = 0; i < functionSignature->functionSignature.arguments->functionSignatureArguments.count; i += 1) { - paramTypes[paramIndex] = WraithTypeToLLVMType(functionSignature->children[2]->children[i]->children[0]->children[0]->primitiveType); + paramTypes[paramIndex] = ResolveType(functionSignature->functionSignature.arguments->functionSignatureArguments.sequence[i]->declaration.type); paramIndex += 1; } - LLVMTypeRef returnType = WraithTypeToLLVMType(functionSignature->children[1]->children[0]->primitiveType); + LLVMTypeRef returnType = ResolveType(functionSignature->functionSignature.type); LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, paramIndex, 0); char *functionName = strdup(parentStructName); strcat(functionName, "_"); - strcat(functionName, functionSignature->children[0]->value.string); + strcat(functionName, functionSignature->functionSignature.identifier->identifier.name); LLVMValueRef function = LLVMAddFunction(module, functionName, functionType); free(functionName); - DeclareStructFunction(wStructPointerType, function, returnType, isStatic, functionSignature->children[0]->value.string); + DeclareStructFunction(wStructPointerType, function, returnType, isStatic, functionSignature->functionSignature.identifier->identifier.name); LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry"); LLVMBuilderRef builder = LLVMCreateBuilder(); @@ -902,20 +898,20 @@ static void CompileFunction( AddStructVariablesToScope(builder, wStructPointer); } - for (i = 0; i < functionSignature->children[2]->childCount; i += 1) + for (i = 0; i < functionSignature->functionSignature.arguments->functionSignatureArguments.count; i += 1) { - char *ptrName = strdup(functionSignature->children[2]->children[i]->children[1]->value.string); + char *ptrName = strdup(functionSignature->functionSignature.arguments->functionSignatureArguments.sequence[i]->identifier.name); strcat(ptrName, "_ptr"); LLVMValueRef argument = LLVMGetParam(function, i + !isStatic); LLVMValueRef argumentCopy = LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName); LLVMBuildStore(builder, argument, argumentCopy); free(ptrName); - AddLocalVariable(scope, argumentCopy, NULL, functionSignature->children[2]->children[i]->children[1]->value.string); + AddLocalVariable(scope, argumentCopy, NULL, functionSignature->functionSignature.arguments->functionSignatureArguments.sequence[i]->identifier.name); } - for (i = 0; i < functionBody->childCount; i += 1) + for (i = 0; i < functionBody->statementSequence.count; i += 1) { - CompileStatement(builder, function, functionBody->children[i]); + CompileStatement(builder, function, functionBody->statementSequence.sequence[i]); } hasReturn = LLVMGetBasicBlockTerminator(LLVMGetLastBasicBlock(function)) != NULL; @@ -938,12 +934,12 @@ static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *no { uint32_t i; uint32_t fieldCount = 0; - uint32_t declarationCount = node->children[1]->childCount; + uint32_t declarationCount = node->structDeclaration.declarationSequence->declarationSequence.count; uint8_t packed = 1; LLVMTypeRef types[declarationCount]; Node *currentDeclarationNode; Node *fieldDeclarations[declarationCount]; - char *structName = node->children[0]->value.string; + char *structName = node->structDeclaration.identifier->identifier.name; PushScopeFrame(scope); @@ -953,12 +949,12 @@ static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *no /* first, build the structure definition */ for (i = 0; i < declarationCount; i += 1) { - currentDeclarationNode = node->children[1]->children[i]; + currentDeclarationNode = node->structDeclaration.declarationSequence->declarationSequence.sequence[i]; switch (currentDeclarationNode->syntaxKind) { case Declaration: /* this is badly named */ - types[fieldCount] = ResolveType(currentDeclarationNode->children[0]); + types[fieldCount] = ResolveType(currentDeclarationNode->declaration.type); fieldDeclarations[fieldCount] = currentDeclarationNode; fieldCount += 1; break; @@ -966,12 +962,12 @@ static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *no } LLVMStructSetBody(wStructType, types, fieldCount, packed); - AddStructDeclaration(wStructType, wStructPointerType, node->children[0]->value.string, fieldDeclarations, fieldCount); + AddStructDeclaration(wStructType, wStructPointerType, structName, fieldDeclarations, fieldCount); /* now we can wire up the functions */ for (i = 0; i < declarationCount; i += 1) { - currentDeclarationNode = node->children[1]->children[i]; + currentDeclarationNode = node->structDeclaration.declarationSequence->declarationSequence.sequence[i]; switch (currentDeclarationNode->syntaxKind) { @@ -984,15 +980,15 @@ static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *no PopScopeFrame(scope); } -static void Compile(LLVMModuleRef module, LLVMContextRef context, Node *node) +static void Compile(LLVMModuleRef module, LLVMContextRef context, Node *declarationSequenceNode) { uint32_t i; - for (i = 0; i < node->childCount; i += 1) + for (i = 0; i < declarationSequenceNode->declarationSequence.count; i += 1) { - if (node->children[i]->syntaxKind == StructDeclaration) + if (declarationSequenceNode->declarationSequence.sequence[i]->syntaxKind == StructDeclaration) { - CompileStruct(module, context, node->children[i]); + CompileStruct(module, context, declarationSequenceNode->declarationSequence.sequence[i]); } else { diff --git a/src/main.c b/src/main.c index f58f618..9ffb210 100644 --- a/src/main.c +++ b/src/main.c @@ -69,7 +69,7 @@ int main(int argc, char *argv[]) IdNode *idTree = MakeIdTree(rootNode, NULL); PrintIdTree(idTree, /*tabCount=*/0); printf("\n"); - PrintTree(rootNode, /*tabCount=*/0); + PrintNode(rootNode, /*tabCount=*/0); } exitCode = Codegen(rootNode, optimizationLevel); } diff --git a/src/parser.c b/src/parser.c index 434ed7a..c7a5f34 100644 --- a/src/parser.c +++ b/src/parser.c @@ -31,7 +31,7 @@ int Parse(char *inputFilename, Node **pRootNode, uint8_t parseVerbose) { if (parseVerbose) { - PrintTree(*pRootNode, 0); + PrintNode(*pRootNode, 0); } } else if (result == 1) -- 2.25.1 From 459a1dd3b74fafb633e8af35f788c56c6f73266d Mon Sep 17 00:00:00 2001 From: venko Date: Sat, 15 May 2021 19:00:46 -0700 Subject: [PATCH 2/2] Refactors identcheck for new AST. Fixes newline bugs in PrintNode. --- src/ast.c | 98 ++++++++++++--------- src/identcheck.c | 215 +++++++++++++++++++++++++++++------------------ 2 files changed, 191 insertions(+), 122 deletions(-) diff --git a/src/ast.c b/src/ast.c index 942ff74..b3517d7 100644 --- a/src/ast.c +++ b/src/ast.c @@ -438,15 +438,15 @@ static void PrintBinaryOperator(BinaryOperator operator) switch (operator) { case Add: - printf("+"); + printf("(+)"); break; case Subtract: - printf("-"); + printf("(-)"); break; case Multiply: - printf("*"); + printf("(*)"); break; } } @@ -463,154 +463,173 @@ void PrintNode(Node *node, uint32_t tabCount) switch (node->syntaxKind) { case AccessExpression: + printf("\n"); PrintNode(node->accessExpression.accessee, tabCount + 1); PrintNode(node->accessExpression.accessor, tabCount + 1); - break; + return; case AllocExpression: + printf("\n"); PrintNode(node->allocExpression.type, tabCount + 1); - break; + return; case Assignment: + printf("\n"); PrintNode(node->assignmentStatement.left, tabCount + 1); PrintNode(node->assignmentStatement.right, tabCount + 1); - break; + return; case BinaryExpression: - PrintNode(node->binaryExpression.left, tabCount + 1); PrintBinaryOperator(node->binaryExpression.operator); + printf("\n"); + PrintNode(node->binaryExpression.left, tabCount + 1); PrintNode(node->binaryExpression.right, tabCount + 1); - break; + return; case CustomTypeNode: - printf("%s", node->customType.name); - break; + printf("%s\n", node->customType.name); + return; case Declaration: + printf("\n"); PrintNode(node->declaration.identifier, tabCount + 1); PrintNode(node->declaration.type, tabCount + 1); - break; + return; case DeclarationSequence: + printf("\n"); for (i = 0; i < node->declarationSequence.count; i += 1) { PrintNode(node->declarationSequence.sequence[i], tabCount + 1); } - break; + return; case ForLoop: + printf("\n"); PrintNode(node->forLoop.declaration, tabCount + 1); PrintNode(node->forLoop.startNumber, tabCount + 1); PrintNode(node->forLoop.endNumber, tabCount + 1); PrintNode(node->forLoop.statementSequence, tabCount + 1); - break; + return; case FunctionArgumentSequence: + printf("\n"); for (i = 0; i < node->functionArgumentSequence.count; i += 1) { PrintNode(node->functionArgumentSequence.sequence[i], tabCount + 1); } - break; + return; case FunctionCallExpression: + printf("\n"); PrintNode(node->functionCallExpression.identifier, tabCount + 1); PrintNode(node->functionCallExpression.argumentSequence, tabCount + 1); - break; + return; case FunctionDeclaration: + printf("\n"); PrintNode(node->functionDeclaration.functionSignature, tabCount + 1); PrintNode(node->functionDeclaration.functionBody, tabCount + 1); - break; + return; case FunctionModifiers: + printf("\n"); for (i = 0; i < node->functionModifiers.count; i += 1) { PrintNode(node->functionModifiers.sequence[i], tabCount + 1); } - break; + return; case FunctionSignature: + printf("\n"); PrintNode(node->functionSignature.identifier, tabCount + 1); PrintNode(node->functionSignature.arguments, tabCount + 1); PrintNode(node->functionSignature.type, tabCount + 1); PrintNode(node->functionSignature.modifiers, tabCount + 1); - break; + return; case FunctionSignatureArguments: + printf("\n"); for (i = 0; i < node->functionSignatureArguments.count; i += 1) { PrintNode(node->functionSignatureArguments.sequence[i], tabCount + 1); } - break; + return; case Identifier: if (node->typeTag == NULL) { - printf("%s", node->identifier.name); + printf("%s\n", node->identifier.name); } else { char *type = TypeTagToString(node->typeTag); - printf("%s<%s>", node->identifier.name, type); + printf("%s<%s>\n", node->identifier.name, type); } - break; + return; case IfStatement: + printf("\n"); PrintNode(node->ifStatement.expression, tabCount + 1); PrintNode(node->ifStatement.statementSequence, tabCount + 1); - break; + return; case IfElseStatement: + printf("\n"); PrintNode(node->ifElseStatement.ifStatement, tabCount + 1); PrintNode(node->ifElseStatement.elseStatement, tabCount + 1); - break; + return; case Number: - printf("%lu", node->number.value); - break; + printf("%lu\n", node->number.value); + return; case PrimitiveTypeNode: - printf("%s", PrimitiveTypeToString(node->primitiveType.type)); - break; + printf("%s\n", PrimitiveTypeToString(node->primitiveType.type)); + return; case ReferenceTypeNode: + printf("\n"); PrintNode(node->referenceType.type, tabCount + 1); - break; + return; case Return: + printf("\n"); PrintNode(node->returnStatement.expression, tabCount + 1); - break; + return; case ReturnVoid: - break; + return; case StatementSequence: + printf("\n"); for (i = 0; i < node->statementSequence.count; i += 1) { PrintNode(node->statementSequence.sequence[i], tabCount + 1); } - break; + return; case StaticModifier: - break; + printf("\n"); + return; case StringLiteral: printf("%s", node->stringLiteral.string); - break; + return; case StructDeclaration: + printf("\n"); PrintNode(node->structDeclaration.identifier, tabCount + 1); PrintNode(node->structDeclaration.declarationSequence, tabCount + 1); - break; + return; case Type: + printf("\n"); PrintNode(node->type.typeNode, tabCount + 1); - break; + return; case UnaryExpression: PrintUnaryOperator(node->unaryExpression.operator); PrintNode(node->unaryExpression.child, tabCount + 1); - break; + return; } - - printf("\n"); } TypeTag* MakeTypeTag(Node *node) { @@ -647,7 +666,6 @@ TypeTag* MakeTypeTag(Node *node) { case StructDeclaration: tag->type = Custom; tag->value.customType = strdup(node->structDeclaration.identifier->identifier.name); - printf("Struct tag: %s\n", TypeTagToString(tag)); break; case FunctionDeclaration: diff --git a/src/identcheck.c b/src/identcheck.c index 6fe0ec4..ab50c44 100644 --- a/src/identcheck.c +++ b/src/identcheck.c @@ -38,104 +38,95 @@ IdNode* MakeIdTree(Node *astNode, IdNode *parent) { uint32_t i; IdNode *mainNode; switch (astNode->syntaxKind) { + case AccessExpression: + AddChildToNode(parent, MakeIdTree(astNode->accessExpression.accessee, parent)); + AddChildToNode(parent, MakeIdTree(astNode->accessExpression.accessor, parent)); + return NULL; + + case AllocExpression: + AddChildToNode(parent, MakeIdTree(astNode->allocExpression.type, parent)); + return NULL; + case Assignment: { - if (astNode->children[0]->syntaxKind == Declaration) { - return MakeIdTree(astNode->children[0], parent); + if (astNode->assignmentStatement.left->syntaxKind == Declaration) { + return MakeIdTree(astNode->assignmentStatement.left, parent); } else { - for (i = 0; i < astNode->childCount; i++) { - AddChildToNode(parent, MakeIdTree(astNode->children[i], parent)); - } + AddChildToNode(parent, MakeIdTree(astNode->assignmentStatement.left, parent)); + AddChildToNode(parent, MakeIdTree(astNode->assignmentStatement.right, parent)); return NULL; } } - case IfStatement: { - mainNode = MakeIdNode(OrderedScope, "if", parent); - Node *clause = astNode->children[0]; - Node *stmtSeq = astNode->children[1]; - for (i = 0; i < clause->childCount; i++) { - AddChildToNode(mainNode, MakeIdTree(clause->children[i], mainNode)); - } - for (i = 0; i < stmtSeq->childCount; i++) { - AddChildToNode(mainNode, MakeIdTree(stmtSeq->children[i], mainNode)); - } - break; - } - - case IfElseStatement: { - Node *ifNode = astNode->children[0]; - Node *elseStmts = astNode->children[1]; - mainNode = MakeIdNode(OrderedScope, "if-else", parent); - IdNode *ifBranch = MakeIdTree(ifNode, mainNode); - IdNode *elseBranch = MakeIdNode(OrderedScope, "else", mainNode); - - AddChildToNode(mainNode, ifBranch); - for (i = 0; i < elseStmts->childCount; i++) { - AddChildToNode(elseBranch, MakeIdTree(elseStmts->children[i], elseBranch)); - } - AddChildToNode(mainNode, elseBranch); - break; - } - - case ForLoop: { - Node *loopDecl = astNode->children[0]; - Node *loopBody = astNode->children[3]; - mainNode = MakeIdNode(OrderedScope, "for-loop", parent); - AddChildToNode(mainNode, MakeIdTree(loopDecl, mainNode)); - for (i = 0; i < loopBody->childCount; i++) { - AddChildToNode(mainNode, MakeIdTree(loopBody->children[i], mainNode)); - } - break; - } + case BinaryExpression: + AddChildToNode(parent, MakeIdTree(astNode->binaryExpression.left, parent)); + AddChildToNode(parent, MakeIdTree(astNode->binaryExpression.right, parent)); + return NULL; case Declaration: { - mainNode = MakeIdNode(Variable, astNode->children[1]->value.string, parent); + Node *idNode = astNode->declaration.identifier; + mainNode = MakeIdNode(Variable, idNode->identifier.name, parent); mainNode->typeTag = MakeTypeTag(astNode); - astNode->children[1]->typeTag = mainNode->typeTag; - break; - } - - case StructDeclaration: { - Node *idNode = astNode->children[0]; - Node *declsNode = astNode->children[1]; - mainNode = MakeIdNode(Struct, idNode->value.string, parent); - mainNode->typeTag = MakeTypeTag(astNode); - for (i = 0; i < declsNode->childCount; i++) { - AddChildToNode(mainNode, MakeIdTree(declsNode->children[i], mainNode)); - } - break; - } - - case FunctionDeclaration: { - Node *sigNode = astNode->children[0]; - Node *funcNameNode = sigNode->children[0]; - Node *funcArgsNode = sigNode->children[2]; - Node *bodyStatementsNode = astNode->children[1]; - mainNode = MakeIdNode(Function, funcNameNode->value.string, parent); - mainNode->typeTag = MakeTypeTag(astNode); - astNode->children[0]->children[0]->typeTag = mainNode->typeTag; - for (i = 0; i < funcArgsNode->childCount; i++) { - AddChildToNode(mainNode, MakeIdTree(funcArgsNode->children[i], mainNode)); - } - for (i = 0; i < bodyStatementsNode->childCount; i++) { - AddChildToNode(mainNode, MakeIdTree(bodyStatementsNode->children[i], mainNode)); - } + idNode->typeTag = mainNode->typeTag; break; } case DeclarationSequence: { mainNode = MakeIdNode(UnorderedScope, "", parent); - for (i = 0; i < astNode->childCount; i++) { - AddChildToNode(mainNode, MakeIdTree(astNode->children[i], mainNode)); + for (i = 0; i < astNode->declarationSequence.count; i++) { + AddChildToNode( + mainNode, MakeIdTree(astNode->declarationSequence.sequence[i], mainNode)); } break; } + case ForLoop: { + Node *loopDecl = astNode->forLoop.declaration; + Node *loopBody = astNode->forLoop.statementSequence; + mainNode = MakeIdNode(OrderedScope, "for-loop", parent); + AddChildToNode(mainNode, MakeIdTree(loopDecl, mainNode)); + AddChildToNode(mainNode, MakeIdTree(loopBody, mainNode)); + break; + } + + case FunctionArgumentSequence: + for (i = 0; i < astNode->functionArgumentSequence.count; i++) { + AddChildToNode( + parent, MakeIdTree(astNode->functionArgumentSequence.sequence[i], parent)); + } + return NULL; + + case FunctionCallExpression: + AddChildToNode(parent, MakeIdTree(astNode->functionCallExpression.identifier, parent)); + AddChildToNode( + parent, MakeIdTree(astNode->functionCallExpression.argumentSequence, parent)); + return NULL; + + case FunctionDeclaration: { + Node *sigNode = astNode->functionDeclaration.functionSignature; + Node *idNode = sigNode->functionSignature.identifier; + char *funcName = idNode->identifier.name; + mainNode = MakeIdNode(Function, funcName, parent); + mainNode->typeTag = MakeTypeTag(astNode); + idNode->typeTag = mainNode->typeTag; + MakeIdTree(sigNode->functionSignature.arguments, mainNode); + MakeIdTree(astNode->functionDeclaration.functionBody, mainNode); + break; + } + + case FunctionSignatureArguments: { + for (i = 0; i < astNode->functionSignatureArguments.count; i++) { + Node *argNode = astNode->functionSignatureArguments.sequence[i]; + AddChildToNode(parent, MakeIdTree(argNode, parent)); + } + return NULL; + } + case Identifier: { - mainNode = MakeIdNode(Placeholder, astNode->value.string, parent); - IdNode *lookupNode = LookupId(mainNode, NULL, astNode->value.string); + char *name = astNode->identifier.name; + mainNode = MakeIdNode(Placeholder, name, parent); + IdNode *lookupNode = LookupId(mainNode, NULL, name); if (lookupNode == NULL) { - fprintf(stderr, "wraith: Could not find IdNode for id %s\n", astNode->value.string); + fprintf(stderr, "wraith: Could not find IdNode for id %s\n", name); TypeTag *tag = (TypeTag*)malloc(sizeof(TypeTag)); tag->type = Unknown; astNode->typeTag = tag; @@ -145,12 +136,73 @@ IdNode* MakeIdTree(Node *astNode, IdNode *parent) { break; } - default: { - for (i = 0; i < astNode->childCount; i++) { - AddChildToNode(parent, MakeIdTree(astNode->children[i], parent)); + case IfStatement: { + Node *clause = astNode->ifStatement.expression; + Node *stmtSeq = astNode->ifStatement.statementSequence; + mainNode = MakeIdNode(OrderedScope, "if", parent); + MakeIdTree(clause, mainNode); + MakeIdTree(stmtSeq, mainNode); + break; + } + + case IfElseStatement: { + Node *ifNode = astNode->ifElseStatement.ifStatement; + Node *elseStmts = astNode->ifElseStatement.elseStatement; + mainNode = MakeIdNode(OrderedScope, "if-else", parent); + IdNode *ifBranch = MakeIdTree(ifNode, mainNode); + AddChildToNode(mainNode, ifBranch); + IdNode *elseScope = MakeIdNode(OrderedScope, "else", mainNode); + MakeIdTree(elseStmts, elseScope); + AddChildToNode(mainNode, elseScope); + break; + } + + case ReferenceTypeNode: + AddChildToNode(parent, MakeIdTree(astNode->referenceType.type, parent)); + return NULL; + + case Return: + AddChildToNode(parent, MakeIdTree(astNode->returnStatement.expression, parent)); + return NULL; + + case StatementSequence: { + for (i = 0; i < astNode->statementSequence.count; i++) { + Node *argNode = astNode->statementSequence.sequence[i]; + AddChildToNode(parent, MakeIdTree(argNode, parent)); } return NULL; } + + case StructDeclaration: { + Node *idNode = astNode->structDeclaration.identifier; + Node *declsNode = astNode->structDeclaration.declarationSequence; + mainNode = MakeIdNode(Struct, idNode->identifier.name, parent); + mainNode->typeTag = MakeTypeTag(astNode); + for (i = 0; i < declsNode->declarationSequence.count; i++) { + Node *decl = declsNode->declarationSequence.sequence[i]; + AddChildToNode(mainNode, MakeIdTree(decl, mainNode)); + } + break; + } + + case Type: + AddChildToNode(parent, MakeIdTree(astNode->type.typeNode, parent)); + return NULL; + + case UnaryExpression: + AddChildToNode(parent, MakeIdTree(astNode->unaryExpression.child, parent)); + return NULL; + + case Comment: + case CustomTypeNode: + case FunctionModifiers: + case FunctionSignature: + case Number: + case PrimitiveTypeNode: + case ReturnVoid: + case StaticModifier: + case StringLiteral: + return NULL; } astNode->idLink = mainNode; @@ -203,7 +255,6 @@ void PrintIdTree(IdNode *tree, uint32_t tabCount) { } } - int PrintAncestors(IdNode *node) { if (node == NULL) return -1; -- 2.25.1