From d641f713de7a9980d67b99e5334360c0f25967cb Mon Sep 17 00:00:00 2001 From: cosmonaut Date: Fri, 21 May 2021 19:52:13 -0700 Subject: [PATCH] progress on generics --- src/ast.c | 4 + src/codegen.c | 526 +++++++++++++++++++++++++++++++++++------------ src/identcheck.c | 28 ++- src/identcheck.h | 4 +- src/util.c | 2 +- 5 files changed, 429 insertions(+), 135 deletions(-) diff --git a/src/ast.c b/src/ast.c index 4c578e9..dde4693 100644 --- a/src/ast.c +++ b/src/ast.c @@ -740,6 +740,10 @@ TypeTag *MakeTypeTag(Node *node) ->functionSignature.type); break; + case AllocExpression: + tag = MakeTypeTag(node->allocExpression.type); + break; + default: fprintf( stderr, diff --git a/src/codegen.c b/src/codegen.c index f94e192..8f82dff 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -23,6 +23,12 @@ typedef struct LocalVariable LLVMValueRef value; } LocalVariable; +typedef struct LocalGenericType +{ + char *name; + LLVMTypeRef type; +} LocalGenericType; + typedef struct FunctionArgument { char *name; @@ -33,6 +39,9 @@ typedef struct ScopeFrame { LocalVariable *localVariables; uint32_t localVariableCount; + + LocalGenericType *genericTypes; + uint32_t genericTypeCount; } ScopeFrame; typedef struct Scope @@ -75,6 +84,8 @@ typedef struct MonomorphizedGenericFunctionHashArray typedef struct StructTypeGenericFunction { + char *parentStructName; + LLVMTypeRef parentStructPointerType; char *name; Node *functionDeclarationNode; uint8_t isStatic; @@ -100,6 +111,18 @@ typedef struct StructTypeDeclaration StructTypeDeclaration *structTypeDeclarations; uint32_t structTypeDeclarationCount; +/* FUNCTION FORWARD DECLARATIONS */ +static LLVMBasicBlockRef CompileStatement( + LLVMModuleRef module, + LLVMBuilderRef builder, + LLVMValueRef function, + Node *statement); + +static LLVMValueRef CompileExpression( + LLVMModuleRef module, + LLVMBuilderRef builder, + Node *expression); + static Scope *CreateScope() { Scope *scope = malloc(sizeof(Scope)); @@ -107,6 +130,8 @@ static Scope *CreateScope() scope->scopeStack = malloc(sizeof(ScopeFrame)); scope->scopeStack[0].localVariableCount = 0; scope->scopeStack[0].localVariables = NULL; + scope->scopeStack[0].genericTypeCount = 0; + scope->scopeStack[0].genericTypes = NULL; scope->scopeStackCount = 1; return scope; @@ -120,6 +145,8 @@ static void PushScopeFrame(Scope *scope) sizeof(ScopeFrame) * (scope->scopeStackCount + 1)); scope->scopeStack[index].localVariableCount = 0; scope->scopeStack[index].localVariables = NULL; + scope->scopeStack[index].genericTypeCount = 0; + scope->scopeStack[index].genericTypes = NULL; scope->scopeStackCount += 1; } @@ -138,31 +165,21 @@ static void PopScopeFrame(Scope *scope) free(scope->scopeStack[index].localVariables); } + if (scope->scopeStack[index].genericTypes != NULL) + { + for (i = 0; i < scope->scopeStack[index].genericTypeCount; i += 1) + { + free(scope->scopeStack[index].genericTypes[i].name); + } + free(scope->scopeStack[index].localVariables); + } + scope->scopeStackCount -= 1; scope->scopeStack = realloc(scope->scopeStack, sizeof(ScopeFrame) * scope->scopeStackCount); } -static void AddLocalVariable( - Scope *scope, - LLVMValueRef pointer, /* can be NULL */ - LLVMValueRef value, /* can be NULL */ - char *name) -{ - ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1]; - uint32_t index = scopeFrame->localVariableCount; - - scopeFrame->localVariables = realloc( - scopeFrame->localVariables, - sizeof(LocalVariable) * (scopeFrame->localVariableCount + 1)); - scopeFrame->localVariables[index].name = strdup(name); - scopeFrame->localVariables[index].pointer = pointer; - scopeFrame->localVariables[index].value = value; - - scopeFrame->localVariableCount += 1; -} - static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type) { switch (type) @@ -184,6 +201,120 @@ static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type) return NULL; } +static LLVMTypeRef LookupCustomType(char *name) +{ + int32_t i, j; + + for (i = scope->scopeStackCount - 1; i >= 0; i -= 1) + { + for (j = 0; j < scope->scopeStack[i].genericTypeCount; j += 1) + { + if (strcmp(scope->scopeStack[i].genericTypes[j].name, name) == 0) + { + return scope->scopeStack[i].genericTypes[j].type; + } + } + } + + for (i = 0; i < structTypeDeclarationCount; i += 1) + { + if (strcmp(structTypeDeclarations[i].name, name) == 0) + { + return structTypeDeclarations[i].structType; + } + } + + fprintf(stderr, "Could not find struct type!\n"); + return NULL; +} + +static LLVMTypeRef ResolveType(TypeTag *typeTag) +{ + if (typeTag->type == Primitive) + { + return WraithTypeToLLVMType(typeTag->value.primitiveType); + } + else if (typeTag->type == Custom) + { + return LookupCustomType(typeTag->value.customType); + } + else if (typeTag->type == Reference) + { + return LLVMPointerType(ResolveType(typeTag->value.referenceType), 0); + } + else + { + fprintf(stderr, "Unknown type node!\n"); + return NULL; + } +} + +static void AddLocalVariable( + Scope *scope, + LLVMValueRef pointer, /* can be NULL */ + LLVMValueRef value, /* can be NULL */ + char *name) +{ + ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1]; + uint32_t index = scopeFrame->localVariableCount; + + scopeFrame->localVariables = realloc( + scopeFrame->localVariables, + sizeof(LocalVariable) * (scopeFrame->localVariableCount + 1)); + scopeFrame->localVariables[index].name = strdup(name); + scopeFrame->localVariables[index].pointer = pointer; + scopeFrame->localVariables[index].value = value; + + scopeFrame->localVariableCount += 1; +} + +static void AddGenericVariable(Scope *scope, TypeTag *typeTag, char *name) +{ + ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1]; + uint32_t index = scopeFrame->genericTypeCount; + + scopeFrame->genericTypes = realloc( + scopeFrame->genericTypes, + sizeof(LocalGenericType) * (scopeFrame->genericTypeCount + 1)); + scopeFrame->genericTypes[index].name = strdup(name); + scopeFrame->genericTypes[index].type = ResolveType(typeTag); + + scopeFrame->genericTypeCount += 1; +} + +static void AddStructVariablesToScope( + LLVMBuilderRef builder, + LLVMValueRef structPointer) +{ + uint32_t i, j; + + for (i = 0; i < structTypeDeclarationCount; i += 1) + { + if (structTypeDeclarations[i].structPointerType == + LLVMTypeOf(structPointer)) + { + for (j = 0; j < structTypeDeclarations[i].fieldCount; j += 1) + { + char *ptrName = + strdup(structTypeDeclarations[i].fields[j].name); + strcat(ptrName, "_ptr"); + LLVMValueRef elementPointer = LLVMBuildStructGEP( + builder, + structPointer, + structTypeDeclarations[i].fields[j].index, + ptrName); + free(ptrName); + + AddLocalVariable( + scope, + elementPointer, + NULL, + structTypeDeclarations[i].fields[j].name); + } + } + } +} + static LLVMTypeRef FindStructType(char *name) { uint32_t i; @@ -355,6 +486,7 @@ static void DeclareGenericStructFunction( LLVMTypeRef wStructPointerType, Node *functionDeclarationNode, uint8_t isStatic, + char *parentStructName, char *name) { uint32_t i, j, index; @@ -364,8 +496,15 @@ static void DeclareGenericStructFunction( if (structTypeDeclarations[i].structPointerType == wStructPointerType) { index = structTypeDeclarations[i].genericFunctionCount; + structTypeDeclarations[i].genericFunctions = realloc( + structTypeDeclarations[i].genericFunctions, + sizeof(StructTypeGenericFunction) * + (structTypeDeclarations[i].genericFunctionCount + 1)); structTypeDeclarations[i].genericFunctions[index].name = strdup(name); + structTypeDeclarations[i].genericFunctions[index].parentStructName = + parentStructName; + structTypeDeclarations[i].structPointerType = wStructPointerType; structTypeDeclarations[i] .genericFunctions[index] .functionDeclarationNode = functionDeclarationNode; @@ -391,46 +530,6 @@ static void DeclareGenericStructFunction( } } -static LLVMTypeRef LookupCustomType(char *name) -{ - uint32_t i; - - for (i = 0; i < structTypeDeclarationCount; i += 1) - { - if (strcmp(structTypeDeclarations[i].name, name) == 0) - { - return structTypeDeclarations[i].structType; - } - } - - fprintf(stderr, "Could not find struct type!\n"); - return NULL; -} - -static LLVMTypeRef ResolveType(Node *typeNode) -{ - if (IsPrimitiveType(typeNode)) - { - return WraithTypeToLLVMType( - typeNode->type.typeNode->primitiveType.type); - } - else if (typeNode->type.typeNode->syntaxKind == CustomTypeNode) - { - return LookupCustomType(typeNode->type.typeNode->customType.name); - } - else if (typeNode->type.typeNode->syntaxKind == ReferenceTypeNode) - { - return LLVMPointerType( - ResolveType(typeNode->type.typeNode->referenceType.type), - 0); - } - else - { - fprintf(stderr, "Unknown type node!\n"); - return NULL; - } -} - static inline uint64_t HashTypeTags(TypeTag **tags, uint32_t count) { const uint64_t HASH_FACTOR = 97; @@ -445,7 +544,159 @@ static inline uint64_t HashTypeTags(TypeTag **tags, uint32_t count) return result; } +static StructTypeFunction CompileGenericFunction( + LLVMModuleRef module, + char *parentStructName, + LLVMTypeRef wStructPointerType, + TypeTag **genericArgumentTypes, + uint32_t genericArgumentTypeCount, + Node *functionDeclaration) +{ + uint32_t i; + uint8_t hasReturn = 0; + uint8_t isStatic = 0; + Node *functionSignature = + functionDeclaration->functionDeclaration.functionSignature; + Node *functionBody = functionDeclaration->functionDeclaration.functionBody; + uint32_t argumentCount = functionSignature->functionSignature.arguments + ->functionSignatureArguments.count; + LLVMTypeRef paramTypes[argumentCount + 1]; + uint32_t paramIndex = 0; + + PushScopeFrame(scope); + + for (i = 0; i < genericArgumentTypeCount; i += 1) + { + AddGenericVariable( + scope, + genericArgumentTypes[i], + functionDeclaration->functionDeclaration.functionSignature + ->functionSignature.genericArguments->genericArguments + .arguments[i] + ->genericArgument.identifier->identifier.name); + } + + if (functionSignature->functionSignature.modifiers->functionModifiers + .count > 0) + { + for (i = 0; i < functionSignature->functionSignature.modifiers + ->functionModifiers.count; + i += 1) + { + if (functionSignature->functionSignature.modifiers + ->functionModifiers.sequence[i] + ->syntaxKind == StaticModifier) + { + isStatic = 1; + break; + } + } + } + + char *functionName = strdup(parentStructName); + strcat(functionName, "_"); + strcat( + functionName, + functionSignature->functionSignature.identifier->identifier.name); + + for (i = 0; i < genericArgumentTypeCount; i += 1) + { + strcat(functionName, TypeTagToString(genericArgumentTypes[i])); + } + + if (!isStatic) + { + paramTypes[paramIndex] = wStructPointerType; + paramIndex += 1; + } + + for (i = 0; i < functionSignature->functionSignature.arguments + ->functionSignatureArguments.count; + i += 1) + { + paramTypes[paramIndex] = + ResolveType(functionSignature->functionSignature.arguments + ->functionSignatureArguments.sequence[i] + ->declaration.identifier->typeTag); + paramIndex += 1; + } + + LLVMTypeRef returnType = + ResolveType(functionSignature->functionSignature.identifier->typeTag); + LLVMTypeRef functionType = + LLVMFunctionType(returnType, paramTypes, paramIndex, 0); + + LLVMValueRef function = LLVMAddFunction(module, functionName, functionType); + + LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry"); + LLVMBuilderRef builder = LLVMCreateBuilder(); + LLVMPositionBuilderAtEnd(builder, entry); + + if (!isStatic) + { + LLVMValueRef wStructPointer = LLVMGetParam(function, 0); + AddStructVariablesToScope(builder, wStructPointer); + } + + for (i = 0; i < functionSignature->functionSignature.arguments + ->functionSignatureArguments.count; + i += 1) + { + char *ptrName = strdup(functionSignature->functionSignature.arguments + ->functionSignatureArguments.sequence[i] + ->declaration.identifier->identifier.name); + strcat(ptrName, "_ptr"); + LLVMValueRef argument = LLVMGetParam(function, i + !isStatic); + LLVMValueRef argumentCopy = + LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName); + LLVMBuildStore(builder, argument, argumentCopy); + free(ptrName); + AddLocalVariable( + scope, + argumentCopy, + NULL, + functionSignature->functionSignature.arguments + ->functionSignatureArguments.sequence[i] + ->declaration.identifier->identifier.name); + } + + for (i = 0; i < functionBody->statementSequence.count; i += 1) + { + CompileStatement( + module, + builder, + function, + functionBody->statementSequence.sequence[i]); + } + + hasReturn = + LLVMGetBasicBlockTerminator(LLVMGetLastBasicBlock(function)) != NULL; + + if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn) + { + LLVMBuildRetVoid(builder); + } + else if (LLVMGetTypeKind(returnType) != LLVMVoidTypeKind && !hasReturn) + { + fprintf(stderr, "Return statement not provided!"); + } + + LLVMDisposeBuilder(builder); + PopScopeFrame(scope); + free(functionName); + + StructTypeFunction structTypeFunction; + structTypeFunction.name = strdup( + functionSignature->functionSignature.identifier->identifier.name); + structTypeFunction.function = function; + structTypeFunction.returnType = returnType; + structTypeFunction.isStatic = isStatic; + + return structTypeFunction; +} + static LLVMValueRef LookupGenericFunction( + LLVMModuleRef module, StructTypeGenericFunction *genericFunction, TypeTag **genericArgumentTypes, uint32_t genericArgumentTypeCount, @@ -484,17 +735,41 @@ static LLVMValueRef LookupGenericFunction( if (hashEntry == NULL) { + StructTypeFunction function = CompileGenericFunction( + module, + genericFunction->parentStructName, + genericFunction->parentStructPointerType, + genericArgumentTypes, + genericArgumentTypeCount, + genericFunction->functionDeclarationNode); - /* TODO: compile */ + /* TODO: add to hash */ + hashArray->elements = realloc( + hashArray->elements, + sizeof(MonomorphizedGenericFunctionHashEntry) * + (hashArray->count + 1)); + hashArray->elements[hashArray->count].key = typeHash; + hashArray->elements[hashArray->count].types = + malloc(sizeof(TypeTag *) * genericArgumentTypeCount); + hashArray->elements[hashArray->count].typeCount = + genericArgumentTypeCount; + hashArray->elements[hashArray->count].function = function; + for (i = 0; i < genericArgumentTypeCount; i += 1) + { + hashArray->elements[hashArray->count].types[i] = + genericArgumentTypes[i]; + } + hashArray->count += 1; } *pReturnType = hashEntry->function.returnType; - *pStatic = hashEntry->function.isStatic; + *pStatic = genericFunction->isStatic; return hashEntry->function.function; } static LLVMValueRef LookupFunctionByType( + LLVMModuleRef module, LLVMTypeRef structType, char *name, TypeTag **genericArgumentTypes, @@ -528,6 +803,7 @@ static LLVMValueRef LookupFunctionByType( name) == 0) { return LookupGenericFunction( + module, &structTypeDeclarations[i].genericFunctions[j], genericArgumentTypes, genericArgumentTypeCount, @@ -543,6 +819,7 @@ static LLVMValueRef LookupFunctionByType( } static LLVMValueRef LookupFunctionByPointerType( + LLVMModuleRef module, LLVMTypeRef structPointerType, char *name, TypeTag **genericArgumentTypes, @@ -576,6 +853,7 @@ static LLVMValueRef LookupFunctionByPointerType( name) == 0) { return LookupGenericFunction( + module, &structTypeDeclarations[i].genericFunctions[j], genericArgumentTypes, genericArgumentTypeCount, @@ -591,6 +869,7 @@ static LLVMValueRef LookupFunctionByPointerType( } static LLVMValueRef LookupFunctionByInstance( + LLVMModuleRef module, LLVMValueRef structPointer, char *name, TypeTag **genericArgumentTypes, @@ -599,6 +878,7 @@ static LLVMValueRef LookupFunctionByInstance( uint8_t *pStatic) { return LookupFunctionByPointerType( + module, LLVMTypeOf(structPointer), name, genericArgumentTypes, @@ -607,41 +887,6 @@ static LLVMValueRef LookupFunctionByInstance( pStatic); } -static void AddStructVariablesToScope( - LLVMBuilderRef builder, - LLVMValueRef structPointer) -{ - uint32_t i, j; - - for (i = 0; i < structTypeDeclarationCount; i += 1) - { - if (structTypeDeclarations[i].structPointerType == - LLVMTypeOf(structPointer)) - { - for (j = 0; j < structTypeDeclarations[i].fieldCount; j += 1) - { - char *ptrName = - strdup(structTypeDeclarations[i].fields[j].name); - strcat(ptrName, "_ptr"); - LLVMValueRef elementPointer = LLVMBuildStructGEP( - builder, - structPointer, - structTypeDeclarations[i].fields[j].index, - ptrName); - free(ptrName); - - AddLocalVariable( - scope, - elementPointer, - NULL, - structTypeDeclarations[i].fields[j].name); - } - } - } -} - -static LLVMValueRef CompileExpression(LLVMBuilderRef builder, Node *expression); - static LLVMValueRef CompileNumber(Node *numberExpression) { return LLVMConstInt(LLVMInt64Type(), numberExpression->number.value, 0); @@ -658,13 +903,19 @@ static LLVMValueRef CompileString( } static LLVMValueRef CompileBinaryExpression( + LLVMModuleRef module, LLVMBuilderRef builder, Node *binaryExpression) { - LLVMValueRef left = - CompileExpression(builder, binaryExpression->binaryExpression.left); - LLVMValueRef right = - CompileExpression(builder, binaryExpression->binaryExpression.right); + LLVMValueRef left = CompileExpression( + module, + builder, + binaryExpression->binaryExpression.left); + + LLVMValueRef right = CompileExpression( + module, + builder, + binaryExpression->binaryExpression.right); switch (binaryExpression->binaryExpression.operator) { @@ -709,6 +960,7 @@ static LLVMValueRef CompileBinaryExpression( /* FIXME THIS IS ALL BROKEN */ static LLVMValueRef CompileFunctionCallExpression( + LLVMModuleRef module, LLVMBuilderRef builder, Node *functionCallExpression) { @@ -728,6 +980,7 @@ static LLVMValueRef CompileFunctionCallExpression( LLVMTypeRef functionReturnType; char *returnName = ""; + /* FIXME: this is completely wrong and not how we get generic args */ for (i = 0; i < functionCallExpression->functionCallExpression .argumentSequence->functionArgumentSequence.count; i += 1) @@ -739,7 +992,7 @@ static LLVMValueRef CompileFunctionCallExpression( genericArgumentTypes[genericArgumentCount] = functionCallExpression->functionCallExpression.argumentSequence ->functionArgumentSequence.sequence[i] - ->typeTag; + ->declaration.identifier->typeTag; genericArgumentCount += 1; } @@ -757,6 +1010,7 @@ static LLVMValueRef CompileFunctionCallExpression( if (typeReference != NULL) { function = LookupFunctionByType( + module, typeReference, functionCallExpression->functionCallExpression.identifier ->accessExpression.accessor->identifier.name, @@ -771,6 +1025,7 @@ static LLVMValueRef CompileFunctionCallExpression( functionCallExpression->functionCallExpression.identifier ->accessExpression.accessee->identifier.name); function = LookupFunctionByInstance( + module, structInstance, functionCallExpression->functionCallExpression.identifier ->accessExpression.accessor->identifier.name, @@ -797,6 +1052,7 @@ static LLVMValueRef CompileFunctionCallExpression( i += 1) { args[argumentCount] = CompileExpression( + module, builder, functionCallExpression->functionCallExpression.argumentSequence ->functionArgumentSequence.sequence[i]); @@ -843,11 +1099,14 @@ static LLVMValueRef CompileAllocExpression( LLVMBuilderRef builder, Node *allocExpression) { - LLVMTypeRef type = ResolveType(allocExpression->allocExpression.type); + LLVMTypeRef type = ResolveType(allocExpression->typeTag); return LLVMBuildMalloc(builder, type, "allocation"); } -static LLVMValueRef CompileExpression(LLVMBuilderRef builder, Node *expression) +static LLVMValueRef CompileExpression( + LLVMModuleRef module, + LLVMBuilderRef builder, + Node *expression) { switch (expression->syntaxKind) { @@ -858,10 +1117,10 @@ static LLVMValueRef CompileExpression(LLVMBuilderRef builder, Node *expression) return CompileAllocExpression(builder, expression); case BinaryExpression: - return CompileBinaryExpression(builder, expression); + return CompileBinaryExpression(module, builder, expression); case FunctionCallExpression: - return CompileFunctionCallExpression(builder, expression); + return CompileFunctionCallExpression(module, builder, expression); case Identifier: return FindVariableValue(builder, expression->identifier.name); @@ -877,17 +1136,14 @@ static LLVMValueRef CompileExpression(LLVMBuilderRef builder, Node *expression) return NULL; } -static LLVMBasicBlockRef CompileStatement( - LLVMBuilderRef builder, - LLVMValueRef function, - Node *statement); - static LLVMBasicBlockRef CompileReturn( + LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement) { LLVMValueRef expression = CompileExpression( + module, builder, returnStatemement->returnStatement.expression); LLVMBuildRet(builder, expression); @@ -916,7 +1172,7 @@ static LLVMValueRef CompileFunctionVariableDeclaration( variable = LLVMBuildAlloca( builder, - ResolveType(variableDeclaration->declaration.type), + ResolveType(variableDeclaration->declaration.identifier->typeTag), ptrName); free(ptrName); @@ -927,11 +1183,13 @@ static LLVMValueRef CompileFunctionVariableDeclaration( } static LLVMBasicBlockRef CompileAssignment( + LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement) { LLVMValueRef result = CompileExpression( + module, builder, assignmentStatement->assignmentStatement.right); LLVMValueRef identifier; @@ -969,13 +1227,14 @@ static LLVMBasicBlockRef CompileAssignment( } static LLVMBasicBlockRef CompileIfStatement( + LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement) { uint32_t i; LLVMValueRef conditional = - CompileExpression(builder, ifStatement->ifStatement.expression); + CompileExpression(module, builder, ifStatement->ifStatement.expression); LLVMBasicBlockRef block = LLVMAppendBasicBlock(function, "ifBlock"); LLVMBasicBlockRef afterCond = LLVMAppendBasicBlock(function, "afterCond"); @@ -990,6 +1249,7 @@ static LLVMBasicBlockRef CompileIfStatement( i += 1) { CompileStatement( + module, builder, function, ifStatement->ifStatement.statementSequence->statementSequence @@ -1003,12 +1263,14 @@ static LLVMBasicBlockRef CompileIfStatement( } static LLVMBasicBlockRef CompileIfElseStatement( + LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *ifElseStatement) { uint32_t i; LLVMValueRef conditional = CompileExpression( + module, builder, ifElseStatement->ifElseStatement.ifStatement->ifStatement.expression); @@ -1025,6 +1287,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( i += 1) { CompileStatement( + module, builder, function, ifElseStatement->ifElseStatement.ifStatement->ifStatement @@ -1043,6 +1306,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( i += 1) { CompileStatement( + module, builder, function, ifElseStatement->ifElseStatement.elseStatement @@ -1052,6 +1316,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( else { CompileStatement( + module, builder, function, ifElseStatement->ifElseStatement.elseStatement); @@ -1064,6 +1329,7 @@ static LLVMBasicBlockRef CompileIfElseStatement( } static LLVMBasicBlockRef CompileForLoopStatement( + LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *forLoopStatement) @@ -1076,8 +1342,8 @@ static LLVMBasicBlockRef CompileForLoopStatement( LLVMAppendBasicBlock(function, "afterLoop"); char *iteratorVariableName = forLoopStatement->forLoop.declaration ->declaration.identifier->identifier.name; - LLVMTypeRef iteratorVariableType = - ResolveType(forLoopStatement->forLoop.declaration->declaration.type); + LLVMTypeRef iteratorVariableType = ResolveType( + forLoopStatement->forLoop.declaration->declaration.identifier->typeTag); PushScopeFrame(scope); @@ -1123,6 +1389,7 @@ static LLVMBasicBlockRef CompileForLoopStatement( i += 1) { lastBlock = CompileStatement( + module, builder, function, forLoopStatement->forLoop.statementSequence->statementSequence @@ -1151,6 +1418,7 @@ static LLVMBasicBlockRef CompileForLoopStatement( } static LLVMBasicBlockRef CompileStatement( + LLVMModuleRef module, LLVMBuilderRef builder, LLVMValueRef function, Node *statement) @@ -1158,27 +1426,27 @@ static LLVMBasicBlockRef CompileStatement( switch (statement->syntaxKind) { case Assignment: - return CompileAssignment(builder, function, statement); + return CompileAssignment(module, builder, function, statement); case Declaration: CompileFunctionVariableDeclaration(builder, function, statement); return LLVMGetLastBasicBlock(function); case ForLoop: - return CompileForLoopStatement(builder, function, statement); + return CompileForLoopStatement(module, builder, function, statement); case FunctionCallExpression: - CompileFunctionCallExpression(builder, statement); + CompileFunctionCallExpression(module, builder, statement); return LLVMGetLastBasicBlock(function); case IfStatement: - return CompileIfStatement(builder, function, statement); + return CompileIfStatement(module, builder, function, statement); case IfElseStatement: - return CompileIfElseStatement(builder, function, statement); + return CompileIfElseStatement(module, builder, function, statement); case Return: - return CompileReturn(builder, function, statement); + return CompileReturn(module, builder, function, statement); case ReturnVoid: return CompileReturnVoid(builder, function); @@ -1192,8 +1460,6 @@ static void CompileFunction( LLVMModuleRef module, char *parentStructName, LLVMTypeRef wStructPointerType, - Node **fieldDeclarations, - uint32_t fieldDeclarationCount, Node *functionDeclaration) { uint32_t i; @@ -1248,12 +1514,12 @@ static void CompileFunction( paramTypes[paramIndex] = ResolveType(functionSignature->functionSignature.arguments ->functionSignatureArguments.sequence[i] - ->declaration.type); + ->declaration.identifier->typeTag); paramIndex += 1; } - LLVMTypeRef returnType = - ResolveType(functionSignature->functionSignature.type); + LLVMTypeRef returnType = ResolveType( + functionSignature->functionSignature.identifier->typeTag); LLVMTypeRef functionType = LLVMFunctionType(returnType, paramTypes, paramIndex, 0); @@ -1303,6 +1569,7 @@ static void CompileFunction( for (i = 0; i < functionBody->statementSequence.count; i += 1) { CompileStatement( + module, builder, function, functionBody->statementSequence.sequence[i]); @@ -1330,7 +1597,8 @@ static void CompileFunction( wStructPointerType, functionDeclaration, isStatic, - functionName); + parentStructName, + functionSignature->functionSignature.identifier->identifier.name); } free(functionName); @@ -1367,8 +1635,8 @@ static void CompileStruct( switch (currentDeclarationNode->syntaxKind) { case Declaration: /* this is badly named */ - types[fieldCount] = - ResolveType(currentDeclarationNode->declaration.type); + types[fieldCount] = ResolveType( + currentDeclarationNode->declaration.identifier->typeTag); fieldDeclarations[fieldCount] = currentDeclarationNode; fieldCount += 1; break; @@ -1396,8 +1664,6 @@ static void CompileStruct( module, structName, wStructPointerType, - fieldDeclarations, - fieldCount, currentDeclarationNode); break; } diff --git a/src/identcheck.c b/src/identcheck.c index 571a29e..2d040d0 100644 --- a/src/identcheck.c +++ b/src/identcheck.c @@ -59,9 +59,7 @@ IdNode *MakeIdTree(Node *astNode, IdNode *parent) return NULL; case AllocExpression: - AddChildToNode( - parent, - MakeIdTree(astNode->allocExpression.type, parent)); + astNode->typeTag = MakeTypeTag(astNode); return NULL; case Assignment: @@ -154,6 +152,7 @@ IdNode *MakeIdTree(Node *astNode, IdNode *parent) idNode->typeTag = mainNode->typeTag; MakeIdTree(sigNode->functionSignature.arguments, mainNode); MakeIdTree(astNode->functionDeclaration.functionBody, mainNode); + MakeIdTree(sigNode->functionSignature.genericArguments, mainNode); break; } @@ -167,6 +166,23 @@ IdNode *MakeIdTree(Node *astNode, IdNode *parent) return NULL; } + case GenericArgument: + { + char *name = astNode->genericArgument.identifier->identifier.name; + mainNode = MakeIdNode(GenericType, name, parent); + break; + } + + case GenericArguments: + { + for (i = 0; i < astNode->genericArguments.count; i += 1) + { + Node *argNode = astNode->genericArguments.arguments[i]; + AddChildToNode(parent, MakeIdTree(argNode, parent)); + } + return NULL; + } + case Identifier: { char *name = astNode->identifier.name; @@ -302,6 +318,12 @@ void PrintIdNode(IdNode *node) case Variable: printf("%s : %s\n", node->name, TypeTagToString(node->typeTag)); break; + case GenericType: + printf("Generic type: %s\n", node->name); + break; + case Alloc: + printf("Alloc: %s\n", TypeTagToString(node->typeTag)); + break; } } diff --git a/src/identcheck.h b/src/identcheck.h index c0ccca6..8b287dd 100644 --- a/src/identcheck.h +++ b/src/identcheck.h @@ -17,7 +17,9 @@ typedef enum NodeType OrderedScope, Struct, Function, - Variable + Variable, + GenericType, + Alloc } NodeType; typedef struct IdNode diff --git a/src/util.c b/src/util.c index aa5b4fb..42911e7 100644 --- a/src/util.c +++ b/src/util.c @@ -20,7 +20,7 @@ uint64_t str_hash(char *str) uint64_t hash = 5381; size_t c; - while (c = *str++) + while ((c = *str++)) { hash = ((hash << 5) + hash) + c; /* hash * 33 + c */ }