#!/usr/bin/env python3 # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT from __future__ import annotations """Two-pass instruction-prefetch offset patcher. Round 1: build with koffset=0 so the compiler emits s_prefetch_inst_pc_rel with placeholder operands. Round 2: assemble the GPU .s via llvm-mc, disassemble with llvm-objdump to get exact hex addresses, compute correct PC-relative koffset/klength, then patch both the .s file and the GPU ELF inside the fat .o via direct binary patching (no recompilation needed, only a relink). If the computed prefetch region has zero in-bounds cachelines, the 8-byte s_prefetch_inst_pc_rel is replaced with 2× 4-byte s_nop 0. Labels are discovered automatically from [ck_prefetch] / [ck_label] comments in the generated .s assembly file — no source path needed. Standalone usage (runs both rounds): python patch_prefetch_offset.py \\ --build-dir /path/to/build \\ --target \\ --objdump-mcpu gfx1201 \\ [--dry-run] CMake PRE_LINK usage (round 1 already done by cmake, only patch the .o): python patch_prefetch_offset.py \\ --build-dir /path/to/build \\ --target \\ --objdump-mcpu gfx1201 \\ --skip-build-round1 """ import argparse import multiprocessing import re import shutil import subprocess import sys from pathlib import Path from typing import NamedTuple # --------------------------------------------------------------------------- # Module-level constants # --------------------------------------------------------------------------- CACHELINE_SIZE = 128 # bytes per instruction cache line KLENGTH_SHIFT = 6 KLENGTH_MASK = 0x7F << KLENGTH_SHIFT # klength occupies bits [12:6] of dw0 KOFFSET_MASK = 0x00FFFFFF # 24-bit signed PC-relative offset in dw1[23:0] S_NOP_ENCODING = 0xBF800000 # s_nop 0 — SOPP opcode 0, simm16=0 NOP_KLENGTH_SENTINEL = -1 # klength sentinel: replace prefetch with 2× s_nop PLACE_MODE_DEFAULT = 0 PLACE_MODE_BLOCK_ENTRY = 1 DIR_FORWARD = "forward" DIR_BACKWARD = "backward" # --------------------------------------------------------------------------- # Module-level regex patterns # --------------------------------------------------------------------------- # Function-header label in .s files FUNC_LABEL_RE = re.compile(r"^([A-Za-z_][A-Za-z0-9_$.]*):\s*(?:;.*)?$") # objdump function header (e.g. "0000000000001000 :") OBJDUMP_FUNC_RE = re.compile(r'^[0-9a-fA-F]+ <(.+?)>:\s*$') # objdump instruction address from trailing comment (e.g. "// 00001000: F4...") OBJDUMP_ADDR_RE = re.compile(r'//\s*([0-9a-fA-F]+):\s+[0-9a-fA-F]') # Block label in .s (e.g. ".LBB1_3:") BLOCK_LABEL_RE = re.compile(r'^\.[A-Za-z_]\w*:') # --------------------------------------------------------------------------- # Structured types for label classification # --------------------------------------------------------------------------- class PrefetchSite(NamedTuple): """A [ck_prefetch] marker in the merged .s ↔ objdump table.""" idx: int # index in merged list direction: str # DIR_FORWARD or DIR_BACKWARD offset_cl: int # cacheline offset from target class TargetSite(NamedTuple): """A [ck_label] marker (INST_PREFETCH_TARGET) in the merged table.""" idx: int # index in merged list mode: int # PLACE_MODE_DEFAULT or PLACE_MODE_BLOCK_ENTRY class _Tee: """Write to both stdout and a log file simultaneously.""" def __init__(self, log_path: Path): self._file = log_path.open("w", encoding="utf-8") self._stdout = sys.stdout def write(self, data: str) -> int: self._stdout.write(data) self._file.write(data) return len(data) def flush(self) -> None: self._stdout.flush() self._file.flush() def close(self) -> None: self._file.close() sys.stdout = self._stdout # --------------------------------------------------------------------------- # Instruction classification # --------------------------------------------------------------------------- def is_asm_instruction(line: str) -> bool: """Return True if the line is an instruction (not a comment/label/directive/blank).""" s = line.strip() if not s: return False if s[0] in (';', '/', '.', '#'): return False if s.split()[0].endswith(':'): return False return True # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def run(cmd: list[str], cwd: Path) -> subprocess.CompletedProcess: print(f"[run] {' '.join(cmd)}", flush=True) result = subprocess.run(cmd, cwd=cwd, text=True, capture_output=True) if result.stdout: print(result.stdout, end="") if result.stderr: print(result.stderr, end="", file=sys.stderr) if result.returncode != 0: sys.exit(f"Command failed with exit code {result.returncode}") return result def cmake_build(build_dir: Path, target: str, jobs: int) -> None: run(["cmake", "--build", str(build_dir), "--target", target, "-j", str(jobs), "--"], build_dir) def find_asm_file(search_dir: Path, cpp_stem: str, gpu_arch: str = "") -> Path: """Find the GPU .s file produced by --save-temps.""" all_candidates = sorted( search_dir.rglob(f"{cpp_stem}*.s"), key=lambda p: p.stat().st_mtime, reverse=True, ) if not all_candidates: sys.exit( f"No .s file matching '{cpp_stem}*.s' found under {search_dir}.\n" "Make sure --save-temps is in the target's compile options." ) def is_gpu(p: Path) -> bool: n = p.name if "-host-" in n: return False if "-hip-" in n: return True if gpu_arch and gpu_arch in n: return True return False gpu = [p for p in all_candidates if is_gpu(p)] chosen = gpu[0] if gpu else all_candidates[0] if not gpu: print(f"[warn] No GPU .s found; falling back to {chosen.name}") return chosen def find_obj_file(build_dir: Path, target: str) -> Path: """Find the most recent .o for the given CMake target. Uses ``**`` under ``{target}.dir/`` so that multi-config generators (e.g. Ninja Multi-Config, Visual Studio) whose object files live in a config subdirectory like ``{target}.dir/Release/`` are also found. """ candidates = sorted( build_dir.rglob(f"{target}.dir/**/*.o"), key=lambda p: p.stat().st_mtime, reverse=True, ) if not candidates: sys.exit( f"No .o file found under '*/{target}.dir/' in {build_dir}.\n" "Check that the target was built before running the patch script." ) return candidates[0] def run_objdump(objdump_path: str, mcpu: str, obj_path: Path) -> str: cmd = [objdump_path, f"--mcpu={mcpu}", "-d", str(obj_path)] print(f"[run] {' '.join(cmd)}", flush=True) result = subprocess.run(cmd, text=True, capture_output=True) if result.returncode != 0: sys.exit(f"objdump failed:\n{result.stderr}") return result.stdout def detect_mcpu_from_asm(asm_text: str) -> str: """Extract the GPU architecture from .amdgcn_target directive in the .s file. Looks for lines like: .amdgcn_target "amdgcn-amd-amdhsa--gfx1201" Returns the gfx* portion (e.g. "gfx1201") or empty string if not found. """ m = re.search(r'\.amdgcn_target\s+"[^"]*--(gfx[0-9a-zA-Z]+)', asm_text) return m.group(1) if m else "" # --------------------------------------------------------------------------- # Label discovery from .s # --------------------------------------------------------------------------- def find_prefetch_labels_from_asm(asm_text: str) -> list[str]: """Return unique label names from [ck_prefetch] comments in the .s file.""" label_re = re.compile(r';\s*\[ck_prefetch\].*\bname\s*=\s*(\w+)') seen: dict[str, None] = {} for line in asm_text.splitlines(): m = label_re.search(line) if m: seen.setdefault(m.group(1), None) return list(seen.keys()) # --------------------------------------------------------------------------- # Assembly / objdump helpers # --------------------------------------------------------------------------- def assemble_gpu_asm(asm_file: Path, mcpu: str, objdump_path: str) -> Path: """Assemble GPU .s → temp .o via llvm-mc. Returns path (caller deletes).""" llvm_mc = str(Path(objdump_path).parent / "llvm-mc") out_obj = asm_file.with_suffix(".ck_tmp_patching.o") run([llvm_mc, f"--mcpu={mcpu}", "--triple=amdgcn-amd-amdhsa", "--filetype=obj", "-o", str(out_obj), str(asm_file)], asm_file.parent) return out_obj def parse_objdump_functions(objdump_text: str) -> dict[str, list[tuple[int, str]]]: """Parse llvm-objdump -d output into per-function (addr, instr_text) lists.""" instr_re = re.compile(r'^\t(.+?)//\s*([0-9a-fA-F]+):\s+[0-9a-fA-F]') result: dict[str, list[tuple[int, str]]] = {} cur_name: str | None = None cur_entries: list[tuple[int, str]] = [] for line in objdump_text.splitlines(): m = OBJDUMP_FUNC_RE.match(line) if m: if cur_name is not None: result[cur_name] = cur_entries cur_name = m.group(1) cur_entries = [] elif cur_name is not None: m2 = instr_re.match(line) if m2: cur_entries.append((int(m2.group(2), 16), m2.group(1).strip())) if cur_name is not None: result[cur_name] = cur_entries return result def split_functions(asm_text: str) -> list[tuple[str, list[str]]]: """Split the .s file into per-function blocks.""" blocks: list[tuple[str, list[str]]] = [] current_name = "" current_lines: list[str] = [] for line in asm_text.splitlines(): m = FUNC_LABEL_RE.match(line) if m: if current_lines: blocks.append((current_name, current_lines)) current_name = m.group(1) current_lines = [line] else: current_lines.append(line) if current_lines: blocks.append((current_name, current_lines)) return blocks # --------------------------------------------------------------------------- # Merge .s ↔ objdump and compute koffsets # --------------------------------------------------------------------------- def _merge_s_and_objdump(s_lines: list[str], obj_entries: list[tuple[int, str]]) -> list[tuple[int | None, str]]: """Pair each .s instruction with its objdump entry by mnemonic matching. For each .s instruction we scan forward in objdump entries (up to MAX_LOOKAHEAD) to find a matching mnemonic. This self-corrects drift from assembler-inserted NOPs or classifier mismatches. .p2align directives advance obj_idx to the next aligned entry. Comment/directive/label lines get addr=None. """ MAX_LOOKAHEAD = 32 p2align_re = re.compile(r'\.p2align\s+(\d+)') merged: list[tuple[int | None, str]] = [] obj_idx = 0 for line in s_lines: m = p2align_re.search(line) if m: if obj_idx < len(obj_entries): align = 1 << int(m.group(1)) while obj_idx < len(obj_entries) and (obj_entries[obj_idx][0] % align) != 0: obj_idx += 1 merged.append((None, line)) continue if is_asm_instruction(line): if obj_idx < len(obj_entries): s_mnem = line.strip().split()[0].lower() for scan in range(obj_idx, min(obj_idx + MAX_LOOKAHEAD, len(obj_entries))): if obj_entries[scan][1].split()[0].lower() == s_mnem: obj_idx = scan break addr = obj_entries[obj_idx][0] obj_idx += 1 else: addr = None merged.append((addr, line)) else: merged.append((None, line)) return merged def _merge_all_functions(asm_text: str, objdump_text: str, dump_dir: Path | None = None ) -> dict[str, list[tuple[int | None, str]]]: """Merge .s ↔ objdump once per function. Returns {funcname: merged_list}. Optionally dumps one file per function (not per label).""" s_blocks = split_functions(asm_text) obj_funcs = parse_objdump_functions(objdump_text) merged_funcs: dict[str, list[tuple[int | None, str]]] = {} for name, s_lines in s_blocks: if name not in obj_funcs: continue merged = _merge_s_and_objdump(s_lines, obj_funcs[name]) merged_funcs[name] = merged if dump_dir is not None: safe = re.sub(r'[^A-Za-z0-9_]', '_', name)[:80] dump_path = dump_dir / f"merged_{safe}.txt" with dump_path.open('w', encoding='utf-8') as fh: for idx, (addr, line) in enumerate(merged): addr_str = f'0x{addr:08x}' if addr is not None else ' ' fh.write(f'[{idx:5d}] {addr_str} {line.rstrip()}\n') print(f"[dump] Merged table written to {dump_path}") return merged_funcs def _resolve_target_address( merged: list[tuple[int | None, str]], tgt_idx: int, tgt_mode: int, name: str, ) -> int | None: """Resolve the target address for a prefetch's INST_PREFETCH_TARGET marker. BLOCK_ENTRY mode (1): scan backward for the nearest block label, then use the first instruction after it. DEFAULT mode (0): use the first instruction after the [ck_label] comment. """ if tgt_mode == PLACE_MODE_BLOCK_ENTRY: block_idx: int | None = None for k in range(tgt_idx - 1, -1, -1): if BLOCK_LABEL_RE.match(merged[k][1].strip()): block_idx = k break scan_from = block_idx if block_idx is not None else tgt_idx target: int | None = None for k in range(scan_from + 1, len(merged)): if merged[k][0] is not None: target = merged[k][0] break if block_idx is not None and target is not None: orig_target: int | None = None for k in range(tgt_idx + 1, len(merged)): if merged[k][0] is not None: orig_target = merged[k][0] break if orig_target is not None and target != orig_target: print(f"[adjust] {name[:60]!r}: BLOCK_ENTRY mode — " f"block label at merged[{block_idx}] " f"→ target 0x{target:x} (was 0x{orig_target:x}, " f"saved {orig_target - target}B)") return target # DEFAULT mode (mode=0): first instruction after [ck_label]. for k in range(tgt_idx + 1, len(merged)): if merged[k][0] is not None: return merged[k][0] return None def _clamp_prefetch_region( name: str, pair_idx: int, pc_next: int, target: int, orig_klength: int, direction: str, offset_cl: int, func_end: int, ) -> tuple[int, int] | None: """Compute (koffset, klength) for one prefetch pair with OOB clamping. *klength* may be ``NOP_KLENGTH_SENTINEL`` if the prefetch is entirely out of bounds. Returns ``None`` if the pair should be skipped entirely (e.g. negative forward koffset). """ target_aligned = target & ~(CACHELINE_SIZE - 1) offset_bytes = offset_cl * CACHELINE_SIZE klength = orig_klength if direction == DIR_BACKWARD: region_end = target_aligned + CACHELINE_SIZE + offset_bytes region_start = region_end - (klength + 1) * CACHELINE_SIZE min_base = (pc_next & ~(CACHELINE_SIZE - 1)) + CACHELINE_SIZE if region_start < min_base: region_start = min_base usable = (region_end - region_start) // CACHELINE_SIZE if usable <= 0: klength = NOP_KLENGTH_SENTINEL print(f"[nop] {name[:60]!r}: backward prefetch fully OOB " f"(min_base 0x{min_base:x} >= region_end 0x{region_end:x}), " f"replacing with 2× s_nop") else: klength = usable - 1 print(f"[clamp] {name[:60]!r}: backward start clamped " f"(first cacheline after pc_next: 0x{min_base:x}), " f"klength {orig_klength} → {klength}") if klength == NOP_KLENGTH_SENTINEL: print(f"[debug] func={name[:60]!r} pair {pair_idx}: " f"pc_next=0x{pc_next:x} dir=backward → NOP (0 cachelines in bounds)") return (0, NOP_KLENGTH_SENTINEL) prefetch_base = region_start koffset = prefetch_base - pc_next print(f"[debug] func={name[:60]!r} pair {pair_idx}: " f"pc_next=0x{pc_next:x} target=0x{target:x} dir=backward " f"offset={offset_cl}cl prefetch_base=0x{prefetch_base:x} " f"koffset=0x{koffset:x} ({koffset}B) klength={klength} " f"region=[0x{region_start:x}, 0x{region_end:x}) " f"({(region_end - region_start)}B = {klength + 1} cachelines)") return (koffset, klength) # ── Forward direction ──────────────────────────────────────────────── prefetch_base = target_aligned + offset_bytes koffset = prefetch_base - pc_next if koffset < 0: print(f"[warn] {name[:60]!r}: negative koffset — target before prefetch, skipping") return None region_end = prefetch_base + (klength + 1) * CACHELINE_SIZE if region_end > func_end: needed = max(0, (func_end - prefetch_base + CACHELINE_SIZE - 1) // CACHELINE_SIZE) if needed == 0: klength = NOP_KLENGTH_SENTINEL print(f"[nop] {name[:60]!r}: forward prefetch fully OOB " f"(prefetch_base 0x{prefetch_base:x} >= func_end 0x{func_end:x}), " f"replacing with 2× s_nop") else: klength = needed - 1 region_end = prefetch_base + (klength + 1) * CACHELINE_SIZE print(f"[clamp] {name[:60]!r}: forward end clamped " f"(func_end 0x{func_end:x}), klength {orig_klength} → {klength}") if klength == NOP_KLENGTH_SENTINEL: print(f"[debug] func={name[:60]!r} pair {pair_idx}: " f"pc_next=0x{pc_next:x} dir=forward → NOP (0 cachelines in bounds)") else: region_start = prefetch_base print(f"[debug] func={name[:60]!r} pair {pair_idx}: " f"pc_next=0x{pc_next:x} target=0x{target:x} dir=forward " f"offset={offset_cl}cl " f"koffset=0x{koffset:x} ({koffset}B) klength={klength} " f"region=[0x{region_start:x}, 0x{region_end:x}) " f"({(region_end - region_start)}B = {klength + 1} cachelines)") return (koffset, klength) def find_best_koffset_hybrid(merged_funcs: dict[str, list[tuple[int | None, str]]], label: str) -> dict[str, list[tuple[int, int]]]: """Compute per-function (koffset, klength) for INST_PREFETCH/INST_PREFETCH_TARGET label pairs. Returns {funcname: [(koffset, klength), ...]} for each function containing the given label. klength is clamped so the prefetch does not extend past the end of the function; if no cachelines are in bounds, klength is set to NOP_KLENGTH_SENTINEL and the prefetch will be replaced with 2× s_nop. """ # [ck_prefetch] marks INST_PREFETCH sites, [ck_label] marks INST_PREFETCH_TARGET targets. prefetch_re = re.compile(rf";\s*\[ck_prefetch\].*\bname\s*=\s*{re.escape(label)}\b") target_re = re.compile(rf";\s*\[ck_label\].*\bname\s*=\s*{re.escape(label)}\b") either_re = re.compile(rf";\s*(?:\[ck_label\]|\[ck_prefetch\]).*\bname\s*=\s*{re.escape(label)}\b") klength_re = re.compile(r's_prefetch_inst_pc_rel\s+\S+\s*,\s*\S+\s*,\s*(\d+)') mode_re = re.compile(r'\bmode\s*=\s*(\d+)') dir_re = re.compile(r'\bdir\s*=\s*(\w+)') offset_re = re.compile(r'\boffset\s*=\s*(-?\d+)') results: dict[str, list[tuple[int, int]]] = {} for name, merged in merged_funcs.items(): if not any(either_re.search(line) for _, line in merged): continue # Determine function end address (for OOB clamping). func_end: int = 0 for addr, _ in reversed(merged): if addr is not None: func_end = addr + 8 # conservative: largest instruction is 8 bytes break # Classify markers from .s comments. prefetch_sites: list[PrefetchSite] = [] target_sites: list[TargetSite] = [] for idx, (_addr, line) in enumerate(merged): if prefetch_re.search(line): m_dir = dir_re.search(line) m_off = offset_re.search(line) prefetch_sites.append(PrefetchSite( idx=idx, direction=m_dir.group(1) if m_dir else DIR_FORWARD, offset_cl=int(m_off.group(1)) if m_off else 0, )) elif target_re.search(line): m_mode = mode_re.search(line) target_sites.append(TargetSite( idx=idx, mode=int(m_mode.group(1)) if m_mode else PLACE_MODE_DEFAULT, )) pairs: list[tuple[int, int]] = [] for pf in prefetch_sites: # Find the s_prefetch_inst_pc_rel instruction and parse its klength. pf_instr_idx: int | None = None orig_klength = 3 # default j = pf.idx + 1 while j < len(merged) and 's_prefetch_inst_pc_rel' not in merged[j][1]: j += 1 if j < len(merged): pf_instr_idx = j m_kl = klength_re.search(merged[j][1]) if m_kl: orig_klength = int(m_kl.group(1)) # pc_next = address of the instruction after s_prefetch_inst_pc_rel. pc_next: int | None = None if pf_instr_idx is not None: for k in range(pf_instr_idx + 1, len(merged)): if merged[k][0] is not None: pc_next = merged[k][0] break if pc_next is None: print(f"[warn] {name[:60]!r}: no pc_next for prefetch at merged[{pf.idx}], skipping") continue # Find the nearest INST_PREFETCH_TARGET after this INST_PREFETCH. tgt: TargetSite | None = None for t in target_sites: if t.idx > pf.idx: tgt = t break if tgt is None: # Unpaired prefetch — treat as forward with target=pc_next. print(f"[warn] {name[:60]!r}: unpaired prefetch label at merged[{pf.idx}] — " f"using koffset=0, clamping klength") result = _clamp_prefetch_region( name, len(pairs), pc_next, pc_next, orig_klength, DIR_FORWARD, 0, func_end) if result is not None: pairs.append(result) continue target = _resolve_target_address(merged, tgt.idx, tgt.mode, name) if target is None: print(f"[warn] {name[:60]!r}: no target address for label at merged[{tgt.idx}]") continue result = _clamp_prefetch_region( name, len(pairs), pc_next, target, orig_klength, pf.direction, pf.offset_cl, func_end) if result is not None: pairs.append(result) if pairs: results[name] = pairs if not results: print(f"[skip] Label '{label}' not found in any matching function block.") else: total = sum(len(v) for v in results.values()) print(f"[offsets] {len(results)} function(s), {total} pair(s) with koffset for '{label}'.") return results # --------------------------------------------------------------------------- # Patching # --------------------------------------------------------------------------- def patch_asm_s(asm_file: Path, func_koffsets: dict[str, list[tuple[int, int]]]) -> bool: """Patch s_prefetch_inst_pc_rel koffset and klength operands in the .s file. Returns True if any change was made.""" prefetch_re = re.compile( r'(s_prefetch_inst_pc_rel\s+)(?:0x[0-9a-fA-F]+|0|-?\d+)' r'(\s*,\s*null\s*,\s*)(?:\d+)') # Full-line regex to capture indentation and the entire prefetch instruction # (used for NOP replacement). prefetch_full_re = re.compile( r'^(\s*)s_prefetch_inst_pc_rel\s+\S+\s*,\s*\S+\s*,\s*\d+(.*)$') text = asm_file.read_text(encoding="utf-8", errors="replace") out_lines: list[str] = [] current_func = "" func_pf_idx: dict[str, int] = {} n_patched = 0 n_nopped = 0 for line in text.splitlines(keepends=True): m = FUNC_LABEL_RE.match(line.rstrip()) if m: current_func = m.group(1) if current_func in func_koffsets: pair_list = func_koffsets[current_func] idx = func_pf_idx.get(current_func, 0) if idx < len(pair_list): koffset, klength = pair_list[idx] if klength == NOP_KLENGTH_SENTINEL: # Replace 8-byte prefetch with 2× 4-byte s_nop 0 m_full = prefetch_full_re.match(line.rstrip('\n\r')) if m_full: indent = m_full.group(1) trailing = m_full.group(2) # e.g. comment eol = line[len(line.rstrip('\n\r')):] # preserve \n nop_lines = (f"{indent}s_nop 0{trailing}{eol}" f"{indent}s_nop 0{eol}") print(f"[patch-s] {current_func[:60]}: prefetch[{idx}] → " f"2× s_nop 0 (OOB)") func_pf_idx[current_func] = idx + 1 n_nopped += 1 n_patched += 1 out_lines.append(nop_lines) continue else: koffset_str = hex(koffset) new_line, n = prefetch_re.subn( rf'\g<1>{koffset_str}\g<2>{klength}', line) if n: print(f"[patch-s] {current_func[:60]}: prefetch[{idx}] → " f"koffset={koffset_str} klength={klength}") func_pf_idx[current_func] = idx + 1 n_patched += n out_lines.append(new_line) continue out_lines.append(line) new_text = ''.join(out_lines) print(f"[patch-s] {n_patched} operand(s) patched ({n_nopped} replaced with NOPs).") if new_text == text: print("[patch-s] No change.") return False asm_file.write_text(new_text, encoding="utf-8") print(f"[patch-s] Written: {asm_file}") return True # --------------------------------------------------------------------------- # ELF / fatbin helpers # --------------------------------------------------------------------------- def _find_elf_text_section(data: bytes | bytearray, base: int = 0) -> tuple[int, int, int] | None: """Find .text section in an ELF image starting at data[base:]. Returns (file_offset_from_base, size, vaddr) or None.""" import struct as _s d = data[base:] if len(d) < 64 or d[:4] != b'\x7fELF': return None ei_class, ei_data = d[4], d[5] endian = '<' if ei_data == 1 else '>' try: if ei_class == 2: e_shoff, = _s.unpack_from(f'{endian}Q', d, 40) e_shentsize, = _s.unpack_from(f'{endian}H', d, 58) e_shnum, = _s.unpack_from(f'{endian}H', d, 60) e_shstrndx, = _s.unpack_from(f'{endian}H', d, 62) addr_in_shdr, off_in_shdr, sz_in_shdr = 16, 24, 32 addr_fmt, off_fmt, sz_fmt = f'{endian}Q', f'{endian}Q', f'{endian}Q' else: e_shoff, = _s.unpack_from(f'{endian}I', d, 32) e_shentsize, = _s.unpack_from(f'{endian}H', d, 46) e_shnum, = _s.unpack_from(f'{endian}H', d, 48) e_shstrndx, = _s.unpack_from(f'{endian}H', d, 50) addr_in_shdr, off_in_shdr, sz_in_shdr = 12, 16, 20 addr_fmt, off_fmt, sz_fmt = f'{endian}I', f'{endian}I', f'{endian}I' shstr_sh = e_shoff + e_shstrndx * e_shentsize shstr_off, = _s.unpack_from(off_fmt, d, shstr_sh + off_in_shdr) for i in range(e_shnum): sh = e_shoff + i * e_shentsize name_idx, = _s.unpack_from(f'{endian}I', d, sh) ns = shstr_off + name_idx ne = d.index(b'\x00', ns) if d[ns:ne] == b'.text': sec_addr, = _s.unpack_from(addr_fmt, d, sh + addr_in_shdr) sec_off, = _s.unpack_from(off_fmt, d, sh + off_in_shdr) sec_sz, = _s.unpack_from(sz_fmt, d, sh + sz_in_shdr) return (sec_off, sec_sz, sec_addr) except (_s.error, ValueError): pass return None def _find_gpu_bundle(data: bytes | bytearray, tag: str = "fatbin" ) -> tuple[int, int, int, str] | None: """Locate the GPU bundle in a fat .o / fatbin. Returns (magic_idx, gpu_off, gpu_sz, gpu_triple) or None. *gpu_off* is relative to *magic_idx* (the absolute start of the GPU ELF in *data* is ``magic_idx + gpu_off``). """ import struct as _s MAGIC = b'__CLANG_OFFLOAD_BUNDLE__' magic_idx = data.find(MAGIC) if magic_idx < 0: print(f"[{tag}] __CLANG_OFFLOAD_BUNDLE__ magic not found") return None hdr = magic_idx + len(MAGIC) if hdr + 8 > len(data): print(f"[{tag}] Truncated fatbin header") return None num_bundles, = _s.unpack_from(' len(data): break off, sz, triple_sz = _s.unpack_from(' 512 or cur + triple_sz > len(data): break triple = data[cur:cur + triple_sz].decode('utf-8', errors='replace') cur += triple_sz if 'amdgcn' in triple or (triple.startswith('hip') and 'host' not in triple): gpu_off, gpu_sz, gpu_triple = off, sz, triple if not gpu_triple: print(f"[{tag}] No GPU entry found in fatbin header") return None return (magic_idx, gpu_off, gpu_sz, gpu_triple) def _objdump_gpu_elf(data: bytes | bytearray, abs_gpu_start: int, gpu_sz: int, mcpu: str, objdump_path: str, tmp_path: Path, tag: str = "fatbin") -> str | None: """Extract the GPU ELF from *data*, run objdump -d, return the text or None.""" try: tmp_path.write_bytes(bytes(data[abs_gpu_start:abs_gpu_start + gpu_sz])) result = subprocess.run( [objdump_path, f"--mcpu={mcpu}", "-d", str(tmp_path)], text=True, capture_output=True, ) if result.returncode != 0: print(f"[{tag}] objdump on GPU ELF failed: {result.stderr[:200]}") return None return result.stdout finally: tmp_path.unlink(missing_ok=True) def _patch_one_prefetch(fat_data: bytearray, instr_pos: int, instr_va: int, idx: int, new_koffset: int, new_klength: int) -> None: """Patch a single s_prefetch_inst_pc_rel at *instr_pos* in *fat_data*. If *new_klength* is NOP_KLENGTH_SENTINEL, replaces the 8-byte instruction with 2× s_nop 0. Otherwise patches koffset (dw1[23:0]) and klength (dw0[12:6]) in place. """ import struct as _struct old_dw0 = _struct.unpack_from('> KLENGTH_SHIFT) & 0x7F}") if new_klength == NOP_KLENGTH_SENTINEL: _struct.pack_into('> KLENGTH_SHIFT) & 0x7F}") print(f"[patch-obj] VA 0x{instr_va:x}: prefetch[{idx}] → " f"koffset={hex(new_koffset)} klength={new_klength}") def replace_gpu_in_fatobj(fat_obj: Path, mcpu: str, objdump_path: str, func_koffsets: dict[str, list[tuple[int, int]]]) -> bool: """Patch s_prefetch_inst_pc_rel koffsets and klengths directly in the GPU ELF embedded in the fat .o via direct binary patching. Returns True on success.""" fat_data = bytearray(fat_obj.read_bytes()) bundle = _find_gpu_bundle(fat_data, tag="patch-obj") if bundle is None: return False magic_idx, gpu_off, gpu_sz, gpu_triple = bundle abs_gpu_start = magic_idx + gpu_off print(f"[patch-obj] GPU bundle: '{gpu_triple}' abs=0x{abs_gpu_start:x} size={gpu_sz}") if fat_data[abs_gpu_start:abs_gpu_start + 4] != b'\x7fELF': print("[patch-obj] GPU data does not start with ELF magic") return False text_info = _find_elf_text_section(fat_data, abs_gpu_start) if text_info is None: print("[patch-obj] Cannot find .text in GPU ELF") return False text_foff, text_sz, text_va = text_info print(f"[patch-obj] .text: foff=0x{text_foff:x} size={text_sz} va=0x{text_va:x}") objdump_text = _objdump_gpu_elf(fat_data, abs_gpu_start, gpu_sz, mcpu, objdump_path, fat_obj.with_suffix(".ck_gpu_elf_tmp"), tag="patch-obj") if objdump_text is None: return False n_patched = 0 current_func: str | None = None func_pf_idx: dict[str, int] = {} for line in objdump_text.splitlines(): m = OBJDUMP_FUNC_RE.match(line) if m: current_func = m.group(1) continue if not current_func or current_func not in func_koffsets: continue if 's_prefetch_inst_pc_rel' not in line: continue m2 = OBJDUMP_ADDR_RE.search(line) if not m2: continue idx = func_pf_idx.get(current_func, 0) pair_list = func_koffsets[current_func] if idx >= len(pair_list): continue instr_va = int(m2.group(1), 16) new_koffset, new_klength = pair_list[idx] instr_pos = abs_gpu_start + text_foff + (instr_va - text_va) _patch_one_prefetch(fat_data, instr_pos, instr_va, idx, new_koffset, new_klength) func_pf_idx[current_func] = idx + 1 n_patched += 1 if n_patched == 0: print("[patch-obj] No s_prefetch_inst_pc_rel found to patch") return False fat_obj.write_bytes(bytes(fat_data)) # Sanity check: re-read and verify the write persisted import hashlib written_hash = hashlib.md5(fat_obj.read_bytes()).hexdigest() expected_hash = hashlib.md5(bytes(fat_data)).hexdigest() if written_hash != expected_hash: print(f"[patch-obj] WARNING: write verification failed! " f"expected={expected_hash} written={written_hash}") print(f"[patch-obj] Patched {n_patched} instruction(s) in {fat_obj.name} " f"md5={written_hash}") return True def verify_patched_obj(fat_obj: Path, mcpu: str, objdump_path: str, func_koffsets: dict[str, list[tuple[int, int]]]) -> bool: """Verify patched s_prefetch_inst_pc_rel koffsets and klengths. Returns True if all match. For NOP-replaced entries (klength == NOP_KLENGTH_SENTINEL), verification checks that the raw bytes at the original position are 2× s_nop encoding. """ import struct as _struct data = fat_obj.read_bytes() bundle = _find_gpu_bundle(data, tag="verify") if bundle is None: return False magic_idx, gpu_off, gpu_sz, _ = bundle abs_start = magic_idx + gpu_off # Locate .text for raw byte diagnostics text_info = _find_elf_text_section(data, abs_start) text_foff = text_va = 0 if text_info: text_foff, text_sz, text_va = text_info objdump_text = _objdump_gpu_elf(data, abs_start, gpu_sz, mcpu, objdump_path, fat_obj.with_suffix(".ck_verify_tmp"), tag="verify") if objdump_text is None: return False prefetch_re = re.compile( r's_prefetch_inst_pc_rel\s+(0x[0-9a-fA-F]+|\d+)\s*,\s*\S+\s*,\s*(\d+)') current_func: str | None = None func_pf_idx: dict[str, int] = {} # Track VAs already consumed as part of a NOP pair so the second s_nop # of a pair (or compiler-emitted s_nops) are not misidentified. consumed_nop_vas: set[int] = set() ok = True checked = 0 for line in objdump_text.splitlines(): m = OBJDUMP_FUNC_RE.match(line) if m: current_func = m.group(1) continue if current_func and current_func in func_koffsets: idx = func_pf_idx.get(current_func, 0) pair_list = func_koffsets[current_func] if idx >= len(pair_list): continue exp_koff, exp_klen = pair_list[idx] if exp_klen == NOP_KLENGTH_SENTINEL: # Expect 2× s_nop at this position. Match by checking whether # the raw bytes at this VA form a NOP pair (both dwords are # S_NOP_ENCODING). This avoids confusion with compiler-emitted # s_nop instructions that are not part of our patching. if 's_nop' in line: m_addr = OBJDUMP_ADDR_RE.search(line) if m_addr and text_info: va = int(m_addr.group(1), 16) if va in consumed_nop_vas: continue # second nop of an already-verified pair pos = abs_start + text_foff + (va - text_va) if 0 <= pos and pos + 8 <= len(data): dw0 = _struct.unpack_from('> KLENGTH_SHIFT) & 0x7F}") func_pf_idx[current_func] = idx + 1 checked += 1 if checked == 0: print("[verify] WARNING: no instructions found to verify") return False print(f"[verify] Checked {checked} instruction(s): {'ALL OK' if ok else 'FAILURES FOUND'}") return ok # --------------------------------------------------------------------------- # main # --------------------------------------------------------------------------- def main() -> None: ap = argparse.ArgumentParser(description=__doc__) ap.add_argument("--build-dir", required=True, type=Path, help="CMake build directory") ap.add_argument("--target", required=True, help="CMake target to build") ap.add_argument("--objdump-path", default=None, help="Path to llvm-objdump (auto-detected from PATH / /opt/rocm if omitted)") ap.add_argument("--objdump-mcpu", default="", help="--mcpu value for llvm-objdump/llvm-mc (auto-detected from .s if omitted)") ap.add_argument("--dry-run", action="store_true", help="Parse and print the koffset but do not write the file or rebuild") ap.add_argument("--skip-build-round1", action="store_true", help="Skip the round-1 cmake build (use when called from a CMake POST_BUILD " "command where round 1 was already performed by the normal build)") ap.add_argument("--jobs", type=int, default=None, help="Parallel jobs for cmake builds (default: all logical CPUs)") ap.add_argument("--log-file", type=Path, default=None, help="Tee all script output to this file (default: " "/prefetch_patch_.log; pass empty string to disable)") ap.add_argument("--label", default=None, help="Process only this label name (default: all labels discovered from .s)") ap.add_argument("--dump-intermediates", action="store_true", help="Write intermediate files (merged tables, objdump text) to the build dir") # Legacy: --source and --bundler-path are accepted but ignored. ap.add_argument("--source", type=Path, default=None, help=argparse.SUPPRESS) ap.add_argument("--bundler-path", default="", help=argparse.SUPPRESS) args = ap.parse_args() # Auto-detect llvm-objdump if not provided. if args.objdump_path is None: _candidates = [ shutil.which("llvm-objdump"), "/opt/rocm/llvm/bin/llvm-objdump", ] for _c in _candidates: if _c and Path(_c).is_file(): args.objdump_path = _c break if args.objdump_path is None: sys.exit( "Cannot find llvm-objdump. Pass --objdump-path explicitly or " "ensure llvm-objdump is on PATH or installed at /opt/rocm/llvm/bin/." ) print(f"[auto] Using llvm-objdump: {args.objdump_path}") # Log setup. log_path: Path | None if args.log_file is None: log_path = args.build_dir.resolve() / f"prefetch_patch_{args.target}.log" elif str(args.log_file) == "": log_path = None else: log_path = args.log_file tee: _Tee | None = None if log_path: log_path.parent.mkdir(parents=True, exist_ok=True) tee = _Tee(log_path) sys.stdout = tee # type: ignore[assignment] print(f"[log] Output mirrored to {log_path}") build_dir = args.build_dir.resolve() jobs = args.jobs if args.jobs is not None else multiprocessing.cpu_count() # ── Round 1 ─────────────────────────────────────────────────────────────── if args.skip_build_round1: print("=== Round 1: skipped (--skip-build-round1) ===") else: print("=== Round 1: building with koffset=0 ===") cmake_build(build_dir, args.target, jobs) # ── Locate .s and .o ───────────────────────────────────────────────────── obj_file = find_obj_file(build_dir, args.target) cpp_stem = Path(obj_file.stem).stem asm_file = find_asm_file(obj_file.parent.parent.parent, cpp_stem) print(f"[asm] Using {asm_file}") asm_text = asm_file.read_text(encoding="utf-8", errors="replace") # Auto-detect mcpu from .amdgcn_target directive if not provided. if not args.objdump_mcpu: args.objdump_mcpu = detect_mcpu_from_asm(asm_text) if args.objdump_mcpu: print(f"[mcpu] Auto-detected from .s: {args.objdump_mcpu}") else: sys.exit("Could not auto-detect GPU arch from .s file. " "Please pass --objdump-mcpu explicitly (e.g. --objdump-mcpu gfx1201).") # ── Discover labels from .s ────────────────────────────────────────────── labels = find_prefetch_labels_from_asm(asm_text) if not labels: print("[skip] No [ck_label]+s_prefetch_inst_pc_rel found in .s — nothing to patch.") if tee is not None: tee.close() return if args.label: labels = [l for l in labels if l == args.label] if not labels: print(f"[skip] Label '{args.label}' not found in .s.") if tee is not None: tee.close() return print(f"[labels] Discovered from .s: {labels}") # ── Assemble + objdump ─────────────────────────────────────────────────── gpu_obj = assemble_gpu_asm(asm_file, args.objdump_mcpu, args.objdump_path) print(f"[hybrid] Assembled GPU .s → {gpu_obj}") objdump_text = run_objdump(args.objdump_path, args.objdump_mcpu, gpu_obj) gpu_obj.unlink(missing_ok=True) if args.dump_intermediates: objdump_dump = build_dir / f"prefetch_patch_{args.target}_objdump.txt" objdump_dump.write_text(objdump_text, encoding="utf-8") print(f"[dump] Raw objdump written to {objdump_dump}") # ── Merge .s ↔ objdump (once per function) ────────────────────────────── dump_dir = build_dir if args.dump_intermediates else None merged_funcs = _merge_all_functions(asm_text, objdump_text, dump_dir=dump_dir) # ── Per-label: compute koffsets ────────────────────────────────────────── per_label: dict[str, dict[str, list[tuple[int, int]]]] = {} for label in labels: print(f"\n[label] Processing '{label}'") fk = find_best_koffset_hybrid(merged_funcs, label) if not fk: print(f"[label] '{label}': markers not matched — skipping.") continue for fname, pairs in fk.items(): for i, (k, kl) in enumerate(pairs): if kl == NOP_KLENGTH_SENTINEL: print(f"[offsets] {fname[:70]}: prefetch[{i}] → 2× s_nop (OOB)") else: print(f"[offsets] {fname[:70]}: prefetch[{i}] koffset={hex(k)} klength={kl}") per_label[label] = fk if args.dry_run: print("[dry-run] Stopping before Round 2.") if tee is not None: tee.close() return if not per_label: print("=== Round 2: skipped (no koffsets computed) ===") print("=== Done ===") if tee is not None: tee.close() return # ── Round 2: patch + rebuild ───────────────────────────────────────────── all_func_koffsets: dict[str, list[tuple[int, int]]] = {} for fk in per_label.values(): for fname, pairs in fk.items(): if fname in all_func_koffsets: all_func_koffsets[fname].extend(pairs) else: all_func_koffsets[fname] = list(pairs) print("=== Round 2: patching .s, replacing GPU in fat .o, relinking ===") # Save original mtimes so we can restore them after patching. # This prevents the patched files from appearing "newer" than the source, # which would cause CMake to skip recompilation on the next build. # Also save the .o.d dependency file mtime — if it drifts relative to # compiler_depend.internal, CMake re-evaluates dependencies unnecessarily. import os as _os obj_stat = _os.stat(obj_file) asm_stat = _os.stat(asm_file) obj_dep_file = obj_file.with_suffix(obj_file.suffix + ".d") obj_dep_stat = _os.stat(obj_dep_file) if obj_dep_file.exists() else None patch_asm_s(asm_file, all_func_koffsets) if not replace_gpu_in_fatobj(obj_file, args.objdump_mcpu, args.objdump_path, all_func_koffsets): sys.exit("[error] Direct binary patching of fat .o failed.") # Restore original mtimes so build-system dependency tracking is not disrupted. _os.utime(obj_file, (obj_stat.st_atime, obj_stat.st_mtime)) _os.utime(asm_file, (asm_stat.st_atime, asm_stat.st_mtime)) if obj_dep_stat is not None and obj_dep_file.exists(): _os.utime(obj_dep_file, (obj_dep_stat.st_atime, obj_dep_stat.st_mtime)) print("[patch-obj] GPU code replaced via direct binary patching.") # When invoked from PRE_LINK (--skip-build-round1), do NOT call cmake_build # again — the linker is about to run and will consume the patched .o. # PRE_LINK fires after compilation but before linking, which is exactly # what we need: the .o is patched in-place, then the linker embeds the # patched device ISA into the final executable. if args.skip_build_round1: print("=== Verifying patched object ===") if not verify_patched_obj(obj_file, args.objdump_mcpu, args.objdump_path, all_func_koffsets): print("[verify] WARNING: verification failed — koffsets may be stale") print("=== Done (PRE_LINK mode) ===") else: # Standalone mode: we need to relink so the final executable picks up # the patched .o. import hashlib as _hashlib obj_hash_before = _hashlib.md5(obj_file.read_bytes()).hexdigest() print(f"[diag] .o md5 BEFORE cmake_build: {obj_hash_before} ({obj_file.name})") cmake_build(build_dir, args.target, jobs) obj_hash_after = _hashlib.md5(obj_file.read_bytes()).hexdigest() print(f"[diag] .o md5 AFTER cmake_build: {obj_hash_after} ({obj_file.name})") if obj_hash_before != obj_hash_after: print("[diag] WARNING: .o was RECOMPILED by cmake_build — binary patches lost!") else: print("[diag] .o unchanged — binary patches preserved.") print("=== Done ===") if tee is not None: tee.close() if __name__ == "__main__": main()