aboutsummaryrefslogtreecommitdiff
path: root/bootstrap/fold.c
diff options
context:
space:
mode:
Diffstat (limited to 'bootstrap/fold.c')
-rw-r--r--bootstrap/fold.c210
1 files changed, 210 insertions, 0 deletions
diff --git a/bootstrap/fold.c b/bootstrap/fold.c
new file mode 100644
index 0000000..8f080a0
--- /dev/null
+++ b/bootstrap/fold.c
@@ -0,0 +1,210 @@
+#include "all.h"
+#include <stdint.h>
+
+int fold(struct expr *ex);
+
+static void
+numcast(struct expr *ex, const struct type *to) {
+ enum typetype t0 = ex->ty->t;
+ enum typetype t1 = to->t;
+ const struct type *from = ex->ty;
+ const struct type *uto = unconstify(to);
+ const struct type *ufrom = unconstify(from);
+ int size = to->size;
+ bool sgn = to->int_signed;
+ assert(t0 == TYint || t0 == TYfloat || t0 == TYbool);
+ assert(t1 == TYint || t1 == TYfloat || t1 == TYbool);
+
+ if (ufrom == uto)
+ /* pass */;
+ else if (t1 == TYint && t0 == TYfloat)
+ ex->i = from->size == 4 ? (float)ex->f : ex->f;
+ else if (t0 == TYfloat && t1 == TYfloat && size == 4)
+ ex->f = (float)ex->f;
+ else if (t0 == TYfloat && t1 == TYfloat && size == 8)
+ ex->f = (double)ex->f;
+ else if (t0 == TYbool)
+ ex->i = !!ex->i;
+ #define intint(siz,sgnd,Ty) \
+ else if (t0 == TYint && t1 == TYint && size == (siz) && sgn == (sgnd)) \
+ ex->i = (Ty)ex->i;
+ intint(1,0,uint8_t)
+ intint(1,1,int8_t)
+ intint(2,0,uint16_t)
+ intint(2,1,int16_t)
+ intint(4,0,uint32_t)
+ intint(4,1,int32_t)
+ intint(8,0,uint64_t)
+ intint(8,1,int64_t)
+ else if (uto == ty_f64 && ufrom == ty_u64)
+ ex->f = (u64)ex->i;
+ else if (uto == ty_f32 && ufrom == ty_u64)
+ ex->f = (float)(u64)ex->i;
+ else if (uto == ty_f64)
+ ex->f = ex->i;
+ else if (uto == ty_f32)
+ ex->f = (float)ex->i;
+ else if (uto == ty_bool && t0 == TYint)
+ ex->i = !!ex->i;
+ else if (ufrom == ty_bool && t1 == TYint)
+ ex->i = !!ex->i;
+ else
+ assert(0);
+
+ ex->ty = to;
+ ex->t = to->t == TYfloat ? Eflolit
+ : to->t == TYint ? Eintlit
+ : Eboolit;
+}
+
+static void
+unary(struct expr *ex) {
+ struct expr *r = ex->unop.child;
+ const struct type *ty = ex->ty;
+
+ if (!fold(r))
+ return;
+
+ switch (ex->unop.op) {
+ case '-':
+ *ex = *r;
+ ex->ty = ty;
+ if (r->ty->t == TYint)
+ ex->i = -ex->i;
+ else if (r->ty->t == TYfloat)
+ ex->f = -ex->f;
+ else assert(0);
+ free(r);
+ numcast(ex, ty);
+ break;
+ case '~':
+ *ex = *r;
+ ex->ty = ty;
+ assert(r->ty->t == TYint);
+ ex->i = ~ex->i;
+ numcast(ex, ty);
+ free(r);
+ break;
+ case 'not':
+ *ex = *r;
+ ex->ty = ty;
+ assert(ty->t == TYbool);
+ ex->i = !ex->i;
+ free(r);
+ break;
+ }
+}
+
+static void
+binop(struct expr *ex) {
+ struct expr *l = ex->binop.lhs;
+ struct expr *r = ex->binop.rhs;
+ const struct type *ty = ex->ty,
+ *uty = unconstify(ty);
+ int op = ex->binop.op;
+
+ if (!fold(l) || !fold(r))
+ return;
+
+ if (uty == ty_bool)
+ ty = typeof2(l->ty, r->ty);
+ numcast(l, ty);
+ numcast(r, ty);
+ if (op == '/' && r->ty->t == TYint && r->i == 0) // div/0
+ return;
+
+ if (op == '+' && ty->t == TYint) ex->i = l->i + r->i;
+ else if (op == '+' && ty->t == TYfloat) ex->i = l->f + r->f;
+ else if (op == '-' && ty->t == TYint) ex->i = l->i - r->i;
+ else if (op == '-' && ty->t == TYfloat) ex->i = l->f - r->f;
+ else if (op == '*' && uty == ty_u64) ex->i = (u64)l->i * (u64)r->i;
+ else if (op == '*' && ty->t == TYint) ex->i = l->i * r->i;
+ else if (op == '*' && ty->t == TYfloat) ex->i = l->f * r->f;
+ else if (op == '/' && uty == ty_u64) ex->i = (u64)l->i / (u64)r->i;
+ else if (op == '/' && ty->t == TYint) ex->i = l->i / r->i;
+ else if (op == '/' && ty->t == TYfloat) ex->i = l->f / r->f;
+ else if (op == '%' && uty == ty_u64) ex->i = (u64)l->i % (u64)r->i;
+ else if (op == '%' && ty->t == TYint) ex->i = l->i % r->i;
+ else if (op == '&' && ty->t == TYint) ex->i = l->i & r->i;
+ else if (op == '|' && ty->t == TYint) ex->i = l->i | r->i;
+ else if (op == '^' && ty->t == TYint) ex->i = l->i ^ r->i;
+ else if (op == '<<' && ty->t == TYint) ex->i = l->i << r->i;
+ else if (op == '>>' && ty->t == TYint && !ty->int_signed)
+ ex->i = (u64)l->i >> r->i;
+ else if (op == '>>' && ty->t == TYint) ex->i = l->i >> r->i;
+ else if (op == '==' && ty->t == TYfloat) ex->i = l->f == r->f;
+ else if (op == '==' && ty->t == TYint) ex->i = l->i == r->i;
+ else if (op == '<' && ty->t == TYfloat) ex->i = l->f < r->f;
+ else if (op == '<' && uty == ty_u64) ex->i = (u64)l->i < (u64)r->i;
+ else if (op == '<' && ty->t == TYint) ex->i = l->i < r->i;
+ else if (op == '>' && ty->t == TYfloat) ex->i = l->f > r->f;
+ else if (op == '>' && uty == ty_u64) ex->i = (u64)l->i > (u64)r->i;
+ else if (op == '>' && ty->t == TYint) ex->i = l->i > r->i;
+ else if (op == '<=' && ty->t == TYfloat) ex->i = l->f <= r->f;
+ else if (op == '<=' && uty == ty_u64) ex->i = (u64)l->i <= (u64)r->i;
+ else if (op == '<=' && ty->t == TYint) ex->i = l->i <= r->i;
+ else if (op == '>=' && ty->t == TYfloat) ex->i = l->f >= r->f;
+ else if (op == '>=' && uty == ty_u64) ex->i = (u64)l->i >= (u64)r->i;
+ else if (op == '>=' && ty->t == TYint) ex->i = l->i >= r->i;
+ else if (op == 'and') ex->i = l->i && r->i;
+ else if (op == 'or') ex->i = l->i || r->i;
+ else assert(0);
+
+ free(l);
+ free(r);
+ if (uty != ty_bool)
+ numcast(ex, ty);
+ else
+ ex->t = Eboolit;
+}
+
+static void
+cond(struct expr *ex) {
+ struct expr *test = ex->cond.test,
+ *t = ex->cond.t,
+ *f = ex->cond.f;
+ const struct type *ty = ex->ty;
+
+ if (!fold(test) || !fold(f) || !fold(t))
+ return;
+ assert(test->ty->t == TYbool);
+ numcast(f, ty);
+ numcast(t, ty);
+ if (test->i)
+ ex->i = t->i;
+ else
+ ex->i = f->i;
+ numcast(ex, ty);
+ free(test);
+ free(t);
+ free(f);
+}
+
+int
+fold(struct expr *ex) {
+ switch (ex->t) {
+ case Eintlit: case Eflolit: case Eboolit:
+ numcast(ex, ex->ty);
+ return 1;
+ case Estrlit:case Enullit:
+ return 1;
+ case Eprefix:
+ unary(ex);
+ break;
+ case Ebinop:
+ binop(ex);
+ break;
+ case Econd:
+ cond(ex);
+ default:
+ break;
+ }
+
+ switch (ex->t) {
+ case Eintlit: case Eflolit: case Estrlit:
+ case Eboolit: case Enullit:
+ return 1;
+ default:
+ return 0;
+ }
+}