CK-UA: gate dwordx3/x4 global_load_lds builtin on clang≥21, inline-asm fallback

The size=12 and size=16 ImmArg overloads of __builtin_amdgcn_global_load_lds
for gfx950 only landed in AMD clang ~21 (present in ROCm ≥ 7.11 / clang 22,
absent in ROCm 7.1.1 / clang 20). Building this CK branch on the older
toolchain failed during semantic analysis of amd_buffer_addressing_builtins.hpp:

    error: invalid size value
       __builtin_amdgcn_global_load_lds(gptr, lptr, 16, ...);
    note: size must be 1, 2, or 4

The error is unavoidable as soon as the unified_attention pipeline is built —
its `if (cache_ptr_int32_overflow_possible)` dispatch is a runtime branch,
not `if constexpr`, so the `bytes ∈ {12, 16}` instantiations are compiled
regardless of whether any workload at runtime takes that path.

Fix: introduce CK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN, gated on
__clang_major__ >= 21 (overridable). When 0, emit
`global_load_lds_dwordx{1,3,4}` via inline asm, with M0 set explicitly
through `s_mov_b32` from the addrspace(3) `lptr` narrowed to its 32-bit
LDS byte offset and wave-uniformed via `readfirstlane`. The assembler
accepts the mnemonic and emits the same HW instruction the builtin
would lower to (verified zero perf delta vs. the builtin path across
the full decode regression sweep — all 8 (b, d, dtype) configs match
to within ≤ 1.5% run-to-run noise when the fallback is force-on).

Two simpler "issue N× size=4" decompositions were tried and rejected:
INST.OFFSET stepping by 4 reproduces the dwordx4 layout for no shape;
stepping by 256 with `gptr += 4` per issue happens to pass on one
big-cache decode shape (b=1 / sk=1M) but fails on b=128 / sk=16384 /
d=128 / bf16. The native dwordx4's in-LDS sub-issue ordering doesn't
reduce to any combination of dword INST.OFFSET steps we could find that
survives all decode shapes; asking the assembler for the literal
instruction sidesteps the question.

The dormant amd_buffer_addressing.hpp copy (used only when CK_TILE_USE_
BUFFER_ADDRESSING_BUILTIN is forced to 0, which doesn't happen on clang
≥ 20) gets the same treatment so toggling the macro doesn't reintroduce
the bug.

Allows building jukorhon/unified-attention-ck on ROCm 7.1.1 unchanged;
upgrading to a newer ROCm container remains the recommended option.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
juuso-oskari
2026-05-27 12:45:18 +00:00
parent 2645149bbf
commit 46e6225397
2 changed files with 121 additions and 8 deletions

View File

@@ -2864,14 +2864,11 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
// =============================================================================
// global_load_lds path — direct DRAM->LDS load via per-lane 64-bit base pointer.
//
// Equivalent of `amd_async_buffer_load_with_oob_raw` but bypasses the SRD
// (`int32x4_t` resource descriptor) entirely:
// - SRD's `size` field is uint32_t (max ~4 GB pool). Caches above that wrap.
// - `buffer_load_*` voffset is 32-bit. Per-lane offsets above 4 GB wrap.
// Replacing the underlying HW instruction with `global_load_lds` (per-lane
// 64-bit VGPR-pair base + 13-bit signed immediate offset) lifts both limits.
// Required for paged-KV caches whose `num_blocks * page_size * row_stride *
// sizeof(T)` exceeds INT32_MAX (e.g. very-long-context decode pools).
// See amd_buffer_addressing_builtins.hpp for the full rationale (this file
// is the dormant copy used only when `CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN`
// is forced to 0; clang ≥ 20 routes to the _builtins.hpp variant). Kept in
// lockstep so toggling that macro doesn't reintroduce the >4 GB-cache path
// or the size=12/16 ImmArg compile failure on older toolchains.
//
// Caveats:
// - Loses the SRD's free OOB clamp. Caller must ensure the per-lane pointer
@@ -2879,6 +2876,14 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
// - gfx9.4+ / gfx950 only — uses `__builtin_amdgcn_global_load_lds`.
// Older arches would need a `global_load + ds_write` fallback.
// =============================================================================
#ifndef CK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN
#if __clang_major__ >= 21
#define CK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN 1
#else
#define CK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN 0
#endif
#endif
template <typename T,
index_t N,
index_t byte_offset_imm = 0, // 13-bit signed
@@ -2912,12 +2917,46 @@ CK_TILE_DEVICE void amd_async_global_load_lds_raw(T* smem,
// literals. A switch on the constexpr `bytes` value lets each branch
// pass the literal directly.
constexpr int kCoherence = static_cast<int>(coherence);
#if CK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN
if constexpr(bytes == 16)
__builtin_amdgcn_global_load_lds(gptr, lptr, 16, byte_offset_imm, kCoherence);
else if constexpr(bytes == 12)
__builtin_amdgcn_global_load_lds(gptr, lptr, 12, byte_offset_imm, kCoherence);
else /* bytes == 4 */
__builtin_amdgcn_global_load_lds(gptr, lptr, 4, byte_offset_imm, kCoherence);
#else
// Old-toolchain fallback — see amd_buffer_addressing_builtins.hpp for
// the full rationale. Emits the dwordx{1,3,4} instruction via inline
// asm so the ImmArg size literal check is never performed; M0 is set
// explicitly from `lptr` (the 32-bit LDS byte offset).
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
const uint32_t lds_byte_offset = (uint32_t)((uintptr_t)lptr);
#pragma clang diagnostic pop
const uint32_t lds_byte_offset_u =
__builtin_amdgcn_readfirstlane(lds_byte_offset);
uint32_t m0_dep;
asm volatile("s_mov_b32 m0, %1"
: "=s"(m0_dep)
: "s"(lds_byte_offset_u)
: "memory");
if constexpr(bytes == 16)
asm volatile("global_load_lds_dwordx4 %0, off offset:%c1"
:
: "v"(gptr), "n"(byte_offset_imm), "s"(m0_dep)
: "memory");
else if constexpr(bytes == 12)
asm volatile("global_load_lds_dwordx3 %0, off offset:%c1"
:
: "v"(gptr), "n"(byte_offset_imm), "s"(m0_dep)
: "memory");
else /* bytes == 4 */
asm volatile("global_load_lds_dword %0, off offset:%c1"
:
: "v"(gptr), "n"(byte_offset_imm), "s"(m0_dep)
: "memory");
#endif
}
// This version support buffer resource as input arg

