preparing for struct generics

main
cosmonaut 2021-06-02 17:26:26 -07:00
parent ea203e6c3c
commit a571edcf6d
5 changed files with 273 additions and 123 deletions

View File

@ -345,7 +345,10 @@ GenericDeclarationClause : LESS_THAN GenericDeclarations GREATER_THAN
$$ = MakeEmptyGenericDeclarationsNode();
}
GenericArgument : Type;
GenericArgument : Type
{
$$ = MakeGenericArgumentNode($1);
}
GenericArguments : GenericArgument
{

View File

@ -9,6 +9,17 @@ struct Foo {
}
}
struct MemoryBlock<T>
{
start: MemoryAddress;
capacity: uint;
AddressOf(count: uint): MemoryAddress
{
return start + (count * @sizeof<T>());
}
}
struct Program {
static Main(): int {
x: int = 4;

View File

@ -875,6 +875,7 @@ void Recurse(Node *node, void (*func)(Node *))
case FunctionCallExpression:
func(node->functionCallExpression.identifier);
func(node->functionCallExpression.argumentSequence);
func(node->functionCallExpression.genericArguments);
return;
case FunctionDeclaration:
@ -904,6 +905,17 @@ void Recurse(Node *node, void (*func)(Node *))
}
return;
case GenericArgument:
func(node->genericArgument.type);
break;
case GenericArguments:
for (i = 0; i < node->genericArguments.count; i += 1)
{
func(node->genericArguments.arguments[i]);
}
return;
case GenericDeclaration:
func(node->genericDeclaration.identifier);
func(node->genericDeclaration.constraint);
@ -967,7 +979,14 @@ void Recurse(Node *node, void (*func)(Node *))
func(node->structDeclaration.declarationSequence);
return;
case SystemCall:
func(node->systemCall.identifier);
func(node->systemCall.argumentSequence);
func(node->systemCall.genericArguments);
return;
case Type:
func(node->type.typeNode);
return;
case UnaryExpression:
@ -1195,6 +1214,17 @@ void LinkParentPointers(Node *node, Node *prev)
}
return;
case GenericArgument:
LinkParentPointers(node->genericArgument.type, node);
return;
case GenericArguments:
for (i = 0; i < node->genericArguments.count; i += 1)
{
LinkParentPointers(node->genericArguments.arguments[i], node);
}
return;
case GenericDeclaration:
LinkParentPointers(node->genericDeclaration.identifier, node);
LinkParentPointers(node->genericDeclaration.constraint, node);
@ -1258,6 +1288,12 @@ void LinkParentPointers(Node *node, Node *prev)
LinkParentPointers(node->structDeclaration.declarationSequence, node);
return;
case SystemCall:
LinkParentPointers(node->systemCall.identifier, node);
LinkParentPointers(node->systemCall.argumentSequence, node);
LinkParentPointers(node->systemCall.genericArguments, node);
return;
case Type:
return;

View File

