mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
CK-UA: gate dwordx3/x4 global_load_lds builtin on clang≥21, inline-asm fallback
The size=12 and size=16 ImmArg overloads of __builtin_amdgcn_global_load_lds
for gfx950 only landed in AMD clang ~21 (present in ROCm ≥ 7.11 / clang 22,
absent in ROCm 7.1.1 / clang 20). Building this CK branch on the older
toolchain failed during semantic analysis of amd_buffer_addressing_builtins.hpp:
error: invalid size value
__builtin_amdgcn_global_load_lds(gptr, lptr, 16, ...);
note: size must be 1, 2, or 4
The error is unavoidable as soon as the unified_attention pipeline is built —
its `if (cache_ptr_int32_overflow_possible)` dispatch is a runtime branch,
not `if constexpr`, so the `bytes ∈ {12, 16}` instantiations are compiled
regardless of whether any workload at runtime takes that path.
Fix: introduce CK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN, gated on
__clang_major__ >= 21 (overridable). When 0, emit
`global_load_lds_dwordx{1,3,4}` via inline asm, with M0 set explicitly
through `s_mov_b32` from the addrspace(3) `lptr` narrowed to its 32-bit
LDS byte offset and wave-uniformed via `readfirstlane`. The assembler
accepts the mnemonic and emits the same HW instruction the builtin
would lower to (verified zero perf delta vs. the builtin path across
the full decode regression sweep — all 8 (b, d, dtype) configs match
to within ≤ 1.5% run-to-run noise when the fallback is force-on).
Two simpler "issue N× size=4" decompositions were tried and rejected:
INST.OFFSET stepping by 4 reproduces the dwordx4 layout for no shape;
stepping by 256 with `gptr += 4` per issue happens to pass on one
big-cache decode shape (b=1 / sk=1M) but fails on b=128 / sk=16384 /
d=128 / bf16. The native dwordx4's in-LDS sub-issue ordering doesn't
reduce to any combination of dword INST.OFFSET steps we could find that
survives all decode shapes; asking the assembler for the literal
instruction sidesteps the question.
The dormant amd_buffer_addressing.hpp copy (used only when CK_TILE_USE_
BUFFER_ADDRESSING_BUILTIN is forced to 0, which doesn't happen on clang
≥ 20) gets the same treatment so toggling the macro doesn't reintroduce
the bug.
Allows building jukorhon/unified-attention-ck on ROCm 7.1.1 unchanged;
upgrading to a newer ROCm container remains the recommended option.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -2864,14 +2864,11 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
|
||||
// =============================================================================
|
||||
// global_load_lds path — direct DRAM->LDS load via per-lane 64-bit base pointer.
|
||||
//
|
||||
// Equivalent of `amd_async_buffer_load_with_oob_raw` but bypasses the SRD
|
||||
// (`int32x4_t` resource descriptor) entirely:
|
||||
// - SRD's `size` field is uint32_t (max ~4 GB pool). Caches above that wrap.
|
||||
// - `buffer_load_*` voffset is 32-bit. Per-lane offsets above 4 GB wrap.
|
||||
// Replacing the underlying HW instruction with `global_load_lds` (per-lane
|
||||
// 64-bit VGPR-pair base + 13-bit signed immediate offset) lifts both limits.
|
||||
// Required for paged-KV caches whose `num_blocks * page_size * row_stride *
|
||||
// sizeof(T)` exceeds INT32_MAX (e.g. very-long-context decode pools).
|
||||
// See amd_buffer_addressing_builtins.hpp for the full rationale (this file
|
||||
// is the dormant copy used only when `CK_TILE_USE_BUFFER_ADDRESSING_BUILTIN`
|
||||
// is forced to 0; clang ≥ 20 routes to the _builtins.hpp variant). Kept in
|
||||
// lockstep so toggling that macro doesn't reintroduce the >4 GB-cache path
|
||||
// or the size=12/16 ImmArg compile failure on older toolchains.
|
||||
//
|
||||
// Caveats:
|
||||
// - Loses the SRD's free OOB clamp. Caller must ensure the per-lane pointer
|
||||
@@ -2879,6 +2876,14 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
|
||||
// - gfx9.4+ / gfx950 only — uses `__builtin_amdgcn_global_load_lds`.
|
||||
// Older arches would need a `global_load + ds_write` fallback.
|
||||
// =============================================================================
|
||||
#ifndef CK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN
|
||||
#if __clang_major__ >= 21
|
||||
#define CK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN 1
|
||||
#else
|
||||
#define CK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template <typename T,
|
||||
index_t N,
|
||||
index_t byte_offset_imm = 0, // 13-bit signed
|
||||
@@ -2912,12 +2917,46 @@ CK_TILE_DEVICE void amd_async_global_load_lds_raw(T* smem,
|
||||
// literals. A switch on the constexpr `bytes` value lets each branch
|
||||
// pass the literal directly.
|
||||
constexpr int kCoherence = static_cast<int>(coherence);
|
||||
#if CK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN
|
||||
if constexpr(bytes == 16)
|
||||
__builtin_amdgcn_global_load_lds(gptr, lptr, 16, byte_offset_imm, kCoherence);
|
||||
else if constexpr(bytes == 12)
|
||||
__builtin_amdgcn_global_load_lds(gptr, lptr, 12, byte_offset_imm, kCoherence);
|
||||
else /* bytes == 4 */
|
||||
__builtin_amdgcn_global_load_lds(gptr, lptr, 4, byte_offset_imm, kCoherence);
|
||||
#else
|
||||
// Old-toolchain fallback — see amd_buffer_addressing_builtins.hpp for
|
||||
// the full rationale. Emits the dwordx{1,3,4} instruction via inline
|
||||
// asm so the ImmArg size literal check is never performed; M0 is set
|
||||
// explicitly from `lptr` (the 32-bit LDS byte offset).
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
const uint32_t lds_byte_offset = (uint32_t)((uintptr_t)lptr);
|
||||
#pragma clang diagnostic pop
|
||||
const uint32_t lds_byte_offset_u =
|
||||
__builtin_amdgcn_readfirstlane(lds_byte_offset);
|
||||
uint32_t m0_dep;
|
||||
asm volatile("s_mov_b32 m0, %1"
|
||||
: "=s"(m0_dep)
|
||||
: "s"(lds_byte_offset_u)
|
||||
: "memory");
|
||||
|
||||
if constexpr(bytes == 16)
|
||||
asm volatile("global_load_lds_dwordx4 %0, off offset:%c1"
|
||||
:
|
||||
: "v"(gptr), "n"(byte_offset_imm), "s"(m0_dep)
|
||||
: "memory");
|
||||
else if constexpr(bytes == 12)
|
||||
asm volatile("global_load_lds_dwordx3 %0, off offset:%c1"
|
||||
:
|
||||
: "v"(gptr), "n"(byte_offset_imm), "s"(m0_dep)
|
||||
: "memory");
|
||||
else /* bytes == 4 */
|
||||
asm volatile("global_load_lds_dword %0, off offset:%c1"
|
||||
:
|
||||
: "v"(gptr), "n"(byte_offset_imm), "s"(m0_dep)
|
||||
: "memory");
|
||||
#endif
|
||||
}
|
||||
|
||||
// This version support buffer resource as input arg
|
||||
|
||||
@@ -2705,7 +2705,27 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
|
||||
// is valid (in our pipeline use, the page_table lookup guarantees this).
|
||||
// - gfx9.4+ / gfx950 only — uses `__builtin_amdgcn_global_load_lds`. Older
|
||||
// arches would need a `global_load + ds_write` fallback.
|
||||
//
|
||||
// Toolchain note (`CK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN`):
|
||||
// The size=12/size=16 ImmArg overloads of `__builtin_amdgcn_global_load_lds`
|
||||
// for gfx950 only landed in AMD clang ~21+ (verified absent in ROCm 7.1.1
|
||||
// / clang 20, present in ROCm 7.11.0 / clang 22). On older toolchains the
|
||||
// front-end rejects the size literal at parse time — no flag fixes this.
|
||||
// The macro below gates on `__clang_major__ >= 21`; when 0 we fall back to
|
||||
// emitting `global_load_lds_dwordx{1,3,4}` via inline asm, which bypasses
|
||||
// the ImmArg check entirely and produces the exact same HW instruction
|
||||
// (verified zero perf delta vs. the builtin path across the decode
|
||||
// regression suite). Override the heuristic manually with
|
||||
// `-DCK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN=0/1`.
|
||||
// =============================================================================
|
||||
#ifndef CK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN
|
||||
#if __clang_major__ >= 21
|
||||
#define CK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN 1
|
||||
#else
|
||||
#define CK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
template <typename T,
|
||||
index_t N,
|
||||
index_t byte_offset_imm = 0, // 13-bit signed
|
||||
@@ -2739,12 +2759,66 @@ CK_TILE_DEVICE void amd_async_global_load_lds_raw(T* smem,
|
||||
// literals. A switch on the constexpr `bytes` value lets each branch
|
||||
// pass the literal directly.
|
||||
constexpr int kCoherence = static_cast<int>(coherence);
|
||||
#if CK_TILE_HAS_GLOBAL_LOAD_LDS_DWORDX4_BUILTIN
|
||||
if constexpr(bytes == 16)
|
||||
__builtin_amdgcn_global_load_lds(gptr, lptr, 16, byte_offset_imm, kCoherence);
|
||||
else if constexpr(bytes == 12)
|
||||
__builtin_amdgcn_global_load_lds(gptr, lptr, 12, byte_offset_imm, kCoherence);
|
||||
else /* bytes == 4 */
|
||||
__builtin_amdgcn_global_load_lds(gptr, lptr, 4, byte_offset_imm, kCoherence);
|
||||
#else
|
||||
// Old-toolchain fallback (ROCm ≤ 7.1.1 / AMD clang ≤ 20).
|
||||
//
|
||||
// The size=12/16 ImmArg overloads of `__builtin_amdgcn_global_load_lds`
|
||||
// are rejected during semantic analysis on these compilers, so we emit
|
||||
// the dwordx{1,3,4} instruction via inline asm instead — the assembler
|
||||
// happily accepts the mnemonic and stamps an identical HW instruction
|
||||
// to the one the newer builtin would lower to. (Decomposing into N×
|
||||
// size=4 builtin calls *looks* equivalent but isn't: the in-LDS layout
|
||||
// of a native `dwordx4` doesn't reduce to any combination of dword
|
||||
// INST.OFFSET steps we could find that survives all decode shapes —
|
||||
// observed FAIL on b=128 / sk=16384 / d=128 / bf16. Easier to just
|
||||
// ask the assembler for the real instruction.)
|
||||
//
|
||||
// Operand contract:
|
||||
// - M0 (LDS dest base): set explicitly by us via `s_mov_b32`. The
|
||||
// addrspace(3) `lptr` narrows to a 32-bit LDS byte offset on cast.
|
||||
// `readfirstlane` guarantees the value lands in an SGPR even if
|
||||
// LLVM lost sight of its wave-uniformity. The "s" constraints
|
||||
// enforce SALU placement; `m0_dep` plumbs an SSA edge between the
|
||||
// m0 setter and the load asm so LLVM cannot reorder the two.
|
||||
// - `gptr` (per-lane 64-bit base): VGPR pair via "v".
|
||||
// - `byte_offset_imm` (compile-time INST.OFFSET literal): "n".
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
const uint32_t lds_byte_offset = (uint32_t)((uintptr_t)lptr);
|
||||
#pragma clang diagnostic pop
|
||||
// Wave-uniform readfirstlane keeps `m0` an SGPR even if optimizer didn't
|
||||
// see the wave-uniformity (our caller does pass a wave-uniform value).
|
||||
const uint32_t lds_byte_offset_u =
|
||||
__builtin_amdgcn_readfirstlane(lds_byte_offset);
|
||||
uint32_t m0_dep;
|
||||
asm volatile("s_mov_b32 m0, %1"
|
||||
: "=s"(m0_dep) // SSA tie-back into the load asm's input
|
||||
: "s"(lds_byte_offset_u)
|
||||
: "memory");
|
||||
|
||||
if constexpr(bytes == 16)
|
||||
asm volatile("global_load_lds_dwordx4 %0, off offset:%c1"
|
||||
:
|
||||
: "v"(gptr), "n"(byte_offset_imm), "s"(m0_dep)
|
||||
: "memory");
|
||||
else if constexpr(bytes == 12)
|
||||
asm volatile("global_load_lds_dwordx3 %0, off offset:%c1"
|
||||
:
|
||||
: "v"(gptr), "n"(byte_offset_imm), "s"(m0_dep)
|
||||
: "memory");
|
||||
else /* bytes == 4 */
|
||||
asm volatile("global_load_lds_dword %0, off offset:%c1"
|
||||
:
|
||||
: "v"(gptr), "n"(byte_offset_imm), "s"(m0_dep)
|
||||
: "memory");
|
||||
#endif
|
||||
}
|
||||
|
||||
// This version support buffer resource as input arg
|
||||
|
||||
Reference in New Issue
Block a user