[CK] Add FP8 per-tensor quantization support for FMHA V3 pipeline (#6051)

## Motivation

The existing FMHA V3 pipeline only supports fp16/bf16 data types. This
PR extends V3 to handle FP8 inputs with per-tensor descaling on gfx950,
enabling higher throughput for
  FP8 inference workloads using the assembly-optimized V3 code path.

  ## Technical Details

  **Warp GEMM:**
- Add FP8 32x32x32 warp gemm with C-transposed distribution
(`WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed`) and dispatcher entries

  **V3 Kernel (`fmha_fwd_v3_kernel.hpp`):**
- Add per-tensor descale support for Q, K, V tensors, passing descale
pointers through to pipeline kargs

  **V3 Pipeline (`block_fmha_fwd_v3_pipeline.hpp`):**
  - Add FP8 data path with dtype-aware type selection
  - Add asm volatile P matrix conversion from f32 to fp8
  - Add FP8-aware instruction scheduling in `CoreLoopScheduler`

**V3 Pipeline Policy
(`block_fmha_fwd_v3_pipeline_default_policy.hpp`):**
- Add FP8 QK warp gemm selection (SwizzleB variant for V tile
distribution compatibility)

  **Codegen (`fmha_fwd.py`):**
  - Add gfx950 FP8BF16 V3 tile size (256x64x128x128x64x128)
- Add FP8BF16 V3 pipeline variants (mask: no/causal, qscale:
no/pertensor)
  - Extend `can_dispatch_v3` condition for fp8bf16 + pertensor

  **Misc:**
- Add LLVM scheduler `TRANS` mask to `LLVMSchedGroupMask` enum
(`arch.hpp`)
- Fix `mask_info` default initialization for `no_mask` case (`mask.hpp`)

V3 dispatch for FP8 is disabled by default (`F_is_v3_enabled=false`)
pending further validation.

## Performance: fmha_fwd V3 FP8 (avg runs 2-6, stock ROCm 7.1.1, gfx950)

  | Problem | Regular (TFlops) | Varlen (TFlops) |
  |---|---:|---:|
  | batch=1 heads=6/1 seqlen=1024 causal | 48.9 | 47.6 |
  | batch=1 heads=6/1 seqlen=2048 causal | 119.8 | 117.4 |
  | batch=1 heads=6/1 seqlen=4096 causal | 263.7 | 259.2 |
  | batch=1 heads=6/1 seqlen=8192 causal | 548.9 | 543.6 |
  | batch=1 heads=6/1 seqlen=16384 causal | 1043.0 | 1063.7 |
  | batch=1 heads=6/1 seqlen=32768 causal | 1237.2 | 1279.6 |
  | batch=1 heads=6/1 seqlen=65536 causal | 1315.4 | 1382.7 |
  | batch=1 heads=6/1 seqlen=131072 causal | 1326.3 | 1402.2 |
  | batch=1 heads=16/1 seqlen=65536 causal | 1298.7 | 1388.4 |
  | batch=1 heads=40/40 seqlen=37200 non-causal | 1248.9 | 1326.1 |

## Test Plan

Tested with aiter's `test_mha_fp8.py` test suite (176 cases) covering
batch sizes (1-2), sequence lengths (113-4096), head counts (5/8/32/40),
GQA ratios (1:1, 1:8), and
causal/non-causal modes. Verified all cases dispatch to the V3 pipeline
by enabling `F_is_v3_enabled` and confirming kernel names contain
`qr_async_trload_v3`.

  ## Test Result

176/176 tests passed with V3 enabled. All cases correctly dispatched to
V3 pipeline with `pertensor` quantization.

  ## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Po Yen Chen
2026-04-07 22:19:28 +08:00
committed by GitHub
parent 449844e3d3
commit 6dc44114ba
10 changed files with 564 additions and 524 deletions

View File

@@ -22,6 +22,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added FP8 block scale quantization for FMHA forward kernel.
* Added gfx11 support for FMHA.
* Added microscaling (MX) FP8/FP4 support on gfx950 for FMHA forward kernel ("qr" pipeline only).
* Added FP8 per-tensor quantization support for FMHA forward V3 pipeline on gfx950.
### Changed

View File

@@ -206,22 +206,14 @@ float {F_func_name}([[maybe_unused]] fmha_fwd_traits t, [[maybe_unused]] fmha_fw
"""
FMHA_FWD_API_FOOTER_TEMPLATE = """
float fmha_fwd(fmha_fwd_traits traits, fmha_fwd_args args, const ck_tile::stream_config& config) {{
const std::string device_name = ck_tile::get_device_name();
const bool is_swa = (traits.mask_type != mask_enum::no_mask) and
((0 < args.window_size_left) or (0 < args.window_size_right));
const bool can_dispatch_v3 =
(device_name.compare(0, 6, "gfx950") == 0) and
(traits.data_type.compare("fp16") == 0 or traits.data_type.compare("bf16") == 0) and
traits.is_v_rowmajor and (traits.bias_type == bias_enum::no_bias) and
(not traits.has_lse) and (not traits.has_dropout) and
(traits.qscale_type == quant_scale_enum::no_scale) and (not is_swa) and
(args.nhead_q % args.nhead_k == 0) and (args.hdim_q == 128) and (args.hdim_v == 128);
if ({F_is_v3_enabled} and can_dispatch_v3) {{
return fmha_fwd_v3(traits, args, config);
}} else {{
return fmha_fwd_v2(traits, args, config);
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunreachable-code"
if ({F_is_v3_enabled}) {{
float r = fmha_fwd_v3(traits, args, config);
if (r >= 0) return r;
}}
#pragma clang diagnostic pop
return fmha_fwd_v2(traits, args, config);
}}
"""
@@ -1059,10 +1051,11 @@ class KernelComponentFactoryGfx950(
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
result = KernelComponentFactoryGfx9.get_hdim_tile_size_dict(dtype)
if dtype in cls._DT_FP16_BF16:
# add tile for qr_async_trload_v3
if (128, 128) in result.keys():
result[(128, 128)].append(
FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip
# # add tile for qr_async_trload_v3 (bf16/fp16 V3 not ready)
# if (128, 128) in result.keys():
# result[(128, 128)].append(
# FmhaFwdTileSize(256, 32, 128, 128, 32, 128, 8, 1, 1, 8, 1, 1, 32, 32, 16, 32, 32, 16, -1)) # fmt: skip
pass
elif dtype in cls._DT_MXFP8:
return {
# bm0, bn0, bk0, bn1, bk1,
@@ -1075,6 +1068,10 @@ class KernelComponentFactoryGfx950(
(128, 128) : [FmhaFwdTileSize(128, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 32, 32, 64, 32, 32, 64, -1)],
(256, 256) : [FmhaFwdTileSize(128, 128, 128, 256, 128, 256, 4, 1, 1, 4, 1, 1, 16, 16, 128, 16, 16, 128, -1)],
} # fmt: skip
elif dtype in cls._DT_FP8BF16:
if (128, 128) in result.keys():
result[(128, 128)].append(
FmhaFwdTileSize(256, 64, 128, 128, 64, 128, 8, 1, 1, 8, 1, 1, 32, 32, 32, 32, 32, 32, -1)) # fmt: skip
return result
@classmethod
@@ -1105,12 +1102,19 @@ class KernelComponentFactoryGfx950(
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr_async_trload", "row", "f", "f", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "t", sink)) # fmt: skip
# qr_async_trload_v3 only supports hdim=hdim_v=128 for now
if (hdim, hdim_v) == (128, 128):
# qr_async_trload_v3 only supports (generic) causal mask
for logits, mask in itertools.product(["t", "f"], ["no", "causal"]):
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
# # qr_async_trload_v3 bf16/fp16 not ready
# if (hdim, hdim_v) == (128, 128):
# for logits, mask in itertools.product(["t", "f"], ["no", "causal"]):
# pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f",
# F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
elif dtype in cls._DT_FP8BF16:
# qr_async_trload_v3 only supports (generic) causal mask
for logits, qscale, mask in itertools.product(
["t", "f"],
["no", "pertensor"],
["no", "causal"],
):
pipelines.append(FmhaFwdPipeline("qr_async_trload_v3", "row", "t", "t", "f", "f", F_logits=logits, F_bias="no", F_lse="f", F_dropout="f", F_qscale=qscale, F_mask=mask, F_skip="f", F_trload="t", F_sink="f")) # fmt: skip
elif dtype in cls._DT_MXFP8 or dtype in cls._DT_MXFP4:
# no need dropout kernels
@@ -1494,8 +1498,8 @@ def write_fwd_api(
FMHA_FWD_API_FOOTER_TEMPLATE.format(
F_is_v3_enabled=BOOL_MAP[
# NOTE: enable v3 pipelines when ready
# 0 < api_pool.get_num_traits(filter_fn=accept_only_v3)
False
0 < api_pool.get_num_traits(filter_fn=accept_only_v3)
# False
]
),
]

View File

@@ -844,6 +844,9 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
return FmhaKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.q_descale_ptr,
args.k_descale_ptr,
args.v_descale_ptr,
nullptr, // lse_ptr
args.o_ptr,
args.seqstart_q_ptr,
@@ -877,6 +880,9 @@ auto fmha_fwd_v3_create_kargs_and_grids(fmha_fwd_args args)
return FmhaKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.q_descale_ptr,
args.k_descale_ptr,
args.v_descale_ptr,
nullptr, // lse_ptr
args.o_ptr,
args.seqlen_q,

View File

@@ -1209,7 +1209,8 @@ enum LLVMSchedGroupMask : int32_t
DS = 1 << 7,
DS_READ = 1 << 8,
DS_WRITE = 1 << 9,
ALL = (DS_WRITE << 1) - 1,
TRANS = 1 << 10,
ALL = (TRANS << 1) - 1,
};
CK_TILE_HOST_DEVICE static constexpr auto get_max_mem_vec_inst_width()

View File

@@ -27,6 +27,8 @@ inline constexpr bool
kattr_no_packed_fp32_ops_v<T, std::void_t<decltype(T::kattr_no_packed_fp32_ops)>> =
T::kattr_no_packed_fp32_ops;
// TODO: rename to something more specific (e.g. kernel_attr_no_packed_fp32) since
// kernel_attr<bool> only controls the no-packed-fp32-ops flag, not a general attribute bag.
template <bool no_packed_fp32_ops>
struct kernel_attr
{
@@ -35,6 +37,32 @@ struct kernel_attr
static constexpr bool kattr_no_packed_fp32_ops = no_packed_fp32_ops;
};
// Compose an architecture tag with kernel attributes.
// Inherits ArchTag for symbol mangling and adds attribute flags.
// kernel_attr_for<gfx950_t> -> gfx950_t (identity)
// kernel_attr_for<gfx950_t, kernel_attr<true>> -> unique type with attribute
namespace detail {
template <typename ArchTag, typename... Attrs>
struct kernel_attr_for_impl : ArchTag, Attrs...
{
};
template <typename ArchTag, typename... Attrs>
struct kernel_attr_for_helper
{
using type = kernel_attr_for_impl<ArchTag, Attrs...>;
};
template <typename ArchTag>
struct kernel_attr_for_helper<ArchTag>
{
using type = ArchTag;
};
} // namespace detail
template <typename ArchTag, typename... Attrs>
using kernel_attr_for = typename detail::kernel_attr_for_helper<ArchTag, Attrs...>::type;
#if CK_TILE_USE_LAUNCH_BOUNDS
#define KENTRY_LAUNCH_BOUNDS __launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
#else

View File

@@ -27,6 +27,7 @@ struct FmhaFwdV3Kernel
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
using PDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::PDataType>;
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
@@ -38,6 +39,7 @@ struct FmhaFwdV3Kernel
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
@@ -118,11 +120,21 @@ struct FmhaFwdV3Kernel
float logits_soft_cap_rcp;
};
struct FmhaFwdCommonQScaleKargs
{
const void* q_descale_ptr = nullptr;
const void* k_descale_ptr = nullptr;
const void* v_descale_ptr = nullptr;
};
struct FmhaFwdBatchModeKargs
: FmhaFwdCommonKargs,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<2>>
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
FmhaFwdEmptyKargs<2>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<3>>
{
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
@@ -139,7 +151,10 @@ struct FmhaFwdV3Kernel
: FmhaFwdCommonKargs,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<2>>
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
FmhaFwdCommonQScaleKargs,
FmhaFwdEmptyKargs<2>>,
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<3>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
@@ -166,6 +181,9 @@ struct FmhaFwdV3Kernel
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* q_descale_ptr,
const void* k_descale_ptr,
const void* v_descale_ptr,
void* lse_ptr,
void* o_ptr,
ck_tile::index_t seqlen_q,
@@ -218,6 +236,7 @@ struct FmhaFwdV3Kernel
nhead_stride_o}, // args for common karg
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for qscale
{}, // placeholder for logits_soft_cap
batch_stride_q,
batch_stride_k,
@@ -237,6 +256,12 @@ struct FmhaFwdV3Kernel
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
}
if constexpr(kHasLogitsSoftCap)
{
kargs.init_logits_soft_cap(logits_soft_cap);
@@ -252,6 +277,9 @@ struct FmhaFwdV3Kernel
MakeKargs(const void* q_ptr,
const void* k_ptr,
const void* v_ptr,
const void* q_descale_ptr,
const void* k_descale_ptr,
const void* v_descale_ptr,
void* lse_ptr,
void* o_ptr,
const void* seqstart_q_ptr,
@@ -301,6 +329,7 @@ struct FmhaFwdV3Kernel
nhead_stride_o}, // args for common karg
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for qscale
{}, // placeholder for logits_soft_cap
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
@@ -319,6 +348,12 @@ struct FmhaFwdV3Kernel
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
}
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
kargs.q_descale_ptr = q_descale_ptr;
kargs.k_descale_ptr = k_descale_ptr;
kargs.v_descale_ptr = v_descale_ptr;
}
if constexpr(kHasLogitsSoftCap)
{
kargs.init_logits_soft_cap(logits_soft_cap);
@@ -437,8 +472,19 @@ struct FmhaFwdV3Kernel
{
using namespace ck_tile;
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
// Notice: When using double buffering, make sure both buffers are in the same array.
// This prevents the compiler from using separate VGPRs to store the base address
// and enables the use of immediate offsets in load/store instructions.
constexpr auto smem_size_kv =
FmhaPipeline::Policy::template GetSmemSizeKV<typename FmhaPipeline::Problem>();
__shared__ char smem_k[2][smem_size_kv];
__shared__ char smem_v[2][smem_size_kv];
auto* smem_k0 = reinterpret_cast<KDataType*>(smem_k[0]);
auto* smem_k1 = reinterpret_cast<KDataType*>(smem_k[1]);
auto* smem_v0 = reinterpret_cast<VDataType*>(smem_v[0]);
auto* smem_v1 = reinterpret_cast<VDataType*>(smem_v[1]);
;
// divide problem
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
@@ -640,32 +686,88 @@ struct FmhaFwdV3Kernel
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
}();
const float scale_s = [&] {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));
return kargs.scale_s * q_descale * k_descale;
}
else
{
return kargs.scale_s;
}
}();
AttentionVariant variant;
const auto variant_params = [&] {
if constexpr(kHasLogitsSoftCap)
{
return ck_tile::LogitsSoftCapParams<FmhaMask, CK_TILE_FMHA_FWD_FAST_EXP2>{
mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
mask, scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
}
else
{
return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
return ck_tile::StandardAttentionParams<FmhaMask>{mask, scale_s};
}
}();
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
auto o_acc_tile = [&]() {
return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
lse_dram_window,
mask,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr);
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
{
float v_descale = *(reinterpret_cast<const float*>(kargs.v_descale_ptr));
float scale_p = ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
float scale_o = v_descale / scale_p;
auto o_acc_element_func = [&]() {
if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
return make_composes(
ck_tile::saturates<ck_tile::fp8_t>{},
ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o});
else
return ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o};
}();
return FmhaPipeline{}(
q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales<remove_cvref_t<decltype(scale_p)>>{scale_p}, // p_compute_element_func
o_acc_element_func,
mask,
scale_s,
variant,
variant_params,
block_indices,
smem_k0,
smem_k1,
smem_v0,
smem_v1);
}
else
{
return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
lse_dram_window,
mask,
scale_s,
variant,
variant_params,
block_indices,
smem_k0,
smem_k1,
smem_v0,
smem_v1);
}
}();
// O DRAM and O DRAM window

View File

@@ -239,10 +239,18 @@ struct BlockFmhaV3PipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
constexpr auto warp_gemm = [] {
if constexpr(std::is_same_v<typename Problem::QDataType, fp8_t> &&
std::is_same_v<typename Problem::KDataType, fp8_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
// Use SwizzleB variant to get 8 contiguous K positions per lane,
// matching the V tile distribution for PV GEMM
return WarpGemmMfmaFp8Fp8F32M32N32K32SwizzleBTransposedCDistribution<>{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
/// NOTICE: in order to use load_tile_transpose() later for V tile, we cannot use
/// WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution here
@@ -310,9 +318,8 @@ struct BlockFmhaV3PipelineDefaultPolicy
static constexpr ck_tile::index_t kKLdsPadInBytes = 4 * 4; // 4 dwords
static constexpr ck_tile::index_t kVLdsPadInBytes = 4 * 16; // 16 dwords
template <typename Problem, ck_tile::index_t IBuf = 0>
CK_TILE_DEVICE static constexpr auto
MakeKLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeKLdsStoreBlockDescriptor()
{
using namespace ck_tile;
@@ -323,7 +330,6 @@ struct BlockFmhaV3PipelineDefaultPolicy
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
[[maybe_unused]] constexpr index_t KPack = GetSmemKPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentK<Problem>(); // this is for global load
constexpr index_t kPad =
kKLdsPadInBytes /
@@ -339,31 +345,28 @@ struct BlockFmhaV3PipelineDefaultPolicy
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
make_tuple(number<NumIssues>{}, // n0
number<LaneGroups>{}, // n1
number<NumWarps>{}, // n2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
number<kKPerBlock>{},
number<WarpSize * KVector + kPad>{},
number<KVector>{},
number<1>{}),
number<IBuf * GetSingleSmemElementSpaceSize<Problem>()>{},
number<KVector>{},
number<1>{});
constexpr auto k_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumIssues>{}, // n0
number<LaneGroups>{}, // n1
number<NumWarps>{}, // n2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
number<kKPerBlock>{},
number<WarpSize * KVector + kPad>{},
number<KVector>{},
number<1>{}),
number<KVector>{},
number<1>{});
// TODO this layout is hard coded, and will be used in async copy buffer view load
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
// CRITICAL: Must match Load descriptor merge pattern (NumIssues, LaneGroups, NumWarps)
constexpr auto k_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
make_tuple(make_merge_transform(make_tuple(
number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
make_merge_transform(make_tuple(number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc_issues_warps_lanes;
}
@@ -458,9 +461,8 @@ struct BlockFmhaV3PipelineDefaultPolicy
return max(SingleKSize, SingleVSize);
}
template <typename Problem, ck_tile::index_t IBuf = 0>
CK_TILE_DEVICE static constexpr auto
MakeVLdsStoreBlockDescriptor(ck_tile::number<IBuf> = ck_tile::number<0>{})
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVLdsStoreBlockDescriptor()
{
using namespace ck_tile;
@@ -471,7 +473,6 @@ struct BlockFmhaV3PipelineDefaultPolicy
constexpr index_t NumWarps = Problem::BlockFmhaShape::NumWarps;
constexpr index_t WarpSize = ck_tile::get_warp_size();
[[maybe_unused]] constexpr index_t KPack = GetSmemVPackK<Problem>(); // this is for lds
constexpr index_t KVector = GetAlignmentV<Problem>(); // this is for global load
constexpr index_t kPad =
kVLdsPadInBytes /
@@ -487,31 +488,27 @@ struct BlockFmhaV3PipelineDefaultPolicy
constexpr index_t NumIssues = kNPerBlock / (LaneGroups * NumWarps);
static_assert(NumIssues == kNPerBlock * kKPerBlock / (kBlockSize * KVector));
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor_with_offset(
make_tuple(number<NumIssues>{}, // n0
number<LaneGroups>{}, // n1
number<NumWarps>{}, // n2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
number<kKPerBlock>{},
number<WarpSize * KVector + kPad>{},
number<KVector>{},
number<1>{}),
number<(IBuf + 2) * GetSingleSmemElementSpaceSize<Problem>()>{},
number<KVector>{},
number<1>{});
constexpr auto v_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumIssues>{}, // n0
number<LaneGroups>{}, // n1
number<NumWarps>{}, // n2
number<LanesPerK>{}, // k0
number<KVector>{}), // k1
make_tuple(number<NumWarps*(WarpSize * KVector + kPad)>{},
number<kKPerBlock>{},
number<WarpSize * KVector + kPad>{},
number<KVector>{},
number<1>{}),
number<KVector>{},
number<1>{});
// TODO this layout is hard coded, and will be used in async copy buffer view load
// in LDS the real layout is (bufs, N0, N2, N1*K0*K1)
constexpr auto v_lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<NumIssues>{}),
make_pass_through_transform(number<NumWarps>{}),
make_merge_transform(make_tuple(
number<LaneGroups>{}, number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0>{}, sequence<2>{}, sequence<1, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
make_tuple(make_merge_transform(make_tuple(
number<NumIssues>{}, number<LaneGroups>{}, number<NumWarps>{})),
make_merge_transform(make_tuple(number<LanesPerK>{}, number<KVector>{}))),
make_tuple(sequence<0, 1, 2>{}, sequence<3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return v_lds_block_desc_issues_warps_lanes;
}

View File

@@ -369,6 +369,13 @@ using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl<
using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>,
2,
AttrNumAccess>>;
using WarpGemmMfma_f32_32x32x32_fp8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>,
2>>;

View File

@@ -170,6 +170,8 @@ template<WGAttrNumAccessEnum I> struct Dispatcher<pk_fp4_t, pk_fp4_t, float, 32,
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8<>; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8<EDouble>; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, true, false, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed<>; };
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, true, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8_CTransposed<EDouble>; };
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8<>; };
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 32, false, false, false, EDouble> { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8<EDouble>; };