[CK_TILE] Skip padded k/n fragment work in qr_hpad FMHA fwd (#6450)

## Motivation

`qr_hpad` currently executes work for padded head-dim fragments even
when only a subset of the values are valid. This adds unnecessary
computation for head dimensions that require padding, such as `hdim=72`
and `hdim=80`, and hurts FMHA forward performance.

The goal of this PR is to make the padded-head-dim path skip invalid
work based on the actual valid fragment count, while preserving the
existing behavior for the non-padded path.

## Technical Details

This PR improves the `qr_hpad` FMHA forward path in three parts:

- Skip padded `k`/`n` fragments in the GEMM/pipeline path when only part
of the fragment is valid.
- Add partial GEMM0 tail handling for `qr_hpad` so the kernel uses the
valid fragment range instead of always computing over the padded extent.
- Retune the gfx11 `qr_hpad` kernel configuration after enabling the
partial-fragment path.

To keep the existing path stable, the implementation adds overloads for
the updated GEMM/pipeline interfaces. This allows existing full-tile
callers to keep using the previous form, while the `qr_hpad` path can
pass valid fragment counts when needed.

## Test Plan

./build/bin/tile_example_fmha_fwd -prec=bf16 -mode={0/1} -b=1 -h=16
-d={72/80} -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1}

## Test Result

- On gfx11 and gfx12, for head dimensions that require padding,
`tile_example_fmha_fwd` shows about 20-30% performance improvement at
`hdim=72/80`.

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Hosang Yoon
2026-04-18 02:44:46 -04:00
committed by GitHub
parent 907c6e94ae
commit f5e00ec904
4 changed files with 478 additions and 128 deletions

View File

