aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--ir.c21
-rw-r--r--ir.h1
-rw-r--r--parse.c166
-rw-r--r--test.c15
4 files changed, 158 insertions, 45 deletions
diff --git a/ir.c b/ir.c
index f6cf3a6..3d2ad8e 100644
--- a/ir.c
+++ b/ir.c
@@ -210,11 +210,32 @@ addphi2(struct function *fn, enum irclass cls,
return mkref(RTMP, ninstr++);
}
+union ref
+addphi(struct function *fn, enum irclass cls, struct block **blk, union ref *ref, uint n)
+{
+ struct phi phi = { .n = n, .cap = -1 };
+ struct instr ins = { Ophi, cls };
+ assert(n > 0);
+ phi.blk = alloc(&fn->arena, n*sizeof(struct block *) + n*sizeof(union ref), 0);
+ phi.ref = (union ref *)((char *)phi.blk + n*sizeof(struct block *));
+ memcpy(phi.blk, blk, n * sizeof(struct block *));
+ memcpy(phi.ref, ref, n * sizeof(union ref));
+ vpush(&phis, phi);
+ ins.l = mkref(REXT, phis.n-1);
+ assert(ninstr < arraylength(instr));
+ assert(fn->curblk != NULL);
+ assert(fn->curblk->ins.n == 0);
+ instr[ninstr] = ins;
+ vpush(&fn->curblk->phi, ninstr);
+ return mkref(RTMP, ninstr++);
+}
+
struct block *
newblk(struct function *fn)
{
struct block *blk = alloc(&fn->arena, sizeof(struct block), 0);
memset(blk, 0, sizeof *blk);
+ blk->id = -1;
return blk;
}
diff --git a/ir.h b/ir.h
index eee4859..b8095d9 100644
--- a/ir.h
+++ b/ir.h
@@ -117,6 +117,7 @@ union ref mkcall(struct function *, union type fnty, uint narg, union ref *, uni
union ref addinstr(struct function *, struct instr);
union ref addphi2(struct function *, enum irclass cls,
struct block *b1, union ref r1, struct block *b2, union ref r2);
+union ref addphi(struct function *, enum irclass cls, struct block **blk, union ref *ref, uint n);
struct block *newblk(struct function *);
void useblk(struct function *, struct block *);
void putjump(struct function *, enum jumpkind, union ref arg, struct block *t, struct block *f);
diff --git a/parse.c b/parse.c
index bc9af99..16293f9 100644
--- a/parse.c
+++ b/parse.c
@@ -1119,12 +1119,107 @@ genptrdiff(struct function *fn, uint siz, union ref a, union ref b)
return addinstr(fn, (struct instr) { Odiv, cls, a, mkintcon(fn, cls, siz) });
}
+/* used to emit the jumps in an in if (), while (), etc condition */
+static void
+condjump(struct function *fn, const struct expr *ex, struct block *tr, struct block *fl)
+{
+ struct block *next, *next2;
+Loop:
+ while (ex->t == ESEQ) {
+ (void)exprvalue(fn, &ex->sub[0]);
+ ex = &ex->sub[1];
+ }
+ if (ex->t == ELOGAND) {
+ next = newblk(fn);
+ condjump(fn, &ex->sub[0], next, fl);
+ useblk(fn, next);
+ ex = &ex->sub[1];
+ goto Loop;
+ } else if (ex->t == ELOGIOR) {
+ next = newblk(fn);
+ condjump(fn, &ex->sub[0], tr, next);
+ useblk(fn, next);
+ ex = &ex->sub[1];
+ goto Loop;
+ } else if (ex->t == ECOND) {
+ next = newblk(fn);
+ next2 = newblk(fn);
+ condjump(fn, &ex->sub[0], next, next2);
+ useblk(fn, next);
+ condjump(fn, &ex->sub[1], tr, fl);
+ useblk(fn, next2);
+ condjump(fn, &ex->sub[2], tr, fl);
+ } else {
+ putjump(fn, Jbcnd, exprvalue(fn, ex), tr, fl);
+ }
+}
+
+struct condphis {
+ vec_of(struct block *) blk;
+ vec_of(union ref) ref;
+};
+
+static void
+condexprrec(struct function *fn, const struct expr *ex, bool tobool,
+ struct condphis *phis, struct block *end, struct block *zero)
+{
+ struct block *tr, *fl, *next;
+ union ref r;
+ while (ex->t == ESEQ) {
+ (void)exprvalue(fn, &ex->sub[0]);
+ ex = &ex->sub[1];
+ }
+ if (ex->t == ELOGAND) {
+ next = newblk(fn);
+ condexprrec(fn, &ex->sub[0], 0, phis, next, end);
+ useblk(fn, next);
+ condexprrec(fn, &ex->sub[1], 1, phis, end, zero);
+ } else if (ex->t == ELOGIOR) {
+ next = newblk(fn);
+ condexprrec(fn, &ex->sub[0], 1, phis, end, next);
+ useblk(fn, next);
+ condexprrec(fn, &ex->sub[1], 1, phis, end, zero);
+ } else if (ex->t == ECOND) {
+ tr = newblk(fn);
+ fl = newblk(fn);
+ condjump(fn, &ex->sub[0], tr, fl);
+ useblk(fn, tr);
+ condexprrec(fn, &ex->sub[1], 0, phis, end, zero);
+ useblk(fn, fl);
+ condexprrec(fn, &ex->sub[2], 0, phis, end, zero);
+ } else {
+ r = exprvalue(fn, ex);
+ if (tobool) r = cvt(fn, TYBOOL, ex->ty.t, r);
+ vpush(&phis->blk, fn->curblk);
+ vpush(&phis->ref, r);
+ if (zero)
+ putjump(fn, Jbcnd, r, end, zero);
+ else
+ putjump(fn, Jb, NOREF, end, NULL);
+ }
+}
+
+/* the naive way to generate something like a ? b : c ? d : e, uses multiple phis,
+ * this code reduces such nested conditional expressions into one phi */
+static union ref
+condexprvalue(struct function *fn, const struct expr *ex)
+{
+ struct condphis phis = {0};
+ struct block *dst = newblk(fn);
+ union ref r;
+ condexprrec(fn, ex, 0, &phis, dst, NULL);
+ useblk(fn, dst);
+ r = addphi(fn, type2cls[ex->ty.t], phis.blk.p, phis.ref.p, phis.blk.n);
+ vfree(&phis.blk);
+ vfree(&phis.ref);
+ return r;
+}
+
static union ref
exprvalue(struct function *fn, const struct expr *ex)
{
union type ty;
union ref r, q;
- struct block *tr, *fl, *end;
enum irclass cls = type2cls[ex->ty.t];
struct instr ins = {0};
int swp = 0;
@@ -1145,7 +1240,11 @@ exprvalue(struct function *fn, const struct expr *ex)
case EGETF:
return genload(fn, ex->ty, expraddr(fn, ex));
case ECAST:
- if (ex->ty.t == TYVOID) return exprvalue(fn, sub);
+ if (ex->ty.t == TYVOID) {
+ (void)exprvalue(fn, sub);
+ return NOREF;
+ }
+ /* fallthru */
case EPLUS:
return cvt(fn, ex->ty.t, sub->ty.t, exprvalue(fn, sub));
case ENEG:
@@ -1341,41 +1440,23 @@ exprvalue(struct function *fn, const struct expr *ex)
return addinstr(fn, ins);
}
case ECOND:
- r = exprvalue(fn, &sub[0]);
- tr = newblk(fn);
- fl = newblk(fn);
- end = newblk(fn);
- putjump(fn, Jbcnd, r, tr, fl);
- useblk(fn, tr);
- r = cvt(fn, ex->ty.t, sub[1].ty.t, exprvalue(fn, &sub[1]));
- putjump(fn, Jb, NOREF, end, NULL);
- useblk(fn, fl);
- q = cvt(fn, ex->ty.t, sub[2].ty.t, exprvalue(fn, &sub[2]));
- putjump(fn, Jb, NOREF, end, NULL);
- useblk(fn, end);
- return addphi2(fn, type2cls[ex->ty.t], tr, r, fl, q);
+ if (ex->ty.t == TYVOID) {
+ struct block *tr, *fl, *end;
+ condjump(fn, &ex->sub[0], tr = newblk(fn), fl = newblk(fn));
+ useblk(fn, tr);
+ (void)exprvalue(fn, &ex->sub[1]);
+ end = newblk(fn);
+ putjump(fn, Jb, NOREF, end, NULL);
+ useblk(fn, fl);
+ (void)exprvalue(fn, &ex->sub[2]);
+ putjump(fn, Jb, NOREF, end, NULL);
+ useblk(fn, end);
+ return NOREF;
+ }
+ /* fallthru */
case ELOGAND:
- r = exprvalue(fn, &sub[0]);
- fl = fn->curblk;
- tr = newblk(fn);
- end = newblk(fn);
- putjump(fn, Jbcnd, r, tr, end);
- useblk(fn, tr);
- q = cvt(fn, TYBOOL, sub[1].ty.t, exprvalue(fn, &sub[1]));
- putjump(fn, Jb, NOREF, end, NULL);
- useblk(fn, end);
- return addphi2(fn, KI4, fl, mkref(RICON, 0), tr, q);
case ELOGIOR:
- r = exprvalue(fn, &sub[0]);
- tr = fn->curblk;
- fl = newblk(fn);
- end = newblk(fn);
- putjump(fn, Jbcnd, r, end, fl);
- useblk(fn, fl);
- q = cvt(fn, TYBOOL, sub[1].ty.t, exprvalue(fn, &sub[1]));
- putjump(fn, Jb, NOREF, end, NULL);
- useblk(fn, end);
- return addphi2(fn, KI4, tr, mkref(RICON, 1), fl, q);
+ return condexprvalue(fn, ex);
case ESEQ:
(void)exprvalue(fn, &sub[0]);
return exprvalue(fn, &sub[1]);
@@ -1426,11 +1507,8 @@ stmt(struct parser *pr, struct function *fn)
EMITS {
tr = newblk(fn);
fl = newblk(fn);
- r = exprvalue(fn, &ex);
- EMITS {
- putjump(fn, Jbcnd, r, tr, fl);
- useblk(fn, tr);
- }
+ condjump(fn, &ex, tr, fl);
+ useblk(fn, tr);
}
terminates = stmt(pr, fn);
if (!match(pr, NULL, TKWelse)) {
@@ -1461,11 +1539,8 @@ stmt(struct parser *pr, struct function *fn)
begin = newblk(fn);
putjump(fn, Jb, NOREF, begin, NULL);
useblk(fn, begin);
- r = exprvalue(fn, &ex);
- EMITS {
- putjump(fn, Jbcnd, r, tr = newblk(fn), end = newblk(fn));
- useblk(fn, tr);
- }
+ condjump(fn, &ex, tr = newblk(fn), end = newblk(fn));
+ useblk(fn, tr);
}
terminates = stmt(pr, fn);
EMITS {
@@ -1563,6 +1638,7 @@ block(struct parser *pr, struct function *fn)
}
break;
case SCTYPEDEF: break;
+ case SCEXTERN: break;
default: assert(0);
}
Err:
diff --git a/test.c b/test.c
index 461d42f..e16e415 100644
--- a/test.c
+++ b/test.c
@@ -17,4 +17,19 @@ int test0(struct foo *foo) { return foo->x ? foo->y : foo->z; }
int test1(int x, int y, int z) { return x && y || z; }
int test2(int x, int y, int z) { return x || y && z; }
+extern void f();
+int test3(int x, int y) {
+ if (x < 0 && y < 0 && 1) return x - y;
+ if ((x == 0 || x > 0) && y > 0) return y;
+ if (f(), x == y ? x && 1 : 0 || y) return x;
+ return x + y;
+}
+
+int test4(int c) {
+ return c == 'a' || c == 'x' ? 1
+ : (f(), c == 'b' || c == 'y') ? 2
+ : c == 'c' || c == 'z' ? 3
+ : 0;
+}
+
//