riscv64.py (10206B)
1 from common import ( 2 AddI, 3 ArchDef, 4 BranchReg, 5 CondB, 6 CondBZ, 7 Enter, 8 La, 9 LaBr, 10 LdArg, 11 Li, 12 LogI, 13 Mem, 14 Mov, 15 Nullary, 16 Rrr, 17 ShiftI, 18 le32, 19 round_up, 20 ) 21 22 23 NAT = { 24 'a0': 10, 25 'a1': 11, 26 'a2': 12, 27 'a3': 13, 28 'a4': 14, 29 'a5': 15, 30 'a6': 16, 31 'a7': 17, 32 't0': 5, 33 't1': 6, 34 't2': 7, 35 's0': 9, 36 's1': 18, 37 's2': 19, 38 's3': 20, 39 'sp': 2, 40 'zero': 0, 41 'ra': 1, 42 'fp': 8, 43 'br': 31, 44 'scratch': 30, 45 'save0': 29, 46 'save1': 28, 47 'save2': 16, 48 } 49 50 51 RRR_BASE = { 52 'ADD': 0x00000033, 53 'SUB': 0x40000033, 54 'AND': 0x00007033, 55 'OR': 0x00006033, 56 'XOR': 0x00004033, 57 'SHL': 0x00001033, 58 'SHR': 0x00005033, 59 'SAR': 0x40005033, 60 'MUL': 0x02000033, 61 'DIV': 0x02004033, 62 'REM': 0x02006033, 63 } 64 65 66 # Inverted-condition B-type opcodes for the skip-taken-over-jalr pattern: 67 # the skip fires when the P1 condition is FALSE, so the jalr below is the 68 # taken target. 69 CONDB_INV_BASE = { 70 'BEQ': 0x00001063, # native BNE -- skip when not equal 71 'BNE': 0x00000063, # native BEQ -- skip when equal 72 'BLT': 0x00005063, # native BGE -- skip when ra >= rb (signed) 73 'BLTU': 0x00007063, # native BGEU -- skip when ra >= rb (unsigned) 74 } 75 76 77 CONDBZ_INV_BASE = { 78 'BEQZ': 0x00001063, 79 'BNEZ': 0x00000063, 80 'BLTZ': 0x00005063, 81 } 82 83 84 SYSCALL_NUMBERS = { 85 'SYS_READ': 63, 86 'SYS_WRITE': 64, 87 'SYS_CLOSE': 57, 88 'SYS_OPENAT': 56, 89 'SYS_EXIT': 93, 90 'SYS_CLONE': 220, 91 'SYS_EXECVE': 221, 92 'SYS_WAITID': 95, 93 } 94 95 96 def rv_r_type(base, rd, ra, rb): 97 d = NAT[rd] 98 a = NAT[ra] 99 b = NAT[rb] 100 return le32(base | (b << 20) | (a << 15) | (d << 7)) 101 102 103 def rv_i_type(base, rd, ra, imm12): 104 d = NAT[rd] 105 a = NAT[ra] 106 return le32(base | ((imm12 & 0xFFF) << 20) | (a << 15) | (d << 7)) 107 108 109 def rv_s_type(base, rs, ra, imm12): 110 s = NAT[rs] 111 a = NAT[ra] 112 imm = imm12 & 0xFFF 113 # arithmetic-shift the 12-bit signed value: bits 11:5 -> [31:25], 114 # bits 4:0 -> [11:7]. We only need the unsigned 12-bit pattern here 115 # because the m1pp encoder uses (>> imm 5) on the masked value. 116 hi = (imm >> 5) & 0x7F 117 lo = imm & 0x1F 118 return le32(base | (hi << 25) | (s << 20) | (a << 15) | (lo << 7)) 119 120 121 def rv_b_type_skip8(base, ra, rb): 122 # Hardcoded +8 branch: imm = 8, encoded with imm[4:1]=4, imm[11]=0, 123 # imm[10:5]=0, imm[12]=0. The combined [11:7] field becomes 124 # (imm[4:1] << 1) | imm[11] = 8. 125 a = NAT[ra] 126 b = NAT[rb] 127 return le32(base | (b << 20) | (a << 15) | (8 << 7)) 128 129 130 def rv_addi(rd, ra, imm12): 131 return rv_i_type(0x00000013, rd, ra, imm12) 132 133 134 def rv_ld(rd, ra, imm12): 135 return rv_i_type(0x00003003, rd, ra, imm12) 136 137 138 def rv_sd(rs, ra, imm12): 139 return rv_s_type(0x00003023, rs, ra, imm12) 140 141 142 def rv_lbu(rd, ra, imm12): 143 return rv_i_type(0x00004003, rd, ra, imm12) 144 145 146 def rv_sb(rs, ra, imm12): 147 return rv_s_type(0x00000023, rs, ra, imm12) 148 149 150 def rv_lwu(rd, ra, imm12): 151 return rv_i_type(0x00006003, rd, ra, imm12) 152 153 154 def rv_mov_rr(dst, src): 155 return rv_addi(dst, src, 0) 156 157 158 def rv_slli(rd, ra, shamt): 159 d = NAT[rd] 160 a = NAT[ra] 161 return le32(0x00001013 | ((shamt & 0x3F) << 20) | (a << 15) | (d << 7)) 162 163 164 def rv_srli(rd, ra, shamt): 165 d = NAT[rd] 166 a = NAT[ra] 167 return le32(0x00005013 | ((shamt & 0x3F) << 20) | (a << 15) | (d << 7)) 168 169 170 def rv_srai(rd, ra, shamt): 171 d = NAT[rd] 172 a = NAT[ra] 173 return le32(0x40005013 | ((shamt & 0x3F) << 20) | (a << 15) | (d << 7)) 174 175 176 def rv_jalr(rd, rs, imm12): 177 d = NAT[rd] 178 s = NAT[rs] 179 return le32(0x00000067 | ((imm12 & 0xFFF) << 20) | (s << 15) | (d << 7)) 180 181 182 def rv_ecall(): 183 return le32(0x00000073) 184 185 186 def rv_lit64_prefix(rd): 187 # auipc rd, 0 ; ld rd, 12(rd) ; jal x0, +12. 188 # The 8 bytes that follow in source become the literal. 189 d = NAT[rd] 190 auipc = 0x00000017 | (d << 7) 191 ld = 0x00C03003 | (d << 15) | (d << 7) 192 jal = 0x00C0006F 193 return le32(auipc) + le32(ld) + le32(jal) 194 195 196 def rv_lit32_prefix(rd): 197 # auipc rd, 0 ; lwu rd, 12(rd) ; jal x0, +8. 198 # lwu zero-extends a 4-byte literal; enough for stage0 addresses. 199 d = NAT[rd] 200 auipc = 0x00000017 | (d << 7) 201 lwu = 0x00C06003 | (d << 15) | (d << 7) 202 jal = 0x0080006F 203 return le32(auipc) + le32(lwu) + le32(jal) 204 205 206 def rv_epilogue(): 207 # Frame teardown shared by ERET, TAIL, TAILR. Mirrors p1_eret/p1_tail 208 # in P1-riscv64.M1pp: load saved ra, load saved caller sp into fp, 209 # then move fp into sp. The caller appends the actual jalr. 210 return rv_ld('ra', 'sp', 0) + rv_ld('fp', 'sp', 8) + rv_mov_rr('sp', 'fp') 211 212 213 def encode_li(_arch, row): 214 return rv_lit64_prefix(row.rd) 215 216 217 def encode_la(_arch, row): 218 return rv_lit32_prefix(row.rd) 219 220 221 def encode_labr(_arch, _row): 222 return rv_lit32_prefix('br') 223 224 225 def encode_mov(_arch, row): 226 # Portable sp is the frame-local base, which sits 16 bytes above 227 # native sp (the backend's 2-word hidden header occupies the low 228 # end of each frame). MOV rd, sp must therefore yield native_sp+16. 229 if row.rs == 'sp': 230 return rv_addi(row.rd, 'sp', 16) 231 return rv_mov_rr(row.rd, row.rs) 232 233 234 def encode_rrr(_arch, row): 235 return rv_r_type(RRR_BASE[row.op], row.rd, row.ra, row.rb) 236 237 238 def encode_addi(_arch, row): 239 return rv_addi(row.rd, row.ra, row.imm) 240 241 242 def encode_logi(_arch, row): 243 base = { 244 'ANDI': 0x00007013, 245 'ORI': 0x00006013, 246 }[row.op] 247 return rv_i_type(base, row.rd, row.ra, row.imm) 248 249 250 def encode_shifti(_arch, row): 251 if row.op == 'SHLI': 252 return rv_slli(row.rd, row.ra, row.imm) 253 if row.op == 'SHRI': 254 return rv_srli(row.rd, row.ra, row.imm) 255 if row.op == 'SARI': 256 return rv_srai(row.rd, row.ra, row.imm) 257 raise ValueError(f'unknown shift op: {row.op}') 258 259 260 def encode_mem(_arch, row): 261 # Portable sp points to the frame-local base; the 2-word hidden header 262 # at native_sp+0/+8 is not portable-addressable. Shift sp-relative 263 # offsets past the header. 264 off = row.off + 16 if row.rn == 'sp' else row.off 265 if row.op == 'LD': 266 return rv_ld(row.rt, row.rn, off) 267 if row.op == 'ST': 268 return rv_sd(row.rt, row.rn, off) 269 if row.op == 'LB': 270 return rv_lbu(row.rt, row.rn, off) 271 if row.op == 'SB': 272 return rv_sb(row.rt, row.rn, off) 273 raise ValueError(f'unknown mem op: {row.op}') 274 275 276 def encode_ldarg(_arch, row): 277 # LDARG loads the saved caller sp from [sp+8] (the hidden header 278 # slot), then indexes the incoming stack-arg area off it. Slot 0 is 279 # at caller_sp+16 because the native call instruction does not push 280 # a return address on riscv64 -- the +16 matches the aarch64 layout 281 # by convention for stage0 frame uniformity. 282 return rv_ld('scratch', 'sp', 8) + rv_ld(row.rd, 'scratch', 16 + 8 * row.slot) 283 284 285 def encode_branch_reg(_arch, row): 286 if row.kind == 'BR': 287 return rv_jalr('zero', row.rs, 0) 288 if row.kind == 'CALLR': 289 return rv_jalr('ra', row.rs, 0) 290 if row.kind == 'TAILR': 291 return rv_epilogue() + rv_jalr('zero', row.rs, 0) 292 raise ValueError(f'unknown branch-reg kind: {row.kind}') 293 294 295 def encode_condb(_arch, row): 296 return rv_b_type_skip8(CONDB_INV_BASE[row.op], row.ra, row.rb) + rv_jalr('zero', 'br', 0) 297 298 299 def encode_condbz(_arch, row): 300 return rv_b_type_skip8(CONDBZ_INV_BASE[row.op], row.ra, 'zero') + rv_jalr('zero', 'br', 0) 301 302 303 def encode_enter(arch, row): 304 frame_bytes = round_up(arch.stack_align, 2 * arch.word_bytes + row.size) 305 return ( 306 rv_addi('sp', 'sp', -frame_bytes) 307 + rv_sd('ra', 'sp', 0) 308 + rv_addi('fp', 'sp', frame_bytes) 309 + rv_sd('fp', 'sp', 8) 310 ) 311 312 313 def encode_nullary(_arch, row): 314 if row.kind == 'B': 315 return rv_jalr('zero', 'br', 0) 316 if row.kind == 'CALL': 317 return rv_jalr('ra', 'br', 0) 318 if row.kind == 'RET': 319 return rv_jalr('zero', 'ra', 0) 320 if row.kind == 'ERET': 321 return rv_epilogue() + rv_jalr('zero', 'ra', 0) 322 if row.kind == 'TAIL': 323 return rv_epilogue() + rv_jalr('zero', 'br', 0) 324 if row.kind == 'SYSCALL': 325 # P1: a0=number, a1..a3,t0,s0,s1 = args 0..5. 326 # Linux riscv64: a7=number, a0..a5 = args 0..5, return in a0. 327 # SYSCALL clobbers only P1 a0; restore a1/a2/a3 after ecall. 328 return ''.join([ 329 rv_mov_rr('save0', 'a1'), 330 rv_mov_rr('save1', 'a2'), 331 rv_mov_rr('save2', 'a3'), 332 rv_mov_rr('a7', 'a0'), 333 rv_mov_rr('a0', 'save0'), 334 rv_mov_rr('a1', 'save1'), 335 rv_mov_rr('a2', 'save2'), 336 rv_mov_rr('a3', 't0'), 337 rv_mov_rr('a4', 's0'), 338 rv_mov_rr('a5', 's1'), 339 rv_ecall(), 340 rv_mov_rr('a1', 'save0'), 341 rv_mov_rr('a2', 'save1'), 342 rv_mov_rr('a3', 'save2'), 343 ]) 344 raise ValueError(f'unknown nullary kind: {row.kind}') 345 346 347 def rv_start_stub(): 348 # Backend-owned :_start stub per docs/P1.md §Program Entry. Linux 349 # riscv64 puts argc at [sp] and argv starting at [sp+8]; load argc 350 # into a0, compute &argv[0] into a1, call p1_main under the one-word 351 # direct-result convention, then issue sys_exit. Mirrors %p1_entry 352 # in p1/P1-riscv64.M1pp. 353 # 354 # Raw hex outside DEFINE bodies must be single-quoted so bootstrap 355 # M0 treats it as a literal byte run rather than a token. 356 def q(hex_bytes): 357 return f"'{hex_bytes}'" 358 return [ 359 ':_start', 360 q(rv_ld('a0', 'sp', 0)), 361 q(rv_addi('a1', 'sp', 8)), 362 q(rv_lit32_prefix('br')), 363 '&p1_main', 364 q(rv_jalr('ra', 'br', 0)), 365 q(rv_addi('a7', 'zero', 93)), 366 q(rv_ecall()), 367 ] 368 369 370 ENCODERS = { 371 Li: encode_li, 372 La: encode_la, 373 LaBr: encode_labr, 374 Mov: encode_mov, 375 Rrr: encode_rrr, 376 AddI: encode_addi, 377 LogI: encode_logi, 378 ShiftI: encode_shifti, 379 Mem: encode_mem, 380 LdArg: encode_ldarg, 381 Nullary: encode_nullary, 382 BranchReg: encode_branch_reg, 383 CondB: encode_condb, 384 CondBZ: encode_condbz, 385 Enter: encode_enter, 386 } 387 388 389 ARCH = ArchDef( 390 name='riscv64', 391 word_bytes=8, 392 stack_align=16, 393 syscall_numbers=SYSCALL_NUMBERS, 394 encoders=ENCODERS, 395 start_stub=rv_start_stub, 396 )