generics
cosmonaut 2021-04-30 12:17:44 -07:00
parent 3b43d44f35
commit cbeb8d3ce2
5 changed files with 65 additions and 15 deletions

View File

@ -4,7 +4,7 @@ struct Program
{
sum: int = 0;
for (i: int in [1..1000])
for (i: int in [1..999])
{
if ((i % 3 == 0) || (i % 5 == 0))
{
@ -12,6 +12,7 @@ struct Program
}
}
return sum;
Console.PrintLine("%i", sum);
return 0;
}
}

View File

@ -23,7 +23,7 @@
"for" return FOR;
[0-9]+ return NUMBER;
[a-zA-Z][a-zA-Z0-9]* return ID;
\"[a-zA-Z][a-zA-Z0-9]*\" return STRING_LITERAL;
\".*\" return STRING_LITERAL;
"+" return PLUS;
"-" return MINUS;
"*" return STAR;
@ -32,7 +32,6 @@
"<" return LESS_THAN;
">" return GREATER_THAN;
"=" return EQUAL;
"\"" return QUOTE;
"!" return BANG;
"|" return BAR;
"&" return AMPERSAND;

View File

@ -147,7 +147,7 @@ Number : NUMBER
}
PrimaryExpression : Number
| STRING
| STRING_LITERAL
{
$$ = MakeStringNode(yytext);
}

View File

@ -6,15 +6,16 @@
char* strdup (const char* s)
{
size_t slen = strlen(s);
char* result = malloc(slen + 1);
if(result == NULL)
{
return NULL;
}
size_t slen = strlen(s);
char* result = malloc(slen + 1);
memcpy(result, s, slen+1);
return result;
if(result == NULL)
{
return NULL;
}
memcpy(result, s, slen+1);
return result;
}
const char* SyntaxKindString(SyntaxKind syntaxKind)
@ -124,9 +125,10 @@ Node* MakeNumberNode(
Node* MakeStringNode(
const char *string
) {
size_t slen = strlen(string);
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = StringLiteral;
node->value.string = strdup(string);
node->value.string = strndup(string + 1, slen - 2);
node->childCount = 0;
return node;
}

View File

@ -446,6 +446,13 @@ static LLVMValueRef CompileNumber(
return LLVMConstInt(LLVMInt64Type(), numberExpression->value.number, 0);
}
static LLVMValueRef CompileString(
LLVMBuilderRef builder,
Node *stringExpression
) {
return LLVMBuildGlobalStringPtr(builder, stringExpression->value.string, "stringConstant");
}
static LLVMValueRef CompileBinaryExpression(
LLVMBuilderRef builder,
Node *binaryExpression
@ -598,6 +605,10 @@ static LLVMValueRef CompileExpression(
case Number:
return CompileNumber(expression);
case StringLiteral:
return CompileString(builder, expression);
}
fprintf(stderr, "Unknown expression kind!\n");
@ -748,7 +759,12 @@ static LLVMBasicBlockRef CompileForLoopStatement(LLVMBuilderRef builder, LLVMVal
AddLocalVariable(scope, NULL, iteratorValue, iteratorVariableName);
LLVMPositionBuilderAtEnd(builder, bodyBlock);
LLVMValueRef nextValue = LLVMBuildAdd(builder, iteratorValue, LLVMConstInt(LLVMInt64Type(), 1, 0), "next");
LLVMValueRef nextValue = LLVMBuildAdd(
builder,
iteratorValue,
LLVMConstInt(iteratorVariableType, forLoopStatement->children[1]->value.number, 0),
"next"
);
LLVMPositionBuilderAtEnd(builder, checkBlock);
@ -985,6 +1001,36 @@ static void Compile(LLVMModuleRef module, LLVMContextRef context, Node *node)
}
}
/* TODO: move this to some kind of standard library file? */
static void RegisterLibraryFunctions(LLVMModuleRef module, LLVMContextRef context)
{
LLVMTypeRef structType = LLVMStructCreateNamed(context, "Console");
LLVMTypeRef structPointerType = LLVMPointerType(structType, 0);
AddStructDeclaration(structType, structPointerType, "Console", NULL, 0);
LLVMTypeRef printfArg = LLVMPointerType(LLVMInt8Type(), 0);
LLVMTypeRef printfFunctionType = LLVMFunctionType(LLVMInt32Type(), &printfArg, 1, 1);
LLVMValueRef printfFunction = LLVMAddFunction(module, "printf", printfFunctionType);
LLVMSetLinkage(printfFunction, LLVMExternalLinkage);
LLVMTypeRef printLineFunctionType = LLVMFunctionType(LLVMInt32Type(), &printfArg, 1, 1);
LLVMValueRef printLineFunction = LLVMAddFunction(module, "printLine", printLineFunctionType);
LLVMBuilderRef builder = LLVMCreateBuilder();
LLVMBasicBlockRef entry = LLVMAppendBasicBlock(printLineFunction, "entry");
LLVMPositionBuilderAtEnd(builder, entry);
LLVMValueRef newLine = LLVMBuildGlobalStringPtr(builder, "\n", "newline");
LLVMValueRef printParams[LLVMCountParams(printLineFunction)];
LLVMGetParams(printLineFunction, printParams);
LLVMValueRef stringPrint = LLVMBuildCall(builder, printfFunction, printParams, LLVMCountParams(printLineFunction), "printfCall");
LLVMValueRef newlinePrint = LLVMBuildCall(builder, printfFunction, &newLine, 1, "printNewLine");
LLVMBuildRet(builder, LLVMBuildAnd(builder, stringPrint, newlinePrint, "and"));
DeclareStructFunction(structPointerType, printLineFunction, LLVMInt8Type(), 1, "PrintLine");
}
int Codegen(Node *node, uint32_t optimizationLevel)
{
scope = CreateScope();
@ -995,6 +1041,8 @@ int Codegen(Node *node, uint32_t optimizationLevel)
LLVMModuleRef module = LLVMModuleCreateWithName("my_module");
LLVMContextRef context = LLVMGetGlobalContext();
RegisterLibraryFunctions(module, context);
Compile(module, context, node);
/* add main call */