diff --git a/generators/wraith.y b/generators/wraith.y index 7aaf70a..3176d60 100644 --- a/generators/wraith.y +++ b/generators/wraith.y @@ -307,14 +307,38 @@ Body : LEFT_BRACE Statements RIGHT_BRACE $$ = $2; } -FunctionSignature : Identifier LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type +GenericArgument : Identifier { - $$ = MakeFunctionSignatureNode($1, $6, $3, MakeFunctionModifiersNode(NULL, 0)); + $$ = MakeGenericArgumentNode($1, NULL); } - | STATIC Identifier LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type + +GenericArguments : GenericArgument + { + $$ = StartGenericArgumentsNode($1); + } + | GenericArguments COMMA GenericArgument + { + $$ = AddGenericArgument($1, $3); + } + +GenericArgumentsClause : LESS_THAN GenericArguments GREATER_THAN + { + $$ = $2; + } + | + { + $$ = MakeEmptyGenericArgumentsNode(); + } + + +FunctionSignature : Identifier GenericArgumentsClause LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type + { + $$ = MakeFunctionSignatureNode($1, $7, $4, MakeFunctionModifiersNode(NULL, 0), $2); + } + | STATIC Identifier GenericArgumentsClause LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type { Node *modifier = MakeStaticNode(); - $$ = MakeFunctionSignatureNode($2, $7, $4, MakeFunctionModifiersNode(&modifier, 1)); + $$ = MakeFunctionSignatureNode($2, $8, $5, MakeFunctionModifiersNode(&modifier, 1), $3); } FunctionDeclaration : FunctionSignature Body diff --git a/src/ast.c b/src/ast.c index 74ee4d9..4c578e9 100644 --- a/src/ast.c +++ b/src/ast.c @@ -271,7 +271,8 @@ Node *MakeFunctionSignatureNode( Node *identifierNode, Node *typeNode, Node *arguments, - Node *modifiersNode) + Node *modifiersNode, + Node *genericArgumentsNode) { Node *node = (Node *)malloc(sizeof(Node)); node->syntaxKind = FunctionSignature; @@ -279,6 +280,7 @@ Node *MakeFunctionSignatureNode( node->functionSignature.type = typeNode; node->functionSignature.arguments = arguments; node->functionSignature.modifiers = modifiersNode; + node->functionSignature.genericArguments = genericArgumentsNode; return node; } @@ -359,6 +361,46 @@ Node *MakeEmptyFunctionArgumentSequenceNode() return node; } +Node *MakeGenericArgumentNode(Node *identifierNode, Node *constraintNode) +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = GenericArgument; + node->genericArgument.identifier = identifierNode; + node->genericArgument.constraint = constraintNode; + return node; +} + +Node *StartGenericArgumentsNode(Node *genericArgumentNode) +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = GenericArguments; + node->genericArguments.arguments = (Node **)malloc(sizeof(Node *)); + node->genericArguments.arguments[0] = genericArgumentNode; + node->genericArguments.count = 1; + return node; +} + +Node *AddGenericArgument(Node *genericArgumentsNode, Node *genericArgumentNode) +{ + genericArgumentsNode->genericArguments.arguments = (Node **)realloc( + genericArgumentsNode->genericArguments.arguments, + sizeof(Node *) * (genericArgumentsNode->genericArguments.count + 1)); + genericArgumentsNode->genericArguments + .arguments[genericArgumentsNode->genericArguments.count] = + genericArgumentNode; + genericArgumentsNode->genericArguments.count += 1; + return genericArgumentsNode; +} + +Node *MakeEmptyGenericArgumentsNode() +{ + Node *node = (Node *)malloc(sizeof(Node)); + node->syntaxKind = GenericArguments; + node->genericArguments.arguments = NULL; + node->genericArguments.count = 0; + return node; +} + Node *MakeFunctionCallExpressionNode( Node *identifierNode, Node *argumentSequenceNode) diff --git a/src/ast.h b/src/ast.h index 60e954d..a4b367c 100644 --- a/src/ast.h +++ b/src/ast.h @@ -30,6 +30,8 @@ typedef enum FunctionModifiers, FunctionSignature, FunctionSignatureArguments, + GenericArgument, + GenericArguments, Identifier, IfStatement, IfElseStatement, @@ -192,6 +194,7 @@ struct Node Node *type; Node *arguments; Node *modifiers; + Node *genericArguments; } functionSignature; struct @@ -200,6 +203,18 @@ struct Node uint32_t count; } functionSignatureArguments; + struct + { + Node *identifier; + Node *constraint; + } genericArgument; + + struct + { + Node **arguments; + uint32_t count; + } genericArguments; + struct { char *name; @@ -306,10 +321,15 @@ Node *MakeFunctionSignatureNode( Node *identifierNode, Node *typeNode, Node *argumentsNode, - Node *modifiersNode); + Node *modifiersNode, + Node *genericArgumentsNode); Node *MakeFunctionDeclarationNode( Node *functionSignatureNode, Node *functionBodyNode); +Node *MakeGenericArgumentNode(Node *identifierNode, Node *constraintNode); +Node *MakeEmptyGenericArgumentsNode(); +Node *StartGenericArgumentsNode(Node *genericArgumentNode); +Node *AddGenericArgument(Node *genericArgumentsNode, Node *genericArgumentNode); Node *MakeStructDeclarationNode( Node *identifierNode, Node *declarationSequenceNode); diff --git a/src/codegen.c b/src/codegen.c index e8a5fb9..c24eb2e 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -56,6 +56,24 @@ typedef struct StructTypeFunction uint8_t isStatic; } StructTypeFunction; +typedef struct StructTypeGenericFunction +{ + char *name; + Node *functionDeclarationNode; +} StructTypeGenericFunction; + +typedef struct MonomorphizedGenericFunctionHashEntry +{ + uint64_t key; + StructTypeFunction function; +} MonomorphizedGenericFunctionHashEntry; + +typedef struct MonomorphizedGenericFunctionHashArray +{ + MonomorphizedGenericFunctionHashEntry *elements; + uint32_t count; +} MonomorphizedGenericFunctionHashArray; + typedef struct StructTypeDeclaration { char *name; @@ -66,6 +84,11 @@ typedef struct StructTypeDeclaration StructTypeFunction *functions; uint32_t functionCount; + + StructTypeGenericFunction *genericFunctions; + uint32_t genericFunctionCount; + + MonomorphizedGenericFunctionHashArray monomorphizedGenericFunctions; } StructTypeDeclaration; StructTypeDeclaration *structTypeDeclarations; @@ -271,6 +294,10 @@ static void AddStructDeclaration( structTypeDeclarations[index].fieldCount = 0; structTypeDeclarations[index].functions = NULL; structTypeDeclarations[index].functionCount = 0; + structTypeDeclarations[index].genericFunctions = NULL; + structTypeDeclarations[index].genericFunctionCount = 0; + structTypeDeclarations[index].monomorphizedGenericFunctions.elements = NULL; + structTypeDeclarations[index].monomorphizedGenericFunctions.count = 0; for (i = 0; i < fieldDeclarationCount; i += 1) { @@ -287,6 +314,7 @@ static void AddStructDeclaration( structTypeDeclarationCount += 1; } +/* FIXME: pass the declaration itself */ static void DeclareStructFunction( LLVMTypeRef wStructPointerType, LLVMValueRef function, @@ -318,6 +346,31 @@ static void DeclareStructFunction( fprintf(stderr, "Could not find struct type for function!\n"); } +/* FIXME: pass the declaration itself */ +static void DeclareGenericStructFunction( + LLVMTypeRef wStructPointerType, + Node *functionDeclarationNode, + char *name) +{ + uint32_t i, index; + + for (i = 0; i < structTypeDeclarationCount; i += 1) + { + if (structTypeDeclarations[i].structPointerType == wStructPointerType) + { + index = structTypeDeclarations[i].genericFunctionCount; + structTypeDeclarations[i].genericFunctions[index].name = + strdup(name); + structTypeDeclarations[i] + .genericFunctions[index] + .functionDeclarationNode = functionDeclarationNode; + structTypeDeclarations[i].genericFunctionCount += 1; + + return; + } + } +} + static LLVMTypeRef LookupCustomType(char *name) { uint32_t i; @@ -1023,101 +1076,115 @@ static void CompileFunction( } } - if (!isStatic) - { - paramTypes[paramIndex] = wStructPointerType; - paramIndex += 1; - } - - PushScopeFrame(scope); - - /* FIXME: should work for non-primitive types */ - for (i = 0; i < functionSignature->functionSignature.arguments - ->functionSignatureArguments.count; - i += 1) - { - paramTypes[paramIndex] = - ResolveType(functionSignature->functionSignature.arguments - ->functionSignatureArguments.sequence[i] - ->declaration.type); - paramIndex += 1; - } - - LLVMTypeRef returnType = - ResolveType(functionSignature->functionSignature.type); - LLVMTypeRef functionType = - LLVMFunctionType(returnType, paramTypes, paramIndex, 0); - char *functionName = strdup(parentStructName); strcat(functionName, "_"); strcat( functionName, functionSignature->functionSignature.identifier->identifier.name); - LLVMValueRef function = LLVMAddFunction(module, functionName, functionType); - free(functionName); - DeclareStructFunction( - wStructPointerType, - function, - returnType, - isStatic, - functionSignature->functionSignature.identifier->identifier.name); - - LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry"); - LLVMBuilderRef builder = LLVMCreateBuilder(); - LLVMPositionBuilderAtEnd(builder, entry); - - if (!isStatic) + if (functionSignature->functionSignature.genericArguments->genericArguments + .count == 0) { - LLVMValueRef wStructPointer = LLVMGetParam(function, 0); - AddStructVariablesToScope(builder, wStructPointer); - } + PushScopeFrame(scope); - for (i = 0; i < functionSignature->functionSignature.arguments - ->functionSignatureArguments.count; - i += 1) - { - char *ptrName = strdup(functionSignature->functionSignature.arguments - ->functionSignatureArguments.sequence[i] - ->declaration.identifier->identifier.name); - strcat(ptrName, "_ptr"); - LLVMValueRef argument = LLVMGetParam(function, i + !isStatic); - LLVMValueRef argumentCopy = - LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName); - LLVMBuildStore(builder, argument, argumentCopy); - free(ptrName); - AddLocalVariable( - scope, - argumentCopy, - NULL, - functionSignature->functionSignature.arguments - ->functionSignatureArguments.sequence[i] - ->declaration.identifier->identifier.name); - } + if (!isStatic) + { + paramTypes[paramIndex] = wStructPointerType; + paramIndex += 1; + } - for (i = 0; i < functionBody->statementSequence.count; i += 1) - { - CompileStatement( - builder, + for (i = 0; i < functionSignature->functionSignature.arguments + ->functionSignatureArguments.count; + i += 1) + { + paramTypes[paramIndex] = + ResolveType(functionSignature->functionSignature.arguments + ->functionSignatureArguments.sequence[i] + ->declaration.type); + paramIndex += 1; + } + + LLVMTypeRef returnType = + ResolveType(functionSignature->functionSignature.type); + LLVMTypeRef functionType = + LLVMFunctionType(returnType, paramTypes, paramIndex, 0); + + LLVMValueRef function = + LLVMAddFunction(module, functionName, functionType); + + DeclareStructFunction( + wStructPointerType, function, - functionBody->statementSequence.sequence[i]); + returnType, + isStatic, + functionSignature->functionSignature.identifier->identifier.name); + + LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry"); + LLVMBuilderRef builder = LLVMCreateBuilder(); + LLVMPositionBuilderAtEnd(builder, entry); + + if (!isStatic) + { + LLVMValueRef wStructPointer = LLVMGetParam(function, 0); + AddStructVariablesToScope(builder, wStructPointer); + } + + for (i = 0; i < functionSignature->functionSignature.arguments + ->functionSignatureArguments.count; + i += 1) + { + char *ptrName = + strdup(functionSignature->functionSignature.arguments + ->functionSignatureArguments.sequence[i] + ->declaration.identifier->identifier.name); + strcat(ptrName, "_ptr"); + LLVMValueRef argument = LLVMGetParam(function, i + !isStatic); + LLVMValueRef argumentCopy = + LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName); + LLVMBuildStore(builder, argument, argumentCopy); + free(ptrName); + AddLocalVariable( + scope, + argumentCopy, + NULL, + functionSignature->functionSignature.arguments + ->functionSignatureArguments.sequence[i] + ->declaration.identifier->identifier.name); + } + + for (i = 0; i < functionBody->statementSequence.count; i += 1) + { + CompileStatement( + builder, + function, + functionBody->statementSequence.sequence[i]); + } + + hasReturn = LLVMGetBasicBlockTerminator( + LLVMGetLastBasicBlock(function)) != NULL; + + if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn) + { + LLVMBuildRetVoid(builder); + } + else if (LLVMGetTypeKind(returnType) != LLVMVoidTypeKind && !hasReturn) + { + fprintf(stderr, "Return statement not provided!"); + } + + LLVMDisposeBuilder(builder); + + PopScopeFrame(scope); } - - hasReturn = - LLVMGetBasicBlockTerminator(LLVMGetLastBasicBlock(function)) != NULL; - - if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn) + else { - LLVMBuildRetVoid(builder); - } - else if (LLVMGetTypeKind(returnType) != LLVMVoidTypeKind && !hasReturn) - { - fprintf(stderr, "Return statement not provided!"); + DeclareGenericStructFunction( + wStructPointerType, + functionDeclaration, + functionName); } - PopScopeFrame(scope); - - LLVMDisposeBuilder(builder); + free(functionName); } static void CompileStruct(