#include "ast.h" #include #include #include "util.h" const char *SyntaxKindString(SyntaxKind syntaxKind) { switch (syntaxKind) { case AccessExpression: return "AccessExpression"; case AllocExpression: return "Alloc"; case Assignment: return "Assignment"; case BinaryExpression: return "BinaryExpression"; case Comment: return "Comment"; case ConcreteGenericTypeNode: return "ConcreteGenericTypeNode"; case CustomTypeNode: return "CustomTypeNode"; case Declaration: return "Declaration"; case ForLoop: return "ForLoop"; case DeclarationSequence: return "DeclarationSequence"; case FieldInit: return "FieldInit"; case FunctionArgumentSequence: return "FunctionArgumentSequence"; case FunctionCallExpression: return "FunctionCallExpression"; case FunctionDeclaration: return "FunctionDeclaration"; case FunctionModifiers: return "FunctionModifiers"; case FunctionSignature: return "FunctionSignature"; case FunctionSignatureArguments: return "FunctionSignatureArguments"; case GenericArgument: return "GenericArgument"; case GenericArguments: return "GenericArguments"; case GenericDeclaration: return "GenericDeclaration"; case GenericDeclarations: return "GenericDeclarations"; case GenericTypeNode: return "GenericTypeNode"; case Identifier: return "Identifier"; case IfStatement: return "If"; case IfElseStatement: return "IfElse"; case Number: return "Number"; case PrimitiveTypeNode: return "PrimitiveTypeNode"; case ReferenceTypeNode: return "ReferenceTypeNode"; case Return: return "Return"; case StatementSequence: return "StatementSequence"; case StaticModifier: return "StaticModifier"; case StringLiteral: return "StringLiteral"; case StructDeclaration: return "StructDeclaration"; case StructInit: return "StructInit"; case StructInitFields: return "StructInitFields"; case SystemCall: return "SystemCall"; case Type: return "Type"; case UnaryExpression: return "UnaryExpression"; default: return "Unknown"; } } uint8_t IsPrimitiveType(Node *typeNode) { return typeNode->type.typeNode->syntaxKind == PrimitiveTypeNode; } Node *MakePrimitiveTypeNode(PrimitiveType type) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = PrimitiveTypeNode; node->primitiveType.type = type; return node; } Node *MakeCustomTypeNode(Node *identifierNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = CustomTypeNode; node->customType.name = strdup(identifierNode->identifier.name); free(identifierNode); return node; } Node *MakeReferenceTypeNode(Node *typeNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = ReferenceTypeNode; node->referenceType.type = typeNode; return node; } Node *MakeConcreteGenericTypeNode( Node *identifierNode, Node *genericArgumentsNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = ConcreteGenericTypeNode; node->concreteGenericType.name = strdup(identifierNode->identifier.name); node->concreteGenericType.genericArguments = genericArgumentsNode; free(identifierNode); return node; } Node *MakeTypeNode(Node *typeNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = Type; node->type.typeNode = typeNode; return node; } Node *MakeIdentifierNode(const char *id) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = Identifier; node->identifier.name = strdup(id); node->typeTag = NULL; return node; } Node *MakeNumberNode(const char *numberString) { char *ptr; Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = Number; node->number.value = strtoul(numberString, &ptr, 10); return node; } Node *MakeStringNode(const char *string) { size_t slen = strlen(string); Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = StringLiteral; node->stringLiteral.string = strndup(string + 1, slen - 2); return node; } Node *MakeStaticNode() { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = StaticModifier; return node; } /* FIXME: this sucks */ Node *MakeFunctionModifiersNode(Node **pModifierNodes, uint32_t modifierCount) { uint32_t i; Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = FunctionModifiers; node->functionModifiers.count = modifierCount; node->functionModifiers.sequence = NULL; if (modifierCount > 0) { node->functionModifiers.sequence = malloc(sizeof(Node *) * node->functionModifiers.count); for (i = 0; i < modifierCount; i += 1) { node->functionModifiers.sequence[i] = pModifierNodes[i]; } } return node; } Node *MakeUnaryNode(UnaryOperator operator, Node * child) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = UnaryExpression; node->unaryExpression.operator= operator; node->unaryExpression.child = child; return node; } Node *MakeBinaryNode(BinaryOperator operator, Node * left, Node *right) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = BinaryExpression; node->binaryExpression.left = left; node->binaryExpression.right = right; node->binaryExpression.operator= operator; return node; } Node *MakeDeclarationNode(Node *typeNode, Node *identifierNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = Declaration; node->declaration.type = typeNode; node->declaration.identifier = identifierNode; return node; } Node *MakeAssignmentNode(Node *left, Node *right) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = Assignment; node->assignmentStatement.left = left; node->assignmentStatement.right = right; return node; } Node *StartStatementSequenceNode(Node *statementNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = StatementSequence; node->statementSequence.sequence = (Node **)malloc(sizeof(Node *)); node->statementSequence.sequence[0] = statementNode; node->statementSequence.count = 1; return node; } Node *AddStatement(Node *statementSequenceNode, Node *statementNode) { statementSequenceNode->statementSequence.sequence = realloc( statementSequenceNode->statementSequence.sequence, sizeof(Node *) * (statementSequenceNode->statementSequence.count + 1)); statementSequenceNode->statementSequence .sequence[statementSequenceNode->statementSequence.count] = statementNode; statementSequenceNode->statementSequence.count += 1; return statementSequenceNode; } Node *MakeReturnStatementNode(Node *expressionNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = Return; node->returnStatement.expression = expressionNode; return node; } Node *MakeReturnVoidStatementNode() { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = ReturnVoid; return node; } Node *StartFunctionSignatureArgumentsNode(Node *argumentNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = FunctionSignatureArguments; node->functionSignatureArguments.sequence = (Node **)malloc(sizeof(Node *)); node->functionSignatureArguments.sequence[0] = argumentNode; node->functionSignatureArguments.count = 1; return node; } Node *AddFunctionSignatureArgumentNode(Node *argumentsNode, Node *argumentNode) { argumentsNode->functionSignatureArguments.sequence = realloc( argumentsNode->functionSignatureArguments.sequence, sizeof(Node *) * (argumentsNode->functionSignatureArguments.count + 1)); argumentsNode->functionSignatureArguments .sequence[argumentsNode->functionSignatureArguments.count] = argumentNode; argumentsNode->functionSignatureArguments.count += 1; return argumentsNode; } Node *MakeEmptyFunctionSignatureArgumentsNode() { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = FunctionSignatureArguments; node->functionSignatureArguments.sequence = NULL; node->functionSignatureArguments.count = 0; return node; } Node *MakeFunctionSignatureNode( Node *identifierNode, Node *typeNode, Node *arguments, Node *modifiersNode, Node *genericArgumentsNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = FunctionSignature; node->functionSignature.identifier = identifierNode; node->functionSignature.type = typeNode; node->functionSignature.arguments = arguments; node->functionSignature.modifiers = modifiersNode; node->functionSignature.genericDeclarations = genericArgumentsNode; return node; } Node *MakeFunctionDeclarationNode( Node *functionSignatureNode, Node *functionBodyNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = FunctionDeclaration; node->functionDeclaration.functionSignature = functionSignatureNode; node->functionDeclaration.functionBody = functionBodyNode; return node; } Node *MakeStructDeclarationNode( Node *identifierNode, Node *declarationSequenceNode, Node *genericArgumentsNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = StructDeclaration; node->structDeclaration.identifier = identifierNode; node->structDeclaration.declarationSequence = declarationSequenceNode; node->structDeclaration.genericDeclarations = genericArgumentsNode; return node; } Node *StartDeclarationSequenceNode(Node *declarationNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = DeclarationSequence; node->declarationSequence.sequence = (Node **)malloc(sizeof(Node *)); node->declarationSequence.sequence[0] = declarationNode; node->declarationSequence.count = 1; return node; } Node *AddDeclarationNode(Node *declarationSequenceNode, Node *declarationNode) { declarationSequenceNode->declarationSequence.sequence = (Node **)realloc( declarationSequenceNode->declarationSequence.sequence, sizeof(Node *) * (declarationSequenceNode->declarationSequence.count + 1)); declarationSequenceNode->declarationSequence .sequence[declarationSequenceNode->declarationSequence.count] = declarationNode; declarationSequenceNode->declarationSequence.count += 1; return declarationSequenceNode; } Node *StartFunctionArgumentSequenceNode(Node *argumentNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = FunctionArgumentSequence; node->functionArgumentSequence.sequence = (Node **)malloc(sizeof(Node *)); node->functionArgumentSequence.sequence[0] = argumentNode; node->functionArgumentSequence.count = 1; return node; } Node *AddFunctionArgumentNode(Node *argumentSequenceNode, Node *argumentNode) { argumentSequenceNode->functionArgumentSequence.sequence = (Node **)realloc( argumentSequenceNode->functionArgumentSequence.sequence, sizeof(Node *) * (argumentSequenceNode->functionArgumentSequence.count + 1)); argumentSequenceNode->functionArgumentSequence .sequence[argumentSequenceNode->functionArgumentSequence.count] = argumentNode; argumentSequenceNode->functionArgumentSequence.count += 1; return argumentSequenceNode; } Node *MakeEmptyFunctionArgumentSequenceNode() { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = FunctionArgumentSequence; node->functionArgumentSequence.count = 0; node->functionArgumentSequence.sequence = NULL; return node; } Node *MakeGenericDeclarationNode(Node *identifierNode, Node *constraintNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = GenericDeclaration; node->genericDeclaration.identifier = identifierNode; node->genericDeclaration.constraint = constraintNode; return node; } Node *StartGenericDeclarationsNode(Node *genericArgumentNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = GenericDeclarations; node->genericDeclarations.declarations = (Node **)malloc(sizeof(Node *)); node->genericDeclarations.declarations[0] = genericArgumentNode; node->genericDeclarations.count = 1; return node; } Node *AddGenericDeclaration( Node *genericDeclarationsNode, Node *genericDeclarationNode) { genericDeclarationsNode->genericDeclarations.declarations = (Node **)realloc( genericDeclarationsNode->genericDeclarations.declarations, sizeof(Node *) * (genericDeclarationsNode->genericDeclarations.count + 1)); genericDeclarationsNode->genericDeclarations .declarations[genericDeclarationsNode->genericDeclarations.count] = genericDeclarationNode; genericDeclarationsNode->genericDeclarations.count += 1; return genericDeclarationsNode; } Node *MakeEmptyGenericDeclarationsNode() { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = GenericDeclarations; node->genericDeclarations.declarations = NULL; node->genericDeclarations.count = 0; return node; } Node *MakeGenericArgumentNode(Node *typeNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = GenericArgument; node->genericArgument.type = typeNode; return node; } Node *StartGenericArgumentsNode(Node *genericArgumentNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = GenericArguments; node->genericArguments.arguments = (Node **)malloc(sizeof(Node *)); node->genericArguments.arguments[0] = genericArgumentNode; node->genericArguments.count = 1; return node; } Node *AddGenericArgument(Node *genericArgumentsNode, Node *genericArgumentNode) { genericArgumentsNode->genericArguments.arguments = realloc( genericArgumentsNode->genericArguments.arguments, sizeof(Node *) * (genericArgumentsNode->genericArguments.count + 1)); genericArgumentsNode->genericArguments .arguments[genericArgumentsNode->genericArguments.count] = genericArgumentNode; genericArgumentNode->genericArguments.count += 1; return genericArgumentsNode; } Node *MakeEmptyGenericArgumentsNode() { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = GenericArguments; node->genericArguments.arguments = NULL; node->genericArguments.count = 0; return node; } Node *MakeGenericTypeNode(char *name) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = GenericTypeNode; node->genericType.name = strdup(name); return node; } Node *MakeFunctionCallExpressionNode( Node *identifierNode, Node *argumentSequenceNode, Node *genericArgumentsNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = FunctionCallExpression; node->functionCallExpression.identifier = identifierNode; node->functionCallExpression.argumentSequence = argumentSequenceNode; node->functionCallExpression.genericArguments = genericArgumentsNode; return node; } Node *MakeSystemCallExpressionNode( Node *identifierNode, Node *argumentSequenceNode, Node *genericArgumentsNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = SystemCall; node->systemCall.identifier = identifierNode; node->systemCall.argumentSequence = argumentSequenceNode; node->systemCall.genericArguments = genericArgumentsNode; return node; } Node *MakeAccessExpressionNode(Node *accessee, Node *accessor) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = AccessExpression; node->accessExpression.accessee = accessee; node->accessExpression.accessor = accessor; return node; } Node *MakeAllocNode(Node *typeNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = AllocExpression; node->allocExpression.type = typeNode; return node; } Node *MakeIfNode(Node *expressionNode, Node *statementSequenceNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = IfStatement; node->ifStatement.expression = expressionNode; node->ifStatement.statementSequence = statementSequenceNode; return node; } Node *MakeIfElseNode(Node *ifNode, Node *elseNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = IfElseStatement; node->ifElseStatement.ifStatement = ifNode; node->ifElseStatement.elseStatement = elseNode; return node; } Node *MakeForLoopNode( Node *declarationNode, Node *startNumberNode, Node *endNumberNode, Node *statementSequenceNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = ForLoop; node->forLoop.declaration = declarationNode; node->forLoop.startNumber = startNumberNode; node->forLoop.endNumber = endNumberNode; node->forLoop.statementSequence = statementSequenceNode; return node; } Node *MakeFieldInitNode(Node *identifierNode, Node *expressionNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = FieldInit; node->fieldInit.identifier = identifierNode; node->fieldInit.expression = expressionNode; return node; } Node *StartStructInitFieldsNode(Node *fieldInitNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = StructInitFields; node->structInitFields.fieldInits = (Node **)malloc(sizeof(Node *)); node->structInitFields.fieldInits[0] = fieldInitNode; node->structInitFields.count = 1; return node; } Node *AddFieldInitNode(Node *structInitFieldsNode, Node *fieldInitNode) { structInitFieldsNode->structInitFields.fieldInits = realloc( structInitFieldsNode->structInitFields.fieldInits, sizeof(Node *) * (structInitFieldsNode->structInitFields.count + 1)); structInitFieldsNode->structInitFields .fieldInits[structInitFieldsNode->structInitFields.count] = fieldInitNode; structInitFieldsNode->structInitFields.count += 1; return structInitFieldsNode; } Node *MakeEmptyFieldInitNode() { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = StructInitFields; node->structInitFields.fieldInits = NULL; node->structInitFields.count = 0; return node; } Node *MakeStructInitExpressionNode(Node *typeNode, Node *structInitFieldsNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = StructInit; node->structInit.type = typeNode; node->structInit.initFields = structInitFieldsNode; return node; } static const char *PrimitiveTypeToString(PrimitiveType type) { switch (type) { case Int: return "Int"; case UInt: return "UInt"; case Bool: return "Bool"; case MemoryAddress: return "MemoryAddress"; case Void: return "Void"; } return "Unknown"; } static void PrintUnaryOperator(UnaryOperator operator) { switch (operator) { case Negate: printf("!"); break; } } static void PrintBinaryOperator(BinaryOperator operator) { switch (operator) { case Add: printf("(+)"); break; case Subtract: printf("(-)"); break; case Multiply: printf("(*)"); break; } } void PrintNode(Node *node, uint32_t tabCount) { uint32_t i; for (i = 0; i < tabCount; i += 1) { printf(" "); } printf("%s: ", SyntaxKindString(node->syntaxKind)); switch (node->syntaxKind) { case AccessExpression: printf("\n"); PrintNode(node->accessExpression.accessee, tabCount + 1); PrintNode(node->accessExpression.accessor, tabCount + 1); return; case AllocExpression: printf("\n"); PrintNode(node->allocExpression.type, tabCount + 1); return; case Assignment: printf("\n"); PrintNode(node->assignmentStatement.left, tabCount + 1); PrintNode(node->assignmentStatement.right, tabCount + 1); return; case BinaryExpression: PrintBinaryOperator(node->binaryExpression.operator); printf("\n"); PrintNode(node->binaryExpression.left, tabCount + 1); PrintNode(node->binaryExpression.right, tabCount + 1); return; case ConcreteGenericTypeNode: printf("%s\n", node->concreteGenericType.name); PrintNode(node->concreteGenericType.genericArguments, tabCount + 1); return; case CustomTypeNode: printf("%s\n", node->customType.name); return; case Declaration: printf("\n"); PrintNode(node->declaration.identifier, tabCount + 1); PrintNode(node->declaration.type, tabCount + 1); return; case DeclarationSequence: printf("\n"); for (i = 0; i < node->declarationSequence.count; i += 1) { PrintNode(node->declarationSequence.sequence[i], tabCount + 1); } return; case FieldInit: printf("\n"); PrintNode(node->fieldInit.identifier, tabCount + 1); PrintNode(node->fieldInit.expression, tabCount + 1); return; case ForLoop: printf("\n"); PrintNode(node->forLoop.declaration, tabCount + 1); PrintNode(node->forLoop.startNumber, tabCount + 1); PrintNode(node->forLoop.endNumber, tabCount + 1); PrintNode(node->forLoop.statementSequence, tabCount + 1); return; case FunctionArgumentSequence: printf("\n"); for (i = 0; i < node->functionArgumentSequence.count; i += 1) { PrintNode(node->functionArgumentSequence.sequence[i], tabCount + 1); } return; case FunctionCallExpression: printf("\n"); PrintNode(node->functionCallExpression.identifier, tabCount + 1); PrintNode(node->functionCallExpression.argumentSequence, tabCount + 1); PrintNode(node->functionCallExpression.genericArguments, tabCount + 1); return; case FunctionDeclaration: printf("\n"); PrintNode(node->functionDeclaration.functionSignature, tabCount + 1); PrintNode(node->functionDeclaration.functionBody, tabCount + 1); return; case FunctionModifiers: printf("\n"); for (i = 0; i < node->functionModifiers.count; i += 1) { PrintNode(node->functionModifiers.sequence[i], tabCount + 1); } return; case FunctionSignature: printf("\n"); PrintNode(node->functionSignature.identifier, tabCount + 1); PrintNode(node->functionSignature.genericDeclarations, tabCount + 1); PrintNode(node->functionSignature.arguments, tabCount + 1); PrintNode(node->functionSignature.type, tabCount + 1); PrintNode(node->functionSignature.modifiers, tabCount + 1); return; case FunctionSignatureArguments: printf("\n"); for (i = 0; i < node->functionSignatureArguments.count; i += 1) { PrintNode( node->functionSignatureArguments.sequence[i], tabCount + 1); } return; case GenericArgument: printf("\n"); PrintNode(node->genericArgument.type, tabCount + 1); return; case GenericArguments: printf("\n"); for (i = 0; i < node->genericArguments.count; i += 1) { PrintNode(node->genericArguments.arguments[i], tabCount + 1); } return; case GenericDeclaration: printf("\n"); PrintNode(node->genericDeclaration.identifier, tabCount + 1); /* Constraint nodes are not implemented. */ /* PrintNode(node->genericDeclaration.constraint, tabCount + 1); */ return; case GenericDeclarations: printf("\n"); for (i = 0; i < node->genericDeclarations.count; i += 1) { PrintNode(node->genericDeclarations.declarations[i], tabCount + 1); } return; case GenericTypeNode: printf("%s\n", node->genericType.name); return; case Identifier: if (node->typeTag == NULL) { printf("%s\n", node->identifier.name); } else { char *type = TypeTagToString(node->typeTag); printf("%s<%s>\n", node->identifier.name, type); } return; case IfStatement: printf("\n"); PrintNode(node->ifStatement.expression, tabCount + 1); PrintNode(node->ifStatement.statementSequence, tabCount + 1); return; case IfElseStatement: printf("\n"); PrintNode(node->ifElseStatement.ifStatement, tabCount + 1); PrintNode(node->ifElseStatement.elseStatement, tabCount + 1); return; case Number: printf("%lu\n", node->number.value); return; case PrimitiveTypeNode: printf("%s\n", PrimitiveTypeToString(node->primitiveType.type)); return; case ReferenceTypeNode: printf("\n"); PrintNode(node->referenceType.type, tabCount + 1); return; case Return: printf("\n"); PrintNode(node->returnStatement.expression, tabCount + 1); return; case ReturnVoid: return; case StatementSequence: printf("\n"); for (i = 0; i < node->statementSequence.count; i += 1) { PrintNode(node->statementSequence.sequence[i], tabCount + 1); } return; case StaticModifier: printf("\n"); return; case StringLiteral: printf("%s\n", node->stringLiteral.string); return; case StructDeclaration: printf("\n"); PrintNode(node->structDeclaration.identifier, tabCount + 1); PrintNode(node->structDeclaration.declarationSequence, tabCount + 1); return; case StructInit: printf("\n"); PrintNode(node->structInit.type, tabCount + 1); PrintNode(node->structInit.initFields, tabCount + 1); return; case StructInitFields: printf("\n"); for (i = 0; i < node->structInitFields.count; i += 1) { PrintNode(node->structInitFields.fieldInits[i], tabCount + 1); } return; case SystemCall: printf("\n"); PrintNode(node->systemCall.identifier, tabCount + 1); PrintNode(node->systemCall.argumentSequence, tabCount + 1); PrintNode(node->systemCall.genericArguments, tabCount + 1); return; case Type: printf("\n"); PrintNode(node->type.typeNode, tabCount + 1); return; case UnaryExpression: PrintUnaryOperator(node->unaryExpression.operator); PrintNode(node->unaryExpression.child, tabCount + 1); return; } } void Recurse(Node *node, void (*func)(Node *)) { uint32_t i; switch (node->syntaxKind) { case AccessExpression: func(node->accessExpression.accessee); func(node->accessExpression.accessor); return; case AllocExpression: func(node->allocExpression.type); return; case Assignment: func(node->assignmentStatement.left); func(node->assignmentStatement.right); return; case BinaryExpression: func(node->binaryExpression.left); func(node->binaryExpression.right); return; case Comment: return; case ConcreteGenericTypeNode: func(node->concreteGenericType.genericArguments); return; case CustomTypeNode: return; case Declaration: func(node->declaration.type); func(node->declaration.identifier); return; case DeclarationSequence: for (i = 0; i < node->declarationSequence.count; i += 1) { func(node->declarationSequence.sequence[i]); } return; case FieldInit: func(node->fieldInit.identifier); func(node->fieldInit.expression); return; case ForLoop: func(node->forLoop.declaration); func(node->forLoop.startNumber); func(node->forLoop.endNumber); func(node->forLoop.statementSequence); return; case FunctionArgumentSequence: for (i = 0; i < node->functionArgumentSequence.count; i += 1) { func(node->functionArgumentSequence.sequence[i]); } return; case FunctionCallExpression: func(node->functionCallExpression.identifier); func(node->functionCallExpression.argumentSequence); func(node->functionCallExpression.genericArguments); return; case FunctionDeclaration: func(node->functionDeclaration.functionSignature); func(node->functionDeclaration.functionBody); return; case FunctionModifiers: for (i = 0; i < node->functionModifiers.count; i += 1) { func(node->functionModifiers.sequence[i]); } return; case FunctionSignature: func(node->functionSignature.identifier); func(node->functionSignature.type); func(node->functionSignature.arguments); func(node->functionSignature.modifiers); func(node->functionSignature.genericDeclarations); return; case FunctionSignatureArguments: for (i = 0; i < node->functionSignatureArguments.count; i += 1) { func(node->functionSignatureArguments.sequence[i]); } return; case GenericArgument: func(node->genericArgument.type); break; case GenericArguments: for (i = 0; i < node->genericArguments.count; i += 1) { func(node->genericArguments.arguments[i]); } return; case GenericDeclaration: func(node->genericDeclaration.identifier); func(node->genericDeclaration.constraint); return; case GenericDeclarations: for (i = 0; i < node->genericDeclarations.count; i += 1) { func(node->genericDeclarations.declarations[i]); } return; case GenericTypeNode: return; case Identifier: return; case IfStatement: func(node->ifStatement.expression); func(node->ifStatement.statementSequence); return; case IfElseStatement: func(node->ifElseStatement.ifStatement); func(node->ifElseStatement.elseStatement); return; case Number: return; case PrimitiveTypeNode: return; case ReferenceTypeNode: func(node->referenceType.type); return; case Return: func(node->returnStatement.expression); return; case ReturnVoid: return; case StatementSequence: for (i = 0; i < node->statementSequence.count; i += 1) { func(node->statementSequence.sequence[i]); } return; case StaticModifier: return; case StringLiteral: return; case StructDeclaration: func(node->structDeclaration.identifier); func(node->structDeclaration.declarationSequence); return; case StructInit: func(node->structInit.type); func(node->structInit.initFields); return; case StructInitFields: for (i = 0; i < node->structInitFields.count; i += 1) { func(node->structInitFields.fieldInits[i]); } return; case SystemCall: func(node->systemCall.identifier); func(node->systemCall.argumentSequence); func(node->systemCall.genericArguments); return; case Type: func(node->type.typeNode); return; case UnaryExpression: func(node->unaryExpression.child); return; default: fprintf( stderr, "wraith: Unhandled SyntaxKind %s in recurse function.\n", SyntaxKindString(node->syntaxKind)); return; } } TypeTag *MakeTypeTag(Node *node) { uint32_t i; if (node == NULL) { fprintf( stderr, "wraith: Attempted to call MakeTypeTag on null value.\n"); return NULL; } TypeTag *tag = (TypeTag *)malloc(sizeof(TypeTag)); switch (node->syntaxKind) { case Type: tag = MakeTypeTag(node->type.typeNode); break; case PrimitiveTypeNode: tag->type = Primitive; tag->value.primitiveType = node->primitiveType.type; break; case ReferenceTypeNode: tag->type = Reference; tag->value.referenceType = MakeTypeTag(node->referenceType.type); break; case CustomTypeNode: tag->type = Custom; tag->value.customType = strdup(node->customType.name); break; case ConcreteGenericTypeNode: tag->type = ConcreteGeneric; tag->value.concreteGenericType.name = strdup(node->concreteGenericType.name); tag->value.concreteGenericType.genericArgumentCount = node->concreteGenericType.genericArguments->genericArguments.count; tag->value.concreteGenericType.genericArguments = malloc( sizeof(TypeTag *) * tag->value.concreteGenericType.genericArgumentCount); for (i = 0; i < node->concreteGenericType.genericArguments->genericArguments.count; i += 1) { tag->value.concreteGenericType.genericArguments[i] = MakeTypeTag( node->concreteGenericType.genericArguments->genericArguments .arguments[i] ->genericArgument.type); } break; case Declaration: tag = MakeTypeTag(node->declaration.type); break; case StructDeclaration: tag->type = Custom; tag->value.customType = strdup(node->structDeclaration.identifier->identifier.name); break; case FunctionDeclaration: tag = MakeTypeTag(node->functionDeclaration.functionSignature ->functionSignature.type); break; case AllocExpression: tag = MakeTypeTag(node->allocExpression.type); break; case GenericDeclaration: tag->type = Generic; tag->value.genericType = strdup(node->genericDeclaration.identifier->identifier.name); break; case GenericTypeNode: tag->type = Generic; tag->value.genericType = strdup(node->genericType.name); break; default: fprintf( stderr, "wraith: Attempted to call MakeTypeTag on" " node with unsupported SyntaxKind: %s\n", SyntaxKindString(node->syntaxKind)); return NULL; } return tag; } char *TypeTagToString(TypeTag *tag) { uint32_t i; if (tag == NULL) { fprintf( stderr, "wraith: Attempted to call TypeTagToString with null value\n"); return NULL; } switch (tag->type) { case Unknown: return "Unknown"; case Primitive: return PrimitiveTypeToString(tag->value.primitiveType); case Reference: { char *inner = TypeTagToString(tag->value.referenceType); size_t innerStrLen = strlen(inner); char *result = malloc(sizeof(char) * (innerStrLen + 6)); sprintf(result, "Ref<%s>", inner); return result; } case Custom: { char *result = malloc(sizeof(char) * (strlen(tag->value.customType) + 9)); sprintf(result, "Custom<%s>", tag->value.customType); return result; } case Generic: { char *result = malloc(sizeof(char) * (strlen(tag->value.genericType) + 10)); sprintf(result, "Generic<%s>", tag->value.genericType); return result; } case ConcreteGeneric: { char *result = strdup(tag->value.concreteGenericType.name); uint32_t len = strlen(result); len += 2; result = realloc(result, sizeof(char) * len); strcat(result, "<"); for (i = 0; i < tag->value.concreteGenericType.genericArgumentCount; i += 1) { char *inner = TypeTagToString( tag->value.concreteGenericType.genericArguments[i]); len += strlen(inner); result = realloc(result, sizeof(char) * (len + 3)); if (i != tag->value.concreteGenericType.genericArgumentCount - 1) { strcat(result, ", "); } strcat(result, inner); } result = realloc(result, sizeof(char) * (len + 1)); strcat(result, ">"); return result; } } } uint8_t TypeTagEqual(TypeTag *typeTagA, TypeTag *typeTagB) { if (typeTagA->type != typeTagB->type) { return 0; } switch (typeTagA->type) { case Primitive: return typeTagA->value.primitiveType == typeTagB->value.primitiveType; case Reference: return TypeTagEqual( typeTagA->value.referenceType, typeTagB->value.referenceType); case Custom: return strcmp(typeTagA->value.customType, typeTagB->value.customType) == 0; case Generic: return strcmp( typeTagA->value.genericType, typeTagB->value.genericType) == 0; default: fprintf(stderr, "Invalid type comparison!"); return 0; } } void LinkParentPointers(Node *node, Node *prev) { if (node == NULL) return; node->parent = prev; uint32_t i; switch (node->syntaxKind) { case AccessExpression: LinkParentPointers(node->accessExpression.accessee, node); LinkParentPointers(node->accessExpression.accessor, node); return; case AllocExpression: LinkParentPointers(node->allocExpression.type, node); return; case Assignment: LinkParentPointers(node->assignmentStatement.left, node); LinkParentPointers(node->assignmentStatement.right, node); return; case BinaryExpression: LinkParentPointers(node->binaryExpression.left, node); LinkParentPointers(node->binaryExpression.right, node); return; case Comment: return; case CustomTypeNode: return; case Declaration: LinkParentPointers(node->declaration.type, node); LinkParentPointers(node->declaration.identifier, node); return; case DeclarationSequence: for (i = 0; i < node->declarationSequence.count; i += 1) { LinkParentPointers(node->declarationSequence.sequence[i], node); } return; case FieldInit: LinkParentPointers(node->fieldInit.identifier, node); LinkParentPointers(node->fieldInit.expression, node); return; case ForLoop: LinkParentPointers(node->forLoop.declaration, node); LinkParentPointers(node->forLoop.startNumber, node); LinkParentPointers(node->forLoop.endNumber, node); LinkParentPointers(node->forLoop.statementSequence, node); return; case FunctionArgumentSequence: for (i = 0; i < node->functionArgumentSequence.count; i += 1) { LinkParentPointers( node->functionArgumentSequence.sequence[i], node); } return; case FunctionCallExpression: LinkParentPointers(node->functionCallExpression.identifier, node); LinkParentPointers(node->functionCallExpression.argumentSequence, node); return; case FunctionDeclaration: LinkParentPointers(node->functionDeclaration.functionSignature, node); LinkParentPointers(node->functionDeclaration.functionBody, node); return; case FunctionModifiers: for (i = 0; i < node->functionModifiers.count; i += 1) { LinkParentPointers(node->functionModifiers.sequence[i], node); } return; case FunctionSignature: LinkParentPointers(node->functionSignature.identifier, node); LinkParentPointers(node->functionSignature.type, node); LinkParentPointers(node->functionSignature.arguments, node); LinkParentPointers(node->functionSignature.modifiers, node); LinkParentPointers(node->functionSignature.genericDeclarations, node); return; case FunctionSignatureArguments: for (i = 0; i < node->functionSignatureArguments.count; i += 1) { LinkParentPointers( node->functionSignatureArguments.sequence[i], node); } return; case GenericArgument: LinkParentPointers(node->genericArgument.type, node); return; case GenericArguments: for (i = 0; i < node->genericArguments.count; i += 1) { LinkParentPointers(node->genericArguments.arguments[i], node); } return; case GenericDeclaration: LinkParentPointers(node->genericDeclaration.identifier, node); LinkParentPointers(node->genericDeclaration.constraint, node); return; case GenericDeclarations: for (i = 0; i < node->genericDeclarations.count; i += 1) { LinkParentPointers(node->genericDeclarations.declarations[i], node); } return; case GenericTypeNode: return; case Identifier: return; case IfStatement: LinkParentPointers(node->ifStatement.expression, node); LinkParentPointers(node->ifStatement.statementSequence, node); return; case IfElseStatement: LinkParentPointers(node->ifElseStatement.ifStatement, node); LinkParentPointers(node->ifElseStatement.elseStatement, node); return; case Number: return; case PrimitiveTypeNode: return; case ReferenceTypeNode: LinkParentPointers(node->referenceType.type, node); return; case Return: LinkParentPointers(node->returnStatement.expression, node); return; case ReturnVoid: return; case StatementSequence: for (i = 0; i < node->statementSequence.count; i += 1) { LinkParentPointers(node->statementSequence.sequence[i], node); } return; case StaticModifier: return; case StringLiteral: return; case StructDeclaration: LinkParentPointers(node->structDeclaration.identifier, node); LinkParentPointers(node->structDeclaration.declarationSequence, node); return; case StructInit: LinkParentPointers(node->structInit.type, node); LinkParentPointers(node->structInit.initFields, node); return; case StructInitFields: for (i = 0; i < node->structInitFields.count; i += 1) { LinkParentPointers(node->structInitFields.fieldInits[i], node); } return; case SystemCall: LinkParentPointers(node->systemCall.identifier, node); LinkParentPointers(node->systemCall.argumentSequence, node); LinkParentPointers(node->systemCall.genericArguments, node); return; case Type: return; case UnaryExpression: LinkParentPointers(node->unaryExpression.child, node); return; default: fprintf( stderr, "wraith: Unhandled SyntaxKind %s in recurse function.\n", SyntaxKindString(node->syntaxKind)); return; } }