refactor AST to use nameless union instead of child array

ast_refactor
cosmonaut 2021-05-15 15:34:15 -07:00
parent 41bf2bece8
commit abc82f381e
5 changed files with 494 additions and 251 deletions

364
src/ast.c
View File

@ -16,7 +16,6 @@ const char* SyntaxKindString(SyntaxKind syntaxKind)
case Comment: return "Comment";
case CustomTypeNode: return "CustomTypeNode";
case Declaration: return "Declaration";
case Expression: return "Expression";
case ForLoop: return "ForLoop";
case DeclarationSequence: return "DeclarationSequence";
case FunctionArgumentSequence: return "FunctionArgumentSequence";
@ -45,7 +44,7 @@ const char* SyntaxKindString(SyntaxKind syntaxKind)
uint8_t IsPrimitiveType(
Node *typeNode
) {
return typeNode->children[0]->syntaxKind == PrimitiveTypeNode;
return typeNode->type.typeNode->syntaxKind == PrimitiveTypeNode;
}
Node* MakePrimitiveTypeNode(
@ -53,8 +52,7 @@ Node* MakePrimitiveTypeNode(
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = PrimitiveTypeNode;
node->primitiveType = type;
node->childCount = 0;
node->primitiveType.type = type;
return node;
}
@ -63,8 +61,7 @@ Node* MakeCustomTypeNode(
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = CustomTypeNode;
node->value.string = strdup(name);
node->childCount = 0;
node->customType.name = strdup(name);
return node;
}
@ -73,9 +70,7 @@ Node* MakeReferenceTypeNode(
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = ReferenceTypeNode;
node->childCount = 1;
node->children = (Node**) malloc(sizeof(Node*));
node->children[0] = typeNode;
node->referenceType.type = typeNode;
return node;
}
@ -84,9 +79,7 @@ Node* MakeTypeNode(
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = Type;
node->childCount = 1;
node->children = (Node**) malloc(sizeof(Node*));
node->children[0] = typeNode;
node->type.typeNode = typeNode;
return node;
}
@ -95,8 +88,7 @@ Node* MakeIdentifierNode(
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = Identifier;
node->value.string = strdup(id);
node->childCount = 0;
node->identifier.name = strdup(id);
node->typeTag = NULL;
return node;
}
@ -107,8 +99,7 @@ Node* MakeNumberNode(
char *ptr;
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = Number;
node->value.number = strtoul(numberString, &ptr, 10);
node->childCount = 0;
node->number.value = strtoul(numberString, &ptr, 10);
return node;
}
@ -118,8 +109,7 @@ Node* MakeStringNode(
size_t slen = strlen(string);
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = StringLiteral;
node->value.string = strndup(string + 1, slen - 2);
node->childCount = 0;
node->stringLiteral.string = strndup(string + 1, slen - 2);
return node;
}
@ -127,10 +117,10 @@ Node* MakeStaticNode()
{
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = StaticModifier;
node->childCount = 0;
return node;
}
/* FIXME: this sucks */
Node* MakeFunctionModifiersNode(
Node **pModifierNodes,
uint32_t modifierCount
@ -138,13 +128,14 @@ Node* MakeFunctionModifiersNode(
uint32_t i;
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionModifiers;
node->childCount = modifierCount;
node->functionModifiers.count = modifierCount;
node->functionModifiers.sequence = NULL;
if (modifierCount > 0)
{
node->children = malloc(sizeof(Node*) * node->childCount);
node->functionModifiers.sequence = malloc(sizeof(Node*) * node->functionModifiers.count);
for (i = 0; i < modifierCount; i += 1)
{
node->children[i] = pModifierNodes[i];
node->functionModifiers.sequence[i] = pModifierNodes[i];
}
}
@ -157,10 +148,8 @@ Node* MakeUnaryNode(
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = UnaryExpression;
node->operator.unaryOperator = operator;
node->children = malloc(sizeof(Node*));
node->children[0] = child;
node->childCount = 1;
node->unaryExpression.operator = operator;
node->unaryExpression.child = child;
return node;
}
@ -171,11 +160,9 @@ Node* MakeBinaryNode(
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = BinaryExpression;
node->operator.binaryOperator = operator;
node->children = malloc(sizeof(Node*) * 2);
node->children[0] = left;
node->children[1] = right;
node->childCount = 2;
node->binaryExpression.left = left;
node->binaryExpression.right = right;
node->binaryExpression.operator = operator;
return node;
}
@ -185,10 +172,8 @@ Node* MakeDeclarationNode(
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = Declaration;
node->children = (Node**) malloc(sizeof(Node*) * 2);
node->childCount = 2;
node->children[0] = typeNode;
node->children[1] = identifierNode;
node->declaration.type = typeNode;
node->declaration.identifier = identifierNode;
return node;
}
@ -198,10 +183,8 @@ Node* MakeAssignmentNode(
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = Assignment;
node->childCount = 2;
node->children = malloc(sizeof(Node*) * 2);
node->children[0] = left;
node->children[1] = right;
node->assignmentStatement.left = left;
node->assignmentStatement.right = right;
return node;
}
@ -210,9 +193,9 @@ Node* StartStatementSequenceNode(
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = StatementSequence;
node->children = (Node**) malloc(sizeof(Node*));
node->childCount = 1;
node->children[0] = statementNode;
node->statementSequence.sequence = (Node**) malloc(sizeof(Node*));
node->statementSequence.sequence[0] = statementNode;
node->statementSequence.count = 1;
return node;
}
@ -220,9 +203,9 @@ Node* AddStatement(
Node* statementSequenceNode,
Node *statementNode
) {
statementSequenceNode->children = realloc(statementSequenceNode->children, sizeof(Node*) * (statementSequenceNode->childCount + 1));
statementSequenceNode->children[statementSequenceNode->childCount] = statementNode;
statementSequenceNode->childCount += 1;
statementSequenceNode->statementSequence.sequence = realloc(statementSequenceNode->statementSequence.sequence, sizeof(Node*) * (statementSequenceNode->statementSequence.count + 1));
statementSequenceNode->statementSequence.sequence[statementSequenceNode->statementSequence.count] = statementNode;
statementSequenceNode->statementSequence.count += 1;
return statementSequenceNode;
}
@ -231,9 +214,7 @@ Node* MakeReturnStatementNode(
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = Return;
node->children = (Node**) malloc(sizeof(Node*));
node->childCount = 1;
node->children[0] = expressionNode;
node->returnStatement.expression = expressionNode;
return node;
}
@ -241,8 +222,6 @@ Node* MakeReturnVoidStatementNode()
{
Node *node = (Node*) malloc(sizeof(Node));
node->syntaxKind = ReturnVoid;
node->childCount = 0;
node->children = NULL;
return node;
}
@ -251,9 +230,9 @@ Node *StartFunctionSignatureArgumentsNode(
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionSignatureArguments;
node->childCount = 1;
node->children = (Node**) malloc(sizeof(Node*));
node->children[0] = argumentNode;
node->functionSignatureArguments.sequence = (Node**) malloc(sizeof(Node*));
node->functionSignatureArguments.sequence[0] = argumentNode;
node->functionSignatureArguments.count = 1;
return node;
}
@ -261,9 +240,9 @@ Node* AddFunctionSignatureArgumentNode(
Node *argumentsNode,
Node *argumentNode
) {
argumentsNode->children = realloc(argumentsNode->children, sizeof(Node*) * (argumentsNode->childCount + 1));
argumentsNode->children[argumentsNode->childCount] = argumentNode;
argumentsNode->childCount += 1;
argumentsNode->functionSignatureArguments.sequence = realloc(argumentsNode->functionSignatureArguments.sequence, sizeof(Node*) * (argumentsNode->functionSignatureArguments.count + 1));
argumentsNode->functionSignatureArguments.sequence[argumentsNode->functionSignatureArguments.count] = argumentNode;
argumentsNode->functionSignatureArguments.count += 1;
return argumentsNode;
}
@ -271,8 +250,8 @@ Node *MakeEmptyFunctionSignatureArgumentsNode()
{
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionSignatureArguments;
node->childCount = 0;
node->children = NULL;
node->functionSignatureArguments.sequence = NULL;
node->functionSignatureArguments.count = 0;
return node;
}
@ -284,12 +263,10 @@ Node* MakeFunctionSignatureNode(
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionSignature;
node->childCount = 4;
node->children = (Node**) malloc(sizeof(Node*) * (node->childCount));
node->children[0] = identifierNode;
node->children[1] = typeNode;
node->children[2] = arguments;
node->children[3] = modifiersNode;
node->functionSignature.identifier = identifierNode;
node->functionSignature.type = typeNode;
node->functionSignature.arguments = arguments;
node->functionSignature.modifiers = modifiersNode;
return node;
}
@ -299,10 +276,8 @@ Node* MakeFunctionDeclarationNode(
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionDeclaration;
node->childCount = 2;
node->children = (Node**) malloc(sizeof(Node*) * 2);
node->children[0] = functionSignatureNode;
node->children[1] = functionBodyNode;
node->functionDeclaration.functionSignature = functionSignatureNode;
node->functionDeclaration.functionBody = functionBodyNode;
return node;
}
@ -312,10 +287,8 @@ Node* MakeStructDeclarationNode(
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = StructDeclaration;
node->childCount = 2;
node->children = (Node**) malloc(sizeof(Node*) * 2);
node->children[0] = identifierNode;
node->children[1] = declarationSequenceNode;
node->structDeclaration.identifier = identifierNode;
node->structDeclaration.declarationSequence = declarationSequenceNode;
return node;
}
@ -324,9 +297,9 @@ Node* StartDeclarationSequenceNode(
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = DeclarationSequence;
node->children = (Node**) malloc(sizeof(Node*));
node->childCount = 1;
node->children[0] = declarationNode;
node->declarationSequence.sequence = (Node**) malloc(sizeof(Node*));
node->declarationSequence.sequence[0] = declarationNode;
node->declarationSequence.count = 1;
return node;
}
@ -334,9 +307,9 @@ Node* AddDeclarationNode(
Node *declarationSequenceNode,
Node *declarationNode
) {
declarationSequenceNode->children = realloc(declarationSequenceNode->children, sizeof(Node*) * (declarationSequenceNode->childCount + 1));
declarationSequenceNode->children[declarationSequenceNode->childCount] = declarationNode;
declarationSequenceNode->childCount += 1;
declarationSequenceNode->declarationSequence.sequence = (Node**) realloc(declarationSequenceNode->declarationSequence.sequence, sizeof(Node*) * (declarationSequenceNode->declarationSequence.count + 1));
declarationSequenceNode->declarationSequence.sequence[declarationSequenceNode->declarationSequence.count] = declarationNode;
declarationSequenceNode->declarationSequence.count += 1;
return declarationSequenceNode;
}
@ -345,9 +318,9 @@ Node* StartFunctionArgumentSequenceNode(
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionArgumentSequence;
node->childCount = 1;
node->children = (Node**) malloc(sizeof(Node*));
node->children[0] = argumentNode;
node->functionArgumentSequence.sequence = (Node**) malloc(sizeof(Node*));
node->functionArgumentSequence.sequence[0] = argumentNode;
node->functionArgumentSequence.count = 1;
return node;
}
@ -355,9 +328,9 @@ Node* AddFunctionArgumentNode(
Node *argumentSequenceNode,
Node *argumentNode
) {
argumentSequenceNode->children = realloc(argumentSequenceNode->children, sizeof(Node*) * (argumentSequenceNode->childCount + 1));
argumentSequenceNode->children[argumentSequenceNode->childCount] = argumentNode;
argumentSequenceNode->childCount += 1;
argumentSequenceNode->functionArgumentSequence.sequence = (Node**) realloc(argumentSequenceNode->functionArgumentSequence.sequence, sizeof(Node*) * (argumentSequenceNode->functionArgumentSequence.count + 1));
argumentSequenceNode->functionArgumentSequence.sequence[argumentSequenceNode->functionArgumentSequence.count] = argumentNode;
argumentSequenceNode->functionArgumentSequence.count += 1;
return argumentSequenceNode;
}
@ -365,8 +338,8 @@ Node *MakeEmptyFunctionArgumentSequenceNode()
{
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionArgumentSequence;
node->childCount = 0;
node->children = NULL;
node->functionArgumentSequence.count = 0;
node->functionArgumentSequence.sequence = NULL;
return node;
}
@ -376,10 +349,8 @@ Node* MakeFunctionCallExpressionNode(
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionCallExpression;
node->children = (Node**) malloc(sizeof(Node*) * 2);
node->childCount = 2;
node->children[0] = identifierNode;
node->children[1] = argumentSequenceNode;
node->functionCallExpression.identifier = identifierNode;
node->functionCallExpression.argumentSequence = argumentSequenceNode;
return node;
}
@ -389,10 +360,8 @@ Node* MakeAccessExpressionNode(
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = AccessExpression;
node->children = (Node**) malloc(sizeof(Node*) * 2);
node->childCount = 2;
node->children[0] = accessee;
node->children[1] = accessor;
node->accessExpression.accessee = accessee;
node->accessExpression.accessor = accessor;
return node;
}
@ -400,9 +369,7 @@ Node* MakeAllocNode(Node *typeNode)
{
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = AllocExpression;
node->childCount = 1;
node->children = (Node**) malloc(sizeof(Node*));
node->children[0] = typeNode;
node->allocExpression.type = typeNode;
return node;
}
@ -412,40 +379,34 @@ Node* MakeIfNode(
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = IfStatement;
node->childCount = 2;
node->children = (Node**) malloc(sizeof(Node*));
node->children[0] = expressionNode;
node->children[1] = statementSequenceNode;
node->ifStatement.expression = expressionNode;
node->ifStatement.statementSequence = statementSequenceNode;
return node;
}
Node* MakeIfElseNode(
Node *ifNode,
Node *statementSequenceNode
Node *elseNode
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = IfElseStatement;
node->childCount = 2;
node->children = (Node**) malloc(sizeof(Node*));
node->children[0] = ifNode;
node->children[1] = statementSequenceNode;
node->ifElseStatement.ifStatement = ifNode;
node->ifElseStatement.elseStatement = elseNode;
return node;
}
Node* MakeForLoopNode(
Node *identifierNode,
Node *declarationNode,
Node *startNumberNode,
Node *endNumberNode,
Node *statementSequenceNode
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = ForLoop;
node->childCount = 4;
node->children = (Node**) malloc(sizeof(Node*) * 4);
node->children[0] = identifierNode;
node->children[1] = startNumberNode;
node->children[2] = endNumberNode;
node->children[3] = statementSequenceNode;
node->forLoop.declaration = declarationNode;
node->forLoop.startNumber = startNumberNode;
node->forLoop.endNumber = endNumberNode;
node->forLoop.statementSequence = statementSequenceNode;
return node;
}
@ -462,9 +423,19 @@ static const char* PrimitiveTypeToString(PrimitiveType type)
return "Unknown";
}
static void PrintBinaryOperator(BinaryOperator expression)
static void PrintUnaryOperator(UnaryOperator operator)
{
switch (expression)
switch (operator)
{
case Negate:
printf("!");
break;
}
}
static void PrintBinaryOperator(BinaryOperator operator)
{
switch (operator)
{
case Add:
printf("+");
@ -480,7 +451,7 @@ static void PrintBinaryOperator(BinaryOperator expression)
}
}
static void PrintNode(Node *node, int tabCount)
void PrintNode(Node *node, uint32_t tabCount)
{
uint32_t i;
for (i = 0; i < tabCount; i += 1)
@ -491,48 +462,157 @@ static void PrintNode(Node *node, int tabCount)
printf("%s: ", SyntaxKindString(node->syntaxKind));
switch (node->syntaxKind)
{
case BinaryExpression:
PrintBinaryOperator(node->operator.binaryOperator);
case AccessExpression:
PrintNode(node->accessExpression.accessee, tabCount + 1);
PrintNode(node->accessExpression.accessor, tabCount + 1);
break;
case Declaration:
case AllocExpression:
PrintNode(node->allocExpression.type, tabCount + 1);
break;
case Assignment:
PrintNode(node->assignmentStatement.left, tabCount + 1);
PrintNode(node->assignmentStatement.right, tabCount + 1);
break;
case BinaryExpression:
PrintNode(node->binaryExpression.left, tabCount + 1);
PrintBinaryOperator(node->binaryExpression.operator);
PrintNode(node->binaryExpression.right, tabCount + 1);
break;
case CustomTypeNode:
printf("%s", node->value.string);
printf("%s", node->customType.name);
break;
case PrimitiveTypeNode:
printf("%s", PrimitiveTypeToString(node->primitiveType));
case Declaration:
PrintNode(node->declaration.identifier, tabCount + 1);
PrintNode(node->declaration.type, tabCount + 1);
break;
case DeclarationSequence:
for (i = 0; i < node->declarationSequence.count; i += 1)
{
PrintNode(node->declarationSequence.sequence[i], tabCount + 1);
}
break;
case ForLoop:
PrintNode(node->forLoop.declaration, tabCount + 1);
PrintNode(node->forLoop.startNumber, tabCount + 1);
PrintNode(node->forLoop.endNumber, tabCount + 1);
PrintNode(node->forLoop.statementSequence, tabCount + 1);
break;
case FunctionArgumentSequence:
for (i = 0; i < node->functionArgumentSequence.count; i += 1)
{
PrintNode(node->functionArgumentSequence.sequence[i], tabCount + 1);
}
break;
case FunctionCallExpression:
PrintNode(node->functionCallExpression.identifier, tabCount + 1);
PrintNode(node->functionCallExpression.argumentSequence, tabCount + 1);
break;
case FunctionDeclaration:
PrintNode(node->functionDeclaration.functionSignature, tabCount + 1);
PrintNode(node->functionDeclaration.functionBody, tabCount + 1);
break;
case FunctionModifiers:
for (i = 0; i < node->functionModifiers.count; i += 1)
{
PrintNode(node->functionModifiers.sequence[i], tabCount + 1);
}
break;
case FunctionSignature:
PrintNode(node->functionSignature.identifier, tabCount + 1);
PrintNode(node->functionSignature.arguments, tabCount + 1);
PrintNode(node->functionSignature.type, tabCount + 1);
PrintNode(node->functionSignature.modifiers, tabCount + 1);
break;
case FunctionSignatureArguments:
for (i = 0; i < node->functionSignatureArguments.count; i += 1)
{
PrintNode(node->functionSignatureArguments.sequence[i], tabCount + 1);
}
break;
case Identifier:
if (node->typeTag == NULL) {
printf("%s", node->value.string);
printf("%s", node->identifier.name);
} else {
char *type = TypeTagToString(node->typeTag);
printf("%s<%s>", node->value.string, type);
printf("%s<%s>", node->identifier.name, type);
}
break;
case IfStatement:
PrintNode(node->ifStatement.expression, tabCount + 1);
PrintNode(node->ifStatement.statementSequence, tabCount + 1);
break;
case IfElseStatement:
PrintNode(node->ifElseStatement.ifStatement, tabCount + 1);
PrintNode(node->ifElseStatement.elseStatement, tabCount + 1);
break;
case Number:
printf("%lu", node->value.number);
printf("%lu", node->number.value);
break;
case PrimitiveTypeNode:
printf("%s", PrimitiveTypeToString(node->primitiveType.type));
break;
case ReferenceTypeNode:
PrintNode(node->referenceType.type, tabCount + 1);
break;
case Return:
PrintNode(node->returnStatement.expression, tabCount + 1);
break;
case ReturnVoid:
break;
case StatementSequence:
for (i = 0; i < node->statementSequence.count; i += 1)
{
PrintNode(node->statementSequence.sequence[i], tabCount + 1);
}
break;
case StaticModifier:
break;
case StringLiteral:
printf("%s", node->stringLiteral.string);
break;
case StructDeclaration:
PrintNode(node->structDeclaration.identifier, tabCount + 1);
PrintNode(node->structDeclaration.declarationSequence, tabCount + 1);
break;
case Type:
PrintNode(node->type.typeNode, tabCount + 1);
break;
case UnaryExpression:
PrintUnaryOperator(node->unaryExpression.operator);
PrintNode(node->unaryExpression.child, tabCount + 1);
break;
}
printf("\n");
}
void PrintTree(Node *node, uint32_t tabCount)
{
uint32_t i;
PrintNode(node, tabCount);
for (i = 0; i < node->childCount; i += 1)
{
PrintTree(node->children[i], tabCount + 1);
}
}
TypeTag* MakeTypeTag(Node *node) {
if (node == NULL) {
fprintf(stderr, "wraith: Attempted to call MakeTypeTag on null value.\n");
@ -542,36 +622,36 @@ TypeTag* MakeTypeTag(Node *node) {
TypeTag *tag = (TypeTag*)malloc(sizeof(TypeTag));
switch (node->syntaxKind) {
case Type:
tag = MakeTypeTag(node->children[0]);
tag = MakeTypeTag(node->type.typeNode);
break;
case PrimitiveTypeNode:
tag->type = Primitive;
tag->value.primitiveType = node->primitiveType;
tag->value.primitiveType = node->primitiveType.type;
break;
case ReferenceTypeNode:
tag->type = Reference;
tag->value.referenceType = MakeTypeTag(node->children[0]);
tag->value.referenceType = MakeTypeTag(node->referenceType.type);
break;
case CustomTypeNode:
tag->type = Custom;
tag->value.customType = strdup(node->value.string);
tag->value.customType = strdup(node->customType.name);
break;
case Declaration:
tag = MakeTypeTag(node->children[0]);
tag = MakeTypeTag(node->declaration.type);
break;
case StructDeclaration:
tag->type = Custom;
tag->value.customType = strdup(node->children[0]->value.string);
tag->value.customType = strdup(node->structDeclaration.identifier->identifier.name);
printf("Struct tag: %s\n", TypeTagToString(tag));
break;
case FunctionDeclaration:
tag = MakeTypeTag(node->children[0]->children[1]);
tag = MakeTypeTag(node->functionDeclaration.functionSignature->functionSignature.type);
break;
default:

197
src/ast.h
View File

@ -4,6 +4,15 @@
#include <stdint.h>
#include "identcheck.h"
/* -Wpedantic nameless union/struct silencing */
#ifndef WRAITHNAMELESS
#ifdef __GNUC__
#define WRAITHNAMELESS __extension__
#else
#define WRAITHNAMELESS
#endif /* __GNUC__ */
#endif /* WRAITHNAMELESS */
typedef enum
{
AccessExpression,
@ -14,7 +23,6 @@ typedef enum
CustomTypeNode,
Declaration,
DeclarationSequence,
Expression,
ForLoop,
FunctionArgumentSequence,
FunctionCallExpression,
@ -92,25 +100,184 @@ typedef struct TypeTag
} value;
} TypeTag;
typedef struct Node
typedef struct Node Node;
struct Node
{
Node *parent;
SyntaxKind syntaxKind;
struct Node **children;
uint32_t childCount;
union
WRAITHNAMELESS union
{
UnaryOperator unaryOperator;
BinaryOperator binaryOperator;
} operator;
union
struct
{
Node *accessee;
Node *accessor;
} accessExpression;
struct
{
Node *type;
} allocExpression;
struct
{
Node *left;
Node *right;
} assignmentStatement;
struct
{
Node *left;
Node *right;
BinaryOperator operator;
} binaryExpression;
struct
{
} comment;
struct
{
char *name;
} customType;
struct
{
Node *type;
Node *identifier;
} declaration;
struct
{
Node **sequence;
uint32_t count;
} declarationSequence;
struct
{
Node *declaration;
Node *startNumber;
Node *endNumber;
Node *statementSequence;
} forLoop;
struct
{
Node **sequence;
uint32_t count;
} functionArgumentSequence;
struct
{
Node *identifier; /* FIXME: need better name */
Node *argumentSequence;
} functionCallExpression;
struct
{
Node *functionSignature;
Node *functionBody;
} functionDeclaration;
struct
{
Node **sequence;
uint32_t count;
} functionModifiers;
struct
{
Node *identifier;
Node *type;
Node *arguments;
Node *modifiers;
} functionSignature;
struct
{
Node **sequence;
uint32_t count;
} functionSignatureArguments;
struct
{
char *name;
} identifier;
struct
{
Node *expression;
Node *statementSequence;
} ifStatement;
struct
{
Node *ifStatement;
Node *elseStatement;
} ifElseStatement;
struct
{
uint64_t value;
} number;
struct
{
PrimitiveType type;
} primitiveType;
struct
{
Node *type;
} referenceType;
struct
{
Node *expression;
} returnStatement;
struct
{
} returnVoidStatement;
struct
{
Node **sequence;
uint32_t count;
} statementSequence;
struct
{
} staticModifier; /* FIXME: modifiers should just be an enum */
struct
{
char *string;
uint64_t number;
} value;
PrimitiveType primitiveType;
} stringLiteral;
struct
{
Node *identifier;
Node *declarationSequence;
} structDeclaration;
struct
{
Node *typeNode;
} type; /* FIXME: this needs a refactor */
struct
{
Node *child;
UnaryOperator operator;
} unaryExpression;
};
TypeTag *typeTag;
IdNode *idLink;
} Node;
};
const char* SyntaxKindString(SyntaxKind syntaxKind);
@ -223,7 +390,7 @@ Node* MakeIfNode(
);
Node* MakeIfElseNode(
Node *ifNode,
Node *statementSequenceNode
Node *elseNode /* can be a conditional or a statement sequence */
);
Node* MakeForLoopNode(
Node *identifierNode,
@ -232,7 +399,7 @@ Node* MakeForLoopNode(
Node *statementSequenceNode
);
void PrintTree(Node *node, uint32_t tabCount);
void PrintNode(Node *node, uint32_t tabCount);
const char* SyntaxKindString(SyntaxKind syntaxKind);
TypeTag* MakeTypeTag(Node *node);

View File

@ -263,7 +263,7 @@ static void AddStructDeclaration(
for (i = 0; i < fieldDeclarationCount; i += 1)
{
structTypeDeclarations[index].fields = realloc(structTypeDeclarations[index].fields, sizeof(StructTypeField) * (structTypeDeclarations[index].fieldCount + 1));
structTypeDeclarations[index].fields[i].name = strdup(fieldDeclarations[i]->children[1]->value.string);
structTypeDeclarations[index].fields[i].name = strdup(fieldDeclarations[i]->declaration.identifier->identifier.name);
structTypeDeclarations[index].fields[i].index = i;
structTypeDeclarations[index].fieldCount += 1;
}
@ -319,16 +319,15 @@ static LLVMTypeRef ResolveType(Node* typeNode)
{
if (IsPrimitiveType(typeNode))
{
return WraithTypeToLLVMType(typeNode->children[0]->primitiveType);
return WraithTypeToLLVMType(typeNode->type.typeNode->primitiveType.type);
}
else if (typeNode->children[0]->syntaxKind == CustomTypeNode)
else if (typeNode->type.typeNode->syntaxKind == CustomTypeNode)
{
char *typeName = typeNode->children[0]->value.string;
return LookupCustomType(typeName);
return LookupCustomType(typeNode->type.typeNode->customType.name);
}
else if (typeNode->children[0]->syntaxKind == ReferenceTypeNode)
else if (typeNode->type.typeNode->syntaxKind == ReferenceTypeNode)
{
return LLVMPointerType(ResolveType(typeNode->children[0]->children[0]), 0);
return LLVMPointerType(ResolveType(typeNode->type.typeNode->referenceType.type), 0);
}
else
{
@ -443,24 +442,24 @@ static LLVMValueRef CompileExpression(
static LLVMValueRef CompileNumber(
Node *numberExpression
) {
return LLVMConstInt(LLVMInt64Type(), numberExpression->value.number, 0);
return LLVMConstInt(LLVMInt64Type(), numberExpression->number.value, 0);
}
static LLVMValueRef CompileString(
LLVMBuilderRef builder,
Node *stringExpression
) {
return LLVMBuildGlobalStringPtr(builder, stringExpression->value.string, "stringConstant");
return LLVMBuildGlobalStringPtr(builder, stringExpression->stringLiteral.string, "stringConstant");
}
static LLVMValueRef CompileBinaryExpression(
LLVMBuilderRef builder,
Node *binaryExpression
) {
LLVMValueRef left = CompileExpression(builder, binaryExpression->children[0]);
LLVMValueRef right = CompileExpression(builder, binaryExpression->children[1]);
LLVMValueRef left = CompileExpression(builder, binaryExpression->binaryExpression.left);
LLVMValueRef right = CompileExpression(builder, binaryExpression->binaryExpression.right);
switch (binaryExpression->operator.binaryOperator)
switch (binaryExpression->binaryExpression.operator)
{
case Add:
return LLVMBuildAdd(builder, left, right, "addResult");
@ -494,11 +493,11 @@ static LLVMValueRef CompileBinaryExpression(
/* FIXME THIS IS ALL BROKEN */
static LLVMValueRef CompileFunctionCallExpression(
LLVMBuilderRef builder,
Node *expression
Node *functionCallExpression
) {
uint32_t i;
uint32_t argumentCount = 0;
LLVMValueRef args[expression->children[1]->childCount + 1];
LLVMValueRef args[functionCallExpression->functionCallExpression.argumentSequence->functionArgumentSequence.count + 1];
LLVMValueRef function;
uint8_t isStatic;
LLVMValueRef structInstance;
@ -506,25 +505,26 @@ static LLVMValueRef CompileFunctionCallExpression(
char *returnName = "";
/* FIXME: this needs to be recursive on access chains */
if (expression->children[0]->syntaxKind == AccessExpression)
/* FIXME: this needs to be able to call same-struct functions implicitly */
if (functionCallExpression->functionCallExpression.identifier->syntaxKind == AccessExpression)
{
LLVMTypeRef typeReference = FindStructType(
expression->children[0]->children[0]->value.string
functionCallExpression->functionCallExpression.identifier->identifier.name
);
if (typeReference != NULL)
{
function = LookupFunctionByType(
typeReference,
expression->children[0]->children[1]->value.string,
functionCallExpression->functionCallExpression.identifier->accessExpression.accessor->identifier.name,
&functionReturnType,
&isStatic
);
}
else
{
structInstance = FindVariablePointer(expression->children[0]->children[0]->value.string);
function = LookupFunctionByInstance(structInstance, expression->children[0]->children[1]->value.string, &functionReturnType, &isStatic);
structInstance = FindVariablePointer(functionCallExpression->functionCallExpression.identifier->accessExpression.accessee->identifier.name);
function = LookupFunctionByInstance(structInstance, functionCallExpression->functionCallExpression.identifier->accessExpression.accessor->identifier.name, &functionReturnType, &isStatic);
}
}
else
@ -539,9 +539,9 @@ static LLVMValueRef CompileFunctionCallExpression(
argumentCount += 1;
}
for (i = 0; i < expression->children[1]->childCount; i += 1)
for (i = 0; i < functionCallExpression->functionCallExpression.argumentSequence->functionArgumentSequence.count; i += 1)
{
args[argumentCount] = CompileExpression(builder, expression->children[1]->children[i]);
args[argumentCount] = CompileExpression(builder, functionCallExpression->functionCallExpression.argumentSequence->functionArgumentSequence.sequence[i]);
argumentCount += 1;
}
@ -555,30 +555,26 @@ static LLVMValueRef CompileFunctionCallExpression(
static LLVMValueRef CompileAccessExpressionForStore(
LLVMBuilderRef builder,
Node *expression
Node *accessExpression
) {
Node *accessee = expression->children[0];
Node *accessor = expression->children[1];
LLVMValueRef accesseeValue = FindVariablePointer(accessee->value.string);
return FindStructFieldPointer(builder, accesseeValue, accessor->value.string);
LLVMValueRef accesseeValue = FindVariablePointer(accessExpression->accessExpression.accessee->identifier.name);
return FindStructFieldPointer(builder, accesseeValue, accessExpression->accessExpression.accessor->identifier.name);
}
static LLVMValueRef CompileAccessExpression(
LLVMBuilderRef builder,
Node *expression
Node *accessExpression
) {
Node *accessee = expression->children[0];
Node *accessor = expression->children[1];
LLVMValueRef accesseeValue = FindVariablePointer(accessee->value.string);
LLVMValueRef access = FindStructFieldPointer(builder, accesseeValue, accessor->value.string);
return LLVMBuildLoad(builder, access, accessor->value.string);
LLVMValueRef accesseeValue = FindVariablePointer(accessExpression->accessExpression.accessee->identifier.name);
LLVMValueRef access = FindStructFieldPointer(builder, accesseeValue, accessExpression->accessExpression.accessor->identifier.name);
return LLVMBuildLoad(builder, access, accessExpression->accessExpression.accessor->identifier.name);
}
static LLVMValueRef CompileAllocExpression(
LLVMBuilderRef builder,
Node *expression
Node *allocExpression
) {
LLVMTypeRef type = ResolveType(expression->children[0]);
LLVMTypeRef type = ResolveType(allocExpression->allocExpression.type);
return LLVMBuildMalloc(builder, type, "allocation");
}
@ -601,7 +597,7 @@ static LLVMValueRef CompileExpression(
return CompileFunctionCallExpression(builder, expression);
case Identifier:
return FindVariableValue(builder, expression->value.string);
return FindVariableValue(builder, expression->identifier.name);
case Number:
return CompileNumber(expression);
@ -619,7 +615,7 @@ static LLVMBasicBlockRef CompileStatement(LLVMBuilderRef builder, LLVMValueRef f
static LLVMBasicBlockRef CompileReturn(LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement)
{
LLVMValueRef expression = CompileExpression(builder, returnStatemement->children[0]);
LLVMValueRef expression = CompileExpression(builder, returnStatemement->returnStatement.expression);
LLVMBuildRet(builder, expression);
return LLVMGetLastBasicBlock(function);
}
@ -634,11 +630,11 @@ static LLVMBasicBlockRef CompileReturnVoid(LLVMBuilderRef builder, LLVMValueRef
static LLVMValueRef CompileFunctionVariableDeclaration(LLVMBuilderRef builder, LLVMValueRef function, Node *variableDeclaration)
{
LLVMValueRef variable;
char *variableName = variableDeclaration->children[1]->value.string;
char *variableName = variableDeclaration->declaration.identifier->identifier.name;
char *ptrName = strdup(variableName);
strcat(ptrName, "_ptr");
variable = LLVMBuildAlloca(builder, ResolveType(variableDeclaration->children[0]), ptrName);
variable = LLVMBuildAlloca(builder, ResolveType(variableDeclaration->declaration.type), ptrName);
free(ptrName);
@ -649,19 +645,19 @@ static LLVMValueRef CompileFunctionVariableDeclaration(LLVMBuilderRef builder, L
static LLVMBasicBlockRef CompileAssignment(LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement)
{
LLVMValueRef result = CompileExpression(builder, assignmentStatement->children[1]);
LLVMValueRef result = CompileExpression(builder, assignmentStatement->assignmentStatement.right);
LLVMValueRef identifier;
if (assignmentStatement->children[0]->syntaxKind == AccessExpression)
if (assignmentStatement->assignmentStatement.left->syntaxKind == AccessExpression)
{
identifier = CompileAccessExpressionForStore(builder, assignmentStatement->children[0]);
identifier = CompileAccessExpressionForStore(builder, assignmentStatement->assignmentStatement.left);
}
else if (assignmentStatement->children[0]->syntaxKind == Identifier)
else if (assignmentStatement->assignmentStatement.left->syntaxKind == Identifier)
{
identifier = FindVariablePointer(assignmentStatement->children[0]->value.string);
identifier = FindVariablePointer(assignmentStatement->assignmentStatement.left->identifier.name);
}
else if (assignmentStatement->children[0]->syntaxKind == Declaration)
else if (assignmentStatement->assignmentStatement.left->syntaxKind == Declaration)
{
identifier = CompileFunctionVariableDeclaration(builder, function, assignmentStatement->children[0]);
identifier = CompileFunctionVariableDeclaration(builder, function, assignmentStatement->assignmentStatement.left);
}
else
{
@ -677,7 +673,7 @@ static LLVMBasicBlockRef CompileAssignment(LLVMBuilderRef builder, LLVMValueRef
static LLVMBasicBlockRef CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement)
{
uint32_t i;
LLVMValueRef conditional = CompileExpression(builder, ifStatement->children[0]);
LLVMValueRef conditional = CompileExpression(builder, ifStatement->ifStatement.expression);
LLVMBasicBlockRef block = LLVMAppendBasicBlock(function, "ifBlock");
LLVMBasicBlockRef afterCond = LLVMAppendBasicBlock(function, "afterCond");
@ -686,9 +682,9 @@ static LLVMBasicBlockRef CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef
LLVMPositionBuilderAtEnd(builder, block);
for (i = 0; i < ifStatement->children[1]->childCount; i += 1)
for (i = 0; i < ifStatement->ifStatement.statementSequence->statementSequence.count; i += 1)
{
CompileStatement(builder, function, ifStatement->children[1]->children[i]);
CompileStatement(builder, function, ifStatement->ifStatement.statementSequence->statementSequence.sequence[i]);
}
LLVMBuildBr(builder, afterCond);
@ -700,7 +696,7 @@ static LLVMBasicBlockRef CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef
static LLVMBasicBlockRef CompileIfElseStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifElseStatement)
{
uint32_t i;
LLVMValueRef conditional = CompileExpression(builder, ifElseStatement->children[0]->children[0]);
LLVMValueRef conditional = CompileExpression(builder, ifElseStatement->ifElseStatement.ifStatement->ifStatement.expression);
LLVMBasicBlockRef ifBlock = LLVMAppendBasicBlock(function, "ifBlock");
LLVMBasicBlockRef elseBlock = LLVMAppendBasicBlock(function, "elseBlock");
@ -710,25 +706,25 @@ static LLVMBasicBlockRef CompileIfElseStatement(LLVMBuilderRef builder, LLVMValu
LLVMPositionBuilderAtEnd(builder, ifBlock);
for (i = 0; i < ifElseStatement->children[0]->children[1]->childCount; i += 1)
for (i = 0; i < ifElseStatement->ifElseStatement.ifStatement->ifStatement.statementSequence->statementSequence.count; i += 1)
{
CompileStatement(builder, function, ifElseStatement->children[0]->children[1]->children[i]);
CompileStatement(builder, function, ifElseStatement->ifStatement.statementSequence->statementSequence.sequence[i]);
}
LLVMBuildBr(builder, afterCond);
LLVMPositionBuilderAtEnd(builder, elseBlock);
if (ifElseStatement->children[1]->syntaxKind == StatementSequence)
if (ifElseStatement->ifElseStatement.elseStatement->syntaxKind == StatementSequence)
{
for (i = 0; i < ifElseStatement->children[1]->childCount; i += 1)
for (i = 0; i < ifElseStatement->ifElseStatement.elseStatement->statementSequence.count; i += 1)
{
CompileStatement(builder, function, ifElseStatement->children[1]->children[i]);
CompileStatement(builder, function, ifElseStatement->ifElseStatement.elseStatement->statementSequence.sequence[i]);
}
}
else
{
CompileStatement(builder, function, ifElseStatement->children[1]);
CompileStatement(builder, function, ifElseStatement->ifElseStatement.elseStatement);
}
LLVMBuildBr(builder, afterCond);
@ -744,8 +740,8 @@ static LLVMBasicBlockRef CompileForLoopStatement(LLVMBuilderRef builder, LLVMVal
LLVMBasicBlockRef checkBlock = LLVMAppendBasicBlock(function, "loopCheck");
LLVMBasicBlockRef bodyBlock = LLVMAppendBasicBlock(function, "loopBody");
LLVMBasicBlockRef afterLoopBlock = LLVMAppendBasicBlock(function, "afterLoop");
char *iteratorVariableName = forLoopStatement->children[0]->children[1]->value.string;
LLVMTypeRef iteratorVariableType = ResolveType(forLoopStatement->children[0]->children[0]);
char *iteratorVariableName = forLoopStatement->forLoop.declaration->declaration.identifier->identifier.name;
LLVMTypeRef iteratorVariableType = ResolveType(forLoopStatement->forLoop.declaration->declaration.type);
PushScopeFrame(scope);
@ -762,13 +758,13 @@ static LLVMBasicBlockRef CompileForLoopStatement(LLVMBuilderRef builder, LLVMVal
LLVMValueRef nextValue = LLVMBuildAdd(
builder,
iteratorValue,
LLVMConstInt(iteratorVariableType, forLoopStatement->children[1]->value.number, 0),
LLVMConstInt(iteratorVariableType, 1, 0), /* FIXME: add custom increment value */
"next"
);
LLVMPositionBuilderAtEnd(builder, checkBlock);
LLVMValueRef iteratorEndValue = CompileNumber(forLoopStatement->children[2]);
LLVMValueRef iteratorEndValue = CompileNumber(forLoopStatement->forLoop.endNumber);
LLVMValueRef comparison = LLVMBuildICmp(builder, LLVMIntULE, iteratorValue, iteratorEndValue, "iteratorCompare");
LLVMBuildCondBr(builder, comparison, bodyBlock, afterLoopBlock);
@ -776,9 +772,9 @@ static LLVMBasicBlockRef CompileForLoopStatement(LLVMBuilderRef builder, LLVMVal
LLVMPositionBuilderAtEnd(builder, bodyBlock);
LLVMBasicBlockRef lastBlock;
for (i = 0; i < forLoopStatement->children[3]->childCount; i += 1)
for (i = 0; i < forLoopStatement->forLoop.statementSequence->statementSequence.count; i += 1)
{
lastBlock = CompileStatement(builder, function, forLoopStatement->children[3]->children[i]);
lastBlock = CompileStatement(builder, function, forLoopStatement->forLoop.statementSequence->statementSequence.sequence[i]);
}
LLVMBuildBr(builder, checkBlock);
@ -786,7 +782,7 @@ static LLVMBasicBlockRef CompileForLoopStatement(LLVMBuilderRef builder, LLVMVal
LLVMPositionBuilderBefore(builder, LLVMGetFirstInstruction(checkBlock));
LLVMValueRef incomingValues[2];
incomingValues[0] = CompileNumber(forLoopStatement->children[1]);
incomingValues[0] = CompileNumber(forLoopStatement->forLoop.startNumber);
incomingValues[1] = nextValue;
LLVMBasicBlockRef incomingBlocks[2];
@ -848,17 +844,17 @@ static void CompileFunction(
uint32_t i;
uint8_t hasReturn = 0;
uint8_t isStatic = 0;
Node *functionSignature = functionDeclaration->children[0];
Node *functionBody = functionDeclaration->children[1];
uint32_t argumentCount = functionSignature->children[2]->childCount;
Node *functionSignature = functionDeclaration->functionDeclaration.functionSignature;
Node *functionBody = functionDeclaration->functionDeclaration.functionBody;
uint32_t argumentCount = functionSignature->functionSignature.arguments->functionSignatureArguments.count;
LLVMTypeRef paramTypes[argumentCount + 1];
uint32_t paramIndex = 0;
if (functionSignature->children[3]->childCount > 0)
if (functionSignature->functionSignature.modifiers->functionModifiers.count > 0)
{
for (i = 0; i < functionSignature->children[3]->childCount; i += 1)
for (i = 0; i < functionSignature->functionSignature.modifiers->functionModifiers.count; i += 1)
{
if (functionSignature->children[3]->children[i]->syntaxKind == StaticModifier)
if (functionSignature->functionSignature.modifiers->functionModifiers.sequence[i]->syntaxKind == StaticModifier)
{
isStatic = 1;
break;
@ -875,22 +871,22 @@ static void CompileFunction(
PushScopeFrame(scope);
/* FIXME: should work for non-primitive types */
for (i = 0; i < functionSignature->children[2]->childCount; i += 1)
for (i = 0; i < functionSignature->functionSignature.arguments->functionSignatureArguments.count; i += 1)
{
paramTypes[paramIndex] = WraithTypeToLLVMType(functionSignature->children[2]->children[i]->children[0]->children[0]->primitiveType);
paramTypes[paramIndex] = ResolveType(functionSignature->functionSignature.arguments->functionSignatureArguments.sequence[i]->declaration.type);
paramIndex += 1;
}
LLVMTypeRef returnType = WraithTypeToLLVMType(functionSignature->children[1]->children[0]->primitiveType);
LLVMTypeRef returnType = ResolveType(functionSignature->functionSignature.type);
LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, paramIndex, 0);
char *functionName = strdup(parentStructName);
strcat(functionName, "_");
strcat(functionName, functionSignature->children[0]->value.string);
strcat(functionName, functionSignature->functionSignature.identifier->identifier.name);
LLVMValueRef function = LLVMAddFunction(module, functionName, functionType);
free(functionName);
DeclareStructFunction(wStructPointerType, function, returnType, isStatic, functionSignature->children[0]->value.string);
DeclareStructFunction(wStructPointerType, function, returnType, isStatic, functionSignature->functionSignature.identifier->identifier.name);
LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry");
LLVMBuilderRef builder = LLVMCreateBuilder();
@ -902,20 +898,20 @@ static void CompileFunction(
AddStructVariablesToScope(builder, wStructPointer);
}
for (i = 0; i < functionSignature->children[2]->childCount; i += 1)
for (i = 0; i < functionSignature->functionSignature.arguments->functionSignatureArguments.count; i += 1)
{
char *ptrName = strdup(functionSignature->children[2]->children[i]->children[1]->value.string);
char *ptrName = strdup(functionSignature->functionSignature.arguments->functionSignatureArguments.sequence[i]->identifier.name);
strcat(ptrName, "_ptr");
LLVMValueRef argument = LLVMGetParam(function, i + !isStatic);
LLVMValueRef argumentCopy = LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName);
LLVMBuildStore(builder, argument, argumentCopy);
free(ptrName);
AddLocalVariable(scope, argumentCopy, NULL, functionSignature->children[2]->children[i]->children[1]->value.string);
AddLocalVariable(scope, argumentCopy, NULL, functionSignature->functionSignature.arguments->functionSignatureArguments.sequence[i]->identifier.name);
}
for (i = 0; i < functionBody->childCount; i += 1)
for (i = 0; i < functionBody->statementSequence.count; i += 1)
{
CompileStatement(builder, function, functionBody->children[i]);
CompileStatement(builder, function, functionBody->statementSequence.sequence[i]);
}
hasReturn = LLVMGetBasicBlockTerminator(LLVMGetLastBasicBlock(function)) != NULL;
@ -938,12 +934,12 @@ static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *no
{
uint32_t i;
uint32_t fieldCount = 0;
uint32_t declarationCount = node->children[1]->childCount;
uint32_t declarationCount = node->structDeclaration.declarationSequence->declarationSequence.count;
uint8_t packed = 1;
LLVMTypeRef types[declarationCount];
Node *currentDeclarationNode;
Node *fieldDeclarations[declarationCount];
char *structName = node->children[0]->value.string;
char *structName = node->structDeclaration.identifier->identifier.name;
PushScopeFrame(scope);
@ -953,12 +949,12 @@ static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *no
/* first, build the structure definition */
for (i = 0; i < declarationCount; i += 1)
{
currentDeclarationNode = node->children[1]->children[i];
currentDeclarationNode = node->structDeclaration.declarationSequence->declarationSequence.sequence[i];
switch (currentDeclarationNode->syntaxKind)
{
case Declaration: /* this is badly named */
types[fieldCount] = ResolveType(currentDeclarationNode->children[0]);
types[fieldCount] = ResolveType(currentDeclarationNode->declaration.type);
fieldDeclarations[fieldCount] = currentDeclarationNode;
fieldCount += 1;
break;
@ -966,12 +962,12 @@ static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *no
}
LLVMStructSetBody(wStructType, types, fieldCount, packed);
AddStructDeclaration(wStructType, wStructPointerType, node->children[0]->value.string, fieldDeclarations, fieldCount);
AddStructDeclaration(wStructType, wStructPointerType, structName, fieldDeclarations, fieldCount);
/* now we can wire up the functions */
for (i = 0; i < declarationCount; i += 1)
{
currentDeclarationNode = node->children[1]->children[i];
currentDeclarationNode = node->structDeclaration.declarationSequence->declarationSequence.sequence[i];
switch (currentDeclarationNode->syntaxKind)
{
@ -984,15 +980,15 @@ static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *no
PopScopeFrame(scope);
}
static void Compile(LLVMModuleRef module, LLVMContextRef context, Node *node)
static void Compile(LLVMModuleRef module, LLVMContextRef context, Node *declarationSequenceNode)
{
uint32_t i;
for (i = 0; i < node->childCount; i += 1)
for (i = 0; i < declarationSequenceNode->declarationSequence.count; i += 1)
{
if (node->children[i]->syntaxKind == StructDeclaration)
if (declarationSequenceNode->declarationSequence.sequence[i]->syntaxKind == StructDeclaration)
{
CompileStruct(module, context, node->children[i]);
CompileStruct(module, context, declarationSequenceNode->declarationSequence.sequence[i]);
}
else
{

View File

@ -69,7 +69,7 @@ int main(int argc, char *argv[])
IdNode *idTree = MakeIdTree(rootNode, NULL);
PrintIdTree(idTree, /*tabCount=*/0);
printf("\n");
PrintTree(rootNode, /*tabCount=*/0);
PrintNode(rootNode, /*tabCount=*/0);
}
exitCode = Codegen(rootNode, optimizationLevel);
}

View File

@ -31,7 +31,7 @@ int Parse(char *inputFilename, Node **pRootNode, uint8_t parseVerbose)
{
if (parseVerbose)
{
PrintTree(*pRootNode, 0);
PrintNode(*pRootNode, 0);
}
}
else if (result == 1)