#include "ast.h" #include #include #include "util.h" const char *SyntaxKindString(SyntaxKind syntaxKind) { switch (syntaxKind) { case AccessExpression: return "AccessExpression"; case AllocExpression: return "Alloc"; case Assignment: return "Assignment"; case BinaryExpression: return "BinaryExpression"; case Comment: return "Comment"; case CustomTypeNode: return "CustomTypeNode"; case Declaration: return "Declaration"; case ForLoop: return "ForLoop"; case DeclarationSequence: return "DeclarationSequence"; case FunctionArgumentSequence: return "FunctionArgumentSequence"; case FunctionCallExpression: return "FunctionCallExpression"; case FunctionDeclaration: return "FunctionDeclaration"; case FunctionModifiers: return "FunctionModifiers"; case FunctionSignature: return "FunctionSignature"; case FunctionSignatureArguments: return "FunctionSignatureArguments"; case GenericArgument: return "GenericArgument"; case GenericArguments: return "GenericArguments"; case GenericTypeNode: return "GenericTypeNode"; case Identifier: return "Identifier"; case IfStatement: return "If"; case IfElseStatement: return "IfElse"; case Number: return "Number"; case PrimitiveTypeNode: return "PrimitiveTypeNode"; case ReferenceTypeNode: return "ReferenceTypeNode"; case Return: return "Return"; case StatementSequence: return "StatementSequence"; case StaticModifier: return "StaticModifier"; case StringLiteral: return "StringLiteral"; case StructDeclaration: return "StructDeclaration"; case Type: return "Type"; case UnaryExpression: return "UnaryExpression"; default: return "Unknown"; } } uint8_t IsPrimitiveType(Node *typeNode) { return typeNode->type.typeNode->syntaxKind == PrimitiveTypeNode; } Node *MakePrimitiveTypeNode(PrimitiveType type) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = PrimitiveTypeNode; node->primitiveType.type = type; return node; } Node *MakeCustomTypeNode(char *name) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = CustomTypeNode; node->customType.name = strdup(name); return node; } Node *MakeReferenceTypeNode(Node *typeNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = ReferenceTypeNode; node->referenceType.type = typeNode; return node; } Node *MakeTypeNode(Node *typeNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = Type; node->type.typeNode = typeNode; return node; } Node *MakeIdentifierNode(const char *id) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = Identifier; node->identifier.name = strdup(id); node->typeTag = NULL; return node; } Node *MakeNumberNode(const char *numberString) { char *ptr; Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = Number; node->number.value = strtoul(numberString, &ptr, 10); return node; } Node *MakeStringNode(const char *string) { size_t slen = strlen(string); Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = StringLiteral; node->stringLiteral.string = strndup(string + 1, slen - 2); return node; } Node *MakeStaticNode() { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = StaticModifier; return node; } /* FIXME: this sucks */ Node *MakeFunctionModifiersNode(Node **pModifierNodes, uint32_t modifierCount) { uint32_t i; Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = FunctionModifiers; node->functionModifiers.count = modifierCount; node->functionModifiers.sequence = NULL; if (modifierCount > 0) { node->functionModifiers.sequence = malloc(sizeof(Node *) * node->functionModifiers.count); for (i = 0; i < modifierCount; i += 1) { node->functionModifiers.sequence[i] = pModifierNodes[i]; } } return node; } Node *MakeUnaryNode(UnaryOperator operator, Node * child) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = UnaryExpression; node->unaryExpression.operator= operator; node->unaryExpression.child = child; return node; } Node *MakeBinaryNode(BinaryOperator operator, Node * left, Node *right) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = BinaryExpression; node->binaryExpression.left = left; node->binaryExpression.right = right; node->binaryExpression.operator= operator; return node; } Node *MakeDeclarationNode(Node *typeNode, Node *identifierNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = Declaration; node->declaration.type = typeNode; node->declaration.identifier = identifierNode; return node; } Node *MakeAssignmentNode(Node *left, Node *right) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = Assignment; node->assignmentStatement.left = left; node->assignmentStatement.right = right; return node; } Node *StartStatementSequenceNode(Node *statementNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = StatementSequence; node->statementSequence.sequence = (Node **)malloc(sizeof(Node *)); node->statementSequence.sequence[0] = statementNode; node->statementSequence.count = 1; return node; } Node *AddStatement(Node *statementSequenceNode, Node *statementNode) { statementSequenceNode->statementSequence.sequence = realloc( statementSequenceNode->statementSequence.sequence, sizeof(Node *) * (statementSequenceNode->statementSequence.count + 1)); statementSequenceNode->statementSequence .sequence[statementSequenceNode->statementSequence.count] = statementNode; statementSequenceNode->statementSequence.count += 1; return statementSequenceNode; } Node *MakeReturnStatementNode(Node *expressionNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = Return; node->returnStatement.expression = expressionNode; return node; } Node *MakeReturnVoidStatementNode() { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = ReturnVoid; return node; } Node *StartFunctionSignatureArgumentsNode(Node *argumentNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = FunctionSignatureArguments; node->functionSignatureArguments.sequence = (Node **)malloc(sizeof(Node *)); node->functionSignatureArguments.sequence[0] = argumentNode; node->functionSignatureArguments.count = 1; return node; } Node *AddFunctionSignatureArgumentNode(Node *argumentsNode, Node *argumentNode) { argumentsNode->functionSignatureArguments.sequence = realloc( argumentsNode->functionSignatureArguments.sequence, sizeof(Node *) * (argumentsNode->functionSignatureArguments.count + 1)); argumentsNode->functionSignatureArguments .sequence[argumentsNode->functionSignatureArguments.count] = argumentNode; argumentsNode->functionSignatureArguments.count += 1; return argumentsNode; } Node *MakeEmptyFunctionSignatureArgumentsNode() { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = FunctionSignatureArguments; node->functionSignatureArguments.sequence = NULL; node->functionSignatureArguments.count = 0; return node; } Node *MakeFunctionSignatureNode( Node *identifierNode, Node *typeNode, Node *arguments, Node *modifiersNode, Node *genericArgumentsNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = FunctionSignature; node->functionSignature.identifier = identifierNode; node->functionSignature.type = typeNode; node->functionSignature.arguments = arguments; node->functionSignature.modifiers = modifiersNode; node->functionSignature.genericArguments = genericArgumentsNode; return node; } Node *MakeFunctionDeclarationNode( Node *functionSignatureNode, Node *functionBodyNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = FunctionDeclaration; node->functionDeclaration.functionSignature = functionSignatureNode; node->functionDeclaration.functionBody = functionBodyNode; return node; } Node *MakeStructDeclarationNode( Node *identifierNode, Node *declarationSequenceNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = StructDeclaration; node->structDeclaration.identifier = identifierNode; node->structDeclaration.declarationSequence = declarationSequenceNode; return node; } Node *StartDeclarationSequenceNode(Node *declarationNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = DeclarationSequence; node->declarationSequence.sequence = (Node **)malloc(sizeof(Node *)); node->declarationSequence.sequence[0] = declarationNode; node->declarationSequence.count = 1; return node; } Node *AddDeclarationNode(Node *declarationSequenceNode, Node *declarationNode) { declarationSequenceNode->declarationSequence.sequence = (Node **)realloc( declarationSequenceNode->declarationSequence.sequence, sizeof(Node *) * (declarationSequenceNode->declarationSequence.count + 1)); declarationSequenceNode->declarationSequence .sequence[declarationSequenceNode->declarationSequence.count] = declarationNode; declarationSequenceNode->declarationSequence.count += 1; return declarationSequenceNode; } Node *StartFunctionArgumentSequenceNode(Node *argumentNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = FunctionArgumentSequence; node->functionArgumentSequence.sequence = (Node **)malloc(sizeof(Node *)); node->functionArgumentSequence.sequence[0] = argumentNode; node->functionArgumentSequence.count = 1; return node; } Node *AddFunctionArgumentNode(Node *argumentSequenceNode, Node *argumentNode) { argumentSequenceNode->functionArgumentSequence.sequence = (Node **)realloc( argumentSequenceNode->functionArgumentSequence.sequence, sizeof(Node *) * (argumentSequenceNode->functionArgumentSequence.count + 1)); argumentSequenceNode->functionArgumentSequence .sequence[argumentSequenceNode->functionArgumentSequence.count] = argumentNode; argumentSequenceNode->functionArgumentSequence.count += 1; return argumentSequenceNode; } Node *MakeEmptyFunctionArgumentSequenceNode() { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = FunctionArgumentSequence; node->functionArgumentSequence.count = 0; node->functionArgumentSequence.sequence = NULL; return node; } Node *MakeGenericArgumentNode(Node *identifierNode, Node *constraintNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = GenericArgument; node->genericArgument.identifier = identifierNode; node->genericArgument.constraint = constraintNode; return node; } Node *StartGenericArgumentsNode(Node *genericArgumentNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = GenericArguments; node->genericArguments.arguments = (Node **)malloc(sizeof(Node *)); node->genericArguments.arguments[0] = genericArgumentNode; node->genericArguments.count = 1; return node; } Node *AddGenericArgument(Node *genericArgumentsNode, Node *genericArgumentNode) { genericArgumentsNode->genericArguments.arguments = (Node **)realloc( genericArgumentsNode->genericArguments.arguments, sizeof(Node *) * (genericArgumentsNode->genericArguments.count + 1)); genericArgumentsNode->genericArguments .arguments[genericArgumentsNode->genericArguments.count] = genericArgumentNode; genericArgumentsNode->genericArguments.count += 1; return genericArgumentsNode; } Node *MakeEmptyGenericArgumentsNode() { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = GenericArguments; node->genericArguments.arguments = NULL; node->genericArguments.count = 0; return node; } Node *MakeGenericTypeNode(char *name) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = GenericTypeNode; node->genericType.name = strdup(name); return node; } Node *MakeFunctionCallExpressionNode( Node *identifierNode, Node *argumentSequenceNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = FunctionCallExpression; node->functionCallExpression.identifier = identifierNode; node->functionCallExpression.argumentSequence = argumentSequenceNode; return node; } Node *MakeAccessExpressionNode(Node *accessee, Node *accessor) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = AccessExpression; node->accessExpression.accessee = accessee; node->accessExpression.accessor = accessor; return node; } Node *MakeAllocNode(Node *typeNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = AllocExpression; node->allocExpression.type = typeNode; return node; } Node *MakeIfNode(Node *expressionNode, Node *statementSequenceNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = IfStatement; node->ifStatement.expression = expressionNode; node->ifStatement.statementSequence = statementSequenceNode; return node; } Node *MakeIfElseNode(Node *ifNode, Node *elseNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = IfElseStatement; node->ifElseStatement.ifStatement = ifNode; node->ifElseStatement.elseStatement = elseNode; return node; } Node *MakeForLoopNode( Node *declarationNode, Node *startNumberNode, Node *endNumberNode, Node *statementSequenceNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = ForLoop; node->forLoop.declaration = declarationNode; node->forLoop.startNumber = startNumberNode; node->forLoop.endNumber = endNumberNode; node->forLoop.statementSequence = statementSequenceNode; return node; } static const char *PrimitiveTypeToString(PrimitiveType type) { switch (type) { case Int: return "Int"; case UInt: return "UInt"; case Bool: return "Bool"; case Void: return "Void"; } return "Unknown"; } static void PrintUnaryOperator(UnaryOperator operator) { switch (operator) { case Negate: printf("!"); break; } } static void PrintBinaryOperator(BinaryOperator operator) { switch (operator) { case Add: printf("(+)"); break; case Subtract: printf("(-)"); break; case Multiply: printf("(*)"); break; } } void PrintNode(Node *node, uint32_t tabCount) { uint32_t i; for (i = 0; i < tabCount; i += 1) { printf(" "); } printf("%s: ", SyntaxKindString(node->syntaxKind)); switch (node->syntaxKind) { case AccessExpression: printf("\n"); PrintNode(node->accessExpression.accessee, tabCount + 1); PrintNode(node->accessExpression.accessor, tabCount + 1); return; case AllocExpression: printf("\n"); PrintNode(node->allocExpression.type, tabCount + 1); return; case Assignment: printf("\n"); PrintNode(node->assignmentStatement.left, tabCount + 1); PrintNode(node->assignmentStatement.right, tabCount + 1); return; case BinaryExpression: PrintBinaryOperator(node->binaryExpression.operator); printf("\n"); PrintNode(node->binaryExpression.left, tabCount + 1); PrintNode(node->binaryExpression.right, tabCount + 1); return; case CustomTypeNode: printf("%s\n", node->customType.name); return; case Declaration: printf("\n"); PrintNode(node->declaration.identifier, tabCount + 1); PrintNode(node->declaration.type, tabCount + 1); return; case DeclarationSequence: printf("\n"); for (i = 0; i < node->declarationSequence.count; i += 1) { PrintNode(node->declarationSequence.sequence[i], tabCount + 1); } return; case ForLoop: printf("\n"); PrintNode(node->forLoop.declaration, tabCount + 1); PrintNode(node->forLoop.startNumber, tabCount + 1); PrintNode(node->forLoop.endNumber, tabCount + 1); PrintNode(node->forLoop.statementSequence, tabCount + 1); return; case FunctionArgumentSequence: printf("\n"); for (i = 0; i < node->functionArgumentSequence.count; i += 1) { PrintNode(node->functionArgumentSequence.sequence[i], tabCount + 1); } return; case FunctionCallExpression: printf("\n"); PrintNode(node->functionCallExpression.identifier, tabCount + 1); PrintNode(node->functionCallExpression.argumentSequence, tabCount + 1); return; case FunctionDeclaration: printf("\n"); PrintNode(node->functionDeclaration.functionSignature, tabCount + 1); PrintNode(node->functionDeclaration.functionBody, tabCount + 1); return; case FunctionModifiers: printf("\n"); for (i = 0; i < node->functionModifiers.count; i += 1) { PrintNode(node->functionModifiers.sequence[i], tabCount + 1); } return; case FunctionSignature: printf("\n"); PrintNode(node->functionSignature.identifier, tabCount + 1); PrintNode(node->functionSignature.genericArguments, tabCount + 1); PrintNode(node->functionSignature.arguments, tabCount + 1); PrintNode(node->functionSignature.type, tabCount + 1); PrintNode(node->functionSignature.modifiers, tabCount + 1); return; case FunctionSignatureArguments: printf("\n"); for (i = 0; i < node->functionSignatureArguments.count; i += 1) { PrintNode( node->functionSignatureArguments.sequence[i], tabCount + 1); } return; case GenericArgument: printf("\n"); PrintNode(node->genericArgument.identifier, tabCount + 1); /* Constraint nodes are not implemented. */ /* PrintNode(node->genericArgument.constraint, tabCount + 1); */ return; case GenericArguments: printf("\n"); for (i = 0; i < node->genericArguments.count; i += 1) { PrintNode(node->genericArguments.arguments[i], tabCount + 1); } return; case GenericTypeNode: printf("%s\n", node->genericType.name); return; case Identifier: if (node->typeTag == NULL) { printf("%s\n", node->identifier.name); } else { char *type = TypeTagToString(node->typeTag); printf("%s<%s>\n", node->identifier.name, type); } return; case IfStatement: printf("\n"); PrintNode(node->ifStatement.expression, tabCount + 1); PrintNode(node->ifStatement.statementSequence, tabCount + 1); return; case IfElseStatement: printf("\n"); PrintNode(node->ifElseStatement.ifStatement, tabCount + 1); PrintNode(node->ifElseStatement.elseStatement, tabCount + 1); return; case Number: printf("%lu\n", node->number.value); return; case PrimitiveTypeNode: printf("%s\n", PrimitiveTypeToString(node->primitiveType.type)); return; case ReferenceTypeNode: printf("\n"); PrintNode(node->referenceType.type, tabCount + 1); return; case Return: printf("\n"); PrintNode(node->returnStatement.expression, tabCount + 1); return; case ReturnVoid: return; case StatementSequence: printf("\n"); for (i = 0; i < node->statementSequence.count; i += 1) { PrintNode(node->statementSequence.sequence[i], tabCount + 1); } return; case StaticModifier: printf("\n"); return; case StringLiteral: printf("%s\n", node->stringLiteral.string); return; case StructDeclaration: printf("\n"); PrintNode(node->structDeclaration.identifier, tabCount + 1); PrintNode(node->structDeclaration.declarationSequence, tabCount + 1); return; case Type: printf("\n"); PrintNode(node->type.typeNode, tabCount + 1); return; case UnaryExpression: PrintUnaryOperator(node->unaryExpression.operator); PrintNode(node->unaryExpression.child, tabCount + 1); return; } } void Recurse(Node *node, void (*func)(Node*)) { uint32_t i; switch (node->syntaxKind) { case AccessExpression: func(node->accessExpression.accessee); func(node->accessExpression.accessor); return; case AllocExpression: func(node->allocExpression.type); return; case Assignment: func(node->assignmentStatement.left); func(node->assignmentStatement.right); return; case BinaryExpression: func(node->binaryExpression.left); func(node->binaryExpression.right); return; case Comment: return; case CustomTypeNode: return; case Declaration: func(node->declaration.type); func(node->declaration.identifier); return; case DeclarationSequence: for (i = 0; i < node->declarationSequence.count; i += 1) { func(node->declarationSequence.sequence[i]); } return; case ForLoop: func(node->forLoop.declaration); func(node->forLoop.startNumber); func(node->forLoop.endNumber); func(node->forLoop.statementSequence); return; case FunctionArgumentSequence: for (i = 0; i < node->functionArgumentSequence.count; i += 1) { func(node->functionArgumentSequence.sequence[i]); } return; case FunctionCallExpression: func(node->functionCallExpression.identifier); func(node->functionCallExpression.argumentSequence); return; case FunctionDeclaration: func(node->functionDeclaration.functionSignature); func(node->functionDeclaration.functionBody); return; case FunctionModifiers: for (i = 0; i < node->functionModifiers.count; i += 1) { func(node->functionModifiers.sequence[i]); } return; case FunctionSignature: func(node->functionSignature.identifier); func(node->functionSignature.type); func(node->functionSignature.arguments); func(node->functionSignature.modifiers); func(node->functionSignature.genericArguments); return; case FunctionSignatureArguments: for (i = 0; i < node->functionSignatureArguments.count; i += 1) { func(node->functionSignatureArguments.sequence[i]); } return; case GenericArgument: func(node->genericArgument.identifier); func(node->genericArgument.constraint); return; case GenericArguments: for (i = 0; i < node->genericArguments.count; i += 1) { func(node->genericArguments.arguments[i]); } return; case GenericTypeNode: return; case Identifier: return; case IfStatement: func(node->ifStatement.expression); func(node->ifStatement.statementSequence); return; case IfElseStatement: func(node->ifElseStatement.ifStatement); func(node->ifElseStatement.elseStatement); return; case Number: return; case PrimitiveTypeNode: return; case ReferenceTypeNode: func(node->referenceType.type); return; case Return: func(node->returnStatement.expression); return; case ReturnVoid: return; case StatementSequence: for (i = 0; i < node->statementSequence.count; i += 1) { func(node->statementSequence.sequence[i]); } return; case StaticModifier: return; case StringLiteral: return; case StructDeclaration: func(node->structDeclaration.identifier); func(node->structDeclaration.declarationSequence); return; case Type: return; case UnaryExpression: func(node->unaryExpression.child); return; default: fprintf(stderr, "wraith: Unhandled SyntaxKind %s in recurse function.\n", SyntaxKindString(node->syntaxKind)); return; } } TypeTag *MakeTypeTag(Node *node) { if (node == NULL) { fprintf( stderr, "wraith: Attempted to call MakeTypeTag on null value.\n"); return NULL; } TypeTag *tag = (TypeTag *)malloc(sizeof(TypeTag)); switch (node->syntaxKind) { case Type: tag = MakeTypeTag(node->type.typeNode); break; case PrimitiveTypeNode: tag->type = Primitive; tag->value.primitiveType = node->primitiveType.type; break; case ReferenceTypeNode: tag->type = Reference; tag->value.referenceType = MakeTypeTag(node->referenceType.type); break; case CustomTypeNode: tag->type = Custom; tag->value.customType = strdup(node->customType.name); break; case Declaration: tag = MakeTypeTag(node->declaration.type); break; case StructDeclaration: tag->type = Custom; tag->value.customType = strdup(node->structDeclaration.identifier->identifier.name); break; case FunctionDeclaration: tag = MakeTypeTag(node->functionDeclaration.functionSignature ->functionSignature.type); break; case AllocExpression: tag = MakeTypeTag(node->allocExpression.type); break; case GenericTypeNode: tag->type = Generic; tag->value.genericType = strdup(node->genericType.name); default: fprintf( stderr, "wraith: Attempted to call MakeTypeTag on" " node with unsupported SyntaxKind: %s\n", SyntaxKindString(node->syntaxKind)); return NULL; } return tag; } char *TypeTagToString(TypeTag *tag) { if (tag == NULL) { fprintf( stderr, "wraith: Attempted to call TypeTagToString with null value\n"); return NULL; } switch (tag->type) { case Unknown: return "Unknown"; case Primitive: return PrimitiveTypeToString(tag->value.primitiveType); case Reference: { char *inner = TypeTagToString(tag->value.referenceType); size_t innerStrLen = strlen(inner); char *result = malloc(sizeof(char) * (innerStrLen + 5)); sprintf(result, "Ref<%s>", inner); return result; } case Custom: { char *result = malloc(sizeof(char) * (strlen(tag->value.customType) + 8)); sprintf(result, "Custom<%s>", tag->value.customType); return result; } case Generic: { char *result = malloc(sizeof(char) * (strlen(tag->value.customType) + 9)); sprintf(result, "Generic<%s>", tag->value.customType); return result; } } }