Compare commits

...

9 Commits

13 changed files with 1161 additions and 200 deletions

View File

@ -43,10 +43,14 @@ add_executable(
src/codegen.h
src/identcheck.h
src/parser.h
src/typeutils.h
src/util.h
src/ast.c
src/codegen.c
src/identcheck.c
src/parser.c
src/typeutils.c
src/util.c
src/main.c
# Generated code
${BISON_Parser_OUTPUTS}

View File

@ -307,14 +307,38 @@ Body : LEFT_BRACE Statements RIGHT_BRACE
$$ = $2;
}
FunctionSignature : Identifier LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type
GenericArgument : Identifier
{
$$ = MakeFunctionSignatureNode($1, $6, $3, MakeFunctionModifiersNode(NULL, 0));
$$ = MakeGenericArgumentNode($1, NULL);
}
| STATIC Identifier LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type
GenericArguments : GenericArgument
{
$$ = StartGenericArgumentsNode($1);
}
| GenericArguments COMMA GenericArgument
{
$$ = AddGenericArgument($1, $3);
}
GenericArgumentsClause : LESS_THAN GenericArguments GREATER_THAN
{
$$ = $2;
}
|
{
$$ = MakeEmptyGenericArgumentsNode();
}
FunctionSignature : Identifier GenericArgumentsClause LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type
{
$$ = MakeFunctionSignatureNode($1, $7, $4, MakeFunctionModifiersNode(NULL, 0), $2);
}
| STATIC Identifier GenericArgumentsClause LEFT_PAREN SignatureArguments RIGHT_PAREN COLON Type
{
Node *modifier = MakeStaticNode();
$$ = MakeFunctionSignatureNode($2, $7, $4, MakeFunctionModifiersNode(&modifier, 1));
$$ = MakeFunctionSignatureNode($2, $8, $5, MakeFunctionModifiersNode(&modifier, 1), $3);
}
FunctionDeclaration : FunctionSignature Body

18
generic.w Normal file
View File

@ -0,0 +1,18 @@
struct Foo {
static Func2<U>(u: U) : U {
return u;
}
static Func<T>(t: T): T {
foo: T = t;
return Func2(foo);
}
}
struct Program {
static main(): int {
x: int = 4;
y: int = Foo.Func(x);
return x;
}
}

256
src/ast.c
View File

