progress on generics

traversal
cosmonaut 2021-05-21 19:52:13 -07:00
parent 8a3920918c
commit d641f713de
5 changed files with 429 additions and 135 deletions

View File

@ -740,6 +740,10 @@ TypeTag *MakeTypeTag(Node *node)
->functionSignature.type); ->functionSignature.type);
break; break;
case AllocExpression:
tag = MakeTypeTag(node->allocExpression.type);
break;
default: default:
fprintf( fprintf(
stderr, stderr,

View File

@ -23,6 +23,12 @@ typedef struct LocalVariable
LLVMValueRef value; LLVMValueRef value;
} LocalVariable; } LocalVariable;
typedef struct LocalGenericType
{
char *name;
LLVMTypeRef type;
} LocalGenericType;
typedef struct FunctionArgument typedef struct FunctionArgument
{ {
char *name; char *name;
@ -33,6 +39,9 @@ typedef struct ScopeFrame
{ {
LocalVariable *localVariables; LocalVariable *localVariables;
uint32_t localVariableCount; uint32_t localVariableCount;
LocalGenericType *genericTypes;
uint32_t genericTypeCount;
} ScopeFrame; } ScopeFrame;
typedef struct Scope typedef struct Scope
@ -75,6 +84,8 @@ typedef struct MonomorphizedGenericFunctionHashArray
typedef struct StructTypeGenericFunction typedef struct StructTypeGenericFunction
{ {
char *parentStructName;
LLVMTypeRef parentStructPointerType;
char *name; char *name;
Node *functionDeclarationNode; Node *functionDeclarationNode;
uint8_t isStatic; uint8_t isStatic;
@ -100,6 +111,18 @@ typedef struct StructTypeDeclaration
StructTypeDeclaration *structTypeDeclarations; StructTypeDeclaration *structTypeDeclarations;
uint32_t structTypeDeclarationCount; uint32_t structTypeDeclarationCount;
/* FUNCTION FORWARD DECLARATIONS */
static LLVMBasicBlockRef CompileStatement(
LLVMModuleRef module,
LLVMBuilderRef builder,
LLVMValueRef function,
Node *statement);
static LLVMValueRef CompileExpression(
LLVMModuleRef module,
LLVMBuilderRef builder,
Node *expression);
static Scope *CreateScope() static Scope *CreateScope()
{ {
Scope *scope = malloc(sizeof(Scope)); Scope *scope = malloc(sizeof(Scope));
@ -107,6 +130,8 @@ static Scope *CreateScope()
scope->scopeStack = malloc(sizeof(ScopeFrame)); scope->scopeStack = malloc(sizeof(ScopeFrame));
scope->scopeStack[0].localVariableCount = 0; scope->scopeStack[0].localVariableCount = 0;
scope->scopeStack[0].localVariables = NULL; scope->scopeStack[0].localVariables = NULL;
scope->scopeStack[0].genericTypeCount = 0;
scope->scopeStack[0].genericTypes = NULL;
scope->scopeStackCount = 1; scope->scopeStackCount = 1;
return scope; return scope;
@ -120,6 +145,8 @@ static void PushScopeFrame(Scope *scope)
sizeof(ScopeFrame) * (scope->scopeStackCount + 1)); sizeof(ScopeFrame) * (scope->scopeStackCount + 1));
scope->scopeStack[index].localVariableCount = 0; scope->scopeStack[index].localVariableCount = 0;
scope->scopeStack[index].localVariables = NULL; scope->scopeStack[index].localVariables = NULL;
scope->scopeStack[index].genericTypeCount = 0;
scope->scopeStack[index].genericTypes = NULL;
scope->scopeStackCount += 1; scope->scopeStackCount += 1;
} }
@ -138,31 +165,21 @@ static void PopScopeFrame(Scope *scope)
free(scope->scopeStack[index].localVariables); free(scope->scopeStack[index].localVariables);
} }
if (scope->scopeStack[index].genericTypes != NULL)
{
for (i = 0; i < scope->scopeStack[index].genericTypeCount; i += 1)
{
free(scope->scopeStack[index].genericTypes[i].name);
}
free(scope->scopeStack[index].localVariables);
}
scope->scopeStackCount -= 1; scope->scopeStackCount -= 1;
scope->scopeStack = scope->scopeStack =
realloc(scope->scopeStack, sizeof(ScopeFrame) * scope->scopeStackCount); realloc(scope->scopeStack, sizeof(ScopeFrame) * scope->scopeStackCount);
} }
static void AddLocalVariable(
Scope *scope,
LLVMValueRef pointer, /* can be NULL */
LLVMValueRef value, /* can be NULL */
char *name)
{
ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1];
uint32_t index = scopeFrame->localVariableCount;
scopeFrame->localVariables = realloc(
scopeFrame->localVariables,
sizeof(LocalVariable) * (scopeFrame->localVariableCount + 1));
scopeFrame->localVariables[index].name = strdup(name);
scopeFrame->localVariables[index].pointer = pointer;
scopeFrame->localVariables[index].value = value;
scopeFrame->localVariableCount += 1;
}
static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type) static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type)
{ {
switch (type) switch (type)
@ -184,6 +201,120 @@ static LLVMTypeRef WraithTypeToLLVMType(PrimitiveType type)
return NULL; return NULL;
} }
static LLVMTypeRef LookupCustomType(char *name)
{
int32_t i, j;
for (i = scope->scopeStackCount - 1; i >= 0; i -= 1)
{
for (j = 0; j < scope->scopeStack[i].genericTypeCount; j += 1)
{
if (strcmp(scope->scopeStack[i].genericTypes[j].name, name) == 0)
{
return scope->scopeStack[i].genericTypes[j].type;
}
}
}
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (strcmp(structTypeDeclarations[i].name, name) == 0)
{
return structTypeDeclarations[i].structType;
}
}
fprintf(stderr, "Could not find struct type!\n");
return NULL;
}
static LLVMTypeRef ResolveType(TypeTag *typeTag)
{
if (typeTag->type == Primitive)
{
return WraithTypeToLLVMType(typeTag->value.primitiveType);
}
else if (typeTag->type == Custom)
{
return LookupCustomType(typeTag->value.customType);
}
else if (typeTag->type == Reference)
{
return LLVMPointerType(ResolveType(typeTag->value.referenceType), 0);
}
else
{
fprintf(stderr, "Unknown type node!\n");
return NULL;
}
}
static void AddLocalVariable(
Scope *scope,
LLVMValueRef pointer, /* can be NULL */
LLVMValueRef value, /* can be NULL */
char *name)
{
ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1];
uint32_t index = scopeFrame->localVariableCount;
scopeFrame->localVariables = realloc(
scopeFrame->localVariables,
sizeof(LocalVariable) * (scopeFrame->localVariableCount + 1));
scopeFrame->localVariables[index].name = strdup(name);
scopeFrame->localVariables[index].pointer = pointer;
scopeFrame->localVariables[index].value = value;
scopeFrame->localVariableCount += 1;
}
static void AddGenericVariable(Scope *scope, TypeTag *typeTag, char *name)
{
ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1];
uint32_t index = scopeFrame->genericTypeCount;
scopeFrame->genericTypes = realloc(
scopeFrame->genericTypes,
sizeof(LocalGenericType) * (scopeFrame->genericTypeCount + 1));
scopeFrame->genericTypes[index].name = strdup(name);
scopeFrame->genericTypes[index].type = ResolveType(typeTag);
scopeFrame->genericTypeCount += 1;
}
static void AddStructVariablesToScope(
LLVMBuilderRef builder,
LLVMValueRef structPointer)
{
uint32_t i, j;
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (structTypeDeclarations[i].structPointerType ==
LLVMTypeOf(structPointer))
{
for (j = 0; j < structTypeDeclarations[i].fieldCount; j += 1)
{
char *ptrName =
strdup(structTypeDeclarations[i].fields[j].name);
strcat(ptrName, "_ptr");
LLVMValueRef elementPointer = LLVMBuildStructGEP(
builder,
structPointer,
structTypeDeclarations[i].fields[j].index,
ptrName);
free(ptrName);
AddLocalVariable(
scope,
elementPointer,
NULL,
structTypeDeclarations[i].fields[j].name);
}
}
}
}
static LLVMTypeRef FindStructType(char *name) static LLVMTypeRef FindStructType(char *name)
{ {
uint32_t i; uint32_t i;
@ -355,6 +486,7 @@ static void DeclareGenericStructFunction(
LLVMTypeRef wStructPointerType, LLVMTypeRef wStructPointerType,
Node *functionDeclarationNode, Node *functionDeclarationNode,
uint8_t isStatic, uint8_t isStatic,
char *parentStructName,
char *name) char *name)
{ {
uint32_t i, j, index; uint32_t i, j, index;
@ -364,8 +496,15 @@ static void DeclareGenericStructFunction(
if (structTypeDeclarations[i].structPointerType == wStructPointerType) if (structTypeDeclarations[i].structPointerType == wStructPointerType)
{ {
index = structTypeDeclarations[i].genericFunctionCount; index = structTypeDeclarations[i].genericFunctionCount;
structTypeDeclarations[i].genericFunctions = realloc(
structTypeDeclarations[i].genericFunctions,
sizeof(StructTypeGenericFunction) *
(structTypeDeclarations[i].genericFunctionCount + 1));
structTypeDeclarations[i].genericFunctions[index].name = structTypeDeclarations[i].genericFunctions[index].name =
strdup(name); strdup(name);
structTypeDeclarations[i].genericFunctions[index].parentStructName =
parentStructName;
structTypeDeclarations[i].structPointerType = wStructPointerType;
structTypeDeclarations[i] structTypeDeclarations[i]
.genericFunctions[index] .genericFunctions[index]
.functionDeclarationNode = functionDeclarationNode; .functionDeclarationNode = functionDeclarationNode;
@ -391,46 +530,6 @@ static void DeclareGenericStructFunction(
} }
} }
static LLVMTypeRef LookupCustomType(char *name)
{
uint32_t i;
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (strcmp(structTypeDeclarations[i].name, name) == 0)
{
return structTypeDeclarations[i].structType;
}
}
fprintf(stderr, "Could not find struct type!\n");
return NULL;
}
static LLVMTypeRef ResolveType(Node *typeNode)
{
if (IsPrimitiveType(typeNode))
{
return WraithTypeToLLVMType(
typeNode->type.typeNode->primitiveType.type);
}
else if (typeNode->type.typeNode->syntaxKind == CustomTypeNode)
{
return LookupCustomType(typeNode->type.typeNode->customType.name);
}
else if (typeNode->type.typeNode->syntaxKind == ReferenceTypeNode)
{
return LLVMPointerType(
ResolveType(typeNode->type.typeNode->referenceType.type),
0);
}
else
{
fprintf(stderr, "Unknown type node!\n");
return NULL;
}
}
static inline uint64_t HashTypeTags(TypeTag **tags, uint32_t count) static inline uint64_t HashTypeTags(TypeTag **tags, uint32_t count)
{ {
const uint64_t HASH_FACTOR = 97; const uint64_t HASH_FACTOR = 97;
@ -445,7 +544,159 @@ static inline uint64_t HashTypeTags(TypeTag **tags, uint32_t count)
return result; return result;
} }
static StructTypeFunction CompileGenericFunction(
LLVMModuleRef module,
char *parentStructName,
LLVMTypeRef wStructPointerType,
TypeTag **genericArgumentTypes,
uint32_t genericArgumentTypeCount,
Node *functionDeclaration)
{
uint32_t i;
uint8_t hasReturn = 0;
uint8_t isStatic = 0;
Node *functionSignature =
functionDeclaration->functionDeclaration.functionSignature;
Node *functionBody = functionDeclaration->functionDeclaration.functionBody;
uint32_t argumentCount = functionSignature->functionSignature.arguments
->functionSignatureArguments.count;
LLVMTypeRef paramTypes[argumentCount + 1];
uint32_t paramIndex = 0;
PushScopeFrame(scope);
for (i = 0; i < genericArgumentTypeCount; i += 1)
{
AddGenericVariable(
scope,
genericArgumentTypes[i],
functionDeclaration->functionDeclaration.functionSignature
->functionSignature.genericArguments->genericArguments
.arguments[i]
->genericArgument.identifier->identifier.name);
}
if (functionSignature->functionSignature.modifiers->functionModifiers
.count > 0)
{
for (i = 0; i < functionSignature->functionSignature.modifiers
->functionModifiers.count;
i += 1)
{
if (functionSignature->functionSignature.modifiers
->functionModifiers.sequence[i]
->syntaxKind == StaticModifier)
{
isStatic = 1;
break;
}
}
}
char *functionName = strdup(parentStructName);
strcat(functionName, "_");
strcat(
functionName,
functionSignature->functionSignature.identifier->identifier.name);
for (i = 0; i < genericArgumentTypeCount; i += 1)
{
strcat(functionName, TypeTagToString(genericArgumentTypes[i]));
}
if (!isStatic)
{
paramTypes[paramIndex] = wStructPointerType;
paramIndex += 1;
}
for (i = 0; i < functionSignature->functionSignature.arguments
->functionSignatureArguments.count;
i += 1)
{
paramTypes[paramIndex] =
ResolveType(functionSignature->functionSignature.arguments
->functionSignatureArguments.sequence[i]
->declaration.identifier->typeTag);
paramIndex += 1;
}
LLVMTypeRef returnType =
ResolveType(functionSignature->functionSignature.identifier->typeTag);
LLVMTypeRef functionType =
LLVMFunctionType(returnType, paramTypes, paramIndex, 0);
LLVMValueRef function = LLVMAddFunction(module, functionName, functionType);
LLVMBasicBlockRef entry = LLVMAppendBasicBlock(function, "entry");
LLVMBuilderRef builder = LLVMCreateBuilder();
LLVMPositionBuilderAtEnd(builder, entry);
if (!isStatic)
{
LLVMValueRef wStructPointer = LLVMGetParam(function, 0);
AddStructVariablesToScope(builder, wStructPointer);
}
for (i = 0; i < functionSignature->functionSignature.arguments
->functionSignatureArguments.count;
i += 1)
{
char *ptrName = strdup(functionSignature->functionSignature.arguments
->functionSignatureArguments.sequence[i]
->declaration.identifier->identifier.name);
strcat(ptrName, "_ptr");
LLVMValueRef argument = LLVMGetParam(function, i + !isStatic);
LLVMValueRef argumentCopy =
LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName);
LLVMBuildStore(builder, argument, argumentCopy);
free(ptrName);
AddLocalVariable(
scope,
argumentCopy,
NULL,
functionSignature->functionSignature.arguments
->functionSignatureArguments.sequence[i]
->declaration.identifier->identifier.name);
}
for (i = 0; i < functionBody->statementSequence.count; i += 1)
{
CompileStatement(
module,
builder,
function,
functionBody->statementSequence.sequence[i]);
}
hasReturn =
LLVMGetBasicBlockTerminator(LLVMGetLastBasicBlock(function)) != NULL;
if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn)
{
LLVMBuildRetVoid(builder);
}
else if (LLVMGetTypeKind(returnType) != LLVMVoidTypeKind && !hasReturn)
{
fprintf(stderr, "Return statement not provided!");
}
LLVMDisposeBuilder(builder);
PopScopeFrame(scope);
free(functionName);
StructTypeFunction structTypeFunction;
structTypeFunction.name = strdup(
functionSignature->functionSignature.identifier->identifier.name);
structTypeFunction.function = function;
structTypeFunction.returnType = returnType;
structTypeFunction.isStatic = isStatic;
return structTypeFunction;
}
static LLVMValueRef LookupGenericFunction( static LLVMValueRef LookupGenericFunction(
LLVMModuleRef module,
StructTypeGenericFunction *genericFunction, StructTypeGenericFunction *genericFunction,
TypeTag **genericArgumentTypes, TypeTag **genericArgumentTypes,
uint32_t genericArgumentTypeCount, uint32_t genericArgumentTypeCount,
@ -484,17 +735,41 @@ static LLVMValueRef LookupGenericFunction(
if (hashEntry == NULL) if (hashEntry == NULL)
{ {
StructTypeFunction function = CompileGenericFunction(
module,
genericFunction->parentStructName,
genericFunction->parentStructPointerType,
genericArgumentTypes,
genericArgumentTypeCount,
genericFunction->functionDeclarationNode);
/* TODO: compile */ /* TODO: add to hash */
hashArray->elements = realloc(
hashArray->elements,
sizeof(MonomorphizedGenericFunctionHashEntry) *
(hashArray->count + 1));
hashArray->elements[hashArray->count].key = typeHash;
hashArray->elements[hashArray->count].types =
malloc(sizeof(TypeTag *) * genericArgumentTypeCount);
hashArray->elements[hashArray->count].typeCount =
genericArgumentTypeCount;
hashArray->elements[hashArray->count].function = function;
for (i = 0; i < genericArgumentTypeCount; i += 1)
{
hashArray->elements[hashArray->count].types[i] =
genericArgumentTypes[i];
}
hashArray->count += 1;
} }
*pReturnType = hashEntry->function.returnType; *pReturnType = hashEntry->function.returnType;
*pStatic = hashEntry->function.isStatic; *pStatic = genericFunction->isStatic;
return hashEntry->function.function; return hashEntry->function.function;
} }
static LLVMValueRef LookupFunctionByType( static LLVMValueRef LookupFunctionByType(
LLVMModuleRef module,
LLVMTypeRef structType, LLVMTypeRef structType,
char *name, char *name,
TypeTag **genericArgumentTypes, TypeTag **genericArgumentTypes,
@ -528,6 +803,7 @@ static LLVMValueRef LookupFunctionByType(
name) == 0) name) == 0)
{ {
return LookupGenericFunction( return LookupGenericFunction(
module,
&structTypeDeclarations[i].genericFunctions[j], &structTypeDeclarations[i].genericFunctions[j],
genericArgumentTypes, genericArgumentTypes,
genericArgumentTypeCount, genericArgumentTypeCount,
@ -543,6 +819,7 @@ static LLVMValueRef LookupFunctionByType(
} }
static LLVMValueRef LookupFunctionByPointerType( static LLVMValueRef LookupFunctionByPointerType(
LLVMModuleRef module,
LLVMTypeRef structPointerType, LLVMTypeRef structPointerType,
char *name, char *name,
TypeTag **genericArgumentTypes, TypeTag **genericArgumentTypes,
@ -576,6 +853,7 @@ static LLVMValueRef LookupFunctionByPointerType(
name) == 0) name) == 0)
{ {
return LookupGenericFunction( return LookupGenericFunction(
module,
&structTypeDeclarations[i].genericFunctions[j], &structTypeDeclarations[i].genericFunctions[j],
genericArgumentTypes, genericArgumentTypes,
genericArgumentTypeCount, genericArgumentTypeCount,
@ -591,6 +869,7 @@ static LLVMValueRef LookupFunctionByPointerType(
} }
static LLVMValueRef LookupFunctionByInstance( static LLVMValueRef LookupFunctionByInstance(
LLVMModuleRef module,
LLVMValueRef structPointer, LLVMValueRef structPointer,
char *name, char *name,
TypeTag **genericArgumentTypes, TypeTag **genericArgumentTypes,
@ -599,6 +878,7 @@ static LLVMValueRef LookupFunctionByInstance(
uint8_t *pStatic) uint8_t *pStatic)
{ {
return LookupFunctionByPointerType( return LookupFunctionByPointerType(
module,
LLVMTypeOf(structPointer), LLVMTypeOf(structPointer),
name, name,
genericArgumentTypes, genericArgumentTypes,
@ -607,41 +887,6 @@ static LLVMValueRef LookupFunctionByInstance(
pStatic); pStatic);
} }
static void AddStructVariablesToScope(
LLVMBuilderRef builder,
LLVMValueRef structPointer)
{
uint32_t i, j;
for (i = 0; i < structTypeDeclarationCount; i += 1)
{
if (structTypeDeclarations[i].structPointerType ==
LLVMTypeOf(structPointer))
{
for (j = 0; j < structTypeDeclarations[i].fieldCount; j += 1)
{
char *ptrName =
strdup(structTypeDeclarations[i].fields[j].name);
strcat(ptrName, "_ptr");
LLVMValueRef elementPointer = LLVMBuildStructGEP(
builder,
structPointer,
structTypeDeclarations[i].fields[j].index,
ptrName);
free(ptrName);
AddLocalVariable(
scope,
elementPointer,
NULL,
structTypeDeclarations[i].fields[j].name);
}
}
}
}
static LLVMValueRef CompileExpression(LLVMBuilderRef builder, Node *expression);
static LLVMValueRef CompileNumber(Node *numberExpression) static LLVMValueRef CompileNumber(Node *numberExpression)
{ {
return LLVMConstInt(LLVMInt64Type(), numberExpression->number.value, 0); return LLVMConstInt(LLVMInt64Type(), numberExpression->number.value, 0);
@ -658,13 +903,19 @@ static LLVMValueRef CompileString(
} }
static LLVMValueRef CompileBinaryExpression( static LLVMValueRef CompileBinaryExpression(
LLVMModuleRef module,
LLVMBuilderRef builder, LLVMBuilderRef builder,
Node *binaryExpression) Node *binaryExpression)
{ {
LLVMValueRef left = LLVMValueRef left = CompileExpression(
CompileExpression(builder, binaryExpression->binaryExpression.left); module,
LLVMValueRef right = builder,
CompileExpression(builder, binaryExpression->binaryExpression.right); binaryExpression->binaryExpression.left);
LLVMValueRef right = CompileExpression(
module,
builder,
binaryExpression->binaryExpression.right);
switch (binaryExpression->binaryExpression.operator) switch (binaryExpression->binaryExpression.operator)
{ {
@ -709,6 +960,7 @@ static LLVMValueRef CompileBinaryExpression(
/* FIXME THIS IS ALL BROKEN */ /* FIXME THIS IS ALL BROKEN */
static LLVMValueRef CompileFunctionCallExpression( static LLVMValueRef CompileFunctionCallExpression(
LLVMModuleRef module,
LLVMBuilderRef builder, LLVMBuilderRef builder,
Node *functionCallExpression) Node *functionCallExpression)
{ {
@ -728,6 +980,7 @@ static LLVMValueRef CompileFunctionCallExpression(
LLVMTypeRef functionReturnType; LLVMTypeRef functionReturnType;
char *returnName = ""; char *returnName = "";
/* FIXME: this is completely wrong and not how we get generic args */
for (i = 0; i < functionCallExpression->functionCallExpression for (i = 0; i < functionCallExpression->functionCallExpression
.argumentSequence->functionArgumentSequence.count; .argumentSequence->functionArgumentSequence.count;
i += 1) i += 1)
@ -739,7 +992,7 @@ static LLVMValueRef CompileFunctionCallExpression(
genericArgumentTypes[genericArgumentCount] = genericArgumentTypes[genericArgumentCount] =
functionCallExpression->functionCallExpression.argumentSequence functionCallExpression->functionCallExpression.argumentSequence
->functionArgumentSequence.sequence[i] ->functionArgumentSequence.sequence[i]
->typeTag; ->declaration.identifier->typeTag;
genericArgumentCount += 1; genericArgumentCount += 1;
} }
@ -757,6 +1010,7 @@ static LLVMValueRef CompileFunctionCallExpression(
if (typeReference != NULL) if (typeReference != NULL)
{ {
function = LookupFunctionByType( function = LookupFunctionByType(
module,
typeReference, typeReference,
functionCallExpression->functionCallExpression.identifier functionCallExpression->functionCallExpression.identifier
->accessExpression.accessor->identifier.name, ->accessExpression.accessor->identifier.name,
@ -771,6 +1025,7 @@ static LLVMValueRef CompileFunctionCallExpression(
functionCallExpression->functionCallExpression.identifier functionCallExpression->functionCallExpression.identifier
->accessExpression.accessee->identifier.name); ->accessExpression.accessee->identifier.name);
function = LookupFunctionByInstance( function = LookupFunctionByInstance(
module,
structInstance, structInstance,
functionCallExpression->functionCallExpression.identifier functionCallExpression->functionCallExpression.identifier
->accessExpression.accessor->identifier.name, ->accessExpression.accessor->identifier.name,
@ -797,6 +1052,7 @@ static LLVMValueRef CompileFunctionCallExpression(
i += 1) i += 1)
{ {
args[argumentCount] = CompileExpression( args[argumentCount] = CompileExpression(
module,
builder, builder,
functionCallExpression->functionCallExpression.argumentSequence functionCallExpression->functionCallExpression.argumentSequence
->functionArgumentSequence.sequence[i]); ->functionArgumentSequence.sequence[i]);
@ -843,11 +1099,14 @@ static LLVMValueRef CompileAllocExpression(
LLVMBuilderRef builder, LLVMBuilderRef builder,
Node *allocExpression) Node *allocExpression)
{ {
LLVMTypeRef type = ResolveType(allocExpression->allocExpression.type); LLVMTypeRef type = ResolveType(allocExpression->typeTag);
return LLVMBuildMalloc(builder, type, "allocation"); return LLVMBuildMalloc(builder, type, "allocation");
} }
static LLVMValueRef CompileExpression(LLVMBuilderRef builder, Node *expression) static LLVMValueRef CompileExpression(
LLVMModuleRef module,
LLVMBuilderRef builder,
Node *expression)
{ {
switch (expression->syntaxKind) switch (expression->syntaxKind)
{ {
@ -858,10 +1117,10 @@ static LLVMValueRef CompileExpression(LLVMBuilderRef builder, Node *expression)
return CompileAllocExpression(builder, expression); return CompileAllocExpression(builder, expression);
case BinaryExpression: case BinaryExpression:
return CompileBinaryExpression(builder, expression); return CompileBinaryExpression(module, builder, expression);
case FunctionCallExpression: case FunctionCallExpression:
return CompileFunctionCallExpression(builder, expression); return CompileFunctionCallExpression(module, builder, expression);
case Identifier: case Identifier:
return FindVariableValue(builder, expression->identifier.name); return FindVariableValue(builder, expression->identifier.name);
@ -877,17 +1136,14 @@ static LLVMValueRef CompileExpression(LLVMBuilderRef builder, Node *expression)
return NULL; return NULL;
} }
static LLVMBasicBlockRef CompileStatement(
LLVMBuilderRef builder,
LLVMValueRef function,
Node *statement);
static LLVMBasicBlockRef CompileReturn( static LLVMBasicBlockRef CompileReturn(
LLVMModuleRef module,
LLVMBuilderRef builder, LLVMBuilderRef builder,
LLVMValueRef function, LLVMValueRef function,
Node *returnStatemement) Node *returnStatemement)
{ {
LLVMValueRef expression = CompileExpression( LLVMValueRef expression = CompileExpression(
module,
builder, builder,
returnStatemement->returnStatement.expression); returnStatemement->returnStatement.expression);
LLVMBuildRet(builder, expression); LLVMBuildRet(builder, expression);
@ -916,7 +1172,7 @@ static LLVMValueRef CompileFunctionVariableDeclaration(
variable = LLVMBuildAlloca( variable = LLVMBuildAlloca(
builder, builder,
ResolveType(variableDeclaration->declaration.type), ResolveType(variableDeclaration->declaration.identifier->typeTag),
ptrName); ptrName);
free(ptrName); free(ptrName);
@ -927,11 +1183,13 @@ static LLVMValueRef CompileFunctionVariableDeclaration(
} }
static LLVMBasicBlockRef CompileAssignment( static LLVMBasicBlockRef CompileAssignment(
LLVMModuleRef module,
LLVMBuilderRef builder, LLVMBuilderRef builder,
LLVMValueRef function, LLVMValueRef function,
Node *assignmentStatement) Node *assignmentStatement)
{ {
LLVMValueRef result = CompileExpression( LLVMValueRef result = CompileExpression(
module,
builder, builder,
assignmentStatement->assignmentStatement.right); assignmentStatement->assignmentStatement.right);
LLVMValueRef identifier; LLVMValueRef identifier;
@ -969,13 +1227,14 @@ static LLVMBasicBlockRef CompileAssignment(
} }
static LLVMBasicBlockRef CompileIfStatement( static LLVMBasicBlockRef CompileIfStatement(
LLVMModuleRef module,
LLVMBuilderRef builder, LLVMBuilderRef builder,
LLVMValueRef function, LLVMValueRef function,
Node *ifStatement) Node *ifStatement)
{ {
uint32_t i; uint32_t i;
LLVMValueRef conditional = LLVMValueRef conditional =
CompileExpression(builder, ifStatement->ifStatement.expression); CompileExpression(module, builder, ifStatement->ifStatement.expression);
LLVMBasicBlockRef block = LLVMAppendBasicBlock(function, "ifBlock"); LLVMBasicBlockRef block = LLVMAppendBasicBlock(function, "ifBlock");
LLVMBasicBlockRef afterCond = LLVMAppendBasicBlock(function, "afterCond"); LLVMBasicBlockRef afterCond = LLVMAppendBasicBlock(function, "afterCond");
@ -990,6 +1249,7 @@ static LLVMBasicBlockRef CompileIfStatement(
i += 1) i += 1)
{ {
CompileStatement( CompileStatement(
module,
builder, builder,
function, function,
ifStatement->ifStatement.statementSequence->statementSequence ifStatement->ifStatement.statementSequence->statementSequence
@ -1003,12 +1263,14 @@ static LLVMBasicBlockRef CompileIfStatement(
} }
static LLVMBasicBlockRef CompileIfElseStatement( static LLVMBasicBlockRef CompileIfElseStatement(
LLVMModuleRef module,
LLVMBuilderRef builder, LLVMBuilderRef builder,
LLVMValueRef function, LLVMValueRef function,
Node *ifElseStatement) Node *ifElseStatement)
{ {
uint32_t i; uint32_t i;
LLVMValueRef conditional = CompileExpression( LLVMValueRef conditional = CompileExpression(
module,
builder, builder,
ifElseStatement->ifElseStatement.ifStatement->ifStatement.expression); ifElseStatement->ifElseStatement.ifStatement->ifStatement.expression);
@ -1025,6 +1287,7 @@ static LLVMBasicBlockRef CompileIfElseStatement(
i += 1) i += 1)
{ {
CompileStatement( CompileStatement(
module,
builder, builder,
function, function,
ifElseStatement->ifElseStatement.ifStatement->ifStatement ifElseStatement->ifElseStatement.ifStatement->ifStatement
@ -1043,6 +1306,7 @@ static LLVMBasicBlockRef CompileIfElseStatement(
i += 1) i += 1)
{ {
CompileStatement( CompileStatement(
module,
builder, builder,
function, function,
ifElseStatement->ifElseStatement.elseStatement ifElseStatement->ifElseStatement.elseStatement
@ -1052,6 +1316,7 @@ static LLVMBasicBlockRef CompileIfElseStatement(
else else
{ {
CompileStatement( CompileStatement(
module,
builder, builder,
function, function,
ifElseStatement->ifElseStatement.elseStatement); ifElseStatement->ifElseStatement.elseStatement);
@ -1064,6 +1329,7 @@ static LLVMBasicBlockRef CompileIfElseStatement(
} }
static LLVMBasicBlockRef CompileForLoopStatement( static LLVMBasicBlockRef CompileForLoopStatement(
LLVMModuleRef module,
LLVMBuilderRef builder, LLVMBuilderRef builder,
LLVMValueRef function, LLVMValueRef function,
Node *forLoopStatement) Node *forLoopStatement)
@ -1076,8 +1342,8 @@ static LLVMBasicBlockRef CompileForLoopStatement(
LLVMAppendBasicBlock(function, "afterLoop"); LLVMAppendBasicBlock(function, "afterLoop");
char *iteratorVariableName = forLoopStatement->forLoop.declaration char *iteratorVariableName = forLoopStatement->forLoop.declaration
->declaration.identifier->identifier.name; ->declaration.identifier->identifier.name;
LLVMTypeRef iteratorVariableType = LLVMTypeRef iteratorVariableType = ResolveType(
ResolveType(forLoopStatement->forLoop.declaration->declaration.type); forLoopStatement->forLoop.declaration->declaration.identifier->typeTag);
PushScopeFrame(scope); PushScopeFrame(scope);
@ -1123,6 +1389,7 @@ static LLVMBasicBlockRef CompileForLoopStatement(
i += 1) i += 1)
{ {
lastBlock = CompileStatement( lastBlock = CompileStatement(
module,
builder, builder,
function, function,
forLoopStatement->forLoop.statementSequence->statementSequence forLoopStatement->forLoop.statementSequence->statementSequence
@ -1151,6 +1418,7 @@ static LLVMBasicBlockRef CompileForLoopStatement(
} }
static LLVMBasicBlockRef CompileStatement( static LLVMBasicBlockRef CompileStatement(
LLVMModuleRef module,
LLVMBuilderRef builder, LLVMBuilderRef builder,
LLVMValueRef function, LLVMValueRef function,
Node *statement) Node *statement)
@ -1158,27 +1426,27 @@ static LLVMBasicBlockRef CompileStatement(
switch (statement->syntaxKind) switch (statement->syntaxKind)
{ {
case Assignment: case Assignment:
return CompileAssignment(builder, function, statement); return CompileAssignment(module, builder, function, statement);
case Declaration: case Declaration:
CompileFunctionVariableDeclaration(builder, function, statement); CompileFunctionVariableDeclaration(builder, function, statement);
return LLVMGetLastBasicBlock(function); return LLVMGetLastBasicBlock(function);
case ForLoop: case ForLoop:
return CompileForLoopStatement(builder, function, statement); return CompileForLoopStatement(module, builder, function, statement);
case FunctionCallExpression: case FunctionCallExpression:
CompileFunctionCallExpression(builder, statement); CompileFunctionCallExpression(module, builder, statement);
return LLVMGetLastBasicBlock(function); return LLVMGetLastBasicBlock(function);
case IfStatement: case IfStatement:
return CompileIfStatement(builder, function, statement); return CompileIfStatement(module, builder, function, statement);
case IfElseStatement: case IfElseStatement:
return CompileIfElseStatement(builder, function, statement); return CompileIfElseStatement(module, builder, function, statement);
case Return: case Return:
return CompileReturn(builder, function, statement); return CompileReturn(module, builder, function, statement);
case ReturnVoid: case ReturnVoid:
return CompileReturnVoid(builder, function); return CompileReturnVoid(builder, function);
@ -1192,8 +1460,6 @@ static void CompileFunction(
LLVMModuleRef module, LLVMModuleRef module,
char *parentStructName, char *parentStructName,
LLVMTypeRef wStructPointerType, LLVMTypeRef wStructPointerType,
Node **fieldDeclarations,
uint32_t fieldDeclarationCount,
Node *functionDeclaration) Node *functionDeclaration)
{ {
uint32_t i; uint32_t i;
@ -1248,12 +1514,12 @@ static void CompileFunction(
paramTypes[paramIndex] = paramTypes[paramIndex] =
ResolveType(functionSignature->functionSignature.arguments ResolveType(functionSignature->functionSignature.arguments
->functionSignatureArguments.sequence[i] ->functionSignatureArguments.sequence[i]
->declaration.type); ->declaration.identifier->typeTag);
paramIndex += 1; paramIndex += 1;
} }
LLVMTypeRef returnType = LLVMTypeRef returnType = ResolveType(
ResolveType(functionSignature->functionSignature.type); functionSignature->functionSignature.identifier->typeTag);
LLVMTypeRef functionType = LLVMTypeRef functionType =
LLVMFunctionType(returnType, paramTypes, paramIndex, 0); LLVMFunctionType(returnType, paramTypes, paramIndex, 0);
@ -1303,6 +1569,7 @@ static void CompileFunction(
for (i = 0; i < functionBody->statementSequence.count; i += 1) for (i = 0; i < functionBody->statementSequence.count; i += 1)
{ {
CompileStatement( CompileStatement(
module,
builder, builder,
function, function,
functionBody->statementSequence.sequence[i]); functionBody->statementSequence.sequence[i]);
@ -1330,7 +1597,8 @@ static void CompileFunction(
wStructPointerType, wStructPointerType,
functionDeclaration, functionDeclaration,
isStatic, isStatic,
functionName); parentStructName,
functionSignature->functionSignature.identifier->identifier.name);
} }
free(functionName); free(functionName);
@ -1367,8 +1635,8 @@ static void CompileStruct(
switch (currentDeclarationNode->syntaxKind) switch (currentDeclarationNode->syntaxKind)
{ {
case Declaration: /* this is badly named */ case Declaration: /* this is badly named */
types[fieldCount] = types[fieldCount] = ResolveType(
ResolveType(currentDeclarationNode->declaration.type); currentDeclarationNode->declaration.identifier->typeTag);
fieldDeclarations[fieldCount] = currentDeclarationNode; fieldDeclarations[fieldCount] = currentDeclarationNode;
fieldCount += 1; fieldCount += 1;
break; break;
@ -1396,8 +1664,6 @@ static void CompileStruct(
module, module,
structName, structName,
wStructPointerType, wStructPointerType,
fieldDeclarations,
fieldCount,
currentDeclarationNode); currentDeclarationNode);
break; break;
} }

View File

@ -59,9 +59,7 @@ IdNode *MakeIdTree(Node *astNode, IdNode *parent)
return NULL; return NULL;
case AllocExpression: case AllocExpression:
AddChildToNode( astNode->typeTag = MakeTypeTag(astNode);
parent,
MakeIdTree(astNode->allocExpression.type, parent));
return NULL; return NULL;
case Assignment: case Assignment:
@ -154,6 +152,7 @@ IdNode *MakeIdTree(Node *astNode, IdNode *parent)
idNode->typeTag = mainNode->typeTag; idNode->typeTag = mainNode->typeTag;
MakeIdTree(sigNode->functionSignature.arguments, mainNode); MakeIdTree(sigNode->functionSignature.arguments, mainNode);
MakeIdTree(astNode->functionDeclaration.functionBody, mainNode); MakeIdTree(astNode->functionDeclaration.functionBody, mainNode);
MakeIdTree(sigNode->functionSignature.genericArguments, mainNode);
break; break;
} }
@ -167,6 +166,23 @@ IdNode *MakeIdTree(Node *astNode, IdNode *parent)
return NULL; return NULL;
} }
case GenericArgument:
{
char *name = astNode->genericArgument.identifier->identifier.name;
mainNode = MakeIdNode(GenericType, name, parent);
break;
}
case GenericArguments:
{
for (i = 0; i < astNode->genericArguments.count; i += 1)
{
Node *argNode = astNode->genericArguments.arguments[i];
AddChildToNode(parent, MakeIdTree(argNode, parent));
}
return NULL;
}
case Identifier: case Identifier:
{ {
char *name = astNode->identifier.name; char *name = astNode->identifier.name;
@ -302,6 +318,12 @@ void PrintIdNode(IdNode *node)
case Variable: case Variable:
printf("%s : %s\n", node->name, TypeTagToString(node->typeTag)); printf("%s : %s\n", node->name, TypeTagToString(node->typeTag));
break; break;
case GenericType:
printf("Generic type: %s\n", node->name);
break;
case Alloc:
printf("Alloc: %s\n", TypeTagToString(node->typeTag));
break;
} }
} }

View File

@ -17,7 +17,9 @@ typedef enum NodeType
OrderedScope, OrderedScope,
Struct, Struct,
Function, Function,
Variable Variable,
GenericType,
Alloc
} NodeType; } NodeType;
typedef struct IdNode typedef struct IdNode

View File

@ -20,7 +20,7 @@ uint64_t str_hash(char *str)
uint64_t hash = 5381; uint64_t hash = 5381;
size_t c; size_t c;
while (c = *str++) while ((c = *str++))
{ {
hash = ((hash << 5) + hash) + c; /* hash * 33 + c */ hash = ((hash << 5) + hash) + c; /* hash * 33 + c */
} }