From 75db572e8a7a144b3d6fdc2b148ac3921f77f2b2 Mon Sep 17 00:00:00 2001 From: lemon Date: Tue, 16 Sep 2025 17:24:52 +0200 Subject: implement switch statement --- c.c | 132 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 112 insertions(+), 20 deletions(-) (limited to 'c.c') diff --git a/c.c b/c.c index 7b64b8f..83a4df5 100644 --- a/c.c +++ b/c.c @@ -11,7 +11,8 @@ struct comp { struct arena *fnarena, *exarena; struct span fnblkspan; uint loopdepth, switchdepth; - struct block *loopbreak, *loopcont; + struct block *breakto, *loopcont; + struct switchstmt *switchstmt; struct label *labels; }; @@ -672,7 +673,7 @@ tkprec(int tt) /* param ident is a kludge to support block labels without backtracking or extra lookahead * see stmt() */ static struct expr -exprparse(struct comp *cm, int prec, const struct token *ident) +exprparse(struct comp *cm, int prec, const struct token *ident, bool fromstmt) { struct token tk, tk2; struct span span; @@ -710,7 +711,7 @@ Unary: unops[nunop].t0 = 0; unops[nunop].tt = tk.t; if (++nunop >= arraylength(unops)) { - ex = exprparse(cm, 999, NULL); + ex = exprparse(cm, 999, NULL, 0); break; } goto Unary; @@ -761,7 +762,7 @@ Unary: unops[nunop].span = tk.span; unops[nunop].ty = decl.ty; if (++nunop >= arraylength(unops)) { - ex = exprparse(cm, 999, NULL); + ex = exprparse(cm, 999, NULL, 0); break; } goto Unary; @@ -783,7 +784,7 @@ Unary: ex = mkexpr(ENUMLIT, span, mktype(targ_sizetype), .u = typesize(ty)); break; default: - fatal(&tk.span, "expected expression (near %'tk)", &tk); + fatal(&tk.span, "expected %s (near %'tk)", fromstmt ? "statement" : "expression", &tk); } /* postfix operators */ @@ -945,7 +946,7 @@ Postfix: /* ex OP rhs */ span.sl = tk.span.sl; span.ex = ex.span.ex; - rhs = exprparse(cm, opprec + leftassoc, NULL); + rhs = exprparse(cm, opprec + leftassoc, NULL, 0); if (!joinspan(&span.ex, tk.span.ex) || !joinspan(&span.ex, rhs.span.ex)) span.ex = tk.span.ex; ty = bintypecheck(&span, tk.t, &ex, &rhs); @@ -982,19 +983,19 @@ Postfix: static struct expr expr(struct comp *cm) { - return exprparse(cm, bintab['='].prec, NULL); /* non-comma expr */ + return exprparse(cm, bintab['='].prec, NULL, 0); /* non-comma expr */ } static struct expr constantexpr(struct comp *cm) { - return exprparse(cm, bintab['?'].prec, NULL); /* conditional-expr */ + return exprparse(cm, bintab['?'].prec, NULL, 0); /* conditional-expr */ } static struct expr commaexpr(struct comp *cm) { - return exprparse(cm, 1, NULL); + return exprparse(cm, 1, NULL, 0); } /****************/ @@ -3114,18 +3115,79 @@ loopbody(struct comp *cm, struct function *fn, struct block *brk, struct block * struct block *save[2]; bool terminates = 0; - save[0] = cm->loopbreak, save[1] = cm->loopcont; - cm->loopbreak = brk, cm->loopcont = cont; + save[0] = cm->breakto, save[1] = cm->loopcont; + cm->breakto = brk, cm->loopcont = cont; ++cm->loopdepth; terminates = stmt(cm, fn); --cm->loopdepth; - cm->loopbreak = save[0], cm->loopcont = save[1]; + cm->breakto = save[0], cm->loopcont = save[1]; return terminates; } +#define EMITS if (doemit && !nerror) + +struct swcase { + vlong val; + struct block *blk; +}; +struct switchstmt { + struct block *bdefault; + vec_of(struct swcase) cases; +}; + +static void +genswitch(struct comp *cm, struct function *fn, const struct expr *ex) +{ + union ref sel; + bool doemit = fn->curblk; + struct block *begin = NULL, *end = NULL, *breaksave = cm->breakto; + struct switchstmt *stsave = cm->switchstmt, st = {0}; + enum irclass k = type2cls[ex->ty.t]; + struct swcase casebuf[8]; + vinit(&st.cases, casebuf, arraylength(casebuf)); + + end = newblk(fn); + EMITS { + sel = exprvalue(fn, ex); + assert(isint(ex->ty)); + } + cm->switchstmt = &st; + cm->breakto = end; + begin = fn->curblk; + fn->curblk = NULL; + stmt(cm, fn); + doemit = fn->curblk; + cm->switchstmt = stsave; + cm->breakto = breaksave; + + EMITS putbranch(fn, end); + useblk(fn, begin); + doemit = 1; + if (!st.bdefault) st.bdefault = end; + /* TODO: optimize instead of generating the equivalent of if == .. else if .. chain + * 1. sort by case values (also for easy duplicates checking) + * 2. contiguous ranges (case a..b: -> x >= && x <= b) + * 3. binary search + * 4. jump tables? (harder) + */ + for (int i = 0; i < st.cases.n; ++i) { + struct swcase c = st.cases.p[i]; + EMITS { + struct block *next = i < st.cases.n - 1 ? newblk(fn) : st.bdefault; + putcondbranch(fn, addinstr(fn, mkinstr(Oequ, k, .l = sel, .r = mkintcon(k, c.val))), c.blk, next); + useblk(fn, next); + } + } + vfree(&st.cases); + if (fn->curblk != end) { + EMITS putbranch(fn, end); + useblk(fn, end); + } +} + static bool /* return 1 if stmt is terminating (ends with a jump) */ stmt(struct comp *cm, struct function *fn) { @@ -3142,20 +3204,41 @@ stmt(struct comp *cm, struct function *fn) bool terminates = 0; bool doemit = fn->curblk; -#define EMITS if (doemit && !nerror) - - while (match(cm, &tk, TKIDENT)) { - if (match(cm, NULL, ':')){ + while (match(cm, &tk, TKIDENT) || match(cm, &tk, TKWcase) || match(cm, &tk, TKWdefault)) { + if (tk.t == TKWcase) { + /* case ':' */ + if (!cm->switchstmt) error(&tk.span, "'case' outside of switch statement"); + ex = commaexpr(cm); + if (!eval(&ex, EVINTCONST)) error(&ex.span, "not an integer constant expression"); + expect(cm, ':', NULL); + begin = newblk(fn); + EMITS putbranch(fn, begin); + useblk(fn, begin); + if (cm->switchstmt) + vpush(&cm->switchstmt->cases, ((struct swcase) {ex.i, fn->curblk})); + } else if (tk.t == TKWdefault) { + /* default ':' */ + if (!cm->switchstmt) error(&tk.span, "'default' outside of switch statement"); + expect(cm, ':', NULL); + begin = newblk(fn); + EMITS putbranch(fn, begin); + useblk(fn, begin); + if (cm->switchstmt) { + if (cm->switchstmt->bdefault) error(&tk.span, "multiple 'default' labels in one switch"); + cm->switchstmt->bdefault = begin; + } + } else if (tk.t == TKIDENT && match(cm, NULL, ':')) { /*