initial for loop range implementation

generics
cosmonaut 2021-04-29 23:49:35 -07:00
parent c728dd6b8c
commit 62f42e47b9
6 changed files with 196 additions and 37 deletions

19
euler001.w Normal file
View File

@ -0,0 +1,19 @@
struct Program
{
static Main(): int
{
sum: int;
sum = 0;
for (i in [1..1000])
{
if ((i % 3 == 0) || (i % 5 == 0))
{
sum = sum + i;
}
}
return sum;
}
}

View File

@ -19,6 +19,8 @@
"alloc" return ALLOC; "alloc" return ALLOC;
"if" return IF; "if" return IF;
"else" return ELSE; "else" return ELSE;
"in" return IN;
"for" return FOR;
[0-9]+ return NUMBER; [0-9]+ return NUMBER;
[a-zA-Z][a-zA-Z0-9]* return ID; [a-zA-Z][a-zA-Z0-9]* return ID;
\"[a-zA-Z][a-zA-Z0-9]*\" return STRING_LITERAL; \"[a-zA-Z][a-zA-Z0-9]*\" return STRING_LITERAL;

View File

@ -32,6 +32,8 @@ extern FILE *yyin;
%token ALLOC %token ALLOC
%token IF %token IF
%token ELSE %token ELSE
%token IN
%token FOR
%token NUMBER %token NUMBER
%token ID %token ID
%token STRING_LITERAL %token STRING_LITERAL
@ -65,10 +67,10 @@ extern FILE *yyin;
%define parse.error verbose %define parse.error verbose
%left GREATER_THAN LESS_THAN %left GREATER_THAN LESS_THAN EQUAL
%left PLUS MINUS %left PLUS MINUS
%left STAR %left STAR PERCENT
%left BANG %left BANG BAR
%left LEFT_PAREN RIGHT_PAREN %left LEFT_PAREN RIGHT_PAREN
%% %%
@ -139,10 +141,12 @@ AccessExpression : Identifier POINT AccessExpression
$$ = $1; $$ = $1;
} }
PrimaryExpression : NUMBER Number : NUMBER
{ {
$$ = MakeNumberNode(yytext); $$ = MakeNumberNode(yytext);
} }
PrimaryExpression : Number
| STRING | STRING
{ {
$$ = MakeStringNode(yytext); $$ = MakeStringNode(yytext);
@ -172,6 +176,10 @@ BinaryExpression : Expression PLUS Expression
{ {
$$ = MakeBinaryNode(Multiply, $1, $3); $$ = MakeBinaryNode(Multiply, $1, $3);
} }
| Expression PERCENT Expression
{
$$ = MakeBinaryNode(Mod, $1, $3);
}
| Expression LESS_THAN Expression | Expression LESS_THAN Expression
{ {
$$ = MakeBinaryNode(LessThan, $1, $3); $$ = MakeBinaryNode(LessThan, $1, $3);
@ -180,6 +188,14 @@ BinaryExpression : Expression PLUS Expression
{ {
$$ = MakeBinaryNode(GreaterThan, $1, $3); $$ = MakeBinaryNode(GreaterThan, $1, $3);
} }
| Expression EQUAL EQUAL Expression
{
$$ = MakeBinaryNode(Equal, $1, $4);
}
| Expression BAR BAR Expression
{
$$ = MakeBinaryNode(LogicalOr, $1, $4);
}
Expression : BinaryExpression Expression : BinaryExpression
| UnaryExpression | UnaryExpression
@ -240,8 +256,14 @@ Conditional : IfStatement
$$ = MakeIfElseNode($1, $3); $$ = MakeIfElseNode($1, $3);
} }
ForStatement : FOR LEFT_PAREN Identifier IN LEFT_BRACKET Number POINT POINT Number RIGHT_BRACKET RIGHT_PAREN LEFT_BRACE Statements RIGHT_BRACE
{
$$ = MakeForLoopNode($3, $6, $9, $13);
}
Statement : PartialStatement SEMICOLON Statement : PartialStatement SEMICOLON
| Conditional | Conditional
| ForStatement
; ;
Statements : Statement Statements : Statement

View File

@ -440,6 +440,23 @@ Node* MakeIfElseNode(
return node; return node;
} }
Node* MakeForLoopNode(
Node *identifierNode,
Node *startNumberNode,
Node *endNumberNode,
Node *statementSequenceNode
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = ForLoop;
node->childCount = 4;
node->children = (Node**) malloc(sizeof(Node*) * 4);
node->children[0] = identifierNode;
node->children[1] = startNumberNode;
node->children[2] = endNumberNode;
node->children[3] = statementSequenceNode;
return node;
}
static const char* PrimitiveTypeToString(PrimitiveType type) static const char* PrimitiveTypeToString(PrimitiveType type)
{ {
switch (type) switch (type)

View File

@ -47,8 +47,11 @@ typedef enum
Add, Add,
Subtract, Subtract,
Multiply, Multiply,
Mod,
Equal,
LessThan, LessThan,
GreaterThan GreaterThan,
LogicalOr
} BinaryOperator; } BinaryOperator;
typedef enum typedef enum
@ -200,6 +203,12 @@ Node* MakeIfElseNode(
Node *ifNode, Node *ifNode,
Node *statementSequenceNode Node *statementSequenceNode
); );
Node* MakeForLoopNode(
Node *identifierNode,
Node *startNumberNode,
Node *endNumberNode,
Node *statementSequenceNode
);
void PrintTree(Node *node, uint32_t tabCount); void PrintTree(Node *node, uint32_t tabCount);

