preparing for struct generics

main
cosmonaut 2021-06-02 17:26:26 -07:00
parent ea203e6c3c
commit a571edcf6d
5 changed files with 273 additions and 123 deletions

View File

@ -345,7 +345,10 @@ GenericDeclarationClause : LESS_THAN GenericDeclarations GREATER_THAN
$$ = MakeEmptyGenericDeclarationsNode(); $$ = MakeEmptyGenericDeclarationsNode();
} }
GenericArgument : Type; GenericArgument : Type
{
$$ = MakeGenericArgumentNode($1);
}
GenericArguments : GenericArgument GenericArguments : GenericArgument
{ {

View File

@ -9,6 +9,17 @@ struct Foo {
} }
} }
struct MemoryBlock<T>
{
start: MemoryAddress;
capacity: uint;
AddressOf(count: uint): MemoryAddress
{
return start + (count * @sizeof<T>());
}
}
struct Program { struct Program {
static Main(): int { static Main(): int {
x: int = 4; x: int = 4;

View File

@ -875,6 +875,7 @@ void Recurse(Node *node, void (*func)(Node *))
case FunctionCallExpression: case FunctionCallExpression:
func(node->functionCallExpression.identifier); func(node->functionCallExpression.identifier);
func(node->functionCallExpression.argumentSequence); func(node->functionCallExpression.argumentSequence);
func(node->functionCallExpression.genericArguments);
return; return;
case FunctionDeclaration: case FunctionDeclaration:
@ -904,6 +905,17 @@ void Recurse(Node *node, void (*func)(Node *))
} }
return; return;
case GenericArgument:
func(node->genericArgument.type);
break;
case GenericArguments:
for (i = 0; i < node->genericArguments.count; i += 1)
{
func(node->genericArguments.arguments[i]);
}
return;
case GenericDeclaration: case GenericDeclaration:
func(node->genericDeclaration.identifier); func(node->genericDeclaration.identifier);
func(node->genericDeclaration.constraint); func(node->genericDeclaration.constraint);
@ -967,7 +979,14 @@ void Recurse(Node *node, void (*func)(Node *))
func(node->structDeclaration.declarationSequence); func(node->structDeclaration.declarationSequence);
return; return;
case SystemCall:
func(node->systemCall.identifier);
func(node->systemCall.argumentSequence);
func(node->systemCall.genericArguments);
return;
case Type: case Type:
func(node->type.typeNode);
return; return;
case UnaryExpression: case UnaryExpression:
@ -1195,6 +1214,17 @@ void LinkParentPointers(Node *node, Node *prev)
} }
return; return;
case GenericArgument:
LinkParentPointers(node->genericArgument.type, node);
return;
case GenericArguments:
for (i = 0; i < node->genericArguments.count; i += 1)
{
LinkParentPointers(node->genericArguments.arguments[i], node);
}
return;
case GenericDeclaration: case GenericDeclaration:
LinkParentPointers(node->genericDeclaration.identifier, node); LinkParentPointers(node->genericDeclaration.identifier, node);
LinkParentPointers(node->genericDeclaration.constraint, node); LinkParentPointers(node->genericDeclaration.constraint, node);
@ -1258,6 +1288,12 @@ void LinkParentPointers(Node *node, Node *prev)
LinkParentPointers(node->structDeclaration.declarationSequence, node); LinkParentPointers(node->structDeclaration.declarationSequence, node);
return; return;
case SystemCall:
LinkParentPointers(node->systemCall.identifier, node);
LinkParentPointers(node->systemCall.argumentSequence, node);
LinkParentPointers(node->systemCall.genericArguments, node);
return;
case Type: case Type:
return; return;

View File

