mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
CK-UA: fuse FP8 cvt + cross-lane swap to hide ds_bpermute latency
Previously the 32x32x16 FP8 P-tile cvt and the QK-C -> PV-A cross-lane
swap ran in two separate static_for loops back-to-back inside fmha_alu1:
the whole tile was cvt'd into p.thread_buf_ first, then a second pass
issued one ds_bpermute_b32 per 8-fp8 K-chunk and read/wrote the same
buffer to swap the "bad" 4-byte halves between paired lanes.
The ds_bpermute has nontrivial LDS-DMA latency that the scheduler has
no way to hide when it lives alone in a tight serial loop with the
gather/scatter packs around it.
Fuse the two into one 8-fp8-per-iter loop:
1. cvt 8 fp32 -> 2 packed uint32 (lo_pack=slot[0..3], hi_pack=slot[4..7])
using the chained cvt_pk_fp8_f32 pattern matching cast_tile_pk_fp8_fp32.
2. Pick own_bad = (sub==0 ? hi_pack : lo_pack) and issue ds_bpermute on it.
3. Write back all 8 fp8 bytes; the "good" half lands first so its byte
stores can overlap with the in-flight ds_bpermute, and the next
iter's cvts can begin while the swap is still pending.
The 16x16x32 LDS-roundtrip branch keeps the original separated cvt
loop (no swap latency to hide there since the relayout goes through
LDS, not ds_bpermute).
Single-shape FP8 perf on gfx950 GPU 2 (CUDA graph, 50 iters):
decode d=128 b=4 sq=8 sk=4096: 0.2106 -> 0.1951 ms (-7.4%)
decode d=64 b=4 sq=8 sk=4096: 0.1464 -> 0.1208 ms (-17.5%)
prefill d=128 b=2 sq=512 sk=4k: 0.2558 -> 0.2220 ms (-13.2%)
BF16 unchanged (0.2046 -> 0.2039 ms, within noise).
Correctness: pytest UA correctness suite 405 passed / 80 skipped
(245 BF16/FP16 + 160 FP8), unchanged from before.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -833,29 +833,6 @@ struct UnifiedAttentionPipeline
|
|||||||
"fp8 P conversion expects packs of 4 fp32 lanes per "
|
"fp8 P conversion expects packs of 4 fp32 lanes per "
|
||||||
"thread; widen the warp gemm M distribution if this "
|
"thread; widen the warp gemm M distribution if this "
|
||||||
"trips.");
|
"trips.");
|
||||||
#pragma clang diagnostic push
|
|
||||||
#pragma clang diagnostic ignored "-Wuninitialized"
|
|
||||||
int dummy_old;
|
|
||||||
static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 4>{}([&](auto idx) {
|
|
||||||
const float a = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 0]);
|
|
||||||
const float b = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]);
|
|
||||||
const float c = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 2]);
|
|
||||||
const float d = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 3]);
|
|
||||||
|
|
||||||
uint32_t lo =
|
|
||||||
__builtin_amdgcn_cvt_pk_fp8_f32(a, b, dummy_old, /*hi=*/false);
|
|
||||||
uint32_t packed =
|
|
||||||
__builtin_amdgcn_cvt_pk_fp8_f32(c, d, lo, /*hi=*/true);
|
|
||||||
sp(sp_reg_idx).p.thread_buf_[idx + 0] =
|
|
||||||
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 0) & 0xFFu));
|
|
||||||
sp(sp_reg_idx).p.thread_buf_[idx + 1] =
|
|
||||||
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 8) & 0xFFu));
|
|
||||||
sp(sp_reg_idx).p.thread_buf_[idx + 2] =
|
|
||||||
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 16) & 0xFFu));
|
|
||||||
sp(sp_reg_idx).p.thread_buf_[idx + 3] =
|
|
||||||
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 24) & 0xFFu));
|
|
||||||
});
|
|
||||||
#pragma clang diagnostic pop
|
|
||||||
|
|
||||||
// ---------------------------------------------------------
|
// ---------------------------------------------------------
|
||||||
// FP8 P-tile QK-C -> PV-A re-layout.
|
// FP8 P-tile QK-C -> PV-A re-layout.
|
||||||
@@ -900,22 +877,33 @@ struct UnifiedAttentionPipeline
|
|||||||
// tiny-decode tier where (A) doesn't apply. This
|
// tiny-decode tier where (A) doesn't apply. This
|
||||||
// keeps the previously-tuned 32x32x16 perf intact
|
// keeps the previously-tuned 32x32x16 perf intact
|
||||||
// while enabling FP8 on the m16 tier.
|
// while enabling FP8 on the m16 tier.
|
||||||
|
//
|
||||||
|
// For strategy (A) the cvt and the cross-lane swap are
|
||||||
|
// fused into a single 8-fp8-per-iter loop so that the
|
||||||
|
// ds_bpermute_b32 latency overlaps with subsequent
|
||||||
|
// cvt_pk_fp8_f32 calls (instead of running serially
|
||||||
|
// after the whole cvt phase finishes).
|
||||||
using PVWarpTile = typename UnifiedAttentionShape::Gemm1WarpTile;
|
using PVWarpTile = typename UnifiedAttentionShape::Gemm1WarpTile;
|
||||||
if constexpr(PVWarpTile::at(number<0>{}) == 32 &&
|
if constexpr(PVWarpTile::at(number<0>{}) == 32 &&
|
||||||
PVWarpTile::at(number<1>{}) == 32 &&
|
PVWarpTile::at(number<1>{}) == 32 &&
|
||||||
PVWarpTile::at(number<2>{}) == 16)
|
PVWarpTile::at(number<2>{}) == 16)
|
||||||
{
|
{
|
||||||
// ---- (A) Cross-lane in-register swap (32x32x16). ----
|
// ---- (A) Fused cvt + cross-lane swap (32x32x16). ----
|
||||||
//
|
//
|
||||||
// For each 8-fp8 PV K-chunk (one warp-gemm K iter),
|
// Per 8-fp8 K-chunk:
|
||||||
// the slot decomposition is:
|
// 1. cvt 8 fp32 -> 2 packed uint32 (lo_pack = slot[0..3],
|
||||||
|
// hi_pack = slot[4..7]) using the chained-`old` pattern
|
||||||
|
// that matches `cast_tile_pk_fp8_fp32`.
|
||||||
|
// 2. ds_bpermute the "bad" pack to the paired lane (lane^32).
|
||||||
|
// 3. Write back both packs as 8 fp8 bytes; the "good" half
|
||||||
|
// gets written first so its byte stores overlap with the
|
||||||
|
// ds_bpermute latency.
|
||||||
|
//
|
||||||
|
// Slot decomposition (per fmha_alu1 doc above):
|
||||||
// sub=0 | slot[0..3] | N=0..3 | K=0..3 OK
|
// sub=0 | slot[0..3] | N=0..3 | K=0..3 OK
|
||||||
// sub=0 | slot[4..7] | N=8..11 | K=4..7 BAD
|
// sub=0 | slot[4..7] | N=8..11 | K=4..7 BAD
|
||||||
// sub=1 | slot[0..3] | N=4..7 | K=8..11 BAD
|
// sub=1 | slot[0..3] | N=4..7 | K=8..11 BAD
|
||||||
// sub=1 | slot[4..7] | N=12..15 | K=12..15 OK
|
// sub=1 | slot[4..7] | N=12..15 | K=12..15 OK
|
||||||
// Each sub's *bad* 4-fp8 chunk holds exactly the data
|
|
||||||
// the *paired* sub (lane ^ 32) needs for its bad
|
|
||||||
// slot. One ds_bpermute_b32 per K-chunk fixes it.
|
|
||||||
static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 8 == 0,
|
static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 8 == 0,
|
||||||
"FP8 32x32x16 + Single cross-lane permute "
|
"FP8 32x32x16 + Single cross-lane permute "
|
||||||
"expects PV per-thread buffer in chunks of 8 "
|
"expects PV per-thread buffer in chunks of 8 "
|
||||||
@@ -925,38 +913,59 @@ struct UnifiedAttentionPipeline
|
|||||||
const int paired_addr = (lane_id ^ 32) << 2; // bytes
|
const int paired_addr = (lane_id ^ 32) << 2; // bytes
|
||||||
const bool is_sub_0 = (lane_id & 32) == 0;
|
const bool is_sub_0 = (lane_id & 32) == 0;
|
||||||
|
|
||||||
auto pack4 = [](fp8_t a, fp8_t b, fp8_t c, fp8_t d) -> uint32_t {
|
#pragma clang diagnostic push
|
||||||
return (static_cast<uint32_t>(bit_cast<fp8_raw_t>(a)) << 0) |
|
#pragma clang diagnostic ignored "-Wuninitialized"
|
||||||
(static_cast<uint32_t>(bit_cast<fp8_raw_t>(b)) << 8) |
|
int dummy_old;
|
||||||
(static_cast<uint32_t>(bit_cast<fp8_raw_t>(c)) << 16) |
|
|
||||||
(static_cast<uint32_t>(bit_cast<fp8_raw_t>(d)) << 24);
|
|
||||||
};
|
|
||||||
auto unpack4 = [](uint32_t v, fp8_t& a, fp8_t& b, fp8_t& c, fp8_t& d) {
|
|
||||||
a = bit_cast<fp8_t>(static_cast<fp8_raw_t>((v >> 0) & 0xFFu));
|
|
||||||
b = bit_cast<fp8_t>(static_cast<fp8_raw_t>((v >> 8) & 0xFFu));
|
|
||||||
c = bit_cast<fp8_t>(static_cast<fp8_raw_t>((v >> 16) & 0xFFu));
|
|
||||||
d = bit_cast<fp8_t>(static_cast<fp8_raw_t>((v >> 24) & 0xFFu));
|
|
||||||
};
|
|
||||||
|
|
||||||
static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 8>{}([&](auto k_base) {
|
static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 8>{}([&](auto k_base) {
|
||||||
auto& p = sp(sp_reg_idx).p;
|
auto& p = sp(sp_reg_idx).p;
|
||||||
const uint32_t own_bad =
|
auto& sc = sp(sp_reg_idx).sp_compute;
|
||||||
is_sub_0
|
|
||||||
? pack4(p.thread_buf_[k_base + 4], p.thread_buf_[k_base + 5],
|
const float a = p_compute_element_func(sc.thread_buf_[k_base + 0]);
|
||||||
p.thread_buf_[k_base + 6], p.thread_buf_[k_base + 7])
|
const float b = p_compute_element_func(sc.thread_buf_[k_base + 1]);
|
||||||
: pack4(p.thread_buf_[k_base + 0], p.thread_buf_[k_base + 1],
|
const float c = p_compute_element_func(sc.thread_buf_[k_base + 2]);
|
||||||
p.thread_buf_[k_base + 2], p.thread_buf_[k_base + 3]);
|
const float d = p_compute_element_func(sc.thread_buf_[k_base + 3]);
|
||||||
const uint32_t recv =
|
const float e = p_compute_element_func(sc.thread_buf_[k_base + 4]);
|
||||||
__builtin_amdgcn_ds_bpermute(paired_addr, static_cast<int>(own_bad));
|
const float f = p_compute_element_func(sc.thread_buf_[k_base + 5]);
|
||||||
if(is_sub_0)
|
const float g = p_compute_element_func(sc.thread_buf_[k_base + 6]);
|
||||||
unpack4(recv,
|
const float h = p_compute_element_func(sc.thread_buf_[k_base + 7]);
|
||||||
p.thread_buf_[k_base + 4], p.thread_buf_[k_base + 5],
|
|
||||||
p.thread_buf_[k_base + 6], p.thread_buf_[k_base + 7]);
|
const uint32_t lo_tmp =
|
||||||
else
|
__builtin_amdgcn_cvt_pk_fp8_f32(a, b, dummy_old, /*hi=*/false);
|
||||||
unpack4(recv,
|
const uint32_t lo_pack =
|
||||||
p.thread_buf_[k_base + 0], p.thread_buf_[k_base + 1],
|
__builtin_amdgcn_cvt_pk_fp8_f32(c, d, lo_tmp, /*hi=*/true);
|
||||||
p.thread_buf_[k_base + 2], p.thread_buf_[k_base + 3]);
|
const uint32_t hi_tmp =
|
||||||
|
__builtin_amdgcn_cvt_pk_fp8_f32(e, f, dummy_old, /*hi=*/false);
|
||||||
|
const uint32_t hi_pack =
|
||||||
|
__builtin_amdgcn_cvt_pk_fp8_f32(g, h, hi_tmp, /*hi=*/true);
|
||||||
|
|
||||||
|
// Issue ds_bpermute as early as possible so its LDS-DMA
|
||||||
|
// latency overlaps with the byte writes below (and with
|
||||||
|
// the next K-chunk's cvts after this iter unrolls).
|
||||||
|
const uint32_t own_bad = is_sub_0 ? hi_pack : lo_pack;
|
||||||
|
const uint32_t recv = __builtin_amdgcn_ds_bpermute(
|
||||||
|
paired_addr, static_cast<int>(own_bad));
|
||||||
|
|
||||||
|
const uint32_t out_lo = is_sub_0 ? lo_pack : recv;
|
||||||
|
const uint32_t out_hi = is_sub_0 ? recv : hi_pack;
|
||||||
|
|
||||||
|
p.thread_buf_[k_base + 0] =
|
||||||
|
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_lo >> 0) & 0xFFu));
|
||||||
|
p.thread_buf_[k_base + 1] =
|
||||||
|
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_lo >> 8) & 0xFFu));
|
||||||
|
p.thread_buf_[k_base + 2] =
|
||||||
|
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_lo >> 16) & 0xFFu));
|
||||||
|
p.thread_buf_[k_base + 3] =
|
||||||
|
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_lo >> 24) & 0xFFu));
|
||||||
|
p.thread_buf_[k_base + 4] =
|
||||||
|
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_hi >> 0) & 0xFFu));
|
||||||
|
p.thread_buf_[k_base + 5] =
|
||||||
|
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_hi >> 8) & 0xFFu));
|
||||||
|
p.thread_buf_[k_base + 6] =
|
||||||
|
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_hi >> 16) & 0xFFu));
|
||||||
|
p.thread_buf_[k_base + 7] =
|
||||||
|
bit_cast<fp8_t>(static_cast<fp8_raw_t>((out_hi >> 24) & 0xFFu));
|
||||||
});
|
});
|
||||||
|
#pragma clang diagnostic pop
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@@ -964,22 +973,52 @@ struct UnifiedAttentionPipeline
|
|||||||
// future MFMA shape that doesn't fit the
|
// future MFMA shape that doesn't fit the
|
||||||
// paired-lane swap pattern). ----
|
// paired-lane swap pattern). ----
|
||||||
//
|
//
|
||||||
// 1. `p_qkc` is a static_distributed_tensor<fp8>
|
// The cvt phase is separated out here (vs fused into
|
||||||
|
// the swap as in branch (A)) because the relayout
|
||||||
|
// travels through LDS, not through `ds_bpermute`, so
|
||||||
|
// there's no swap-latency to hide.
|
||||||
|
//
|
||||||
|
// 1. cvt_pk_fp8_f32 chain into `sp(idx).p.thread_buf_`.
|
||||||
|
// 2. `p_qkc` is a static_distributed_tensor<fp8>
|
||||||
// whose distribution metadata says "QK-C
|
// whose distribution metadata says "QK-C
|
||||||
// layout". Its register bytes are populated
|
// layout". Its register bytes are populated
|
||||||
// from `sp(idx).p.thread_buf_`, which is
|
// from `sp(idx).p.thread_buf_`, which is
|
||||||
// exactly where the cvt_pk_fp8_f32 chain just
|
// exactly where the cvt_pk_fp8_f32 chain just
|
||||||
// wrote the FP8 bytes (the union has them at
|
// wrote the FP8 bytes (the union has them at
|
||||||
// QK-C-layout register offsets).
|
// QK-C-layout register offsets).
|
||||||
// 2. `store_tile` writes `p_qkc` to LDS at
|
// 3. `store_tile` writes `p_qkc` to LDS at
|
||||||
// canonical (M, N) order.
|
// canonical (M, N) order.
|
||||||
// 3. Block-level barrier.
|
// 4. Block-level barrier.
|
||||||
// 4. `load_tile` reads from the same LDS region
|
// 5. `load_tile` reads from the same LDS region
|
||||||
// with the PV-A distribution.
|
// with the PV-A distribution.
|
||||||
// 5. Copy `p_pva.thread_buf_` back into
|
// 6. Copy `p_pva.thread_buf_` back into
|
||||||
// `sp(idx).p` so the gemm_1 call site reads
|
// `sp(idx).p` so the gemm_1 call site reads
|
||||||
// correctly-laid-out data with no further
|
// correctly-laid-out data with no further
|
||||||
// changes.
|
// changes.
|
||||||
|
#pragma clang diagnostic push
|
||||||
|
#pragma clang diagnostic ignored "-Wuninitialized"
|
||||||
|
int dummy_old;
|
||||||
|
static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 4>{}([&](auto idx) {
|
||||||
|
const float a = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 0]);
|
||||||
|
const float b = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]);
|
||||||
|
const float c = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 2]);
|
||||||
|
const float d = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 3]);
|
||||||
|
|
||||||
|
const uint32_t lo =
|
||||||
|
__builtin_amdgcn_cvt_pk_fp8_f32(a, b, dummy_old, /*hi=*/false);
|
||||||
|
const uint32_t packed =
|
||||||
|
__builtin_amdgcn_cvt_pk_fp8_f32(c, d, lo, /*hi=*/true);
|
||||||
|
sp(sp_reg_idx).p.thread_buf_[idx + 0] =
|
||||||
|
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 0) & 0xFFu));
|
||||||
|
sp(sp_reg_idx).p.thread_buf_[idx + 1] =
|
||||||
|
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 8) & 0xFFu));
|
||||||
|
sp(sp_reg_idx).p.thread_buf_[idx + 2] =
|
||||||
|
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 16) & 0xFFu));
|
||||||
|
sp(sp_reg_idx).p.thread_buf_[idx + 3] =
|
||||||
|
bit_cast<fp8_t>(static_cast<fp8_raw_t>((packed >> 24) & 0xFFu));
|
||||||
|
});
|
||||||
|
#pragma clang diagnostic pop
|
||||||
|
|
||||||
auto p_qkc = make_static_distributed_tensor<PDataType>(
|
auto p_qkc = make_static_distributed_tensor<PDataType>(
|
||||||
sp(sp_reg_idx).sp_compute.get_tile_distribution());
|
sp(sp_reg_idx).sp_compute.get_tile_distribution());
|
||||||
static_assert(
|
static_assert(
|
||||||
|
|||||||
Reference in New Issue
Block a user