@@ -1194,18 +1194,15 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
if (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128):
return True
is_64x32_tile = kernel_ctx.tile.F_bm0 == 64 and kernel_ctx.tile.F_bn0 == 32
pads_hdim = (
kernel_ctx.pipeline.F_dpad == "t" and kernel_ctx.pipeline.F_dvpad == "t"
)
exact_hdim = (
kernel_ctx.pipeline.F_dpad == "f" and kernel_ctx.pipeline.F_dvpad == "f"
)
# For (128, 128) head dims, partial-fragment support in qr_hpad removes the need
# for the previous qr_hpad-specific handling that was added to avoid register spill.
# qr_hpad now reuses the regular 128x64 tile choice.
# The 64x64 tile remains disabled for qr_hpad because it is consistently slower
# in our measurements.
if kernel_ctx.tile.F_bm0 == 64 and kernel_ctx.tile.F_bn0 == 64:
return kernel_ctx.pipeline.tag != "qr_hpad"
if is_64x32_tile:
return pads_hdim
return exact_hdim
return True
rules.append(check_d128_tile_pipeline)
return rules
@@ -1218,8 +1215,7 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(128, 128) : [FmhaFwdTileSize( 64, 32, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, 6, CppConstraint("a.hdim_q != 128 || a.hdim_v != 128")),
FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 2048")),
FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)],
(192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(256, 256) : [FmhaFwdTileSize(128, 64, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)]

View File

@@ -39,6 +39,9 @@ struct FmhaFwdKernel
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
template <typename T>
using has_hdim_tail_args = decltype(T::kUseHdimTailArgs);
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
@@ -1891,6 +1894,35 @@ struct FmhaFwdKernel
}();
BlockIndices block_indices{i_batch, i_nhead, i_nhead_k};
constexpr bool kPassHdimTailArgs = [] {
if constexpr(ck_tile::is_detected<has_hdim_tail_args, FmhaPipeline>::value)
return static_cast<bool>(FmhaPipeline::kUseHdimTailArgs);
else
return false;
}();
auto invoke_fmha_pipeline = [&](auto&&... args) -> decltype(auto) {
if constexpr(kPassHdimTailArgs)
{
const ck_tile::index_t valid_k0_loops =
ck_tile::integer_divide_ceil(kargs.hdim_q, FmhaPipeline::kK0);
const ck_tile::index_t valid_last_k0_length =
kargs.hdim_q - (valid_k0_loops - 1) * FmhaPipeline::kK0;
const ck_tile::index_t valid_n1_length = [&]() {
const ck_tile::index_t remaining_n1 = kargs.hdim_v - i_n1;
return ck_tile::min(remaining_n1,
static_cast<ck_tile::index_t>(FmhaPipeline::kN1));
}();
return FmhaPipeline{}(static_cast<decltype(args)&&>(args)...,
sink_value,
valid_k0_loops,
valid_last_k0_length,
valid_n1_length);
}
else
{
return FmhaPipeline{}(static_cast<decltype(args)&&>(args)..., sink_value);
}
};
auto o_acc_tile = [&, i_nhead_ = i_nhead, i_nhead_k_ = i_nhead_k]() {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
@@ -1910,36 +1942,35 @@ struct FmhaFwdKernel
else
return ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o};
}();
return FmhaPipeline{}(q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales<remove_cvref_t<decltype(scale_p)>>{
scale_p}, // p_compute_element_func
o_acc_element_func, // o_acc_element_func
mask,
position_encoding,
variant_params.sm_scale,
variant,
variant_params,
block_indices,
smem_ptr,
dropout,
nullptr,
nullptr,
1,
make_null_tile_window(make_tuple()),
make_null_tile_window(make_tuple()),
make_null_tile_window(make_tuple()),
sink_value);
return invoke_fmha_pipeline(q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
scales<remove_cvref_t<decltype(scale_p)>>{
scale_p}, // p_compute_element_func
o_acc_element_func, // o_acc_element_func
mask,
position_encoding,
variant_params.sm_scale,
variant,
variant_params,
block_indices,
smem_ptr,
dropout,
nullptr,
nullptr,
1,
make_null_tile_window(make_tuple()),
make_null_tile_window(make_tuple()),
make_null_tile_window(make_tuple()));
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
@@ -1964,7 +1995,7 @@ struct FmhaFwdKernel
// Both P and rowsum are scaled by 2^shift, canceling in normalization
// No additional scaling needed in p_compute_element_func or o_acc_element_func
return FmhaPipeline{}(
return invoke_fmha_pipeline(
q_dram_window,
identity{}, // q_element_func
k_dram_window,
@@ -1992,8 +2023,7 @@ struct FmhaFwdKernel
kargs.block_scale_size_kv,
make_null_tile_window(make_tuple()),
make_null_tile_window(make_tuple()),
make_null_tile_window(make_tuple()),
sink_value);
make_null_tile_window(make_tuple()));
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX)
{
@@ -2098,53 +2128,51 @@ struct FmhaFwdKernel
number<FmhaPipeline::kK1 / kVScaleGranularity>{}),
{i_n1, 0});
return FmhaPipeline{}(q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
identity{}, // p_compute_element_func
identity{}, // o_acc_element_func
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout,
nullptr,
nullptr,
1,
q_scale_dram_window,
k_scale_dram_window,
v_scale_dram_window,
sink_value);
return invoke_fmha_pipeline(q_dram_window,
identity{}, // q_element_func
k_dram_window,
identity{}, // k_element_func
v_dram_window,
identity{}, // v_element_func
bias_dram_window,
identity{}, // bias_element_func
randval_dram_window,
lse_dram_window,
identity{}, // lse_element_func
identity{}, // s_acc_element_func
identity{}, // p_compute_element_func
identity{}, // o_acc_element_func
mask,
position_encoding,
kargs.scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout,
nullptr,
nullptr,
1,
q_scale_dram_window,
k_scale_dram_window,
v_scale_dram_window);
}
else
{
return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
randval_dram_window,
lse_dram_window,
mask,
position_encoding,
variant_params.sm_scale,
variant,
variant_params,
block_indices,
smem_ptr,
dropout,
sink_value);
return invoke_fmha_pipeline(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
randval_dram_window,
lse_dram_window,
mask,
position_encoding,
variant_params.sm_scale,
variant,
variant_params,
block_indices,
smem_ptr,
dropout);
}
}();

View File

