support function calls

generics
cosmonaut 2021-04-20 10:47:40 -07:00
parent 5c147d80ec
commit b5d256251e
7 changed files with 155 additions and 79 deletions

38
ast.c
View File

@ -26,6 +26,8 @@ const char* SyntaxKindString(SyntaxKind syntaxKind)
case Comment: return "Comment";
case Declaration: return "Declaration";
case DeclarationSequence: return "DeclarationSequence";
case FunctionArgumentSequence: return "FunctionArgumentSequence";
case FunctionCallExpression: return "FunctionCallExpression";
case FunctionDeclaration: return "FunctionDeclaration";
case FunctionSignature: return "FunctionSignature";
case FunctionSignatureArguments: return "FunctionSignatureArguments";
@ -239,6 +241,36 @@ Node* MakeDeclarationSequenceNode(
return node;
}
Node *MakeFunctionArgumentSequenceNode(
Node **pArgumentNodes,
uint32_t argumentCount
) {
int32_t i;
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionArgumentSequence;
node->childCount = argumentCount;
node->children = (Node**) malloc(sizeof(Node*) * node->childCount);
for (i = argumentCount - 1; i >= 0; i -= 1)
{
node->children[argumentCount - 1 - i] = pArgumentNodes[i];
}
return node;
}
Node* MakeFunctionCallExpressionNode(
Node *identifierNode,
Node *argumentSequenceNode
) {
int32_t i;
Node* node = (Node*) malloc(sizeof(Node));
node->syntaxKind = FunctionCallExpression;
node->children = (Node**) malloc(sizeof(Node*) * 2);
node->childCount = 2;
node->children[0] = identifierNode;
node->children[1] = argumentSequenceNode;
return node;
}
static const char* PrimitiveTypeToString(PrimitiveType type)
{
switch (type)
@ -262,6 +294,10 @@ static void PrintBinaryOperator(BinaryOperator expression)
case Subtract:
printf("-");
break;
case Multiply:
printf("*");
break;
}
}
@ -308,5 +344,3 @@ void PrintTree(Node *node, uint32_t tabCount)
PrintTree(node->children[i], tabCount + 1);
}
}

13
ast.h
View File

@ -12,6 +12,8 @@ typedef enum
DeclarationSequence,
Expression,
ForLoop,
FunctionArgumentSequence,
FunctionCallExpression,
FunctionDeclaration,
FunctionSignature,
FunctionSignatureArguments,
@ -33,7 +35,8 @@ typedef enum
typedef enum
{
Add,
Subtract
Subtract,
Multiply
} BinaryOperator;
typedef enum
@ -130,6 +133,14 @@ Node* MakeDeclarationSequenceNode(
Node **pNodes,
uint32_t nodeCount
);
Node *MakeFunctionArgumentSequenceNode(
Node **pArgumentNodes,
uint32_t argumentCount
);
Node* MakeFunctionCallExpressionNode(
Node *identifierNode,
Node *argumentSequenceNode
);
void PrintTree(Node *node, uint32_t tabCount);

View File

