#include "all.h" #include int fold(struct expr *ex); static void numcast(struct expr *ex, const struct type *to) { enum typetype t1 = to->t; const struct type *from = ex->ty->t == TYenum ? ex->ty->enu.intty : ex->ty; enum typetype t0 = from->t; const struct type *uto = unconstify(to); const struct type *ufrom = unconstify(from); int size = to->size; bool sgn = to->int_signed; assert(t0 == TYint || t0 == TYfloat || t0 == TYbool || t0 == TYenum); assert(t1 == TYint || t1 == TYfloat || t1 == TYbool || t1 == TYenum); if (ufrom == uto) /* pass */; else if (t1 == TYint && t0 == TYfloat) ex->i = from->size == 4 ? (float)ex->f : ex->f; else if (t0 == TYfloat && t1 == TYfloat && size == 4) ex->f = (float)ex->f; else if (t0 == TYfloat && t1 == TYfloat && size == 8) ex->f = (double)ex->f; else if (t0 == TYbool) ex->i = !!ex->i; #define intint(siz,sgnd,Ty) \ else if (t0 == TYint && t1 == TYint && size == (siz) && sgn == (sgnd)) \ ex->i = (Ty)ex->i; intint(1,0,uint8_t) intint(1,1,int8_t) intint(2,0,uint16_t) intint(2,1,int16_t) intint(4,0,uint32_t) intint(4,1,int32_t) intint(8,0,uint64_t) intint(8,1,int64_t) else if (uto == ty_f64 && ufrom == ty_u64) ex->f = (u64)ex->i; else if (uto == ty_f32 && ufrom == ty_u64) ex->f = (float)(u64)ex->i; else if (uto == ty_f64) ex->f = ex->i; else if (uto == ty_f32) ex->f = (float)ex->i; else if (uto == ty_bool && t0 == TYint) ex->i = !!ex->i; else if (ufrom == ty_bool && t1 == TYint) ex->i = !!ex->i; else assert(0); ex->ty = to; ex->t = to->t == TYfloat ? Eflolit : to->t == TYint ? Eintlit : Eboolit; } static void funary(struct expr *ex) { struct expr *r = ex->unop.child; const struct type *ty = ex->ty; if (!fold(r)) return; switch (ex->unop.op) { case '-': *ex = *r; ex->ty = ty; if (r->ty->t == TYint) ex->i = -ex->i; else if (r->ty->t == TYfloat) ex->f = -ex->f; else assert(0); free(r); numcast(ex, ty); break; case '~': *ex = *r; ex->ty = ty; assert(r->ty->t == TYint); ex->i = ~ex->i; numcast(ex, ty); free(r); break; case 'not': *ex = *r; ex->ty = ty; assert(ty->t == TYbool); ex->i = !ex->i; free(r); break; } } static void fbinop(struct expr *ex) { struct expr *l = ex->binop.lhs; struct expr *r = ex->binop.rhs; const struct type *ty = ex->ty, *uty = unconstify(ty); int op = ex->binop.op; if (!fold(l) || !fold(r)) return; if (uty == ty_bool) ty = typeof2(l->ty, r->ty); numcast(l, ty); numcast(r, ty); if (op == '/' && r->ty->t == TYint && r->i == 0) // div/0 return; if (op == '+' && ty->t == TYint) ex->i = l->i + r->i; else if (op == '+' && ty->t == TYfloat) ex->i = l->f + r->f; else if (op == '-' && ty->t == TYint) ex->i = l->i - r->i; else if (op == '-' && ty->t == TYfloat) ex->i = l->f - r->f; else if (op == '*' && uty == ty_u64) ex->i = (u64)l->i * (u64)r->i; else if (op == '*' && ty->t == TYint) ex->i = l->i * r->i; else if (op == '*' && ty->t == TYfloat) ex->i = l->f * r->f; else if (op == '/' && uty == ty_u64) ex->i = (u64)l->i / (u64)r->i; else if (op == '/' && ty->t == TYint) ex->i = l->i / r->i; else if (op == '/' && ty->t == TYfloat) ex->i = l->f / r->f; else if (op == '%' && uty == ty_u64) ex->i = (u64)l->i % (u64)r->i; else if (op == '%' && ty->t == TYint) ex->i = l->i % r->i; else if (op == '&' && ty->t == TYint) ex->i = l->i & r->i; else if (op == '|' && ty->t == TYint) ex->i = l->i | r->i; else if (op == '^' && ty->t == TYint) ex->i = l->i ^ r->i; else if (op == '<<' && ty->t == TYint) ex->i = l->i << r->i; else if (op == '>>' && ty->t == TYint && !ty->int_signed) ex->i = (u64)l->i >> r->i; else if (op == '>>' && ty->t == TYint) ex->i = l->i >> r->i; else if (op == '==' && ty->t == TYfloat) ex->i = l->f == r->f; else if (op == '==' && ty->t == TYint) ex->i = l->i == r->i; else if (op == '<' && ty->t == TYfloat) ex->i = l->f < r->f; else if (op == '<' && uty == ty_u64) ex->i = (u64)l->i < (u64)r->i; else if (op == '<' && ty->t == TYint) ex->i = l->i < r->i; else if (op == '>' && ty->t == TYfloat) ex->i = l->f > r->f; else if (op == '>' && uty == ty_u64) ex->i = (u64)l->i > (u64)r->i; else if (op == '>' && ty->t == TYint) ex->i = l->i > r->i; else if (op == '<=' && ty->t == TYfloat) ex->i = l->f <= r->f; else if (op == '<=' && uty == ty_u64) ex->i = (u64)l->i <= (u64)r->i; else if (op == '<=' && ty->t == TYint) ex->i = l->i <= r->i; else if (op == '>=' && ty->t == TYfloat) ex->i = l->f >= r->f; else if (op == '>=' && uty == ty_u64) ex->i = (u64)l->i >= (u64)r->i; else if (op == '>=' && ty->t == TYint) ex->i = l->i >= r->i; else if (op == 'and') ex->i = l->i && r->i; else if (op == 'or') ex->i = l->i || r->i; else assert(0); free(l); free(r); if (uty != ty_bool) numcast(ex, ty); else ex->t = Eboolit; } static void fcond(struct expr *ex) { struct expr *test = ex->cond.test, *t = ex->cond.t, *f = ex->cond.f; const struct type *ty = ex->ty; if (!fold(test) || !fold(f) || !fold(t)) return; assert(test->ty->t == TYbool); numcast(f, ty); numcast(t, ty); if (test->i) ex->i = t->i; else ex->i = f->i; numcast(ex, ty); free(test); free(t); free(f); } static void findex(struct expr *ex) { struct expr *l = ex->index.lhs, *r = ex->index.rhs; if (!fold(l) || !fold(r)) return; if (l->t == Estrlit) { if (r->u > l->strlit.n) return; ex->t = Eintlit; ex->i = l->strlit.d[r->u]; } else { assert(0); } free(l); free(r); } static void trysetenumvname(struct expr *ex) { const struct type *t = ex->ty; for (int i = 0; i < t->enu.vals.n; ++i) if (ex->enu.i == t->enu.vals.d[i].i) { ex->enu.vname = t->enu.vals.d[i].name; break; } } static void fas(struct expr *ex) { struct expr *child = ex->child; const struct type *to = ex->ty, *from = child->ty; if (!fold(child)) return; if (!isnumtype(to) && to->t != TYenum) return; ex->i = child->i; if (from->t == TYenum) { child->ty = child->ty->enu.intty; } ex->ty = child->ty; ex->i = child->i; if (to->t == TYenum) { numcast(ex, to->enu.intty); ex->ty = to; ex->t = Eenumval; trysetenumvname(ex); } else { numcast(ex, to); } free(child); } static void fzeroini(struct expr *ex) { const struct type *ty = ex->ty; if (ty->t == TYint || ty->t == TYfloat || ty->t == TYbool || ty->t == TYenum) { ex->i = 0; numcast(ex, ty); if (ty->t == TYenum) trysetenumvname(ex); } } int fold(struct expr *ex) { switch (ex->t) { case Eintlit: case Eflolit: case Eboolit: numcast(ex, ex->ty); return 1; case Estrlit:case Enullit: case Eenumval: return 1; case Eprefix: funary(ex); break; case Ebinop: fbinop(ex); break; case Econd: fcond(ex); break; case Eindex: findex(ex); break; case Eas: fas(ex); break; case Ezeroini: fzeroini(ex); default: break; } switch (ex->t) { case Eintlit: case Eflolit: case Estrlit: case Eboolit: case Enullit: case Eenumval: return 1; default: return 0; } }