initial generics stuff

pull/4/head
cosmonaut 2021-05-19 15:45:07 -07:00
parent 473b706ad9
commit 24bcef6d87
4 changed files with 240 additions and 87 deletions

View File

@ -307,14 +307,38 @@ Body : LEFT_BRACE Statements RIGHT_BRACE
$$ = $2;
}
FunctionSignature : Identifier LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type
GenericArgument : Identifier
{
$$ = MakeFunctionSignatureNode($1, $6, $3, MakeFunctionModifiersNode(NULL, 0));
$$ = MakeGenericArgumentNode($1, NULL);
}
| STATIC Identifier LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type
GenericArguments : GenericArgument
{
$$ = StartGenericArgumentsNode($1);
}
| GenericArguments COMMA GenericArgument
{
$$ = AddGenericArgument($1, $3);
}
GenericArgumentsClause : LESS_THAN GenericArguments GREATER_THAN
{
$$ = $2;
}
|
{
$$ = MakeEmptyGenericArgumentsNode();
}
FunctionSignature : Identifier GenericArgumentsClause LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type
{
$$ = MakeFunctionSignatureNode($1, $7, $4, MakeFunctionModifiersNode(NULL, 0), $2);
}
| STATIC Identifier GenericArgumentsClause LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type
{
Node *modifier = MakeStaticNode();
$$ = MakeFunctionSignatureNode($2, $7, $4, MakeFunctionModifiersNode(&modifier, 1));
$$ = MakeFunctionSignatureNode($2, $8, $5, MakeFunctionModifiersNode(&modifier, 1), $3);
}
FunctionDeclaration : FunctionSignature Body

View File

@ -271,7 +271,8 @@ Node *MakeFunctionSignatureNode(
Node *identifierNode,
Node *typeNode,
Node *arguments,
Node *modifiersNode)
Node *modifiersNode,
Node *genericArgumentsNode)
{
Node *node = (Node *)malloc(sizeof(Node));
node->syntaxKind = FunctionSignature;
@ -279,6 +280,7 @@ Node *MakeFunctionSignatureNode(
node->functionSignature.type = typeNode;
node->functionSignature.arguments = arguments;
node->functionSignature.modifiers = modifiersNode;
node->functionSignature.genericArguments = genericArgumentsNode;
return node;
}
@ -359,6 +361,46 @@ Node *MakeEmptyFunctionArgumentSequenceNode()
return node;
}
Node *MakeGenericArgumentNode(Node *identifierNode, Node *constraintNode)
{
Node *node = (Node *)malloc(sizeof(Node));
node->syntaxKind = GenericArgument;
node->genericArgument.identifier = identifierNode;
node->genericArgument.constraint = constraintNode;
return node;
}
Node *StartGenericArgumentsNode(Node *genericArgumentNode)
{
Node *node = (Node *)malloc(sizeof(Node));
node->syntaxKind = GenericArguments;
node->genericArguments.arguments = (Node **)malloc(sizeof(Node *));
node->genericArguments.arguments[0] = genericArgumentNode;
node->genericArguments.count = 1;
return node;
}
Node *AddGenericArgument(Node *genericArgumentsNode, Node *genericArgumentNode)
{
genericArgumentsNode->genericArguments.arguments = (Node **)realloc(
genericArgumentsNode->genericArguments.arguments,
sizeof(Node *) * (genericArgumentsNode->genericArguments.count + 1));
genericArgumentsNode->genericArguments
.arguments[genericArgumentsNode->genericArguments.count] =
genericArgumentNode;
genericArgumentsNode->genericArguments.count += 1;
return genericArgumentsNode;
}
Node *MakeEmptyGenericArgumentsNode()
{
Node *node = (Node *)malloc(sizeof(Node));
node->syntaxKind = GenericArguments;
node->genericArguments.arguments = NULL;
node->genericArguments.count = 0;
return node;
}
Node *MakeFunctionCallExpressionNode(
Node *identifierNode,
Node *argumentSequenceNode)

View File

