diff --git a/compiler.c b/compiler.c index 7531728..9249fa6 100644 --- a/compiler.c +++ b/compiler.c @@ -61,18 +61,20 @@ typedef struct StructTypeFunction uint8_t isStatic; } StructTypeFunction; -typedef struct StructTypeFieldDeclaration +typedef struct StructTypeDeclaration { + char *name; LLVMTypeRef structType; + LLVMTypeRef structPointerType; StructTypeField *fields; uint32_t fieldCount; StructTypeFunction *functions; uint32_t functionCount; -} StructTypeFieldDeclaration; +} StructTypeDeclaration; -StructTypeFieldDeclaration *structTypeFieldDeclarations; -uint32_t structTypeFieldDeclarationCount; +StructTypeDeclaration *structTypeDeclarations; +uint32_t structTypeDeclarationCount; static Scope* CreateScope() { @@ -127,26 +129,41 @@ static void AddLocalVariable(Scope *scope, LLVMValueRef pointer, char *name) scopeFrame->localVariableCount += 1; } +static LLVMTypeRef FindStructType(char *name) +{ + uint32_t i; + + for (i = 0; i < structTypeDeclarationCount; i += 1) + { + if (strcmp(structTypeDeclarations[i].name, name) == 0) + { + return structTypeDeclarations[i].structType; + } + } + + return NULL; +} + static LLVMValueRef FindStructFieldPointer(LLVMBuilderRef builder, LLVMValueRef structPointer, char *name) { int32_t i, j; - LLVMTypeRef structType = LLVMTypeOf(structPointer); + LLVMTypeRef structPointerType = LLVMTypeOf(structPointer); - for (i = 0; i < structTypeFieldDeclarationCount; i += 1) + for (i = 0; i < structTypeDeclarationCount; i += 1) { - if (structTypeFieldDeclarations[i].structType == structType) + if (structTypeDeclarations[i].structPointerType == structPointerType) { - for (j = 0; j < structTypeFieldDeclarations[i].fieldCount; j += 1) + for (j = 0; j < structTypeDeclarations[i].fieldCount; j += 1) { - if (strcmp(structTypeFieldDeclarations[i].fields[j].name, name) == 0) + if (strcmp(structTypeDeclarations[i].fields[j].name, name) == 0) { char *ptrName = strdup(name); strcat(ptrName, "_ptr"); return LLVMBuildStructGEP( builder, structPointer, - structTypeFieldDeclarations[i].fields[j].index, + structTypeDeclarations[i].fields[j].index, ptrName ); free(ptrName); @@ -199,31 +216,35 @@ static LLVMValueRef FindVariableValue(LLVMBuilderRef builder, char *name) static void AddStructDeclaration( LLVMTypeRef wStructType, + LLVMTypeRef wStructPointerType, + char *name, Node **fieldDeclarations, uint32_t fieldDeclarationCount ) { uint32_t i; - uint32_t index = structTypeFieldDeclarationCount; - structTypeFieldDeclarations = realloc(structTypeFieldDeclarations, sizeof(StructTypeFieldDeclaration) * (structTypeFieldDeclarationCount + 1)); - structTypeFieldDeclarations[index].structType = wStructType; - structTypeFieldDeclarations[index].fields = NULL; - structTypeFieldDeclarations[index].fieldCount = 0; - structTypeFieldDeclarations[index].functions = NULL; - structTypeFieldDeclarations[index].functionCount = 0; + uint32_t index = structTypeDeclarationCount; + structTypeDeclarations = realloc(structTypeDeclarations, sizeof(StructTypeDeclaration) * (structTypeDeclarationCount + 1)); + structTypeDeclarations[index].structType = wStructType; + structTypeDeclarations[index].structPointerType = wStructPointerType; + structTypeDeclarations[index].name = strdup(name); + structTypeDeclarations[index].fields = NULL; + structTypeDeclarations[index].fieldCount = 0; + structTypeDeclarations[index].functions = NULL; + structTypeDeclarations[index].functionCount = 0; for (i = 0; i < fieldDeclarationCount; i += 1) { - structTypeFieldDeclarations[index].fields = realloc(structTypeFieldDeclarations[index].fields, sizeof(StructTypeField) * (structTypeFieldDeclarations[index].fieldCount + 1)); - structTypeFieldDeclarations[index].fields[i].name = strdup(fieldDeclarations[i]->children[1]->value.string); - structTypeFieldDeclarations[index].fields[i].index = i; - structTypeFieldDeclarations[index].fieldCount += 1; + structTypeDeclarations[index].fields = realloc(structTypeDeclarations[index].fields, sizeof(StructTypeField) * (structTypeDeclarations[index].fieldCount + 1)); + structTypeDeclarations[index].fields[i].name = strdup(fieldDeclarations[i]->children[1]->value.string); + structTypeDeclarations[index].fields[i].index = i; + structTypeDeclarations[index].fieldCount += 1; } - structTypeFieldDeclarationCount += 1; + structTypeDeclarationCount += 1; } static void DeclareStructFunction( - LLVMTypeRef wStructType, + LLVMTypeRef wStructPointerType, LLVMValueRef function, LLVMTypeRef returnType, uint8_t isStatic, @@ -231,71 +252,124 @@ static void DeclareStructFunction( ) { uint32_t i, index; - for (i = 0; i < structTypeFieldDeclarationCount; i += 1) + for (i = 0; i < structTypeDeclarationCount; i += 1) { - if (structTypeFieldDeclarations[i].structType == wStructType) + if (structTypeDeclarations[i].structPointerType == wStructPointerType) { - index = structTypeFieldDeclarations[i].functionCount; - structTypeFieldDeclarations[i].functions = realloc(structTypeFieldDeclarations[i].functions, sizeof(StructTypeFunction) * (structTypeFieldDeclarations[i].functionCount + 1)); - structTypeFieldDeclarations[i].functions[index].name = strdup(name); - structTypeFieldDeclarations[i].functions[index].function = function; - structTypeFieldDeclarations[i].functions[index].returnType = returnType; - structTypeFieldDeclarations[i].functions[index].isStatic = isStatic; - structTypeFieldDeclarations[i].functionCount += 1; + index = structTypeDeclarations[i].functionCount; + structTypeDeclarations[i].functions = realloc(structTypeDeclarations[i].functions, sizeof(StructTypeFunction) * (structTypeDeclarations[i].functionCount + 1)); + structTypeDeclarations[i].functions[index].name = strdup(name); + structTypeDeclarations[i].functions[index].function = function; + structTypeDeclarations[i].functions[index].returnType = returnType; + structTypeDeclarations[i].functions[index].isStatic = isStatic; + structTypeDeclarations[i].functionCount += 1; return; } } - fprintf(stderr, "Could not find struct type for function!"); + fprintf(stderr, "Could not find struct type for function!\n"); } -static LLVMValueRef LookupFunction( - LLVMValueRef structPointer, +static LLVMTypeRef LookupCustomType(char *name) +{ + uint32_t i; + + for (i = 0; i < structTypeDeclarationCount; i += 1) + { + if (strcmp(structTypeDeclarations[i].name, name) == 0) + { + return structTypeDeclarations[i].structType; + } + } + + fprintf(stderr, "Could not find struct type!\n"); + return NULL; +} + +static LLVMValueRef LookupFunctionByType( + LLVMTypeRef structType, char *name, LLVMTypeRef *pReturnType, uint8_t *pStatic ) { uint32_t i, j; - for (i = 0; i < structTypeFieldDeclarationCount; i += 1) + for (i = 0; i < structTypeDeclarationCount; i += 1) { - if (structTypeFieldDeclarations[i].structType == LLVMTypeOf(structPointer)) + if (structTypeDeclarations[i].structType == structType) { - for (j = 0; j < structTypeFieldDeclarations[i].functionCount; j += 1) + for (j = 0; j < structTypeDeclarations[i].functionCount; j += 1) { - if (strcmp(structTypeFieldDeclarations[i].functions[j].name, name) == 0) + if (strcmp(structTypeDeclarations[i].functions[j].name, name) == 0) { - *pReturnType = structTypeFieldDeclarations[i].functions[j].returnType; - *pStatic = structTypeFieldDeclarations[i].functions[j].isStatic; - return structTypeFieldDeclarations[i].functions[j].function; + *pReturnType = structTypeDeclarations[i].functions[j].returnType; + *pStatic = structTypeDeclarations[i].functions[j].isStatic; + return structTypeDeclarations[i].functions[j].function; } } } } - fprintf(stderr, "Could not find struct function!"); + fprintf(stderr, "Could not find struct function!\n"); return NULL; } +static LLVMValueRef LookupFunctionByPointerType( + LLVMTypeRef structPointerType, + char *name, + LLVMTypeRef *pReturnType, + uint8_t *pStatic +) { + uint32_t i, j; + + for (i = 0; i < structTypeDeclarationCount; i += 1) + { + if (structTypeDeclarations[i].structPointerType == structPointerType) + { + for (j = 0; j < structTypeDeclarations[i].functionCount; j += 1) + { + if (strcmp(structTypeDeclarations[i].functions[j].name, name) == 0) + { + *pReturnType = structTypeDeclarations[i].functions[j].returnType; + *pStatic = structTypeDeclarations[i].functions[j].isStatic; + return structTypeDeclarations[i].functions[j].function; + } + } + } + } + + fprintf(stderr, "Could not find struct function!\n"); + return NULL; +} + +static LLVMValueRef LookupFunctionByInstance( + LLVMValueRef structPointer, + char *name, + LLVMTypeRef *pReturnType, + uint8_t *pStatic +) { + return LookupFunctionByPointerType(LLVMTypeOf(structPointer), name, pReturnType, pStatic); +} + static void AddStructVariablesToScope( LLVMBuilderRef builder, LLVMValueRef structPointer ) { uint32_t i, j; - for (i = 0; i < structTypeFieldDeclarationCount; i += 1) + for (i = 0; i < structTypeDeclarationCount; i += 1) { - if (structTypeFieldDeclarations[i].structType == LLVMTypeOf(structPointer)) + if (structTypeDeclarations[i].structPointerType == LLVMTypeOf(structPointer)) { - for (j = 0; j < structTypeFieldDeclarations[i].fieldCount; j += 1) + for (j = 0; j < structTypeDeclarations[i].fieldCount; j += 1) { - char *ptrName = strdup(structTypeFieldDeclarations[i].fields[j].name); + char *ptrName = strdup(structTypeDeclarations[i].fields[j].name); strcat(ptrName, "_ptr"); LLVMValueRef elementPointer = LLVMBuildStructGEP( builder, structPointer, - structTypeFieldDeclarations[i].fields[j].index, + structTypeDeclarations[i].fields[j].index, ptrName ); free(ptrName); @@ -303,7 +377,7 @@ static void AddStructVariablesToScope( AddLocalVariable( scope, elementPointer, - structTypeFieldDeclarations[i].fields[j].name + structTypeDeclarations[i].fields[j].name ); } } @@ -315,38 +389,6 @@ static LLVMValueRef CompileExpression( Node *binaryExpression ); -typedef struct CustomTypeMap -{ - LLVMTypeRef type; - char *name; -} CustomTypeMap; - -CustomTypeMap *customTypes; -uint32_t customTypeCount; - -static void RegisterCustomType(LLVMTypeRef type, char *name) -{ - customTypes = realloc(customTypes, sizeof(CustomType) * (customTypeCount + 1)); - customTypes[customTypeCount].type = type; - customTypes[customTypeCount].name = strdup(name); - customTypeCount += 1; -} - -static LLVMTypeRef LookupCustomType(char *name) -{ - uint32_t i; - - for (i = 0; i < customTypeCount; i += 1) - { - if (strcmp(customTypes[i].name, name) == 0) - { - return customTypes[i].type; - } - } - - return NULL; -} - static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type) { switch (type) @@ -404,7 +446,7 @@ static LLVMValueRef CompileFunctionCallExpression( ) { uint32_t i; uint32_t argumentCount = 0; - LLVMValueRef args[argumentCount]; + LLVMValueRef args[expression->children[1]->childCount + 1]; LLVMValueRef function; uint8_t isStatic; LLVMValueRef structInstance; @@ -414,8 +456,24 @@ static LLVMValueRef CompileFunctionCallExpression( /* FIXME: this needs to be recursive on access chains */ if (expression->children[0]->syntaxKind == AccessExpression) { - structInstance = FindVariablePointer(expression->children[0]->children[0]->value.string); - function = LookupFunction(structInstance, expression->children[0]->children[1]->value.string, &functionReturnType, &isStatic); + LLVMTypeRef typeReference = FindStructType( + expression->children[0]->children[0]->value.string + ); + + if (typeReference != NULL) + { + function = LookupFunctionByType( + typeReference, + expression->children[0]->children[1]->value.string, + &functionReturnType, + &isStatic + ); + } + else + { + structInstance = FindVariablePointer(expression->children[0]->children[0]->value.string); + function = LookupFunctionByInstance(structInstance, expression->children[0]->children[1]->value.string, &functionReturnType, &isStatic); + } } else { @@ -641,7 +699,7 @@ static void CompileFunction( { char *ptrName = strdup(functionSignature->children[2]->children[i]->children[1]->value.string); strcat(ptrName, "_ptr"); - LLVMValueRef argument = LLVMGetParam(function, i + 1); + LLVMValueRef argument = LLVMGetParam(function, i + !isStatic); LLVMValueRef argumentCopy = LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName); LLVMBuildStore(builder, argument, argumentCopy); free(ptrName); @@ -680,8 +738,8 @@ static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *no PushScopeFrame(scope); - LLVMTypeRef wStruct = LLVMStructCreateNamed(context, structName); - LLVMTypeRef wStructPointerType = LLVMPointerType(wStruct, 0); /* FIXME: is this address space correct? */ + LLVMTypeRef wStructType = LLVMStructCreateNamed(context, structName); + LLVMTypeRef wStructPointerType = LLVMPointerType(wStructType, 0); /* FIXME: is this address space correct? */ /* first, build the structure definition */ for (i = 0; i < declarationCount; i += 1) @@ -698,9 +756,8 @@ static void CompileStruct(LLVMModuleRef module, LLVMContextRef context, Node *no } } - LLVMStructSetBody(wStruct, types, fieldCount, packed); - AddStructDeclaration(wStructPointerType, fieldDeclarations, fieldCount); - RegisterCustomType(wStruct, node->children[0]->value.string); + LLVMStructSetBody(wStructType, types, fieldCount, packed); + AddStructDeclaration(wStructType, wStructPointerType, node->children[0]->value.string, fieldDeclarations, fieldCount); /* now we can wire up the functions */ for (i = 0; i < declarationCount; i += 1) @@ -748,11 +805,8 @@ int main(int argc, char *argv[]) scope = CreateScope(); - structTypeFieldDeclarations = NULL; - structTypeFieldDeclarationCount = 0; - - customTypes = NULL; - customTypeCount = 0; + structTypeDeclarations = NULL; + structTypeDeclarationCount = 0; stack = CreateStack();