#ifndef WRAITH_AST_H #define WRAITH_AST_H #include #include #include #include typedef enum { Assignment, BinaryExpression, Comment, Declaration, DeclarationSequence, Expression, ForLoop, FunctionDeclaration, FunctionSignature, Identifier, Number, Return, StatementSequence, StringLiteral, StructDeclaration, Type, UnaryExpression } SyntaxKind; typedef enum { Negate } UnaryOperator; typedef enum { Add, Subtract } BinaryOperator; typedef enum { Bool, Int, UInt, Float, Double, String } PrimitiveType; typedef union { UnaryOperator unaryOperator; BinaryOperator binaryOperator; } Operator; typedef struct Node { SyntaxKind syntaxKind; struct Node **children; uint32_t childCount; union { UnaryOperator unaryOperator; BinaryOperator binaryOperator; } operator; union { char *string; uint64_t number; } value; PrimitiveType type; } Node; char* strdup (const char* s) { size_t slen = strlen(s); char* result = malloc(slen + 1); if(result == NULL) { return NULL; } memcpy(result, s, slen+1); return result; } const char* SyntaxKindString(SyntaxKind syntaxKind) { switch(syntaxKind) { case Assignment: return "Assignment"; case BinaryExpression: return "BinaryExpression"; case Comment: return "Comment"; case Declaration: return "Declaration"; case DeclarationSequence: return "DeclarationSequence"; case FunctionDeclaration: return "FunctionDeclaration"; case FunctionSignature: return "FunctionSignature"; case Identifier: return "Identifier"; case Number: return "Number"; case Return: return "Return"; case StatementSequence: return "StatementSequence"; case StringLiteral: return "StringLiteral"; case StructDeclaration: return "StructDeclaration"; case Type: return "Type"; case UnaryExpression: return "UnaryExpression"; default: return "Unknown"; } } Node* MakeTypeNode( PrimitiveType type ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Type; node->type = type; node->childCount = 0; return node; } Node* MakeIdentifierNode( const char *id ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Identifier; node->value.string = strdup(id); node->childCount = 0; return node; } Node* MakeNumberNode( const char *numberString ) { char *ptr; Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Number; node->value.number = strtoul(numberString, &ptr, 10); node->childCount = 0; return node; } Node* MakeStringNode( const char *string ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = StringLiteral; node->value.string = strdup(string); node->childCount = 0; return node; } Node* MakeUnaryNode( UnaryOperator operator, Node *child ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = UnaryExpression; node->operator.unaryOperator = operator; node->children = malloc(sizeof(Node*)); node->children[0] = child; node->childCount = 1; return node; } Node* MakeBinaryNode( BinaryOperator operator, Node *left, Node *right ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = BinaryExpression; node->operator.binaryOperator = operator; node->children = malloc(sizeof(Node*) * 2); node->children[0] = left; node->children[1] = right; node->childCount = 2; return node; } Node* MakeDeclarationNode( Node* typeNode, Node* identifierNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Declaration; node->children = (Node**) malloc(sizeof(Node*) * 2); node->childCount = 2; node->children[0] = typeNode; node->children[1] = identifierNode; return node; } Node* MakeAssignmentNode( Node *left, Node *right ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Assignment; node->childCount = 2; node->children = malloc(sizeof(Node*) * 2); node->children[0] = left; node->children[1] = right; return node; } Node* MakeStatementSequenceNode( Node** pNodes, uint32_t nodeCount ) { int32_t i; Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = StatementSequence; node->children = (Node**) malloc(sizeof(Node*) * nodeCount); node->childCount = nodeCount; for (i = nodeCount - 1; i >= 0; i -= 1) { node->children[nodeCount - 1 - i] = pNodes[i]; } return node; } Node* MakeReturnStatementNode( Node *expressionNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Return; node->children = (Node**) malloc(sizeof(Node*)); node->childCount = 1; node->children[0] = expressionNode; return node; } Node* MakeFunctionSignatureNode( Node *identifierNode, Node* typeNode, Node* arguments ) { uint32_t i; Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionSignature; node->childCount = 3; node->children = (Node**) malloc(sizeof(Node*) * (node->childCount)); node->children[0] = identifierNode; node->children[1] = typeNode; node->children[2] = arguments; return node; } Node* MakeFunctionDeclarationNode( Node* functionSignatureNode, Node* functionBodyNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionDeclaration; node->childCount = 2; node->children = (Node**) malloc(sizeof(Node*) * 2); node->children[0] = functionSignatureNode; node->children[1] = functionBodyNode; return node; } Node* MakeStructDeclarationNode( Node *identifierNode, Node *declarationSequenceNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = StructDeclaration; node->childCount = 2; node->children = (Node**) malloc(sizeof(Node*) * 2); node->children[0] = identifierNode; node->children[1] = declarationSequenceNode; return node; } Node* MakeDeclarationSequenceNode( Node **pNodes, uint32_t nodeCount ) { int32_t i; Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = DeclarationSequence; node->children = (Node**) malloc(sizeof(Node*) * nodeCount); node->childCount = nodeCount; for (i = nodeCount - 1; i >= 0; i -= 1) { node->children[nodeCount - 1 - i] = pNodes[i]; } return node; } static const char* PrimitiveTypeToString(PrimitiveType type) { switch (type) { case Int: return "Int"; case UInt: return "UInt"; case Bool: return "Bool"; } return "Unknown"; } static void PrintBinaryOperator(BinaryOperator expression) { switch (expression) { case Add: printf("+"); break; case Subtract: printf("-"); break; } } static void PrintNode(Node *node, int tabCount) { uint32_t i; for (i = 0; i < tabCount; i += 1) { printf(" "); } printf("%s: ", SyntaxKindString(node->syntaxKind)); switch (node->syntaxKind) { case BinaryExpression: PrintBinaryOperator(node->operator.binaryOperator); break; case Declaration: break; case Type: printf("%s", PrimitiveTypeToString(node->type)); break; case Identifier: printf("%s", node->value.string); break; case Number: printf("%lu", node->value.number); break; } printf("\n"); } void PrintTree(Node *node, uint32_t tabCount) { uint32_t i; PrintNode(node, tabCount); for (i = 0; i < node->childCount; i += 1) { PrintTree(node->children[i], tabCount + 1); } } #endif /* WRAITH_AST_H */