@ -30,6 +30,8 @@ typedef enum
FunctionModifiers,
FunctionSignature,
FunctionSignatureArguments,
GenericArgument,
GenericArguments,
Identifier,
IfStatement,
IfElseStatement,
@ -192,6 +194,7 @@ struct Node
Node *type;
Node *arguments;
Node *modifiers;
Node *genericArguments;
} functionSignature;
struct
@ -200,6 +203,18 @@ struct Node
uint32_t count;
} functionSignatureArguments;
struct
{
Node *identifier;
Node *constraint;
} genericArgument;
struct
{
Node **arguments;
uint32_t count;
} genericArguments;
struct
{
char *name;
@ -306,10 +321,15 @@ Node *MakeFunctionSignatureNode(
Node *identifierNode,
Node *typeNode,
Node *argumentsNode,
Node *modifiersNode);
Node *modifiersNode,
Node *genericArgumentsNode);
Node *MakeFunctionDeclarationNode(
Node *functionSignatureNode,
Node *functionBodyNode);
Node *MakeGenericArgumentNode(Node *identifierNode, Node *constraintNode);
Node *MakeEmptyGenericArgumentsNode();
Node *StartGenericArgumentsNode(Node *genericArgumentNode);
Node *AddGenericArgument(Node *genericArgumentsNode, Node *genericArgumentNode);
Node *MakeStructDeclarationNode(
Node *identifierNode,
Node *declarationSequenceNode);

View File

