#include "ast.h" #include #include #include "util.h" const char* SyntaxKindString(SyntaxKind syntaxKind) { switch(syntaxKind) { case AccessExpression: return "AccessExpression"; case AllocExpression: return "Alloc"; case Assignment: return "Assignment"; case BinaryExpression: return "BinaryExpression"; case Comment: return "Comment"; case CustomTypeNode: return "CustomTypeNode"; case Declaration: return "Declaration"; case ForLoop: return "ForLoop"; case DeclarationSequence: return "DeclarationSequence"; case FunctionArgumentSequence: return "FunctionArgumentSequence"; case FunctionCallExpression: return "FunctionCallExpression"; case FunctionDeclaration: return "FunctionDeclaration"; case FunctionModifiers: return "FunctionModifiers"; case FunctionSignature: return "FunctionSignature"; case FunctionSignatureArguments: return "FunctionSignatureArguments"; case Identifier: return "Identifier"; case IfStatement: return "If"; case IfElseStatement: return "IfElse"; case Number: return "Number"; case PrimitiveTypeNode: return "PrimitiveTypeNode"; case ReferenceTypeNode: return "ReferenceTypeNode"; case Return: return "Return"; case StatementSequence: return "StatementSequence"; case StaticModifier: return "StaticModifier"; case StringLiteral: return "StringLiteral"; case StructDeclaration: return "StructDeclaration"; case Type: return "Type"; case UnaryExpression: return "UnaryExpression"; default: return "Unknown"; } } uint8_t IsPrimitiveType( Node *typeNode ) { return typeNode->type.typeNode->syntaxKind == PrimitiveTypeNode; } Node* MakePrimitiveTypeNode( PrimitiveType type ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = PrimitiveTypeNode; node->primitiveType.type = type; return node; } Node* MakeCustomTypeNode( char *name ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = CustomTypeNode; node->customType.name = strdup(name); return node; } Node* MakeReferenceTypeNode( Node *typeNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = ReferenceTypeNode; node->referenceType.type = typeNode; return node; } Node* MakeTypeNode( Node* typeNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Type; node->type.typeNode = typeNode; return node; } Node* MakeIdentifierNode( const char *id ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Identifier; node->identifier.name = strdup(id); node->typeTag = NULL; return node; } Node* MakeNumberNode( const char *numberString ) { char *ptr; Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Number; node->number.value = strtoul(numberString, &ptr, 10); return node; } Node* MakeStringNode( const char *string ) { size_t slen = strlen(string); Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = StringLiteral; node->stringLiteral.string = strndup(string + 1, slen - 2); return node; } Node* MakeStaticNode() { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = StaticModifier; return node; } /* FIXME: this sucks */ Node* MakeFunctionModifiersNode( Node **pModifierNodes, uint32_t modifierCount ) { uint32_t i; Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionModifiers; node->functionModifiers.count = modifierCount; node->functionModifiers.sequence = NULL; if (modifierCount > 0) { node->functionModifiers.sequence = malloc(sizeof(Node*) * node->functionModifiers.count); for (i = 0; i < modifierCount; i += 1) { node->functionModifiers.sequence[i] = pModifierNodes[i]; } } return node; } Node* MakeUnaryNode( UnaryOperator operator, Node *child ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = UnaryExpression; node->unaryExpression.operator = operator; node->unaryExpression.child = child; return node; } Node* MakeBinaryNode( BinaryOperator operator, Node *left, Node *right ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = BinaryExpression; node->binaryExpression.left = left; node->binaryExpression.right = right; node->binaryExpression.operator = operator; return node; } Node* MakeDeclarationNode( Node* typeNode, Node* identifierNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Declaration; node->declaration.type = typeNode; node->declaration.identifier = identifierNode; return node; } Node* MakeAssignmentNode( Node *left, Node *right ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Assignment; node->assignmentStatement.left = left; node->assignmentStatement.right = right; return node; } Node* StartStatementSequenceNode( Node *statementNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = StatementSequence; node->statementSequence.sequence = (Node**) malloc(sizeof(Node*)); node->statementSequence.sequence[0] = statementNode; node->statementSequence.count = 1; return node; } Node* AddStatement( Node* statementSequenceNode, Node *statementNode ) { statementSequenceNode->statementSequence.sequence = realloc(statementSequenceNode->statementSequence.sequence, sizeof(Node*) * (statementSequenceNode->statementSequence.count + 1)); statementSequenceNode->statementSequence.sequence[statementSequenceNode->statementSequence.count] = statementNode; statementSequenceNode->statementSequence.count += 1; return statementSequenceNode; } Node* MakeReturnStatementNode( Node *expressionNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = Return; node->returnStatement.expression = expressionNode; return node; } Node* MakeReturnVoidStatementNode() { Node *node = (Node*) malloc(sizeof(Node)); node->syntaxKind = ReturnVoid; return node; } Node *StartFunctionSignatureArgumentsNode( Node *argumentNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionSignatureArguments; node->functionSignatureArguments.sequence = (Node**) malloc(sizeof(Node*)); node->functionSignatureArguments.sequence[0] = argumentNode; node->functionSignatureArguments.count = 1; return node; } Node* AddFunctionSignatureArgumentNode( Node *argumentsNode, Node *argumentNode ) { argumentsNode->functionSignatureArguments.sequence = realloc(argumentsNode->functionSignatureArguments.sequence, sizeof(Node*) * (argumentsNode->functionSignatureArguments.count + 1)); argumentsNode->functionSignatureArguments.sequence[argumentsNode->functionSignatureArguments.count] = argumentNode; argumentsNode->functionSignatureArguments.count += 1; return argumentsNode; } Node *MakeEmptyFunctionSignatureArgumentsNode() { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionSignatureArguments; node->functionSignatureArguments.sequence = NULL; node->functionSignatureArguments.count = 0; return node; } Node* MakeFunctionSignatureNode( Node *identifierNode, Node* typeNode, Node* arguments, Node* modifiersNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionSignature; node->functionSignature.identifier = identifierNode; node->functionSignature.type = typeNode; node->functionSignature.arguments = arguments; node->functionSignature.modifiers = modifiersNode; return node; } Node* MakeFunctionDeclarationNode( Node* functionSignatureNode, Node* functionBodyNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionDeclaration; node->functionDeclaration.functionSignature = functionSignatureNode; node->functionDeclaration.functionBody = functionBodyNode; return node; } Node* MakeStructDeclarationNode( Node *identifierNode, Node *declarationSequenceNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = StructDeclaration; node->structDeclaration.identifier = identifierNode; node->structDeclaration.declarationSequence = declarationSequenceNode; return node; } Node* StartDeclarationSequenceNode( Node *declarationNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = DeclarationSequence; node->declarationSequence.sequence = (Node**) malloc(sizeof(Node*)); node->declarationSequence.sequence[0] = declarationNode; node->declarationSequence.count = 1; return node; } Node* AddDeclarationNode( Node *declarationSequenceNode, Node *declarationNode ) { declarationSequenceNode->declarationSequence.sequence = (Node**) realloc(declarationSequenceNode->declarationSequence.sequence, sizeof(Node*) * (declarationSequenceNode->declarationSequence.count + 1)); declarationSequenceNode->declarationSequence.sequence[declarationSequenceNode->declarationSequence.count] = declarationNode; declarationSequenceNode->declarationSequence.count += 1; return declarationSequenceNode; } Node* StartFunctionArgumentSequenceNode( Node *argumentNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionArgumentSequence; node->functionArgumentSequence.sequence = (Node**) malloc(sizeof(Node*)); node->functionArgumentSequence.sequence[0] = argumentNode; node->functionArgumentSequence.count = 1; return node; } Node* AddFunctionArgumentNode( Node *argumentSequenceNode, Node *argumentNode ) { argumentSequenceNode->functionArgumentSequence.sequence = (Node**) realloc(argumentSequenceNode->functionArgumentSequence.sequence, sizeof(Node*) * (argumentSequenceNode->functionArgumentSequence.count + 1)); argumentSequenceNode->functionArgumentSequence.sequence[argumentSequenceNode->functionArgumentSequence.count] = argumentNode; argumentSequenceNode->functionArgumentSequence.count += 1; return argumentSequenceNode; } Node *MakeEmptyFunctionArgumentSequenceNode() { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionArgumentSequence; node->functionArgumentSequence.count = 0; node->functionArgumentSequence.sequence = NULL; return node; } Node* MakeFunctionCallExpressionNode( Node *identifierNode, Node *argumentSequenceNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = FunctionCallExpression; node->functionCallExpression.identifier = identifierNode; node->functionCallExpression.argumentSequence = argumentSequenceNode; return node; } Node* MakeAccessExpressionNode( Node *accessee, Node *accessor ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = AccessExpression; node->accessExpression.accessee = accessee; node->accessExpression.accessor = accessor; return node; } Node* MakeAllocNode(Node *typeNode) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = AllocExpression; node->allocExpression.type = typeNode; return node; } Node* MakeIfNode( Node *expressionNode, Node *statementSequenceNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = IfStatement; node->ifStatement.expression = expressionNode; node->ifStatement.statementSequence = statementSequenceNode; return node; } Node* MakeIfElseNode( Node *ifNode, Node *elseNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = IfElseStatement; node->ifElseStatement.ifStatement = ifNode; node->ifElseStatement.elseStatement = elseNode; return node; } Node* MakeForLoopNode( Node *declarationNode, Node *startNumberNode, Node *endNumberNode, Node *statementSequenceNode ) { Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = ForLoop; node->forLoop.declaration = declarationNode; node->forLoop.startNumber = startNumberNode; node->forLoop.endNumber = endNumberNode; node->forLoop.statementSequence = statementSequenceNode; return node; } static const char* PrimitiveTypeToString(PrimitiveType type) { switch (type) { case Int: return "Int"; case UInt: return "UInt"; case Bool: return "Bool"; case Void: return "Void"; } return "Unknown"; } static void PrintUnaryOperator(UnaryOperator operator) { switch (operator) { case Negate: printf("!"); break; } } static void PrintBinaryOperator(BinaryOperator operator) { switch (operator) { case Add: printf("(+)"); break; case Subtract: printf("(-)"); break; case Multiply: printf("(*)"); break; } } void PrintNode(Node *node, uint32_t tabCount) { uint32_t i; for (i = 0; i < tabCount; i += 1) { printf(" "); } printf("%s: ", SyntaxKindString(node->syntaxKind)); switch (node->syntaxKind) { case AccessExpression: printf("\n"); PrintNode(node->accessExpression.accessee, tabCount + 1); PrintNode(node->accessExpression.accessor, tabCount + 1); return; case AllocExpression: printf("\n"); PrintNode(node->allocExpression.type, tabCount + 1); return; case Assignment: printf("\n"); PrintNode(node->assignmentStatement.left, tabCount + 1); PrintNode(node->assignmentStatement.right, tabCount + 1); return; case BinaryExpression: PrintBinaryOperator(node->binaryExpression.operator); printf("\n"); PrintNode(node->binaryExpression.left, tabCount + 1); PrintNode(node->binaryExpression.right, tabCount + 1); return; case CustomTypeNode: printf("%s\n", node->customType.name); return; case Declaration: printf("\n"); PrintNode(node->declaration.identifier, tabCount + 1); PrintNode(node->declaration.type, tabCount + 1); return; case DeclarationSequence: printf("\n"); for (i = 0; i < node->declarationSequence.count; i += 1) { PrintNode(node->declarationSequence.sequence[i], tabCount + 1); } return; case ForLoop: printf("\n"); PrintNode(node->forLoop.declaration, tabCount + 1); PrintNode(node->forLoop.startNumber, tabCount + 1); PrintNode(node->forLoop.endNumber, tabCount + 1); PrintNode(node->forLoop.statementSequence, tabCount + 1); return; case FunctionArgumentSequence: printf("\n"); for (i = 0; i < node->functionArgumentSequence.count; i += 1) { PrintNode(node->functionArgumentSequence.sequence[i], tabCount + 1); } return; case FunctionCallExpression: printf("\n"); PrintNode(node->functionCallExpression.identifier, tabCount + 1); PrintNode(node->functionCallExpression.argumentSequence, tabCount + 1); return; case FunctionDeclaration: printf("\n"); PrintNode(node->functionDeclaration.functionSignature, tabCount + 1); PrintNode(node->functionDeclaration.functionBody, tabCount + 1); return; case FunctionModifiers: printf("\n"); for (i = 0; i < node->functionModifiers.count; i += 1) { PrintNode(node->functionModifiers.sequence[i], tabCount + 1); } return; case FunctionSignature: printf("\n"); PrintNode(node->functionSignature.identifier, tabCount + 1); PrintNode(node->functionSignature.arguments, tabCount + 1); PrintNode(node->functionSignature.type, tabCount + 1); PrintNode(node->functionSignature.modifiers, tabCount + 1); return; case FunctionSignatureArguments: printf("\n"); for (i = 0; i < node->functionSignatureArguments.count; i += 1) { PrintNode(node->functionSignatureArguments.sequence[i], tabCount + 1); } return; case Identifier: if (node->typeTag == NULL) { printf("%s\n", node->identifier.name); } else { char *type = TypeTagToString(node->typeTag); printf("%s<%s>\n", node->identifier.name, type); } return; case IfStatement: printf("\n"); PrintNode(node->ifStatement.expression, tabCount + 1); PrintNode(node->ifStatement.statementSequence, tabCount + 1); return; case IfElseStatement: printf("\n"); PrintNode(node->ifElseStatement.ifStatement, tabCount + 1); PrintNode(node->ifElseStatement.elseStatement, tabCount + 1); return; case Number: printf("%lu\n", node->number.value); return; case PrimitiveTypeNode: printf("%s\n", PrimitiveTypeToString(node->primitiveType.type)); return; case ReferenceTypeNode: printf("\n"); PrintNode(node->referenceType.type, tabCount + 1); return; case Return: printf("\n"); PrintNode(node->returnStatement.expression, tabCount + 1); return; case ReturnVoid: return; case StatementSequence: printf("\n"); for (i = 0; i < node->statementSequence.count; i += 1) { PrintNode(node->statementSequence.sequence[i], tabCount + 1); } return; case StaticModifier: printf("\n"); return; case StringLiteral: printf("%s", node->stringLiteral.string); return; case StructDeclaration: printf("\n"); PrintNode(node->structDeclaration.identifier, tabCount + 1); PrintNode(node->structDeclaration.declarationSequence, tabCount + 1); return; case Type: printf("\n"); PrintNode(node->type.typeNode, tabCount + 1); return; case UnaryExpression: PrintUnaryOperator(node->unaryExpression.operator); PrintNode(node->unaryExpression.child, tabCount + 1); return; } } TypeTag* MakeTypeTag(Node *node) { if (node == NULL) { fprintf(stderr, "wraith: Attempted to call MakeTypeTag on null value.\n"); return NULL; } TypeTag *tag = (TypeTag*)malloc(sizeof(TypeTag)); switch (node->syntaxKind) { case Type: tag = MakeTypeTag(node->type.typeNode); break; case PrimitiveTypeNode: tag->type = Primitive; tag->value.primitiveType = node->primitiveType.type; break; case ReferenceTypeNode: tag->type = Reference; tag->value.referenceType = MakeTypeTag(node->referenceType.type); break; case CustomTypeNode: tag->type = Custom; tag->value.customType = strdup(node->customType.name); break; case Declaration: tag = MakeTypeTag(node->declaration.type); break; case StructDeclaration: tag->type = Custom; tag->value.customType = strdup(node->structDeclaration.identifier->identifier.name); break; case FunctionDeclaration: tag = MakeTypeTag(node->functionDeclaration.functionSignature->functionSignature.type); break; default: fprintf(stderr, "wraith: Attempted to call MakeTypeTag on" " node with unsupported SyntaxKind: %s\n", SyntaxKindString(node->syntaxKind)); return NULL; } return tag; } char* TypeTagToString(TypeTag *tag) { if (tag == NULL) { fprintf(stderr, "wraith: Attempted to call TypeTagToString with null value\n"); return NULL; } switch (tag->type) { case Unknown: return "Unknown"; case Primitive: return PrimitiveTypeToString(tag->value.primitiveType); case Reference: { char *inner = TypeTagToString(tag->value.referenceType); size_t innerStrLen = strlen(inner); char *result = malloc(sizeof(char) * (innerStrLen + 5)); sprintf(result, "Ref<%s>", inner); return result; } case Custom: return tag->value.customType; } }