refactor AST to use nameless union instead of child array

pull/2/head
cosmonaut 2021-05-15 15:34:15 -07:00
parent 41bf2bece8
commit abc82f381e
5 changed files with 494 additions and 251 deletions

368
src/ast.c
View File

@ -16,7 +16,6 @@ const char* SyntaxKindString(SyntaxKind syntaxKind)
case Comment: return "Comment"; case Comment: return "Comment";
case CustomTypeNode: return "CustomTypeNode"; case CustomTypeNode: return "CustomTypeNode";
case Declaration: return "Declaration"; case Declaration: return "Declaration";
case Expression: return "Expression";
case ForLoop: return "ForLoop"; case ForLoop: return "ForLoop";
case DeclarationSequence: return "DeclarationSequence"; case DeclarationSequence: return "DeclarationSequence";
case FunctionArgumentSequence: return "FunctionArgumentSequence"; case FunctionArgumentSequence: return "FunctionArgumentSequence";
@ -45,7 +44,7 @@ const char* SyntaxKindString(SyntaxKind syntaxKind)
uint8_t IsPrimitiveType( uint8_t IsPrimitiveType(
Node *typeNode Node *typeNode
) { ) {
return typeNode->children[0]->syntaxKind == PrimitiveTypeNode; return typeNode->type.typeNode->syntaxKind == PrimitiveTypeNode;
} }
Node* MakePrimitiveTypeNode( Node* MakePrimitiveTypeNode(
@ -53,8 +52,7 @@ Node* MakePrimitiveTypeNode(
) { ) {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = PrimitiveTypeNode; node->syntaxKind = PrimitiveTypeNode;
node->primitiveType = type; node->primitiveType.type = type;
node->childCount = 0;
return node; return node;
} }
@ -63,8 +61,7 @@ Node* MakeCustomTypeNode(
) { ) {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = CustomTypeNode; node->syntaxKind = CustomTypeNode;
node->value.string = strdup(name); node->customType.name = strdup(name);
node->childCount = 0;
return node; return node;
} }
@ -73,9 +70,7 @@ Node* MakeReferenceTypeNode(
) { ) {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = ReferenceTypeNode; node->syntaxKind = ReferenceTypeNode;
node->childCount = 1; node->referenceType.type = typeNode;
node->children = (Node**) malloc(sizeof(Node*));
node->children[0] = typeNode;
return node; return node;
} }
@ -84,9 +79,7 @@ Node* MakeTypeNode(
) { ) {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = Type; node->syntaxKind = Type;
node->childCount = 1; node->type.typeNode = typeNode;
node->children = (Node**) malloc(sizeof(Node*));
node->children[0] = typeNode;
return node; return node;
} }
@ -95,8 +88,7 @@ Node* MakeIdentifierNode(
) { ) {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = Identifier; node->syntaxKind = Identifier;
node->value.string = strdup(id); node->identifier.name = strdup(id);
node->childCount = 0;
node->typeTag = NULL; node->typeTag = NULL;
return node; return node;
} }
@ -107,8 +99,7 @@ Node* MakeNumberNode(
char *ptr; char *ptr;
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = Number; node->syntaxKind = Number;
node->value.number = strtoul(numberString, &ptr, 10); node->number.value = strtoul(numberString, &ptr, 10);
node->childCount = 0;
return node; return node;
} }
@ -118,8 +109,7 @@ Node* MakeStringNode(
size_t slen = strlen(string); size_t slen = strlen(string);
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = StringLiteral; node->syntaxKind = StringLiteral;
node->value.string = strndup(string + 1, slen - 2); node->stringLiteral.string = strndup(string + 1, slen - 2);
node->childCount = 0;
return node; return node;
} }
@ -127,10 +117,10 @@ Node* MakeStaticNode()
{ {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = StaticModifier; node->syntaxKind = StaticModifier;
node->childCount = 0;
return node; return node;
} }
/* FIXME: this sucks */
Node* MakeFunctionModifiersNode( Node* MakeFunctionModifiersNode(
Node **pModifierNodes, Node **pModifierNodes,
uint32_t modifierCount uint32_t modifierCount
@ -138,13 +128,14 @@ Node* MakeFunctionModifiersNode(
uint32_t i; uint32_t i;
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionModifiers; node->syntaxKind = FunctionModifiers;
node->childCount = modifierCount; node->functionModifiers.count = modifierCount;
node->functionModifiers.sequence = NULL;
if (modifierCount > 0) if (modifierCount > 0)
{ {
node->children = malloc(sizeof(Node*) * node->childCount); node->functionModifiers.sequence = malloc(sizeof(Node*) * node->functionModifiers.count);
for (i = 0; i < modifierCount; i += 1) for (i = 0; i < modifierCount; i += 1)
{ {
node->children[i] = pModifierNodes[i]; node->functionModifiers.sequence[i] = pModifierNodes[i];
} }
} }
@ -157,10 +148,8 @@ Node* MakeUnaryNode(
) { ) {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = UnaryExpression; node->syntaxKind = UnaryExpression;
node->operator.unaryOperator = operator; node->unaryExpression.operator = operator;
node->children = malloc(sizeof(Node*)); node->unaryExpression.child = child;
node->children[0] = child;
node->childCount = 1;
return node; return node;
} }
@ -171,11 +160,9 @@ Node* MakeBinaryNode(
) { ) {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = BinaryExpression; node->syntaxKind = BinaryExpression;
node->operator.binaryOperator = operator; node->binaryExpression.left = left;
node->children = malloc(sizeof(Node*) * 2); node->binaryExpression.right = right;
node->children[0] = left; node->binaryExpression.operator = operator;
node->children[1] = right;
node->childCount = 2;
return node; return node;
} }
@ -185,10 +172,8 @@ Node* MakeDeclarationNode(
) { ) {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = Declaration; node->syntaxKind = Declaration;
node->children = (Node**) malloc(sizeof(Node*) * 2); node->declaration.type = typeNode;
node->childCount = 2; node->declaration.identifier = identifierNode;
node->children[0] = typeNode;
node->children[1] = identifierNode;
return node; return node;
} }
@ -198,10 +183,8 @@ Node* MakeAssignmentNode(
) { ) {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = Assignment; node->syntaxKind = Assignment;
node->childCount = 2; node->assignmentStatement.left = left;
node->children = malloc(sizeof(Node*) * 2); node->assignmentStatement.right = right;
node->children[0] = left;
node->children[1] = right;
return node; return node;
} }
@ -210,9 +193,9 @@ Node* StartStatementSequenceNode(
) { ) {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = StatementSequence; node->syntaxKind = StatementSequence;
node->children = (Node**) malloc(sizeof(Node*)); node->statementSequence.sequence = (Node**) malloc(sizeof(Node*));
node->childCount = 1; node->statementSequence.sequence[0] = statementNode;
node->children[0] = statementNode; node->statementSequence.count = 1;
return node; return node;
} }
@ -220,9 +203,9 @@ Node* AddStatement(
Node* statementSequenceNode, Node* statementSequenceNode,
Node *statementNode Node *statementNode
) { ) {
statementSequenceNode->children = realloc(statementSequenceNode->children, sizeof(Node*) * (statementSequenceNode->childCount + 1)); statementSequenceNode->statementSequence.sequence = realloc(statementSequenceNode->statementSequence.sequence, sizeof(Node*) * (statementSequenceNode->statementSequence.count + 1));
statementSequenceNode->children[statementSequenceNode->childCount] = statementNode; statementSequenceNode->statementSequence.sequence[statementSequenceNode->statementSequence.count] = statementNode;
statementSequenceNode->childCount += 1; statementSequenceNode->statementSequence.count += 1;
return statementSequenceNode; return statementSequenceNode;
} }
@ -231,9 +214,7 @@ Node* MakeReturnStatementNode(
) { ) {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = Return; node->syntaxKind = Return;
node->children = (Node**) malloc(sizeof(Node*)); node->returnStatement.expression = expressionNode;
node->childCount = 1;
node->children[0] = expressionNode;
return node; return node;
} }
@ -241,8 +222,6 @@ Node* MakeReturnVoidStatementNode()
{ {
Node *node = (Node*) malloc(sizeof(Node)); Node *node = (Node*) malloc(sizeof(Node));
node->syntaxKind = ReturnVoid; node->syntaxKind = ReturnVoid;
node->childCount = 0;
node->children = NULL;
return node; return node;
} }
@ -251,9 +230,9 @@ Node *StartFunctionSignatureArgumentsNode(
) { ) {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionSignatureArguments; node->syntaxKind = FunctionSignatureArguments;
node->childCount = 1; node->functionSignatureArguments.sequence = (Node**) malloc(sizeof(Node*));
node->children = (Node**) malloc(sizeof(Node*)); node->functionSignatureArguments.sequence[0] = argumentNode;
node->children[0] = argumentNode; node->functionSignatureArguments.count = 1;
return node; return node;
} }
@ -261,9 +240,9 @@ Node* AddFunctionSignatureArgumentNode(
Node *argumentsNode, Node *argumentsNode,
Node *argumentNode Node *argumentNode
) { ) {
argumentsNode->children = realloc(argumentsNode->children, sizeof(Node*) * (argumentsNode->childCount + 1)); argumentsNode->functionSignatureArguments.sequence = realloc(argumentsNode->functionSignatureArguments.sequence, sizeof(Node*) * (argumentsNode->functionSignatureArguments.count + 1));
argumentsNode->children[argumentsNode->childCount] = argumentNode; argumentsNode->functionSignatureArguments.sequence[argumentsNode->functionSignatureArguments.count] = argumentNode;
argumentsNode->childCount += 1; argumentsNode->functionSignatureArguments.count += 1;
return argumentsNode; return argumentsNode;
} }
@ -271,8 +250,8 @@ Node *MakeEmptyFunctionSignatureArgumentsNode()
{ {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionSignatureArguments; node->syntaxKind = FunctionSignatureArguments;
node->childCount = 0; node->functionSignatureArguments.sequence = NULL;
node->children = NULL; node->functionSignatureArguments.count = 0;
return node; return node;
} }
@ -284,12 +263,10 @@ Node* MakeFunctionSignatureNode(
) { ) {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionSignature; node->syntaxKind = FunctionSignature;
node->childCount = 4; node->functionSignature.identifier = identifierNode;
node->children = (Node**) malloc(sizeof(Node*) * (node->childCount)); node->functionSignature.type = typeNode;
node->children[0] = identifierNode; node->functionSignature.arguments = arguments;
node->children[1] = typeNode; node->functionSignature.modifiers = modifiersNode;
node->children[2] = arguments;
node->children[3] = modifiersNode;
return node; return node;
} }
@ -299,10 +276,8 @@ Node* MakeFunctionDeclarationNode(
) { ) {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionDeclaration; node->syntaxKind = FunctionDeclaration;
node->childCount = 2; node->functionDeclaration.functionSignature = functionSignatureNode;
node->children = (Node**) malloc(sizeof(Node*) * 2); node->functionDeclaration.functionBody = functionBodyNode;
node->children[0] = functionSignatureNode;
node->children[1] = functionBodyNode;
return node; return node;
} }
@ -312,10 +287,8 @@ Node* MakeStructDeclarationNode(
) { ) {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = StructDeclaration; node->syntaxKind = StructDeclaration;
node->childCount = 2; node->structDeclaration.identifier = identifierNode;
node->children = (Node**) malloc(sizeof(Node*) * 2); node->structDeclaration.declarationSequence = declarationSequenceNode;
node->children[0] = identifierNode;
node->children[1] = declarationSequenceNode;
return node; return node;
} }
@ -324,9 +297,9 @@ Node* StartDeclarationSequenceNode(
) { ) {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = DeclarationSequence; node->syntaxKind = DeclarationSequence;
node->children = (Node**) malloc(sizeof(Node*)); node->declarationSequence.sequence = (Node**) malloc(sizeof(Node*));
node->childCount = 1; node->declarationSequence.sequence[0] = declarationNode;
node->children[0] = declarationNode; node->declarationSequence.count = 1;
return node; return node;
} }
@ -334,9 +307,9 @@ Node* AddDeclarationNode(
Node *declarationSequenceNode, Node *declarationSequenceNode,
Node *declarationNode Node *declarationNode
) { ) {
declarationSequenceNode->children = realloc(declarationSequenceNode->children, sizeof(Node*) * (declarationSequenceNode->childCount + 1)); declarationSequenceNode->declarationSequence.sequence = (Node**) realloc(declarationSequenceNode->declarationSequence.sequence, sizeof(Node*) * (declarationSequenceNode->declarationSequence.count + 1));
declarationSequenceNode->children[declarationSequenceNode->childCount] = declarationNode; declarationSequenceNode->declarationSequence.sequence[declarationSequenceNode->declarationSequence.count] = declarationNode;
declarationSequenceNode->childCount += 1; declarationSequenceNode->declarationSequence.count += 1;
return declarationSequenceNode; return declarationSequenceNode;
} }
@ -345,9 +318,9 @@ Node* StartFunctionArgumentSequenceNode(
) { ) {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionArgumentSequence; node->syntaxKind = FunctionArgumentSequence;
node->childCount = 1; node->functionArgumentSequence.sequence = (Node**) malloc(sizeof(Node*));
node->children = (Node**) malloc(sizeof(Node*)); node->functionArgumentSequence.sequence[0] = argumentNode;
node->children[0] = argumentNode; node->functionArgumentSequence.count = 1;
return node; return node;
} }
@ -355,9 +328,9 @@ Node* AddFunctionArgumentNode(
Node *argumentSequenceNode, Node *argumentSequenceNode,
Node *argumentNode Node *argumentNode
) { ) {
argumentSequenceNode->children = realloc(argumentSequenceNode->children, sizeof(Node*) * (argumentSequenceNode->childCount + 1)); argumentSequenceNode->functionArgumentSequence.sequence = (Node**) realloc(argumentSequenceNode->functionArgumentSequence.sequence, sizeof(Node*) * (argumentSequenceNode->functionArgumentSequence.count + 1));
argumentSequenceNode->children[argumentSequenceNode->childCount] = argumentNode; argumentSequenceNode->functionArgumentSequence.sequence[argumentSequenceNode->functionArgumentSequence.count] = argumentNode;
argumentSequenceNode->childCount += 1; argumentSequenceNode->functionArgumentSequence.count += 1;
return argumentSequenceNode; return argumentSequenceNode;
} }
@ -365,8 +338,8 @@ Node *MakeEmptyFunctionArgumentSequenceNode()
{ {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionArgumentSequence; node->syntaxKind = FunctionArgumentSequence;
node->childCount = 0; node->functionArgumentSequence.count = 0;
node->children = NULL; node->functionArgumentSequence.sequence = NULL;
return node; return node;
} }
@ -376,10 +349,8 @@ Node* MakeFunctionCallExpressionNode(
) { ) {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionCallExpression; node->syntaxKind = FunctionCallExpression;
node->children = (Node**) malloc(sizeof(Node*) * 2); node->functionCallExpression.identifier = identifierNode;
node->childCount = 2; node->functionCallExpression.argumentSequence = argumentSequenceNode;
node->children[0] = identifierNode;
node->children[1] = argumentSequenceNode;
return node; return node;
} }
@ -389,10 +360,8 @@ Node* MakeAccessExpressionNode(
) { ) {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = AccessExpression; node->syntaxKind = AccessExpression;
node->children = (Node**) malloc(sizeof(Node*) * 2); node->accessExpression.accessee = accessee;
node->childCount = 2; node->accessExpression.accessor = accessor;
node->children[0] = accessee;
node->children[1] = accessor;
return node; return node;
} }
@ -400,9 +369,7 @@ Node* MakeAllocNode(Node *typeNode)
{ {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = AllocExpression; node->syntaxKind = AllocExpression;
node->childCount = 1; node->allocExpression.type = typeNode;
node->children = (Node**) malloc(sizeof(Node*));
node->children[0] = typeNode;
return node; return node;
} }
@ -412,40 +379,34 @@ Node* MakeIfNode(
) { ) {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = IfStatement; node->syntaxKind = IfStatement;
node->childCount = 2; node->ifStatement.expression = expressionNode;
node->children = (Node**) malloc(sizeof(Node*)); node->ifStatement.statementSequence = statementSequenceNode;
node->children[0] = expressionNode;
node->children[1] = statementSequenceNode;
return node; return node;
} }
Node* MakeIfElseNode( Node* MakeIfElseNode(
Node *ifNode, Node *ifNode,
Node *statementSequenceNode Node *elseNode
) { ) {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = IfElseStatement; node->syntaxKind = IfElseStatement;
node->childCount = 2; node->ifElseStatement.ifStatement = ifNode;
node->children = (Node**) malloc(sizeof(Node*)); node->ifElseStatement.elseStatement = elseNode;
node->children[0] = ifNode;
node->children[1] = statementSequenceNode;
return node; return node;
} }
Node* MakeForLoopNode( Node* MakeForLoopNode(
Node *identifierNode, Node *declarationNode,
Node *startNumberNode, Node *startNumberNode,
Node *endNumberNode, Node *endNumberNode,
Node *statementSequenceNode Node *statementSequenceNode
) { ) {
Node* node = (Node*) malloc(sizeof(Node)); Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = ForLoop; node->syntaxKind = ForLoop;
node->childCount = 4; node->forLoop.declaration = declarationNode;
node->children = (Node**) malloc(sizeof(Node*) * 4); node->forLoop.startNumber = startNumberNode;
node->children[0] = identifierNode; node->forLoop.endNumber = endNumberNode;
node->children[1] = startNumberNode; node->forLoop.statementSequence = statementSequenceNode;
node->children[2] = endNumberNode;
node->children[3] = statementSequenceNode;
return node; return node;
} }
@ -462,9 +423,19 @@ static const char* PrimitiveTypeToString(PrimitiveType type)
return "Unknown"; return "Unknown";
} }
static void PrintBinaryOperator(BinaryOperator expression) static void PrintUnaryOperator(UnaryOperator operator)
{ {
switch (expression) switch (operator)
{
case Negate:
printf("!");
break;
}
}
static void PrintBinaryOperator(BinaryOperator operator)
{
switch (operator)
{ {
case Add: case Add:
printf("+"); printf("+");
@ -480,7 +451,7 @@ static void PrintBinaryOperator(BinaryOperator expression)
} }
} }
static void PrintNode(Node *node, int tabCount) void PrintNode(Node *node, uint32_t tabCount)
{ {
uint32_t i; uint32_t i;
for (i = 0; i < tabCount; i += 1) for (i = 0; i < tabCount; i += 1)
@ -491,48 +462,157 @@ static void PrintNode(Node *node, int tabCount)
printf("%s: ", SyntaxKindString(node->syntaxKind)); printf("%s: ", SyntaxKindString(node->syntaxKind));
switch (node->syntaxKind) switch (node->syntaxKind)
{ {
case BinaryExpression: case AccessExpression:
PrintBinaryOperator(node->operator.binaryOperator); PrintNode(node->accessExpression.accessee, tabCount + 1);
PrintNode(node->accessExpression.accessor, tabCount + 1);
break; break;
case Declaration: case AllocExpression:
PrintNode(node->allocExpression.type, tabCount + 1);
break;
case Assignment:
PrintNode(node->assignmentStatement.left, tabCount + 1);
PrintNode(node->assignmentStatement.right, tabCount + 1);
break;
case BinaryExpression:
PrintNode(node->binaryExpression.left, tabCount + 1);
PrintBinaryOperator(node->binaryExpression.operator);
PrintNode(node->binaryExpression.right, tabCount + 1);
break; break;
case CustomTypeNode: case CustomTypeNode:
printf("%s", node->value.string); printf("%s", node->customType.name);
break; break;
case PrimitiveTypeNode: case Declaration:
printf("%s", PrimitiveTypeToString(node->primitiveType)); PrintNode(node->declaration.identifier, tabCount + 1);
PrintNode(node->declaration.type, tabCount + 1);
break;
case DeclarationSequence:
for (i = 0; i < node->declarationSequence.count; i += 1)
{
PrintNode(node->declarationSequence.sequence[i], tabCount + 1);
}
break;
case ForLoop:
PrintNode(node->forLoop.declaration, tabCount + 1);
PrintNode(node->forLoop.startNumber, tabCount + 1);
PrintNode(node->forLoop.endNumber, tabCount + 1);
PrintNode(node->forLoop.statementSequence, tabCount + 1);
break;
case FunctionArgumentSequence:
for (i = 0; i < node->functionArgumentSequence.count; i += 1)
{
PrintNode(node->functionArgumentSequence.sequence[i], tabCount + 1);
}
break;
case FunctionCallExpression:
PrintNode(node->functionCallExpression.identifier, tabCount + 1);
PrintNode(node->functionCallExpression.argumentSequence, tabCount + 1);
break;
case FunctionDeclaration:
PrintNode(node->functionDeclaration.functionSignature, tabCount + 1);
PrintNode(node->functionDeclaration.functionBody, tabCount + 1);
break;
case FunctionModifiers:
for (i = 0; i < node->functionModifiers.count; i += 1)
{
PrintNode(node->functionModifiers.sequence[i], tabCount + 1);
}
break;
case FunctionSignature:
PrintNode(node->functionSignature.identifier, tabCount + 1);
PrintNode(node->functionSignature.arguments, tabCount + 1);
PrintNode(node->functionSignature.type, tabCount + 1);
PrintNode(node->functionSignature.modifiers, tabCount + 1);
break;
case FunctionSignatureArguments:
for (i = 0; i < node->functionSignatureArguments.count; i += 1)
{
PrintNode(node->functionSignatureArguments.sequence[i], tabCount + 1);
}
break; break;
case Identifier: case Identifier:
if (node->typeTag == NULL) { if (node->typeTag == NULL) {
printf("%s", node->value.string); printf("%s", node->identifier.name);
} else { } else {
char *type = TypeTagToString(node->typeTag); char *type = TypeTagToString(node->typeTag);
printf("%s<%s>", node->value.string, type); printf("%s<%s>", node->identifier.name, type);
} }
break; break;
case IfStatement:
PrintNode(node->ifStatement.expression, tabCount + 1);
PrintNode(node->ifStatement.statementSequence, tabCount + 1);
break;
case IfElseStatement:
PrintNode(node->ifElseStatement.ifStatement, tabCount + 1);
PrintNode(node->ifElseStatement.elseStatement, tabCount + 1);
break;
case Number: case Number:
printf("%lu", node->value.number); printf("%lu", node->number.value);
break;
case PrimitiveTypeNode:
printf("%s", PrimitiveTypeToString(node->primitiveType.type));
break;
case ReferenceTypeNode:
PrintNode(node->referenceType.type, tabCount + 1);
break;
case Return:
PrintNode(node->returnStatement.expression, tabCount + 1);
break;
case ReturnVoid:
break;
case StatementSequence:
for (i = 0; i < node->statementSequence.count; i += 1)
{
PrintNode(node->statementSequence.sequence[i], tabCount + 1);
}
break;
case StaticModifier:
break;
case StringLiteral:
printf("%s", node->stringLiteral.string);
break;
case StructDeclaration:
PrintNode(node->structDeclaration.identifier, tabCount + 1);
PrintNode(node->structDeclaration.declarationSequence, tabCount + 1);
break;
case Type:
PrintNode(node->type.typeNode, tabCount + 1);
break;
case UnaryExpression:
PrintUnaryOperator(node->unaryExpression.operator);
PrintNode(node->unaryExpression.child, tabCount + 1);
break; break;
} }
printf("\n"); printf("\n");
} }
void PrintTree(Node *node, uint32_t tabCount)
{
uint32_t i;
PrintNode(node, tabCount);
for (i = 0; i < node->childCount; i += 1)
{
PrintTree(node->children[i], tabCount + 1);
}
}
TypeTag* MakeTypeTag(Node *node) { TypeTag* MakeTypeTag(Node *node) {
if (node == NULL) { if (node == NULL) {
fprintf(stderr, "wraith: Attempted to call MakeTypeTag on null value.\n"); fprintf(stderr, "wraith: Attempted to call MakeTypeTag on null value.\n");
@ -542,40 +622,40 @@ TypeTag* MakeTypeTag(Node *node) {
TypeTag *tag = (TypeTag*)malloc(sizeof(TypeTag)); TypeTag *tag = (TypeTag*)malloc(sizeof(TypeTag));
switch (node->syntaxKind) { switch (node->syntaxKind) {
case Type: case Type:
tag = MakeTypeTag(node->children[0]); tag = MakeTypeTag(node->type.typeNode);
break; break;
case PrimitiveTypeNode: case PrimitiveTypeNode:
tag->type = Primitive; tag->type = Primitive;
tag->value.primitiveType = node->primitiveType; tag->value.primitiveType = node->primitiveType.type;
break; break;
case ReferenceTypeNode: case ReferenceTypeNode:
tag->type = Reference; tag->type = Reference;
tag->value.referenceType = MakeTypeTag(node->children[0]); tag->value.referenceType = MakeTypeTag(node->referenceType.type);
break; break;
case CustomTypeNode: case CustomTypeNode:
tag->type = Custom; tag->type = Custom;
tag->value.customType = strdup(node->value.string); tag->value.customType = strdup(node->customType.name);
break; break;
case Declaration: case Declaration:
tag = MakeTypeTag(node->children[0]); tag = MakeTypeTag(node->declaration.type);
break; break;
case StructDeclaration: case StructDeclaration:
tag->type = Custom; tag->type = Custom;
tag->value.customType = strdup(node->children[0]->value.string); tag->value.customType = strdup(node->structDeclaration.identifier->identifier.name);
printf("Struct tag: %s\n", TypeTagToString(tag)); printf("Struct tag: %s\n", TypeTagToString(tag));
break; break;
case FunctionDeclaration: case FunctionDeclaration:
tag = MakeTypeTag(node->children[0]->children[1]); tag = MakeTypeTag(node->functionDeclaration.functionSignature->functionSignature.type);
break; break;
default: default:
fprintf(stderr, fprintf(stderr,
"wraith: Attempted to call MakeTypeTag on" "wraith: Attempted to call MakeTypeTag on"
" node with unsupported SyntaxKind: %s\n", " node with unsupported SyntaxKind: %s\n",
SyntaxKindString(node->syntaxKind)); SyntaxKindString(node->syntaxKind));
@ -605,4 +685,4 @@ char* TypeTagToString(TypeTag *tag) {
case Custom: case Custom:
return tag->value.customType; return tag->value.customType;
} }
} }

201
src/ast.h
View File

@ -4,6 +4,15 @@
#include <stdint.h> #include <stdint.h>
#include "identcheck.h" #include "identcheck.h"
/* -Wpedantic nameless union/struct silencing */
#ifndef WRAITHNAMELESS
#ifdef __GNUC__
#define WRAITHNAMELESS __extension__
#else
#define WRAITHNAMELESS
#endif /* __GNUC__ */
#endif /* WRAITHNAMELESS */
typedef enum typedef enum
{ {
AccessExpression, AccessExpression,
@ -14,7 +23,6 @@ typedef enum
CustomTypeNode, CustomTypeNode,
Declaration, Declaration,
DeclarationSequence, DeclarationSequence,
Expression,
ForLoop, ForLoop,
FunctionArgumentSequence, FunctionArgumentSequence,
FunctionCallExpression, FunctionCallExpression,
@ -92,25 +100,184 @@ typedef struct TypeTag
} value; } value;
} TypeTag; } TypeTag;
typedef struct Node typedef struct Node Node;
struct Node
{ {
Node *parent;
SyntaxKind syntaxKind; SyntaxKind syntaxKind;
struct Node **children; WRAITHNAMELESS union
uint32_t childCount;
union
{ {
UnaryOperator unaryOperator; struct
BinaryOperator binaryOperator; {
} operator; Node *accessee;
union Node *accessor;
{ } accessExpression;
char *string;
uint64_t number; struct
} value; {
PrimitiveType primitiveType; Node *type;
} allocExpression;
struct
{
Node *left;
Node *right;
} assignmentStatement;
struct
{
Node *left;
Node *right;
BinaryOperator operator;
} binaryExpression;
struct
{
} comment;
struct
{
char *name;
} customType;
struct
{
Node *type;
Node *identifier;
} declaration;
struct
{
Node **sequence;
uint32_t count;
} declarationSequence;
struct
{
Node *declaration;
Node *startNumber;
Node *endNumber;
Node *statementSequence;
} forLoop;
struct
{
Node **sequence;
uint32_t count;
} functionArgumentSequence;
struct
{
Node *identifier; /* FIXME: need better name */
Node *argumentSequence;
} functionCallExpression;
struct
{
Node *functionSignature;
Node *functionBody;
} functionDeclaration;
struct
{
Node **sequence;
uint32_t count;
} functionModifiers;
struct
{
Node *identifier;
Node *type;
Node *arguments;
Node *modifiers;
} functionSignature;
struct
{
Node **sequence;
uint32_t count;
} functionSignatureArguments;
struct
{
char *name;
} identifier;
struct
{
Node *expression;
Node *statementSequence;
} ifStatement;
struct
{
Node *ifStatement;
Node *elseStatement;
} ifElseStatement;
struct
{
uint64_t value;
} number;
struct
{
PrimitiveType type;
} primitiveType;
struct
{
Node *type;
} referenceType;
struct
{
Node *expression;
} returnStatement;
struct
{
} returnVoidStatement;
struct
{
Node **sequence;
uint32_t count;
} statementSequence;
struct
{
} staticModifier; /* FIXME: modifiers should just be an enum */
struct
{
char *string;
} stringLiteral;
struct
{
Node *identifier;
Node *declarationSequence;
} structDeclaration;
struct
{
Node *typeNode;
} type; /* FIXME: this needs a refactor */
struct
{
Node *child;
UnaryOperator operator;
} unaryExpression;
};
TypeTag *typeTag; TypeTag *typeTag;
IdNode *idLink; IdNode *idLink;
} Node; };
const char* SyntaxKindString(SyntaxKind syntaxKind); const char* SyntaxKindString(SyntaxKind syntaxKind);
@ -223,7 +390,7 @@ Node* MakeIfNode(
); );
Node* MakeIfElseNode( Node* MakeIfElseNode(
Node *ifNode, Node *ifNode,
Node *statementSequenceNode Node *elseNode /* can be a conditional or a statement sequence */
); );
Node* MakeForLoopNode( Node* MakeForLoopNode(
Node *identifierNode, Node *identifierNode,
@ -232,7 +399,7 @@ Node* MakeForLoopNode(
Node *statementSequenceNode Node *statementSequenceNode
); );
void PrintTree(Node *node, uint32_t tabCount); void PrintNode(Node *node, uint32_t tabCount);
const char* SyntaxKindString(SyntaxKind syntaxKind); const char* SyntaxKindString(SyntaxKind syntaxKind);
TypeTag* MakeTypeTag(Node *node); TypeTag* MakeTypeTag(Node *node);

View File

@ -263,7 +263,7 @@ static void AddStructDeclaration(
for (i = 0; i < fieldDeclarationCount; i += 1) for (i = 0; i < fieldDeclarationCount; i += 1)
{ {
structTypeDeclarations[index].fields = realloc(structTypeDeclarations[index].fields, sizeof(StructTypeField) * (structTypeDeclarations[index].fieldCount + 1)); structTypeDeclarations[index].fields = realloc(structTypeDeclarations[index].fields, sizeof(StructTypeField) * (structTypeDeclarations[index].fieldCount + 1));
structTypeDeclarations[index].fields[i].name = strdup(fieldDeclarations[i]->children[1]->value.string); structTypeDeclarations[index].fields[i].name = strdup(fieldDeclarations[i]->declaration.identifier->identifier.name);
structTypeDeclarations[index].fields[i].index = i; structTypeDeclarations[index].fields[i].index = i;
structTypeDeclarations[index].fieldCount += 1; structTypeDeclarations[index].fieldCount += 1;
} }
@ -319,16 +319,15 @@ static LLVMTypeRef ResolveType(Node* typeNode)
{ {
if (IsPrimitiveType(typeNode)) if (IsPrimitiveType(typeNode))
{ {
return WraithTypeToLLVMType(typeNode->children[0]->primitiveType); return WraithTypeToLLVMType(typeNode->type.typeNode->primitiveType.type);
} }
else if (typeNode->children[0]->syntaxKind == CustomTypeNode) else if (typeNode->type.typeNode->syntaxKind == CustomTypeNode)
{ {
char *typeName = typeNode->children[0]->value.string; return LookupCustomType(typeNode->type.typeNode->customType.name);
return LookupCustomType(typeName);
} }
else if (typeNode->children[0]->syntaxKind == ReferenceTypeNode) else if (typeNode->type.typeNode->syntaxKind == ReferenceTypeNode)
{ {
return LLVMPointerType(ResolveType(typeNode->children[0]->children[0]), 0); return LLVMPointerType(ResolveType(typeNode->type.typeNode->referenceType.type), 0);
} }
else else
{ {
@ -443,24 +442,24 @@ static LLVMValueRef CompileExpression(
static LLVMValueRef CompileNumber( static LLVMValueRef CompileNumber(
Node *numberExpression Node *numberExpression
) { ) {
return LLVMConstInt(LLVMInt64Type(), numberExpression->value.number, 0); return LLVMConstInt(LLVMInt64Type(), numberExpression->number.value, 0);
} }
static LLVMValueRef CompileString( static LLVMValueRef CompileString(
LLVMBuilderRef builder, LLVMBuilderRef builder,
Node *stringExpression Node *stringExpression
) { ) {
return LLVMBuildGlobalStringPtr(builder, stringExpression->value.string, "stringConstant"); return LLVMBuildGlobalStringPtr(builder, stringExpression->stringLiteral.string, "stringConstant");
} }
static LLVMValueRef CompileBinaryExpression( static LLVMValueRef CompileBinaryExpression(
LLVMBuilderRef builder, LLVMBuilderRef builder,
Node *binaryExpression Node *binaryExpression
) { ) {
LLVMValueRef left = CompileExpression(builder, binaryExpression->children[0]); LLVMValueRef left = CompileExpression(builder, binaryExpression->binaryExpression.left);
LLVMValueRef right = CompileExpression(builder, binaryExpression->children[1]); LLVMValueRef right = CompileExpression(builder, binaryExpression->binaryExpression.right);
switch (binaryExpression->operator.binaryOperator) switch (binaryExpression->binaryExpression.operator)
{ {
case Add: case Add:
return LLVMBuildAdd(builder, left, right, "addResult"); return LLVMBuildAdd(builder, left, right, "addResult");
@ -494,11 +493,11 @@ static LLVMValueRef CompileBinaryExpression(
/* FIXME THIS IS ALL BROKEN */ /* FIXME THIS IS ALL BROKEN */
static LLVMValueRef CompileFunctionCallExpression( static LLVMValueRef CompileFunctionCallExpression(
LLVMBuilderRef builder, LLVMBuilderRef builder,
Node *expression Node *functionCallExpression
) { ) {
uint32_t i; uint32_t i;
uint32_t argumentCount = 0; uint32_t argumentCount = 0;
LLVMValueRef args[expression->children[1]->childCount + 1]; LLVMValueRef args[functionCallExpression->functionCallExpression.argumentSequence->functionArgumentSequence.count + 1];
LLVMValueRef function; LLVMValueRef function;
uint8_t isStatic; uint8_t isStatic;
LLVMValueRef structInstance; LLVMValueRef structInstance;
@ -506,25 +505,26 @@ static LLVMValueRef CompileFunctionCallExpression(
char *returnName = ""; char *returnName = "";
/* FIXME: this needs to be recursive on access chains */ /* FIXME: this needs to be recursive on access chains */
if (expression->children[0]->syntaxKind == AccessExpression) /* FIXME: this needs to be able to call same-struct functions implicitly */
if (functionCallExpression->functionCallExpression.identifier->syntaxKind == AccessExpression)
{ {
LLVMTypeRef typeReference = FindStructType( LLVMTypeRef typeReference = FindStructType(
expression->children[0]->children[0]->value.string functionCallExpression->functionCallExpression.identifier->identifier.name
); );
if (typeReference != NULL) if (typeReference != NULL)
{ {
function = LookupFunctionByType( function = LookupFunctionByType(
typeReference, typeReference,
expression->children[0]->children[1]->value.string, functionCallExpression->functionCallExpression.identifier->accessExpression.accessor->identifier.name,
&functionReturnType, &functionReturnType,
&isStatic &isStatic
); );
} }
else else
{ {
structInstance = FindVariablePointer(expression->children[0]->children[0]->value.string); structInstance = FindVariablePointer(functionCallExpression->functionCallExpression.identifier->accessExpression.accessee->identifier.name);
function = LookupFunctionByInstance(structInstance, expression->children[0]->children[1]->value.string, &functionReturnType, &isStatic); function = LookupFunctionByInstance(structInstance, functionCallExpression->functionCallExpression.identifier->accessExpression.accessor->identifier.name, &functionReturnType, &isStatic);
} }
} }
else else
@ -539,9 +539,9 @@ static LLVMValueRef CompileFunctionCallExpression(
argumentCount += 1; argumentCount += 1;
} }
for (i = 0; i < expression->children[1]->childCount; i += 1) for (i = 0; i < functionCallExpression->functionCallExpression.argumentSequence->functionArgumentSequence.count; i += 1)
{ {
args[argumentCount] = CompileExpression(builder, expression->children[1]->children[i]); args[argumentCount] = CompileExpression(builder, functionCallExpression->functionCallExpression.argumentSequence->functionArgumentSequence.sequence[i]);
argumentCount += 1; argumentCount += 1;
} }
@ -555,30 +555,26 @@ static LLVMValueRef CompileFunctionCallExpression(
static LLVMValueRef CompileAccessExpressionForStore( static LLVMValueRef CompileAccessExpressionForStore(
LLVMBuilderRef builder, LLVMBuilderRef builder,
Node *expression Node *accessExpression
) { ) {
Node *accessee = expression->children[0]; LLVMValueRef accesseeValue = FindVariablePointer(accessExpression->accessExpression.accessee->identifier.name);
Node *accessor = expression->children[1]; return FindStructFieldPointer(builder, accesseeValue, accessExpression->accessExpression.accessor->identifier.name);
LLVMValueRef accesseeValue = FindVariablePointer(accessee->value.string);
return FindStructFieldPointer(builder, accesseeValue, accessor->value.string);
} }
static LLVMValueRef CompileAccessExpression( static LLVMValueRef CompileAccessExpression(
LLVMBuilderRef builder, LLVMBuilderRef builder,
Node *expression Node *accessExpression
) { ) {
Node *accessee = expression->children[0]; LLVMValueRef accesseeValue = FindVariablePointer(accessExpression->accessExpression.accessee->identifier.name);
Node *accessor = expression->children[1]; LLVMValueRef access = FindStructFieldPointer(builder, accesseeValue, accessExpression->accessExpression.accessor->identifier.name);
LLVMValueRef accesseeValue = FindVariablePointer(accessee->value.string); return LLVMBuildLoad(builder, access, accessExpression->accessExpression.accessor->identifier.name);
LLVMValueRef access = FindStructFieldPointer(builder, accesseeValue, accessor->value.string);
return LLVMBuildLoad(builder, access, accessor->value.string);
} }
static LLVMValueRef CompileAllocExpression( static LLVMValueRef CompileAllocExpression(
LLVMBuilderRef builder, LLVMBuilderRef builder,
Node *expression Node *allocExpression
) { ) {
LLVMTypeRef type = ResolveType(expression->children[0]); LLVMTypeRef type = ResolveType(allocExpression->allocExpression.type);
return LLVMBuildMalloc(builder, type, "allocation"); return LLVMBuildMalloc(builder, type, "allocation");
} }
@ -601,7 +597,7 @@ static LLVMValueRef CompileExpression(
return CompileFunctionCallExpression(builder, expression); return CompileFunctionCallExpression(builder, expression);
case Identifier: case Identifier:
return FindVariableValue(builder, expression->value.string); return FindVariableValue(builder, expression->identifier.name);
case Number: case Number:
return CompileNumber(expression); return CompileNumber(expression);
@ -619,7 +615,7 @@ static LLVMBasicBlockRef CompileStatement(LLVMBuilderRef builder, LLVMValueRef f
static LLVMBasicBlockRef CompileReturn(LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement) static LLVMBasicBlockRef CompileReturn(LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement)
{ {
LLVMValueRef expression = CompileExpression(builder, returnStatemement->children[0]); LLVMValueRef expression = CompileExpression(builder, returnStatemement->returnStatement.expression);
LLVMBuildRet(builder, expression); LLVMBuildRet(builder, expression);
return LLVMGetLastBasicBlock(function); return LLVMGetLastBasicBlock(function);
} }
@ -634,11 +630,11 @@ static LLVMBasicBlockRef CompileReturnVoid(LLVMBuilderRef builder, LLVMValueRef
static LLVMValueRef CompileFunctionVariableDeclaration(LLVMBuilderRef builder, LLVMValueRef function, Node *variableDeclaration) static LLVMValueRef CompileFunctionVariableDeclaration(LLVMBuilderRef builder, LLVMValueRef function, Node *variableDeclaration)
{ {
LLVMValueRef variable; LLVMValueRef variable;
char *variableName = variableDeclaration->children[1]->value.string; char *variableName = variableDeclaration->declaration.identifier->identifier.name;
char *ptrName = strdup(variableName); char *ptrName = strdup(variableName);
strcat(ptrName, "_ptr"); strcat(ptrName, "_ptr");
variable = LLVMBuildAlloca(builder, ResolveType(variableDeclaration->children[0]), ptrName); variable = LLVMBuildAlloca(builder, ResolveType(variableDeclaration->declaration.type), ptrName);
free(ptrName); free(ptrName);
@ -649,19 +645,19 @@ static LLVMValueRef CompileFunctionVariableDeclaration(LLVMBuilderRef builder, L
static LLVMBasicBlockRef CompileAssignment(LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement) static LLVMBasicBlockRef CompileAssignment(LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement)
{ {
LLVMValueRef result = CompileExpression(builder, assignmentStatement->children[1]); LLVMValueRef result = CompileExpression(builder, assignmentStatement->assignmentStatement.right);
LLVMValueRef identifier; LLVMValueRef identifier;
if (assignmentStatement->children[0]->syntaxKind == AccessExpression) if (assignmentStatement->assignmentStatement.left->syntaxKind == AccessExpression)
{ {
identifier = CompileAccessExpressionForStore(builder, assignmentStatement->children[0]); identifier = CompileAccessExpressionForStore(builder, assignmentStatement->assignmentStatement.left);
} }
else if (assignmentStatement->children[0]->syntaxKind == Identifier) else if (assignmentStatement->assignmentStatement.left->syntaxKind == Identifier)
{ {
identifier = FindVariablePointer(assignmentStatement->children[0]->value.string); identifier = FindVariablePointer(assignmentStatement->assignmentStatement.left->identifier.name);
} }
else if (assignmentStatement->children[0]->syntaxKind == Declaration) else if (assignmentStatement->assignmentStatement.left->syntaxKind == Declaration)
{ {
identifier = CompileFunctionVariableDeclaration(builder, function, assignmentStatement->children[0]); identifier = CompileFunctionVariableDeclaration(builder, function, assignmentStatement->assignmentStatement.left);
} }
else else
{ {
@ -677,7 +673,7 @@ static LLVMBasicBlockRef CompileAssignment(LLVMBuilderRef builder, LLVMValueRef
static LLVMBasicBlockRef CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement) static LLVMBasicBlockRef CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement)
{ {
uint32_t i; uint32_t i;
LLVMValueRef conditional = CompileExpression(builder, ifStatement->children[0]); LLVMValueRef conditional = CompileExpression(builder, ifStatement->ifStatement.expression);
LLVMBasicBlockRef block = LLVMAppendBasicBlock(function, "ifBlock"); LLVMBasicBlockRef block = LLVMAppendBasicBlock(function, "ifBlock");
LLVMBasicBlockRef afterCond = LLVMAppendBasicBlock(function, "afterCond"); LLVMBasicBlockRef afterCond = LLVMAppendBasicBlock(function, "afterCond");
@ -686,9 +682,9 @@ static LLVMBasicBlockRef CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef
LLVMPositionBuilderAtEnd(builder, block); LLVMPositionBuilderAtEnd(builder, block);
for (i = 0; i < ifStatement->children[1]->childCount; i += 1) for (i = 0; i < ifStatement->ifStatement.statementSequence->statementSequence.count; i += 1)
{ {
CompileStatement(builder, function, ifStatement->children[1]->children[i]); CompileStatement(builder, function, ifStatement->ifStatement.statementSequence->statementSequence.sequence[i]);
} }
LLVMBuildBr(builder, afterCond); LLVMBuildBr(builder, afterCond);
@ -700,7 +696,7 @@ static LLVMBasicBlockRef CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef
static LLVMBasicBlockRef CompileIfElseStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifElseStatement) static LLVMBasicBlockRef CompileIfElseStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifElseStatement)
{ {
uint32_t i; uint32_t i;
LLVMValueRef conditional = CompileExpression(builder, ifElseStatement->children[0]->children[0]); LLVMValueRef conditional = CompileExpression(builder, ifElseStatement->ifElseStatement.ifStatement->ifStatement.expression);
LLVMBasicBlockRef ifBlock = LLVMAppendBasicBlock(function, "ifBlock"); LLVMBasicBlockRef ifBlock = LLVMAppendBasicBlock(function, "ifBlock");
LLVMBasicBlockRef elseBlock = LLVMAppendBasicBlock(function, "elseBlock"); LLVMBasicBlockRef elseBlock = LLVMAppendBasicBlock(function, "elseBlock");
@ -710,25 +706,25 @@ static LLVMBasicBlockRef CompileIfElseStatement(LLVMBuilderRef builder, LLVMValu
LLVMPositionBuilderAtEnd(builder, ifBlock); LLVMPositionBuilderAtEnd(builder, ifBlock);
for (i = 0; i < ifElseStatement->children[0]->children[1]->childCount; i += 1) for (i = 0; i < ifElseStatement->ifElseStatement.ifStatement->ifStatement.statementSequence->statementSequence.count; i += 1)
{ {
CompileStatement(builder, function, ifElseStatement->children[0]->children[1]->children[i]); CompileStatement(builder, function, ifElseStatement->ifStatement.statementSequence->statementSequence.sequence[i]);
} }
LLVMBuildBr(builder, afterCond); LLVMBuildBr(builder, afterCond);
LLVMPositionBuilderAtEnd(builder, elseBlock); LLVMPositionBuilderAtEnd(builder, elseBlock);
if (ifElseStatement->children[1]->syntaxKind == StatementSequence) if (ifElseStatement->ifElseStatement.elseStatement->syntaxKind == StatementSequence)
{ {
for (i = 0; i < ifElseStatement->children[1]->childCount; i += 1) for (i = 0; i < ifElseStatement->ifElseStatement.elseStatement->statementSequence.count; i += 1)
{ {
CompileStatement(builder, function, ifElseStatement->children[1]->children[i]); CompileStatement(builder, function, ifElseStatement->ifElseStatement.elseStatement->statementSequence.sequence[i]);
} }
} }
else else
{ {
CompileStatement(builder, function, ifElseStatement->children[1]); CompileStatement(builder, function, ifElseStatement->ifElseStatement.elseStatement);
} }
LLVMBuildBr(builder, afterCond); LLVMBuildBr(builder, afterCond);
@ -744,8 +740,8 @@ static LLVMBasicBlockRef CompileForLoopStatement(LLVMBuilderRef builder, LLVMVal
LLVMBasicBlockRef checkBlock = LLVMAppendBasicBlock(function, "loopCheck"); LLVMBasicBlockRef checkBlock = LLVMAppendBasicBlock(function, "loopCheck");
LLVMBasicBlockRef bodyBlock = LLVMAppendBasicBlock(function, "loopBody"); LLVMBasicBlockRef bodyBlock = LLVMAppendBasicBlock(function, "loopBody");
LLVMBasicBlockRef afterLoopBlock = LLVMAppendBasicBlock(function, "afterLoop"); LLVMBasicBlockRef afterLoopBlock = LLVMAppendBasicBlock(function, "afterLoop");
char *iteratorVariableName = forLoopStatement->children[0]->children[1]->value.string; char *iteratorVariableName = forLoopStatement->forLoop.declaration->declaration.identifier->identifier.name;
LLVMTypeRef iteratorVariableType = ResolveType(forLoopStatement->children[0]->children[0]); LLVMTypeRef iteratorVariableType = ResolveType(forLoopStatement->forLoop.declaration->declaration.type);
PushScopeFrame(scope); PushScopeFrame(scope);
@ -762,13 +758,13 @@ static LLVMBasicBlockRef CompileForLoopStatement(LLVMBuilderRef builder, LLVMVal
LLVMValueRef nextValue = LLVMBuildAdd( LLVMValueRef nextValue = LLVMBuildAdd(
builder, builder,
iteratorValue, iteratorValue,
LLVMConstInt(iteratorVariableType, forLoopStatement->children[1]->value.number, 0), LLVMConstInt(iteratorVariableType, 1, 0), /* FIXME: add custom increment value */
"next" "next"
); );
LLVMPositionBuilderAtEnd(builder, checkBlock); LLVMPositionBuilderAtEnd(builder, checkBlock);
LLVMValueRef iteratorEndValue = CompileNumber(forLoopStatement->children[2]); LLVMValueRef iteratorEndValue = CompileNumber(forLoopStatement->forLoop.endNumber);
LLVMValueRef comparison = LLVMBuildICmp(builder, LLVMIntULE, iteratorValue, iteratorEndValue, "iteratorCompare"); LLVMValueRef comparison = LLVMBuildICmp(builder, LLVMIntULE, iteratorValue, iteratorEndValue, "iteratorCompare");
LLVMBuildCondBr(builder, comparison, bodyBlock, afterLoopBlock); LLVMBuildCondBr(builder, comparison, bodyBlock, afterLoopBlock);
@ -776,9 +772,9 @@ static LLVMBasicBlockRef CompileForLoopStatement(LLVMBuilderRef builder, LLVMVal
LLVMPositionBuilderAtEnd(builder, bodyBlock); LLVMPositionBuilderAtEnd(builder, bodyBlock);
LLVMBasicBlockRef lastBlock; LLVMBasicBlockRef lastBlock;
for (i = 0; i < forLoopStatement->children[3]->childCount; i += 1) for (i = 0; i < forLoopStatement->forLoop.statementSequence->statementSequence.count; i += 1)
{ {
lastBlock = CompileStatement(builder, function, forLoopStatement->children[3]->children[i]); lastBlock = CompileStatement(builder, function, forLoopStatement->forLoop.statementSequence->statementSequence.sequence[i]);
} }
LLVMBuildBr(builder, checkBlock); LLVMBuildBr(builder, checkBlock);
@ -786,7 +782,7 @@ static LLVMBasicBlockRef CompileForLoopStatement(LLVMBuilderRef builder, LLVMVal
LLVMPositionBuilderBefore(builder, LLVMGetFirstInstruction(checkBlock)); LLVMPositionBuilderBefore(builder, LLVMGetFirstInstruction(checkBlock));
LLVMValueRef incomingValues[2]; LLVMValueRef incomingValues[2];
incomingValues[0] = CompileNumber(forLoopStatement->children[1]); incomingValues[0] = CompileNumber(forLoopStatement->forLoop.startNumber);
incomingValues[1] = nextValue; incomingValues[1] = nextValue;
LLVMBasicBlockRef incomingBlocks[2]; LLVMBasicBlockRef incomingBlocks[2];
@ -848,17 +844,17 @@ static void CompileFunction(
uint32_t i; uint32_t i;
uint8_t hasReturn = 0; uint8_t hasReturn = 0;
uint8_t isStatic = 0; uint8_t isStatic = 0;
Node *functionSignature = functionDeclaration->children[0]; Node *functionSignature = functionDeclaration->functionDeclaration.functionSignature;
Node *functionBody = functionDeclaration->children[1]; Node *functionBody = functionDeclaration->functionDeclaration.functionBody;
uint32_t argumentCount = functionSignature->children[2]->childCount; uint32_t argumentCount = functionSignature->functionSignature.arguments->functionSignatureArguments.count;
LLVMTypeRef paramTypes[argumentCount + 1]; LLVMTypeRef paramTypes[argumentCount + 1];
uint32_t paramIndex = 0; uint32_t paramIndex = 0;
if (functionSignature->children[3]->childCount > 0) if (functionSignature->functionSignature.modifiers->functionModifiers.count > 0)
{ {
for (i = 0; i < functionSignature->children[3]->childCount; i += 1) for (i = 0; i < functionSignature->functionSignature.modifiers->functionModifiers.count; i += 1)
{ {
if (functionSignature->children[3]->children[i]->syntaxKind == StaticModifier) if (functionSignature->functionSignature.modifiers->functionModifiers.sequence[i]->syntaxKind == StaticModifier)
{ {
isStatic = 1; isStatic = 1;
break; break;
@ -875,22 +871,22 @@ static void CompileFunction(
PushScopeFrame(scope); PushScopeFrame(scope);
/* FIXME: should work for non-primitive types */ /* FIXME: should work for non-primitive types */
for (i = 0; i < functionSignature->children[2]->childCount; i += 1) for (i = 0; i < functionSignature->functionSignature.arguments->functionSignatureArguments.count; i += 1)
{ {
paramTypes[paramIndex] = WraithTypeToLLVMType(functionSignature->children[2]->children[i]->children[0]->children[0]->primitiveType); paramTypes[paramIndex] = ResolveType(functionSignature->functionSignature.arguments->functionSignatureArguments.sequence[i]->declaration.type);
paramIndex += 1; paramIndex += 1;
} }
LLVMTypeRef returnType = WraithTypeToLLVMType(functionSignature->children[1]->children[0]->primitiveType); LLVMTypeRef returnType = ResolveType(functionSignature->functionSignature.type);
LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, paramIndex, 0); LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, paramIndex, 0);
char *functionName = strdup(parentStructName); char *functionName = strdup(parentStructName);
strcat(functionName, "_"); strcat(functionName, "_");
strcat(functionName, functionSignature->children[0]->value.string); strcat(functionName, functionSignature->functionSignature.identifier->identifier.name);
LLVMValueRef function = LLVMAddFunction(module, functionName, functionType); LLVMValueRef function = LLVMAddFunction(module, functionName, functionType);
free(functionName); free(functionName);
DeclareStructFunction(wStructPointerType, function, returnType, isStatic, functionSignature->children[0]->value.string); DeclareStructFunction(wStructPointerType, function, returnType, isStatic, functionSignature->functionSignature.identifier->identifier.name);
LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry"); LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry");
LLVMBuilderRef builder = LLVMCreateBuilder(); LLVMBuilderRef builder = LLVMCreateBuilder();
@ -902,20 +898,20 @@ static void CompileFunction(
AddStructVariablesToScope(builder, wStructPointer); AddStructVariablesToScope(builder, wStructPointer);
} }
for (i = 0; i < functionSignature->children[2]->childCount; i += 1) for (i = 0; i < functionSignature->functionSignature.arguments->functionSignatureArguments.count; i += 1)
{ {
char *ptrName = strdup(functionSignature->children[2]->children[i]->children[1]->value.string); char *ptrName = strdup(functionSignature->functionSignature.arguments->functionSignatureArguments.sequence[i]->identifier.name);
strcat(ptrName, "_ptr"); strcat(ptrName, "_ptr");
LLVMValueRef argument = LLVMGetParam(function, i + !isStatic); LLVMValueRef argument = LLVMGetParam(function, i + !isStatic);
LLVMValueRef argumentCopy = LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName); LLVMValueRef argumentCopy = LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName);
LLVMBuildStore(builder, argument, argumentCopy); LLVMBuildStore(builder, argument, argumentCopy);
free(ptrName); free(ptrName);
AddLocalVariable(scope, argumentCopy, NULL, functionSignature->children[2]->children[i]->children[1]->value.string); AddLocalVariable(scope, argumentCopy, NULL, functionSignature->functionSignature.arguments->functionSignatureArguments.sequence[i]->identifier.name);
} }
for (i = 0; i < functionBody->childCount; i += 1) for (i = 0; i < functionBody->statementSequence.count; i += 1)
{ {
CompileStatement(builder, function, functionBody->children[i]); CompileStatement(builder, function, functionBody->statementSequence.sequence[i]);
} }
hasReturn = LLVMGetBasicBlockTerminator(LLVMGetLastBasicBlock(function)) != NULL; hasReturn = LLVMGetBasicBlockTerminator(LLVMGetLastBasicBlock(function)) != NULL;
@ -938,12 +934,12 @@ static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *no
{ {
uint32_t i; uint32_t i;
uint32_t fieldCount = 0; uint32_t fieldCount = 0;
uint32_t declarationCount = node->children[1]->childCount; uint32_t declarationCount = node->structDeclaration.declarationSequence->declarationSequence.count;
uint8_t packed = 1; uint8_t packed = 1;
LLVMTypeRef types[declarationCount]; LLVMTypeRef types[declarationCount];
Node *currentDeclarationNode; Node *currentDeclarationNode;
Node *fieldDeclarations[declarationCount]; Node *fieldDeclarations[declarationCount];
char *structName = node->children[0]->value.string; char *structName = node->structDeclaration.identifier->identifier.name;
PushScopeFrame(scope); PushScopeFrame(scope);
@ -953,12 +949,12 @@ static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *no
/* first, build the structure definition */ /* first, build the structure definition */
for (i = 0; i < declarationCount; i += 1) for (i = 0; i < declarationCount; i += 1)
{ {
currentDeclarationNode = node->children[1]->children[i]; currentDeclarationNode = node->structDeclaration.declarationSequence->declarationSequence.sequence[i];
switch (currentDeclarationNode->syntaxKind) switch (currentDeclarationNode->syntaxKind)
{ {
case Declaration: /* this is badly named */ case Declaration: /* this is badly named */
types[fieldCount] = ResolveType(currentDeclarationNode->children[0]); types[fieldCount] = ResolveType(currentDeclarationNode->declaration.type);
fieldDeclarations[fieldCount] = currentDeclarationNode; fieldDeclarations[fieldCount] = currentDeclarationNode;
fieldCount += 1; fieldCount += 1;
break; break;
@ -966,12 +962,12 @@ static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *no
} }
LLVMStructSetBody(wStructType, types, fieldCount, packed); LLVMStructSetBody(wStructType, types, fieldCount, packed);
AddStructDeclaration(wStructType, wStructPointerType, node->children[0]->value.string, fieldDeclarations, fieldCount); AddStructDeclaration(wStructType, wStructPointerType, structName, fieldDeclarations, fieldCount);
/* now we can wire up the functions */ /* now we can wire up the functions */
for (i = 0; i < declarationCount; i += 1) for (i = 0; i < declarationCount; i += 1)
{ {
currentDeclarationNode = node->children[1]->children[i]; currentDeclarationNode = node->structDeclaration.declarationSequence->declarationSequence.sequence[i];
switch (currentDeclarationNode->syntaxKind) switch (currentDeclarationNode->syntaxKind)
{ {
@ -984,15 +980,15 @@ static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *no
PopScopeFrame(scope); PopScopeFrame(scope);
} }
static void Compile(LLVMModuleRef module, LLVMContextRef context, Node *node) static void Compile(LLVMModuleRef module, LLVMContextRef context, Node *declarationSequenceNode)
{ {
uint32_t i; uint32_t i;
for (i = 0; i < node->childCount; i += 1) for (i = 0; i < declarationSequenceNode->declarationSequence.count; i += 1)
{ {
if (node->children[i]->syntaxKind == StructDeclaration) if (declarationSequenceNode->declarationSequence.sequence[i]->syntaxKind == StructDeclaration)
{ {
CompileStruct(module, context, node->children[i]); CompileStruct(module, context, declarationSequenceNode->declarationSequence.sequence[i]);
} }
else else
{ {

View File

@ -69,7 +69,7 @@ int main(int argc, char *argv[])
IdNode *idTree = MakeIdTree(rootNode, NULL); IdNode *idTree = MakeIdTree(rootNode, NULL);
PrintIdTree(idTree, /*tabCount=*/0); PrintIdTree(idTree, /*tabCount=*/0);
printf("\n"); printf("\n");
PrintTree(rootNode, /*tabCount=*/0); PrintNode(rootNode, /*tabCount=*/0);
} }
exitCode = Codegen(rootNode, optimizationLevel); exitCode = Codegen(rootNode, optimizationLevel);
} }

View File

@ -31,7 +31,7 @@ int Parse(char *inputFilename, Node **pRootNode, uint8_t parseVerbose)
{ {
if (parseVerbose) if (parseVerbose)
{ {
PrintTree(*pRootNode, 0); PrintNode(*pRootNode, 0);
} }
} }
else if (result == 1) else if (result == 1)