@ -112,6 +112,30 @@ typedef struct StructTypeDeclaration
StructTypeDeclaration *structTypeDeclarations; StructTypeDeclaration *structTypeDeclarations;
uint32_t structTypeDeclarationCount; uint32_t structTypeDeclarationCount;
typedef struct MonomorphizedGenericStructHashEntry
{
uint64_t key;
TypeTag **types;
uint32_t typeCount;
StructTypeDeclaration structDeclaration;
} MonomorphizedGenericStructHashEntry;
typedef struct MonomorphizedGenericStructHashArray
{
MonomorphizedGenericStructHashEntry *elements;
uint32_t count;
} MonomorphizedGenericStructHashArray;
typedef struct GenericStructTypeDeclaration
{
Node *structDeclarationNode;
MonomorphizedGenericStructHashArray
monomorphizedStructs[NUM_MONOMORPHIZED_HASH_BUCKETS];
} GenericStructTypeDeclaration;
GenericStructTypeDeclaration *genericStructTypeDeclarations;
uint32_t genericStructTypeDeclarationCount;
typedef struct SystemFunction typedef struct SystemFunction
{ {
char *name; char *name;
@ -234,16 +258,14 @@ static LocalGenericType *LookupGenericType(char *name)
return NULL; return NULL;
} }
static TypeTag *ConcretizeGenericType(char *name) static TypeTag *ConcretizeType(TypeTag *type)
{ {
LocalGenericType *type = LookupGenericType(name); if (type->type == Generic)
if (type == NULL)
{ {
return NULL; return LookupGenericType(type->value.genericType)->concreteTypeTag;
} }
return type->concreteTypeTag; return type;
} }
static LLVMTypeRef LookupCustomType(char *name) static LLVMTypeRef LookupCustomType(char *name)
@ -258,7 +280,7 @@ static LLVMTypeRef LookupCustomType(char *name)
} }
} }
fprintf(stderr, "Could not find struct type!\n"); fprintf(stderr, "Could not find custom type!\n");
return NULL; return NULL;
} }
@ -301,9 +323,10 @@ static void AddSystemFunction(
systemFunctionCount += 1; systemFunctionCount += 1;
} }
static SystemFunction *LookupSystemFunction(char *name) static SystemFunction *LookupSystemFunction(Node *systemCallExpression)
{ {
uint32_t i; uint32_t i;
char *name = systemCallExpression->systemCall.identifier->identifier.name;
for (i = 0; i < systemFunctionCount; i += 1) for (i = 0; i < systemFunctionCount; i += 1)
{ {
@ -518,6 +541,30 @@ static void AddStructDeclaration(
structTypeDeclarationCount += 1; structTypeDeclarationCount += 1;
} }
static void AddGenericStructDeclaration(Node *structDeclarationNode)
{
uint32_t i;
genericStructTypeDeclarations = realloc(
genericStructTypeDeclarations,
sizeof(GenericStructTypeDeclaration) *
(genericStructTypeDeclarationCount + 1));
genericStructTypeDeclarations[genericStructTypeDeclarationCount]
.structDeclarationNode = structDeclarationNode;
for (i = 0; i < NUM_MONOMORPHIZED_HASH_BUCKETS; i += 1)
{
genericStructTypeDeclarations[genericStructTypeDeclarationCount]
.monomorphizedStructs[i]
.elements = NULL;
genericStructTypeDeclarations[genericStructTypeDeclarationCount]
.monomorphizedStructs[i]
.count = 0;
}
genericStructTypeDeclarationCount += 1;
}
/* FIXME: pass the declaration itself */ /* FIXME: pass the declaration itself */
static void DeclareStructFunction( static void DeclareStructFunction(
LLVMTypeRef wStructPointerType, LLVMTypeRef wStructPointerType,
@ -771,8 +818,7 @@ static StructTypeFunction CompileGenericFunction(
static LLVMValueRef LookupGenericFunction( static LLVMValueRef LookupGenericFunction(
LLVMModuleRef module, LLVMModuleRef module,
StructTypeGenericFunction *genericFunction, StructTypeGenericFunction *genericFunction,
TypeTag **argumentTypes, Node *functionCallExpression,
uint32_t argumentCount,
LLVMTypeRef *pReturnType, LLVMTypeRef *pReturnType,
uint8_t *pStatic) uint8_t *pStatic)
{ {
@ -784,33 +830,70 @@ static LLVMValueRef LookupGenericFunction(
.functionSignature->functionSignature.genericDeclarations .functionSignature->functionSignature.genericDeclarations
->genericDeclarations.count; ->genericDeclarations.count;
TypeTag *resolvedGenericArgumentTypes[genericArgumentTypeCount]; TypeTag *resolvedGenericArgumentTypes[genericArgumentTypeCount];
uint32_t argumentCount = 0;
TypeTag
*argumentTypes[functionCallExpression->functionCallExpression
.argumentSequence->functionArgumentSequence.count];
for (i = 0; i < genericArgumentTypeCount; i += 1) for (i = 0; i < functionCallExpression->functionCallExpression
.argumentSequence->functionArgumentSequence.count;
i += 1)
{ {
for (j = 0; argumentTypes[i] =
j < genericFunction->functionDeclarationNode->functionDeclaration functionCallExpression->functionCallExpression.argumentSequence
->functionArgumentSequence.sequence[i]
->typeTag;
argumentCount += 1;
}
if (functionCallExpression->functionCallExpression.genericArguments
->genericArguments.count > 0)
{
for (i = 0; i < functionCallExpression->functionCallExpression
.genericArguments->genericArguments.count;
i += 1)
{
resolvedGenericArgumentTypes[i] =
functionCallExpression->functionCallExpression.genericArguments
->genericArguments.arguments[i]
->genericArgument.type->typeTag;
}
}
else /* we have to infer the generics */
{
for (i = 0; i < genericArgumentTypeCount; i += 1)
{
for (j = 0;
j <
genericFunction->functionDeclarationNode->functionDeclaration
.functionSignature->functionSignature.arguments .functionSignature->functionSignature.arguments
->functionSignatureArguments.count; ->functionSignatureArguments.count;
j += 1) j += 1)
{
if (genericFunction->functionDeclarationNode->functionDeclaration
.functionSignature->functionSignature.arguments
->functionSignatureArguments.sequence[j]
->declaration.identifier->typeTag->type == Generic &&
strcmp(
genericFunction->functionDeclarationNode
->functionDeclaration.functionSignature
->functionSignature.arguments
->functionSignatureArguments.sequence[j]
->declaration.identifier->typeTag->value.genericType,
genericFunction->functionDeclarationNode
->functionDeclaration.functionSignature
->functionSignature.genericDeclarations
->genericDeclarations.declarations[i]
->genericDeclaration.identifier->identifier.name) == 0)
{ {
resolvedGenericArgumentTypes[i] = argumentTypes[j]; if (genericFunction->functionDeclarationNode
break; ->functionDeclaration.functionSignature
->functionSignature.arguments
->functionSignatureArguments.sequence[j]
->declaration.identifier->typeTag->type ==
Generic &&
strcmp(
genericFunction->functionDeclarationNode
->functionDeclaration.functionSignature
->functionSignature.arguments
->functionSignatureArguments.sequence[j]
->declaration.identifier->typeTag->value
.genericType,
genericFunction->functionDeclarationNode
->functionDeclaration.functionSignature
->functionSignature.genericDeclarations
->genericDeclarations.declarations[i]
->genericDeclaration.identifier->identifier.name) ==
0)
{
resolvedGenericArgumentTypes[i] = argumentTypes[j];
break;
}
} }
} }
} }
@ -818,11 +901,8 @@ static LLVMValueRef LookupGenericFunction(
/* Concretize generics if we are compiling nested generic functions */ /* Concretize generics if we are compiling nested generic functions */
for (i = 0; i < genericArgumentTypeCount; i += 1) for (i = 0; i < genericArgumentTypeCount; i += 1)
{ {
if (resolvedGenericArgumentTypes[i]->type == Generic) resolvedGenericArgumentTypes[i] =
{ ConcretizeType(resolvedGenericArgumentTypes[i]);
resolvedGenericArgumentTypes[i] = ConcretizeGenericType(
resolvedGenericArgumentTypes[i]->value.genericType);
}
} }
typeHash = typeHash =
@ -894,13 +974,14 @@ static LLVMValueRef LookupGenericFunction(
static LLVMValueRef LookupFunctionByType( static LLVMValueRef LookupFunctionByType(
LLVMModuleRef module, LLVMModuleRef module,
LLVMTypeRef structType, LLVMTypeRef structType,
char *name, Node *functionCallExpression,
TypeTag **argumentTypes,
uint32_t argumentCount,
LLVMTypeRef *pReturnType, LLVMTypeRef *pReturnType,
uint8_t *pStatic) uint8_t *pStatic)
{ {
uint32_t i, j; uint32_t i, j;
/* FIXME: hot garbage */
char *name = functionCallExpression->functionCallExpression.identifier
->accessExpression.accessor->identifier.name;
for (i = 0; i < structTypeDeclarationCount; i += 1) for (i = 0; i < structTypeDeclarationCount; i += 1)
{ {
@ -928,8 +1009,7 @@ static LLVMValueRef LookupFunctionByType(
return LookupGenericFunction( return LookupGenericFunction(
module, module,
&structTypeDeclarations[i].genericFunctions[j], &structTypeDeclarations[i].genericFunctions[j],
argumentTypes, functionCallExpression,
argumentCount,
pReturnType, pReturnType,
pStatic); pStatic);
} }
@ -944,13 +1024,14 @@ static LLVMValueRef LookupFunctionByType(
static LLVMValueRef LookupFunctionByPointerType( static LLVMValueRef LookupFunctionByPointerType(
LLVMModuleRef module, LLVMModuleRef module,
LLVMTypeRef structPointerType, LLVMTypeRef structPointerType,
char *name, Node *functionCallExpression,
TypeTag **argumentTypes,
uint32_t argumentCount,
LLVMTypeRef *pReturnType, LLVMTypeRef *pReturnType,
uint8_t *pStatic) uint8_t *pStatic)
{ {
uint32_t i, j; uint32_t i, j;
/* FIXME: hot garbage */
char *name = functionCallExpression->functionCallExpression.identifier
->accessExpression.accessor->identifier.name;
for (i = 0; i < structTypeDeclarationCount; i += 1) for (i = 0; i < structTypeDeclarationCount; i += 1)
{ {
@ -978,8 +1059,7 @@ static LLVMValueRef LookupFunctionByPointerType(
return LookupGenericFunction( return LookupGenericFunction(
module, module,
&structTypeDeclarations[i].genericFunctions[j], &structTypeDeclarations[i].genericFunctions[j],
argumentTypes, functionCallExpression,
argumentCount,
pReturnType, pReturnType,
pStatic); pStatic);
} }
@ -994,18 +1074,14 @@ static LLVMValueRef LookupFunctionByPointerType(
static LLVMValueRef LookupFunctionByInstance( static LLVMValueRef LookupFunctionByInstance(
LLVMModuleRef module, LLVMModuleRef module,
LLVMValueRef structPointer, LLVMValueRef structPointer,
char *name, Node *functionCallExpression,
TypeTag **argumentTypes,
uint32_t argumentCount,
LLVMTypeRef *pReturnType, LLVMTypeRef *pReturnType,
uint8_t *pStatic) uint8_t *pStatic)
{ {
return LookupFunctionByPointerType( return LookupFunctionByPointerType(
module, module,
LLVMTypeOf(structPointer), LLVMTypeOf(structPointer),
name, functionCallExpression,
argumentTypes,
argumentCount,
pReturnType, pReturnType,
pStatic); pStatic);
} }
@ -1093,27 +1169,12 @@ static LLVMValueRef CompileFunctionCallExpression(
[functionCallExpression->functionCallExpression.argumentSequence [functionCallExpression->functionCallExpression.argumentSequence
->functionArgumentSequence.count + ->functionArgumentSequence.count +
1]; 1];
TypeTag
*argumentTypes[functionCallExpression->functionCallExpression
.argumentSequence->functionArgumentSequence.count];
LLVMValueRef function; LLVMValueRef function;
uint8_t isStatic; uint8_t isStatic;
LLVMValueRef structInstance; LLVMValueRef structInstance;
LLVMTypeRef functionReturnType; LLVMTypeRef functionReturnType;
char *returnName = ""; char *returnName = "";
for (i = 0; i < functionCallExpression->functionCallExpression
.argumentSequence->functionArgumentSequence.count;
i += 1)
{
argumentTypes[i] =
functionCallExpression->functionCallExpression.argumentSequence
->functionArgumentSequence.sequence[i]
->typeTag;
argumentCount += 1;
}
/* FIXME: this needs to be recursive on access chains */ /* FIXME: this needs to be recursive on access chains */
/* FIXME: this needs to be able to call same-struct functions implicitly /* FIXME: this needs to be able to call same-struct functions implicitly
*/ */
@ -1129,10 +1190,7 @@ static LLVMValueRef CompileFunctionCallExpression(
function = LookupFunctionByType( function = LookupFunctionByType(
module, module,
typeReference, typeReference,
functionCallExpression->functionCallExpression.identifier functionCallExpression,
->accessExpression.accessor->identifier.name,
argumentTypes,
argumentCount,
&functionReturnType, &functionReturnType,
&isStatic); &isStatic);
} }
@ -1144,10 +1202,7 @@ static LLVMValueRef CompileFunctionCallExpression(
function = LookupFunctionByInstance( function = LookupFunctionByInstance(
module, module,
structInstance, structInstance,
functionCallExpression->functionCallExpression.identifier functionCallExpression,
->accessExpression.accessor->identifier.name,
argumentTypes,
argumentCount,
&functionReturnType, &functionReturnType,
&isStatic); &isStatic);
} }
@ -1158,8 +1213,6 @@ static LLVMValueRef CompileFunctionCallExpression(
return NULL; return NULL;
} }
argumentCount = 0;
if (!isStatic) if (!isStatic)
{ {
args[argumentCount] = structInstance; args[argumentCount] = structInstance;
@ -1208,13 +1261,27 @@ static LLVMValueRef CompileSystemCallExpression(
->functionArgumentSequence.sequence[i]); ->functionArgumentSequence.sequence[i]);
} }
SystemFunction *systemFunction = LookupSystemFunction( SystemFunction *systemFunction = LookupSystemFunction(systemCallExpression);
systemCallExpression->systemCall.identifier->identifier.name);
if (systemFunction == NULL) if (systemFunction == NULL)
{ {
fprintf(stderr, "System function not found!"); /* special case for sizeof */
return NULL; if (strcmp(
systemCallExpression->systemCall.identifier->identifier.name,
"sizeof") == 0)
{
TypeTag *typeTag =
systemCallExpression->systemCall.genericArguments
->genericArguments.arguments[0]
->type.typeNode->typeTag;
return LLVMSizeOf(ResolveType(ConcretizeType(typeTag)));
}
else
{
fprintf(stderr, "System function not found!");
return NULL;
}
} }
if (LLVMGetTypeKind(LLVMGetReturnType(systemFunction->functionType)) != if (LLVMGetTypeKind(LLVMGetReturnType(systemFunction->functionType)) !=
@ -1792,52 +1859,62 @@ static void CompileStruct(
PushScopeFrame(scope); PushScopeFrame(scope);
LLVMTypeRef wStructType = LLVMStructCreateNamed(context, structName); if (node->structDeclaration.genericDeclarations->genericDeclarations
LLVMTypeRef wStructPointerType = LLVMPointerType( .count == 0)
wStructType,
0); /* FIXME: is this address space correct? */
/* first, build the structure definition */
for (i = 0; i < declarationCount; i += 1)
{ {
currentDeclarationNode = node->structDeclaration.declarationSequence LLVMTypeRef wStructType = LLVMStructCreateNamed(context, structName);
->declarationSequence.sequence[i]; LLVMTypeRef wStructPointerType = LLVMPointerType(
wStructType,
0); /* FIXME: is this address space correct? */
switch (currentDeclarationNode->syntaxKind) /* first, build the structure definition */
for (i = 0; i < declarationCount; i += 1)
{ {
case Declaration: /* this is badly named */ currentDeclarationNode =
types[fieldCount] = ResolveType( node->structDeclaration.declarationSequence->declarationSequence
currentDeclarationNode->declaration.identifier->typeTag); .sequence[i];
fieldDeclarations[fieldCount] = currentDeclarationNode;
fieldCount += 1; switch (currentDeclarationNode->syntaxKind)
break; {
case Declaration: /* this is badly named */
types[fieldCount] = ResolveType(
currentDeclarationNode->declaration.identifier->typeTag);
fieldDeclarations[fieldCount] = currentDeclarationNode;
fieldCount += 1;
break;
}
}
LLVMStructSetBody(wStructType, types, fieldCount, packed);
AddStructDeclaration(
wStructType,
wStructPointerType,
structName,
fieldDeclarations,
fieldCount);
/* now we can wire up the functions */
for (i = 0; i < declarationCount; i += 1)
{
currentDeclarationNode =
node->structDeclaration.declarationSequence->declarationSequence
.sequence[i];
switch (currentDeclarationNode->syntaxKind)
{
case FunctionDeclaration:
CompileFunction(
module,
structName,
wStructPointerType,
currentDeclarationNode);
break;
}
} }
} }
else
LLVMStructSetBody(wStructType, types, fieldCount, packed);
AddStructDeclaration(
wStructType,
wStructPointerType,
structName,
fieldDeclarations,
fieldCount);
/* now we can wire up the functions */
for (i = 0; i < declarationCount; i += 1)
{ {
currentDeclarationNode = node->structDeclaration.declarationSequence AddGenericStructDeclaration(node);
->declarationSequence.sequence[i];
switch (currentDeclarationNode->syntaxKind)
{
case FunctionDeclaration:
CompileFunction(
module,
structName,
wStructPointerType,
currentDeclarationNode);
break;
}
} }
PopScopeFrame(scope); PopScopeFrame(scope);
@ -1946,6 +2023,9 @@ int Codegen(Node *node, uint32_t optimizationLevel)
structTypeDeclarations = NULL; structTypeDeclarations = NULL;
structTypeDeclarationCount = 0; structTypeDeclarationCount = 0;
genericStructTypeDeclarations = NULL;
genericStructTypeDeclarationCount = 0;
systemFunctions = NULL; systemFunctions = NULL;
systemFunctionCount = 0; systemFunctionCount = 0;

View File

@ -313,6 +313,10 @@ void TagIdentifierTypes(Node *node)
node->genericDeclaration.identifier->typeTag = MakeTypeTag(node); node->genericDeclaration.identifier->typeTag = MakeTypeTag(node);
break; break;
case Type:
node->typeTag = MakeTypeTag(node);
break;
case Identifier: case Identifier:
{ {
if (node->typeTag != NULL) if (node->typeTag != NULL)
@ -467,6 +471,22 @@ void ConvertCustomsToGenerics(Node *node)
} }
break; break;
} }
case GenericArgument:
{
Node *typeNode = node->genericArgument.type;
if (typeNode->typeTag->type == Custom)
{
char *target = typeNode->typeTag->value.customType;
Node *typeLookup = LookupType(node, target);
if (typeLookup != NULL &&
typeLookup->syntaxKind == GenericDeclaration)
{
typeNode->typeTag->type = Generic;
}
}
break;
}
} }
Recurse(node, *ConvertCustomsToGenerics); Recurse(node, *ConvertCustomsToGenerics);