static function lookup

generics
cosmonaut 2021-04-22 22:37:23 -07:00
parent b344635c8d
commit c4c916a2de
1 changed files with 149 additions and 95 deletions

View File

@ -61,18 +61,20 @@ typedef struct StructTypeFunction
uint8_t isStatic; uint8_t isStatic;
} StructTypeFunction; } StructTypeFunction;
typedef struct StructTypeFieldDeclaration typedef struct StructTypeDeclaration
{ {
char *name;
LLVMTypeRef structType; LLVMTypeRef structType;
LLVMTypeRef structPointerType;
StructTypeField *fields; StructTypeField *fields;
uint32_t fieldCount; uint32_t fieldCount;
StructTypeFunction *functions; StructTypeFunction *functions;
uint32_t functionCount; uint32_t functionCount;
} StructTypeFieldDeclaration; } StructTypeDeclaration;
StructTypeFieldDeclaration *structTypeFieldDeclarations; StructTypeDeclaration *structTypeDeclarations;
uint32_t structTypeFieldDeclarationCount; uint32_t structTypeDeclarationCount;
static Scope* CreateScope() static Scope* CreateScope()
{ {
@ -127,26 +129,41 @@ static void AddLocalVariable(Scope *scope, LLVMValueRef pointer, char *name)
scopeFrame->localVariableCount += 1; scopeFrame->localVariableCount += 1;
} }
static LLVMTypeRef FindStructType(char *name)
{
uint32_t i;
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (strcmp(structTypeDeclarations[i].name, name) == 0)
{
return structTypeDeclarations[i].structType;
}
}
return NULL;
}
static LLVMValueRef FindStructFieldPointer(LLVMBuilderRef builder, LLVMValueRef structPointer, char *name) static LLVMValueRef FindStructFieldPointer(LLVMBuilderRef builder, LLVMValueRef structPointer, char *name)
{ {
int32_t i, j; int32_t i, j;
LLVMTypeRef structType = LLVMTypeOf(structPointer); LLVMTypeRef structPointerType = LLVMTypeOf(structPointer);
for (i = 0; i < structTypeFieldDeclarationCount; i += 1) for (i = 0; i < structTypeDeclarationCount; i += 1)
{ {
if (structTypeFieldDeclarations[i].structType == structType) if (structTypeDeclarations[i].structPointerType == structPointerType)
{ {
for (j = 0; j < structTypeFieldDeclarations[i].fieldCount; j += 1) for (j = 0; j < structTypeDeclarations[i].fieldCount; j += 1)
{ {
if (strcmp(structTypeFieldDeclarations[i].fields[j].name, name) == 0) if (strcmp(structTypeDeclarations[i].fields[j].name, name) == 0)
{ {
char *ptrName = strdup(name); char *ptrName = strdup(name);
strcat(ptrName, "_ptr"); strcat(ptrName, "_ptr");
return LLVMBuildStructGEP( return LLVMBuildStructGEP(
builder, builder,
structPointer, structPointer,
structTypeFieldDeclarations[i].fields[j].index, structTypeDeclarations[i].fields[j].index,
ptrName ptrName
); );
free(ptrName); free(ptrName);
@ -199,31 +216,35 @@ static LLVMValueRef FindVariableValue(LLVMBuilderRef builder, char *name)
static void AddStructDeclaration( static void AddStructDeclaration(
LLVMTypeRef wStructType, LLVMTypeRef wStructType,
LLVMTypeRef wStructPointerType,
char *name,
Node **fieldDeclarations, Node **fieldDeclarations,
uint32_t fieldDeclarationCount uint32_t fieldDeclarationCount
) { ) {
uint32_t i; uint32_t i;
uint32_t index = structTypeFieldDeclarationCount; uint32_t index = structTypeDeclarationCount;
structTypeFieldDeclarations = realloc(structTypeFieldDeclarations, sizeof(StructTypeFieldDeclaration) * (structTypeFieldDeclarationCount + 1)); structTypeDeclarations = realloc(structTypeDeclarations, sizeof(StructTypeDeclaration) * (structTypeDeclarationCount + 1));
structTypeFieldDeclarations[index].structType = wStructType; structTypeDeclarations[index].structType = wStructType;
structTypeFieldDeclarations[index].fields = NULL; structTypeDeclarations[index].structPointerType = wStructPointerType;
structTypeFieldDeclarations[index].fieldCount = 0; structTypeDeclarations[index].name = strdup(name);
structTypeFieldDeclarations[index].functions = NULL; structTypeDeclarations[index].fields = NULL;
structTypeFieldDeclarations[index].functionCount = 0; structTypeDeclarations[index].fieldCount = 0;
structTypeDeclarations[index].functions = NULL;
structTypeDeclarations[index].functionCount = 0;
for (i = 0; i < fieldDeclarationCount; i += 1) for (i = 0; i < fieldDeclarationCount; i += 1)
{ {
structTypeFieldDeclarations[index].fields = realloc(structTypeFieldDeclarations[index].fields, sizeof(StructTypeField) * (structTypeFieldDeclarations[index].fieldCount + 1)); structTypeDeclarations[index].fields = realloc(structTypeDeclarations[index].fields, sizeof(StructTypeField) * (structTypeDeclarations[index].fieldCount + 1));
structTypeFieldDeclarations[index].fields[i].name = strdup(fieldDeclarations[i]->children[1]->value.string); structTypeDeclarations[index].fields[i].name = strdup(fieldDeclarations[i]->children[1]->value.string);
structTypeFieldDeclarations[index].fields[i].index = i; structTypeDeclarations[index].fields[i].index = i;
structTypeFieldDeclarations[index].fieldCount += 1; structTypeDeclarations[index].fieldCount += 1;
} }
structTypeFieldDeclarationCount += 1; structTypeDeclarationCount += 1;
} }
static void DeclareStructFunction( static void DeclareStructFunction(
LLVMTypeRef wStructType, LLVMTypeRef wStructPointerType,
LLVMValueRef function, LLVMValueRef function,
LLVMTypeRef returnType, LLVMTypeRef returnType,
uint8_t isStatic, uint8_t isStatic,
@ -231,71 +252,124 @@ static void DeclareStructFunction(
) { ) {
uint32_t i, index; uint32_t i, index;
for (i = 0; i < structTypeFieldDeclarationCount; i += 1) for (i = 0; i < structTypeDeclarationCount; i += 1)
{ {
if (structTypeFieldDeclarations[i].structType == wStructType) if (structTypeDeclarations[i].structPointerType == wStructPointerType)
{ {
index = structTypeFieldDeclarations[i].functionCount; index = structTypeDeclarations[i].functionCount;
structTypeFieldDeclarations[i].functions = realloc(structTypeFieldDeclarations[i].functions, sizeof(StructTypeFunction) * (structTypeFieldDeclarations[i].functionCount + 1)); structTypeDeclarations[i].functions = realloc(structTypeDeclarations[i].functions, sizeof(StructTypeFunction) * (structTypeDeclarations[i].functionCount + 1));
structTypeFieldDeclarations[i].functions[index].name = strdup(name); structTypeDeclarations[i].functions[index].name = strdup(name);
structTypeFieldDeclarations[i].functions[index].function = function; structTypeDeclarations[i].functions[index].function = function;
structTypeFieldDeclarations[i].functions[index].returnType = returnType; structTypeDeclarations[i].functions[index].returnType = returnType;
structTypeFieldDeclarations[i].functions[index].isStatic = isStatic; structTypeDeclarations[i].functions[index].isStatic = isStatic;
structTypeFieldDeclarations[i].functionCount += 1; structTypeDeclarations[i].functionCount += 1;
return; return;
} }
} }
fprintf(stderr, "Could not find struct type for function!"); fprintf(stderr, "Could not find struct type for function!\n");
} }
static LLVMValueRef LookupFunction( static LLVMTypeRef LookupCustomType(char *name)
LLVMValueRef structPointer, {
uint32_t i;
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (strcmp(structTypeDeclarations[i].name, name) == 0)
{
return structTypeDeclarations[i].structType;
}
}
fprintf(stderr, "Could not find struct type!\n");
return NULL;
}
static LLVMValueRef LookupFunctionByType(
LLVMTypeRef structType,
char *name, char *name,
LLVMTypeRef *pReturnType, LLVMTypeRef *pReturnType,
uint8_t *pStatic uint8_t *pStatic
) { ) {
uint32_t i, j; uint32_t i, j;
for (i = 0; i < structTypeFieldDeclarationCount; i += 1) for (i = 0; i < structTypeDeclarationCount; i += 1)
{ {
if (structTypeFieldDeclarations[i].structType == LLVMTypeOf(structPointer)) if (structTypeDeclarations[i].structType == structType)
{ {
for (j = 0; j < structTypeFieldDeclarations[i].functionCount; j += 1) for (j = 0; j < structTypeDeclarations[i].functionCount; j += 1)
{ {
if (strcmp(structTypeFieldDeclarations[i].functions[j].name, name) == 0) if (strcmp(structTypeDeclarations[i].functions[j].name, name) == 0)
{ {
*pReturnType = structTypeFieldDeclarations[i].functions[j].returnType; *pReturnType = structTypeDeclarations[i].functions[j].returnType;
*pStatic = structTypeFieldDeclarations[i].functions[j].isStatic; *pStatic = structTypeDeclarations[i].functions[j].isStatic;
return structTypeFieldDeclarations[i].functions[j].function; return structTypeDeclarations[i].functions[j].function;
} }
} }
} }
} }
fprintf(stderr, "Could not find struct function!"); fprintf(stderr, "Could not find struct function!\n");
return NULL; return NULL;
} }
static LLVMValueRef LookupFunctionByPointerType(
LLVMTypeRef structPointerType,
char *name,
LLVMTypeRef *pReturnType,
uint8_t *pStatic
) {
uint32_t i, j;
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (structTypeDeclarations[i].structPointerType == structPointerType)
{
for (j = 0; j < structTypeDeclarations[i].functionCount; j += 1)
{
if (strcmp(structTypeDeclarations[i].functions[j].name, name) == 0)
{
*pReturnType = structTypeDeclarations[i].functions[j].returnType;
*pStatic = structTypeDeclarations[i].functions[j].isStatic;
return structTypeDeclarations[i].functions[j].function;
}
}
}
}
fprintf(stderr, "Could not find struct function!\n");
return NULL;
}
static LLVMValueRef LookupFunctionByInstance(
LLVMValueRef structPointer,
char *name,
LLVMTypeRef *pReturnType,
uint8_t *pStatic
) {
return LookupFunctionByPointerType(LLVMTypeOf(structPointer), name, pReturnType, pStatic);
}
static void AddStructVariablesToScope( static void AddStructVariablesToScope(
LLVMBuilderRef builder, LLVMBuilderRef builder,
LLVMValueRef structPointer LLVMValueRef structPointer
) { ) {
uint32_t i, j; uint32_t i, j;
for (i = 0; i < structTypeFieldDeclarationCount; i += 1) for (i = 0; i < structTypeDeclarationCount; i += 1)
{ {
if (structTypeFieldDeclarations[i].structType == LLVMTypeOf(structPointer)) if (structTypeDeclarations[i].structPointerType == LLVMTypeOf(structPointer))
{ {
for (j = 0; j < structTypeFieldDeclarations[i].fieldCount; j += 1) for (j = 0; j < structTypeDeclarations[i].fieldCount; j += 1)
{ {
char *ptrName = strdup(structTypeFieldDeclarations[i].fields[j].name); char *ptrName = strdup(structTypeDeclarations[i].fields[j].name);
strcat(ptrName, "_ptr"); strcat(ptrName, "_ptr");
LLVMValueRef elementPointer = LLVMBuildStructGEP( LLVMValueRef elementPointer = LLVMBuildStructGEP(
builder, builder,
structPointer, structPointer,
structTypeFieldDeclarations[i].fields[j].index, structTypeDeclarations[i].fields[j].index,
ptrName ptrName
); );
free(ptrName); free(ptrName);
@ -303,7 +377,7 @@ static void AddStructVariablesToScope(
AddLocalVariable( AddLocalVariable(
scope, scope,
elementPointer, elementPointer,
structTypeFieldDeclarations[i].fields[j].name structTypeDeclarations[i].fields[j].name
); );
} }
} }
@ -315,38 +389,6 @@ static LLVMValueRef CompileExpression(
Node *binaryExpression Node *binaryExpression
); );
typedef struct CustomTypeMap
{
LLVMTypeRef type;
char *name;
} CustomTypeMap;
CustomTypeMap *customTypes;
uint32_t customTypeCount;
static void RegisterCustomType(LLVMTypeRef type, char *name)
{
customTypes = realloc(customTypes, sizeof(CustomType) * (customTypeCount + 1));
customTypes[customTypeCount].type = type;
customTypes[customTypeCount].name = strdup(name);
customTypeCount += 1;
}
static LLVMTypeRef LookupCustomType(char *name)
{
uint32_t i;
for (i = 0; i < customTypeCount; i += 1)
{
if (strcmp(customTypes[i].name, name) == 0)
{
return customTypes[i].type;
}
}
return NULL;
}
static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type) static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type)
{ {
switch (type) switch (type)
@ -404,7 +446,7 @@ static LLVMValueRef CompileFunctionCallExpression(
) { ) {
uint32_t i; uint32_t i;
uint32_t argumentCount = 0; uint32_t argumentCount = 0;
LLVMValueRef args[argumentCount]; LLVMValueRef args[expression->children[1]->childCount + 1];
LLVMValueRef function; LLVMValueRef function;
uint8_t isStatic; uint8_t isStatic;
LLVMValueRef structInstance; LLVMValueRef structInstance;
@ -414,8 +456,24 @@ static LLVMValueRef CompileFunctionCallExpression(
/* FIXME: this needs to be recursive on access chains */ /* FIXME: this needs to be recursive on access chains */
if (expression->children[0]->syntaxKind == AccessExpression) if (expression->children[0]->syntaxKind == AccessExpression)
{ {
structInstance = FindVariablePointer(expression->children[0]->children[0]->value.string); LLVMTypeRef typeReference = FindStructType(
function = LookupFunction(structInstance, expression->children[0]->children[1]->value.string, &functionReturnType, &isStatic); expression->children[0]->children[0]->value.string
);
if (typeReference != NULL)
{
function = LookupFunctionByType(
typeReference,
expression->children[0]->children[1]->value.string,
&functionReturnType,
&isStatic
);
}
else
{
structInstance = FindVariablePointer(expression->children[0]->children[0]->value.string);
function = LookupFunctionByInstance(structInstance, expression->children[0]->children[1]->value.string, &functionReturnType, &isStatic);
}
} }
else else
{ {
@ -641,7 +699,7 @@ static void CompileFunction(
{ {
char *ptrName = strdup(functionSignature->children[2]->children[i]->children[1]->value.string); char *ptrName = strdup(functionSignature->children[2]->children[i]->children[1]->value.string);
strcat(ptrName, "_ptr"); strcat(ptrName, "_ptr");
LLVMValueRef argument = LLVMGetParam(function, i + 1); LLVMValueRef argument = LLVMGetParam(function, i + !isStatic);
LLVMValueRef argumentCopy = LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName); LLVMValueRef argumentCopy = LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName);
LLVMBuildStore(builder, argument, argumentCopy); LLVMBuildStore(builder, argument, argumentCopy);
free(ptrName); free(ptrName);
@ -680,8 +738,8 @@ static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *no
PushScopeFrame(scope); PushScopeFrame(scope);
LLVMTypeRef wStruct = LLVMStructCreateNamed(context, structName); LLVMTypeRef wStructType = LLVMStructCreateNamed(context, structName);
LLVMTypeRef wStructPointerType = LLVMPointerType(wStruct, 0); /* FIXME: is this address space correct? */ LLVMTypeRef wStructPointerType = LLVMPointerType(wStructType, 0); /* FIXME: is this address space correct? */
/* first, build the structure definition */ /* first, build the structure definition */
for (i = 0; i < declarationCount; i += 1) for (i = 0; i < declarationCount; i += 1)
@ -698,9 +756,8 @@ static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *no
} }
} }
LLVMStructSetBody(wStruct, types, fieldCount, packed); LLVMStructSetBody(wStructType, types, fieldCount, packed);
AddStructDeclaration(wStructPointerType, fieldDeclarations, fieldCount); AddStructDeclaration(wStructType, wStructPointerType, node->children[0]->value.string, fieldDeclarations, fieldCount);
RegisterCustomType(wStruct, node->children[0]->value.string);
/* now we can wire up the functions */ /* now we can wire up the functions */
for (i = 0; i < declarationCount; i += 1) for (i = 0; i < declarationCount; i += 1)
@ -748,11 +805,8 @@ int main(int argc, char *argv[])
scope = CreateScope(); scope = CreateScope();
structTypeFieldDeclarations = NULL; structTypeDeclarations = NULL;
structTypeFieldDeclarationCount = 0; structTypeDeclarationCount = 0;
customTypes = NULL;
customTypeCount = 0;
stack = CreateStack(); stack = CreateStack();