kit

kit
git clone https://git.ryansepassi.com/git/kit.git
Log | Files | Refs | README

commit c55dd9714a923b709a594fcabba01ed5a9447fc6
parent 0f85365d6934d1507a8da340e71118dc91b8ce4a
Author: Ryan Sepassi <rsepassi@gmail.com>
Date:   Thu, 14 May 2026 12:26:27 -0700

Improve O1 live info for regalloc

Diffstat:
Msrc/opt/ir.h | 10+++++++++-
Msrc/opt/pass_cfg.c | 1+
Msrc/opt/pass_lower.c | 136++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------------
Mtest/opt/opt_test.c | 215+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
4 files changed, 336 insertions(+), 26 deletions(-)

diff --git a/src/opt/ir.h b/src/opt/ir.h @@ -246,6 +246,7 @@ typedef struct Block { u64* live_out; u64* live_use; u64* live_def; + u64** live_after; /* live set after each instruction, length ninsts */ } Block; typedef enum OptAllocKind { @@ -258,7 +259,12 @@ typedef struct OptValInfo { u32 first_pos; u32 last_pos; u32 live_length; - u32 frequency; + u32 frequency; /* legacy aggregate priority score */ + u32 use_freq; + u32 def_freq; + u32 live_block_freq; + u32 live_across_call_freq; + u32 spill_cost; i32 tied_hard_reg; /* -1 means no fixed/tied physical register need. */ Reg hard_reg; FrameSlot spill_slot; @@ -314,8 +320,10 @@ typedef struct Func { u8 opt_has_target; u8 opt_rewritten; u16 opt_live_words; + u16 opt_conflict_words; u32 opt_position_count; OptValInfo* val_info; /* indexed by Val, length nvals */ + u64* val_conflicts; /* nvals x opt_conflict_words bit matrix */ Reg opt_hard_regs[OPT_REG_CLASSES][OPT_MAX_HARD_REGS]; u32 opt_hard_reg_count[OPT_REG_CLASSES]; diff --git a/src/opt/pass_cfg.c b/src/opt/pass_cfg.c @@ -106,6 +106,7 @@ static void prune_unreachable(Func* f, const u8* reachable) { bl->succ[0] = 0; bl->succ[1] = 0; bl->nsucc = 0; + bl->live_after = NULL; } u32 w = 0; diff --git a/src/opt/pass_lower.c b/src/opt/pass_lower.c @@ -172,11 +172,13 @@ static void collect_use_def(Func* f, Inst* in, Operand* op, int is_def, OptValInfo* vi = &f->val_info[v]; if (vi->first_pos == 0 || c->pos < vi->first_pos) vi->first_pos = c->pos; if (c->pos > vi->last_pos) vi->last_pos = c->pos; - vi->frequency += c->freq ? c->freq : 1u; + u32 freq = c->freq ? c->freq : 1u; vi->cls = f->val_cls[v]; if (is_def) { + vi->def_freq += freq; bit_set(c->def, v); } else { + vi->use_freq += freq; if (!bit_has(c->def, v)) bit_set(c->use, v); } } @@ -199,6 +201,40 @@ static void collect_bits(Func* f, Inst* in, Operand* op, int is_def, bit_set(c->use, v); } +static u64* conflict_row(Func* f, Val v) { + return f->val_conflicts + ((size_t)v * f->opt_conflict_words); +} + +static void add_conflict(Func* f, Val a, Val b) { + if (a == VAL_NONE || b == VAL_NONE || a == b) return; + if (a >= f->nvals || b >= f->nvals) return; + if (f->val_cls[a] != f->val_cls[b]) return; + bit_set(conflict_row(f, a), b); + bit_set(conflict_row(f, b), a); +} + +static void add_conflicts_with_set(Func* f, Val v, const u64* live, + const u64* skip) { + if (v == VAL_NONE || v >= f->nvals) return; + for (Val other = 1; other < f->nvals; ++other) { + if (skip && bit_has(skip, other)) continue; + if (bit_has(live, other)) add_conflict(f, v, other); + } +} + +static void add_pairwise_conflicts(Func* f, const u64* live) { + for (Val a = 1; a < f->nvals; ++a) { + if (!bit_has(live, a)) continue; + for (Val b = a + 1u; b < f->nvals; ++b) { + if (bit_has(live, b)) add_conflict(f, a, b); + } + } +} + +static void copy_bits(u64* dst, const u64* src, u32 words) { + for (u32 w = 0; w < words; ++w) dst[w] = src[w]; +} + void opt_machinize(Func* f, CGTarget* target) { f->opt_target = target->c->target; f->opt_has_target = 1; @@ -402,15 +438,25 @@ void opt_build_loop_tree(Func* f) { void opt_live_info(Func* f) { f->opt_live_words = (u16)bit_words(f->nvals); + f->opt_conflict_words = f->opt_live_words; f->opt_position_count = 0; u32 words = f->opt_live_words; - if (!f->val_info) f->val_info = arena_zarray(f->arena, OptValInfo, f->nvals); + int new_val_info = f->val_info == NULL; + if (new_val_info) f->val_info = arena_zarray(f->arena, OptValInfo, f->nvals); + f->val_conflicts = + arena_zarray(f->arena, u64, (size_t)f->nvals * f->opt_conflict_words); for (u32 v = 0; v < f->nvals; ++v) { + i32 tied = new_val_info ? -1 : f->val_info[v].tied_hard_reg; f->val_info[v].first_pos = 0; f->val_info[v].last_pos = 0; f->val_info[v].live_length = 0; f->val_info[v].frequency = 0; - f->val_info[v].tied_hard_reg = -1; + f->val_info[v].use_freq = 0; + f->val_info[v].def_freq = 0; + f->val_info[v].live_block_freq = 0; + f->val_info[v].live_across_call_freq = 0; + f->val_info[v].spill_cost = 0; + f->val_info[v].tied_hard_reg = tied; f->val_info[v].hard_reg = REG_NONE; f->val_info[v].spill_slot = FRAME_SLOT_NONE; f->val_info[v].alloc_kind = OPT_ALLOC_NONE; @@ -422,6 +468,7 @@ void opt_live_info(Func* f) { bl->live_out = arena_zarray(f->arena, u64, words); bl->live_use = arena_zarray(f->arena, u64, words); bl->live_def = arena_zarray(f->arena, u64, words); + bl->live_after = NULL; UseDefCtx c; memset(&c, 0, sizeof c); c.use = bl->live_use; @@ -459,14 +506,66 @@ void opt_live_info(Func* f) { if (vi->first_pos == 0) vi->first_pos = 1; if (vi->last_pos < f->opt_position_count) vi->last_pos = f->opt_position_count; - vi->frequency += bl->frequency; + vi->live_block_freq += bl->frequency; + } + } + } + + for (u32 b = 0; b < f->nblocks; ++b) { + Block* bl = &f->blocks[b]; + bl->live_after = + arena_array(f->arena, u64*, bl->ninsts ? bl->ninsts : 1u); + u64* live = arena_zarray(f->arena, u64, words); + copy_bits(live, bl->live_out, words); + + add_pairwise_conflicts(f, bl->live_in); + + for (u32 ri = bl->ninsts; ri > 0; --ri) { + u32 i = ri - 1u; + Inst* in = &bl->insts[i]; + u64* after = arena_zarray(f->arena, u64, words); + copy_bits(after, live, words); + bl->live_after[i] = after; + + u64* use = arena_zarray(f->arena, u64, words); + u64* def = arena_zarray(f->arena, u64, words); + BitsCtx bc = {use, def}; + walk_inst_operands(f, in, collect_bits, &bc); + + for (Val v = 1; v < f->nvals; ++v) { + if (bit_has(def, v)) { + add_conflicts_with_set(f, v, after, def); + for (Val other = v + 1u; other < f->nvals; ++other) + if (bit_has(after, v) && bit_has(def, other) && + bit_has(after, other)) + add_conflict(f, v, other); + } + if (bit_has(use, v)) { + add_conflicts_with_set(f, v, after, def); + } } + add_pairwise_conflicts(f, use); + + if ((IROp)in->op == IR_CALL) { + for (Val v = 1; v < f->nvals; ++v) { + if (bit_has(def, v)) continue; + if (bit_has(after, v)) + f->val_info[v].live_across_call_freq += bl->frequency; + } + } + + for (u32 w = 0; w < words; ++w) + live[w] = (live[w] & ~def[w]) | use[w]; } } + for (Val v = 1; v < f->nvals; ++v) { OptValInfo* vi = &f->val_info[v]; if (vi->first_pos && vi->last_pos >= vi->first_pos) vi->live_length = vi->last_pos - vi->first_pos + 1u; + vi->spill_cost = (vi->use_freq * 2u) + vi->def_freq + + vi->live_across_call_freq + vi->live_block_freq; + vi->frequency = vi->spill_cost; } } @@ -482,6 +581,8 @@ static int val_higher_priority(Func* f, Val a, Val b) { int bt = bv->tied_hard_reg >= 0; if (at != bt) return at > bt; if (av->frequency != bv->frequency) return av->frequency > bv->frequency; + if (av->live_across_call_freq != bv->live_across_call_freq) + return av->live_across_call_freq > bv->live_across_call_freq; if (av->live_length != bv->live_length) return av->live_length < bv->live_length; return a < b; } @@ -491,7 +592,11 @@ static int hard_conflicts(Func* f, Val* assigned, u32 nassigned, Val v, Reg r) { Val ov = assigned[i]; if (f->val_info[ov].alloc_kind != OPT_ALLOC_HARD) continue; if (f->val_info[ov].hard_reg != r) continue; - if (ranges_overlap(&f->val_info[ov], &f->val_info[v])) return 1; + if (f->val_conflicts && f->opt_conflict_words) { + if (bit_has(conflict_row(f, v), ov)) return 1; + } else if (ranges_overlap(&f->val_info[ov], &f->val_info[v])) { + return 1; + } } return 0; } @@ -625,29 +730,10 @@ static void rewrite_one_operand(Func* f, Inst* owner, Operand* op, int is_def, } } -static u64** compute_block_live_after(Func* f, Block* bl) { - u32 words = f->opt_live_words; - u64** live_after = arena_array(f->arena, u64*, bl->ninsts ? bl->ninsts : 1u); - u64* live = arena_zarray(f->arena, u64, words); - for (u32 w = 0; w < words; ++w) live[w] = bl->live_out[w]; - for (u32 ri = bl->ninsts; ri > 0; --ri) { - u32 i = ri - 1u; - live_after[i] = arena_zarray(f->arena, u64, words); - for (u32 w = 0; w < words; ++w) live_after[i][w] = live[w]; - u64* use = arena_zarray(f->arena, u64, words); - u64* def = arena_zarray(f->arena, u64, words); - BitsCtx bc = {use, def}; - walk_inst_operands(f, &bl->insts[i], collect_bits, &bc); - for (u32 w = 0; w < words; ++w) - live[w] = (live[w] & ~def[w]) | use[w]; - } - return live_after; -} - static void rewrite_func(Func* f) { for (u32 b = 0; b < f->nblocks; ++b) { Block* bl = &f->blocks[b]; - u64** live_after = compute_block_live_after(f, bl); + u64** live_after = bl->live_after; RewriteList out; memset(&out, 0, sizeof out); diff --git a/test/opt/opt_test.c b/test/opt/opt_test.c @@ -128,6 +128,10 @@ static Val add_val(Func* f, CfreeCgTypeId ty) { return ir_alloc_val(f, ty, RC_INT); } +static Val add_val_cls(Func* f, CfreeCgTypeId ty, RegClass cls) { + return ir_alloc_val(f, ty, cls); +} + static Inst* emit_load_imm(Func* f, u32 b, Val dst, CfreeCgTypeId ty, i64 imm) { Inst* in = ir_emit(f, b, IR_LOAD_IMM); @@ -211,6 +215,11 @@ static int bit_has(const u64* bits, Val v) { return (bits[v / 64u] & (1ull << (v % 64u))) != 0; } +static int val_conflicts(const Func* f, Val a, Val b) { + const u64* bits = f->val_conflicts + ((size_t)a * f->opt_conflict_words); + return bit_has(bits, b); +} + static int count_op(Func* f, IROp op) { int n = 0; for (u32 b = 0; b < f->nblocks; ++b) @@ -384,6 +393,206 @@ static void opt_liveness_branch(void) { tc_fini(&tc); } +static void opt_live_after_linear(void) { + TestCtx tc; + tc_init(&tc); + Func* f = new_func(&tc); + u32 b = f->entry; + Val a = add_val(f, tc.i32); + Val bv = add_val(f, tc.i32); + Val c = add_val(f, tc.i32); + + emit_load_imm(f, b, a, tc.i32, 1); + u32 ia = f->blocks[b].ninsts - 1u; + emit_load_imm(f, b, bv, tc.i32, 2); + u32 ib = f->blocks[b].ninsts - 1u; + emit_binop(f, b, c, a, bv, tc.i32); + u32 ic = f->blocks[b].ninsts - 1u; + emit_ret_val(f, b, c, tc.i32); + u32 ir = f->blocks[b].ninsts - 1u; + + opt_build_cfg(f); + opt_build_loop_tree(f); + opt_live_info(f); + + EXPECT(f->blocks[b].live_after != NULL, + "live_info should record per-instruction live_after sets"); + EXPECT(bit_has(f->blocks[b].live_after[ia], a), + "a should be live after its def"); + EXPECT(!bit_has(f->blocks[b].live_after[ia], bv), + "b should not be live before its def"); + EXPECT(bit_has(f->blocks[b].live_after[ib], a), + "a should be live after b's def"); + EXPECT(bit_has(f->blocks[b].live_after[ib], bv), + "b should be live after its def"); + EXPECT(!bit_has(f->blocks[b].live_after[ic], a), + "a should die after add"); + EXPECT(!bit_has(f->blocks[b].live_after[ic], bv), + "b should die after add"); + EXPECT(bit_has(f->blocks[b].live_after[ic], c), + "c should be live after add"); + EXPECT(!bit_has(f->blocks[b].live_after[ir], c), + "return should consume c"); + tc_fini(&tc); +} + +static void opt_interference_branch_disjoint(void) { + TestCtx tc; + tc_init(&tc); + MockCGTarget mock; + mock_init(&mock, tc.c); + static const Reg pool[] = {19}; + static const Reg scratch[] = {9, 10}; + mock_set_pool(&mock, RC_INT, pool, 1, scratch, 2, 0); + + Func* f = new_func(&tc); + opt_machinize(f, &mock.base); + u32 entry = f->entry; + u32 then_b = ir_block_new(f); + u32 else_b = ir_block_new(f); + ir_note_emit(f, then_b); + ir_note_emit(f, else_b); + Val a = add_val(f, tc.i32); + Val b = add_val(f, tc.i32); + + emit_test_branch(f, entry, then_b, else_b, tc.i32); + emit_load_imm(f, then_b, a, tc.i32, 11); + emit_ret_val(f, then_b, a, tc.i32); + emit_load_imm(f, else_b, b, tc.i32, 22); + emit_ret_val(f, else_b, b, tc.i32); + + opt_build_cfg(f); + opt_build_loop_tree(f); + opt_live_info(f); + + EXPECT(!val_conflicts(f, a, b), + "branch-local values v%u and v%u should not conflict", a, b); + + opt_regalloc(f, 0); + EXPECT(f->val_info[a].alloc_kind == OPT_ALLOC_HARD, + "then value should get the one hard register"); + EXPECT(f->val_info[b].alloc_kind == OPT_ALLOC_HARD, + "else value should share the one hard register"); + EXPECT(f->val_info[a].hard_reg == f->val_info[b].hard_reg, + "disjoint branch values should share a hard register"); + tc_fini(&tc); +} + +static void opt_interference_def_live_out(void) { + TestCtx tc; + tc_init(&tc); + Func* f = new_func(&tc); + u32 b = f->entry; + Val a = add_val(f, tc.i32); + Val bv = add_val(f, tc.i32); + Val c = add_val(f, tc.i32); + + emit_load_imm(f, b, a, tc.i32, 1); + emit_load_imm(f, b, bv, tc.i32, 2); + emit_binop(f, b, c, a, bv, tc.i32); + emit_ret_val(f, b, c, tc.i32); + + opt_build_cfg(f); + opt_build_loop_tree(f); + opt_live_info(f); + + EXPECT(val_conflicts(f, a, bv), + "a and b should conflict while both are live"); + EXPECT(val_conflicts(f, bv, a), + "conflicts should be symmetric"); + EXPECT(!val_conflicts(f, a, c), + "a should not conflict with c after c's def kills a"); + EXPECT(!val_conflicts(f, bv, c), + "b should not conflict with c after c's def kills b"); + tc_fini(&tc); +} + +static void opt_loop_frequency_weights_live_info(void) { + TestCtx tc; + tc_init(&tc); + Func* f = new_func(&tc); + u32 entry = f->entry; + u32 header = ir_block_new(f); + u32 body = ir_block_new(f); + u32 exit = ir_block_new(f); + ir_note_emit(f, header); + ir_note_emit(f, body); + ir_note_emit(f, exit); + Val loop_v = add_val(f, tc.i32); + Val exit_v = add_val(f, tc.i32); + Val out = add_val(f, tc.i32); + + emit_load_imm(f, entry, loop_v, tc.i32, 1); + emit_br_to(f, entry, header); + emit_test_branch(f, header, body, exit, tc.i32); + emit_binop(f, body, out, loop_v, loop_v, tc.i32); + emit_br_to(f, body, header); + emit_load_imm(f, exit, exit_v, tc.i32, 2); + emit_ret_val(f, exit, exit_v, tc.i32); + + opt_build_cfg(f); + opt_build_loop_tree(f); + opt_live_info(f); + + EXPECT(f->val_info[loop_v].use_freq > f->val_info[exit_v].use_freq, + "loop-used value should have higher weighted use frequency"); + EXPECT(f->val_info[loop_v].spill_cost > f->val_info[exit_v].spill_cost, + "loop-used value should have higher spill cost"); + tc_fini(&tc); +} + +static void opt_live_across_call_frequency(void) { + TestCtx tc; + tc_init(&tc); + Func* f = new_func(&tc); + u32 b = f->entry; + Val live = add_val(f, tc.i32); + Val dead = add_val(f, tc.i32); + + emit_load_imm(f, b, live, tc.i32, 11); + emit_load_imm(f, b, dead, tc.i32, 12); + emit_call_void(f, b); + emit_ret_val(f, b, live, tc.i32); + + opt_build_cfg(f); + opt_build_loop_tree(f); + opt_live_info(f); + + EXPECT(f->val_info[live].live_across_call_freq > 0, + "live value should be marked live across call"); + EXPECT(f->val_info[dead].live_across_call_freq == 0, + "dead value should not be marked live across call"); + tc_fini(&tc); +} + +static void opt_conflict_symmetry_and_class(void) { + TestCtx tc; + tc_init(&tc); + Func* f = new_func(&tc); + u32 b = f->entry; + Val i0 = add_val(f, tc.i32); + Val i1 = add_val(f, tc.i32); + Val fp = add_val_cls(f, tc.i64, RC_FP); + + emit_load_imm(f, b, i0, tc.i32, 1); + emit_load_imm(f, b, i1, tc.i32, 2); + emit_load_imm(f, b, fp, tc.i64, 3); + emit_binop(f, b, i1, i0, i1, tc.i32); + emit_ret_val(f, b, i1, tc.i32); + + opt_build_cfg(f); + opt_build_loop_tree(f); + opt_live_info(f); + + EXPECT(val_conflicts(f, i0, i1), "int values should conflict"); + EXPECT(val_conflicts(f, i1, i0), "int conflict should be symmetric"); + EXPECT(!val_conflicts(f, i0, fp), + "different register classes should not conflict"); + EXPECT(!val_conflicts(f, fp, i0), + "different-class conflict should be symmetric false"); + tc_fini(&tc); +} + static void opt_cfg_prunes_unreachable(void) { TestCtx tc; tc_init(&tc); @@ -960,6 +1169,12 @@ int main(void) { opt_loop_tree_nested_depths(); opt_loop_tree_does_not_mutate_cfg(); opt_liveness_branch(); + opt_live_after_linear(); + opt_interference_branch_disjoint(); + opt_interference_def_live_out(); + opt_loop_frequency_weights_live_info(); + opt_live_across_call_frequency(); + opt_conflict_symmetry_and_class(); opt_regalloc_priority(); opt_rewrite_spill_use_def(); opt_call_clobber_preservation();