[rocm-libraries] ROCm/rocm-libraries#6051 (commit f0838b2)

[CK] Add FP8 per-tensor quantization support for FMHA V3
 pipeline (#6051)

## Motivation

The existing FMHA V3 pipeline only supports fp16/bf16 data types. This
PR extends V3 to handle FP8 inputs with per-tensor descaling on gfx950,
enabling higher throughput for
  FP8 inference workloads using the assembly-optimized V3 code path.

  ## Technical Details

  **Warp GEMM:**
- Add FP8 32x32x32 warp gemm with C-transposed distribution
(`WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed`) and dispatcher entries

  **V3 Kernel (`fmha_fwd_v3_kernel.hpp`):**
- Add per-tensor descale support for Q, K, V tensors, passing descale
pointers through to pipeline kargs

  **V3 Pipeline (`block_fmha_fwd_v3_pipeline.hpp`):**
  - Add FP8 data path with dtype-aware type selection
  - Add asm volatile P matrix conversion from f32 to fp8
  - Add FP8-aware instruction scheduling in `CoreLoopScheduler`

**V3 Pipeline Policy
(`block_fmha_fwd_v3_pipeline_default_policy.hpp`):**
- Add FP8 QK warp gemm selection (SwizzleB variant for V tile
distribution compatibility)

  **Codegen (`fmha_fwd.py`):**
  - Add gfx950 FP8BF16 V3 tile size (256x64x128x128x64x128)
- Add FP8BF16 V3 pipeline variants (mask: no/causal, qscale:
no/pertensor)
  - Extend `can_dispatch_v3` condition for fp8bf16 + pertensor

  **Misc:**
- Add LLVM scheduler `TRANS` mask to `LLVMSchedGroupMask` enum
(`arch.hpp`)
- Fix `mask_info` default initialization for `no_mask` case (`mask.hpp`)

V3 dispatch for FP8 is disabled by default (`F_is_v3_enabled=false`)
pending further validation.

## Performance: fmha_fwd V3 FP8 (avg runs 2-6, stock ROCm 7.1.1, gfx950)

  | Problem | Regular (TFlops) | Varlen (TFlops) |
  |---|---:|---:|
  | batch=1 heads=6/1 seqlen=1024 causal | 48.9 | 47.6 |
  | batch=1 heads=6/1 seqlen=2048 causal | 119.8 | 117.4 |
  | batch=1 heads=6/1 seqlen=4096 causal | 263.7 | 259.2 |
  | batch=1 heads=6/1 seqlen=8192 causal | 548.9 | 543.6 |
  | batch=1 heads=6/1 seqlen=16384 causal | 1043.0 | 1063.7 |
  | batch=1 heads=6/1 seqlen=32768 causal | 1237.2 | 1279.6 |
  | batch=1 heads=6/1 seqlen=65536 causal | 1315.4 | 1382.7 |
  | batch=1 heads=6/1 seqlen=131072 causal | 1326.3 | 1402.2 |
  | batch=1 heads=16/1 seqlen=65536 causal | 1298.7 | 1388.4 |
  | batch=1 heads=40/40 seqlen=37200 non-causal | 1248.9 | 1326.1 |

## Test Plan

Tested with aiter's `test_mha_fp8.py` test suite (176 cases) covering
batch sizes (1-2), sequence lengths (113-4096), head counts (5/8/32/40),
GQA ratios (1:1, 1:8), and
causal/non-causal modes. Verified all cases dispatch to the V3 pipeline
by enabling `F_is_v3_enabled` and confirming kernel names contain
`qr_async_trload_v3`.

  ## Test Result

176/176 tests passed with V3 enabled. All cases correctly dispatched to
V3 pipeline with `pertensor` quantization.

  ## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Po Yen Chen
2026-04-07 14:20:43 +00:00
committed by assistant-librarian[bot]
parent 020b6f435e
commit c2ac7aa7b0
10 changed files with 564 additions and 524 deletions

View File

@@ -27,6 +27,7 @@ struct FmhaFwdV3Kernel
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
using PDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::PDataType>;
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
@@ -38,6 +39,7 @@ struct FmhaFwdV3Kernel
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
@@ -118,11 +120,21 @@ struct FmhaFwdV3Kernel
float logits_soft_cap_rcp;
};
struct FmhaFwdCommonQScaleKargs
{
const void* q_descale_ptr = nullptr;
const void* k_descale_ptr = nullptr;
const void* v_descale_ptr = nullptr;
};
struct FmhaFwdBatchModeKargs
: FmhaFwdCommonKargs,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<2>>
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
FmhaFwdEmptyKargs<2>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<3>>
{
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
@@ -139,7 +151,10 @@ struct FmhaFwdV3Kernel
: FmhaFwdCommonKargs,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<2>>
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
FmhaFwdEmptyKargs<2>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<3>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
@@ -166,6 +181,9 @@ struct FmhaFwdV3Kernel
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* q_descale_ptr,
const void* k_descale_ptr,
const void* v_descale_ptr,
void* lse_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
@@ -218,6 +236,7 @@ struct FmhaFwdV3Kernel
nhead_stride_o}, // args for common karg
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for qscale
{}, // placeholder for logits_soft_cap
batch_stride_q,
batch_stride_k,
@@ -237,6 +256,12 @@ struct FmhaFwdV3Kernel
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
}
if constexpr(kHasLogitsSoftCap)
{
kargs.init_logits_soft_cap(logits_soft_cap);
@@ -252,6 +277,9 @@ struct FmhaFwdV3Kernel
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* q_descale_ptr,
const void* k_descale_ptr,
const void* v_descale_ptr,
void* lse_ptr,
void* o_ptr,
const void* seqstart_q_ptr,
@@ -301,6 +329,7 @@ struct FmhaFwdV3Kernel
nhead_stride_o}, // args for common karg
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for qscale
{}, // placeholder for logits_soft_cap
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
@@ -319,6 +348,12 @@ struct FmhaFwdV3Kernel
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
}
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
}
if constexpr(kHasLogitsSoftCap)
{
kargs.init_logits_soft_cap(logits_soft_cap);
@@ -437,8 +472,19 @@ struct FmhaFwdV3Kernel
{
using namespace ck_tile;
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
// Notice: When using double buffering, make sure both buffers are in the same array.
// This prevents the compiler from using separate VGPRs to store the base address
// and enables the use of immediate offsets in load/store instructions.
constexpr auto smem_size_kv =
FmhaPipeline::Policy::template GetSmemSizeKV<typename FmhaPipeline::Problem>();
__shared__ char smem_k[2][smem_size_kv];
__shared__ char smem_v[2][smem_size_kv];
auto* smem_k0 = reinterpret_cast<KDataType*>(smem_k[0]);
auto* smem_k1 = reinterpret_cast<KDataType*>(smem_k[1]);
auto* smem_v0 = reinterpret_cast<VDataType*>(smem_v[0]);
auto* smem_v1 = reinterpret_cast<VDataType*>(smem_v[1]);
;
// divide problem
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
@@ -640,32 +686,88 @@ struct FmhaFwdV3Kernel
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
}();
const float scale_s = [&] {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));
return kargs.scale_s * q_descale * k_descale;
}
else
{
return kargs.scale_s;
}
}();
AttentionVariant variant;
const auto variant_params = [&] {
if constexpr(kHasLogitsSoftCap)
{
return ck_tile::LogitsSoftCapParams<FmhaMask, CK_TILE_FMHA_FWD_FAST_EXP2>{
mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
mask, scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
}
else
{
return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
return ck_tile::StandardAttentionParams<FmhaMask>{mask, scale_s};
}
}();
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
auto o_acc_tile = [&]() {
return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
lse_dram_window,
mask,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr);
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
float v_descale = *(reinterpret_cast<const float*>(kargs.v_descale_ptr));
float scale_p = ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
float scale_o = v_descale / scale_p;
auto o_acc_element_func = [&]() {
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
return make_composes(
ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o});
else
return ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o};
}();
return FmhaPipeline{}(
q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales<remove_cvref_t<decltype(scale_p)>>{scale_p}, // p_compute_element_func
o_acc_element_func,
mask,
scale_s,
variant,
variant_params,
block_indices,
smem_k0,
smem_k1,
smem_v0,
smem_v1);
}
else
{
return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
lse_dram_window,
mask,
scale_s,
variant,
variant_params,
block_indices,
smem_k0,
smem_k1,
smem_v0,
smem_v1);
}
}();
// O DRAM and O DRAM window