initial for loop range implementation

generics
cosmonaut 2021-04-29 23:49:35 -07:00
parent c728dd6b8c
commit 62f42e47b9
6 changed files with 196 additions and 37 deletions

19
euler001.w Normal file
View File

@ -0,0 +1,19 @@
struct Program
{
static Main(): int
{
sum: int;
sum = 0;
for (i in [1..1000])
{
if ((i % 3 == 0) || (i % 5 == 0))
{
sum = sum + i;
}
}
return sum;
}
}

View File

@ -19,6 +19,8 @@
"alloc" return ALLOC;
"if" return IF;
"else" return ELSE;
"in" return IN;
"for" return FOR;
[0-9]+ return NUMBER;
[a-zA-Z][a-zA-Z0-9]* return ID;
\"[a-zA-Z][a-zA-Z0-9]*\" return STRING_LITERAL;

View File

@ -32,6 +32,8 @@ extern FILE *yyin;
%token ALLOC
%token IF
%token ELSE
%token IN
%token FOR
%token NUMBER
%token ID
%token STRING_LITERAL
@ -65,10 +67,10 @@ extern FILE *yyin;
%define parse.error verbose
%left GREATER_THAN LESS_THAN
%left GREATER_THAN LESS_THAN EQUAL
%left PLUS MINUS
%left STAR
%left BANG
%left STAR PERCENT
%left BANG BAR
%left LEFT_PAREN RIGHT_PAREN
%%
@ -139,10 +141,12 @@ AccessExpression : Identifier POINT AccessExpression
$$ = $1;
}
PrimaryExpression : NUMBER
Number : NUMBER
{
$$ = MakeNumberNode(yytext);
}
PrimaryExpression : Number
| STRING
{
$$ = MakeStringNode(yytext);
@ -172,6 +176,10 @@ BinaryExpression : Expression PLUS Expression
{
$$ = MakeBinaryNode(Multiply, $1, $3);
}
| Expression PERCENT Expression
{
$$ = MakeBinaryNode(Mod, $1, $3);
}
| Expression LESS_THAN Expression
{
$$ = MakeBinaryNode(LessThan, $1, $3);
@ -180,6 +188,14 @@ BinaryExpression : Expression PLUS Expression
{
$$ = MakeBinaryNode(GreaterThan, $1, $3);
}
| Expression EQUAL EQUAL Expression
{
$$ = MakeBinaryNode(Equal, $1, $4);
}
| Expression BAR BAR Expression
{
$$ = MakeBinaryNode(LogicalOr, $1, $4);
}
Expression : BinaryExpression
| UnaryExpression
@ -240,8 +256,14 @@ Conditional : IfStatement
$$ = MakeIfElseNode($1, $3);
}
ForStatement : FOR LEFT_PAREN Identifier IN LEFT_BRACKET Number POINT POINT Number RIGHT_BRACKET RIGHT_PAREN LEFT_BRACE Statements RIGHT_BRACE
{
$$ = MakeForLoopNode($3, $6, $9, $13);
}
Statement : PartialStatement SEMICOLON
| Conditional
| ForStatement
;
Statements : Statement

View File

@ -440,6 +440,23 @@ Node* MakeIfElseNode(
return node;
}
Node* MakeForLoopNode(
Node *identifierNode,
Node *startNumberNode,
Node *endNumberNode,
Node *statementSequenceNode
) {
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = ForLoop;
node->childCount = 4;
node->children = (Node**) malloc(sizeof(Node*) * 4);
node->children[0] = identifierNode;
node->children[1] = startNumberNode;
node->children[2] = endNumberNode;
node->children[3] = statementSequenceNode;
return node;
}
static const char* PrimitiveTypeToString(PrimitiveType type)
{
switch (type)

View File

@ -47,8 +47,11 @@ typedef enum
Add,
Subtract,
Multiply,
Mod,
Equal,
LessThan,
GreaterThan
GreaterThan,
LogicalOr
} BinaryOperator;
typedef enum
@ -200,6 +203,12 @@ Node* MakeIfElseNode(
Node *ifNode,
Node *statementSequenceNode
);
Node* MakeForLoopNode(
Node *identifierNode,
Node *startNumberNode,
Node *endNumberNode,
Node *statementSequenceNode
);
void PrintTree(Node *node, uint32_t tabCount);

View File

@ -19,6 +19,7 @@ typedef struct LocalVariable
{
char *name;
LLVMValueRef pointer;
LLVMValueRef value;
} LocalVariable;
typedef struct FunctionArgument
@ -111,14 +112,19 @@ static void PopScopeFrame(Scope *scope)
scope->scopeStack = realloc(scope->scopeStack, sizeof(ScopeFrame) * scope->scopeStackCount);
}
static void AddLocalVariable(Scope *scope, LLVMValueRef pointer, char *name)
{
static void AddLocalVariable(
Scope *scope,
LLVMValueRef pointer, /* can be NULL */
LLVMValueRef value, /* can be NULL */
char *name
) {
ScopeFrame *scopeFrame = &scope->scopeStack[scope->scopeStackCount - 1];
uint32_t index = scopeFrame->localVariableCount;
scopeFrame->localVariables = realloc(scopeFrame->localVariables, sizeof(LocalVariable) * (scopeFrame->localVariableCount + 1));
scopeFrame->localVariables[index].name = strdup(name);
scopeFrame->localVariables[index].pointer = pointer;
scopeFrame->localVariables[index].value = value;
scopeFrame->localVariableCount += 1;
}
@ -219,11 +225,18 @@ static LLVMValueRef FindVariableValue(LLVMBuilderRef builder, char *name)
for (j = 0; j < scope->scopeStack[i].localVariableCount; j += 1)
{
if (strcmp(scope->scopeStack[i].localVariables[j].name, name) == 0)
{
if (scope->scopeStack[i].localVariables[j].value != NULL)
{
return scope->scopeStack[i].localVariables[j].value;
}
else
{
return LLVMBuildLoad(builder, scope->scopeStack[i].localVariables[j].pointer, name);
}
}
}
}
printf("Failed to find variable value!");
return NULL;
@ -414,6 +427,7 @@ static void AddStructVariablesToScope(
AddLocalVariable(
scope,
elementPointer,
NULL,
structTypeDeclarations[i].fields[j].name
);
}
@ -456,6 +470,15 @@ static LLVMValueRef CompileBinaryExpression(
case GreaterThan:
return LLVMBuildICmp(builder, LLVMIntSGT, left, right, "greaterThanResult");
case Mod:
return LLVMBuildSRem(builder, left, right, "modResult");
case Equal:
return LLVMBuildICmp(builder, LLVMIntEQ, left, right, "equalResult");
case LogicalOr:
return LLVMBuildOr(builder, left, right, "orResult");
}
return NULL;
@ -581,20 +604,22 @@ static LLVMValueRef CompileExpression(
return NULL;
}
static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement);
static LLVMBasicBlockRef CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement);
static void CompileReturn(LLVMBuilderRef builder, Node *returnStatemement)
static LLVMBasicBlockRef CompileReturn(LLVMBuilderRef builder, LLVMValueRef function, Node *returnStatemement)
{
LLVMValueRef expression = CompileExpression(builder, returnStatemement->children[0]);
LLVMBuildRet(builder, expression);
return LLVMGetLastBasicBlock(function);
}
static void CompileReturnVoid(LLVMBuilderRef builder)
static LLVMBasicBlockRef CompileReturnVoid(LLVMBuilderRef builder, LLVMValueRef function)
{
LLVMBuildRetVoid(builder);
return LLVMGetLastBasicBlock(function);
}
static void CompileAssignment(LLVMBuilderRef builder, Node *assignmentStatement)
static LLVMBasicBlockRef CompileAssignment(LLVMBuilderRef builder, LLVMValueRef function, Node *assignmentStatement)
{
LLVMValueRef result = CompileExpression(builder, assignmentStatement->children[1]);
LLVMValueRef identifier;
@ -609,14 +634,16 @@ static void CompileAssignment(LLVMBuilderRef builder, Node *assignmentStatement)
else
{
printf("Identifier not found!");
return;
return LLVMGetLastBasicBlock(function);
}
LLVMBuildStore(builder, result, identifier);
return LLVMGetLastBasicBlock(function);
}
/* FIXME: path for reference types */
static void CompileFunctionVariableDeclaration(LLVMBuilderRef builder, Node *variableDeclaration)
static LLVMBasicBlockRef CompileFunctionVariableDeclaration(LLVMBuilderRef builder, LLVMValueRef function, Node *variableDeclaration)
{
LLVMValueRef variable;
char *variableName = variableDeclaration->children[1]->value.string;
@ -627,10 +654,12 @@ static void CompileFunctionVariableDeclaration(LLVMBuilderRef builder, Node *var
free(ptrName);
AddLocalVariable(scope, variable, variableName);
AddLocalVariable(scope, variable, NULL, variableName);
return LLVMGetLastBasicBlock(function);
}
static void CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement)
static LLVMBasicBlockRef CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifStatement)
{
uint32_t i;
LLVMValueRef conditional = CompileExpression(builder, ifStatement->children[0]);
@ -649,9 +678,11 @@ static void CompileIfStatement(LLVMBuilderRef builder, LLVMValueRef function, No
LLVMBuildBr(builder, afterCond);
LLVMPositionBuilderAtEnd(builder, afterCond);
return afterCond;
}
static void CompileIfElseStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifElseStatement)
static LLVMBasicBlockRef CompileIfElseStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *ifElseStatement)
{
uint32_t i;
LLVMValueRef conditional = CompileExpression(builder, ifElseStatement->children[0]->children[0]);
@ -686,45 +717,102 @@ static void CompileIfElseStatement(LLVMBuilderRef builder, LLVMValueRef function
}
LLVMBuildBr(builder, afterCond);
LLVMPositionBuilderAtEnd(builder, afterCond);
return afterCond;
}
static uint8_t CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement)
static LLVMBasicBlockRef CompileForLoopStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *forLoopStatement)
{
uint32_t i;
LLVMBasicBlockRef entryBlock = LLVMAppendBasicBlock(function, "loopEntry");
LLVMBasicBlockRef checkBlock = LLVMAppendBasicBlock(function, "loopCheck");
LLVMBasicBlockRef bodyBlock = LLVMAppendBasicBlock(function, "loopBody");
LLVMBasicBlockRef afterLoopBlock = LLVMAppendBasicBlock(function, "afterLoop");
char *iteratorVariableName = forLoopStatement->children[0]->value.string;
PushScopeFrame(scope);
LLVMBuildBr(builder, entryBlock);
LLVMPositionBuilderAtEnd(builder, entryBlock);
LLVMBuildBr(builder, checkBlock);
LLVMPositionBuilderAtEnd(builder, checkBlock);
LLVMValueRef iteratorValue = LLVMBuildPhi(builder, LLVMInt64Type(), iteratorVariableName);
AddLocalVariable(scope, NULL, iteratorValue, iteratorVariableName);
LLVMPositionBuilderAtEnd(builder, bodyBlock);
LLVMValueRef nextValue = LLVMBuildAdd(builder, iteratorValue, LLVMConstInt(LLVMInt64Type(), 1, 0), "next");
LLVMPositionBuilderAtEnd(builder, checkBlock);
LLVMValueRef iteratorEndValue = CompileNumber(forLoopStatement->children[2]);
LLVMValueRef comparison = LLVMBuildICmp(builder, LLVMIntULE, iteratorValue, iteratorEndValue, "iteratorCompare");
LLVMBuildCondBr(builder, comparison, bodyBlock, afterLoopBlock);
LLVMPositionBuilderAtEnd(builder, bodyBlock);
LLVMBasicBlockRef lastBlock;
for (i = 0; i < forLoopStatement->children[3]->childCount; i += 1)
{
lastBlock = CompileStatement(builder, function, forLoopStatement->children[3]->children[i]);
}
LLVMBuildBr(builder, checkBlock);
LLVMPositionBuilderBefore(builder, LLVMGetFirstInstruction(checkBlock));
LLVMValueRef incomingValues[2];
incomingValues[0] = CompileNumber(forLoopStatement->children[1]);
incomingValues[1] = nextValue;
LLVMBasicBlockRef incomingBlocks[2];
incomingBlocks[0] = entryBlock;
incomingBlocks[1] = lastBlock;
LLVMAddIncoming(iteratorValue, incomingValues, incomingBlocks, 2);
LLVMPositionBuilderAtEnd(builder, afterLoopBlock);
PopScopeFrame(scope);
return afterLoopBlock;
}
static LLVMBasicBlockRef CompileStatement(LLVMBuilderRef builder, LLVMValueRef function, Node *statement)
{
switch (statement->syntaxKind)
{
case Assignment:
CompileAssignment(builder, statement);
return 0;
return CompileAssignment(builder, function, statement);
case Declaration:
return CompileFunctionVariableDeclaration(builder, function, statement);
case ForLoop:
return CompileForLoopStatement(builder, function, statement);
case FunctionCallExpression:
CompileFunctionCallExpression(builder, statement);
return 0;
case Declaration:
CompileFunctionVariableDeclaration(builder, statement);
return 0;
return LLVMGetLastBasicBlock(function);
case IfStatement:
CompileIfStatement(builder, function, statement);
return 0;
return CompileIfStatement(builder, function, statement);
case IfElseStatement:
CompileIfElseStatement(builder, function, statement);
return 0;
return CompileIfElseStatement(builder, function, statement);
case Return:
CompileReturn(builder, statement);
return 1;
return CompileReturn(builder, function, statement);
case ReturnVoid:
CompileReturnVoid(builder);
return 1;
return CompileReturnVoid(builder, function);
}
fprintf(stderr, "Unknown statement kind!\n");
return 0;
return NULL;
}
static void CompileFunction(
@ -800,14 +888,16 @@ static void CompileFunction(
LLVMValueRef argumentCopy = LLVMBuildAlloca(builder, LLVMTypeOf(argument), ptrName);
LLVMBuildStore(builder, argument, argumentCopy);
free(ptrName);
AddLocalVariable(scope, argumentCopy, functionSignature->children[2]->children[i]->children[1]->value.string);
AddLocalVariable(scope, argumentCopy, NULL, functionSignature->children[2]->children[i]->children[1]->value.string);
}
for (i = 0; i < functionBody->childCount; i += 1)
{
hasReturn |= CompileStatement(builder, function, functionBody->children[i]);
CompileStatement(builder, function, functionBody->children[i]);
}
hasReturn = LLVMGetBasicBlockTerminator(LLVMGetLastBasicBlock(function)) != NULL;
if (LLVMGetTypeKind(returnType) == LLVMVoidTypeKind && !hasReturn)
{
LLVMBuildRetVoid(builder);