@ -56,6 +56,24 @@ typedef struct StructTypeFunction
uint8_t isStatic;
} StructTypeFunction;
typedef struct StructTypeGenericFunction
{
char *name;
Node *functionDeclarationNode;
} StructTypeGenericFunction;
typedef struct MonomorphizedGenericFunctionHashEntry
{
uint64_t key;
StructTypeFunction function;
} MonomorphizedGenericFunctionHashEntry;
typedef struct MonomorphizedGenericFunctionHashArray
{
MonomorphizedGenericFunctionHashEntry *elements;
uint32_t count;
} MonomorphizedGenericFunctionHashArray;
typedef struct StructTypeDeclaration
{
char *name;
@ -66,6 +84,11 @@ typedef struct StructTypeDeclaration
StructTypeFunction *functions;
uint32_t functionCount;
StructTypeGenericFunction *genericFunctions;
uint32_t genericFunctionCount;
MonomorphizedGenericFunctionHashArray monomorphizedGenericFunctions;
} StructTypeDeclaration;
StructTypeDeclaration *structTypeDeclarations;
@ -271,6 +294,10 @@ static void AddStructDeclaration(
structTypeDeclarations[index].fieldCount = 0;
structTypeDeclarations[index].functions = NULL;
structTypeDeclarations[index].functionCount = 0;
structTypeDeclarations[index].genericFunctions = NULL;
structTypeDeclarations[index].genericFunctionCount = 0;
structTypeDeclarations[index].monomorphizedGenericFunctions.elements = NULL;
structTypeDeclarations[index].monomorphizedGenericFunctions.count = 0;
for (i = 0; i < fieldDeclarationCount; i += 1)
{
@ -287,6 +314,7 @@ static void AddStructDeclaration(
structTypeDeclarationCount += 1;
}
/* FIXME: pass the declaration itself */
static void DeclareStructFunction(
LLVMTypeRef wStructPointerType,
LLVMValueRef function,
@ -318,6 +346,31 @@ static void DeclareStructFunction(
fprintf(stderr, "Could not find struct type for function!\n");
}
/* FIXME: pass the declaration itself */
static void DeclareGenericStructFunction(
LLVMTypeRef wStructPointerType,
Node *functionDeclarationNode,
char *name)
{
uint32_t i, index;
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (structTypeDeclarations[i].structPointerType == wStructPointerType)
{
index = structTypeDeclarations[i].genericFunctionCount;
structTypeDeclarations[i].genericFunctions[index].name =
strdup(name);
structTypeDeclarations[i]
.genericFunctions[index]
.functionDeclarationNode = functionDeclarationNode;
structTypeDeclarations[i].genericFunctionCount += 1;
return;
}
}
}
static LLVMTypeRef LookupCustomType(char *name)
{
uint32_t i;
@ -1023,101 +1076,115 @@ static void CompileFunction(
}
}
if (!isStatic)
{
paramTypes[paramIndex] = wStructPointerType;
paramIndex += 1;
}
PushScopeFrame(scope);
/* FIXME: should work for non-primitive types */
for (i = 0; i < functionSignature->functionSignature.arguments
->functionSignatureArguments.count;
i += 1)
{
paramTypes[paramIndex] =
ResolveType(functionSignature->functionSignature.arguments
->functionSignatureArguments.sequence[i]
->declaration.type);
paramIndex += 1;
}
LLVMTypeRef returnType =
ResolveType(functionSignature->functionSignature.type);
LLVMTypeRef functionType =
LLVMFunctionType(returnType, paramTypes, paramIndex, 0);
char *functionName = strdup(parentStructName);
strcat(functionName, "_");
strcat(
functionName,
functionSignature->functionSignature.identifier->identifier.name);
LLVMValueRef function = LLVMAddFunction(module, functionName, functionType);
free(functionName);
DeclareStructFunction(
wStructPointerType,
function,
returnType,
isStatic,
functionSignature->functionSignature.identifier->identifier.name);
LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry");
LLVMBuilderRef builder = LLVMCreateBuilder();
LLVMPositionBuilderAtEnd(builder, entry);
if (!isStatic)
if (functionSignature->functionSignature.genericArguments->genericArguments
.count == 0)
{
LLVMValueRef wStructPointer = LLVMGetParam(function, 0);
AddStructVariablesToScope(builder, wStructPointer);
}
PushScopeFrame(scope);
for (i = 0; i < functionSignature->functionSignature.arguments
->functionSignatureArguments.count;
i += 1)
{
char *ptrName = strdup(functionSignature->functionSignature.arguments
->functionSignatureArguments.sequence[i]
->declaration.identifier->identifier.name);
strcat(ptrName, "_ptr");
LLVMValueRef argument = LLVMGetParam(function, i + !isStatic);
LLVMValueRef argumentCopy =
LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName);
LLVMBuildStore(builder, argument, argumentCopy);
free(ptrName);
AddLocalVariable(
scope,
argumentCopy,
NULL,
functionSignature->functionSignature.arguments
->functionSignatureArguments.sequence[i]
->declaration.identifier->identifier.name);
}
if (!isStatic)
{
paramTypes[paramIndex] = wStructPointerType;
paramIndex += 1;
}
for (i = 0; i < functionBody->statementSequence.count; i += 1)
{
CompileStatement(
builder,
for (i = 0; i < functionSignature->functionSignature.arguments
->functionSignatureArguments.count;
i += 1)
{
paramTypes[paramIndex] =
ResolveType(functionSignature->functionSignature.arguments
->functionSignatureArguments.sequence[i]
->declaration.type);
paramIndex += 1;
}
LLVMTypeRef returnType =
ResolveType(functionSignature->functionSignature.type);
LLVMTypeRef functionType =
LLVMFunctionType(returnType, paramTypes, paramIndex, 0);
LLVMValueRef function =
LLVMAddFunction(module, functionName, functionType);
DeclareStructFunction(
wStructPointerType,
function,
functionBody->statementSequence.sequence[i]);
returnType,
isStatic,
functionSignature->functionSignature.identifier->identifier.name);
LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry");
LLVMBuilderRef builder = LLVMCreateBuilder();
LLVMPositionBuilderAtEnd(builder, entry);
if (!isStatic)
{
LLVMValueRef wStructPointer = LLVMGetParam(function, 0);
AddStructVariablesToScope(builder, wStructPointer);
}
for (i = 0; i < functionSignature->functionSignature.arguments
->functionSignatureArguments.count;
i += 1)
{
char *ptrName =
strdup(functionSignature->functionSignature.arguments
->functionSignatureArguments.sequence[i]
->declaration.identifier->identifier.name);
strcat(ptrName, "_ptr");
LLVMValueRef argument = LLVMGetParam(function, i + !isStatic);
LLVMValueRef argumentCopy =
LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName);
LLVMBuildStore(builder, argument, argumentCopy);
free(ptrName);
AddLocalVariable(
scope,
argumentCopy,
NULL,
functionSignature->functionSignature.arguments
->functionSignatureArguments.sequence[i]
->declaration.identifier->identifier.name);
}
for (i = 0; i < functionBody->statementSequence.count; i += 1)
{
CompileStatement(
builder,
function,
functionBody->statementSequence.sequence[i]);
}
hasReturn = LLVMGetBasicBlockTerminator(
LLVMGetLastBasicBlock(function)) != NULL;
if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn)
{
LLVMBuildRetVoid(builder);
}
else if (LLVMGetTypeKind(returnType) != LLVMVoidTypeKind && !hasReturn)
{
fprintf(stderr, "Return statement not provided!");
}
LLVMDisposeBuilder(builder);
PopScopeFrame(scope);
}
hasReturn =
LLVMGetBasicBlockTerminator(LLVMGetLastBasicBlock(function)) != NULL;
if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn)
else
{
LLVMBuildRetVoid(builder);
}
else if (LLVMGetTypeKind(returnType) != LLVMVoidTypeKind && !hasReturn)
{
fprintf(stderr, "Return statement not provided!");
DeclareGenericStructFunction(
wStructPointerType,
functionDeclaration,
functionName);
}
PopScopeFrame(scope);
LLVMDisposeBuilder(builder);
free(functionName);
}
static void CompileStruct(