From 12ac9cc9808de6ac0e65d2fff835fe38e8c829f4 Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Thu, 3 Jun 2021 17:48:16 -0700 Subject: [PATCH] allow function self reference --- generic.w | 19 +++++-- src/ast.c | 5 +- src/codegen.c | 151 ++++++++++++++++++++++++++++++++++++++------------ src/util.c | 11 +++- src/util.h | 1 + 5 files changed, 146 insertions(+), 41 deletions(-) diff --git a/generic.w b/generic.w index 51fb64f..c127e17 100644 --- a/generic.w +++ b/generic.w @@ -14,9 +14,19 @@ struct MemoryBlock start: MemoryAddress; capacity: uint; - AddressOf(count: uint): MemoryAddress + AddressOf(index: uint): MemoryAddress { - return start + (count * @sizeof()); + return start + (index * @sizeof()); + } + + Get(index: uint): T + { + return @bitcast(AddressOf(index)); + } + + Free(): void + { + @free(start); } } @@ -24,13 +34,14 @@ struct Program { static Main(): int { x: int = 4; y: int = Foo.Func(x); - block: MemoryBlock; + block: MemoryBlock; block.capacity = y; block.start = @malloc(y * @sizeof()); z: MemoryAddress = block.AddressOf(2); + Console.PrintLine("%u", block.Get(0)); Console.PrintLine("%p", block.start); Console.PrintLine("%p", z); - @free(block.start); + block.Free(); return 0; } } diff --git a/src/ast.c b/src/ast.c index d761850..313a51d 100644 --- a/src/ast.c +++ b/src/ast.c @@ -1169,7 +1169,8 @@ char *TypeTagToString(TypeTag *tag) { char *result = strdup(tag->value.concreteGenericType.name); uint32_t len = strlen(result); - result = realloc(result, len + 2); + len += 2; + result = realloc(result, sizeof(char) * len); strcat(result, "<"); for (i = 0; i < tag->value.concreteGenericType.genericArgumentCount; @@ -1185,7 +1186,7 @@ char *TypeTagToString(TypeTag *tag) } strcat(result, inner); } - result = realloc(result, sizeof(char) * (len + 2)); + result = realloc(result, sizeof(char) * (len + 1)); strcat(result, ">"); return result; } diff --git a/src/codegen.c b/src/codegen.c index fad7286..991098f 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -152,12 +152,14 @@ uint32_t systemFunctionCount; /* FUNCTION FORWARD DECLARATIONS */ static LLVMBasicBlockRef CompileStatement( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, LLVMValueRef function, Node *statement); static LLVMValueRef CompileExpression( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, Node *expression); @@ -282,7 +284,7 @@ static void AddStructVariablesToScope( for (i = 0; i < structTypeDeclaration->fieldCount; i += 1) { char *ptrName = strdup(structTypeDeclaration->fields[i].name); - strcat(ptrName, "_ptr"); /* FIXME: needs to be realloc'd */ + ptrName = w_strcat(ptrName, "_ptr"); LLVMValueRef elementPointer = LLVMBuildStructGEP( builder, structPointer, @@ -441,7 +443,6 @@ static StructTypeDeclaration *CompileMonomorphizedGenericStruct( uint32_t genericArgumentTypeCount) { uint32_t i = 0; - uint32_t nameLen; uint32_t fieldCount = 0; Node *structDeclarationNode = genericStructTypeDeclaration->structDeclarationNode; @@ -464,15 +465,12 @@ static StructTypeDeclaration *CompileMonomorphizedGenericStruct( char *structName = strdup( structDeclarationNode->structDeclaration.identifier->identifier.name); - nameLen = strlen(structName); for (i = 0; i < genericArgumentTypeCount; i += 1) { char *inner = TypeTagToString(genericArgumentTypes[i]); - nameLen += 2 + strlen(inner); - structName = realloc(structName, sizeof(char) * nameLen); - strcat(structName, "_"); - strcat(structName, inner); + structName = w_strcat(structName, "_"); + structName = w_strcat(structName, inner); } LLVMContextRef context = @@ -486,6 +484,8 @@ static StructTypeDeclaration *CompileMonomorphizedGenericStruct( wStructPointerType, structName); + free(structName); + /* first build the structure def */ for (i = 0; i < declarationCount; i += 1) { @@ -712,7 +712,7 @@ static LLVMValueRef FindStructFieldPointer( if (strcmp(structTypeDeclarations[i].fields[j].name, name) == 0) { char *ptrName = strdup(name); - strcat(ptrName, "_ptr"); + ptrName = w_strcat(ptrName, "_ptr"); return LLVMBuildStructGEP( builder, structPointer, @@ -847,6 +847,7 @@ static StructTypeFunction CompileGenericFunction( LLVMTypeRef paramTypes[argumentCount + 1]; uint32_t paramIndex = 0; LLVMTypeRef returnType; + LLVMValueRef wStructPointer = NULL; PushScopeFrame(scope); @@ -878,16 +879,17 @@ static StructTypeFunction CompileGenericFunction( } } - /* FIXME: these cats need to be realloc'd */ char *functionName = strdup(structTypeDeclaration->name); - strcat(functionName, "_"); - strcat( + functionName = w_strcat(functionName, "_"); + functionName = w_strcat( functionName, functionSignature->functionSignature.identifier->identifier.name); for (i = 0; i < genericArgumentTypeCount; i += 1) { - strcat(functionName, TypeTagToString(resolvedGenericArgumentTypes[i])); + functionName = w_strcat( + functionName, + TypeTagToString(resolvedGenericArgumentTypes[i])); } if (!isStatic) @@ -925,7 +927,7 @@ static StructTypeFunction CompileGenericFunction( if (!isStatic) { - LLVMValueRef wStructPointer = LLVMGetParam(function, 0); + wStructPointer = LLVMGetParam(function, 0); AddStructVariablesToScope( structTypeDeclaration, builder, @@ -939,7 +941,7 @@ static StructTypeFunction CompileGenericFunction( char *ptrName = strdup(functionSignature->functionSignature.arguments ->functionSignatureArguments.sequence[i] ->declaration.identifier->identifier.name); - strcat(ptrName, "_ptr"); + ptrName = w_strcat(ptrName, "_ptr"); LLVMValueRef argument = LLVMGetParam(function, i + !isStatic); LLVMValueRef argumentCopy = LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName); @@ -958,6 +960,7 @@ static StructTypeFunction CompileGenericFunction( { CompileStatement( structTypeDeclaration, + wStructPointer, builder, function, functionBody->statementSequence.sequence[i]); @@ -1143,14 +1146,12 @@ static LLVMValueRef LookupGenericFunction( static LLVMValueRef LookupFunctionByType( LLVMTypeRef structType, + char *name, Node *functionCallExpression, LLVMTypeRef *pReturnType, uint8_t *pStatic) { uint32_t i, j; - /* FIXME: hot garbage */ - char *name = functionCallExpression->functionCallExpression.identifier - ->accessExpression.accessor->identifier.name; for (i = 0; i < structTypeDeclarationCount; i += 1) { @@ -1191,14 +1192,12 @@ static LLVMValueRef LookupFunctionByType( static LLVMValueRef LookupFunctionByPointerType( LLVMTypeRef structPointerType, + char *name, Node *functionCallExpression, LLVMTypeRef *pReturnType, uint8_t *pStatic) { uint32_t i, j; - /* FIXME: hot garbage */ - char *name = functionCallExpression->functionCallExpression.identifier - ->accessExpression.accessor->identifier.name; for (i = 0; i < structTypeDeclarationCount; i += 1) { @@ -1239,12 +1238,14 @@ static LLVMValueRef LookupFunctionByPointerType( static LLVMValueRef LookupFunctionByInstance( LLVMValueRef structPointer, + char *functionName, Node *functionCallExpression, LLVMTypeRef *pReturnType, uint8_t *pStatic) { return LookupFunctionByPointerType( LLVMTypeOf(structPointer), + functionName, functionCallExpression, pReturnType, pStatic); @@ -1267,16 +1268,19 @@ static LLVMValueRef CompileString( static LLVMValueRef CompileBinaryExpression( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, Node *binaryExpression) { LLVMValueRef left = CompileExpression( structTypeDeclaration, + selfParam, builder, binaryExpression->binaryExpression.left); LLVMValueRef right = CompileExpression( structTypeDeclaration, + selfParam, builder, binaryExpression->binaryExpression.right); @@ -1324,6 +1328,7 @@ static LLVMValueRef CompileBinaryExpression( /* FIXME THIS IS ALL BROKEN */ static LLVMValueRef CompileFunctionCallExpression( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, Node *functionCallExpression) { @@ -1335,7 +1340,7 @@ static LLVMValueRef CompileFunctionCallExpression( 1]; LLVMValueRef function; uint8_t isStatic; - LLVMValueRef structInstance; + LLVMValueRef structInstance = NULL; LLVMTypeRef functionReturnType; char *returnName = ""; @@ -1349,10 +1354,15 @@ static LLVMValueRef CompileFunctionCallExpression( functionCallExpression->functionCallExpression.identifier ->accessExpression.accessee->identifier.name); + char *functionName = + functionCallExpression->functionCallExpression.identifier + ->accessExpression.accessor->identifier.name; + if (typeReference != NULL) { function = LookupFunctionByType( typeReference, + functionName, functionCallExpression, &functionReturnType, &isStatic); @@ -1362,16 +1372,38 @@ static LLVMValueRef CompileFunctionCallExpression( structInstance = FindVariablePointer( functionCallExpression->functionCallExpression.identifier ->accessExpression.accessee->identifier.name); + function = LookupFunctionByInstance( structInstance, + functionName, functionCallExpression, &functionReturnType, &isStatic); } } + else if ( + functionCallExpression->functionCallExpression.identifier->syntaxKind == + Identifier) + { + LLVMTypeRef structType = structTypeDeclaration->structType; + char *functionName = functionCallExpression->functionCallExpression + .identifier->identifier.name; + + function = LookupFunctionByType( + structType, + functionName, + functionCallExpression, + &functionReturnType, + &isStatic); + + if (!isStatic) + { + structInstance = selfParam; + } + } else { - fprintf(stderr, "Failed to find function!\n"); + fprintf(stderr, "Function identifier syntax kind not recognized!\n"); return NULL; } @@ -1387,6 +1419,7 @@ static LLVMValueRef CompileFunctionCallExpression( { args[argumentCount] = CompileExpression( structTypeDeclaration, + selfParam, builder, functionCallExpression->functionCallExpression.argumentSequence ->functionArgumentSequence.sequence[i]); @@ -1403,6 +1436,7 @@ static LLVMValueRef CompileFunctionCallExpression( static LLVMValueRef CompileSystemCallExpression( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, Node *systemCallExpression) { @@ -1418,6 +1452,7 @@ static LLVMValueRef CompileSystemCallExpression( { args[i] = CompileExpression( structTypeDeclaration, + selfParam, builder, systemCallExpression->systemCall.argumentSequence ->functionArgumentSequence.sequence[i]); @@ -1439,6 +1474,29 @@ static LLVMValueRef CompileSystemCallExpression( return LLVMSizeOf(ResolveType(ConcretizeType(typeTag))); } + else if ( + strcmp( + systemCallExpression->systemCall.identifier->identifier.name, + "bitcast") == 0) + { + TypeTag *typeTag = ConcretizeType( + systemCallExpression->systemCall.genericArguments + ->genericArguments.arguments[0] + ->type.typeNode->typeTag); + + LLVMValueRef expression = CompileExpression( + structTypeDeclaration, + selfParam, + builder, + systemCallExpression->systemCall.argumentSequence + ->functionArgumentSequence.sequence[0]); + + return LLVMBuildBitCast( + builder, + expression, + ResolveType(typeTag), + "castResult"); + } else { fprintf(stderr, "System function not found!"); @@ -1498,6 +1556,7 @@ static LLVMValueRef CompileAllocExpression( static LLVMValueRef CompileExpression( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, Node *expression) { @@ -1512,12 +1571,14 @@ static LLVMValueRef CompileExpression( case BinaryExpression: return CompileBinaryExpression( structTypeDeclaration, + selfParam, builder, expression); case FunctionCallExpression: return CompileFunctionCallExpression( structTypeDeclaration, + selfParam, builder, expression); @@ -1533,6 +1594,7 @@ static LLVMValueRef CompileExpression( case SystemCall: return CompileSystemCallExpression( structTypeDeclaration, + selfParam, builder, expression); } @@ -1543,12 +1605,14 @@ static LLVMValueRef CompileExpression( static LLVMBasicBlockRef CompileReturn( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement) { LLVMValueRef expression = CompileExpression( structTypeDeclaration, + selfParam, builder, returnStatemement->returnStatement.expression); LLVMBuildRet(builder, expression); @@ -1573,7 +1637,7 @@ static LLVMValueRef CompileFunctionVariableDeclaration( char *variableName = variableDeclaration->declaration.identifier->identifier.name; char *ptrName = strdup(variableName); - strcat(ptrName, "_ptr"); + ptrName = w_strcat(ptrName, "_ptr"); variable = LLVMBuildAlloca( builder, @@ -1589,12 +1653,14 @@ static LLVMValueRef CompileFunctionVariableDeclaration( static LLVMBasicBlockRef CompileAssignment( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement) { LLVMValueRef result = CompileExpression( structTypeDeclaration, + selfParam, builder, assignmentStatement->assignmentStatement.right); @@ -1635,6 +1701,7 @@ static LLVMBasicBlockRef CompileAssignment( static LLVMBasicBlockRef CompileIfStatement( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement) @@ -1642,6 +1709,7 @@ static LLVMBasicBlockRef CompileIfStatement( uint32_t i; LLVMValueRef conditional = CompileExpression( structTypeDeclaration, + selfParam, builder, ifStatement->ifStatement.expression); @@ -1659,6 +1727,7 @@ static LLVMBasicBlockRef CompileIfStatement( { CompileStatement( structTypeDeclaration, + selfParam, builder, function, ifStatement->ifStatement.statementSequence->statementSequence @@ -1673,6 +1742,7 @@ static LLVMBasicBlockRef CompileIfStatement( static LLVMBasicBlockRef CompileIfElseStatement( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, LLVMValueRef function, Node *ifElseStatement) @@ -1680,6 +1750,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( uint32_t i; LLVMValueRef conditional = CompileExpression( structTypeDeclaration, + selfParam, builder, ifElseStatement->ifElseStatement.ifStatement->ifStatement.expression); @@ -1697,6 +1768,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( { CompileStatement( structTypeDeclaration, + selfParam, builder, function, ifElseStatement->ifElseStatement.ifStatement->ifStatement @@ -1716,6 +1788,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( { CompileStatement( structTypeDeclaration, + selfParam, builder, function, ifElseStatement->ifElseStatement.elseStatement @@ -1726,6 +1799,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( { CompileStatement( structTypeDeclaration, + selfParam, builder, function, ifElseStatement->ifElseStatement.elseStatement); @@ -1739,6 +1813,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( static LLVMBasicBlockRef CompileForLoopStatement( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, LLVMBuilderRef builder, LLVMValueRef function, Node *forLoopStatement) @@ -1799,6 +1874,7 @@ static LLVMBasicBlockRef CompileForLoopStatement( { lastBlock = CompileStatement( structTypeDeclaration, + selfParam, builder, function, forLoopStatement->forLoop.statementSequence->statementSequence @@ -1828,6 +1904,7 @@ static LLVMBasicBlockRef CompileForLoopStatement( static LLVMBasicBlockRef CompileStatement( StructTypeDeclaration *structTypeDeclaration, + LLVMValueRef selfParam, /* can be NULL for statics */ LLVMBuilderRef builder, LLVMValueRef function, Node *statement) @@ -1837,6 +1914,7 @@ static LLVMBasicBlockRef CompileStatement( case Assignment: return CompileAssignment( structTypeDeclaration, + selfParam, builder, function, statement); @@ -1848,6 +1926,7 @@ static LLVMBasicBlockRef CompileStatement( case ForLoop: return CompileForLoopStatement( structTypeDeclaration, + selfParam, builder, function, statement); @@ -1855,6 +1934,7 @@ static LLVMBasicBlockRef CompileStatement( case FunctionCallExpression: CompileFunctionCallExpression( structTypeDeclaration, + selfParam, builder, statement); return LLVMGetLastBasicBlock(function); @@ -1862,6 +1942,7 @@ static LLVMBasicBlockRef CompileStatement( case IfStatement: return CompileIfStatement( structTypeDeclaration, + selfParam, builder, function, statement); @@ -1869,6 +1950,7 @@ static LLVMBasicBlockRef CompileStatement( case IfElseStatement: return CompileIfElseStatement( structTypeDeclaration, + selfParam, builder, function, statement); @@ -1876,6 +1958,7 @@ static LLVMBasicBlockRef CompileStatement( case Return: return CompileReturn( structTypeDeclaration, + selfParam, builder, function, statement); @@ -1884,7 +1967,11 @@ static LLVMBasicBlockRef CompileStatement( return CompileReturnVoid(builder, function); case SystemCall: - CompileSystemCallExpression(structTypeDeclaration, builder, statement); + CompileSystemCallExpression( + structTypeDeclaration, + selfParam, + builder, + statement); return LLVMGetLastBasicBlock(function); } @@ -1906,6 +1993,7 @@ static void CompileFunction( ->functionSignatureArguments.count; LLVMTypeRef paramTypes[argumentCount + 1]; uint32_t paramIndex = 0; + LLVMValueRef wStructPointer = NULL; if (functionSignature->functionSignature.modifiers->functionModifiers .count > 0) @@ -1925,14 +2013,8 @@ static void CompileFunction( } char *functionName = strdup(structTypeDeclaration->name); - uint32_t nameLen = strlen(functionName); - nameLen += - 2 + - strlen( - functionSignature->functionSignature.identifier->identifier.name); - functionName = realloc(functionName, sizeof(char) * nameLen); - strcat(functionName, "_"); - strcat( + functionName = w_strcat(functionName, "_"); + functionName = w_strcat( functionName, functionSignature->functionSignature.identifier->identifier.name); @@ -1981,7 +2063,7 @@ static void CompileFunction( if (!isStatic) { - LLVMValueRef wStructPointer = LLVMGetParam(function, 0); + wStructPointer = LLVMGetParam(function, 0); AddStructVariablesToScope( structTypeDeclaration, builder, @@ -1996,7 +2078,7 @@ static void CompileFunction( strdup(functionSignature->functionSignature.arguments ->functionSignatureArguments.sequence[i] ->declaration.identifier->identifier.name); - strcat(ptrName, "_ptr"); + ptrName = w_strcat(ptrName, "_ptr"); LLVMValueRef argument = LLVMGetParam(function, i + !isStatic); LLVMValueRef argumentCopy = LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName); @@ -2015,6 +2097,7 @@ static void CompileFunction( { CompileStatement( structTypeDeclaration, + wStructPointer, builder, function, functionBody->statementSequence.sequence[i]); diff --git a/src/util.c b/src/util.c index 42911e7..491bb81 100644 --- a/src/util.c +++ b/src/util.c @@ -5,7 +5,7 @@ char *strdup(const char *s) { size_t slen = strlen(s); - char *result = (char *)malloc(slen + 1); + char *result = (char *)malloc(sizeof(char) * (slen + 1)); if (result == NULL) { return NULL; @@ -15,6 +15,15 @@ char *strdup(const char *s) return result; } +char *w_strcat(char *s, char *s2) +{ + size_t slen = strlen(s); + size_t slen2 = strlen(s2); + s = realloc(s, sizeof(char) * (slen + slen2 + 1)); + strcat(s, s2); + return s; +} + uint64_t str_hash(char *str) { uint64_t hash = 5381; diff --git a/src/util.h b/src/util.h index 108211b..94a64e2 100644 --- a/src/util.h +++ b/src/util.h @@ -5,6 +5,7 @@ #include char *strdup(const char *s); +char *w_strcat(char *s, char *s2); uint64_t str_hash(char *str); #endif /* WRAITH_UTIL_H */