aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/ir_simpl.c
diff options
context:
space:
mode:
authorlemon <lsof@mailbox.org>2026-03-17 13:22:00 +0100
committerlemon <lsof@mailbox.org>2026-03-17 13:22:00 +0100
commita8d6f8bf30c07edb775e56889f568ca20240bedf (patch)
treeb5a452b2675b2400f15013617291fe6061180bbf /src/ir_simpl.c
parent24f14b7ad1af08d872971d72ce089a529911f657 (diff)
REFACTOR: move sources to src/
Diffstat (limited to 'src/ir_simpl.c')
-rw-r--r--src/ir_simpl.c308
1 files changed, 308 insertions, 0 deletions
diff --git a/src/ir_simpl.c b/src/ir_simpl.c
new file mode 100644
index 0000000..a01879a
--- /dev/null
+++ b/src/ir_simpl.c
@@ -0,0 +1,308 @@
+#include "ir.h"
+
+static int
+mulk(struct instr *ins, struct block *blk, int *curi)
+{
+ vlong iv = intconval(ins->r);
+ enum irclass cls = ins->cls;
+ assert((uvlong)iv > 1 && "trivial mul not handled by irbinop() ?");
+ bool neg = iv < 0;
+ if (neg) iv = -iv;
+ /* This can be generalized to any sequence of shifts and
+ * adds/subtracts, but whether that's worth it depends on the number of them
+ * and the microarchitecture.. clang seems to stop after two shifts. Should
+ * we compute approximate cost of instrs to determine? For now just handle
+ * the po2 (+/- 1) case */
+ if (ispo2(iv)) {
+ /* x * 2^y ==> x << y */
+ ins->op = Oshl;
+ ins->r = mkref(RICON, ilog2(iv));
+ } else if (ispo2(iv-1)) {
+ /* x * 5 ==> (x << 2) + x */
+ ins->op = Oadd;
+ ins->r = ins->l;
+ ins->l = insertinstr(blk, (*curi)++, mkinstr(Oshl, cls, ins->l, mkref(RICON, ilog2(iv-1))));
+ } else if (ispo2(iv+1)) {
+ /* x * 7 ==> (x << 3) - x */
+ ins->op = Osub;
+ ins->r = ins->l;
+ ins->l = insertinstr(blk, (*curi)++, mkinstr(Oshl, cls, ins->l, mkref(RICON, ilog2(iv+1))));
+ } else return 0;
+ if (neg) {
+ ins->l = insertinstr(blk, (*curi)++, *ins);
+ ins->op = Oneg;
+ ins->r = NOREF;
+ }
+ return 1;
+}
+
+static int
+divmodk(struct instr *ins, struct block *blk, int *curi)
+{
+ enum op op = ins->op;
+ enum irclass cls = ins->cls;
+ vlong iv = intconval(ins->r);
+ uint nbit = 8 * cls2siz[cls];
+ bool neg = (op == Odiv || op == Orem) && iv < 0;
+ if (ispo2(iv) || (neg && ispo2(-iv))) { /* simple po2 cases */
+ union ref temp;
+ uint s = ilog2(neg ? -iv : iv);
+ switch (op) {
+ default: assert(0);
+ case Oudiv: /* x / 2^s ==> x >> s */
+ ins->op = Oslr, ins->r = mkref(RICON, s);
+ return 1;
+ case Ourem: /* x % 2^s ==> x & 2^s-1 */
+ ins->op = Oand, ins->r = mkintcon(cls, iv - 1);
+ return 1;
+ case Odiv: case Orem:
+ /* have to adjust to round negatives toward zero */
+ /* x' = (((x < 0 ? -1 : 0) >>> (Nbit - s)) + x) */
+ temp = insertinstr(blk, (*curi)++, mkinstr(Osar, cls, ins->l, mkref(RICON, nbit - 1)));
+ temp = insertinstr(blk, (*curi)++, mkinstr(Oslr, cls, temp, mkref(RICON, nbit - s)));
+ temp = insertinstr(blk, (*curi)++, mkinstr(Oadd, cls, ins->l, temp));
+ if (op == Odiv) {
+ /* (-) (x' >> s) */
+ struct instr sar = mkinstr(Osar, cls, temp, mkref(RICON, s));
+ if (!neg) *ins = sar;
+ else {
+ temp = insertinstr(blk, (*curi)++, sar);
+ ins->op = Oneg, ins->l = temp, ins->r = NOREF;
+ }
+ } else {
+ /* x - (x' & -(2^s)) */
+ temp = insertinstr(blk, (*curi)++, mkinstr(Oand, cls, temp, mkintcon(cls, neg ? iv : -iv)));
+ ins->op = Osub, ins->r = temp;
+ }
+ break;
+ return 0;
+ }
+ }
+ return 0;
+}
+
+static int
+doins(struct instr *ins, struct block *blk, int *curi)
+{
+ int narg = opnarg[ins->op];
+ if (oisarith(ins->op)) {
+ union ref r = narg == 1 ? irunop(NULL, ins->op, ins->cls, ins->l)
+ : irbinop(NULL, ins->op, ins->cls, ins->l, ins->r);
+ if (r.bits) {
+ *ins = mkinstr(Onop,0,);
+ replcuses(mkref(RTMP, ins - instrtab), r);
+ return 1;
+ }
+ }
+ enum irclass k = ins->cls;
+ switch (ins->op) {
+ case Ocopy:
+ if ((ins->l.t == RTMP && k == instrtab[ins->l.i].cls)
+ || (isnumcon(ins->l) && k == concls(ins->l))
+ || (kisint(k) && ins->l.t == RICON)) {
+ union ref it = ins->l;
+ *ins = mkinstr(Onop,0,);
+ replcuses(mkref(RTMP, ins - instrtab), it);
+ return 1;
+ }
+ break;
+ case Omul:
+ if (kisflt(k)) break;
+ if (isnumcon(ins->l)) rswap(ins->l, ins->r); /* put const in rhs */
+ if (isintcon(ins->r)) return mulk(ins, blk, curi);
+ break;
+ case Odiv: case Oudiv: case Orem: case Ourem:
+ if (kisflt(k)) break;
+ if (isintcon(ins->r)) return divmodk(ins, blk, curi);
+ break;
+ case Oequ: case Oneq:
+ if (ins->l.t == RTMP && isintcon(ins->r)) {
+ /* optimize `x <op> C <cmp> Q` */
+ /* could apply with add/sub to lth/lte/gth/gte iff no signed wraparound, which isn't assumed */
+ struct instr *lhs = &instrtab[ins->l.i];
+ enum op o = lhs->op;
+ if (ins->cls == lhs->cls && (o == Oadd || o == Osub || o == Oxor) && isintcon(lhs->r)) {
+ uvlong c = intconval(ins->r), q = intconval(lhs->r);
+ switch (o) { default: assert(0);
+ case Oadd: c -= q; break; /* x + 3 == C ==> x == C - 3 */
+ case Osub: c += q; break; /* x - 3 == C ==> x == C + 3 */
+ case Oxor: c ^= q; break; /* x ^ 3 == C ==> x == C ^ 3 */
+ }
+ ins->l = lhs->l, ins->r = mkintcon(ins->cls, c);
+ deluse(blk, ins - instrtab, mkref(RTMP, lhs - instrtab));
+ if (!instruse[lhs - instrtab]) *lhs = mkinstr(Onop,0,);
+ return 1;
+ }
+ }
+ }
+ return 0;
+}
+
+static void
+jmpfind(struct block **final, struct block **pblk)
+{
+ struct block **p2 = &final[(*pblk)->id];
+ if (*p2 && !(*p2)->phi.n) {
+ jmpfind(final, p2);
+ *pblk = *p2;
+ }
+}
+
+static void
+fillpredsrec(struct block *blk)
+{
+ while (blk && !wasvisited(blk)) {
+ markvisited(blk);
+ if (!blk->s1) return;
+ if (!blk->s1->phi.n) addpred(blk->s1, blk);
+ if (!blk->s2) {
+ blk = blk->s1;
+ continue;
+ }
+ if (!blk->s2->phi.n) addpred(blk->s2, blk);
+ fillpredsrec(blk->s1);
+ blk = blk->s2;
+ }
+}
+
+static void
+fillpreds(struct function *fn)
+{
+ struct block *blk = fn->entry, *next;
+ do {
+ if (blk->phi.n) continue;
+ if (blk->npred > 1)
+ xbfree(blk->_pred);
+ blk->npred = 0;
+ blk->_pred = NULL;
+ } while ((blk = blk->lnext) != fn->entry);
+ startbbvisit();
+ fillpredsrec(fn->entry);
+ blk = fn->entry->lnext;
+ do { /* remove dead blocks */
+ next = blk->lnext;
+ if (!blk->npred) freeblk(fn, blk);
+ } while ((blk = next) != fn->entry);
+}
+
+static void
+mergeblks(struct function *fn, struct block *p, struct block *s)
+{
+ assert(s->npred == 1 && !s->phi.n);
+ vpushn(&p->ins, s->ins.p, s->ins.n);
+ p->jmp = p->s1->jmp;
+ for (int i = 0; i < 2; ++i) {
+ struct block *ss = (&s->s1)[i];
+ if (ss) for (int j = 0; j < ss->npred; ++j) {
+ if (blkpred(ss, j) == s) {
+ blkpred(ss, j) = p;
+ break;
+ }
+ }
+ }
+ /* breaks use for temps used in s */
+ if (s->ins.n || s->jmp.arg[0].t == RTMP || s->jmp.arg[1].t == RTMP)
+ fn->prop &= ~FNUSE;
+ p->s1 = s->s1;
+ p->s2 = s->s2;
+ freeblk(fn, s);
+}
+
+void
+simpl(struct function *fn)
+{
+ FREQUIRE(FNUSE);
+ int inschange = 0,
+ blkchange = 0;
+ struct block **jmpfinal = allocz(fn->passarena, fn->nblk * sizeof *jmpfinal, 0);
+ struct block *blk = fn->entry;
+
+ do {
+ /* merge blocks:
+ * @blk:
+ * ... (1)
+ * b @s
+ * @s:
+ * ... (2)
+ * b {...}
+ *=>
+ * @blk:
+ * ... (1)
+ * ... (2)
+ * b {...}
+ * */
+ if (blk != fn->entry) while (blk->s1 && !blk->s2 && blk->s1->npred == 1 && !blk->phi.n) {
+ mergeblks(fn, blk, blk->s1);
+ }
+ } while ((blk = blk->lnext) != fn->entry);
+
+ if (!(fn->prop & FNUSE)) filluses(fn);
+
+ do {
+ for (int i = 0; i < blk->phi.n; ++i) {
+ int phi = blk->phi.p[i];
+ /* delete trivial phis */
+ union ref *args = phitab.p[instrtab[phi].l.i],
+ same = *args;
+ if (same.t == RTMP && instrtab[same.i].op == Ophi) goto Next;
+ if (blk->npred > 1) for (int j = 1; j < blk->npred; ++j) {
+ if (args[j].bits != same.bits) goto Next;
+ }
+ if (!(fn->prop & FNUSE)) filluses(fn);
+ replcuses(mkref(RTMP, phi), same);
+ delphi(blk, i);
+ Next:;
+ }
+
+ int curi = 0;
+ DoIns:
+ for (; curi < blk->ins.n; ++curi) {
+ struct instr *ins = &instrtab[blk->ins.p[curi]];
+ if (ins->op != Onop) {
+ if (!(fn->prop & FNUSE)) filluses(fn);
+ inschange += doins(ins, blk, &curi);
+ }
+ }
+
+ if (blk->s2 && isintcon(blk->jmp.arg[0])) {
+ /* simplify known conditional branch */
+ struct block *s = intconval(blk->jmp.arg[0]) ? blk->s1 : blk->s2;
+ delpred(s == blk->s1 ? blk->s2 : blk->s1, blk);
+ blk->s1 = s, blk->s2 = NULL;
+ blk->jmp.arg[0] = NOREF;
+ if (blk->s1 && !blk->s2 && blk->s1->npred == 1 && blk->s1->phi.n == 0) {
+ mergeblks(fn, blk, blk->s1);
+ goto DoIns;
+ }
+ }
+
+ if (blk != fn->entry && blk->npred == 0) {
+ freeblk(fn, blk);
+ } else if (!blk->phi.n && !blk->ins.n) { /* thread jumps.. */
+ if (blk->jmp.t == Jb && !blk->s2) {
+ jmpfind(jmpfinal, &blk->s1);
+ if (blk->s1 != blk) {
+ ++blkchange;
+ jmpfinal[blk->id] = blk->s1;
+ }
+ }
+ }
+ } while ((blk = blk->lnext) != fn->entry);
+
+ if (blkchange) {
+ do {
+ if (blk->s1) jmpfind(jmpfinal, &blk->s1);
+ if (blk->s2) jmpfind(jmpfinal, &blk->s2);
+ if (blk->s1 && blk->s1 == blk->s2) {
+ memset(blk->jmp.arg, 0, sizeof blk->jmp.arg);
+ blk->s2 = NULL;
+ }
+ } while ((blk = blk->lnext) != fn->entry);
+ fillpreds(fn);
+ sortrpo(fn);
+ }
+ if (!(fn->prop & FNUSE)) filluses(fn);
+ fn->prop &= ~FNDOM;
+}
+
+/* vim:set ts=3 sw=3 expandtab: */