@ -39,6 +39,12 @@ const char *SyntaxKindString(SyntaxKind syntaxKind)
return "FunctionSignature";
case FunctionSignatureArguments:
return "FunctionSignatureArguments";
case GenericArgument:
return "GenericArgument";
case GenericArguments:
return "GenericArguments";
case GenericTypeNode:
return "GenericTypeNode";
case Identifier:
return "Identifier";
case IfStatement:
@ -271,7 +277,8 @@ Node *MakeFunctionSignatureNode(
Node *identifierNode,
Node *typeNode,
Node *arguments,
Node *modifiersNode)
Node *modifiersNode,
Node *genericArgumentsNode)
{
Node *node = (Node *)malloc(sizeof(Node));
node->syntaxKind = FunctionSignature;
@ -279,6 +286,7 @@ Node *MakeFunctionSignatureNode(
node->functionSignature.type = typeNode;
node->functionSignature.arguments = arguments;
node->functionSignature.modifiers = modifiersNode;
node->functionSignature.genericArguments = genericArgumentsNode;
return node;
}
@ -359,6 +367,54 @@ Node *MakeEmptyFunctionArgumentSequenceNode()
return node;
}
Node *MakeGenericArgumentNode(Node *identifierNode, Node *constraintNode)
{
Node *node = (Node *)malloc(sizeof(Node));
node->syntaxKind = GenericArgument;
node->genericArgument.identifier = identifierNode;
node->genericArgument.constraint = constraintNode;
return node;
}
Node *StartGenericArgumentsNode(Node *genericArgumentNode)
{
Node *node = (Node *)malloc(sizeof(Node));
node->syntaxKind = GenericArguments;
node->genericArguments.arguments = (Node **)malloc(sizeof(Node *));
node->genericArguments.arguments[0] = genericArgumentNode;
node->genericArguments.count = 1;
return node;
}
Node *AddGenericArgument(Node *genericArgumentsNode, Node *genericArgumentNode)
{
genericArgumentsNode->genericArguments.arguments = (Node **)realloc(
genericArgumentsNode->genericArguments.arguments,
sizeof(Node *) * (genericArgumentsNode->genericArguments.count + 1));
genericArgumentsNode->genericArguments
.arguments[genericArgumentsNode->genericArguments.count] =
genericArgumentNode;
genericArgumentsNode->genericArguments.count += 1;
return genericArgumentsNode;
}
Node *MakeEmptyGenericArgumentsNode()
{
Node *node = (Node *)malloc(sizeof(Node));
node->syntaxKind = GenericArguments;
node->genericArguments.arguments = NULL;
node->genericArguments.count = 0;
return node;
}
Node *MakeGenericTypeNode(char *name)
{
Node *node = (Node *)malloc(sizeof(Node));
node->syntaxKind = GenericTypeNode;
node->genericType.name = strdup(name);
return node;
}
Node *MakeFunctionCallExpressionNode(
Node *identifierNode,
Node *argumentSequenceNode)
@ -557,6 +613,7 @@ void PrintNode(Node *node, uint32_t tabCount)
case FunctionSignature:
printf("\n");
PrintNode(node->functionSignature.identifier, tabCount + 1);
PrintNode(node->functionSignature.genericArguments, tabCount + 1);
PrintNode(node->functionSignature.arguments, tabCount + 1);
PrintNode(node->functionSignature.type, tabCount + 1);
PrintNode(node->functionSignature.modifiers, tabCount + 1);
@ -572,6 +629,24 @@ void PrintNode(Node *node, uint32_t tabCount)
}
return;
case GenericArgument:
printf("\n");
PrintNode(node->genericArgument.identifier, tabCount + 1);
/* Constraint nodes are not implemented. */
/* PrintNode(node->genericArgument.constraint, tabCount + 1); */
return;
case GenericArguments:
printf("\n");
for (i = 0; i < node->genericArguments.count; i += 1) {
PrintNode(node->genericArguments.arguments[i], tabCount + 1);
}
return;
case GenericTypeNode:
printf("%s\n", node->genericType.name);
return;
case Identifier:
if (node->typeTag == NULL)
{
@ -651,6 +726,165 @@ void PrintNode(Node *node, uint32_t tabCount)
}
}
void Recurse(Node *node, void (*func)(Node*))
{
uint32_t i;
switch (node->syntaxKind)
{
case AccessExpression:
func(node->accessExpression.accessee);
func(node->accessExpression.accessor);
return;
case AllocExpression:
func(node->allocExpression.type);
return;
case Assignment:
func(node->assignmentStatement.left);
func(node->assignmentStatement.right);
return;
case BinaryExpression:
func(node->binaryExpression.left);
func(node->binaryExpression.right);
return;
case Comment:
return;
case CustomTypeNode:
return;
case Declaration:
func(node->declaration.type);
func(node->declaration.identifier);
return;
case DeclarationSequence:
for (i = 0; i < node->declarationSequence.count; i += 1) {
func(node->declarationSequence.sequence[i]);
}
return;
case ForLoop:
func(node->forLoop.declaration);
func(node->forLoop.startNumber);
func(node->forLoop.endNumber);
func(node->forLoop.statementSequence);
return;
case FunctionArgumentSequence:
for (i = 0; i < node->functionArgumentSequence.count; i += 1) {
func(node->functionArgumentSequence.sequence[i]);
}
return;
case FunctionCallExpression:
func(node->functionCallExpression.identifier);
func(node->functionCallExpression.argumentSequence);
return;
case FunctionDeclaration:
func(node->functionDeclaration.functionSignature);
func(node->functionDeclaration.functionBody);
return;
case FunctionModifiers:
for (i = 0; i < node->functionModifiers.count; i += 1) {
func(node->functionModifiers.sequence[i]);
}
return;
case FunctionSignature:
func(node->functionSignature.identifier);
func(node->functionSignature.type);
func(node->functionSignature.arguments);
func(node->functionSignature.modifiers);
func(node->functionSignature.genericArguments);
return;
case FunctionSignatureArguments:
for (i = 0; i < node->functionSignatureArguments.count; i += 1) {
func(node->functionSignatureArguments.sequence[i]);
}
return;
case GenericArgument:
func(node->genericArgument.identifier);
func(node->genericArgument.constraint);
return;
case GenericArguments:
for (i = 0; i < node->genericArguments.count; i += 1) {
func(node->genericArguments.arguments[i]);
}
return;
case GenericTypeNode:
return;
case Identifier:
return;
case IfStatement:
func(node->ifStatement.expression);
func(node->ifStatement.statementSequence);
return;
case IfElseStatement:
func(node->ifElseStatement.ifStatement);
func(node->ifElseStatement.elseStatement);
return;
case Number:
return;
case PrimitiveTypeNode:
return;
case ReferenceTypeNode:
func(node->referenceType.type);
return;
case Return:
func(node->returnStatement.expression);
return;
case ReturnVoid:
return;
case StatementSequence:
for (i = 0; i < node->statementSequence.count; i += 1) {
func(node->statementSequence.sequence[i]);
}
return;
case StaticModifier:
return;
case StringLiteral:
return;
case StructDeclaration:
func(node->structDeclaration.identifier);
func(node->structDeclaration.declarationSequence);
return;
case Type:
return;
case UnaryExpression:
func(node->unaryExpression.child);
return;
default:
fprintf(stderr, "wraith: Unhandled SyntaxKind %s in recurse function.\n",
SyntaxKindString(node->syntaxKind));
return;
}
}
TypeTag *MakeTypeTag(Node *node)
{
if (node == NULL)
@ -698,6 +932,14 @@ TypeTag *MakeTypeTag(Node *node)
->functionSignature.type);
break;
case AllocExpression:
tag = MakeTypeTag(node->allocExpression.type);
break;
case GenericTypeNode:
tag->type = Generic;
tag->value.genericType = strdup(node->genericType.name);
default:
fprintf(
stderr,
@ -734,6 +976,16 @@ char *TypeTagToString(TypeTag *tag)
return result;
}
case Custom:
return tag->value.customType;
{
char *result = malloc(sizeof(char) * (strlen(tag->value.customType) + 8));
sprintf(result, "Custom<%s>", tag->value.customType);
return result;
}
case Generic:
{
char *result = malloc(sizeof(char) * (strlen(tag->value.customType) + 9));
sprintf(result, "Generic<%s>", tag->value.customType);
return result;
}
}
}

