static function lookup

generics
cosmonaut 2021-04-22 22:37:23 -07:00
parent b344635c8d
commit c4c916a2de
1 changed files with 149 additions and 95 deletions

View File

@ -61,18 +61,20 @@ typedef struct StructTypeFunction
uint8_t isStatic;
} StructTypeFunction;
typedef struct StructTypeFieldDeclaration
typedef struct StructTypeDeclaration
{
char *name;
LLVMTypeRef structType;
LLVMTypeRef structPointerType;
StructTypeField *fields;
uint32_t fieldCount;
StructTypeFunction *functions;
uint32_t functionCount;
} StructTypeFieldDeclaration;
} StructTypeDeclaration;
StructTypeFieldDeclaration *structTypeFieldDeclarations;
uint32_t structTypeFieldDeclarationCount;
StructTypeDeclaration *structTypeDeclarations;
uint32_t structTypeDeclarationCount;
static Scope* CreateScope()
{
@ -127,26 +129,41 @@ static void AddLocalVariable(Scope *scope, LLVMValueRef pointer, char *name)
scopeFrame->localVariableCount += 1;
}
static LLVMTypeRef FindStructType(char *name)
{
uint32_t i;
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (strcmp(structTypeDeclarations[i].name, name) == 0)
{
return structTypeDeclarations[i].structType;
}
}
return NULL;
}
static LLVMValueRef FindStructFieldPointer(LLVMBuilderRef builder, LLVMValueRef structPointer, char *name)
{
int32_t i, j;
LLVMTypeRef structType = LLVMTypeOf(structPointer);
LLVMTypeRef structPointerType = LLVMTypeOf(structPointer);
for (i = 0; i < structTypeFieldDeclarationCount; i += 1)
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (structTypeFieldDeclarations[i].structType == structType)
if (structTypeDeclarations[i].structPointerType == structPointerType)
{
for (j = 0; j < structTypeFieldDeclarations[i].fieldCount; j += 1)
for (j = 0; j < structTypeDeclarations[i].fieldCount; j += 1)
{
if (strcmp(structTypeFieldDeclarations[i].fields[j].name, name) == 0)
if (strcmp(structTypeDeclarations[i].fields[j].name, name) == 0)
{
char *ptrName = strdup(name);
strcat(ptrName, "_ptr");
return LLVMBuildStructGEP(
builder,
structPointer,
structTypeFieldDeclarations[i].fields[j].index,
structTypeDeclarations[i].fields[j].index,
ptrName
);
free(ptrName);
@ -199,31 +216,35 @@ static LLVMValueRef FindVariableValue(LLVMBuilderRef builder, char *name)
static void AddStructDeclaration(
LLVMTypeRef wStructType,
LLVMTypeRef wStructPointerType,
char *name,
Node **fieldDeclarations,
uint32_t fieldDeclarationCount
) {
uint32_t i;
uint32_t index = structTypeFieldDeclarationCount;
structTypeFieldDeclarations = realloc(structTypeFieldDeclarations, sizeof(StructTypeFieldDeclaration) * (structTypeFieldDeclarationCount + 1));
structTypeFieldDeclarations[index].structType = wStructType;
structTypeFieldDeclarations[index].fields = NULL;
structTypeFieldDeclarations[index].fieldCount = 0;
structTypeFieldDeclarations[index].functions = NULL;
structTypeFieldDeclarations[index].functionCount = 0;
uint32_t index = structTypeDeclarationCount;
structTypeDeclarations = realloc(structTypeDeclarations, sizeof(StructTypeDeclaration) * (structTypeDeclarationCount + 1));
structTypeDeclarations[index].structType = wStructType;
structTypeDeclarations[index].structPointerType = wStructPointerType;
structTypeDeclarations[index].name = strdup(name);
structTypeDeclarations[index].fields = NULL;
structTypeDeclarations[index].fieldCount = 0;
structTypeDeclarations[index].functions = NULL;
structTypeDeclarations[index].functionCount = 0;
for (i = 0; i < fieldDeclarationCount; i += 1)
{
structTypeFieldDeclarations[index].fields = realloc(structTypeFieldDeclarations[index].fields, sizeof(StructTypeField) * (structTypeFieldDeclarations[index].fieldCount + 1));
structTypeFieldDeclarations[index].fields[i].name = strdup(fieldDeclarations[i]->children[1]->value.string);
structTypeFieldDeclarations[index].fields[i].index = i;
structTypeFieldDeclarations[index].fieldCount += 1;
structTypeDeclarations[index].fields = realloc(structTypeDeclarations[index].fields, sizeof(StructTypeField) * (structTypeDeclarations[index].fieldCount + 1));
structTypeDeclarations[index].fields[i].name = strdup(fieldDeclarations[i]->children[1]->value.string);
structTypeDeclarations[index].fields[i].index = i;
structTypeDeclarations[index].fieldCount += 1;
}
structTypeFieldDeclarationCount += 1;
structTypeDeclarationCount += 1;
}
static void DeclareStructFunction(
LLVMTypeRef wStructType,
LLVMTypeRef wStructPointerType,
LLVMValueRef function,
LLVMTypeRef returnType,
uint8_t isStatic,
@ -231,71 +252,124 @@ static void DeclareStructFunction(
) {
uint32_t i, index;
for (i = 0; i < structTypeFieldDeclarationCount; i += 1)
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (structTypeFieldDeclarations[i].structType == wStructType)
if (structTypeDeclarations[i].structPointerType == wStructPointerType)
{
index = structTypeFieldDeclarations[i].functionCount;
structTypeFieldDeclarations[i].functions = realloc(structTypeFieldDeclarations[i].functions, sizeof(StructTypeFunction) * (structTypeFieldDeclarations[i].functionCount + 1));
structTypeFieldDeclarations[i].functions[index].name = strdup(name);
structTypeFieldDeclarations[i].functions[index].function = function;
structTypeFieldDeclarations[i].functions[index].returnType = returnType;
structTypeFieldDeclarations[i].functions[index].isStatic = isStatic;
structTypeFieldDeclarations[i].functionCount += 1;
index = structTypeDeclarations[i].functionCount;
structTypeDeclarations[i].functions = realloc(structTypeDeclarations[i].functions, sizeof(StructTypeFunction) * (structTypeDeclarations[i].functionCount + 1));
structTypeDeclarations[i].functions[index].name = strdup(name);
structTypeDeclarations[i].functions[index].function = function;
structTypeDeclarations[i].functions[index].returnType = returnType;
structTypeDeclarations[i].functions[index].isStatic = isStatic;
structTypeDeclarations[i].functionCount += 1;
return;
}
}
fprintf(stderr, "Could not find struct type for function!");
fprintf(stderr, "Could not find struct type for function!\n");
}
static LLVMValueRef LookupFunction(
LLVMValueRef structPointer,
static LLVMTypeRef LookupCustomType(char *name)
{
uint32_t i;
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (strcmp(structTypeDeclarations[i].name, name) == 0)
{
return structTypeDeclarations[i].structType;
}
}
fprintf(stderr, "Could not find struct type!\n");
return NULL;
}
static LLVMValueRef LookupFunctionByType(
LLVMTypeRef structType,
char *name,
LLVMTypeRef *pReturnType,
uint8_t *pStatic
) {
uint32_t i, j;
for (i = 0; i < structTypeFieldDeclarationCount; i += 1)
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (structTypeFieldDeclarations[i].structType == LLVMTypeOf(structPointer))
if (structTypeDeclarations[i].structType == structType)
{
for (j = 0; j < structTypeFieldDeclarations[i].functionCount; j += 1)
for (j = 0; j < structTypeDeclarations[i].functionCount; j += 1)
{
if (strcmp(structTypeFieldDeclarations[i].functions[j].name, name) == 0)
if (strcmp(structTypeDeclarations[i].functions[j].name, name) == 0)
{
*pReturnType = structTypeFieldDeclarations[i].functions[j].returnType;
*pStatic = structTypeFieldDeclarations[i].functions[j].isStatic;
return structTypeFieldDeclarations[i].functions[j].function;
*pReturnType = structTypeDeclarations[i].functions[j].returnType;
*pStatic = structTypeDeclarations[i].functions[j].isStatic;
return structTypeDeclarations[i].functions[j].function;
}
}
}
}
fprintf(stderr, "Could not find struct function!");
fprintf(stderr, "Could not find struct function!\n");
return NULL;
}
static LLVMValueRef LookupFunctionByPointerType(
LLVMTypeRef structPointerType,
char *name,
LLVMTypeRef *pReturnType,
uint8_t *pStatic
) {
uint32_t i, j;
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (structTypeDeclarations[i].structPointerType == structPointerType)
{
for (j = 0; j < structTypeDeclarations[i].functionCount; j += 1)
{
if (strcmp(structTypeDeclarations[i].functions[j].name, name) == 0)
{
*pReturnType = structTypeDeclarations[i].functions[j].returnType;
*pStatic = structTypeDeclarations[i].functions[j].isStatic;
return structTypeDeclarations[i].functions[j].function;
}
}
}
}
fprintf(stderr, "Could not find struct function!\n");
return NULL;
}
static LLVMValueRef LookupFunctionByInstance(
LLVMValueRef structPointer,
char *name,
LLVMTypeRef *pReturnType,
uint8_t *pStatic
) {
return LookupFunctionByPointerType(LLVMTypeOf(structPointer), name, pReturnType, pStatic);
}
static void AddStructVariablesToScope(
LLVMBuilderRef builder,
LLVMValueRef structPointer
) {
uint32_t i, j;
for (i = 0; i < structTypeFieldDeclarationCount; i += 1)
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (structTypeFieldDeclarations[i].structType == LLVMTypeOf(structPointer))
if (structTypeDeclarations[i].structPointerType == LLVMTypeOf(structPointer))
{
for (j = 0; j < structTypeFieldDeclarations[i].fieldCount; j += 1)
for (j = 0; j < structTypeDeclarations[i].fieldCount; j += 1)
{
char *ptrName = strdup(structTypeFieldDeclarations[i].fields[j].name);
char *ptrName = strdup(structTypeDeclarations[i].fields[j].name);
strcat(ptrName, "_ptr");
LLVMValueRef elementPointer = LLVMBuildStructGEP(
builder,
structPointer,
structTypeFieldDeclarations[i].fields[j].index,
structTypeDeclarations[i].fields[j].index,
ptrName
);
free(ptrName);
@ -303,7 +377,7 @@ static void AddStructVariablesToScope(
AddLocalVariable(
scope,
elementPointer,
structTypeFieldDeclarations[i].fields[j].name
structTypeDeclarations[i].fields[j].name
);
}
}
@ -315,38 +389,6 @@ static LLVMValueRef CompileExpression(
Node *binaryExpression
);
typedef struct CustomTypeMap
{
LLVMTypeRef type;
char *name;
} CustomTypeMap;
CustomTypeMap *customTypes;
uint32_t customTypeCount;
static void RegisterCustomType(LLVMTypeRef type, char *name)
{
customTypes = realloc(customTypes, sizeof(CustomType) * (customTypeCount + 1));
customTypes[customTypeCount].type = type;
customTypes[customTypeCount].name = strdup(name);
customTypeCount += 1;
}
static LLVMTypeRef LookupCustomType(char *name)
{
uint32_t i;
for (i = 0; i < customTypeCount; i += 1)
{
if (strcmp(customTypes[i].name, name) == 0)
{
return customTypes[i].type;
}
}
return NULL;
}
static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type)
{
switch (type)
@ -404,7 +446,7 @@ static LLVMValueRef CompileFunctionCallExpression(
) {
uint32_t i;
uint32_t argumentCount = 0;
LLVMValueRef args[argumentCount];
LLVMValueRef args[expression->children[1]->childCount + 1];
LLVMValueRef function;
uint8_t isStatic;
LLVMValueRef structInstance;
@ -414,8 +456,24 @@ static LLVMValueRef CompileFunctionCallExpression(
/* FIXME: this needs to be recursive on access chains */
if (expression->children[0]->syntaxKind == AccessExpression)
{
structInstance = FindVariablePointer(expression->children[0]->children[0]->value.string);
function = LookupFunction(structInstance, expression->children[0]->children[1]->value.string, &functionReturnType, &isStatic);
LLVMTypeRef typeReference = FindStructType(
expression->children[0]->children[0]->value.string
);
if (typeReference != NULL)
{
function = LookupFunctionByType(
typeReference,
expression->children[0]->children[1]->value.string,
&functionReturnType,
&isStatic
);
}
else
{
structInstance = FindVariablePointer(expression->children[0]->children[0]->value.string);
function = LookupFunctionByInstance(structInstance, expression->children[0]->children[1]->value.string, &functionReturnType, &isStatic);
}
}
else
{
@ -641,7 +699,7 @@ static void CompileFunction(
{
char *ptrName = strdup(functionSignature->children[2]->children[i]->children[1]->value.string);
strcat(ptrName, "_ptr");
LLVMValueRef argument = LLVMGetParam(function, i + 1);
LLVMValueRef argument = LLVMGetParam(function, i + !isStatic);
LLVMValueRef argumentCopy = LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName);
LLVMBuildStore(builder, argument, argumentCopy);
free(ptrName);
@ -680,8 +738,8 @@ static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *no
PushScopeFrame(scope);
LLVMTypeRef wStruct = LLVMStructCreateNamed(context, structName);
LLVMTypeRef wStructPointerType = LLVMPointerType(wStruct, 0); /* FIXME: is this address space correct? */
LLVMTypeRef wStructType = LLVMStructCreateNamed(context, structName);
LLVMTypeRef wStructPointerType = LLVMPointerType(wStructType, 0); /* FIXME: is this address space correct? */
/* first, build the structure definition */
for (i = 0; i < declarationCount; i += 1)
@ -698,9 +756,8 @@ static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *no
}
}
LLVMStructSetBody(wStruct, types, fieldCount, packed);
AddStructDeclaration(wStructPointerType, fieldDeclarations, fieldCount);
RegisterCustomType(wStruct, node->children[0]->value.string);
LLVMStructSetBody(wStructType, types, fieldCount, packed);
AddStructDeclaration(wStructType, wStructPointerType, node->children[0]->value.string, fieldDeclarations, fieldCount);
/* now we can wire up the functions */
for (i = 0; i < declarationCount; i += 1)
@ -748,11 +805,8 @@ int main(int argc, char *argv[])
scope = CreateScope();
structTypeFieldDeclarations = NULL;
structTypeFieldDeclarationCount = 0;
customTypes = NULL;
customTypeCount = 0;
structTypeDeclarations = NULL;
structTypeDeclarationCount = 0;
stack = CreateStack();