View File

@@ -2705,7 +2705,27 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
// is valid (in our pipeline use, the page_table lookup guarantees this).
// - gfx9.4+ / gfx950 only — uses `__builtin_amdgcn_global_load_lds`. Older
// arches would need a `global_load + ds_write` fallback.
//
// Toolchain note (`CK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN`):
// The size=12/size=16 ImmArg overloads of `__builtin_amdgcn_global_load_lds`
// for gfx950 only landed in AMD clang ~21+ (verified absent in ROCm 7.1.1
// / clang 20, present in ROCm 7.11.0 / clang 22). On older toolchains the
// front-end rejects the size literal at parse time — no flag fixes this.
// The macro below gates on `__clang_major__ >= 21`; when 0 we fall back to
// emitting `global_load_lds_dwordx{1,3,4}` via inline asm, which bypasses
// the ImmArg check entirely and produces the exact same HW instruction
// (verified zero perf delta vs. the builtin path across the decode
// regression suite). Override the heuristic manually with
// `-DCK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN=0/1`.
// =============================================================================
#ifndef CK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN
#if __clang_major__ >= 21
#define CK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN 1
#else
#define CK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN 0
#endif
#endif
template <typename T,
index_t N,
index_t byte_offset_imm = 0, // 13-bit signed
@@ -2739,12 +2759,66 @@ CK_TILE_DEVICE void amd_async_global_load_lds_raw(T* smem,
// literals. A switch on the constexpr `bytes` value lets each branch
// pass the literal directly.
constexpr int kCoherence = static_cast<int>(coherence);
#if CK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN
if constexpr(bytes == 16)
__builtin_amdgcn_global_load_lds(gptr, lptr, 16, byte_offset_imm, kCoherence);
else if constexpr(bytes == 12)
__builtin_amdgcn_global_load_lds(gptr, lptr, 12, byte_offset_imm, kCoherence);
else /* bytes == 4 */
__builtin_amdgcn_global_load_lds(gptr, lptr, 4, byte_offset_imm, kCoherence);
#else
// Old-toolchain fallback (ROCm ≤ 7.1.1 / AMD clang ≤ 20).
//
// The size=12/16 ImmArg overloads of `__builtin_amdgcn_global_load_lds`
// are rejected during semantic analysis on these compilers, so we emit
// the dwordx{1,3,4} instruction via inline asm instead — the assembler
// happily accepts the mnemonic and stamps an identical HW instruction
// to the one the newer builtin would lower to. (Decomposing into N×
// size=4 builtin calls *looks* equivalent but isn't: the in-LDS layout
// of a native `dwordx4` doesn't reduce to any combination of dword
// INST.OFFSET steps we could find that survives all decode shapes —
// observed FAIL on b=128 / sk=16384 / d=128 / bf16. Easier to just
// ask the assembler for the real instruction.)
//
// Operand contract:
// - M0 (LDS dest base): set explicitly by us via `s_mov_b32`. The
// addrspace(3) `lptr` narrows to a 32-bit LDS byte offset on cast.
// `readfirstlane` guarantees the value lands in an SGPR even if
// LLVM lost sight of its wave-uniformity. The "s" constraints
// enforce SALU placement; `m0_dep` plumbs an SSA edge between the
// m0 setter and the load asm so LLVM cannot reorder the two.
// - `gptr` (per-lane 64-bit base): VGPR pair via "v".
// - `byte_offset_imm` (compile-time INST.OFFSET literal): "n".
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
const uint32_t lds_byte_offset = (uint32_t)((uintptr_t)lptr);
#pragma clang diagnostic pop
// Wave-uniform readfirstlane keeps `m0` an SGPR even if optimizer didn't
// see the wave-uniformity (our caller does pass a wave-uniform value).
const uint32_t lds_byte_offset_u =
__builtin_amdgcn_readfirstlane(lds_byte_offset);
uint32_t m0_dep;
asm volatile("s_mov_b32 m0, %1"
: "=s"(m0_dep) // SSA tie-back into the load asm's input
: "s"(lds_byte_offset_u)
: "memory");
if constexpr(bytes == 16)
asm volatile("global_load_lds_dwordx4 %0, off offset:%c1"
:
: "v"(gptr), "n"(byte_offset_imm), "s"(m0_dep)
: "memory");
else if constexpr(bytes == 12)
asm volatile("global_load_lds_dwordx3 %0, off offset:%c1"
:
: "v"(gptr), "n"(byte_offset_imm), "s"(m0_dep)
: "memory");
else /* bytes == 4 */
asm volatile("global_load_lds_dword %0, off offset:%c1"
:
: "v"(gptr), "n"(byte_offset_imm), "s"(m0_dep)
: "memory");
#endif
}
// This version support buffer resource as input arg