View File

@ -30,6 +30,9 @@ typedef enum
FunctionModifiers,
FunctionSignature,
FunctionSignatureArguments,
GenericArgument,
GenericArguments,
GenericTypeNode,
Identifier,
IfStatement,
IfElseStatement,
@ -87,7 +90,8 @@ typedef struct TypeTag
Unknown,
Primitive,
Reference,
Custom
Custom,
Generic
} type;
union
{
@ -97,6 +101,8 @@ typedef struct TypeTag
struct TypeTag *referenceType;
/* Valid when type = Custom. */
char *customType;
/* Valid when type = Generic. */
char *genericType;
} value;
} TypeTag;
@ -192,6 +198,7 @@ struct Node
Node *type;
Node *arguments;
Node *modifiers;
Node *genericArguments;
} functionSignature;
struct
@ -200,6 +207,23 @@ struct Node
uint32_t count;
} functionSignatureArguments;
struct
{
Node *identifier;
Node *constraint;
} genericArgument;
struct
{
Node **arguments;
uint32_t count;
} genericArguments;
struct
{
char *name;
} genericType;
struct
{
char *name;
@ -306,10 +330,16 @@ Node *MakeFunctionSignatureNode(
Node *identifierNode,
Node *typeNode,
Node *argumentsNode,
Node *modifiersNode);
Node *modifiersNode,
Node *genericArgumentsNode);
Node *MakeFunctionDeclarationNode(
Node *functionSignatureNode,
Node *functionBodyNode);
Node *MakeGenericArgumentNode(Node *identifierNode, Node *constraintNode);
Node *MakeEmptyGenericArgumentsNode();
Node *StartGenericArgumentsNode(Node *genericArgumentNode);
Node *AddGenericArgument(Node *genericArgumentsNode, Node *genericArgumentNode);
Node *MakeGenericTypeNode(char *name);
Node *MakeStructDeclarationNode(
Node *identifierNode,
Node *declarationSequenceNode);
@ -337,6 +367,12 @@ Node *MakeForLoopNode(
void PrintNode(Node *node, uint32_t tabCount);
const char *SyntaxKindString(SyntaxKind syntaxKind);
/* Helper function for applying a void function generically over the children of an AST node.
* Used for functions that need to traverse the entire tree but only perform operations on a subset
* of node types. Such functions can match the syntaxKinds relevant to their purpose and invoke this
* function in all other cases. */
void Recurse(Node *node, void (*func)(Node*));
TypeTag *MakeTypeTag(Node *node);
char *TypeTagToString(TypeTag *tag);

File diff suppressed because it is too large Load Diff

View File

@ -59,9 +59,7 @@ IdNode *MakeIdTree(Node *astNode, IdNode *parent)
return NULL;
case AllocExpression:
AddChildToNode(
parent,
MakeIdTree(astNode->allocExpression.type, parent));
astNode->typeTag = MakeTypeTag(astNode);
return NULL;
case Assignment:
@ -152,6 +150,7 @@ IdNode *MakeIdTree(Node *astNode, IdNode *parent)
mainNode = MakeIdNode(Function, funcName, parent);
mainNode->typeTag = MakeTypeTag(astNode);
idNode->typeTag = mainNode->typeTag;
MakeIdTree(sigNode->functionSignature.genericArguments, mainNode);
MakeIdTree(sigNode->functionSignature.arguments, mainNode);
MakeIdTree(astNode->functionDeclaration.functionBody, mainNode);
break;
@ -167,6 +166,23 @@ IdNode *MakeIdTree(Node *astNode, IdNode *parent)
return NULL;
}
case GenericArgument:
{
char *name = astNode->genericArgument.identifier->identifier.name;
mainNode = MakeIdNode(GenericType, name, parent);
break;
}
case GenericArguments:
{
for (i = 0; i < astNode->genericArguments.count; i += 1)
{
Node *argNode = astNode->genericArguments.arguments[i];
AddChildToNode(parent, MakeIdTree(argNode, parent));
}
return NULL;
}
case Identifier:
{
char *name = astNode->identifier.name;
@ -302,6 +318,12 @@ void PrintIdNode(IdNode *node)
case Variable:
printf("%s : %s\n", node->name, TypeTagToString(node->typeTag));
break;
case GenericType:
printf("Generic type: %s\n", node->name);
break;
case Alloc:
printf("Alloc: %s\n", TypeTagToString(node->typeTag));
break;
}
}

