diff --git a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp index 9c7cebfce0..9fbd2c4a40 100644 --- a/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp +++ b/include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp @@ -833,29 +833,6 @@ struct UnifiedAttentionPipeline "fp8 P conversion expects packs of 4 fp32 lanes per " "thread; widen the warp gemm M distribution if this " "trips."); -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wuninitialized" - int dummy_old; - static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 4>{}([&](auto idx) { - const float a = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 0]); - const float b = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]); - const float c = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 2]); - const float d = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 3]); - - uint32_t lo = - __builtin_amdgcn_cvt_pk_fp8_f32(a, b, dummy_old, /*hi=*/false); - uint32_t packed = - __builtin_amdgcn_cvt_pk_fp8_f32(c, d, lo, /*hi=*/true); - sp(sp_reg_idx).p.thread_buf_[idx + 0] = - bit_cast(static_cast((packed >> 0) & 0xFFu)); - sp(sp_reg_idx).p.thread_buf_[idx + 1] = - bit_cast(static_cast((packed >> 8) & 0xFFu)); - sp(sp_reg_idx).p.thread_buf_[idx + 2] = - bit_cast(static_cast((packed >> 16) & 0xFFu)); - sp(sp_reg_idx).p.thread_buf_[idx + 3] = - bit_cast(static_cast((packed >> 24) & 0xFFu)); - }); -#pragma clang diagnostic pop // --------------------------------------------------------- // FP8 P-tile QK-C -> PV-A re-layout. @@ -900,22 +877,33 @@ struct UnifiedAttentionPipeline // tiny-decode tier where (A) doesn't apply. This // keeps the previously-tuned 32x32x16 perf intact // while enabling FP8 on the m16 tier. + // + // For strategy (A) the cvt and the cross-lane swap are + // fused into a single 8-fp8-per-iter loop so that the + // ds_bpermute_b32 latency overlaps with subsequent + // cvt_pk_fp8_f32 calls (instead of running serially + // after the whole cvt phase finishes). using PVWarpTile = typename UnifiedAttentionShape::Gemm1WarpTile; if constexpr(PVWarpTile::at(number<0>{}) == 32 && PVWarpTile::at(number<1>{}) == 32 && PVWarpTile::at(number<2>{}) == 16) { - // ---- (A) Cross-lane in-register swap (32x32x16). ---- + // ---- (A) Fused cvt + cross-lane swap (32x32x16). ---- // - // For each 8-fp8 PV K-chunk (one warp-gemm K iter), - // the slot decomposition is: + // Per 8-fp8 K-chunk: + // 1. cvt 8 fp32 -> 2 packed uint32 (lo_pack = slot[0..3], + // hi_pack = slot[4..7]) using the chained-`old` pattern + // that matches `cast_tile_pk_fp8_fp32`. + // 2. ds_bpermute the "bad" pack to the paired lane (lane^32). + // 3. Write back both packs as 8 fp8 bytes; the "good" half + // gets written first so its byte stores overlap with the + // ds_bpermute latency. + // + // Slot decomposition (per fmha_alu1 doc above): // sub=0 | slot[0..3] | N=0..3 | K=0..3 OK // sub=0 | slot[4..7] | N=8..11 | K=4..7 BAD // sub=1 | slot[0..3] | N=4..7 | K=8..11 BAD // sub=1 | slot[4..7] | N=12..15 | K=12..15 OK - // Each sub's *bad* 4-fp8 chunk holds exactly the data - // the *paired* sub (lane ^ 32) needs for its bad - // slot. One ds_bpermute_b32 per K-chunk fixes it. static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 8 == 0, "FP8 32x32x16 + Single cross-lane permute " "expects PV per-thread buffer in chunks of 8 " @@ -925,38 +913,59 @@ struct UnifiedAttentionPipeline const int paired_addr = (lane_id ^ 32) << 2; // bytes const bool is_sub_0 = (lane_id & 32) == 0; - auto pack4 = [](fp8_t a, fp8_t b, fp8_t c, fp8_t d) -> uint32_t { - return (static_cast(bit_cast(a)) << 0) | - (static_cast(bit_cast(b)) << 8) | - (static_cast(bit_cast(c)) << 16) | - (static_cast(bit_cast(d)) << 24); - }; - auto unpack4 = [](uint32_t v, fp8_t& a, fp8_t& b, fp8_t& c, fp8_t& d) { - a = bit_cast(static_cast((v >> 0) & 0xFFu)); - b = bit_cast(static_cast((v >> 8) & 0xFFu)); - c = bit_cast(static_cast((v >> 16) & 0xFFu)); - d = bit_cast(static_cast((v >> 24) & 0xFFu)); - }; - +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wuninitialized" + int dummy_old; static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 8>{}([&](auto k_base) { - auto& p = sp(sp_reg_idx).p; - const uint32_t own_bad = - is_sub_0 - ? pack4(p.thread_buf_[k_base + 4], p.thread_buf_[k_base + 5], - p.thread_buf_[k_base + 6], p.thread_buf_[k_base + 7]) - : pack4(p.thread_buf_[k_base + 0], p.thread_buf_[k_base + 1], - p.thread_buf_[k_base + 2], p.thread_buf_[k_base + 3]); - const uint32_t recv = - __builtin_amdgcn_ds_bpermute(paired_addr, static_cast(own_bad)); - if(is_sub_0) - unpack4(recv, - p.thread_buf_[k_base + 4], p.thread_buf_[k_base + 5], - p.thread_buf_[k_base + 6], p.thread_buf_[k_base + 7]); - else - unpack4(recv, - p.thread_buf_[k_base + 0], p.thread_buf_[k_base + 1], - p.thread_buf_[k_base + 2], p.thread_buf_[k_base + 3]); + auto& p = sp(sp_reg_idx).p; + auto& sc = sp(sp_reg_idx).sp_compute; + + const float a = p_compute_element_func(sc.thread_buf_[k_base + 0]); + const float b = p_compute_element_func(sc.thread_buf_[k_base + 1]); + const float c = p_compute_element_func(sc.thread_buf_[k_base + 2]); + const float d = p_compute_element_func(sc.thread_buf_[k_base + 3]); + const float e = p_compute_element_func(sc.thread_buf_[k_base + 4]); + const float f = p_compute_element_func(sc.thread_buf_[k_base + 5]); + const float g = p_compute_element_func(sc.thread_buf_[k_base + 6]); + const float h = p_compute_element_func(sc.thread_buf_[k_base + 7]); + + const uint32_t lo_tmp = + __builtin_amdgcn_cvt_pk_fp8_f32(a, b, dummy_old, /*hi=*/false); + const uint32_t lo_pack = + __builtin_amdgcn_cvt_pk_fp8_f32(c, d, lo_tmp, /*hi=*/true); + const uint32_t hi_tmp = + __builtin_amdgcn_cvt_pk_fp8_f32(e, f, dummy_old, /*hi=*/false); + const uint32_t hi_pack = + __builtin_amdgcn_cvt_pk_fp8_f32(g, h, hi_tmp, /*hi=*/true); + + // Issue ds_bpermute as early as possible so its LDS-DMA + // latency overlaps with the byte writes below (and with + // the next K-chunk's cvts after this iter unrolls). + const uint32_t own_bad = is_sub_0 ? hi_pack : lo_pack; + const uint32_t recv = __builtin_amdgcn_ds_bpermute( + paired_addr, static_cast(own_bad)); + + const uint32_t out_lo = is_sub_0 ? lo_pack : recv; + const uint32_t out_hi = is_sub_0 ? recv : hi_pack; + + p.thread_buf_[k_base + 0] = + bit_cast(static_cast((out_lo >> 0) & 0xFFu)); + p.thread_buf_[k_base + 1] = + bit_cast(static_cast((out_lo >> 8) & 0xFFu)); + p.thread_buf_[k_base + 2] = + bit_cast(static_cast((out_lo >> 16) & 0xFFu)); + p.thread_buf_[k_base + 3] = + bit_cast(static_cast((out_lo >> 24) & 0xFFu)); + p.thread_buf_[k_base + 4] = + bit_cast(static_cast((out_hi >> 0) & 0xFFu)); + p.thread_buf_[k_base + 5] = + bit_cast(static_cast((out_hi >> 8) & 0xFFu)); + p.thread_buf_[k_base + 6] = + bit_cast(static_cast((out_hi >> 16) & 0xFFu)); + p.thread_buf_[k_base + 7] = + bit_cast(static_cast((out_hi >> 24) & 0xFFu)); }); +#pragma clang diagnostic pop } else { @@ -964,22 +973,52 @@ struct UnifiedAttentionPipeline // future MFMA shape that doesn't fit the // paired-lane swap pattern). ---- // - // 1. `p_qkc` is a static_distributed_tensor + // The cvt phase is separated out here (vs fused into + // the swap as in branch (A)) because the relayout + // travels through LDS, not through `ds_bpermute`, so + // there's no swap-latency to hide. + // + // 1. cvt_pk_fp8_f32 chain into `sp(idx).p.thread_buf_`. + // 2. `p_qkc` is a static_distributed_tensor // whose distribution metadata says "QK-C // layout". Its register bytes are populated // from `sp(idx).p.thread_buf_`, which is // exactly where the cvt_pk_fp8_f32 chain just // wrote the FP8 bytes (the union has them at // QK-C-layout register offsets). - // 2. `store_tile` writes `p_qkc` to LDS at + // 3. `store_tile` writes `p_qkc` to LDS at // canonical (M, N) order. - // 3. Block-level barrier. - // 4. `load_tile` reads from the same LDS region + // 4. Block-level barrier. + // 5. `load_tile` reads from the same LDS region // with the PV-A distribution. - // 5. Copy `p_pva.thread_buf_` back into + // 6. Copy `p_pva.thread_buf_` back into // `sp(idx).p` so the gemm_1 call site reads // correctly-laid-out data with no further // changes. +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wuninitialized" + int dummy_old; + static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 4>{}([&](auto idx) { + const float a = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 0]); + const float b = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]); + const float c = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 2]); + const float d = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 3]); + + const uint32_t lo = + __builtin_amdgcn_cvt_pk_fp8_f32(a, b, dummy_old, /*hi=*/false); + const uint32_t packed = + __builtin_amdgcn_cvt_pk_fp8_f32(c, d, lo, /*hi=*/true); + sp(sp_reg_idx).p.thread_buf_[idx + 0] = + bit_cast(static_cast((packed >> 0) & 0xFFu)); + sp(sp_reg_idx).p.thread_buf_[idx + 1] = + bit_cast(static_cast((packed >> 8) & 0xFFu)); + sp(sp_reg_idx).p.thread_buf_[idx + 2] = + bit_cast(static_cast((packed >> 16) & 0xFFu)); + sp(sp_reg_idx).p.thread_buf_[idx + 3] = + bit_cast(static_cast((packed >> 24) & 0xFFu)); + }); +#pragma clang diagnostic pop + auto p_qkc = make_static_distributed_tensor( sp(sp_reg_idx).sp_compute.get_tile_distribution()); static_assert(