@@ -39,6 +39,11 @@ struct BlockFmhaPipelineQRKSVS
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
template <typename T>
using has_partial_k_support = decltype(T::kSupportsPartialK);
template <typename T>
using has_partial_n_support = decltype(T::kSupportsPartialN);
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
@@ -68,6 +73,7 @@ struct BlockFmhaPipelineQRKSVS
static constexpr auto QScaleEnum = Problem::QScaleEnum;
static constexpr bool kHasSink = Problem::kHasSink;
static constexpr bool kPaddedVecLoadStore = PaddedVecLoadStore_;
static constexpr bool kUseHdimTailArgs = kPadHeadDimQ || kPadHeadDimV;
static constexpr ck_tile::index_t kQKScaleGranularity = Problem::kQKScaleGranularity;
static constexpr ck_tile::index_t kVScaleGranularity = Problem::kVScaleGranularity;
@@ -203,7 +209,10 @@ struct BlockFmhaPipelineQRKSVS
k_scale_dram_block_window_tmp, // N0*(K0/kQKScaleGranularity) tile
const VScaleDramBlockWindowTmp&
v_scale_dram_block_window_tmp, // N1*(K1/kVScaleGranularity) tile
const float sink_v) const
const float sink_v,
const index_t valid_k0_loops,
const index_t valid_last_k0_length,
const index_t valid_n1_length) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
@@ -261,8 +270,30 @@ struct BlockFmhaPipelineQRKSVS
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
using BlockGemm1 = remove_cvref_t<decltype(gemm_1)>;
constexpr bool kBlockGemm0SupportsPartialK = [] {
if constexpr(ck_tile::is_detected<has_partial_k_support, BlockGemm0>::value)
return static_cast<bool>(BlockGemm0::kSupportsPartialK);
else
return false;
}();
constexpr bool kBlockGemm1SupportsPartialN = [] {
if constexpr(ck_tile::is_detected<has_partial_n_support, BlockGemm1>::value)
return static_cast<bool>(BlockGemm1::kSupportsPartialN);
else
return false;
}();
constexpr auto gemm_0_config =
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using Gemm0WarpGemm = remove_cvref_t<decltype(gemm_0_config.template at<0>())>;
constexpr index_t kGemm0WarpK = Gemm0WarpGemm::kK;
constexpr index_t kGemm0KItersPerBlock = kK0 / kGemm0WarpK;
constexpr bool kUsePartialKForGemm0Tail =
kPadHeadDimQ && kBlockGemm0SupportsPartialK && (kGemm0KItersPerBlock > 1);
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
@@ -428,10 +459,26 @@ struct BlockFmhaPipelineQRKSVS
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kN0 / kK1;
// Number of k0 iterations prefetched ahead of the current compute iteration.
// The skip decision must be made this many iterations before the last k0 loop.
constexpr index_t kK0PrefetchDepth = 2;
const index_t gemm0_tail_k_iters = [&]() {
if constexpr(kUsePartialKForGemm0Tail)
{
return ck_tile::integer_divide_ceil(valid_last_k0_length, kGemm0WarpK);
}
return static_cast<index_t>(kGemm0KItersPerBlock);
}();
const bool skip_last_k0_loop = [&]() {
if constexpr(kPadHeadDimQ)
{
return valid_k0_loops == (k0_loops - 1);
}
return false;
}();
// Use compile-time conditional for group barrier sequence
// (No runtime lambda selection)
auto schedule_gemm_0 = [] {
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
constexpr auto WarpGemmConfig =
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm0 = remove_cvref_t<decltype(WarpGemmConfig.template at<0>())>;
@@ -456,7 +503,7 @@ struct BlockFmhaPipelineQRKSVS
}
};
static_assert(2 <= k0_loops);
static_assert(kK0PrefetchDepth <= k0_loops);
static_assert(1 <= k1_loops);
do
{
@@ -523,6 +570,46 @@ struct BlockFmhaPipelineQRKSVS
}
auto run_gemm_0 = [&](auto i_k0) {
if constexpr(kUsePartialKForGemm0Tail)
{
if(static_cast<index_t>(i_k0.value) == (valid_k0_loops - 1) &&
gemm0_tail_k_iters < kGemm0KItersPerBlock)
{
static_for<1, kGemm0KItersPerBlock, 1>{}([&](auto i_tail_k_iter) {
constexpr index_t kTailKIters = i_tail_k_iter;
constexpr index_t kTailK0 = kTailKIters * kGemm0WarpK;
if(gemm0_tail_k_iters == kTailKIters)
{
using Gemm0TailProblem = BlockGemmProblem<
QDataType,
KDataType,
SaccDataType,
Problem::kNumGemm0Warps * get_warp_size(),
TileGemmShape<
sequence<kM0, kN0, kTailK0>,
typename BlockFmhaShape::Gemm0BlockWarps,
sequence<BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
kGemm0WarpK>>>;
constexpr auto gemm_0_tail =
BlockGemmARegBSmemCRegV2<Gemm0TailProblem,
typename BlockGemm0::Policy>{};
auto q_slice =
get_slice_tile(q_tile,
sequence<0, i_k0 * kK0>{},
sequence<kM0, i_k0 * kK0 + kTailK0>{});
auto k_tail_window = make_tile_window(
k_lds, make_tuple(number<kN0>{}, number<kTailK0>{}), {0, 0});
gemm_0_tail(s_acc, q_slice, k_tail_window);
}
});
return;
}
}
auto q_slice = get_slice_tile(
q_tile, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{});
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX)
@@ -540,19 +627,37 @@ struct BlockFmhaPipelineQRKSVS
}
};
if constexpr(k0_loops > 2)
if constexpr(k0_loops > kK0PrefetchDepth)
{
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
static_for<0, k0_loops - kK0PrefetchDepth, 1>{}([&](auto i_k0) {
block_sync_lds();
run_gemm_0(number<i_k0>{});
block_sync_lds();
move_tile_window(k_dram_window, {0, kK0});
if constexpr(kPadHeadDimQ && i_k0 == (k0_loops - 1 - kK0PrefetchDepth))
{
if(!skip_last_k0_loop)
{
move_tile_window(k_dram_window, {0, kK0});
}
store_tile(
k_lds_window,
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
k_block_tile = load_tile(k_dram_window); // global read i + 2
store_tile(
k_lds_window,
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
if(!skip_last_k0_loop)
{
k_block_tile = load_tile(k_dram_window); // global read i + 2
}
}
else
{
move_tile_window(k_dram_window, {0, kK0});
store_tile(
k_lds_window,
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
k_block_tile = load_tile(k_dram_window); // global read i + 2
}
k_scale_block_tile = load_k_scale_block_tile();
});
}
@@ -577,16 +682,19 @@ struct BlockFmhaPipelineQRKSVS
}
{ // tail
block_sync_lds();
run_gemm_0(number<k0_loops - 2>{});
block_sync_lds();
run_gemm_0(number<k0_loops - kK0PrefetchDepth>{});
if(!skip_last_k0_loop)
{
block_sync_lds();
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
k_scale_block_tile = load_k_scale_block_tile();
k_scale_block_tile = load_k_scale_block_tile();
block_sync_lds();
block_sync_lds();
run_gemm_0(number<k0_loops - 1>{});
run_gemm_0(number<k0_loops - 1>{});
}
}
if constexpr(kVPrefetch == VPrefetchPoint::AfterGemm0Tail)
{
@@ -933,6 +1041,31 @@ struct BlockFmhaPipelineQRKSVS
auto o_acc0 = decltype(o_acc){};
clear_tile(o_acc0);
constexpr auto gemm_1_config =
BlockGemm1::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using Gemm1WarpGemm = remove_cvref_t<decltype(gemm_1_config.template at<0>())>;
constexpr index_t kGemm1NWarp = gemm_1_config.template at<2>();
constexpr index_t kGemm1NPerIter = kGemm1NWarp * Gemm1WarpGemm::kN;
const index_t valid_n_iters = [&]() {
if constexpr(kPadHeadDimV && kBlockGemm1SupportsPartialN)
{
return ck_tile::integer_divide_ceil(valid_n1_length, kGemm1NPerIter);
}
return static_cast<index_t>(0);
}();
auto run_gemm_1_impl =
[&](auto& o_acc_tensor, const auto& p_slice, const auto&... gemm_1_args) {
if constexpr(kPadHeadDimV && kBlockGemm1SupportsPartialN)
{
gemm_1(o_acc_tensor, p_slice, gemm_1_args..., valid_n_iters);
}
else
{
gemm_1(o_acc_tensor, p_slice, gemm_1_args...);
}
};
auto run_gemm_1 = [&](auto i_k1) {
auto p_slice =
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{});
@@ -942,15 +1075,19 @@ struct BlockFmhaPipelineQRKSVS
get_slice_tile(p_scale,
sequence<0, i_k1*(kK1 / kVScaleGranularity)>{},
sequence<kM0, (i_k1 + 1) * (kK1 / kVScaleGranularity)>{});
gemm_1(o_acc, p_slice, p_scale_slice, v_lds_window, v_scale_block_tile);
}
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
gemm_1(o_acc0, p_slice, v_lds_window);
run_gemm_1_impl(
o_acc, p_slice, p_scale_slice, v_lds_window, v_scale_block_tile);
}
else
{
gemm_1(o_acc, p_slice, v_lds_window);
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
run_gemm_1_impl(o_acc0, p_slice, v_lds_window);
}
else
{
run_gemm_1_impl(o_acc, p_slice, v_lds_window);
}
}
};
@@ -1075,6 +1212,94 @@ struct BlockFmhaPipelineQRKSVS
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction,
typename BiasElementFunction,
typename LSEElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices,
typename QScaleDramBlockWindowTmp,
typename KScaleDramBlockWindowTmp,
typename VScaleDramBlockWindowTmp>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
const LSEElementFunction& lse_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout,
const float* k_descale_ptr,
const float* v_descale_ptr,
const index_t block_scale_size_kv,
const QScaleDramBlockWindowTmp&
q_scale_dram_block_window_tmp, // M0*(K0/kQKScaleGranularity) tile
const KScaleDramBlockWindowTmp&
k_scale_dram_block_window_tmp, // N0*(K0/kQKScaleGranularity) tile
const VScaleDramBlockWindowTmp&
v_scale_dram_block_window_tmp, // N1*(K1/kVScaleGranularity) tile
const float sink_v) const
{
return operator()(q_dram_block_window_tmp,
q_element_func,
k_dram_block_window_tmp,
k_element_func,
v_dram_block_window_tmp,
v_element_func,
bias_dram_block_window_tmp,
bias_element_func,
randval_dram_block_window_tmp,
lse_dram_window_tmp,
lse_element_func,
s_acc_element_func,
p_compute_element_func,
o_acc_element_func,
mask,
position_encoding,
scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout,
k_descale_ptr,
v_descale_ptr,
block_scale_size_kv,
q_scale_dram_block_window_tmp,
k_scale_dram_block_window_tmp,
v_scale_dram_block_window_tmp,
sink_v,
kQKHeaddim / kK0,
kK0,
kN1);
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
@@ -1099,7 +1324,10 @@ struct BlockFmhaPipelineQRKSVS
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout,
const float sink_v) const
const float sink_v,
const index_t valid_k0_loops,
const index_t valid_last_k0_length,
const index_t valid_n1_length) const
{
return operator()(q_dram_block_window_tmp,
identity{},
@@ -1129,7 +1357,56 @@ struct BlockFmhaPipelineQRKSVS
make_null_tile_window(make_tuple()),
make_null_tile_window(make_tuple()),
make_null_tile_window(make_tuple()),
sink_v);
sink_v,
valid_k0_loops,
valid_last_k0_length,
valid_n1_length);
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename PositionEncoding,
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
const AttentionVariant& variant,
const AttentionVariantParams& variant_params,
const BlockIndices& block_indices,
void* smem_ptr,
DropoutType& dropout,
const float sink_v) const
{
return operator()(q_dram_block_window_tmp,
k_dram_block_window_tmp,
v_dram_block_window_tmp,
bias_dram_block_window_tmp,
randval_dram_block_window_tmp,
lse_dram_block_window_tmp,
mask,
position_encoding,
scale_s,
variant,
variant_params,
block_indices,
smem_ptr,
dropout,
sink_v,
kQKHeaddim / kK0,
kK0,
kN1);
}
};