@ -112,6 +112,30 @@ typedef struct StructTypeDeclaration
StructTypeDeclaration *structTypeDeclarations;
uint32_t structTypeDeclarationCount;
typedef struct MonomorphizedGenericStructHashEntry
{
uint64_t key;
TypeTag **types;
uint32_t typeCount;
StructTypeDeclaration structDeclaration;
} MonomorphizedGenericStructHashEntry;
typedef struct MonomorphizedGenericStructHashArray
{
MonomorphizedGenericStructHashEntry *elements;
uint32_t count;
} MonomorphizedGenericStructHashArray;
typedef struct GenericStructTypeDeclaration
{
Node *structDeclarationNode;
MonomorphizedGenericStructHashArray
monomorphizedStructs[NUM_MONOMORPHIZED_HASH_BUCKETS];
} GenericStructTypeDeclaration;
GenericStructTypeDeclaration *genericStructTypeDeclarations;
uint32_t genericStructTypeDeclarationCount;
typedef struct SystemFunction
{
char *name;
@ -234,16 +258,14 @@ static LocalGenericType *LookupGenericType(char *name)
return NULL;
}
static TypeTag *ConcretizeGenericType(char *name)
static TypeTag *ConcretizeType(TypeTag *type)
{
LocalGenericType *type = LookupGenericType(name);
if (type == NULL)
if (type->type == Generic)
{
return NULL;
return LookupGenericType(type->value.genericType)->concreteTypeTag;
}
return type->concreteTypeTag;
return type;
}
static LLVMTypeRef LookupCustomType(char *name)
@ -258,7 +280,7 @@ static LLVMTypeRef LookupCustomType(char *name)
}
}
fprintf(stderr, "Could not find struct type!\n");
fprintf(stderr, "Could not find custom type!\n");
return NULL;
}
@ -301,9 +323,10 @@ static void AddSystemFunction(
systemFunctionCount += 1;
}
static SystemFunction *LookupSystemFunction(char *name)
static SystemFunction *LookupSystemFunction(Node *systemCallExpression)
{
uint32_t i;
char *name = systemCallExpression->systemCall.identifier->identifier.name;
for (i = 0; i < systemFunctionCount; i += 1)
{
@ -518,6 +541,30 @@ static void AddStructDeclaration(
structTypeDeclarationCount += 1;
}
static void AddGenericStructDeclaration(Node *structDeclarationNode)
{
uint32_t i;
genericStructTypeDeclarations = realloc(
genericStructTypeDeclarations,
sizeof(GenericStructTypeDeclaration) *
(genericStructTypeDeclarationCount + 1));
genericStructTypeDeclarations[genericStructTypeDeclarationCount]
.structDeclarationNode = structDeclarationNode;
for (i = 0; i < NUM_MONOMORPHIZED_HASH_BUCKETS; i += 1)
{
genericStructTypeDeclarations[genericStructTypeDeclarationCount]
.monomorphizedStructs[i]
.elements = NULL;
genericStructTypeDeclarations[genericStructTypeDeclarationCount]
.monomorphizedStructs[i]
.count = 0;
}
genericStructTypeDeclarationCount += 1;
}
/* FIXME: pass the declaration itself */
static void DeclareStructFunction(
LLVMTypeRef wStructPointerType,
@ -771,8 +818,7 @@ static StructTypeFunction CompileGenericFunction(
static LLVMValueRef LookupGenericFunction(
LLVMModuleRef module,
StructTypeGenericFunction *genericFunction,
TypeTag **argumentTypes,
uint32_t argumentCount,
Node *functionCallExpression,
LLVMTypeRef *pReturnType,
uint8_t *pStatic)
{
@ -784,45 +830,79 @@ static LLVMValueRef LookupGenericFunction(
.functionSignature->functionSignature.genericDeclarations
->genericDeclarations.count;
TypeTag *resolvedGenericArgumentTypes[genericArgumentTypeCount];
uint32_t argumentCount = 0;
TypeTag
*argumentTypes[functionCallExpression->functionCallExpression
.argumentSequence->functionArgumentSequence.count];
for (i = 0; i < functionCallExpression->functionCallExpression
.argumentSequence->functionArgumentSequence.count;
i += 1)
{
argumentTypes[i] =
functionCallExpression->functionCallExpression.argumentSequence
->functionArgumentSequence.sequence[i]
->typeTag;
argumentCount += 1;
}
if (functionCallExpression->functionCallExpression.genericArguments
->genericArguments.count > 0)
{
for (i = 0; i < functionCallExpression->functionCallExpression
.genericArguments->genericArguments.count;
i += 1)
{
resolvedGenericArgumentTypes[i] =
functionCallExpression->functionCallExpression.genericArguments
->genericArguments.arguments[i]
->genericArgument.type->typeTag;
}
}
else /* we have to infer the generics */
{
for (i = 0; i < genericArgumentTypeCount; i += 1)
{
for (j = 0;
j < genericFunction->functionDeclarationNode->functionDeclaration
j <
genericFunction->functionDeclarationNode->functionDeclaration
.functionSignature->functionSignature.arguments
->functionSignatureArguments.count;
j += 1)
{
if (genericFunction->functionDeclarationNode->functionDeclaration
.functionSignature->functionSignature.arguments
if (genericFunction->functionDeclarationNode
->functionDeclaration.functionSignature
->functionSignature.arguments
->functionSignatureArguments.sequence[j]
->declaration.identifier->typeTag->type == Generic &&
->declaration.identifier->typeTag->type ==
Generic &&
strcmp(
genericFunction->functionDeclarationNode
->functionDeclaration.functionSignature
->functionSignature.arguments
->functionSignatureArguments.sequence[j]
->declaration.identifier->typeTag->value.genericType,
->declaration.identifier->typeTag->value
.genericType,
genericFunction->functionDeclarationNode
->functionDeclaration.functionSignature
->functionSignature.genericDeclarations
->genericDeclarations.declarations[i]
->genericDeclaration.identifier->identifier.name) == 0)
->genericDeclaration.identifier->identifier.name) ==
0)
{
resolvedGenericArgumentTypes[i] = argumentTypes[j];
break;
}
}
}
}
/* Concretize generics if we are compiling nested generic functions */
for (i = 0; i < genericArgumentTypeCount; i += 1)
{
if (resolvedGenericArgumentTypes[i]->type == Generic)
{
resolvedGenericArgumentTypes[i] = ConcretizeGenericType(
resolvedGenericArgumentTypes[i]->value.genericType);
}
resolvedGenericArgumentTypes[i] =
ConcretizeType(resolvedGenericArgumentTypes[i]);
}
typeHash =
@ -894,13 +974,14 @@ static LLVMValueRef LookupGenericFunction(
static LLVMValueRef LookupFunctionByType(
LLVMModuleRef module,
LLVMTypeRef structType,
char *name,
TypeTag **argumentTypes,
uint32_t argumentCount,
Node *functionCallExpression,
LLVMTypeRef *pReturnType,
uint8_t *pStatic)
{
uint32_t i, j;
/* FIXME: hot garbage */
char *name = functionCallExpression->functionCallExpression.identifier
->accessExpression.accessor->identifier.name;
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
@ -928,8 +1009,7 @@ static LLVMValueRef LookupFunctionByType(
return LookupGenericFunction(
module,
&structTypeDeclarations[i].genericFunctions[j],
argumentTypes,
argumentCount,
functionCallExpression,
pReturnType,
pStatic);
}
@ -944,13 +1024,14 @@ static LLVMValueRef LookupFunctionByType(
static LLVMValueRef LookupFunctionByPointerType(
LLVMModuleRef module,
LLVMTypeRef structPointerType,
char *name,
TypeTag **argumentTypes,
uint32_t argumentCount,
Node *functionCallExpression,
LLVMTypeRef *pReturnType,
uint8_t *pStatic)
{
uint32_t i, j;
/* FIXME: hot garbage */
char *name = functionCallExpression->functionCallExpression.identifier
->accessExpression.accessor->identifier.name;
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
@ -978,8 +1059,7 @@ static LLVMValueRef LookupFunctionByPointerType(
return LookupGenericFunction(
module,
&structTypeDeclarations[i].genericFunctions[j],
argumentTypes,
argumentCount,
functionCallExpression,
pReturnType,
pStatic);
}
@ -994,18 +1074,14 @@ static LLVMValueRef LookupFunctionByPointerType(
static LLVMValueRef LookupFunctionByInstance(
LLVMModuleRef module,
LLVMValueRef structPointer,
char *name,
TypeTag **argumentTypes,
uint32_t argumentCount,
Node *functionCallExpression,
LLVMTypeRef *pReturnType,
uint8_t *pStatic)
{
return LookupFunctionByPointerType(
module,
LLVMTypeOf(structPointer),
name,
argumentTypes,
argumentCount,
functionCallExpression,
pReturnType,
pStatic);
}
@ -1093,27 +1169,12 @@ static LLVMValueRef CompileFunctionCallExpression(
[functionCallExpression->functionCallExpression.argumentSequence
->functionArgumentSequence.count +
1];
TypeTag
*argumentTypes[functionCallExpression->functionCallExpression
.argumentSequence->functionArgumentSequence.count];
LLVMValueRef function;
uint8_t isStatic;
LLVMValueRef structInstance;
LLVMTypeRef functionReturnType;
char *returnName = "";
for (i = 0; i < functionCallExpression->functionCallExpression
.argumentSequence->functionArgumentSequence.count;
i += 1)
{
argumentTypes[i] =
functionCallExpression->functionCallExpression.argumentSequence
->functionArgumentSequence.sequence[i]
->typeTag;
argumentCount += 1;
}
/* FIXME: this needs to be recursive on access chains */
/* FIXME: this needs to be able to call same-struct functions implicitly
*/
@ -1129,10 +1190,7 @@ static LLVMValueRef CompileFunctionCallExpression(
function = LookupFunctionByType(
module,
typeReference,
functionCallExpression->functionCallExpression.identifier
->accessExpression.accessor->identifier.name,
argumentTypes,
argumentCount,
functionCallExpression,
&functionReturnType,
&isStatic);
}
@ -1144,10 +1202,7 @@ static LLVMValueRef CompileFunctionCallExpression(
function = LookupFunctionByInstance(
module,
structInstance,
functionCallExpression->functionCallExpression.identifier
->accessExpression.accessor->identifier.name,
argumentTypes,
argumentCount,
functionCallExpression,
&functionReturnType,
&isStatic);
}
@ -1158,8 +1213,6 @@ static LLVMValueRef CompileFunctionCallExpression(
return NULL;
}
argumentCount = 0;
if (!isStatic)
{
args[argumentCount] = structInstance;
@ -1208,14 +1261,28 @@ static LLVMValueRef CompileSystemCallExpression(
->functionArgumentSequence.sequence[i]);
}
SystemFunction *systemFunction = LookupSystemFunction(
systemCallExpression->systemCall.identifier->identifier.name);
SystemFunction *systemFunction = LookupSystemFunction(systemCallExpression);
if (systemFunction == NULL)
{
/* special case for sizeof */
if (strcmp(
systemCallExpression->systemCall.identifier->identifier.name,
"sizeof") == 0)
{
TypeTag *typeTag =
systemCallExpression->systemCall.genericArguments
->genericArguments.arguments[0]
->type.typeNode->typeTag;
return LLVMSizeOf(ResolveType(ConcretizeType(typeTag)));
}
else
{
fprintf(stderr, "System function not found!");
return NULL;
}
}
if (LLVMGetTypeKind(LLVMGetReturnType(systemFunction->functionType)) !=
LLVMVoidTypeKind)
@ -1792,6 +1859,9 @@ static void CompileStruct(
PushScopeFrame(scope);
if (node->structDeclaration.genericDeclarations->genericDeclarations
.count == 0)
{
LLVMTypeRef wStructType = LLVMStructCreateNamed(context, structName);
LLVMTypeRef wStructPointerType = LLVMPointerType(
wStructType,
@ -1800,8 +1870,9 @@ static void CompileStruct(
/* first, build the structure definition */
for (i = 0; i < declarationCount; i += 1)
{
currentDeclarationNode = node->structDeclaration.declarationSequence
->declarationSequence.sequence[i];
currentDeclarationNode =
node->structDeclaration.declarationSequence->declarationSequence
.sequence[i];
switch (currentDeclarationNode->syntaxKind)
{
@ -1825,8 +1896,9 @@ static void CompileStruct(
/* now we can wire up the functions */
for (i = 0; i < declarationCount; i += 1)
{
currentDeclarationNode = node->structDeclaration.declarationSequence
->declarationSequence.sequence[i];
currentDeclarationNode =
node->structDeclaration.declarationSequence->declarationSequence
.sequence[i];
switch (currentDeclarationNode->syntaxKind)
{
@ -1839,6 +1911,11 @@ static void CompileStruct(
break;
}
}
}
else
{
AddGenericStructDeclaration(node);
}
PopScopeFrame(scope);
}
@ -1946,6 +2023,9 @@ int Codegen(Node *node, uint32_t optimizationLevel)
structTypeDeclarations = NULL;
structTypeDeclarationCount = 0;
genericStructTypeDeclarations = NULL;
genericStructTypeDeclarationCount = 0;
systemFunctions = NULL;
systemFunctionCount = 0;

View File

@ -313,6 +313,10 @@ void TagIdentifierTypes(Node *node)
node->genericDeclaration.identifier->typeTag = MakeTypeTag(node);
break;
case Type:
node->typeTag = MakeTypeTag(node);
break;
case Identifier:
{
if (node->typeTag != NULL)
@ -467,6 +471,22 @@ void ConvertCustomsToGenerics(Node *node)
}
break;
}
case GenericArgument:
{
Node *typeNode = node->genericArgument.type;
if (typeNode->typeTag->type == Custom)
{
char *target = typeNode->typeTag->value.customType;
Node *typeLookup = LookupType(node, target);
if (typeLookup != NULL &&
typeLookup->syntaxKind == GenericDeclaration)
{
typeNode->typeTag->type = Generic;
}
}
break;
}
}
Recurse(node, *ConvertCustomsToGenerics);