optimize struct field reads

generics
cosmonaut 2021-04-20 19:40:39 -07:00
parent 180583d772
commit 48d049b6c9
1 changed files with 55 additions and 27 deletions

View File

@ -19,7 +19,9 @@ typedef struct StructFieldMapValue
char *name; char *name;
LLVMValueRef value; LLVMValueRef value;
LLVMValueRef valuePointer; LLVMValueRef valuePointer;
uint32_t index;
uint8_t needsWrite; uint8_t needsWrite;
uint8_t needsRead;
} StructFieldMapValue; } StructFieldMapValue;
typedef struct StructFieldMap typedef struct StructFieldMap
@ -41,7 +43,7 @@ static void AddStruct(LLVMValueRef wStructPointer)
structFieldMapCount += 1; structFieldMapCount += 1;
} }
static void AddStructField(LLVMBuilderRef builder, LLVMValueRef wStructPointer, char *name, uint32_t index) static void AddStructFieldName(LLVMBuilderRef builder, LLVMValueRef wStructPointer, char *name, uint32_t index)
{ {
uint32_t i, fieldCount; uint32_t i, fieldCount;
@ -51,18 +53,13 @@ static void AddStructField(LLVMBuilderRef builder, LLVMValueRef wStructPointer,
{ {
fieldCount = structFieldMaps[i].fieldCount; fieldCount = structFieldMaps[i].fieldCount;
LLVMValueRef elementPointer = LLVMBuildStructGEP(
builder,
wStructPointer,
fieldCount,
"ptr"
);
structFieldMaps[i].fields = realloc(structFieldMaps[i].fields, sizeof(StructFieldMapValue) * (fieldCount + 1)); structFieldMaps[i].fields = realloc(structFieldMaps[i].fields, sizeof(StructFieldMapValue) * (fieldCount + 1));
structFieldMaps[i].fields[fieldCount].name = strdup(name); structFieldMaps[i].fields[fieldCount].name = strdup(name);
structFieldMaps[i].fields[fieldCount].value = LLVMBuildLoad(builder, elementPointer, name); structFieldMaps[i].fields[fieldCount].value = NULL;
structFieldMaps[i].fields[fieldCount].valuePointer = elementPointer; structFieldMaps[i].fields[fieldCount].valuePointer = NULL;
structFieldMaps[i].fields[fieldCount].index = index;
structFieldMaps[i].fields[fieldCount].needsWrite = 0; structFieldMaps[i].fields[fieldCount].needsWrite = 0;
structFieldMaps[i].fields[fieldCount].needsRead = 1;
structFieldMaps[i].fieldCount += 1; structFieldMaps[i].fieldCount += 1;
break; break;
@ -70,6 +67,44 @@ static void AddStructField(LLVMBuilderRef builder, LLVMValueRef wStructPointer,
} }
} }
static LLVMValueRef CheckStructFieldAndLoad(LLVMBuilderRef builder, LLVMValueRef wStructPointer, char *name)
{
uint32_t i, j;
for (i = 0; i < structFieldMapCount; i += 1)
{
if (structFieldMaps[i].structPointer == wStructPointer)
{
for (j = 0; j < structFieldMaps[i].fieldCount; j += 1)
{
if (strcmp(structFieldMaps[i].fields[j].name, name) == 0)
{
if (structFieldMaps[i].fields[j].needsRead)
{
char *ptrName = strdup(name);
strcat(ptrName, "_ptr");
LLVMValueRef elementPointer = LLVMBuildStructGEP(
builder,
wStructPointer,
structFieldMaps[i].fields[j].index,
ptrName
);
free(ptrName);
structFieldMaps[i].fields[j].value = LLVMBuildLoad(builder, elementPointer, name);
structFieldMaps[i].fields[j].valuePointer = elementPointer;
structFieldMaps[i].fields[j].needsRead = 0;
}
return structFieldMaps[i].fields[j].value;
}
}
}
}
return NULL;
}
static void MarkStructFieldForWrite(LLVMValueRef wStructPointer, LLVMValueRef value) static void MarkStructFieldForWrite(LLVMValueRef wStructPointer, LLVMValueRef value)
{ {
uint32_t i, j; uint32_t i, j;
@ -168,9 +203,10 @@ static void AddNamedVariable(char *name, LLVMValueRef variable)
namedVariableCount += 1; namedVariableCount += 1;
} }
static LLVMValueRef FindVariableByName(LLVMValueRef wStructValue, LLVMBuilderRef builder, char *name) static LLVMValueRef FindVariableByName(LLVMBuilderRef builder, LLVMValueRef wStructValue, char *name)
{ {
uint32_t i, j; uint32_t i, j;
LLVMValueRef searchResult;
/* first, search scoped vars */ /* first, search scoped vars */
for (i = 0; i < namedVariableCount; i += 1) for (i = 0; i < namedVariableCount; i += 1)
@ -182,22 +218,14 @@ static LLVMValueRef FindVariableByName(LLVMValueRef wStructValue, LLVMBuilderRef
} }
/* if none exist, search struct vars */ /* if none exist, search struct vars */
for (i = 0; i < structFieldMapCount; i += 1) searchResult = CheckStructFieldAndLoad(builder, wStructValue, name);
if (searchResult == NULL)
{ {
if (structFieldMaps[i].structPointer == wStructValue) fprintf(stderr, "Identifier not found!");
{
for (j = 0; j < structFieldMaps[i].fieldCount; j += 1)
{
if (strcmp(structFieldMaps[i].fields[j].name, name) == 0)
{
return structFieldMaps[i].fields[j].value;
}
}
}
} }
fprintf(stderr, "Identifier not found!"); return searchResult;
return NULL;
} }
static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type) static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type)
@ -267,7 +295,7 @@ static LLVMValueRef CompileFunctionCallExpression(
args[i] = CompileExpression(wStructValue, builder, function, expression->children[1]->children[i]); args[i] = CompileExpression(wStructValue, builder, function, expression->children[1]->children[i]);
} }
return LLVMBuildCall(builder, FindVariableByName(wStructValue, builder, expression->children[0]->value.string), args, argumentCount, "tmp"); return LLVMBuildCall(builder, FindVariableByName(builder, wStructValue, expression->children[0]->value.string), args, argumentCount, "tmp");
} }
static LLVMValueRef CompileExpression( static LLVMValueRef CompileExpression(
@ -287,7 +315,7 @@ static LLVMValueRef CompileExpression(
return CompileFunctionCallExpression(wStructValue, builder, function, expression); return CompileFunctionCallExpression(wStructValue, builder, function, expression);
case Identifier: case Identifier:
return FindVariableByName(wStructValue, builder, expression->value.string); return FindVariableByName(builder, wStructValue, expression->value.string);
case Number: case Number:
return CompileNumber(expression); return CompileNumber(expression);
@ -380,7 +408,7 @@ static void CompileFunction(
for (i = 0; i < fieldDeclarationCount; i += 1) for (i = 0; i < fieldDeclarationCount; i += 1)
{ {
AddStructField(builder, wStructPointer, fieldDeclarations[i]->children[1]->value.string, i); AddStructFieldName(builder, wStructPointer, fieldDeclarations[i]->children[1]->value.string, i);
} }
for (i = 0; i < functionBody->childCount; i += 1) for (i = 0; i < functionBody->childCount; i += 1)