View File

@@ -21,13 +21,18 @@ struct BlockGemmARegBSmemCRegV2
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr bool kSupportsPartialK = true;
static constexpr bool kSupportsPartialN = true;
// C += A * B
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
template <bool UsePartialN,
typename CBlockTensor,
typename ABlockTensorTmp,
typename BBlockWindowTmp>
CK_TILE_DEVICE void Impl(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp,
[[maybe_unused]] const index_t valid_n_iters) const
{
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
@@ -134,10 +139,7 @@ struct BlockGemmARegBSmemCRegV2
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
constexpr auto kIter = number<kn[number<0>{}]>{};
constexpr auto nIter = number<kn[number<1>{}]>{};
auto run_n_iter = [&](auto kIter, auto nIter) {
// read B warp tensor from B Block window
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
@@ -166,7 +168,44 @@ struct BlockGemmARegBSmemCRegV2
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
};
// hot loop:
if constexpr(UsePartialN)
{
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
if(static_cast<index_t>(nIter.value) < valid_n_iters)
{
run_n_iter(kIter, nIter);
}
});
});
}
else
{
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { run_n_iter(kIter, nIter); });
});
}
}
// C += A * B (executing only the first valid_n_iters N sub-iterations)
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp,
const index_t valid_n_iters) const
{
Impl<true>(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp, valid_n_iters);
}
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const
{
Impl<false>(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp, 0);
}
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
@@ -227,7 +266,17 @@ struct BlockGemmARegBSmemCRegV2
return c_block_tensor;
}
// C = A * B
// C = A * B (executing only the first valid_n_iters N sub-iterations)
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp,
const index_t valid_n_iters) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp, valid_n_iters);
return c_block_tensor;
}
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const