boot2

Playing with the boostrap
git clone https://git.ryansepassi.com/git/boot2.git
Log | Files | Refs | README

riscv64.py (10206B)


      1 from common import (
      2     AddI,
      3     ArchDef,
      4     BranchReg,
      5     CondB,
      6     CondBZ,
      7     Enter,
      8     La,
      9     LaBr,
     10     LdArg,
     11     Li,
     12     LogI,
     13     Mem,
     14     Mov,
     15     Nullary,
     16     Rrr,
     17     ShiftI,
     18     le32,
     19     round_up,
     20 )
     21 
     22 
     23 NAT = {
     24     'a0': 10,
     25     'a1': 11,
     26     'a2': 12,
     27     'a3': 13,
     28     'a4': 14,
     29     'a5': 15,
     30     'a6': 16,
     31     'a7': 17,
     32     't0': 5,
     33     't1': 6,
     34     't2': 7,
     35     's0': 9,
     36     's1': 18,
     37     's2': 19,
     38     's3': 20,
     39     'sp': 2,
     40     'zero': 0,
     41     'ra': 1,
     42     'fp': 8,
     43     'br': 31,
     44     'scratch': 30,
     45     'save0': 29,
     46     'save1': 28,
     47     'save2': 16,
     48 }
     49 
     50 
     51 RRR_BASE = {
     52     'ADD': 0x00000033,
     53     'SUB': 0x40000033,
     54     'AND': 0x00007033,
     55     'OR':  0x00006033,
     56     'XOR': 0x00004033,
     57     'SHL': 0x00001033,
     58     'SHR': 0x00005033,
     59     'SAR': 0x40005033,
     60     'MUL': 0x02000033,
     61     'DIV': 0x02004033,
     62     'REM': 0x02006033,
     63 }
     64 
     65 
     66 # Inverted-condition B-type opcodes for the skip-taken-over-jalr pattern:
     67 # the skip fires when the P1 condition is FALSE, so the jalr below is the
     68 # taken target.
     69 CONDB_INV_BASE = {
     70     'BEQ':  0x00001063,  # native BNE -- skip when not equal
     71     'BNE':  0x00000063,  # native BEQ -- skip when equal
     72     'BLT':  0x00005063,  # native BGE -- skip when ra >= rb (signed)
     73     'BLTU': 0x00007063,  # native BGEU -- skip when ra >= rb (unsigned)
     74 }
     75 
     76 
     77 CONDBZ_INV_BASE = {
     78     'BEQZ': 0x00001063,
     79     'BNEZ': 0x00000063,
     80     'BLTZ': 0x00005063,
     81 }
     82 
     83 
     84 SYSCALL_NUMBERS = {
     85     'SYS_READ': 63,
     86     'SYS_WRITE': 64,
     87     'SYS_CLOSE': 57,
     88     'SYS_OPENAT': 56,
     89     'SYS_EXIT': 93,
     90     'SYS_CLONE': 220,
     91     'SYS_EXECVE': 221,
     92     'SYS_WAITID': 95,
     93 }
     94 
     95 
     96 def rv_r_type(base, rd, ra, rb):
     97     d = NAT[rd]
     98     a = NAT[ra]
     99     b = NAT[rb]
    100     return le32(base | (b << 20) | (a << 15) | (d << 7))
    101 
    102 
    103 def rv_i_type(base, rd, ra, imm12):
    104     d = NAT[rd]
    105     a = NAT[ra]
    106     return le32(base | ((imm12 & 0xFFF) << 20) | (a << 15) | (d << 7))
    107 
    108 
    109 def rv_s_type(base, rs, ra, imm12):
    110     s = NAT[rs]
    111     a = NAT[ra]
    112     imm = imm12 & 0xFFF
    113     # arithmetic-shift the 12-bit signed value: bits 11:5 -> [31:25],
    114     # bits 4:0 -> [11:7]. We only need the unsigned 12-bit pattern here
    115     # because the m1pp encoder uses (>> imm 5) on the masked value.
    116     hi = (imm >> 5) & 0x7F
    117     lo = imm & 0x1F
    118     return le32(base | (hi << 25) | (s << 20) | (a << 15) | (lo << 7))
    119 
    120 
    121 def rv_b_type_skip8(base, ra, rb):
    122     # Hardcoded +8 branch: imm = 8, encoded with imm[4:1]=4, imm[11]=0,
    123     # imm[10:5]=0, imm[12]=0. The combined [11:7] field becomes
    124     # (imm[4:1] << 1) | imm[11] = 8.
    125     a = NAT[ra]
    126     b = NAT[rb]
    127     return le32(base | (b << 20) | (a << 15) | (8 << 7))
    128 
    129 
    130 def rv_addi(rd, ra, imm12):
    131     return rv_i_type(0x00000013, rd, ra, imm12)
    132 
    133 
    134 def rv_ld(rd, ra, imm12):
    135     return rv_i_type(0x00003003, rd, ra, imm12)
    136 
    137 
    138 def rv_sd(rs, ra, imm12):
    139     return rv_s_type(0x00003023, rs, ra, imm12)
    140 
    141 
    142 def rv_lbu(rd, ra, imm12):
    143     return rv_i_type(0x00004003, rd, ra, imm12)
    144 
    145 
    146 def rv_sb(rs, ra, imm12):
    147     return rv_s_type(0x00000023, rs, ra, imm12)
    148 
    149 
    150 def rv_lwu(rd, ra, imm12):
    151     return rv_i_type(0x00006003, rd, ra, imm12)
    152 
    153 
    154 def rv_mov_rr(dst, src):
    155     return rv_addi(dst, src, 0)
    156 
    157 
    158 def rv_slli(rd, ra, shamt):
    159     d = NAT[rd]
    160     a = NAT[ra]
    161     return le32(0x00001013 | ((shamt & 0x3F) << 20) | (a << 15) | (d << 7))
    162 
    163 
    164 def rv_srli(rd, ra, shamt):
    165     d = NAT[rd]
    166     a = NAT[ra]
    167     return le32(0x00005013 | ((shamt & 0x3F) << 20) | (a << 15) | (d << 7))
    168 
    169 
    170 def rv_srai(rd, ra, shamt):
    171     d = NAT[rd]
    172     a = NAT[ra]
    173     return le32(0x40005013 | ((shamt & 0x3F) << 20) | (a << 15) | (d << 7))
    174 
    175 
    176 def rv_jalr(rd, rs, imm12):
    177     d = NAT[rd]
    178     s = NAT[rs]
    179     return le32(0x00000067 | ((imm12 & 0xFFF) << 20) | (s << 15) | (d << 7))
    180 
    181 
    182 def rv_ecall():
    183     return le32(0x00000073)
    184 
    185 
    186 def rv_lit64_prefix(rd):
    187     # auipc rd, 0 ; ld rd, 12(rd) ; jal x0, +12.
    188     # The 8 bytes that follow in source become the literal.
    189     d = NAT[rd]
    190     auipc = 0x00000017 | (d << 7)
    191     ld = 0x00C03003 | (d << 15) | (d << 7)
    192     jal = 0x00C0006F
    193     return le32(auipc) + le32(ld) + le32(jal)
    194 
    195 
    196 def rv_lit32_prefix(rd):
    197     # auipc rd, 0 ; lwu rd, 12(rd) ; jal x0, +8.
    198     # lwu zero-extends a 4-byte literal; enough for stage0 addresses.
    199     d = NAT[rd]
    200     auipc = 0x00000017 | (d << 7)
    201     lwu = 0x00C06003 | (d << 15) | (d << 7)
    202     jal = 0x0080006F
    203     return le32(auipc) + le32(lwu) + le32(jal)
    204 
    205 
    206 def rv_epilogue():
    207     # Frame teardown shared by ERET, TAIL, TAILR. Mirrors p1_eret/p1_tail
    208     # in P1-riscv64.M1pp: load saved ra, load saved caller sp into fp,
    209     # then move fp into sp. The caller appends the actual jalr.
    210     return rv_ld('ra', 'sp', 0) + rv_ld('fp', 'sp', 8) + rv_mov_rr('sp', 'fp')
    211 
    212 
    213 def encode_li(_arch, row):
    214     return rv_lit64_prefix(row.rd)
    215 
    216 
    217 def encode_la(_arch, row):
    218     return rv_lit32_prefix(row.rd)
    219 
    220 
    221 def encode_labr(_arch, _row):
    222     return rv_lit32_prefix('br')
    223 
    224 
    225 def encode_mov(_arch, row):
    226     # Portable sp is the frame-local base, which sits 16 bytes above
    227     # native sp (the backend's 2-word hidden header occupies the low
    228     # end of each frame). MOV rd, sp must therefore yield native_sp+16.
    229     if row.rs == 'sp':
    230         return rv_addi(row.rd, 'sp', 16)
    231     return rv_mov_rr(row.rd, row.rs)
    232 
    233 
    234 def encode_rrr(_arch, row):
    235     return rv_r_type(RRR_BASE[row.op], row.rd, row.ra, row.rb)
    236 
    237 
    238 def encode_addi(_arch, row):
    239     return rv_addi(row.rd, row.ra, row.imm)
    240 
    241 
    242 def encode_logi(_arch, row):
    243     base = {
    244         'ANDI': 0x00007013,
    245         'ORI':  0x00006013,
    246     }[row.op]
    247     return rv_i_type(base, row.rd, row.ra, row.imm)
    248 
    249 
    250 def encode_shifti(_arch, row):
    251     if row.op == 'SHLI':
    252         return rv_slli(row.rd, row.ra, row.imm)
    253     if row.op == 'SHRI':
    254         return rv_srli(row.rd, row.ra, row.imm)
    255     if row.op == 'SARI':
    256         return rv_srai(row.rd, row.ra, row.imm)
    257     raise ValueError(f'unknown shift op: {row.op}')
    258 
    259 
    260 def encode_mem(_arch, row):
    261     # Portable sp points to the frame-local base; the 2-word hidden header
    262     # at native_sp+0/+8 is not portable-addressable. Shift sp-relative
    263     # offsets past the header.
    264     off = row.off + 16 if row.rn == 'sp' else row.off
    265     if row.op == 'LD':
    266         return rv_ld(row.rt, row.rn, off)
    267     if row.op == 'ST':
    268         return rv_sd(row.rt, row.rn, off)
    269     if row.op == 'LB':
    270         return rv_lbu(row.rt, row.rn, off)
    271     if row.op == 'SB':
    272         return rv_sb(row.rt, row.rn, off)
    273     raise ValueError(f'unknown mem op: {row.op}')
    274 
    275 
    276 def encode_ldarg(_arch, row):
    277     # LDARG loads the saved caller sp from [sp+8] (the hidden header
    278     # slot), then indexes the incoming stack-arg area off it. Slot 0 is
    279     # at caller_sp+16 because the native call instruction does not push
    280     # a return address on riscv64 -- the +16 matches the aarch64 layout
    281     # by convention for stage0 frame uniformity.
    282     return rv_ld('scratch', 'sp', 8) + rv_ld(row.rd, 'scratch', 16 + 8 * row.slot)
    283 
    284 
    285 def encode_branch_reg(_arch, row):
    286     if row.kind == 'BR':
    287         return rv_jalr('zero', row.rs, 0)
    288     if row.kind == 'CALLR':
    289         return rv_jalr('ra', row.rs, 0)
    290     if row.kind == 'TAILR':
    291         return rv_epilogue() + rv_jalr('zero', row.rs, 0)
    292     raise ValueError(f'unknown branch-reg kind: {row.kind}')
    293 
    294 
    295 def encode_condb(_arch, row):
    296     return rv_b_type_skip8(CONDB_INV_BASE[row.op], row.ra, row.rb) + rv_jalr('zero', 'br', 0)
    297 
    298 
    299 def encode_condbz(_arch, row):
    300     return rv_b_type_skip8(CONDBZ_INV_BASE[row.op], row.ra, 'zero') + rv_jalr('zero', 'br', 0)
    301 
    302 
    303 def encode_enter(arch, row):
    304     frame_bytes = round_up(arch.stack_align, 2 * arch.word_bytes + row.size)
    305     return (
    306         rv_addi('sp', 'sp', -frame_bytes)
    307         + rv_sd('ra', 'sp', 0)
    308         + rv_addi('fp', 'sp', frame_bytes)
    309         + rv_sd('fp', 'sp', 8)
    310     )
    311 
    312 
    313 def encode_nullary(_arch, row):
    314     if row.kind == 'B':
    315         return rv_jalr('zero', 'br', 0)
    316     if row.kind == 'CALL':
    317         return rv_jalr('ra', 'br', 0)
    318     if row.kind == 'RET':
    319         return rv_jalr('zero', 'ra', 0)
    320     if row.kind == 'ERET':
    321         return rv_epilogue() + rv_jalr('zero', 'ra', 0)
    322     if row.kind == 'TAIL':
    323         return rv_epilogue() + rv_jalr('zero', 'br', 0)
    324     if row.kind == 'SYSCALL':
    325         # P1: a0=number, a1..a3,t0,s0,s1 = args 0..5.
    326         # Linux riscv64: a7=number, a0..a5 = args 0..5, return in a0.
    327         # SYSCALL clobbers only P1 a0; restore a1/a2/a3 after ecall.
    328         return ''.join([
    329             rv_mov_rr('save0', 'a1'),
    330             rv_mov_rr('save1', 'a2'),
    331             rv_mov_rr('save2', 'a3'),
    332             rv_mov_rr('a7', 'a0'),
    333             rv_mov_rr('a0', 'save0'),
    334             rv_mov_rr('a1', 'save1'),
    335             rv_mov_rr('a2', 'save2'),
    336             rv_mov_rr('a3', 't0'),
    337             rv_mov_rr('a4', 's0'),
    338             rv_mov_rr('a5', 's1'),
    339             rv_ecall(),
    340             rv_mov_rr('a1', 'save0'),
    341             rv_mov_rr('a2', 'save1'),
    342             rv_mov_rr('a3', 'save2'),
    343         ])
    344     raise ValueError(f'unknown nullary kind: {row.kind}')
    345 
    346 
    347 def rv_start_stub():
    348     # Backend-owned :_start stub per docs/P1.md §Program Entry. Linux
    349     # riscv64 puts argc at [sp] and argv starting at [sp+8]; load argc
    350     # into a0, compute &argv[0] into a1, call p1_main under the one-word
    351     # direct-result convention, then issue sys_exit. Mirrors %p1_entry
    352     # in p1/P1-riscv64.M1pp.
    353     #
    354     # Raw hex outside DEFINE bodies must be single-quoted so bootstrap
    355     # M0 treats it as a literal byte run rather than a token.
    356     def q(hex_bytes):
    357         return f"'{hex_bytes}'"
    358     return [
    359         ':_start',
    360         q(rv_ld('a0', 'sp', 0)),
    361         q(rv_addi('a1', 'sp', 8)),
    362         q(rv_lit32_prefix('br')),
    363         '&p1_main',
    364         q(rv_jalr('ra', 'br', 0)),
    365         q(rv_addi('a7', 'zero', 93)),
    366         q(rv_ecall()),
    367     ]
    368 
    369 
    370 ENCODERS = {
    371     Li: encode_li,
    372     La: encode_la,
    373     LaBr: encode_labr,
    374     Mov: encode_mov,
    375     Rrr: encode_rrr,
    376     AddI: encode_addi,
    377     LogI: encode_logi,
    378     ShiftI: encode_shifti,
    379     Mem: encode_mem,
    380     LdArg: encode_ldarg,
    381     Nullary: encode_nullary,
    382     BranchReg: encode_branch_reg,
    383     CondB: encode_condb,
    384     CondBZ: encode_condbz,
    385     Enter: encode_enter,
    386 }
    387 
    388 
    389 ARCH = ArchDef(
    390     name='riscv64',
    391     word_bytes=8,
    392     stack_align=16,
    393     syscall_numbers=SYSCALL_NUMBERS,
    394     encoders=ENCODERS,
    395     start_stub=rv_start_stub,
    396 )