initial generics stuff

pull/4/head
cosmonaut 2021-05-19 15:45:07 -07:00
parent 473b706ad9
commit 24bcef6d87
4 changed files with 240 additions and 87 deletions

View File

@ -307,14 +307,38 @@ Body : LEFT_BRACE Statements RIGHT_BRACE
$$ = $2; $$ = $2;
} }
FunctionSignature : Identifier LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type GenericArgument : Identifier
{ {
$$ = MakeFunctionSignatureNode($1, $6, $3, MakeFunctionModifiersNode(NULL, 0)); $$ = MakeGenericArgumentNode($1, NULL);
} }
| STATIC Identifier LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type
GenericArguments : GenericArgument
{
$$ = StartGenericArgumentsNode($1);
}
| GenericArguments COMMA GenericArgument
{
$$ = AddGenericArgument($1, $3);
}
GenericArgumentsClause : LESS_THAN GenericArguments GREATER_THAN
{
$$ = $2;
}
|
{
$$ = MakeEmptyGenericArgumentsNode();
}
FunctionSignature : Identifier GenericArgumentsClause LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type
{
$$ = MakeFunctionSignatureNode($1, $7, $4, MakeFunctionModifiersNode(NULL, 0), $2);
}
| STATIC Identifier GenericArgumentsClause LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type
{ {
Node *modifier = MakeStaticNode(); Node *modifier = MakeStaticNode();
$$ = MakeFunctionSignatureNode($2, $7, $4, MakeFunctionModifiersNode(&modifier, 1)); $$ = MakeFunctionSignatureNode($2, $8, $5, MakeFunctionModifiersNode(&modifier, 1), $3);
} }
FunctionDeclaration : FunctionSignature Body FunctionDeclaration : FunctionSignature Body

View File

@ -271,7 +271,8 @@ Node *MakeFunctionSignatureNode(
Node *identifierNode, Node *identifierNode,
Node *typeNode, Node *typeNode,
Node *arguments, Node *arguments,
Node *modifiersNode) Node *modifiersNode,
Node *genericArgumentsNode)
{ {
Node *node = (Node *)malloc(sizeof(Node)); Node *node = (Node *)malloc(sizeof(Node));
node->syntaxKind = FunctionSignature; node->syntaxKind = FunctionSignature;
@ -279,6 +280,7 @@ Node *MakeFunctionSignatureNode(
node->functionSignature.type = typeNode; node->functionSignature.type = typeNode;
node->functionSignature.arguments = arguments; node->functionSignature.arguments = arguments;
node->functionSignature.modifiers = modifiersNode; node->functionSignature.modifiers = modifiersNode;
node->functionSignature.genericArguments = genericArgumentsNode;
return node; return node;
} }
@ -359,6 +361,46 @@ Node *MakeEmptyFunctionArgumentSequenceNode()
return node; return node;
} }
Node *MakeGenericArgumentNode(Node *identifierNode, Node *constraintNode)
{
Node *node = (Node *)malloc(sizeof(Node));
node->syntaxKind = GenericArgument;
node->genericArgument.identifier = identifierNode;
node->genericArgument.constraint = constraintNode;
return node;
}
Node *StartGenericArgumentsNode(Node *genericArgumentNode)
{
Node *node = (Node *)malloc(sizeof(Node));
node->syntaxKind = GenericArguments;
node->genericArguments.arguments = (Node **)malloc(sizeof(Node *));
node->genericArguments.arguments[0] = genericArgumentNode;
node->genericArguments.count = 1;
return node;
}
Node *AddGenericArgument(Node *genericArgumentsNode, Node *genericArgumentNode)
{
genericArgumentsNode->genericArguments.arguments = (Node **)realloc(
genericArgumentsNode->genericArguments.arguments,
sizeof(Node *) * (genericArgumentsNode->genericArguments.count + 1));
genericArgumentsNode->genericArguments
.arguments[genericArgumentsNode->genericArguments.count] =
genericArgumentNode;
genericArgumentsNode->genericArguments.count += 1;
return genericArgumentsNode;
}
Node *MakeEmptyGenericArgumentsNode()
{
Node *node = (Node *)malloc(sizeof(Node));
node->syntaxKind = GenericArguments;
node->genericArguments.arguments = NULL;
node->genericArguments.count = 0;
return node;
}
Node *MakeFunctionCallExpressionNode( Node *MakeFunctionCallExpressionNode(
Node *identifierNode, Node *identifierNode,
Node *argumentSequenceNode) Node *argumentSequenceNode)

