[rocm-libraries] ROCm/rocm-libraries#6051 (commit f0838b2)

[CK] Add FP8 per-tensor quantization support for FMHA V3
 pipeline (#6051)

## Motivation

The existing FMHA V3 pipeline only supports fp16/bf16 data types. This
PR extends V3 to handle FP8 inputs with per-tensor descaling on gfx950,
enabling higher throughput for
  FP8 inference workloads using the assembly-optimized V3 code path.

  ## Technical Details

  **Warp GEMM:**
- Add FP8 32x32x32 warp gemm with C-transposed distribution
(`WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed`) and dispatcher entries

  **V3 Kernel (`fmha_fwd_v3_kernel.hpp`):**
- Add per-tensor descale support for Q, K, V tensors, passing descale
pointers through to pipeline kargs

  **V3 Pipeline (`block_fmha_fwd_v3_pipeline.hpp`):**
  - Add FP8 data path with dtype-aware type selection
  - Add asm volatile P matrix conversion from f32 to fp8
  - Add FP8-aware instruction scheduling in `CoreLoopScheduler`

**V3 Pipeline Policy
(`block_fmha_fwd_v3_pipeline_default_policy.hpp`):**
- Add FP8 QK warp gemm selection (SwizzleB variant for V tile
distribution compatibility)

  **Codegen (`fmha_fwd.py`):**
  - Add gfx950 FP8BF16 V3 tile size (256x64x128x128x64x128)
- Add FP8BF16 V3 pipeline variants (mask: no/causal, qscale:
no/pertensor)
  - Extend `can_dispatch_v3` condition for fp8bf16 + pertensor

  **Misc:**
- Add LLVM scheduler `TRANS` mask to `LLVMSchedGroupMask` enum
(`arch.hpp`)
- Fix `mask_info` default initialization for `no_mask` case (`mask.hpp`)

V3 dispatch for FP8 is disabled by default (`F_is_v3_enabled=false`)
pending further validation.

## Performance: fmha_fwd V3 FP8 (avg runs 2-6, stock ROCm 7.1.1, gfx950)

  | Problem | Regular (TFlops) | Varlen (TFlops) |
  |---|---:|---:|
  | batch=1 heads=6/1 seqlen=1024 causal | 48.9 | 47.6 |
  | batch=1 heads=6/1 seqlen=2048 causal | 119.8 | 117.4 |
  | batch=1 heads=6/1 seqlen=4096 causal | 263.7 | 259.2 |
  | batch=1 heads=6/1 seqlen=8192 causal | 548.9 | 543.6 |
  | batch=1 heads=6/1 seqlen=16384 causal | 1043.0 | 1063.7 |
  | batch=1 heads=6/1 seqlen=32768 causal | 1237.2 | 1279.6 |
  | batch=1 heads=6/1 seqlen=65536 causal | 1315.4 | 1382.7 |
  | batch=1 heads=6/1 seqlen=131072 causal | 1326.3 | 1402.2 |
  | batch=1 heads=16/1 seqlen=65536 causal | 1298.7 | 1388.4 |
  | batch=1 heads=40/40 seqlen=37200 non-causal | 1248.9 | 1326.1 |

## Test Plan

Tested with aiter's `test_mha_fp8.py` test suite (176 cases) covering
batch sizes (1-2), sequence lengths (113-4096), head counts (5/8/32/40),
GQA ratios (1:1, 1:8), and
causal/non-causal modes. Verified all cases dispatch to the V3 pipeline
by enabling `F_is_v3_enabled` and confirming kernel names contain
`qr_async_trload_v3`.

  ## Test Result

176/176 tests passed with V3 enabled. All cases correctly dispatched to
V3 pipeline with `pertensor` quantization.

  ## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Po Yen Chen
2026-04-07 14:20:43 +00:00
committed by assistant-librarian[bot]
parent 020b6f435e
commit c2ac7aa7b0
10 changed files with 564 additions and 524 deletions

View File

@@ -239,10 +239,18 @@ struct BlockFmhaV3PipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
constexpr auto warp_gemm = [] {
if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
// Use SwizzleB variant to get 8 contiguous K positions per lane,
// matching the V tile distribution for PV GEMM
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<>{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
/// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use
/// WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution here
@@ -310,9 +318,8 @@ struct BlockFmhaV3PipelineDefaultPolicy
static constexpr ck_tile::index_t kKLdsPadInBytes = 4 * 4; // 4 dwords
static constexpr ck_tile::index_t kVLdsPadInBytes = 4 * 16; // 16 dwords
template <typename Problem, ck_tile::index_t IBuf = 0>
CK_TILE_DEVICE static constexpr auto
MakeKLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeKLdsStoreBlockDescriptor()
{
using namespace ck_tile;
@@ -323,7 +330,6 @@ struct BlockFmhaV3PipelineDefaultPolicy
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
[[maybe_unused]] constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad =
kKLdsPadInBytes /
@@ -339,31 +345,28 @@ struct BlockFmhaV3PipelineDefaultPolicy
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
make_tuple(number<NumIssues>{}, // n0
number<LaneGroups>{}, // n1
number<NumWarps>{}, // n2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
number<kKPerBlock>{},
number<WarpSize * KVector + kPad>{},
number<KVector>{},
number<1>{}),
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
number<KVector>{},
number<1>{});
constexpr auto k_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumIssues>{}, // n0
number<LaneGroups>{}, // n1
number<NumWarps>{}, // n2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
number<kKPerBlock>{},
number<WarpSize * KVector + kPad>{},
number<KVector>{},
number<1>{}),
number<KVector>{},
number<1>{});
// TODO this layout is hard coded, and will be used in async copy buffer view load
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
// CRITICAL: Must match Load descriptor merge pattern (NumIssues, LaneGroups, NumWarps)
constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
make_tuple(make_merge_transform(make_tuple(
number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
make_merge_transform(make_tuple(number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc_issues_warps_lanes;
}
@@ -458,9 +461,8 @@ struct BlockFmhaV3PipelineDefaultPolicy
return max(SingleKSize, SingleVSize);
}
template <typename Problem, ck_tile::index_t IBuf = 0>
CK_TILE_DEVICE static constexpr auto
MakeVLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVLdsStoreBlockDescriptor()
{
using namespace ck_tile;
@@ -471,7 +473,6 @@ struct BlockFmhaV3PipelineDefaultPolicy
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
[[maybe_unused]] constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentV<Problem>(); // this is for global load
constexpr index_t kPad =
kVLdsPadInBytes /
@@ -487,31 +488,27 @@ struct BlockFmhaV3PipelineDefaultPolicy
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
make_tuple(number<NumIssues>{}, // n0
number<LaneGroups>{}, // n1
number<NumWarps>{}, // n2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
number<kKPerBlock>{},
number<WarpSize * KVector + kPad>{},
number<KVector>{},
number<1>{}),
number<(IBuf + 2) * GetSingleSmemElementSpaceSize<Problem>()>{},
number<KVector>{},
number<1>{});
constexpr auto v_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumIssues>{}, // n0
number<LaneGroups>{}, // n1
number<NumWarps>{}, // n2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
number<kKPerBlock>{},
number<WarpSize * KVector + kPad>{},
number<KVector>{},
number<1>{}),
number<KVector>{},
number<1>{});
// TODO this layout is hard coded, and will be used in async copy buffer view load
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
constexpr auto v_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
make_tuple(make_merge_transform(make_tuple(
number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
make_merge_transform(make_tuple(number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return v_lds_block_desc_issues_warps_lanes;
}