aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--Makefile2
-rw-r--r--c/c.c3
-rw-r--r--common.h2
-rw-r--r--ir/inliner.c300
-rw-r--r--ir/ir.c21
-rw-r--r--ir/ir.h10
-rw-r--r--ir/simpl.c13
-rw-r--r--main.c6
-rw-r--r--test/16-inline.c22
9 files changed, 364 insertions, 15 deletions
diff --git a/Makefile b/Makefile
index 4aab646..7543ed8 100644
--- a/Makefile
+++ b/Makefile
@@ -1,6 +1,6 @@
SRC=main.c io.c mem.c c/c.c c/lex.c c/eval.c c/builtin.c type.c targ.c \
ir/ir.c ir/builder.c ir/fold.c ir/dump.c ir/ssa.c ir/cfg.c ir/intrin.c ir/abi0.c ir/mem2reg.c ir/regalloc.c ir/simpl.c ir/stack.c \
- ir/cse.c \
+ ir/cse.c ir/inliner.c \
x86_64/sysv.c x86_64/isel.c x86_64/emit.c \
aarch64/aapcs.c aarch64/isel.c aarch64/emit.c \
obj/obj.c obj/elf.c \
diff --git a/c/c.c b/c/c.c
index aa291e1..a7f48a4 100644
--- a/c/c.c
+++ b/c/c.c
@@ -4574,7 +4574,8 @@ tldecl(struct comp *cm)
decl.isdef = 1;
int idecl = putdecl(cm, &decl);
struct decl *d = &declsbuf.p[idecl];
- struct function fn = { &cm->fnarena, .name = d->sym, .globl = d->scls != SCSTATIC, .fnty = decl.ty, .retty = td->ret };
+ if (d->inlin && decl.scls != SCSTATIC) fatal(&d->span, "non-static inline is unimplemented");
+ struct function fn = { &cm->fnarena, .name = d->sym, .globl = d->scls != SCSTATIC, .fnty = decl.ty, .retty = td->ret, .inlin = d->inlin };
irinit(&fn);
function(cm, &fn, st.pnames, st.pspans, st.pqual);
if (!nerror && ccopt.dbg.p)
diff --git a/common.h b/common.h
index 81e34c5..0e10a8e 100644
--- a/common.h
+++ b/common.h
@@ -136,12 +136,14 @@ struct option {
enum optz {
OPT0 = -1,
OPT1 = 1,
+ OPT2 = 2,
} o;
union {
struct {
bool p : 1, /* after parsing */
a : 1, /* after abi0 */
m : 1, /* after mem */
+ y : 1, /* after inline */
o : 1, /* after optimizations */
s : 1, /* after stack */
i : 1, /* after isel */
diff --git a/ir/inliner.c b/ir/inliner.c
new file mode 100644
index 0000000..0853d26
--- /dev/null
+++ b/ir/inliner.c
@@ -0,0 +1,300 @@
+#include "ir.h"
+
+struct savedfunc {
+ int mark;
+ uint ninstr;
+ struct instr *instrtab;
+ struct xcon *conht;
+ struct call *calltab;
+ union ref **phitab;
+ struct block *entry;
+ union type fnty, retty;
+ struct abiarg *abiarg, abiret[2];
+ ushort nabiarg, nabiret;
+ ushort nretpoints;
+};
+
+enum { MAX_INLINED_FN_NINS = 50,
+ MAX_INLINED_FN_NBLK = 16, };
+
+static pmap_of(struct savedfunc *) savedfns;
+
+bool
+maybeinlinee(struct function *fn)
+{
+ static struct arena *savearena;
+ extern int ninstr;
+
+ // TODO better heuristics
+ if (ccopt.o < OPT1) return 0;
+ if (!fn->inlin && ccopt.o < OPT2) return 0;
+ if (ninstr > MAX_INLINED_FN_NINS) return 0;
+ if (fn->nblk > MAX_INLINED_FN_NBLK) return 0;
+ for (int i = 0; i < fn->nabiarg; ++i) {
+ /* TODO inlining functions with stack args */
+ if (fn->abiarg[i].isstk) return 0;
+ }
+ if (fn->nabiret > 1) return 0; /* TODO 2reg scalar return */
+
+ if (!savearena) {
+ enum { N = 1<<12 };
+ static union { char m[sizeof(struct arena) + N]; struct arena *_align; } amem;
+ savearena = (void *)amem.m;
+ savearena->cap = N;
+ }
+
+ if (ccopt.dbg.y) {
+ bfmt(ccopt.dbgout, "> stashing '%s' for inlining\n", fn->name);
+ }
+ struct savedfunc *sv = allocz(&savearena, sizeof *sv, 0);
+ sv->fnty = fn->fnty, sv->retty = fn->retty;
+ if (fn->abiarg)
+ sv->abiarg = alloccopy(&savearena, fn->abiarg, sizeof *sv->abiarg * fn->nabiarg, 0);
+ sv->nabiarg = fn->nabiarg;
+ if ((sv->nabiret = fn->nabiret) > 0)
+ memcpy(sv->abiret, fn->abiret, sizeof sv->abiret);
+ struct block *bmap[MAX_INLINED_FN_NBLK];
+ struct block *b = fn->entry;
+ int id = 0;
+ do {
+ b->id = id++;
+ struct block *q = alloccopy(&savearena, b, sizeof *b, 0);
+ if (q->phi.n)
+ q->phi.p = alloccopy(&savearena, q->phi.p, sizeof *q->phi.p * q->phi.n, 0);
+ if (q->ins.n)
+ q->ins.p = alloccopy(&savearena, q->ins.p, sizeof *q->ins.p * q->ins.n, 0);
+ if (q->npred > 1)
+ q->_pred = alloccopy(&savearena, q->_pred, sizeof *q->_pred * q->npred, 0);
+ q->lprev = NULL;
+ q->idom = NULL;
+ bmap[b->id] = q;
+ sv->nretpoints += b->jmp.t == Jret;
+ } while ((b = b->lnext) != fn->entry);
+ b = sv->entry = bmap[0];
+ do {
+ if (b->s1) b->s1 = bmap[b->s1->id];
+ if (b->s2) b->s2 = bmap[b->s2->id];
+ if (b->npred == 1)
+ b->_pred0 = bmap[b->_pred0->id];
+ else for (int i = 0; i < b->npred; ++i)
+ b->_pred[i] = bmap[b->_pred[i]->id];
+ b->lnext = b->lnext == fn->entry ? NULL : bmap[b->lnext->id];
+ } while ((b = b->lnext));
+
+ sv->instrtab = alloccopy(&savearena, instrtab, sizeof *instrtab * (sv->ninstr = ninstr), 0);
+ sv->conht = alloccopy(&savearena, conht, sizeof *conht * (1<<12), 0); // HACK
+ if (calltab.n) {
+ sv->calltab = alloccopy(&savearena, calltab.p, sizeof *calltab.p * calltab.n, 0);
+ for (int i = 0; i < calltab.n; ++i) {
+ if (sv->calltab[i].abiarg)
+ sv->calltab[i].abiarg = alloccopy(&savearena, sv->calltab[i].abiarg,
+ sv->calltab[i].narg * sizeof *sv->calltab[i].abiarg, 0);
+ }
+ }
+ if (phitab.n) {
+ sv->phitab = alloc(&savearena, sizeof *phitab.p * phitab.n, 0);
+ for (int i = 0; i < phitab.n; ++i) {
+ sv->phitab[i] = alloccopy(&savearena, phitab.p[i], sizeof *phitab.p[i] * xbcap(phitab.p[i]), 0);
+ }
+ }
+ pmap_set(&savedfns, fn->name, sv);
+ return 1;
+}
+
+static union ref
+mapref(short *instrmap, struct savedfunc *sv, union ref r)
+{
+ int newcon(const struct xcon *con);
+ if (r.t == RTMP) return r.i = instrmap[r.i], r;
+ if (r.t == RXCON) return r.i = newcon(&sv->conht[r.i]), r;
+ assert(r.t != RADDR);
+ assert(r.t != RSTACK);
+ return r;
+}
+
+static struct block *
+inlcall(struct function *fn, struct block *blk, int curi, struct savedfunc *sv)
+{
+ int res = blk->ins.p[curi], res2;
+ struct instr *ins = &instrtab[res];
+ struct call *call = &calltab.p[ins->r.i];
+ union ref retvals[64];
+ union ref args[64];
+ assert(sv->nabiret < 2 && sv->nretpoints < countof(retvals));
+ for (int n = call->narg, i = curi-1; n > 0; --i) {
+ assert(i >= 0);
+ struct instr *ins = &instrtab[blk->ins.p[i]];
+ if (ins->op == Oarg) {
+ args[--n] = ins->r;
+ *ins = mkinstr(Onop,0,);
+ }
+ }
+ if (call->abiret[1].ty.bits) {
+ assert(curi+1 < blk->ins.n);
+ res2 = blk->ins.p[curi+1];
+ assert(instrtab[res2].op == Ocall2r);
+ }
+ struct block *exit = blksplitafter(fn, blk, curi-1);
+ if (!ins->cls) {
+ *ins = mkinstr(Onop,0,);
+ } else {
+ if (sv->nretpoints == 1) {
+ *ins = mkinstr(Ocopy, ins->cls, );
+ } else {
+ /* turn into a phi */
+ *ins = mkinstr(Ophi, ins->cls, );
+ exit->ins.p[0] = newinstr(blk, mkinstr(Onop,0,));
+ vpush(&exit->phi, res);
+ }
+ }
+
+ struct block *bmap[MAX_INLINED_FN_NBLK];
+ short instrmap[MAX_INLINED_FN_NINS];
+ for (struct block *b = sv->entry; b; b = b->lnext) {
+ bmap[b->id] = newblk(fn);
+ }
+ for (int i = 0; i < sv->ninstr; ++i) {
+ /* TODO don't wastefully allocate for tombstone instructions
+ * - mark instructions some way when they are freed (put in instrfreelist)
+ * */
+ int allocinstr(void);
+ instrmap[i] = allocinstr();
+ }
+
+ exit->npred = 0;
+ exit->_pred0 = NULL;
+ blk->s1 = bmap[0];
+ int iret = 0;
+ for (struct block *b = sv->entry, *prev = blk, *new; b; prev = new, b = b->lnext) {
+ new = bmap[b->id];
+ new->id = fn->nblk++;
+ new->lprev = prev;
+ prev->lnext = new;
+ new->lnext = exit;
+ exit->lprev = new;
+ if (b->npred == 1) new->_pred0 = bmap[b->_pred0->id];
+ else if (b->npred > 0) {
+ xbgrow(&new->_pred, b->npred);
+ for (int i = 0; i < b->npred; ++i)
+ new->_pred[i] = bmap[b->_pred[i]->id];
+ }
+ new->npred = b->npred;
+ vresize(&new->phi, b->phi.n);
+ for (int i = 0; i < b->phi.n; ++i) {
+ int t = b->phi.p[i];
+ union ref *refs = NULL,
+ *src = sv->phitab[sv->instrtab[t].l.i];
+ xbgrow(&refs, b->npred);
+ for (int i = 0; i < b->npred; ++i)
+ refs[i] = mapref(instrmap, sv, src[i]);
+ vpush(&phitab, refs);
+ instrtab[instrmap[t]] = mkinstr(Ophi, sv->instrtab[t].cls, .l.i = phitab.n-1);
+ new->phi.p[i] = instrmap[t];
+ }
+ vresize(&new->ins, b->ins.n);
+ for (int i = 0; i < b->ins.n; ++i) {
+ int t = b->ins.p[i];
+ struct instr *ins = &instrtab[instrmap[t]];
+ *ins = sv->instrtab[t];
+ if (ins->op == Oparam) {
+ ins->op = Ocopy;
+ assert(ins->l.t == RICON && (uint) ins->l.i < call->narg);
+ ins->l = args[ins->l.i];
+ } else if (ins->op == Ocall || ins->op == Ointrin) {
+ ins->l = mapref(instrmap, sv, ins->l);
+ for (struct instr *ins2;
+ ins->l.t == RTMP && (ins2 = &instrtab[ins->l.i])->op == Ocopy
+ && (isaddrcon(ins2->l,0) || (ins2->l.t == RTMP && instrtab[ins->l.i].cls == KPTR));) {
+ /* for an indirect function call, eagerly copy-propagate the callee. this allows
+ * subsequent inlining of function pointers statically known after the inlining */
+ ins->l = ins2->l;
+ }
+ vpush(&calltab, sv->calltab[ins->r.i]);
+ ins->r.i = calltab.n-1;
+ } else {
+ if (ins->l.t) ins->l = mapref(instrmap, sv, ins->l);
+ if (ins->r.t) ins->r = mapref(instrmap, sv, ins->r);
+ }
+ new->ins.p[i] = instrmap[t];
+ }
+ if (b->jmp.t == Jret) {
+ new->jmp.t = Jb;
+ new->s1 = exit;
+ retvals[iret++] = mapref(instrmap, sv, b->jmp.arg[0]);
+ addpred(exit, new);
+ } else {
+ new->jmp.t = b->jmp.t;
+ for (int i = 0; i < 2; ++i) {
+ if (!b->jmp.arg[i].bits) break;
+ new->jmp.arg[i] = mapref(instrmap, sv, b->jmp.arg[i]);
+ }
+ }
+ if (b->s1) {
+ new->s1 = bmap[b->s1->id];
+ if (b->s2) new->s2 = bmap[b->s2->id];
+ }
+ }
+ exit->id = fn->nblk;
+ fn->prop &= ~FNUSE;
+ if (ins->cls && sv->nretpoints > 0) {
+ assert(sv->nretpoints == exit->npred);
+ if (sv->nretpoints == 1) {
+ /* fill copy */
+ ins->l = retvals[0];
+ } else {
+ /* fill phi */
+ union ref *refs = NULL;
+ xbgrow(&refs, sv->nretpoints);
+ memcpy(refs, retvals, sizeof *refs * sv->nretpoints);
+ vpush(&phitab, refs);
+ ins->l = mkref(RXXX, phitab.n-1);
+ }
+ }
+ bmap[0]->_pred0 = blk;
+ bmap[0]->npred = 1;
+ return exit;
+}
+
+void
+doinline(struct function *fn)
+{
+ if (calltab.n == 0) return;
+ struct block *b = fn->entry;
+ static int visitmark = 1; /* stops infinite recursion */
+ struct block *vnext = NULL;
+ bool dumpbefore = 0;
+ do {
+ if (b == vnext) ++visitmark;
+ for (int i = 0; i < b->ins.n; ++i) {
+ struct instr *ins = &instrtab[b->ins.p[i]];
+ if (ins->op != Ocall) continue;
+ if (!isaddrcon(ins->l,0)) continue;
+ struct call *call = &calltab.p[ins->r.i];
+ internstr fname = xcon2sym(ins->l.i);
+ struct savedfunc **pcallee, *sv;
+ if ((pcallee = pmap_get(&savedfns, fname))
+ && (sv = *pcallee)->nabiarg == call->narg && call->vararg == -1
+ && sv->mark != visitmark
+ && call->narg == sv->nabiarg
+ && (!call->narg || !memcmp(sv->abiarg, call->abiarg, sizeof *sv->abiarg * sv->nabiarg))
+ && !memcmp(sv->abiret, call->abiret, sizeof sv->abiret)) {
+ sv->mark = visitmark;
+ if (ccopt.dbg.y) {
+ if (!dumpbefore) {
+ bfmt(ccopt.dbgout, "<< Before inlining >>\n");
+ irdump(fn);
+ dumpbefore = 1;
+ }
+ }
+ vnext = inlcall(fn, b, i, *pcallee);
+ if (ccopt.dbg.y) {
+ bfmt(ccopt.dbgout, "<< After inlining '%s' >>\n", fname);
+ irdump(fn);
+ }
+ break;
+ }
+ }
+ } while ((b = b->lnext) != fn->entry);
+}
+
+/* vim:set ts=3 sw=3 expandtab: */
diff --git a/ir/ir.c b/ir/ir.c
index 83861af..01d8060 100644
--- a/ir/ir.c
+++ b/ir/ir.c
@@ -100,7 +100,7 @@ newaddr(const struct addr *addr)
}
}
-static int
+int
newcon(const struct xcon *con)
{
uint h = hashb(0, con, sizeof *con);
@@ -292,8 +292,8 @@ freeblk(struct function *fn, struct block *blk)
vfree(&blk->ins);
if (blk->id != -1)
--fn->nblk;
- if (blk->lnext) blk->lnext->lprev = blk->lprev;
if (blk->lprev) blk->lprev->lnext = blk->lnext;
+ if (blk->lnext) blk->lnext->lprev = blk->lprev;
blk->id = 1u<<31;
}
@@ -330,9 +330,14 @@ blksplitafter(struct function *fn, struct block *blk, int idx)
new->lnext = blk->lnext;
blk->lnext = new;
new->lnext->lprev = new;
- if (idx < blk->ins.n-1)
- vpushn(&new->ins, &blk->ins.p[idx+1], blk->ins.n-idx-1);
- blk->ins.n -= new->ins.n;
+ if (idx == -1) {
+ new->ins = blk->ins;
+ memset(&blk->ins, 0, sizeof blk->ins);
+ } else {
+ if (idx < blk->ins.n-1)
+ vpushn(&new->ins, &blk->ins.p[idx+1], blk->ins.n-idx-1);
+ blk->ins.n -= new->ins.n;
+ }
new->jmp = blk->jmp;
blk->jmp.t = Jb;
memset(blk->jmp.arg, 0, sizeof blk->jmp.arg);
@@ -647,12 +652,17 @@ irfini(struct function *fn)
copyopt(fn);
}
if (ccopt.o >= OPT1) {
+ doinline(fn);
filldom(fn);
+ if (!(fn->prop & FNUSE)) filluses(fn);
cselim(fn);
freearena(fn->passarena);
simpl(fn);
freearena(fn->passarena);
}
+ if (maybeinlinee(fn)) {
+ // goto Fin; XXX do this by having inline function rematerialization when symbol is actually referenced
+ }
lowerstack(fn);
freearena(fn->passarena);
if (ccopt.dbg.o) {
@@ -665,6 +675,7 @@ irfini(struct function *fn)
if (objout.code)
mctarg->emit(fn);
+Fin:
freearena(fn->passarena);
freefn(fn);
}
diff --git a/ir/ir.h b/ir/ir.h
index d32885e..73a50bd 100644
--- a/ir/ir.h
+++ b/ir/ir.h
@@ -185,8 +185,10 @@ struct function {
ushort nabiarg, nabiret;
bool globl;
bool isleaf;
+ bool inlin;
regset regusage;
};
+
#define FREQUIRE(_prop) assert((fn->prop & (_prop)) == (_prop) && "preconditions not met")
enum objkind { OBJELF };
@@ -311,6 +313,9 @@ bool foldunop(union ref *to, enum op, enum irclass, union ref);
/** irdump.c **/
void irdump(struct function *);
+/** mem2reg.c **/
+void mem2reg(struct function *);
+
/** ssa.c **/
void copyopt(struct function *);
@@ -328,8 +333,9 @@ void simpl(struct function *);
/** cselim.c **/
void cselim(struct function *);
-/** mem2reg.c **/
-void mem2reg(struct function *);
+/** inliner.c **/
+bool maybeinlinee(struct function *);
+void doinline(struct function *);
/** intrin.c **/
void lowerintrin(struct function *);
diff --git a/ir/simpl.c b/ir/simpl.c
index c52f809..a01879a 100644
--- a/ir/simpl.c
+++ b/ir/simpl.c
@@ -82,7 +82,7 @@ divmodk(struct instr *ins, struct block *blk, int *curi)
}
static int
-ins(struct instr *ins, struct block *blk, int *curi)
+doins(struct instr *ins, struct block *blk, int *curi)
{
int narg = opnarg[ins->op];
if (oisarith(ins->op)) {
@@ -257,7 +257,11 @@ simpl(struct function *fn)
int curi = 0;
DoIns:
for (; curi < blk->ins.n; ++curi) {
- inschange += ins(&instrtab[blk->ins.p[curi]], blk, &curi);
+ struct instr *ins = &instrtab[blk->ins.p[curi]];
+ if (ins->op != Onop) {
+ if (!(fn->prop & FNUSE)) filluses(fn);
+ inschange += doins(ins, blk, &curi);
+ }
}
if (blk->s2 && isintcon(blk->jmp.arg[0])) {
@@ -272,8 +276,9 @@ simpl(struct function *fn)
}
}
- /* thread jumps.. */
- if (!blk->phi.n && !blk->ins.n) {
+ if (blk != fn->entry && blk->npred == 0) {
+ freeblk(fn, blk);
+ } else if (!blk->phi.n && !blk->ins.n) { /* thread jumps.. */
if (blk->jmp.t == Jb && !blk->s2) {
jmpfind(jmpfinal, &blk->s1);
if (blk->s1 != blk) {
diff --git a/main.c b/main.c
index 74e61e4..47664c2 100644
--- a/main.c
+++ b/main.c
@@ -172,6 +172,7 @@ optparse(char **args)
case 'o': ccopt.dbg.o = 1; break;
case 's': ccopt.dbg.s = 1; break;
case 'i': ccopt.dbg.i = 1; break;
+ case 'y': ccopt.dbg.y = 1; break;
case 'l': ccopt.dbg.l = 1; break;
case 'r': ccopt.dbg.r = 1; break;
default: warn(NULL, "-d: invalid debug flag %'c", *arg);
@@ -219,7 +220,8 @@ optparse(char **args)
/* TODO debug info */
} else if (*arg == 'O') {
if (!arg[1]) ccopt.o = 0; /* default opts */
- else if ((uint)arg[1] - '1' < 9) ccopt.o = OPT1;
+ else if (arg[1] == '1') ccopt.o = OPT1;
+ else if ((uint)arg[1] - '1' < 9) ccopt.o = OPT2;
else if (arg[1] == '0') ccopt.o = OPT0;
else goto Bad;
} else if (*arg == 'D' || *arg == 'U') {
@@ -651,7 +653,7 @@ prihelp(void)
" -help \tPrint this help message\n"
" -std=<..> \tSet C standard (c89, c99, c11, c23)\n"
" -pedantic \tWarnings for strict standards compliance\n"
- " -d{pamosilr} \tDebug print IR after {parse, abi, mem, opts, stack, isel, live, rega}\n"
+ " -d{pamyosilr} \tDebug print IR after {parse, abi, mem, inlining, opts, stack, isel, live, rega}\n"
" -o <file> \tPlace the output into <file>\n"
" -v \tVerbose output\n"
" -c \tEmit object file but do not link\n"
diff --git a/test/16-inline.c b/test/16-inline.c
new file mode 100644
index 0000000..983d9b2
--- /dev/null
+++ b/test/16-inline.c
@@ -0,0 +1,22 @@
+
+/* CFLAGS: -O1 */
+/* EXPECT:
+25;0
+*/
+
+static inline int sqr(int x) {return x*x;}
+static inline int foo(int);
+static inline int ind(int (*f)(int), int arg) {
+ return f(arg);
+}
+
+#include <stdio.h>
+int main() {
+ int q = ind(sqr, 5);
+ printf("%d;%d\n", q, ind(foo,-2));
+}
+
+static inline int foo(int w) {
+ return w+2;
+}
+