View File

@ -30,6 +30,8 @@ typedef enum
FunctionModifiers, FunctionModifiers,
FunctionSignature, FunctionSignature,
FunctionSignatureArguments, FunctionSignatureArguments,
GenericArgument,
GenericArguments,
Identifier, Identifier,
IfStatement, IfStatement,
IfElseStatement, IfElseStatement,
@ -192,6 +194,7 @@ struct Node
Node *type; Node *type;
Node *arguments; Node *arguments;
Node *modifiers; Node *modifiers;
Node *genericArguments;
} functionSignature; } functionSignature;
struct struct
@ -200,6 +203,18 @@ struct Node
uint32_t count; uint32_t count;
} functionSignatureArguments; } functionSignatureArguments;
struct
{
Node *identifier;
Node *constraint;
} genericArgument;
struct
{
Node **arguments;
uint32_t count;
} genericArguments;
struct struct
{ {
char *name; char *name;
@ -306,10 +321,15 @@ Node *MakeFunctionSignatureNode(
Node *identifierNode, Node *identifierNode,
Node *typeNode, Node *typeNode,
Node *argumentsNode, Node *argumentsNode,
Node *modifiersNode); Node *modifiersNode,
Node *genericArgumentsNode);
Node *MakeFunctionDeclarationNode( Node *MakeFunctionDeclarationNode(
Node *functionSignatureNode, Node *functionSignatureNode,
Node *functionBodyNode); Node *functionBodyNode);
Node *MakeGenericArgumentNode(Node *identifierNode, Node *constraintNode);
Node *MakeEmptyGenericArgumentsNode();
Node *StartGenericArgumentsNode(Node *genericArgumentNode);
Node *AddGenericArgument(Node *genericArgumentsNode, Node *genericArgumentNode);
Node *MakeStructDeclarationNode( Node *MakeStructDeclarationNode(
Node *identifierNode, Node *identifierNode,
Node *declarationSequenceNode); Node *declarationSequenceNode);

View File

@ -56,6 +56,24 @@ typedef struct StructTypeFunction
uint8_t isStatic; uint8_t isStatic;
} StructTypeFunction; } StructTypeFunction;
typedef struct StructTypeGenericFunction
{
char *name;
Node *functionDeclarationNode;
} StructTypeGenericFunction;
typedef struct MonomorphizedGenericFunctionHashEntry
{
uint64_t key;
StructTypeFunction function;
} MonomorphizedGenericFunctionHashEntry;
typedef struct MonomorphizedGenericFunctionHashArray
{
MonomorphizedGenericFunctionHashEntry *elements;
uint32_t count;
} MonomorphizedGenericFunctionHashArray;
typedef struct StructTypeDeclaration typedef struct StructTypeDeclaration
{ {
char *name; char *name;
@ -66,6 +84,11 @@ typedef struct StructTypeDeclaration
StructTypeFunction *functions; StructTypeFunction *functions;
uint32_t functionCount; uint32_t functionCount;
StructTypeGenericFunction *genericFunctions;
uint32_t genericFunctionCount;
MonomorphizedGenericFunctionHashArray monomorphizedGenericFunctions;
} StructTypeDeclaration; } StructTypeDeclaration;
StructTypeDeclaration *structTypeDeclarations; StructTypeDeclaration *structTypeDeclarations;
@ -271,6 +294,10 @@ static void AddStructDeclaration(
structTypeDeclarations[index].fieldCount = 0; structTypeDeclarations[index].fieldCount = 0;
structTypeDeclarations[index].functions = NULL; structTypeDeclarations[index].functions = NULL;
structTypeDeclarations[index].functionCount = 0; structTypeDeclarations[index].functionCount = 0;
structTypeDeclarations[index].genericFunctions = NULL;
structTypeDeclarations[index].genericFunctionCount = 0;
structTypeDeclarations[index].monomorphizedGenericFunctions.elements = NULL;
structTypeDeclarations[index].monomorphizedGenericFunctions.count = 0;
for (i = 0; i < fieldDeclarationCount; i += 1) for (i = 0; i < fieldDeclarationCount; i += 1)
{ {
@ -287,6 +314,7 @@ static void AddStructDeclaration(
structTypeDeclarationCount += 1; structTypeDeclarationCount += 1;
} }
/* FIXME: pass the declaration itself */
static void DeclareStructFunction( static void DeclareStructFunction(
LLVMTypeRef wStructPointerType, LLVMTypeRef wStructPointerType,
LLVMValueRef function, LLVMValueRef function,
@ -318,6 +346,31 @@ static void DeclareStructFunction(
fprintf(stderr, "Could not find struct type for function!\n"); fprintf(stderr, "Could not find struct type for function!\n");
} }
/* FIXME: pass the declaration itself */
static void DeclareGenericStructFunction(
LLVMTypeRef wStructPointerType,
Node *functionDeclarationNode,
char *name)
{
uint32_t i, index;
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (structTypeDeclarations[i].structPointerType == wStructPointerType)
{
index = structTypeDeclarations[i].genericFunctionCount;
structTypeDeclarations[i].genericFunctions[index].name =
strdup(name);
structTypeDeclarations[i]
.genericFunctions[index]
.functionDeclarationNode = functionDeclarationNode;
structTypeDeclarations[i].genericFunctionCount += 1;
return;
}
}
}
static LLVMTypeRef LookupCustomType(char *name) static LLVMTypeRef LookupCustomType(char *name)
{ {
uint32_t i; uint32_t i;
@ -1023,101 +1076,115 @@ static void CompileFunction(
} }
} }
if (!isStatic)
{
paramTypes[paramIndex] = wStructPointerType;
paramIndex += 1;
}
PushScopeFrame(scope);
/* FIXME: should work for non-primitive types */
for (i = 0; i < functionSignature->functionSignature.arguments
->functionSignatureArguments.count;
i += 1)
{
paramTypes[paramIndex] =
ResolveType(functionSignature->functionSignature.arguments
->functionSignatureArguments.sequence[i]
->declaration.type);
paramIndex += 1;
}
LLVMTypeRef returnType =
ResolveType(functionSignature->functionSignature.type);
LLVMTypeRef functionType =
LLVMFunctionType(returnType, paramTypes, paramIndex, 0);
char *functionName = strdup(parentStructName); char *functionName = strdup(parentStructName);
strcat(functionName, "_"); strcat(functionName, "_");
strcat( strcat(
functionName, functionName,
functionSignature->functionSignature.identifier->identifier.name); functionSignature->functionSignature.identifier->identifier.name);
LLVMValueRef function = LLVMAddFunction(module, functionName, functionType);
free(functionName);
DeclareStructFunction( if (functionSignature->functionSignature.genericArguments->genericArguments
wStructPointerType, .count == 0)
function,
returnType,
isStatic,
functionSignature->functionSignature.identifier->identifier.name);
LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry");
LLVMBuilderRef builder = LLVMCreateBuilder();
LLVMPositionBuilderAtEnd(builder, entry);
if (!isStatic)
{ {
LLVMValueRef wStructPointer = LLVMGetParam(function, 0); PushScopeFrame(scope);
AddStructVariablesToScope(builder, wStructPointer);
}
for (i = 0; i < functionSignature->functionSignature.arguments if (!isStatic)
->functionSignatureArguments.count; {
i += 1) paramTypes[paramIndex] = wStructPointerType;
{ paramIndex += 1;
char *ptrName = strdup(functionSignature->functionSignature.arguments }
->functionSignatureArguments.sequence[i]
->declaration.identifier->identifier.name);
strcat(ptrName, "_ptr");
LLVMValueRef argument = LLVMGetParam(function, i + !isStatic);
LLVMValueRef argumentCopy =
LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName);
LLVMBuildStore(builder, argument, argumentCopy);
free(ptrName);
AddLocalVariable(
scope,
argumentCopy,
NULL,
functionSignature->functionSignature.arguments
->functionSignatureArguments.sequence[i]
->declaration.identifier->identifier.name);
}
for (i = 0; i < functionBody->statementSequence.count; i += 1) for (i = 0; i < functionSignature->functionSignature.arguments
{ ->functionSignatureArguments.count;
CompileStatement( i += 1)
builder, {
paramTypes[paramIndex] =
ResolveType(functionSignature->functionSignature.arguments
->functionSignatureArguments.sequence[i]
->declaration.type);
paramIndex += 1;
}
LLVMTypeRef returnType =
ResolveType(functionSignature->functionSignature.type);
LLVMTypeRef functionType =
LLVMFunctionType(returnType, paramTypes, paramIndex, 0);
LLVMValueRef function =
LLVMAddFunction(module, functionName, functionType);
DeclareStructFunction(
wStructPointerType,
function, function,
functionBody->statementSequence.sequence[i]); returnType,
isStatic,
functionSignature->functionSignature.identifier->identifier.name);
LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry");
LLVMBuilderRef builder = LLVMCreateBuilder();
LLVMPositionBuilderAtEnd(builder, entry);
if (!isStatic)
{
LLVMValueRef wStructPointer = LLVMGetParam(function, 0);
AddStructVariablesToScope(builder, wStructPointer);
}
for (i = 0; i < functionSignature->functionSignature.arguments
->functionSignatureArguments.count;
i += 1)
{
char *ptrName =
strdup(functionSignature->functionSignature.arguments
->functionSignatureArguments.sequence[i]
->declaration.identifier->identifier.name);
strcat(ptrName, "_ptr");
LLVMValueRef argument = LLVMGetParam(function, i + !isStatic);
LLVMValueRef argumentCopy =
LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName);
LLVMBuildStore(builder, argument, argumentCopy);
free(ptrName);
AddLocalVariable(
scope,
argumentCopy,
NULL,
functionSignature->functionSignature.arguments
->functionSignatureArguments.sequence[i]
->declaration.identifier->identifier.name);
}
for (i = 0; i < functionBody->statementSequence.count; i += 1)
{
CompileStatement(
builder,
function,
functionBody->statementSequence.sequence[i]);
}
hasReturn = LLVMGetBasicBlockTerminator(
LLVMGetLastBasicBlock(function)) != NULL;
if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn)
{
LLVMBuildRetVoid(builder);
}
else if (LLVMGetTypeKind(returnType) != LLVMVoidTypeKind && !hasReturn)
{
fprintf(stderr, "Return statement not provided!");
}
LLVMDisposeBuilder(builder);
PopScopeFrame(scope);
} }
else
hasReturn =
LLVMGetBasicBlockTerminator(LLVMGetLastBasicBlock(function)) != NULL;
if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn)
{ {
LLVMBuildRetVoid(builder); DeclareGenericStructFunction(
} wStructPointerType,
else if (LLVMGetTypeKind(returnType) != LLVMVoidTypeKind && !hasReturn) functionDeclaration,
{ functionName);
fprintf(stderr, "Return statement not provided!");
} }
PopScopeFrame(scope); free(functionName);
LLVMDisposeBuilder(builder);
} }
static void CompileStruct( static void CompileStruct(