CK-UA: fuse FP8 cvt + cross-lane swap to hide ds_bpermute latency

Previously the 32x32x16 FP8 P-tile cvt and the QK-C -> PV-A cross-lane
swap ran in two separate static_for loops back-to-back inside fmha_alu1:
the whole tile was cvt'd into p.thread_buf_ first, then a second pass
issued one ds_bpermute_b32 per 8-fp8 K-chunk and read/wrote the same
buffer to swap the "bad" 4-byte halves between paired lanes.

The ds_bpermute has nontrivial LDS-DMA latency that the scheduler has
no way to hide when it lives alone in a tight serial loop with the
gather/scatter packs around it.

Fuse the two into one 8-fp8-per-iter loop:
  1. cvt 8 fp32 -> 2 packed uint32 (lo_pack=slot[0..3], hi_pack=slot[4..7])
     using the chained cvt_pk_fp8_f32 pattern matching cast_tile_pk_fp8_fp32.
  2. Pick own_bad = (sub==0 ? hi_pack : lo_pack) and issue ds_bpermute on it.
  3. Write back all 8 fp8 bytes; the "good" half lands first so its byte
     stores can overlap with the in-flight ds_bpermute, and the next
     iter's cvts can begin while the swap is still pending.

The 16x16x32 LDS-roundtrip branch keeps the original separated cvt
loop (no swap latency to hide there since the relayout goes through
LDS, not ds_bpermute).

Single-shape FP8 perf on gfx950 GPU 2 (CUDA graph, 50 iters):
  decode d=128 b=4 sq=8 sk=4096:  0.2106 -> 0.1951 ms  (-7.4%)
  decode d=64  b=4 sq=8 sk=4096:  0.1464 -> 0.1208 ms  (-17.5%)
  prefill d=128 b=2 sq=512 sk=4k: 0.2558 -> 0.2220 ms  (-13.2%)

BF16 unchanged (0.2046 -> 0.2039 ms, within noise).

Correctness: pytest UA correctness suite 405 passed / 80 skipped
(245 BF16/FP16 + 160 FP8), unchanged from before.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
juuso-oskari
2026-05-18 15:48:01 +00:00
parent 9d7cc3ee9e
commit 3431615ff0

View File