View File

@ -19,6 +19,7 @@ typedef struct LocalVariable
{ {
char *name; char *name;
LLVMValueRef pointer; LLVMValueRef pointer;
LLVMValueRef value;
} LocalVariable; } LocalVariable;
typedef struct FunctionArgument typedef struct FunctionArgument
@ -111,14 +112,19 @@ static void PopScopeFrame(Scope *scope)
scope->scopeStack = realloc(scope->scopeStack, sizeof(ScopeFrame) * scope->scopeStackCount); scope->scopeStack = realloc(scope->scopeStack, sizeof(ScopeFrame) * scope->scopeStackCount);
} }
static void AddLocalVariable(Scope *scope, LLVMValueRef pointer, char *name) static void AddLocalVariable(
{ Scope *scope,
LLVMValueRef pointer, /* can be NULL */
LLVMValueRef value, /* can be NULL */
char *name
) {
ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1]; ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1];
uint32_t index = scopeFrame->localVariableCount; uint32_t index = scopeFrame->localVariableCount;
scopeFrame->localVariables = realloc(scopeFrame->localVariables, sizeof(LocalVariable) * (scopeFrame->localVariableCount + 1)); scopeFrame->localVariables = realloc(scopeFrame->localVariables, sizeof(LocalVariable) * (scopeFrame->localVariableCount + 1));
scopeFrame->localVariables[index].name = strdup(name); scopeFrame->localVariables[index].name = strdup(name);
scopeFrame->localVariables[index].pointer = pointer; scopeFrame->localVariables[index].pointer = pointer;
scopeFrame->localVariables[index].value = value;
scopeFrame->localVariableCount += 1; scopeFrame->localVariableCount += 1;
} }
@ -220,7 +226,14 @@ static LLVMValueRef FindVariableValue(LLVMBuilderRef builder, char *name)
{ {
if (strcmp(scope->scopeStack[i].localVariables[j].name, name) == 0) if (strcmp(scope->scopeStack[i].localVariables[j].name, name) == 0)
{ {
return LLVMBuildLoad(builder, scope->scopeStack[i].localVariables[j].pointer, name); if (scope->scopeStack[i].localVariables[j].value != NULL)
{
return scope->scopeStack[i].localVariables[j].value;
}
else
{
return LLVMBuildLoad(builder, scope->scopeStack[i].localVariables[j].pointer, name);
}
} }
} }
} }
@ -414,6 +427,7 @@ static void AddStructVariablesToScope(
AddLocalVariable( AddLocalVariable(
scope, scope,
elementPointer, elementPointer,
NULL,
structTypeDeclarations[i].fields[j].name structTypeDeclarations[i].fields[j].name
); );
} }
@ -456,6 +470,15 @@ static LLVMValueRef CompileBinaryExpression(
case GreaterThan: case GreaterThan:
return LLVMBuildICmp(builder, LLVMIntSGT, left, right, "greaterThanResult"); return LLVMBuildICmp(builder, LLVMIntSGT, left, right, "greaterThanResult");
case Mod:
return LLVMBuildSRem(builder, left, right, "modResult");
case Equal:
return LLVMBuildICmp(builder, LLVMIntEQ, left, right, "equalResult");
case LogicalOr:
return LLVMBuildOr(builder, left, right, "orResult");
} }
return NULL; return NULL;
@ -581,20 +604,22 @@ static LLVMValueRef CompileExpression(
return NULL; return NULL;
} }
static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement); static LLVMBasicBlockRef CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement);
static void CompileReturn(LLVMBuilderRef builder, Node *returnStatemement) static LLVMBasicBlockRef CompileReturn(LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement)
{ {
LLVMValueRef expression = CompileExpression(builder, returnStatemement->children[0]); LLVMValueRef expression = CompileExpression(builder, returnStatemement->children[0]);
LLVMBuildRet(builder, expression); LLVMBuildRet(builder, expression);
return LLVMGetLastBasicBlock(function);
} }
static void CompileReturnVoid(LLVMBuilderRef builder) static LLVMBasicBlockRef CompileReturnVoid(LLVMBuilderRef builder, LLVMValueRef function)
{ {
LLVMBuildRetVoid(builder); LLVMBuildRetVoid(builder);
return LLVMGetLastBasicBlock(function);
} }
static void CompileAssignment(LLVMBuilderRef builder, Node *assignmentStatement) static LLVMBasicBlockRef CompileAssignment(LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement)
{ {
LLVMValueRef result = CompileExpression(builder, assignmentStatement->children[1]); LLVMValueRef result = CompileExpression(builder, assignmentStatement->children[1]);
LLVMValueRef identifier; LLVMValueRef identifier;
@ -609,14 +634,16 @@ static void CompileAssignment(LLVMBuilderRef builder, Node *assignmentStatement)
else else
{ {
printf("Identifier not found!"); printf("Identifier not found!");
return; return LLVMGetLastBasicBlock(function);
} }
LLVMBuildStore(builder, result, identifier); LLVMBuildStore(builder, result, identifier);
return LLVMGetLastBasicBlock(function);
} }
/* FIXME: path for reference types */ /* FIXME: path for reference types */
static void CompileFunctionVariableDeclaration(LLVMBuilderRef builder, Node *variableDeclaration) static LLVMBasicBlockRef CompileFunctionVariableDeclaration(LLVMBuilderRef builder, LLVMValueRef function, Node *variableDeclaration)
{ {
LLVMValueRef variable; LLVMValueRef variable;
char *variableName = variableDeclaration->children[1]->value.string; char *variableName = variableDeclaration->children[1]->value.string;
@ -627,10 +654,12 @@ static void CompileFunctionVariableDeclaration(LLVMBuilderRef builder, Node *var
free(ptrName); free(ptrName);
AddLocalVariable(scope, variable, variableName); AddLocalVariable(scope, variable, NULL, variableName);
return LLVMGetLastBasicBlock(function);
} }
static void CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement) static LLVMBasicBlockRef CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement)
{ {
uint32_t i; uint32_t i;
LLVMValueRef conditional = CompileExpression(builder, ifStatement->children[0]); LLVMValueRef conditional = CompileExpression(builder, ifStatement->children[0]);
@ -649,9 +678,11 @@ static void CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef function, No
LLVMBuildBr(builder, afterCond); LLVMBuildBr(builder, afterCond);
LLVMPositionBuilderAtEnd(builder, afterCond); LLVMPositionBuilderAtEnd(builder, afterCond);
return afterCond;
} }
static void CompileIfElseStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifElseStatement) static LLVMBasicBlockRef CompileIfElseStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifElseStatement)
{ {
uint32_t i; uint32_t i;
LLVMValueRef conditional = CompileExpression(builder, ifElseStatement->children[0]->children[0]); LLVMValueRef conditional = CompileExpression(builder, ifElseStatement->children[0]->children[0]);
@ -686,45 +717,102 @@ static void CompileIfElseStatement(LLVMBuilderRef builder, LLVMValueRef function
} }
LLVMBuildBr(builder, afterCond); LLVMBuildBr(builder, afterCond);
LLVMPositionBuilderAtEnd(builder, afterCond); LLVMPositionBuilderAtEnd(builder, afterCond);
return afterCond;
} }
static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement) static LLVMBasicBlockRef CompileForLoopStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *forLoopStatement)
{
uint32_t i;
LLVMBasicBlockRef entryBlock = LLVMAppendBasicBlock(function, "loopEntry");
LLVMBasicBlockRef checkBlock = LLVMAppendBasicBlock(function, "loopCheck");
LLVMBasicBlockRef bodyBlock = LLVMAppendBasicBlock(function, "loopBody");
LLVMBasicBlockRef afterLoopBlock = LLVMAppendBasicBlock(function, "afterLoop");
char *iteratorVariableName = forLoopStatement->children[0]->value.string;
PushScopeFrame(scope);
LLVMBuildBr(builder, entryBlock);
LLVMPositionBuilderAtEnd(builder, entryBlock);
LLVMBuildBr(builder, checkBlock);
LLVMPositionBuilderAtEnd(builder, checkBlock);
LLVMValueRef iteratorValue = LLVMBuildPhi(builder, LLVMInt64Type(), iteratorVariableName);
AddLocalVariable(scope, NULL, iteratorValue, iteratorVariableName);
LLVMPositionBuilderAtEnd(builder, bodyBlock);
LLVMValueRef nextValue = LLVMBuildAdd(builder, iteratorValue, LLVMConstInt(LLVMInt64Type(), 1, 0), "next");
LLVMPositionBuilderAtEnd(builder, checkBlock);
LLVMValueRef iteratorEndValue = CompileNumber(forLoopStatement->children[2]);
LLVMValueRef comparison = LLVMBuildICmp(builder, LLVMIntULE, iteratorValue, iteratorEndValue, "iteratorCompare");
LLVMBuildCondBr(builder, comparison, bodyBlock, afterLoopBlock);
LLVMPositionBuilderAtEnd(builder, bodyBlock);
LLVMBasicBlockRef lastBlock;
for (i = 0; i < forLoopStatement->children[3]->childCount; i += 1)
{
lastBlock = CompileStatement(builder, function, forLoopStatement->children[3]->children[i]);
}
LLVMBuildBr(builder, checkBlock);
LLVMPositionBuilderBefore(builder, LLVMGetFirstInstruction(checkBlock));
LLVMValueRef incomingValues[2];
incomingValues[0] = CompileNumber(forLoopStatement->children[1]);
incomingValues[1] = nextValue;
LLVMBasicBlockRef incomingBlocks[2];
incomingBlocks[0] = entryBlock;
incomingBlocks[1] = lastBlock;
LLVMAddIncoming(iteratorValue, incomingValues, incomingBlocks, 2);
LLVMPositionBuilderAtEnd(builder, afterLoopBlock);
PopScopeFrame(scope);
return afterLoopBlock;
}
static LLVMBasicBlockRef CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement)
{ {
switch (statement->syntaxKind) switch (statement->syntaxKind)
{ {
case Assignment: case Assignment:
CompileAssignment(builder, statement); return CompileAssignment(builder, function, statement);
return 0;
case Declaration:
return CompileFunctionVariableDeclaration(builder, function, statement);
case ForLoop:
return CompileForLoopStatement(builder, function, statement);
case FunctionCallExpression: case FunctionCallExpression:
CompileFunctionCallExpression(builder, statement); CompileFunctionCallExpression(builder, statement);
return 0; return LLVMGetLastBasicBlock(function);
case Declaration:
CompileFunctionVariableDeclaration(builder, statement);
return 0;
case IfStatement: case IfStatement:
CompileIfStatement(builder, function, statement); return CompileIfStatement(builder, function, statement);
return 0;
case IfElseStatement: case IfElseStatement:
CompileIfElseStatement(builder, function, statement); return CompileIfElseStatement(builder, function, statement);
return 0;
case Return: case Return:
CompileReturn(builder, statement); return CompileReturn(builder, function, statement);
return 1;
case ReturnVoid: case ReturnVoid:
CompileReturnVoid(builder); return CompileReturnVoid(builder, function);
return 1;
} }
fprintf(stderr, "Unknown statement kind!\n"); fprintf(stderr, "Unknown statement kind!\n");
return 0; return NULL;
} }
static void CompileFunction( static void CompileFunction(
@ -800,14 +888,16 @@ static void CompileFunction(
LLVMValueRef argumentCopy = LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName); LLVMValueRef argumentCopy = LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName);
LLVMBuildStore(builder, argument, argumentCopy); LLVMBuildStore(builder, argument, argumentCopy);
free(ptrName); free(ptrName);
AddLocalVariable(scope, argumentCopy, functionSignature->children[2]->children[i]->children[1]->value.string); AddLocalVariable(scope, argumentCopy, NULL, functionSignature->children[2]->children[i]->children[1]->value.string);
} }
for (i = 0; i < functionBody->childCount; i += 1) for (i = 0; i < functionBody->childCount; i += 1)
{ {
hasReturn |= CompileStatement(builder, function, functionBody->children[i]); CompileStatement(builder, function, functionBody->children[i]);
} }
hasReturn = LLVMGetBasicBlockTerminator(LLVMGetLastBasicBlock(function)) != NULL;
if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn) if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn)
{ {
LLVMBuildRetVoid(builder); LLVMBuildRetVoid(builder);