diff --git a/euler001.w b/euler001.w index a5f5158..219be6e 100644 --- a/euler001.w +++ b/euler001.w @@ -4,7 +4,7 @@ struct Program { sum: int = 0; - for (i: int in [1..1000]) + for (i: int in [1..999]) { if ((i % 3 == 0) || (i % 5 == 0)) { @@ -12,6 +12,7 @@ struct Program } } - return sum; + Console.PrintLine("%i", sum); + return 0; } } diff --git a/generators/wraith.lex b/generators/wraith.lex index ee84f7b..f288c6e 100644 --- a/generators/wraith.lex +++ b/generators/wraith.lex @@ -23,7 +23,7 @@ "for" return FOR; [0-9]+ return NUMBER; [a-zA-Z][a-zA-Z0-9]* return ID; -\"[a-zA-Z][a-zA-Z0-9]*\" return STRING_LITERAL; +\".*\" return STRING_LITERAL; "+" return PLUS; "-" return MINUS; "*" return STAR; @@ -32,7 +32,6 @@ "<" return LESS_THAN; ">" return GREATER_THAN; "=" return EQUAL; -"\"" return QUOTE; "!" return BANG; "|" return BAR; "&" return AMPERSAND; diff --git a/generators/wraith.y b/generators/wraith.y index 84a1b65..7aaf70a 100644 --- a/generators/wraith.y +++ b/generators/wraith.y @@ -147,7 +147,7 @@ Number : NUMBER } PrimaryExpression : Number - | STRING + | STRING_LITERAL { $$ = MakeStringNode(yytext); } diff --git a/src/ast.c b/src/ast.c index 24bdeb9..4f33a63 100644 --- a/src/ast.c +++ b/src/ast.c @@ -6,15 +6,16 @@ char* strdup (const char* s) { - size_t slen = strlen(s); - char* result = malloc(slen + 1); - if(result == NULL) - { - return NULL; - } + size_t slen = strlen(s); + char* result = malloc(slen + 1); - memcpy(result, s, slen+1); - return result; + if(result == NULL) + { + return NULL; + } + + memcpy(result, s, slen+1); + return result; } const char* SyntaxKindString(SyntaxKind syntaxKind) @@ -124,9 +125,10 @@ Node* MakeNumberNode( Node* MakeStringNode( const char *string ) { + size_t slen = strlen(string); Node* node = (Node*) malloc(sizeof(Node)); node->syntaxKind = StringLiteral; - node->value.string = strdup(string); + node->value.string = strndup(string + 1, slen - 2); node->childCount = 0; return node; } diff --git a/src/codegen.c b/src/codegen.c index 0608916..3e75d1b 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -446,6 +446,13 @@ static LLVMValueRef CompileNumber( return LLVMConstInt(LLVMInt64Type(), numberExpression->value.number, 0); } +static LLVMValueRef CompileString( + LLVMBuilderRef builder, + Node *stringExpression +) { + return LLVMBuildGlobalStringPtr(builder, stringExpression->value.string, "stringConstant"); +} + static LLVMValueRef CompileBinaryExpression( LLVMBuilderRef builder, Node *binaryExpression @@ -598,6 +605,10 @@ static LLVMValueRef CompileExpression( case Number: return CompileNumber(expression); + + case StringLiteral: + return CompileString(builder, expression); + } fprintf(stderr, "Unknown expression kind!\n"); @@ -748,7 +759,12 @@ static LLVMBasicBlockRef CompileForLoopStatement(LLVMBuilderRef builder, LLVMVal AddLocalVariable(scope, NULL, iteratorValue, iteratorVariableName); LLVMPositionBuilderAtEnd(builder, bodyBlock); - LLVMValueRef nextValue = LLVMBuildAdd(builder, iteratorValue, LLVMConstInt(LLVMInt64Type(), 1, 0), "next"); + LLVMValueRef nextValue = LLVMBuildAdd( + builder, + iteratorValue, + LLVMConstInt(iteratorVariableType, forLoopStatement->children[1]->value.number, 0), + "next" + ); LLVMPositionBuilderAtEnd(builder, checkBlock); @@ -985,6 +1001,36 @@ static void Compile(LLVMModuleRef module, LLVMContextRef context, Node *node) } } +/* TODO: move this to some kind of standard library file? */ +static void RegisterLibraryFunctions(LLVMModuleRef module, LLVMContextRef context) +{ + LLVMTypeRef structType = LLVMStructCreateNamed(context, "Console"); + LLVMTypeRef structPointerType = LLVMPointerType(structType, 0); + AddStructDeclaration(structType, structPointerType, "Console", NULL, 0); + + LLVMTypeRef printfArg = LLVMPointerType(LLVMInt8Type(), 0); + LLVMTypeRef printfFunctionType = LLVMFunctionType(LLVMInt32Type(), &printfArg, 1, 1); + LLVMValueRef printfFunction = LLVMAddFunction(module, "printf", printfFunctionType); + LLVMSetLinkage(printfFunction, LLVMExternalLinkage); + + LLVMTypeRef printLineFunctionType = LLVMFunctionType(LLVMInt32Type(), &printfArg, 1, 1); + LLVMValueRef printLineFunction = LLVMAddFunction(module, "printLine", printLineFunctionType); + + LLVMBuilderRef builder = LLVMCreateBuilder(); + LLVMBasicBlockRef entry = LLVMAppendBasicBlock(printLineFunction, "entry"); + LLVMPositionBuilderAtEnd(builder, entry); + + LLVMValueRef newLine = LLVMBuildGlobalStringPtr(builder, "\n", "newline"); + + LLVMValueRef printParams[LLVMCountParams(printLineFunction)]; + LLVMGetParams(printLineFunction, printParams); + LLVMValueRef stringPrint = LLVMBuildCall(builder, printfFunction, printParams, LLVMCountParams(printLineFunction), "printfCall"); + LLVMValueRef newlinePrint = LLVMBuildCall(builder, printfFunction, &newLine, 1, "printNewLine"); + LLVMBuildRet(builder, LLVMBuildAnd(builder, stringPrint, newlinePrint, "and")); + + DeclareStructFunction(structPointerType, printLineFunction, LLVMInt8Type(), 1, "PrintLine"); +} + int Codegen(Node *node, uint32_t optimizationLevel) { scope = CreateScope(); @@ -995,6 +1041,8 @@ int Codegen(Node *node, uint32_t optimizationLevel) LLVMModuleRef module = LLVMModuleCreateWithName("my_module"); LLVMContextRef context = LLVMGetGlobalContext(); + RegisterLibraryFunctions(module, context); + Compile(module, context, node); /* add main call */