View File

@ -17,7 +17,9 @@ typedef enum NodeType
OrderedScope,
Struct,
Function,
Variable
Variable,
GenericType,
Alloc
} NodeType;
typedef struct IdNode

View File

@ -4,6 +4,7 @@
#include "codegen.h"
#include "identcheck.h"
#include "parser.h"
#include "typeutils.h"
int main(int argc, char *argv[])
{
@ -87,9 +88,19 @@ int main(int argc, char *argv[])
{
{
IdNode *idTree = MakeIdTree(rootNode, NULL);
printf("\n");
PrintIdTree(idTree, /*tabCount=*/0);
printf("\nConverting custom types in the ID-tree.\n");
ConvertIdCustomsToGenerics(idTree);
printf("\n");
PrintIdTree(idTree, /*tabCount=*/0);
printf("\nConverting custom type nodes in the AST.\n");
ConvertASTCustomsToGenerics(rootNode);
printf("\n");
PrintNode(rootNode, /*tabCount=*/0);
}
exitCode = Codegen(rootNode, optimizationLevel);
}

83
src/typeutils.c Normal file
View File

@ -0,0 +1,83 @@
#include "typeutils.h"
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
void ConvertIdCustomsToGenerics(IdNode *node) {
uint32_t i;
switch(node->type)
{
case UnorderedScope:
case OrderedScope:
case Struct:
/* FIXME: This case will need to be modified to handle type parameters over structs. */
for (i = 0; i < node->childCount; i += 1) {
ConvertIdCustomsToGenerics(node->children[i]);
}
return;
case Variable: {
TypeTag *varType = node->typeTag;
if (varType->type == Custom) {
IdNode *x = LookupId(node->parent, node, varType->value.customType);
if (x != NULL && x->type == GenericType) {
varType->type = Generic;
}
}
return;
}
case Function: {
TypeTag *funcType = node->typeTag;
if (funcType->type == Custom) {
/* For functions we have to handle the type lookup manually since the generic type
* identifiers are declared as children of the function's IdNode. */
for (i = 0; i < node->childCount; i += 1) {
IdNode *child = node->children[i];
if (child->type == GenericType && strcmp(child->name, funcType->value.customType) == 0) {
funcType->type = Generic;
}
}
}
for (i = 0; i < node->childCount; i += 1) {
ConvertIdCustomsToGenerics(node->children[i]);
}
return;
}
}
}
/* FIXME: This function will need to be modified to handle type parameters over structs. */
void ConvertASTCustomsToGenerics(Node *node) {
uint32_t i;
switch (node->syntaxKind)
{
case Declaration:
{
Node *type = node->declaration.type->type.typeNode;
Node *id = node->declaration.identifier;
if (id->typeTag->type == Generic && type->syntaxKind == CustomTypeNode) {
free(node->declaration.type);
node->declaration.type = MakeGenericTypeNode(id->typeTag->value.genericType);
}
return;
}
case FunctionSignature:
{
Node *id = node->functionSignature.identifier;
Node *type = node->functionSignature.type;
if (id->typeTag->type == Generic && type->syntaxKind == CustomTypeNode) {
free(node->functionSignature.type);
node->functionSignature.type = MakeGenericTypeNode(id->typeTag->value.genericType);
}
ConvertASTCustomsToGenerics(node->functionSignature.arguments);
return;
}
default:
recurse(node, *ConvertASTCustomsToGenerics);
}
}

13
src/typeutils.h Normal file
View File

@ -0,0 +1,13 @@
/* Helper functions for working with types in the AST and ID-tree. */
#ifndef WRAITH_TYPEUTILS_H
#define WRAITH_TYPEUTILS_H
#include "ast.h"
#include "identcheck.h"
/* FIXME: These two functions will need to be modified to handle type parameters over structs. */
void ConvertIdCustomsToGenerics(IdNode *node);
void ConvertASTCustomsToGenerics(Node *node);
#endif /* WRAITH_TYPEUTILS_H */

View File

@ -1,7 +1,6 @@
#include "util.h"
#include <stdlib.h>
#include <string.h>
char *strdup(const char *s)
{
@ -15,3 +14,16 @@ char *strdup(const char *s)
memcpy(result, s, slen + 1);
return result;
}
uint64_t str_hash(char *str)
{
uint64_t hash = 5381;
size_t c;
while ((c = *str++))
{
hash = ((hash << 5) + hash) + c; /* hash * 33 + c */
}
return hash;
}

View File

@ -1,8 +1,10 @@
#ifndef WRAITH_UTIL_H
#define WRAITH_UTIL_H
#include <stdint.h>
#include <string.h>
char *strdup(const char *s);
uint64_t str_hash(char *str);
#endif /* WRAITH_UTIL_H */