diff --git a/src/ast.c b/src/ast.c index e705b5b..9604dbf 100644 --- a/src/ast.c +++ b/src/ast.c @@ -726,6 +726,165 @@ void PrintNode(Node *node, uint32_t tabCount) } } +void Recurse(Node *node, void (*func)(Node*)) +{ + uint32_t i; + switch (node->syntaxKind) + { + case AccessExpression: + func(node->accessExpression.accessee); + func(node->accessExpression.accessor); + return; + + case AllocExpression: + func(node->allocExpression.type); + return; + + case Assignment: + func(node->assignmentStatement.left); + func(node->assignmentStatement.right); + return; + + case BinaryExpression: + func(node->binaryExpression.left); + func(node->binaryExpression.right); + return; + + case Comment: + return; + + case CustomTypeNode: + return; + + case Declaration: + func(node->declaration.type); + func(node->declaration.identifier); + return; + + case DeclarationSequence: + for (i = 0; i < node->declarationSequence.count; i += 1) { + func(node->declarationSequence.sequence[i]); + } + return; + + case ForLoop: + func(node->forLoop.declaration); + func(node->forLoop.startNumber); + func(node->forLoop.endNumber); + func(node->forLoop.statementSequence); + return; + + case FunctionArgumentSequence: + for (i = 0; i < node->functionArgumentSequence.count; i += 1) { + func(node->functionArgumentSequence.sequence[i]); + } + return; + + case FunctionCallExpression: + func(node->functionCallExpression.identifier); + func(node->functionCallExpression.argumentSequence); + return; + + case FunctionDeclaration: + func(node->functionDeclaration.functionSignature); + func(node->functionDeclaration.functionBody); + return; + + case FunctionModifiers: + for (i = 0; i < node->functionModifiers.count; i += 1) { + func(node->functionModifiers.sequence[i]); + } + return; + + case FunctionSignature: + func(node->functionSignature.identifier); + func(node->functionSignature.type); + func(node->functionSignature.arguments); + func(node->functionSignature.modifiers); + func(node->functionSignature.genericArguments); + return; + + case FunctionSignatureArguments: + for (i = 0; i < node->functionSignatureArguments.count; i += 1) { + func(node->functionSignatureArguments.sequence[i]); + } + return; + + case GenericArgument: + func(node->genericArgument.identifier); + func(node->genericArgument.constraint); + return; + + case GenericArguments: + for (i = 0; i < node->genericArguments.count; i += 1) { + func(node->genericArguments.arguments[i]); + } + return; + + case GenericTypeNode: + return; + + case Identifier: + return; + + case IfStatement: + func(node->ifStatement.expression); + func(node->ifStatement.statementSequence); + return; + + case IfElseStatement: + func(node->ifElseStatement.ifStatement); + func(node->ifElseStatement.elseStatement); + return; + + case Number: + return; + + case PrimitiveTypeNode: + return; + + case ReferenceTypeNode: + func(node->referenceType.type); + return; + + case Return: + func(node->returnStatement.expression); + return; + + case ReturnVoid: + return; + + case StatementSequence: + for (i = 0; i < node->statementSequence.count; i += 1) { + func(node->statementSequence.sequence[i]); + } + return; + + case StaticModifier: + return; + + case StringLiteral: + return; + + case StructDeclaration: + func(node->structDeclaration.identifier); + func(node->structDeclaration.declarationSequence); + return; + + case Type: + return; + + case UnaryExpression: + func(node->unaryExpression.child); + return; + + default: + fprintf(stderr, "wraith: Unhandled SyntaxKind %s in recurse function.\n", + SyntaxKindString(node->syntaxKind)); + return; + } +} + TypeTag *MakeTypeTag(Node *node) { if (node == NULL) diff --git a/src/ast.h b/src/ast.h index eb3cc10..8ad54a0 100644 --- a/src/ast.h +++ b/src/ast.h @@ -367,6 +367,12 @@ Node *MakeForLoopNode( void PrintNode(Node *node, uint32_t tabCount); const char *SyntaxKindString(SyntaxKind syntaxKind); +/* Helper function for applying a void function generically over the children of an AST node. + * Used for functions that need to traverse the entire tree but only perform operations on a subset + * of node types. Such functions can match the syntaxKinds relevant to their purpose and invoke this + * function in all other cases. */ +void Recurse(Node *node, void (*func)(Node*)); + TypeTag *MakeTypeTag(Node *node); char *TypeTagToString(TypeTag *tag); diff --git a/src/typeutils.c b/src/typeutils.c index 1f20be8..f86762c 100644 --- a/src/typeutils.c +++ b/src/typeutils.c @@ -51,152 +51,32 @@ void ConvertIdCustomsToGenerics(IdNode *node) { void ConvertASTCustomsToGenerics(Node *node) { uint32_t i; - switch (node->syntaxKind) { - case AccessExpression: - ConvertASTCustomsToGenerics(node->accessExpression.accessee); - ConvertASTCustomsToGenerics(node->accessExpression.accessor); - return; - - case AllocExpression: - ConvertASTCustomsToGenerics(node->allocExpression.type); - return; - - case Assignment: - ConvertASTCustomsToGenerics(node->assignmentStatement.left); - ConvertASTCustomsToGenerics(node->assignmentStatement.right); - return; - - case BinaryExpression: - ConvertASTCustomsToGenerics(node->binaryExpression.left); - ConvertASTCustomsToGenerics(node->binaryExpression.right); - return; - - case Comment: - return; - - case CustomTypeNode: - return; - - case Declaration: { - Node *type = node->declaration.type->type.typeNode; - Node *id = node->declaration.identifier; - if (id->typeTag->type == Generic && type->syntaxKind == CustomTypeNode) { - free(node->declaration.type); - node->declaration.type = MakeGenericTypeNode(id->typeTag->value.genericType); - } - return; + switch (node->syntaxKind) + { + case Declaration: + { + Node *type = node->declaration.type->type.typeNode; + Node *id = node->declaration.identifier; + if (id->typeTag->type == Generic && type->syntaxKind == CustomTypeNode) { + free(node->declaration.type); + node->declaration.type = MakeGenericTypeNode(id->typeTag->value.genericType); } + return; + } - case DeclarationSequence: - for (i = 0; i < node->declarationSequence.count; i += 1) { - ConvertASTCustomsToGenerics(node->declarationSequence.sequence[i]); - } - return; - - case ForLoop: - ConvertASTCustomsToGenerics(node->forLoop.declaration); - ConvertASTCustomsToGenerics(node->forLoop.startNumber); - ConvertASTCustomsToGenerics(node->forLoop.endNumber); - ConvertASTCustomsToGenerics(node->forLoop.statementSequence); - return; - - case FunctionArgumentSequence: - for (i = 0; i < node->functionArgumentSequence.count; i += 1) { - ConvertASTCustomsToGenerics(node->functionArgumentSequence.sequence[i]); - } - return; - - case FunctionCallExpression: - ConvertASTCustomsToGenerics(node->functionCallExpression.identifier); - ConvertASTCustomsToGenerics(node->functionCallExpression.argumentSequence); - return; - - case FunctionDeclaration: - ConvertASTCustomsToGenerics(node->functionDeclaration.functionSignature); - ConvertASTCustomsToGenerics(node->functionDeclaration.functionBody); - return; - - case FunctionModifiers: - return; - - case FunctionSignature:{ - Node *id = node->functionSignature.identifier; - Node *type = node->functionSignature.type; - if (id->typeTag->type == Generic && type->syntaxKind == CustomTypeNode) { - free(node->functionSignature.type); - node->functionSignature.type = MakeGenericTypeNode(id->typeTag->value.genericType); - } - ConvertASTCustomsToGenerics(node->functionSignature.arguments); - return; + case FunctionSignature: + { + Node *id = node->functionSignature.identifier; + Node *type = node->functionSignature.type; + if (id->typeTag->type == Generic && type->syntaxKind == CustomTypeNode) { + free(node->functionSignature.type); + node->functionSignature.type = MakeGenericTypeNode(id->typeTag->value.genericType); } + ConvertASTCustomsToGenerics(node->functionSignature.arguments); + return; + } - case FunctionSignatureArguments: - for (i = 0; i < node->functionSignatureArguments.count; i += 1) { - ConvertASTCustomsToGenerics(node->functionSignatureArguments.sequence[i]); - } - return; - - case GenericArgument: - return; - - case GenericArguments: - return; - - case GenericTypeNode: - return; - - case Identifier: - return; - - case IfStatement: - ConvertASTCustomsToGenerics(node->ifStatement.expression); - ConvertASTCustomsToGenerics(node->ifStatement.statementSequence); - return; - - case IfElseStatement: - ConvertASTCustomsToGenerics(node->ifElseStatement.ifStatement); - ConvertASTCustomsToGenerics(node->ifElseStatement.elseStatement); - return; - - case Number: - return; - - case PrimitiveTypeNode: - return; - - case ReferenceTypeNode: - return; - - case Return: - ConvertASTCustomsToGenerics(node->returnStatement.expression); - return; - - case ReturnVoid: - return; - - case StatementSequence: - for (i = 0; i < node->statementSequence.count; i += 1) { - ConvertASTCustomsToGenerics(node->statementSequence.sequence[i]); - } - return; - - case StaticModifier: - return; - - case StringLiteral: - return; - - case StructDeclaration: - /* FIXME: This case will need to be modified to handle type parameters over structs. */ - ConvertASTCustomsToGenerics(node->structDeclaration.identifier); - ConvertASTCustomsToGenerics(node->structDeclaration.declarationSequence); - return; - - case Type: - return; - - case UnaryExpression: - ConvertASTCustomsToGenerics(node->unaryExpression.child); - return; + default: + recurse(node, *ConvertASTCustomsToGenerics); } }