wraith-lang/src/ast.c

607 lines
16 KiB
C

#include "ast.h"
#include <stdlib.h>
#include <stdio.h>
#include "util.h"
const char* SyntaxKindString(SyntaxKind syntaxKind)
{
switch(syntaxKind)
{
case AccessExpression: return "AccessExpression";
case AllocExpression: return "Alloc";
case Assignment: return "Assignment";
case BinaryExpression: return "BinaryExpression";
case Comment: return "Comment";
case CustomTypeNode: return "CustomTypeNode";
case Declaration: return "Declaration";
case Expression: return "Expression";
case ForLoop: return "ForLoop";
case DeclarationSequence: return "DeclarationSequence";
case FunctionArgumentSequence: return "FunctionArgumentSequence";
case FunctionCallExpression: return "FunctionCallExpression";
case FunctionDeclaration: return "FunctionDeclaration";
case FunctionModifiers: return "FunctionModifiers";
case FunctionSignature: return "FunctionSignature";
case FunctionSignatureArguments: return "FunctionSignatureArguments";
case Identifier: return "Identifier";
case IfStatement: return "If";
case IfElseStatement: return "IfElse";
case Number: return "Number";
case PrimitiveTypeNode: return "PrimitiveTypeNode";
case ReferenceTypeNode: return "ReferenceTypeNode";
case Return: return "Return";
case StatementSequence: return "StatementSequence";
case StaticModifier: return "StaticModifier";
case StringLiteral: return "StringLiteral";
case StructDeclaration: return "StructDeclaration";
case Type: return "Type";
case UnaryExpression: return "UnaryExpression";
default: return "Unknown";
}
}
uint8_t IsPrimitiveType(
Node *typeNode
) {
return typeNode->children[0]->syntaxKind == PrimitiveTypeNode;
}
Node* MakePrimitiveTypeNode(
PrimitiveType type
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = PrimitiveTypeNode;
node->primitiveType = type;
node->childCount = 0;
return node;
}
Node* MakeCustomTypeNode(
char *name
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = CustomTypeNode;
node->value.string = strdup(name);
node->childCount = 0;
return node;
}
Node* MakeReferenceTypeNode(
Node *typeNode
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = ReferenceTypeNode;
node->childCount = 1;
node->children = (Node**) malloc(sizeof(Node*));
node->children[0] = typeNode;
return node;
}
Node* MakeTypeNode(
Node* typeNode
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = Type;
node->childCount = 1;
node->children = (Node**) malloc(sizeof(Node*));
node->children[0] = typeNode;
return node;
}
Node* MakeIdentifierNode(
const char *id
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = Identifier;
node->value.string = strdup(id);
node->childCount = 0;
return node;
}
Node* MakeNumberNode(
const char *numberString
) {
char *ptr;
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = Number;
node->value.number = strtoul(numberString, &ptr, 10);
node->childCount = 0;
return node;
}
Node* MakeStringNode(
const char *string
) {
size_t slen = strlen(string);
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = StringLiteral;
node->value.string = strndup(string + 1, slen - 2);
node->childCount = 0;
return node;
}
Node* MakeStaticNode()
{
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = StaticModifier;
node->childCount = 0;
return node;
}
Node* MakeFunctionModifiersNode(
Node **pModifierNodes,
uint32_t modifierCount
) {
uint32_t i;
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionModifiers;
node->childCount = modifierCount;
if (modifierCount > 0)
{
node->children = malloc(sizeof(Node*) * node->childCount);
for (i = 0; i < modifierCount; i += 1)
{
node->children[i] = pModifierNodes[i];
}
}
return node;
}
Node* MakeUnaryNode(
UnaryOperator operator,
Node *child
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = UnaryExpression;
node->operator.unaryOperator = operator;
node->children = malloc(sizeof(Node*));
node->children[0] = child;
node->childCount = 1;
return node;
}
Node* MakeBinaryNode(
BinaryOperator operator,
Node *left,
Node *right
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = BinaryExpression;
node->operator.binaryOperator = operator;
node->children = malloc(sizeof(Node*) * 2);
node->children[0] = left;
node->children[1] = right;
node->childCount = 2;
return node;
}
Node* MakeDeclarationNode(
Node* typeNode,
Node* identifierNode
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = Declaration;
node->children = (Node**) malloc(sizeof(Node*) * 2);
node->childCount = 2;
node->children[0] = typeNode;
node->children[1] = identifierNode;
return node;
}
Node* MakeAssignmentNode(
Node *left,
Node *right
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = Assignment;
node->childCount = 2;
node->children = malloc(sizeof(Node*) * 2);
node->children[0] = left;
node->children[1] = right;
return node;
}
Node* StartStatementSequenceNode(
Node *statementNode
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = StatementSequence;
node->children = (Node**) malloc(sizeof(Node*));
node->childCount = 1;
node->children[0] = statementNode;
return node;
}
Node* AddStatement(
Node* statementSequenceNode,
Node *statementNode
) {
statementSequenceNode->children = realloc(statementSequenceNode->children, sizeof(Node*) * (statementSequenceNode->childCount + 1));
statementSequenceNode->children[statementSequenceNode->childCount] = statementNode;
statementSequenceNode->childCount += 1;
return statementSequenceNode;
}
Node* MakeReturnStatementNode(
Node *expressionNode
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = Return;
node->children = (Node**) malloc(sizeof(Node*));
node->childCount = 1;
node->children[0] = expressionNode;
return node;
}
Node* MakeReturnVoidStatementNode()
{
Node *node = (Node*) malloc(sizeof(Node));
node->syntaxKind = ReturnVoid;
node->childCount = 0;
node->children = NULL;
return node;
}
Node *StartFunctionSignatureArgumentsNode(
Node *argumentNode
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionSignatureArguments;
node->childCount = 1;
node->children = (Node**) malloc(sizeof(Node*));
node->children[0] = argumentNode;
return node;
}
Node* AddFunctionSignatureArgumentNode(
Node *argumentsNode,
Node *argumentNode
) {
argumentsNode->children = realloc(argumentsNode->children, sizeof(Node*) * (argumentsNode->childCount + 1));
argumentsNode->children[argumentsNode->childCount] = argumentNode;
argumentsNode->childCount += 1;
return argumentsNode;
}
Node *MakeEmptyFunctionSignatureArgumentsNode()
{
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionSignatureArguments;
node->childCount = 0;
node->children = NULL;
return node;
}
Node* MakeFunctionSignatureNode(
Node *identifierNode,
Node* typeNode,
Node* arguments,
Node* modifiersNode
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionSignature;
node->childCount = 4;
node->children = (Node**) malloc(sizeof(Node*) * (node->childCount));
node->children[0] = identifierNode;
node->children[1] = typeNode;
node->children[2] = arguments;
node->children[3] = modifiersNode;
return node;
}
Node* MakeFunctionDeclarationNode(
Node* functionSignatureNode,
Node* functionBodyNode
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionDeclaration;
node->childCount = 2;
node->children = (Node**) malloc(sizeof(Node*) * 2);
node->children[0] = functionSignatureNode;
node->children[1] = functionBodyNode;
return node;
}
Node* MakeStructDeclarationNode(
Node *identifierNode,
Node *declarationSequenceNode
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = StructDeclaration;
node->childCount = 2;
node->children = (Node**) malloc(sizeof(Node*) * 2);
node->children[0] = identifierNode;
node->children[1] = declarationSequenceNode;
return node;
}
Node* StartDeclarationSequenceNode(
Node *declarationNode
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = DeclarationSequence;
node->children = (Node**) malloc(sizeof(Node*));
node->childCount = 1;
node->children[0] = declarationNode;
return node;
}
Node* AddDeclarationNode(
Node *declarationSequenceNode,
Node *declarationNode
) {
declarationSequenceNode->children = realloc(declarationSequenceNode->children, sizeof(Node*) * (declarationSequenceNode->childCount + 1));
declarationSequenceNode->children[declarationSequenceNode->childCount] = declarationNode;
declarationSequenceNode->childCount += 1;
return declarationSequenceNode;
}
Node* StartFunctionArgumentSequenceNode(
Node *argumentNode
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionArgumentSequence;
node->childCount = 1;
node->children = (Node**) malloc(sizeof(Node*));
node->children[0] = argumentNode;
return node;
}
Node* AddFunctionArgumentNode(
Node *argumentSequenceNode,
Node *argumentNode
) {
argumentSequenceNode->children = realloc(argumentSequenceNode->children, sizeof(Node*) * (argumentSequenceNode->childCount + 1));
argumentSequenceNode->children[argumentSequenceNode->childCount] = argumentNode;
argumentSequenceNode->childCount += 1;
return argumentSequenceNode;
}
Node *MakeEmptyFunctionArgumentSequenceNode()
{
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionArgumentSequence;
node->childCount = 0;
node->children = NULL;
return node;
}
Node* MakeFunctionCallExpressionNode(
Node *identifierNode,
Node *argumentSequenceNode
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionCallExpression;
node->children = (Node**) malloc(sizeof(Node*) * 2);
node->childCount = 2;
node->children[0] = identifierNode;
node->children[1] = argumentSequenceNode;
return node;
}
Node* MakeAccessExpressionNode(
Node *accessee,
Node *accessor
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = AccessExpression;
node->children = (Node**) malloc(sizeof(Node*) * 2);
node->childCount = 2;
node->children[0] = accessee;
node->children[1] = accessor;
return node;
}
Node* MakeAllocNode(Node *typeNode)
{
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = AllocExpression;
node->childCount = 1;
node->children = (Node**) malloc(sizeof(Node*));
node->children[0] = typeNode;
return node;
}
Node* MakeIfNode(
Node *expressionNode,
Node *statementSequenceNode
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = IfStatement;
node->childCount = 2;
node->children = (Node**) malloc(sizeof(Node*));
node->children[0] = expressionNode;
node->children[1] = statementSequenceNode;
return node;
}
Node* MakeIfElseNode(
Node *ifNode,
Node *statementSequenceNode
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = IfElseStatement;
node->childCount = 2;
node->children = (Node**) malloc(sizeof(Node*));
node->children[0] = ifNode;
node->children[1] = statementSequenceNode;
return node;
}
Node* MakeForLoopNode(
Node *identifierNode,
Node *startNumberNode,
Node *endNumberNode,
Node *statementSequenceNode
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = ForLoop;
node->childCount = 4;
node->children = (Node**) malloc(sizeof(Node*) * 4);
node->children[0] = identifierNode;
node->children[1] = startNumberNode;
node->children[2] = endNumberNode;
node->children[3] = statementSequenceNode;
return node;
}
static const char* PrimitiveTypeToString(PrimitiveType type)
{
switch (type)
{
case Int: return "Int";
case UInt: return "UInt";
case Bool: return "Bool";
case Void: return "Void";
}
return "Unknown";
}
static void PrintBinaryOperator(BinaryOperator expression)
{
switch (expression)
{
case Add:
printf("+");
break;
case Subtract:
printf("-");
break;
case Multiply:
printf("*");
break;
}
}
static void PrintNode(Node *node, int tabCount)
{
uint32_t i;
for (i = 0; i < tabCount; i += 1)
{
printf(" ");
}
printf("%s: ", SyntaxKindString(node->syntaxKind));
switch (node->syntaxKind)
{
case BinaryExpression:
PrintBinaryOperator(node->operator.binaryOperator);
break;
case Declaration:
break;
case CustomTypeNode:
printf("%s", node->value.string);
break;
case PrimitiveTypeNode:
printf("%s", PrimitiveTypeToString(node->primitiveType));
break;
case Identifier:
printf("%s", node->value.string);
break;
case Number:
printf("%lu", node->value.number);
break;
}
printf("\n");
}
void PrintTree(Node *node, uint32_t tabCount)
{
uint32_t i;
PrintNode(node, tabCount);
for (i = 0; i < node->childCount; i += 1)
{
PrintTree(node->children[i], tabCount + 1);
}
}
TypeTag* MakeTypeTag(Node *node) {
if (node == NULL) {
fprintf(stderr, "wraith: Attempted to call MakeTypeTag on null value.\n");
return NULL;
}
TypeTag *tag = (TypeTag*)malloc(sizeof(TypeTag));
switch (node->syntaxKind) {
case Type:
tag = MakeTypeTag(node->children[0]);
break;
case PrimitiveTypeNode:
tag->type = Primitive;
tag->value.primitiveType = node->primitiveType;
break;
case ReferenceTypeNode:
tag->type = Reference;
tag->value.referenceType = MakeTypeTag(node->children[0]);
break;
case CustomTypeNode:
tag->type = Custom;
tag->value.customType = strdup(node->value.string);
break;
case Declaration:
tag = MakeTypeTag(node->children[0]);
break;
case StructDeclaration:
tag->type = Custom;
tag->value.customType = strdup(node->children[0]->value.string);
break;
case FunctionDeclaration:
tag = MakeTypeTag(node->children[0]->children[1]);
break;
default:
fprintf(stderr,
"wraith: Attempted to call MakeTypeTag on"
" node with unsupported SyntaxKind: %s\n",
SyntaxKindString(node->syntaxKind));
return NULL;
}
return tag;
}
char* TypeTagToString(TypeTag *tag) {
if (tag == NULL) {
fprintf(stderr, "wraith: Attempted to call TypeTagToString with null value\n");
return NULL;
}
switch (tag->type) {
case Unknown:
return "Unknown";
case Primitive:
return PrimitiveTypeToString(tag->value.primitiveType);
case Reference: {
char *inner = TypeTagToString(tag->value.referenceType);
size_t innerStrLen = strlen(inner);
char *result = malloc(sizeof(char) * (innerStrLen + 5));
sprintf(result, "Ref<%s>", inner);
free(inner);
return result;
}
case Custom:
return tag->value.customType;
}
}
void AddTypeTags(Node *ast)
{
fprintf(stderr, "wraith: AddTypeTags not implemented yet.\n");
}