#include "ast.h" #include #include #include char* strdup (const char* s) { size_t slen = strlen(s); char* result = malloc(slen + 1); if(result == NULL) { return NULL; } memcpy(result, s, slen+1); return result; } const char* SyntaxKindString(SyntaxKind syntaxKind) { switch(syntaxKind) { case AccessExpression: return "AccessExpression"; case AllocExpression: return "Alloc"; case Assignment: return "Assignment"; case BinaryExpression: return "BinaryExpression"; case Comment: return "Comment"; case CustomTypeNode: return "CustomTypeNode"; case Declaration: return "Declaration"; case DeclarationSequence: return "DeclarationSequence"; case FunctionArgumentSequence: return "FunctionArgumentSequence"; case FunctionCallExpression: return "FunctionCallExpression"; case FunctionDeclaration: return "FunctionDeclaration"; case FunctionModifiers: return "FunctionModifiers"; case FunctionSignature: return "FunctionSignature"; case FunctionSignatureArguments: return "FunctionSignatureArguments"; case Identifier: return "Identifier"; case IfStatement: return "If"; case IfElseStatement: return "IfElse"; case Number: return "Number"; case PrimitiveTypeNode: return "PrimitiveTypeNode"; case ReferenceTypeNode: return "ReferenceTypeNode"; case Return: return "Return"; case StatementSequence: return "StatementSequence"; case StaticModifier: return "StaticModifier"; case StringLiteral: return "StringLiteral"; case StructDeclaration: return "StructDeclaration"; case Type: return "Type"; case UnaryExpression: return "UnaryExpression"; default: return "Unknown"; } } uint8_t IsPrimitiveType( Node *typeNode ) { return typeNode->children[0]->syntaxKind == PrimitiveTypeNode; } Node* MakePrimitiveTypeNode( PrimitiveType type ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = PrimitiveTypeNode; node->primitiveType = type; node->childCount = 0; return node; } Node* MakeCustomTypeNode( char *name ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = CustomTypeNode; node->value.string = strdup(name); node->childCount = 0; return node; } Node* MakeReferenceTypeNode( Node *typeNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = ReferenceTypeNode; node->childCount = 1; node->children = (Node**) malloc(sizeof(Node*)); node->children[0] = typeNode; return node; } Node* MakeTypeNode( Node* typeNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Type; node->childCount = 1; node->children = (Node**) malloc(sizeof(Node*)); node->children[0] = typeNode; return node; } Node* MakeIdentifierNode( const char *id ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Identifier; node->value.string = strdup(id); node->childCount = 0; return node; } Node* MakeNumberNode( const char *numberString ) { char *ptr; Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Number; node->value.number = strtoul(numberString, &ptr, 10); node->childCount = 0; return node; } Node* MakeStringNode( const char *string ) { size_t slen = strlen(string); Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = StringLiteral; node->value.string = strndup(string + 1, slen - 2); node->childCount = 0; return node; } Node* MakeStaticNode() { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = StaticModifier; node->childCount = 0; return node; } Node* MakeFunctionModifiersNode( Node **pModifierNodes, uint32_t modifierCount ) { uint32_t i; Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionModifiers; node->childCount = modifierCount; if (modifierCount > 0) { node->children = malloc(sizeof(Node*) * node->childCount); for (i = 0; i < modifierCount; i += 1) { node->children[i] = pModifierNodes[i]; } } return node; } Node* MakeUnaryNode( UnaryOperator operator, Node *child ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = UnaryExpression; node->operator.unaryOperator = operator; node->children = malloc(sizeof(Node*)); node->children[0] = child; node->childCount = 1; return node; } Node* MakeBinaryNode( BinaryOperator operator, Node *left, Node *right ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = BinaryExpression; node->operator.binaryOperator = operator; node->children = malloc(sizeof(Node*) * 2); node->children[0] = left; node->children[1] = right; node->childCount = 2; return node; } Node* MakeDeclarationNode( Node* typeNode, Node* identifierNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Declaration; node->children = (Node**) malloc(sizeof(Node*) * 2); node->childCount = 2; node->children[0] = typeNode; node->children[1] = identifierNode; return node; } Node* MakeAssignmentNode( Node *left, Node *right ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Assignment; node->childCount = 2; node->children = malloc(sizeof(Node*) * 2); node->children[0] = left; node->children[1] = right; return node; } Node* StartStatementSequenceNode( Node *statementNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = StatementSequence; node->children = (Node**) malloc(sizeof(Node*)); node->childCount = 1; node->children[0] = statementNode; return node; } Node* AddStatement( Node* statementSequenceNode, Node *statementNode ) { statementSequenceNode->children = realloc(statementSequenceNode->children, sizeof(Node*) * (statementSequenceNode->childCount + 1)); statementSequenceNode->children[statementSequenceNode->childCount] = statementNode; statementSequenceNode->childCount += 1; return statementSequenceNode; } Node* MakeReturnStatementNode( Node *expressionNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Return; node->children = (Node**) malloc(sizeof(Node*)); node->childCount = 1; node->children[0] = expressionNode; return node; } Node* MakeReturnVoidStatementNode() { Node *node = (Node*) malloc(sizeof(Node)); node->syntaxKind = ReturnVoid; node->childCount = 0; node->children = NULL; return node; } Node *StartFunctionSignatureArgumentsNode( Node *argumentNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionSignatureArguments; node->childCount = 1; node->children = (Node**) malloc(sizeof(Node*)); node->children[0] = argumentNode; return node; } Node* AddFunctionSignatureArgumentNode( Node *argumentsNode, Node *argumentNode ) { argumentsNode->children = realloc(argumentsNode->children, sizeof(Node*) * (argumentsNode->childCount + 1)); argumentsNode->children[argumentsNode->childCount] = argumentNode; argumentsNode->childCount += 1; return argumentsNode; } Node *MakeEmptyFunctionSignatureArgumentsNode() { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionSignatureArguments; node->childCount = 0; node->children = NULL; return node; } Node* MakeFunctionSignatureNode( Node *identifierNode, Node* typeNode, Node* arguments, Node* modifiersNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionSignature; node->childCount = 4; node->children = (Node**) malloc(sizeof(Node*) * (node->childCount)); node->children[0] = identifierNode; node->children[1] = typeNode; node->children[2] = arguments; node->children[3] = modifiersNode; return node; } Node* MakeFunctionDeclarationNode( Node* functionSignatureNode, Node* functionBodyNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionDeclaration; node->childCount = 2; node->children = (Node**) malloc(sizeof(Node*) * 2); node->children[0] = functionSignatureNode; node->children[1] = functionBodyNode; return node; } Node* MakeStructDeclarationNode( Node *identifierNode, Node *declarationSequenceNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = StructDeclaration; node->childCount = 2; node->children = (Node**) malloc(sizeof(Node*) * 2); node->children[0] = identifierNode; node->children[1] = declarationSequenceNode; return node; } Node* StartDeclarationSequenceNode( Node *declarationNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = DeclarationSequence; node->children = (Node**) malloc(sizeof(Node*)); node->childCount = 1; node->children[0] = declarationNode; return node; } Node* AddDeclarationNode( Node *declarationSequenceNode, Node *declarationNode ) { declarationSequenceNode->children = realloc(declarationSequenceNode->children, sizeof(Node*) * (declarationSequenceNode->childCount + 1)); declarationSequenceNode->children[declarationSequenceNode->childCount] = declarationNode; declarationSequenceNode->childCount += 1; return declarationSequenceNode; } Node* StartFunctionArgumentSequenceNode( Node *argumentNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionArgumentSequence; node->childCount = 1; node->children = (Node**) malloc(sizeof(Node*)); node->children[0] = argumentNode; return node; } Node* AddFunctionArgumentNode( Node *argumentSequenceNode, Node *argumentNode ) { argumentSequenceNode->children = realloc(argumentSequenceNode->children, sizeof(Node*) * (argumentSequenceNode->childCount + 1)); argumentSequenceNode->children[argumentSequenceNode->childCount] = argumentNode; argumentSequenceNode->childCount += 1; return argumentSequenceNode; } Node *MakeEmptyFunctionArgumentSequenceNode() { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionArgumentSequence; node->childCount = 0; node->children = NULL; return node; } Node* MakeFunctionCallExpressionNode( Node *identifierNode, Node *argumentSequenceNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionCallExpression; node->children = (Node**) malloc(sizeof(Node*) * 2); node->childCount = 2; node->children[0] = identifierNode; node->children[1] = argumentSequenceNode; return node; } Node* MakeAccessExpressionNode( Node *accessee, Node *accessor ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = AccessExpression; node->children = (Node**) malloc(sizeof(Node*) * 2); node->childCount = 2; node->children[0] = accessee; node->children[1] = accessor; return node; } Node* MakeAllocNode(Node *typeNode) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = AllocExpression; node->childCount = 1; node->children = (Node**) malloc(sizeof(Node*)); node->children[0] = typeNode; return node; } Node* MakeIfNode( Node *expressionNode, Node *statementSequenceNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = IfStatement; node->childCount = 2; node->children = (Node**) malloc(sizeof(Node*)); node->children[0] = expressionNode; node->children[1] = statementSequenceNode; return node; } Node* MakeIfElseNode( Node *ifNode, Node *statementSequenceNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = IfElseStatement; node->childCount = 2; node->children = (Node**) malloc(sizeof(Node*)); node->children[0] = ifNode; node->children[1] = statementSequenceNode; return node; } Node* MakeForLoopNode( Node *identifierNode, Node *startNumberNode, Node *endNumberNode, Node *statementSequenceNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = ForLoop; node->childCount = 4; node->children = (Node**) malloc(sizeof(Node*) * 4); node->children[0] = identifierNode; node->children[1] = startNumberNode; node->children[2] = endNumberNode; node->children[3] = statementSequenceNode; return node; } static const char* PrimitiveTypeToString(PrimitiveType type) { switch (type) { case Int: return "Int"; case UInt: return "UInt"; case Bool: return "Bool"; case Void: return "Void"; } return "Unknown"; } static void PrintBinaryOperator(BinaryOperator expression) { switch (expression) { case Add: printf("+"); break; case Subtract: printf("-"); break; case Multiply: printf("*"); break; } } static void PrintNode(Node *node, int tabCount) { uint32_t i; for (i = 0; i < tabCount; i += 1) { printf(" "); } printf("%s: ", SyntaxKindString(node->syntaxKind)); switch (node->syntaxKind) { case BinaryExpression: PrintBinaryOperator(node->operator.binaryOperator); break; case Declaration: break; case CustomTypeNode: printf("%s", node->value.string); break; case PrimitiveTypeNode: printf("%s", PrimitiveTypeToString(node->primitiveType)); break; case Identifier: printf("%s", node->value.string); break; case Number: printf("%lu", node->value.number); break; } printf("\n"); } void PrintTree(Node *node, uint32_t tabCount) { uint32_t i; PrintNode(node, tabCount); for (i = 0; i < node->childCount; i += 1) { PrintTree(node->children[i], tabCount + 1); } }