diff --git a/compiler.c b/compiler.c index 56618b3..b15ceae 100644 --- a/compiler.c +++ b/compiler.c @@ -70,6 +70,26 @@ static void AddStructField(LLVMBuilderRef builder, LLVMValueRef wStructPointer, } } +static void MarkStructFieldForWrite(LLVMValueRef wStructPointer, LLVMValueRef value) +{ + uint32_t i, j; + + for (i = 0; i < structFieldMapCount; i += 1) + { + if (structFieldMaps[i].structPointer == wStructPointer) + { + for (j = 0; j < structFieldMaps[i].fieldCount; j += 1) + { + if (structFieldMaps[i].fields[j].value == value) + { + structFieldMaps[i].fields[j].needsWrite = 1; + break; + } + } + } + } +} + static LLVMValueRef GetStructFieldPointer(LLVMValueRef wStructPointer, LLVMValueRef value) { uint32_t i, j; @@ -91,14 +111,26 @@ static LLVMValueRef GetStructFieldPointer(LLVMValueRef wStructPointer, LLVMValue return NULL; } -static void RemoveStruct(LLVMValueRef wStructPointer) +static void RemoveStruct(LLVMBuilderRef builder, LLVMValueRef wStructPointer) { - uint32_t i; + uint32_t i, j; for (i = 0; i < structFieldMapCount; i += 1) { if (structFieldMaps[i].structPointer == wStructPointer) { + for (j = 0; j < structFieldMaps[i].fieldCount; j += 1) + { + if (structFieldMaps[i].fields[j].needsWrite) + { + LLVMBuildStore( + builder, + structFieldMaps[i].fields[j].value, + structFieldMaps[i].fields[j].valuePointer + ); + } + } + free(structFieldMaps[i].fields); structFieldMaps[i].fields = NULL; structFieldMaps[i].fieldCount = 0; @@ -281,11 +313,7 @@ static void CompileAssignment(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef result = CompileExpression(wStructValue, builder, function, assignmentStatement->children[1]); LLVMValueRef identifier = CompileExpression(wStructValue, builder, function, assignmentStatement->children[0]); - fieldPointer = GetStructFieldPointer(wStructValue, identifier); - if (fieldPointer != NULL) - { - LLVMBuildStore(builder, result, fieldPointer); - } + MarkStructFieldForWrite(wStructValue, identifier); } static uint8_t CompileStatement(LLVMValueRef wStructValue, LLVMBuilderRef builder, LLVMValueRef function, Node *statement) @@ -360,7 +388,7 @@ static void CompileFunction( hasReturn |= CompileStatement(wStructPointer, builder, function, functionBody->children[i]); } - AddNamedVariable(functionSignature->children[0]->value.string, function); + RemoveStruct(builder, wStructPointer); if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn) { @@ -370,8 +398,6 @@ static void CompileFunction( { fprintf(stderr, "Return statement not provided!"); } - - RemoveStruct(wStructPointer); } static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *node)