@ -14,13 +14,13 @@ extern FILE *yyin;
Stack *stack;
Node *rootNode;
typedef struct VariableMapValue
typedef struct IdentifierMapValue
{
char *name;
LLVMValueRef variable;
} VariableMapValue;
LLVMValueRef value;
} IdentifierMapValue;
VariableMapValue *namedVariables;
IdentifierMapValue *namedVariables;
uint32_t namedVariableCount;
static LLVMValueRef CompileExpression(
@ -32,11 +32,11 @@ static LLVMValueRef CompileExpression(
static void AddNamedVariable(char *name, LLVMValueRef variable)
{
VariableMapValue mapValue;
IdentifierMapValue mapValue;
mapValue.name = name;
mapValue.variable = variable;
mapValue.value = variable;
namedVariables = realloc(namedVariables, namedVariableCount + 1);
namedVariables = realloc(namedVariables, sizeof(IdentifierMapValue) * (namedVariableCount + 1));
namedVariables[namedVariableCount] = mapValue;
namedVariableCount += 1;
@ -50,7 +50,7 @@ static LLVMValueRef FindVariableByName(char *name)
{
if (strcmp(namedVariables[i].name, name) == 0)
{
return namedVariables[i].variable;
return namedVariables[i].value;
}
}
@ -76,8 +76,7 @@ static LLVMValueRef CompileNumber(
LLVMBuilderRef builder,
LLVMValueRef function,
Node *numberExpression
)
{
) {
return LLVMConstInt(LLVMInt64Type(), numberExpression->value.number, 0);
}
@ -94,11 +93,36 @@ static LLVMValueRef CompileBinaryExpression(
{
case Add:
return LLVMBuildAdd(builder, left, right, "tmp");
case Subtract:
return LLVMBuildSub(builder, left, right, "tmp");
case Multiply:
return LLVMBuildMul(builder, left, right, "tmp");
}
return NULL;
}
static LLVMValueRef CompileFunctionCallExpression(
LLVMModuleRef module,
LLVMBuilderRef builder,
LLVMValueRef function,
Node *expression
) {
uint32_t i;
uint32_t argumentCount = expression->children[1]->childCount;
LLVMValueRef args[argumentCount];
for (i = 0; i < argumentCount; i += 1)
{
args[i] = CompileExpression(module, builder, function, expression->children[1]->children[i]);
}
return LLVMBuildCall(builder, FindVariableByName(expression->children[0]->value.string), args, argumentCount, "tmp");
}
static LLVMValueRef CompileExpression(
LLVMModuleRef module,
LLVMBuilderRef builder,
@ -110,6 +134,9 @@ static LLVMValueRef CompileExpression(
case BinaryExpression:
return CompileBinaryExpression(module, builder, function, expression);
case FunctionCallExpression:
return CompileFunctionCallExpression(module, builder, function, expression);
case Identifier:
return FindVariableByName(expression->value.string);
@ -167,6 +194,8 @@ static void CompileFunction(LLVMModuleRef module, Node *functionDeclaration)
{
CompileStatement(module, builder, function, functionBody->children[i]);
}
AddNamedVariable(functionSignature->children[0]->value.string, function);
}
static void Compile(LLVMModuleRef module, Node *node)

51
stack.c
View File

@ -10,12 +10,9 @@ Stack* CreateStack()
stack->stackFrames = (StackFrame*) malloc(sizeof(StackFrame) * stack->stackCapacity);
for (i = 0; i < stack->stackCapacity; i += 1)
{
stack->stackFrames[i].statements = NULL;
stack->stackFrames[i].statementCapacity = 0;
stack->stackFrames[i].statementCount = 0;
stack->stackFrames[i].declarations = NULL;
stack->stackFrames[i].declarationCapacity = 0;
stack->stackFrames[i].declarationCount = 0;
stack->stackFrames[i].nodes = NULL;
stack->stackFrames[i].nodeCapacity = 0;
stack->stackFrames[i].nodeCount = 0;
}
stack->stackIndex = 0;
return stack;
@ -30,12 +27,10 @@ void PushStackFrame(Stack *stack)
stack->stackCapacity += 1;
stack->stackFrames = (StackFrame*) realloc(stack->stackFrames, sizeof(StackFrame) * stack->stackCapacity);
stack->stackFrames[stack->stackIndex].statementCapacity = 0;
stack->stackFrames[stack->stackIndex].declarationCapacity = 0;
stack->stackFrames[stack->stackIndex].nodeCapacity = 0;
}
stack->stackFrames[stack->stackIndex].statementCount = 0;
stack->stackFrames[stack->stackIndex].declarationCount = 0;
stack->stackFrames[stack->stackIndex].nodeCount = 0;
}
void PopStackFrame(Stack *stack)
@ -43,38 +38,20 @@ void PopStackFrame(Stack *stack)
stack->stackIndex -= 1;
}
void AddStatement(Stack *stack, Node *statementNode)
void AddNode(Stack *stack, Node *statementNode)
{
StackFrame *stackFrame = &stack->stackFrames[stack->stackIndex];
if (stackFrame->statementCount == stackFrame->statementCapacity)
if (stackFrame->nodeCount == stackFrame->nodeCapacity)
{
stackFrame->statementCapacity += 1;
stackFrame->statements = (Node**) realloc(stackFrame->statements, stackFrame->statementCapacity);
stackFrame->nodeCapacity += 1;
stackFrame->nodes = (Node**) realloc(stackFrame->nodes, stackFrame->nodeCapacity);
}
stackFrame->statements[stackFrame->statementCount] = statementNode;
stackFrame->statementCount += 1;
stackFrame->nodes[stackFrame->nodeCount] = statementNode;
stackFrame->nodeCount += 1;
}
Node** GetStatements(Stack *stack, uint32_t *pCount)
Node** GetNodes(Stack *stack, uint32_t *pCount)
{
*pCount = stack->stackFrames[stack->stackIndex].statementCount;
return stack->stackFrames[stack->stackIndex].statements;
}
void AddDeclaration(Stack *stack, Node *declarationNode)
{
StackFrame *stackFrame = &stack->stackFrames[stack->stackIndex];
if (stackFrame->declarationCount == stackFrame->declarationCapacity)
{
stackFrame->declarationCapacity += 1;
stackFrame->declarations = (Node**) realloc(stackFrame->declarations, stackFrame->declarationCapacity);
}
stackFrame->declarations[stackFrame->declarationCount] = declarationNode;
stackFrame->declarationCount += 1;
}
Node** GetDeclarations(Stack *stack, uint32_t *pCount)
{
*pCount = stack->stackFrames[stack->stackIndex].declarationCount;
return stack->stackFrames[stack->stackIndex].declarations;
*pCount = stack->stackFrames[stack->stackIndex].nodeCount;
return stack->stackFrames[stack->stackIndex].nodes;
}

16
stack.h
View File

@ -5,13 +5,9 @@
typedef struct StackFrame
{
Node **statements;
uint32_t statementCount;
uint32_t statementCapacity;
Node **declarations;
uint32_t declarationCount;
uint32_t declarationCapacity;
Node **nodes;
uint32_t nodeCount;
uint32_t nodeCapacity;
} StackFrame;
typedef struct Stack
@ -25,9 +21,7 @@ Stack* CreateStack();
void PushStackFrame(Stack *stack);
void PopStackFrame(Stack *stack);
void AddStatement(Stack *stack, Node *statementNode);
Node** GetStatements(Stack *stack, uint32_t *pCount);
void AddDeclaration(Stack *stack, Node *declarationNode);
Node** GetDeclarations(Stack *stack, uint32_t *pCount);
void AddNode(Stack *stack, Node *statementNode);
Node** GetNodes(Stack *stack, uint32_t *pCount);
#endif /* WRAITH_STACK_H */

View File

@ -18,9 +18,9 @@
\"[a-zA-Z][a-zA-Z0-9]*\" return STRING_LITERAL;
"+" return PLUS;
"-" return MINUS;
"*" return MULTIPLY;
"/" return DIVIDE;
"%" return MOD;
"*" return STAR;
"/" return SLASH;
"%" return PERCENT;
"<" return LESS_THAN;
">" return GREATER_THAN;
"=" return EQUAL;

View File

@ -34,9 +34,9 @@ extern Node *rootNode;
%token STRING_LITERAL
%token PLUS
%token MINUS
%token MULTIPLY
%token DIVIDE
%token MOD
%token STAR
%token SLASH
%token PERCENT
%token EQUAL
%token LESS_THAN
%token GREATER_THAN
@ -71,7 +71,7 @@ Program : Declarations
Node *declarationSequence;
uint32_t declarationCount;
declarations = GetDeclarations(stack, &declarationCount);
declarations = GetNodes(stack, &declarationCount);
declarationSequence = MakeDeclarationSequenceNode(declarations, declarationCount);
PopStackFrame(stack);
@ -123,6 +123,7 @@ PrimaryExpression : Identifier
{
$$ = $2;
}
| FunctionCallExpression
;
UnaryExpression : BANG Expression
@ -138,6 +139,10 @@ BinaryExpression : Expression PLUS Expression
{
$$ = MakeBinaryNode(Subtract, $1, $3);
}
| Expression STAR Expression
{
$$ = MakeBinaryNode(Multiply, $1, $3);
}
Expression : PrimaryExpression
| UnaryExpression
@ -163,8 +168,20 @@ ReturnStatement : RETURN Expression
$$ = MakeReturnStatementNode($2);
}
FunctionCallExpression : Identifier LEFT_PAREN Arguments RIGHT_PAREN
{
Node **arguments;
uint32_t argumentCount;
arguments = GetNodes(stack, &argumentCount);
$$ = MakeFunctionCallExpressionNode($1, MakeFunctionArgumentSequenceNode(arguments, argumentCount));
PopStackFrame(stack);
}
PartialStatement : AssignmentStatement
| VariableDeclaration
| FunctionCallExpression
| ReturnStatement
;
@ -172,22 +189,36 @@ Statement : PartialStatement SEMICOLON;
Statements : Statement Statements
{
AddStatement(stack, $1);
AddNode(stack, $1);
}
|
{
PushStackFrame(stack);
}
Arguments : VariableDeclaration COMMA VariableDeclarations
Arguments : PrimaryExpression COMMA Arguments
{
AddDeclaration(stack, $1);
AddNode(stack, $1);
}
| PrimaryExpression
{
PushStackFrame(stack);
AddNode(stack, $1);
}
|
;
SignatureArguments : VariableDeclaration COMMA VariableDeclarations
{
AddNode(stack, $1);
}
| VariableDeclaration
{
PushStackFrame(stack);
AddDeclaration(stack, $1);
AddNode(stack, $1);
}
|
;
Body : LEFT_BRACE Statements RIGHT_BRACE
{
@ -195,19 +226,19 @@ Body : LEFT_BRACE Statements RIGHT_BRACE
Node *statementSequence;
uint32_t statementCount;
statements = GetStatements(stack, &statementCount);
statements = GetNodes(stack, &statementCount);
statementSequence = MakeStatementSequenceNode(statements, statementCount);
$$ = MakeStatementSequenceNode(statements, statementCount);
PopStackFrame(stack);
}
FunctionSignature : Type Identifier LEFT_PAREN Arguments RIGHT_PAREN
FunctionSignature : Type Identifier LEFT_PAREN SignatureArguments RIGHT_PAREN
{
Node **declarations;
uint32_t declarationCount;
declarations = GetDeclarations(stack, &declarationCount);
declarations = GetNodes(stack, &declarationCount);
$$ = MakeFunctionSignatureNode($2, $1, MakeFunctionSignatureArgumentsNode(declarations, declarationCount));
PopStackFrame(stack);
@ -220,7 +251,7 @@ FunctionDeclaration : FunctionSignature Body
VariableDeclarations : VariableDeclaration SEMICOLON VariableDeclarations
{
AddDeclaration(stack, $1);
AddNode(stack, $1);
}
|
{
@ -233,7 +264,7 @@ StructDeclaration : STRUCT Identifier LEFT_BRACE VariableDeclarations RIGH
Node *declarationSequence;
uint32_t declarationCount;
declarations = GetDeclarations(stack, &declarationCount);
declarations = GetNodes(stack, &declarationCount);
declarationSequence = MakeDeclarationSequenceNode(declarations, declarationCount);
$$ = MakeStructDeclarationNode($2, declarationSequence);
@ -246,7 +277,7 @@ Declaration : StructDeclaration
Declarations : Declaration Declarations
{
AddDeclaration(stack, $1);
AddNode(stack, $1);
}
|
{