mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
CK-UA: fuse FP8 cvt + cross-lane swap to hide ds_bpermute latency
Previously the 32x32x16 FP8 P-tile cvt and the QK-C -> PV-A cross-lane
swap ran in two separate static_for loops back-to-back inside fmha_alu1:
the whole tile was cvt'd into p.thread_buf_ first, then a second pass
issued one ds_bpermute_b32 per 8-fp8 K-chunk and read/wrote the same
buffer to swap the "bad" 4-byte halves between paired lanes.
The ds_bpermute has nontrivial LDS-DMA latency that the scheduler has
no way to hide when it lives alone in a tight serial loop with the
gather/scatter packs around it.
Fuse the two into one 8-fp8-per-iter loop:
1. cvt 8 fp32 -> 2 packed uint32 (lo_pack=slot[0..3], hi_pack=slot[4..7])
using the chained cvt_pk_fp8_f32 pattern matching cast_tile_pk_fp8_fp32.
2. Pick own_bad = (sub==0 ? hi_pack : lo_pack) and issue ds_bpermute on it.
3. Write back all 8 fp8 bytes; the "good" half lands first so its byte
stores can overlap with the in-flight ds_bpermute, and the next
iter's cvts can begin while the swap is still pending.
The 16x16x32 LDS-roundtrip branch keeps the original separated cvt
loop (no swap latency to hide there since the relayout goes through
LDS, not ds_bpermute).
Single-shape FP8 perf on gfx950 GPU 2 (CUDA graph, 50 iters):
decode d=128 b=4 sq=8 sk=4096: 0.2106 -> 0.1951 ms (-7.4%)
decode d=64 b=4 sq=8 sk=4096: 0.1464 -> 0.1208 ms (-17.5%)
prefill d=128 b=2 sq=512 sk=4k: 0.2558 -> 0.2220 ms (-13.2%)
BF16 unchanged (0.2046 -> 0.2039 ms, within noise).
Correctness: pytest UA correctness suite 405 passed / 80 skipped
(245 BF16/FP16 + 160 FP8), unchanged from before.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -833,29 +833,6 @@ struct UnifiedAttentionPipeline
|
||||
"fp8 P conversion expects packs of 4 fp32 lanes per "
|
||||
"thread; widen the warp gemm M distribution if this "
|
||||
"trips.");
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wuninitialized"
|
||||
int dummy_old;
|
||||
static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 4>{}([&](auto idx) {
|
||||
const float a = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 0]);
|
||||
const float b = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]);
|
||||
const float c = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 2]);
|
||||
const float d = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 3]);
|
||||
|
||||
uint32_t lo =
|
||||
__builtin_amdgcn_cvt_pk_fp8_f32(a, b, dummy_old, /*hi=*/false);
|
||||
uint32_t packed =
|
||||
__builtin_amdgcn_cvt_pk_fp8_f32(c, d, lo, /*hi=*/true);
|
||||
sp(sp_reg_idx).p.thread_buf_[idx + 0] =
|
||||
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 0) & 0xFFu));
|
||||
sp(sp_reg_idx).p.thread_buf_[idx + 1] =
|
||||
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 8) & 0xFFu));
|
||||
sp(sp_reg_idx).p.thread_buf_[idx + 2] =
|
||||
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 16) & 0xFFu));
|
||||
sp(sp_reg_idx).p.thread_buf_[idx + 3] =
|
||||
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 24) & 0xFFu));
|
||||
});
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
// ---------------------------------------------------------
|
||||
// FP8 P-tile QK-C -> PV-A re-layout.
|
||||
@@ -900,22 +877,33 @@ struct UnifiedAttentionPipeline
|
||||
// tiny-decode tier where (A) doesn't apply. This
|
||||
// keeps the previously-tuned 32x32x16 perf intact
|
||||
// while enabling FP8 on the m16 tier.
|
||||
//
|
||||
// For strategy (A) the cvt and the cross-lane swap are
|
||||
// fused into a single 8-fp8-per-iter loop so that the
|
||||
// ds_bpermute_b32 latency overlaps with subsequent
|
||||
// cvt_pk_fp8_f32 calls (instead of running serially
|
||||
// after the whole cvt phase finishes).
|
||||
using PVWarpTile = typename UnifiedAttentionShape::Gemm1WarpTile;
|
||||
if constexpr(PVWarpTile::at(number<0>{}) == 32 &&
|
||||
PVWarpTile::at(number<1>{}) == 32 &&
|
||||
PVWarpTile::at(number<2>{}) == 16)
|
||||
{
|
||||
// ---- (A) Cross-lane in-register swap (32x32x16). ----
|
||||
// ---- (A) Fused cvt + cross-lane swap (32x32x16). ----
|
||||
//
|
||||
// For each 8-fp8 PV K-chunk (one warp-gemm K iter),
|
||||
// the slot decomposition is:
|
||||
// Per 8-fp8 K-chunk:
|
||||
// 1. cvt 8 fp32 -> 2 packed uint32 (lo_pack = slot[0..3],
|
||||
// hi_pack = slot[4..7]) using the chained-`old` pattern
|
||||
// that matches `cast_tile_pk_fp8_fp32`.
|
||||
// 2. ds_bpermute the "bad" pack to the paired lane (lane^32).
|
||||
// 3. Write back both packs as 8 fp8 bytes; the "good" half
|
||||
// gets written first so its byte stores overlap with the
|
||||
// ds_bpermute latency.
|
||||
//
|
||||
// Slot decomposition (per fmha_alu1 doc above):
|
||||
// sub=0 | slot[0..3] | N=0..3 | K=0..3 OK
|
||||
// sub=0 | slot[4..7] | N=8..11 | K=4..7 BAD
|
||||
// sub=1 | slot[0..3] | N=4..7 | K=8..11 BAD
|
||||
// sub=1 | slot[4..7] | N=12..15 | K=12..15 OK
|
||||
// Each sub's *bad* 4-fp8 chunk holds exactly the data
|
||||
// the *paired* sub (lane ^ 32) needs for its bad
|
||||
// slot. One ds_bpermute_b32 per K-chunk fixes it.
|
||||
static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 8 == 0,
|
||||
"FP8 32x32x16 + Single cross-lane permute "
|
||||
"expects PV per-thread buffer in chunks of 8 "
|
||||
@@ -925,38 +913,59 @@ struct UnifiedAttentionPipeline
|
||||
const int paired_addr = (lane_id ^ 32) << 2; // bytes
|
||||
const bool is_sub_0 = (lane_id & 32) == 0;
|
||||
|
||||
auto pack4 = [](fp8_t a, fp8_t b, fp8_t c, fp8_t d) -> uint32_t {
|
||||
return (static_cast<uint32_t>(bit_cast<fp8_raw_t>(a)) << 0) |
|
||||
(static_cast<uint32_t>(bit_cast<fp8_raw_t>(b)) << 8) |
|
||||
(static_cast<uint32_t>(bit_cast<fp8_raw_t>(c)) << 16) |
|
||||
(static_cast<uint32_t>(bit_cast<fp8_raw_t>(d)) << 24);
|
||||
};
|
||||
auto unpack4 = [](uint32_t v, fp8_t& a, fp8_t& b, fp8_t& c, fp8_t& d) {
|
||||
a = bit_cast<fp8_t>(static_cast<fp8_raw_t>((v >> 0) & 0xFFu));
|
||||
b = bit_cast<fp8_t>(static_cast<fp8_raw_t>((v >> 8) & 0xFFu));
|
||||
c = bit_cast<fp8_t>(static_cast<fp8_raw_t>((v >> 16) & 0xFFu));
|
||||
d = bit_cast<fp8_t>(static_cast<fp8_raw_t>((v >> 24) & 0xFFu));
|
||||
};
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wuninitialized"
|
||||
int dummy_old;
|
||||
static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 8>{}([&](auto k_base) {
|
||||
auto& p = sp(sp_reg_idx).p;
|
||||
const uint32_t own_bad =
|
||||
is_sub_0
|
||||
? pack4(p.thread_buf_[k_base + 4], p.thread_buf_[k_base + 5],
|
||||
p.thread_buf_[k_base + 6], p.thread_buf_[k_base + 7])
|
||||
: pack4(p.thread_buf_[k_base + 0], p.thread_buf_[k_base + 1],
|
||||
p.thread_buf_[k_base + 2], p.thread_buf_[k_base + 3]);
|
||||
const uint32_t recv =
|
||||
__builtin_amdgcn_ds_bpermute(paired_addr, static_cast<int>(own_bad));
|
||||
if(is_sub_0)
|
||||
unpack4(recv,
|
||||
p.thread_buf_[k_base + 4], p.thread_buf_[k_base + 5],
|
||||
p.thread_buf_[k_base + 6], p.thread_buf_[k_base + 7]);
|
||||
else
|
||||
unpack4(recv,
|
||||
p.thread_buf_[k_base + 0], p.thread_buf_[k_base + 1],
|
||||
p.thread_buf_[k_base + 2], p.thread_buf_[k_base + 3]);
|
||||
auto& p = sp(sp_reg_idx).p;
|
||||
auto& sc = sp(sp_reg_idx).sp_compute;
|
||||
|
||||
const float a = p_compute_element_func(sc.thread_buf_[k_base + 0]);
|
||||
const float b = p_compute_element_func(sc.thread_buf_[k_base + 1]);
|
||||
const float c = p_compute_element_func(sc.thread_buf_[k_base + 2]);
|
||||
const float d = p_compute_element_func(sc.thread_buf_[k_base + 3]);
|
||||
const float e = p_compute_element_func(sc.thread_buf_[k_base + 4]);
|
||||
const float f = p_compute_element_func(sc.thread_buf_[k_base + 5]);
|
||||
const float g = p_compute_element_func(sc.thread_buf_[k_base + 6]);
|
||||
const float h = p_compute_element_func(sc.thread_buf_[k_base + 7]);
|
||||
|
||||
const uint32_t lo_tmp =
|
||||
__builtin_amdgcn_cvt_pk_fp8_f32(a, b, dummy_old, /*hi=*/false);
|
||||
const uint32_t lo_pack =
|
||||
__builtin_amdgcn_cvt_pk_fp8_f32(c, d, lo_tmp, /*hi=*/true);
|
||||
const uint32_t hi_tmp =
|
||||
__builtin_amdgcn_cvt_pk_fp8_f32(e, f, dummy_old, /*hi=*/false);
|
||||
const uint32_t hi_pack =
|
||||
__builtin_amdgcn_cvt_pk_fp8_f32(g, h, hi_tmp, /*hi=*/true);
|
||||
|
||||
// Issue ds_bpermute as early as possible so its LDS-DMA
|
||||
// latency overlaps with the byte writes below (and with
|
||||
// the next K-chunk's cvts after this iter unrolls).
|
||||
const uint32_t own_bad = is_sub_0 ? hi_pack : lo_pack;
|
||||
const uint32_t recv = __builtin_amdgcn_ds_bpermute(
|
||||
paired_addr, static_cast<int>(own_bad));
|
||||
|
||||
const uint32_t out_lo = is_sub_0 ? lo_pack : recv;
|
||||
const uint32_t out_hi = is_sub_0 ? recv : hi_pack;
|
||||
|
||||
p.thread_buf_[k_base + 0] =
|
||||
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_lo >> 0) & 0xFFu));
|
||||
p.thread_buf_[k_base + 1] =
|
||||
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_lo >> 8) & 0xFFu));
|
||||
p.thread_buf_[k_base + 2] =
|
||||
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_lo >> 16) & 0xFFu));
|
||||
p.thread_buf_[k_base + 3] =
|
||||
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_lo >> 24) & 0xFFu));
|
||||
p.thread_buf_[k_base + 4] =
|
||||
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_hi >> 0) & 0xFFu));
|
||||
p.thread_buf_[k_base + 5] =
|
||||
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_hi >> 8) & 0xFFu));
|
||||
p.thread_buf_[k_base + 6] =
|
||||
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_hi >> 16) & 0xFFu));
|
||||
p.thread_buf_[k_base + 7] =
|
||||
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_hi >> 24) & 0xFFu));
|
||||
});
|
||||
#pragma clang diagnostic pop
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -964,22 +973,52 @@ struct UnifiedAttentionPipeline
|
||||
// future MFMA shape that doesn't fit the
|
||||
// paired-lane swap pattern). ----
|
||||
//
|
||||
// 1. `p_qkc` is a static_distributed_tensor<fp8>
|
||||
// The cvt phase is separated out here (vs fused into
|
||||
// the swap as in branch (A)) because the relayout
|
||||
// travels through LDS, not through `ds_bpermute`, so
|
||||
// there's no swap-latency to hide.
|
||||
//
|
||||
// 1. cvt_pk_fp8_f32 chain into `sp(idx).p.thread_buf_`.
|
||||
// 2. `p_qkc` is a static_distributed_tensor<fp8>
|
||||
// whose distribution metadata says "QK-C
|
||||
// layout". Its register bytes are populated
|
||||
// from `sp(idx).p.thread_buf_`, which is
|
||||
// exactly where the cvt_pk_fp8_f32 chain just
|
||||
// wrote the FP8 bytes (the union has them at
|
||||
// QK-C-layout register offsets).
|
||||
// 2. `store_tile` writes `p_qkc` to LDS at
|
||||
// 3. `store_tile` writes `p_qkc` to LDS at
|
||||
// canonical (M, N) order.
|
||||
// 3. Block-level barrier.
|
||||
// 4. `load_tile` reads from the same LDS region
|
||||
// 4. Block-level barrier.
|
||||
// 5. `load_tile` reads from the same LDS region
|
||||
// with the PV-A distribution.
|
||||
// 5. Copy `p_pva.thread_buf_` back into
|
||||
// 6. Copy `p_pva.thread_buf_` back into
|
||||
// `sp(idx).p` so the gemm_1 call site reads
|
||||
// correctly-laid-out data with no further
|
||||
// changes.
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wuninitialized"
|
||||
int dummy_old;
|
||||
static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 4>{}([&](auto idx) {
|
||||
const float a = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 0]);
|
||||
const float b = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]);
|
||||
const float c = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 2]);
|
||||
const float d = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 3]);
|
||||
|
||||
const uint32_t lo =
|
||||
__builtin_amdgcn_cvt_pk_fp8_f32(a, b, dummy_old, /*hi=*/false);
|
||||
const uint32_t packed =
|
||||
__builtin_amdgcn_cvt_pk_fp8_f32(c, d, lo, /*hi=*/true);
|
||||
sp(sp_reg_idx).p.thread_buf_[idx + 0] =
|
||||
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 0) & 0xFFu));
|
||||
sp(sp_reg_idx).p.thread_buf_[idx + 1] =
|
||||
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 8) & 0xFFu));
|
||||
sp(sp_reg_idx).p.thread_buf_[idx + 2] =
|
||||
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 16) & 0xFFu));
|
||||
sp(sp_reg_idx).p.thread_buf_[idx + 3] =
|
||||
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 24) & 0xFFu));
|
||||
});
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
auto p_qkc = make_static_distributed_tensor<PDataType>(
|
||||
sp(sp_reg_idx).sp_compute.get_tile_distribution());
|
||||
static_assert(
|
||||
|
||||
Reference in New Issue
Block a user