#include #include #include #include #include #include #include "expr.h" #include "finalize.h" #include "memregion.h" #include "sym.h" struct Expr { enum ExprKind kind; union { struct { const struct Expr *left, *right; } binary; const struct Expr *unary; struct { struct { uint64_t uint; } literal; const struct UStr *identifier; } primary; }; }; // constructors static struct Expr * newExpr(void) { struct Expr *expr = allocFromMemRegion(sizeof(*expr)); return expr; } struct Expr * newUnsignedLiteralExpr(uint64_t uint) { struct Expr *expr = newExpr(); expr->kind = EK_UNSIGNED_LITERAL; expr->primary.literal.uint = uint; return expr; } struct Expr * newIdentifierExpr(const struct UStr *identifier) { assert(identifier); struct Expr *expr = newExpr(); expr->kind = EK_IDENTIFIER; expr->primary.identifier = identifier; return expr; } struct Expr * newUnaryExpr(enum ExprKind kind, const struct Expr *unary) { assert(kind >= EK_UNARY && kind < EK_UNARY_END); assert(unary); struct Expr *expr = newExpr(); expr->kind = kind; expr->unary = unary; return expr; } struct Expr * newBinaryExpr(enum ExprKind kind, const struct Expr *left, const struct Expr *right) { assert(kind >= EK_BINARY && kind < EK_BINARY_END); assert(left); assert(right); struct Expr *expr = newExpr(); expr->kind = kind; expr->binary.left = left; expr->binary.right = right; return expr; } // destructor void deleteAllExpr(void) { releaseMemRegion(); } // methods bool isLValueExpr(const struct Expr *expr) { return expr->kind == EK_IDENTIFIER; } bool isConstExpr(const struct Expr *expr) { return expr->kind == EK_UNSIGNED_LITERAL; } void loadExprAddr(const struct Expr *expr, GenReg dest) { assert(isLValueExpr(expr)); if (expr->kind == EK_IDENTIFIER) { genLoadLabel(expr->primary.identifier->cstr, dest); return; } fprintf(stderr, "loadExprAddr: kind = %d", expr->kind); assert(0); } static enum GenOp makeOp3r(enum ExprKind kind) { switch (kind) { case EK_ADD: return GEN_ADD_R; case EK_SUB: return GEN_SUB_R; case EK_MUL: return GEN_IMUL_R; case EK_DIV: return GEN_DIV_R; case EK_MOD: return GEN_MOD_R; default: fprintf(stderr, "makeOp3r: kind = %d", kind); finalizeExit(1); return 0; // never reached } } static enum GenOp makeOp3i(enum ExprKind kind) { switch (kind) { case EK_ADD: return GEN_ADD_I; case EK_SUB: return GEN_SUB_I; case EK_MUL: return GEN_IMUL_I; case EK_DIV: return GEN_DIV_I; case EK_MOD: return GEN_MOD_I; default: fprintf(stderr, "makeOp3i: kind = %d", kind); finalizeExit(1); return 0; // never reached } } static enum GenOp makeCondJmp(enum ExprKind kind) { switch (kind) { case EK_EQUAL: return GEN_EQUAL; case EK_NOT_EQUAL: return GEN_NOT_EQUAL; case EK_GREATER: return GEN_ABOVE; case EK_GREATER_EQUAL: return GEN_ABOVE_EQUAL; case EK_LESS: return GEN_BELOW; case EK_LESS_EQUAL: return GEN_BELOW_EQUAL; default: fprintf(stderr, "makeCondJmp: kind = %d", kind); finalizeExit(1); return 0; // never reached } } void condJmpExpr(const struct Expr *expr, GenReg dest, const char *trueLabel, const char *falseLabel) { assert(expr); assert(expr->kind >= EK_BINARY && expr->kind < EK_PRIMARY_END); assert(trueLabel || falseLabel); assert(!trueLabel || !falseLabel); switch (expr->kind) { case EK_LESS: case EK_LESS_EQUAL: case EK_GREATER: case EK_GREATER_EQUAL: case EK_EQUAL: case EK_NOT_EQUAL: { const struct Expr *left = expr->binary.left; const struct Expr *right = expr->binary.right; loadExpr(left, dest); if (isConstExpr(right)) { uint64_t rVal = right->primary.literal.uint; genOp2i(GEN_CMP_I, rVal, dest); } else { GenReg tmp = genGetReg(); loadExpr(right, tmp); genOp2r(GEN_CMP_R, tmp, dest); genUngetReg(tmp); } genCondJmp(makeCondJmp(expr->kind), trueLabel, falseLabel); } return; case EK_LOGICAL_NOT: condJmpExpr(expr->unary, dest, falseLabel, trueLabel); return; default: loadExpr(expr, dest); genOp2i(GEN_CMP_I, 0, dest); genCondJmp(GEN_NOT_EQUAL, trueLabel, falseLabel); } } void loadExpr(const struct Expr *expr, GenReg dest) { assert(expr); assert(expr->kind >= EK_BINARY && expr->kind < EK_PRIMARY_END); switch (expr->kind) { case EK_ADD: case EK_SUB: case EK_MUL: case EK_DIV: case EK_MOD: { const struct Expr *left = expr->binary.left; const struct Expr *right = expr->binary.right; loadExpr(left, dest); if (isConstExpr(right)) { uint64_t rVal = right->primary.literal.uint; genOp3i(makeOp3i(expr->kind), rVal, dest, dest); } else { GenReg tmp = genGetReg(); loadExpr(right, tmp); genOp3r(makeOp3r(expr->kind), tmp, dest, dest); genUngetReg(tmp); } } return; case EK_ASSIGN: { const struct Expr *left = expr->binary.left; const struct Expr *right = expr->binary.right; assert(isLValueExpr(left)); loadExpr(right, dest); GenReg addr = genGetReg(); genLoadLabel(left->primary.identifier->cstr, addr); genStore(dest, addr); genUngetReg(addr); } return; case EK_LOGICAL_NOT: case EK_GREATER: case EK_GREATER_EQUAL: case EK_LESS: case EK_LESS_EQUAL: case EK_EQUAL: case EK_NOT_EQUAL: { const char *elseLabel = genGetLabel(); const char *endLabel = genGetLabel(); condJmpExpr(expr, dest, elseLabel, 0); genLoadUInt(0, dest); genJmp(endLabel); genLabelDef(elseLabel); genLoadUInt(1, dest); genLabelDef(endLabel); } return; case EK_UNARY_MINUS: loadExpr(expr->unary, dest); genOp2r(GEN_UNARYMINUS_R, dest, dest); return; case EK_UNARY_PLUS: loadExpr(expr->unary, dest); return; case EK_IDENTIFIER: genLoadLabel(expr->primary.identifier->cstr, dest); genFetch(dest, dest); return; case EK_UNSIGNED_LITERAL: genLoadUInt(expr->primary.literal.uint, dest); return; default: ; } fprintf(stderr, "loadExpr: internal error. kind = %d\n", expr->kind); finalizeExit(1); } static void printIndent(size_t indent, FILE *out) { for (size_t i = 0; i < indent * 4; ++i) { fprintf(out, " "); } } static void printExprNode(const struct Expr *expr, size_t indent, FILE *out) { assert(expr); printIndent(indent, out); if (expr->kind >= EK_BINARY && expr->kind < EK_UNARY_END) { switch (expr->kind) { case EK_ADD: fprintf(out, "[ +\n"); return; case EK_ASSIGN: fprintf(out, "[ {=}\n"); return; case EK_SUB: fprintf(out, "[ -\n"); return; case EK_MUL: fprintf(out, "[ *\n"); return; case EK_DIV: fprintf(out, "[ /\n"); return; case EK_MOD: fprintf(out, "[ $\\bmod$\n"); return; case EK_UNARY_MINUS: fprintf(out, "[ -\n"); return; case EK_UNARY_PLUS: fprintf(out, "[ +\n"); return; case EK_EQUAL: fprintf(out, "[ {==}\n"); return; case EK_NOT_EQUAL: fprintf(out, "[ $\\neq$ \n"); return; case EK_GREATER: fprintf(out, "[ $>$ \n"); return; case EK_GREATER_EQUAL: fprintf(out, "[ $\\geq$ \n"); return; case EK_LESS: fprintf(out, "[ $<$ \n"); return; case EK_LESS_EQUAL: fprintf(out, "[ $\\leq$ \n"); return; case EK_LOGICAL_NOT: fprintf(out, "[ $\\lnot$ \n"); return; default:; } } else if (expr->kind == EK_UNSIGNED_LITERAL) { fprintf(out, "[ %" PRIu64 "]\n", expr->primary.literal.uint); return; } else if (expr->kind == EK_IDENTIFIER) { fprintf(out, "[ %s ]\n", expr->primary.identifier->cstr); return; } fprintf(stderr, "printExprNode: internal error. kind = %d\n", expr->kind); finalizeExit(1); } static void printExprTree_(const struct Expr *expr, size_t indent, FILE *out) { assert(expr); assert(expr->kind >= EK_BINARY && expr->kind < EK_PRIMARY_END); if (expr->kind >= EK_BINARY && expr->kind < EK_BINARY_END) { printExprNode(expr, indent, out); printExprTree_(expr->binary.left, indent + 1, out); printExprTree_(expr->binary.right, indent + 1, out); printIndent(indent, out); fprintf(out, "]\n"); } else if (expr->kind >= EK_UNARY && expr->kind < EK_UNARY_END) { printExprNode(expr, indent, out); printExprTree_(expr->unary, indent + 1, out); printIndent(indent, out); fprintf(out, "]\n"); } else { printExprNode(expr, indent, out); } } void printExprTree(const struct Expr *expr, FILE *out) { fprintf(out, "\\begin{center}\n"); fprintf(out, "\\begin{forest}\n"); fprintf(out, "for tree={draw,circle,calign=fixed edge angles}\n"); printExprTree_(expr, 0, out); fprintf(out, "\\end{forest}\n"); fprintf(out, "\\end{center}\n"); } //------------------------------------------------------------------------------ // stuff for constant folding static const struct Expr * constFoldBinary(const struct Expr *expr) { assert(expr); assert(expr->kind >= EK_BINARY && expr->kind < EK_BINARY_END); // const fold child nodes const struct Expr *left = constFoldExpr(expr->binary.left); const struct Expr *right = constFoldExpr(expr->binary.right); if (isConstExpr(left) && !isConstExpr(right)) { // swap operands if possible switch (expr->kind) { case EK_EQUAL: case EK_NOT_EQUAL: case EK_ADD: case EK_MUL: { expr = newBinaryExpr(expr->kind, right, left); left = expr->binary.left; right = expr->binary.right; } break; default: ; } } if (isConstExpr(right)) { // if right child node is constant, handle some special cases if (expr->kind == EK_ADD && right->primary.literal.uint == 0) { return left; } if (expr->kind == EK_MUL && right->primary.literal.uint == 1) { return left; } if (expr->kind == EK_MUL && right->primary.literal.uint == 0) { return newUnsignedLiteralExpr(0); } if (expr->kind == EK_DIV && right->primary.literal.uint == 1) { return left; } if (expr->kind == EK_MOD && right->primary.literal.uint == 1) { return newUnsignedLiteralExpr(0); } } if (!isConstExpr(left) || !isConstExpr(right)) { // nothing more can be done if (left == expr->binary.left && right == expr->binary.right) { return expr; } return newBinaryExpr(expr->kind, left, right); } // handle cases where node can be completely folded // NOTE: here we exploit that in our case constants are always unsigned uint64_t leftVal = left->primary.literal.uint; uint64_t rightVal = right->primary.literal.uint; switch (expr->kind) { case EK_ADD: return newUnsignedLiteralExpr(leftVal + rightVal); case EK_SUB: return newUnsignedLiteralExpr(leftVal - rightVal); case EK_MUL: return newUnsignedLiteralExpr(leftVal * rightVal); case EK_DIV: return newUnsignedLiteralExpr(leftVal / rightVal); case EK_MOD: return newUnsignedLiteralExpr(leftVal % rightVal); case EK_EQUAL: return newUnsignedLiteralExpr(leftVal == rightVal); case EK_NOT_EQUAL: return newUnsignedLiteralExpr(leftVal != rightVal); case EK_LESS: return newUnsignedLiteralExpr(leftVal < rightVal); case EK_LESS_EQUAL: return newUnsignedLiteralExpr(leftVal <= rightVal); case EK_GREATER: return newUnsignedLiteralExpr(leftVal > rightVal); case EK_GREATER_EQUAL: return newUnsignedLiteralExpr(leftVal >= rightVal); case EK_ASSIGN: assert(0); // internal error return 0; // prevent warning (never reached in debug mode) default: assert(0); // internal error (you can turn this into a warning) return newBinaryExpr(expr->kind, left, right); } } static const struct Expr * constFoldUnary(const struct Expr *expr) { assert(expr); assert(expr->kind >= EK_UNARY && expr->kind < EK_UNARY_END); // const fold child node const struct Expr *unary = constFoldExpr(expr->unary); // handle all cases where folding is not possible if (!isConstExpr(unary)) { if (unary == expr->unary) { return expr; } return newUnaryExpr(expr->kind, unary); } // otherwise return folded expression node // NOTE: here we exploit that in our case constants are always unsigned switch (expr->kind) { case EK_UNARY_PLUS: return newUnsignedLiteralExpr(unary->primary.literal.uint); case EK_UNARY_MINUS: return newUnsignedLiteralExpr(-unary->primary.literal.uint); case EK_LOGICAL_NOT: return newUnsignedLiteralExpr(!unary->primary.literal.uint); default: assert(0); // internal error (you can turn this into a warning) return newUnaryExpr(expr->kind, unary); } } const struct Expr * constFoldExpr(const struct Expr *expr) { assert(expr); assert(expr->kind >= EK_BINARY && expr->kind < EK_PRIMARY_END); if (expr->kind >= EK_BINARY && expr->kind < EK_BINARY_END) { return constFoldBinary(expr); } else if (expr->kind >= EK_UNARY && expr->kind < EK_UNARY_END) { return constFoldUnary(expr); } return expr; }