@@ -833,29 +833,6 @@ struct UnifiedAttentionPipeline
"fp8 P conversion expects packs of 4 fp32 lanes per "
"thread; widen the warp gemm M distribution if this "
"trips.");
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wuninitialized"
int dummy_old;
static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 4>{}([&](auto idx) {
const float a = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 0]);
const float b = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]);
const float c = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 2]);
const float d = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 3]);
uint32_t lo =
__builtin_amdgcn_cvt_pk_fp8_f32(a, b, dummy_old, /*hi=*/false);
uint32_t packed =
__builtin_amdgcn_cvt_pk_fp8_f32(c, d, lo, /*hi=*/true);
sp(sp_reg_idx).p.thread_buf_[idx + 0] =
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 0) & 0xFFu));
sp(sp_reg_idx).p.thread_buf_[idx + 1] =
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 8) & 0xFFu));
sp(sp_reg_idx).p.thread_buf_[idx + 2] =
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 16) & 0xFFu));
sp(sp_reg_idx).p.thread_buf_[idx + 3] =
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 24) & 0xFFu));
});
#pragma clang diagnostic pop
// ---------------------------------------------------------
// FP8 P-tile QK-C -> PV-A re-layout.
@@ -900,22 +877,33 @@ struct UnifiedAttentionPipeline
// tiny-decode tier where (A) doesn't apply. This
// keeps the previously-tuned 32x32x16 perf intact
// while enabling FP8 on the m16 tier.
//
// For strategy (A) the cvt and the cross-lane swap are
// fused into a single 8-fp8-per-iter loop so that the
// ds_bpermute_b32 latency overlaps with subsequent
// cvt_pk_fp8_f32 calls (instead of running serially
// after the whole cvt phase finishes).
using PVWarpTile = typename UnifiedAttentionShape::Gemm1WarpTile;
if constexpr(PVWarpTile::at(number<0>{}) == 32 &&
PVWarpTile::at(number<1>{}) == 32 &&
PVWarpTile::at(number<2>{}) == 16)
{
// ---- (A) Cross-lane in-register swap (32x32x16). ----
// ---- (A) Fused cvt + cross-lane swap (32x32x16). ----
//
// For each 8-fp8 PV K-chunk (one warp-gemm K iter),
// the slot decomposition is:
// Per 8-fp8 K-chunk:
// 1. cvt 8 fp32 -> 2 packed uint32 (lo_pack = slot[0..3],
// hi_pack = slot[4..7]) using the chained-`old` pattern
// that matches `cast_tile_pk_fp8_fp32`.
// 2. ds_bpermute the "bad" pack to the paired lane (lane^32).
// 3. Write back both packs as 8 fp8 bytes; the "good" half
// gets written first so its byte stores overlap with the
// ds_bpermute latency.
//
// Slot decomposition (per fmha_alu1 doc above):
// sub=0 | slot[0..3] | N=0..3 | K=0..3 OK
// sub=0 | slot[4..7] | N=8..11 | K=4..7 BAD
// sub=1 | slot[0..3] | N=4..7 | K=8..11 BAD
// sub=1 | slot[4..7] | N=12..15 | K=12..15 OK
// Each sub's *bad* 4-fp8 chunk holds exactly the data
// the *paired* sub (lane ^ 32) needs for its bad
// slot. One ds_bpermute_b32 per K-chunk fixes it.
static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 8 == 0,
"FP8 32x32x16 + Single cross-lane permute "
"expects PV per-thread buffer in chunks of 8 "
@@ -925,38 +913,59 @@ struct UnifiedAttentionPipeline
const int paired_addr = (lane_id ^ 32) << 2; // bytes
const bool is_sub_0 = (lane_id & 32) == 0;
auto pack4 = [](fp8_t a, fp8_t b, fp8_t c, fp8_t d) -> uint32_t {
return (static_cast<uint32_t>(bit_cast<fp8_raw_t>(a)) << 0) |
(static_cast<uint32_t>(bit_cast<fp8_raw_t>(b)) << 8) |
(static_cast<uint32_t>(bit_cast<fp8_raw_t>(c)) << 16) |
(static_cast<uint32_t>(bit_cast<fp8_raw_t>(d)) << 24);
};
auto unpack4 = [](uint32_t v, fp8_t& a, fp8_t& b, fp8_t& c, fp8_t& d) {
a = bit_cast<fp8_t>(static_cast<fp8_raw_t>((v >> 0) & 0xFFu));
b = bit_cast<fp8_t>(static_cast<fp8_raw_t>((v >> 8) & 0xFFu));
c = bit_cast<fp8_t>(static_cast<fp8_raw_t>((v >> 16) & 0xFFu));
d = bit_cast<fp8_t>(static_cast<fp8_raw_t>((v >> 24) & 0xFFu));
};
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wuninitialized"
int dummy_old;
static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 8>{}([&](auto k_base) {
auto& p = sp(sp_reg_idx).p;
const uint32_t own_bad =
is_sub_0
? pack4(p.thread_buf_[k_base + 4], p.thread_buf_[k_base + 5],
p.thread_buf_[k_base + 6], p.thread_buf_[k_base + 7])
: pack4(p.thread_buf_[k_base + 0], p.thread_buf_[k_base + 1],
p.thread_buf_[k_base + 2], p.thread_buf_[k_base + 3]);
const uint32_t recv =
__builtin_amdgcn_ds_bpermute(paired_addr, static_cast<int>(own_bad));
if(is_sub_0)
unpack4(recv,
p.thread_buf_[k_base + 4], p.thread_buf_[k_base + 5],
p.thread_buf_[k_base + 6], p.thread_buf_[k_base + 7]);
else
unpack4(recv,
p.thread_buf_[k_base + 0], p.thread_buf_[k_base + 1],
p.thread_buf_[k_base + 2], p.thread_buf_[k_base + 3]);
auto& p = sp(sp_reg_idx).p;
auto& sc = sp(sp_reg_idx).sp_compute;
const float a = p_compute_element_func(sc.thread_buf_[k_base + 0]);
const float b = p_compute_element_func(sc.thread_buf_[k_base + 1]);
const float c = p_compute_element_func(sc.thread_buf_[k_base + 2]);
const float d = p_compute_element_func(sc.thread_buf_[k_base + 3]);
const float e = p_compute_element_func(sc.thread_buf_[k_base + 4]);
const float f = p_compute_element_func(sc.thread_buf_[k_base + 5]);
const float g = p_compute_element_func(sc.thread_buf_[k_base + 6]);
const float h = p_compute_element_func(sc.thread_buf_[k_base + 7]);
const uint32_t lo_tmp =
__builtin_amdgcn_cvt_pk_fp8_f32(a, b, dummy_old, /*hi=*/false);
const uint32_t lo_pack =
__builtin_amdgcn_cvt_pk_fp8_f32(c, d, lo_tmp, /*hi=*/true);
const uint32_t hi_tmp =
__builtin_amdgcn_cvt_pk_fp8_f32(e, f, dummy_old, /*hi=*/false);
const uint32_t hi_pack =
__builtin_amdgcn_cvt_pk_fp8_f32(g, h, hi_tmp, /*hi=*/true);
// Issue ds_bpermute as early as possible so its LDS-DMA
// latency overlaps with the byte writes below (and with
// the next K-chunk's cvts after this iter unrolls).
const uint32_t own_bad = is_sub_0 ? hi_pack : lo_pack;
const uint32_t recv = __builtin_amdgcn_ds_bpermute(
paired_addr, static_cast<int>(own_bad));
const uint32_t out_lo = is_sub_0 ? lo_pack : recv;
const uint32_t out_hi = is_sub_0 ? recv : hi_pack;
p.thread_buf_[k_base + 0] =
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_lo >> 0) & 0xFFu));
p.thread_buf_[k_base + 1] =
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_lo >> 8) & 0xFFu));
p.thread_buf_[k_base + 2] =
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_lo >> 16) & 0xFFu));
p.thread_buf_[k_base + 3] =
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_lo >> 24) & 0xFFu));
p.thread_buf_[k_base + 4] =
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_hi >> 0) & 0xFFu));
p.thread_buf_[k_base + 5] =
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_hi >> 8) & 0xFFu));
p.thread_buf_[k_base + 6] =
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_hi >> 16) & 0xFFu));
p.thread_buf_[k_base + 7] =
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_hi >> 24) & 0xFFu));
});
#pragma clang diagnostic pop
}
else
{
@@ -964,22 +973,52 @@ struct UnifiedAttentionPipeline
// future MFMA shape that doesn't fit the
// paired-lane swap pattern). ----
//
// 1. `p_qkc` is a static_distributed_tensor<fp8>
// The cvt phase is separated out here (vs fused into
// the swap as in branch (A)) because the relayout
// travels through LDS, not through `ds_bpermute`, so
// there's no swap-latency to hide.
//
// 1. cvt_pk_fp8_f32 chain into `sp(idx).p.thread_buf_`.
// 2. `p_qkc` is a static_distributed_tensor<fp8>
// whose distribution metadata says "QK-C
// layout". Its register bytes are populated
// from `sp(idx).p.thread_buf_`, which is
// exactly where the cvt_pk_fp8_f32 chain just
// wrote the FP8 bytes (the union has them at
// QK-C-layout register offsets).
// 2. `store_tile` writes `p_qkc` to LDS at
// 3. `store_tile` writes `p_qkc` to LDS at
// canonical (M, N) order.
// 3. Block-level barrier.
// 4. `load_tile` reads from the same LDS region
// 4. Block-level barrier.
// 5. `load_tile` reads from the same LDS region
// with the PV-A distribution.
// 5. Copy `p_pva.thread_buf_` back into
// 6. Copy `p_pva.thread_buf_` back into
// `sp(idx).p` so the gemm_1 call site reads
// correctly-laid-out data with no further
// changes.
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wuninitialized"
int dummy_old;
static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 4>{}([&](auto idx) {
const float a = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 0]);
const float b = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]);
const float c = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 2]);
const float d = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 3]);
const uint32_t lo =
__builtin_amdgcn_cvt_pk_fp8_f32(a, b, dummy_old, /*hi=*/false);
const uint32_t packed =
__builtin_amdgcn_cvt_pk_fp8_f32(c, d, lo, /*hi=*/true);
sp(sp_reg_idx).p.thread_buf_[idx + 0] =
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 0) & 0xFFu));
sp(sp_reg_idx).p.thread_buf_[idx + 1] =
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 8) & 0xFFu));
sp(sp_reg_idx).p.thread_buf_[idx + 2] =
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 16) & 0xFFu));
sp(sp_reg_idx).p.thread_buf_[idx + 3] =
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 24) & 0xFFu));
});
#pragma clang diagnostic pop
auto p_qkc = make_static_distributed_tensor<PDataType>(
sp(sp_reg_idx).sp_compute.get_tile_distribution());
static_assert(