mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
[CK_TILE] Skip padded k/n fragment work in qr_hpad FMHA fwd (#6450)
## Motivation
`qr_hpad` currently executes work for padded head-dim fragments even
when only a subset of the values are valid. This adds unnecessary
computation for head dimensions that require padding, such as `hdim=72`
and `hdim=80`, and hurts FMHA forward performance.
The goal of this PR is to make the padded-head-dim path skip invalid
work based on the actual valid fragment count, while preserving the
existing behavior for the non-padded path.
## Technical Details
This PR improves the `qr_hpad` FMHA forward path in three parts:
- Skip padded `k`/`n` fragments in the GEMM/pipeline path when only part
of the fragment is valid.
- Add partial GEMM0 tail handling for `qr_hpad` so the kernel uses the
valid fragment range instead of always computing over the padded extent.
- Retune the gfx11 `qr_hpad` kernel configuration after enabling the
partial-fragment path.
To keep the existing path stable, the implementation adds overloads for
the updated GEMM/pipeline interfaces. This allows existing full-tile
callers to keep using the previous form, while the `qr_hpad` path can
pass valid fragment counts when needed.
## Test Plan
./build/bin/tile_example_fmha_fwd -prec=bf16 -mode={0/1} -b=1 -h=16
-d={72/80} -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1}
## Test Result
- On gfx11 and gfx12, for head dimensions that require padding,
`tile_example_fmha_fwd` shows about 20-30% performance improvement at
`hdim=72/80`.
## Submission Checklist
- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
@@ -1194,18 +1194,15 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
|
||||
if (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128):
|
||||
return True
|
||||
|
||||
is_64x32_tile = kernel_ctx.tile.F_bm0 == 64 and kernel_ctx.tile.F_bn0 == 32
|
||||
pads_hdim = (
|
||||
kernel_ctx.pipeline.F_dpad == "t" and kernel_ctx.pipeline.F_dvpad == "t"
|
||||
)
|
||||
exact_hdim = (
|
||||
kernel_ctx.pipeline.F_dpad == "f" and kernel_ctx.pipeline.F_dvpad == "f"
|
||||
)
|
||||
# For (128, 128) head dims, partial-fragment support in qr_hpad removes the need
|
||||
# for the previous qr_hpad-specific handling that was added to avoid register spill.
|
||||
# qr_hpad now reuses the regular 128x64 tile choice.
|
||||
# The 64x64 tile remains disabled for qr_hpad because it is consistently slower
|
||||
# in our measurements.
|
||||
if kernel_ctx.tile.F_bm0 == 64 and kernel_ctx.tile.F_bn0 == 64:
|
||||
return kernel_ctx.pipeline.tag != "qr_hpad"
|
||||
|
||||
if is_64x32_tile:
|
||||
return pads_hdim
|
||||
|
||||
return exact_hdim
|
||||
return True
|
||||
|
||||
rules.append(check_d128_tile_pipeline)
|
||||
return rules
|
||||
@@ -1218,8 +1215,7 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
|
||||
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
|
||||
FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize( 64, 32, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, 6, CppConstraint("a.hdim_q != 128 || a.hdim_v != 128")),
|
||||
FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
|
||||
(128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 2048")),
|
||||
FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)],
|
||||
(192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(256, 256) : [FmhaFwdTileSize(128, 64, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)]
|
||||
|
||||
@@ -39,6 +39,9 @@ struct FmhaFwdKernel
|
||||
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
|
||||
|
||||
template <typename T>
|
||||
using has_hdim_tail_args = decltype(T::kUseHdimTailArgs);
|
||||
|
||||
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
|
||||
static_assert(kBlockPerCu > 0);
|
||||
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
|
||||
@@ -1891,6 +1894,35 @@ struct FmhaFwdKernel
|
||||
}();
|
||||
|
||||
BlockIndices block_indices{i_batch, i_nhead, i_nhead_k};
|
||||
constexpr bool kPassHdimTailArgs = [] {
|
||||
if constexpr(ck_tile::is_detected<has_hdim_tail_args, FmhaPipeline>::value)
|
||||
return static_cast<bool>(FmhaPipeline::kUseHdimTailArgs);
|
||||
else
|
||||
return false;
|
||||
}();
|
||||
auto invoke_fmha_pipeline = [&](auto&&... args) -> decltype(auto) {
|
||||
if constexpr(kPassHdimTailArgs)
|
||||
{
|
||||
const ck_tile::index_t valid_k0_loops =
|
||||
ck_tile::integer_divide_ceil(kargs.hdim_q, FmhaPipeline::kK0);
|
||||
const ck_tile::index_t valid_last_k0_length =
|
||||
kargs.hdim_q - (valid_k0_loops - 1) * FmhaPipeline::kK0;
|
||||
const ck_tile::index_t valid_n1_length = [&]() {
|
||||
const ck_tile::index_t remaining_n1 = kargs.hdim_v - i_n1;
|
||||
return ck_tile::min(remaining_n1,
|
||||
static_cast<ck_tile::index_t>(FmhaPipeline::kN1));
|
||||
}();
|
||||
return FmhaPipeline{}(static_cast<decltype(args)&&>(args)...,
|
||||
sink_value,
|
||||
valid_k0_loops,
|
||||
valid_last_k0_length,
|
||||
valid_n1_length);
|
||||
}
|
||||
else
|
||||
{
|
||||
return FmhaPipeline{}(static_cast<decltype(args)&&>(args)..., sink_value);
|
||||
}
|
||||
};
|
||||
|
||||
auto o_acc_tile = [&, i_nhead_ = i_nhead, i_nhead_k_ = i_nhead_k]() {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
@@ -1910,36 +1942,35 @@ struct FmhaFwdKernel
|
||||
else
|
||||
return ck_tile::scales<remove_cvref_t<decltype(scale_o)>>{scale_o};
|
||||
}();
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales<remove_cvref_t<decltype(scale_p)>>{
|
||||
scale_p}, // p_compute_element_func
|
||||
o_acc_element_func, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
variant_params.sm_scale,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
nullptr,
|
||||
nullptr,
|
||||
1,
|
||||
make_null_tile_window(make_tuple()),
|
||||
make_null_tile_window(make_tuple()),
|
||||
make_null_tile_window(make_tuple()),
|
||||
sink_value);
|
||||
return invoke_fmha_pipeline(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
scales<remove_cvref_t<decltype(scale_p)>>{
|
||||
scale_p}, // p_compute_element_func
|
||||
o_acc_element_func, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
variant_params.sm_scale,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
nullptr,
|
||||
nullptr,
|
||||
1,
|
||||
make_null_tile_window(make_tuple()),
|
||||
make_null_tile_window(make_tuple()),
|
||||
make_null_tile_window(make_tuple()));
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
@@ -1964,7 +1995,7 @@ struct FmhaFwdKernel
|
||||
// Both P and rowsum are scaled by 2^shift, canceling in normalization
|
||||
// No additional scaling needed in p_compute_element_func or o_acc_element_func
|
||||
|
||||
return FmhaPipeline{}(
|
||||
return invoke_fmha_pipeline(
|
||||
q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
@@ -1992,8 +2023,7 @@ struct FmhaFwdKernel
|
||||
kargs.block_scale_size_kv,
|
||||
make_null_tile_window(make_tuple()),
|
||||
make_null_tile_window(make_tuple()),
|
||||
make_null_tile_window(make_tuple()),
|
||||
sink_value);
|
||||
make_null_tile_window(make_tuple()));
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX)
|
||||
{
|
||||
@@ -2098,53 +2128,51 @@ struct FmhaFwdKernel
|
||||
number<FmhaPipeline::kK1 / kVScaleGranularity>{}),
|
||||
{i_n1, 0});
|
||||
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
identity{}, // p_compute_element_func
|
||||
identity{}, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
nullptr,
|
||||
nullptr,
|
||||
1,
|
||||
q_scale_dram_window,
|
||||
k_scale_dram_window,
|
||||
v_scale_dram_window,
|
||||
sink_value);
|
||||
return invoke_fmha_pipeline(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
identity{}, // p_compute_element_func
|
||||
identity{}, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
nullptr,
|
||||
nullptr,
|
||||
1,
|
||||
q_scale_dram_window,
|
||||
k_scale_dram_window,
|
||||
v_scale_dram_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
mask,
|
||||
position_encoding,
|
||||
variant_params.sm_scale,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
sink_value);
|
||||
return invoke_fmha_pipeline(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
mask,
|
||||
position_encoding,
|
||||
variant_params.sm_scale,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -39,6 +39,11 @@ struct BlockFmhaPipelineQRKSVS
|
||||
using AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant>;
|
||||
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
|
||||
|
||||
template <typename T>
|
||||
using has_partial_k_support = decltype(T::kSupportsPartialK);
|
||||
template <typename T>
|
||||
using has_partial_n_support = decltype(T::kSupportsPartialN);
|
||||
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
|
||||
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
|
||||
@@ -68,6 +73,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
static constexpr auto QScaleEnum = Problem::QScaleEnum;
|
||||
static constexpr bool kHasSink = Problem::kHasSink;
|
||||
static constexpr bool kPaddedVecLoadStore = PaddedVecLoadStore_;
|
||||
static constexpr bool kUseHdimTailArgs = kPadHeadDimQ || kPadHeadDimV;
|
||||
|
||||
static constexpr ck_tile::index_t kQKScaleGranularity = Problem::kQKScaleGranularity;
|
||||
static constexpr ck_tile::index_t kVScaleGranularity = Problem::kVScaleGranularity;
|
||||
@@ -203,7 +209,10 @@ struct BlockFmhaPipelineQRKSVS
|
||||
k_scale_dram_block_window_tmp, // N0*(K0/kQKScaleGranularity) tile
|
||||
const VScaleDramBlockWindowTmp&
|
||||
v_scale_dram_block_window_tmp, // N1*(K1/kVScaleGranularity) tile
|
||||
const float sink_v) const
|
||||
const float sink_v,
|
||||
const index_t valid_k0_loops,
|
||||
const index_t valid_last_k0_length,
|
||||
const index_t valid_n1_length) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
@@ -261,8 +270,30 @@ struct BlockFmhaPipelineQRKSVS
|
||||
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
|
||||
using BlockGemm1 = remove_cvref_t<decltype(gemm_1)>;
|
||||
constexpr bool kBlockGemm0SupportsPartialK = [] {
|
||||
if constexpr(ck_tile::is_detected<has_partial_k_support, BlockGemm0>::value)
|
||||
return static_cast<bool>(BlockGemm0::kSupportsPartialK);
|
||||
else
|
||||
return false;
|
||||
}();
|
||||
constexpr bool kBlockGemm1SupportsPartialN = [] {
|
||||
if constexpr(ck_tile::is_detected<has_partial_n_support, BlockGemm1>::value)
|
||||
return static_cast<bool>(BlockGemm1::kSupportsPartialN);
|
||||
else
|
||||
return false;
|
||||
}();
|
||||
|
||||
constexpr auto gemm_0_config =
|
||||
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using Gemm0WarpGemm = remove_cvref_t<decltype(gemm_0_config.template at<0>())>;
|
||||
constexpr index_t kGemm0WarpK = Gemm0WarpGemm::kK;
|
||||
constexpr index_t kGemm0KItersPerBlock = kK0 / kGemm0WarpK;
|
||||
constexpr bool kUsePartialKForGemm0Tail =
|
||||
kPadHeadDimQ && kBlockGemm0SupportsPartialK && (kGemm0KItersPerBlock > 1);
|
||||
|
||||
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
@@ -428,10 +459,26 @@ struct BlockFmhaPipelineQRKSVS
|
||||
index_t i_total_loops = 0;
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
// Number of k0 iterations prefetched ahead of the current compute iteration.
|
||||
// The skip decision must be made this many iterations before the last k0 loop.
|
||||
constexpr index_t kK0PrefetchDepth = 2;
|
||||
const index_t gemm0_tail_k_iters = [&]() {
|
||||
if constexpr(kUsePartialKForGemm0Tail)
|
||||
{
|
||||
return ck_tile::integer_divide_ceil(valid_last_k0_length, kGemm0WarpK);
|
||||
}
|
||||
return static_cast<index_t>(kGemm0KItersPerBlock);
|
||||
}();
|
||||
const bool skip_last_k0_loop = [&]() {
|
||||
if constexpr(kPadHeadDimQ)
|
||||
{
|
||||
return valid_k0_loops == (k0_loops - 1);
|
||||
}
|
||||
return false;
|
||||
}();
|
||||
// Use compile-time conditional for group barrier sequence
|
||||
// (No runtime lambda selection)
|
||||
auto schedule_gemm_0 = [] {
|
||||
using BlockGemm0 = remove_cvref_t<decltype(gemm_0)>;
|
||||
constexpr auto WarpGemmConfig =
|
||||
BlockGemm0::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm0 = remove_cvref_t<decltype(WarpGemmConfig.template at<0>())>;
|
||||
@@ -456,7 +503,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
}
|
||||
};
|
||||
|
||||
static_assert(2 <= k0_loops);
|
||||
static_assert(kK0PrefetchDepth <= k0_loops);
|
||||
static_assert(1 <= k1_loops);
|
||||
do
|
||||
{
|
||||
@@ -523,6 +570,46 @@ struct BlockFmhaPipelineQRKSVS
|
||||
}
|
||||
|
||||
auto run_gemm_0 = [&](auto i_k0) {
|
||||
if constexpr(kUsePartialKForGemm0Tail)
|
||||
{
|
||||
if(static_cast<index_t>(i_k0.value) == (valid_k0_loops - 1) &&
|
||||
gemm0_tail_k_iters < kGemm0KItersPerBlock)
|
||||
{
|
||||
static_for<1, kGemm0KItersPerBlock, 1>{}([&](auto i_tail_k_iter) {
|
||||
constexpr index_t kTailKIters = i_tail_k_iter;
|
||||
constexpr index_t kTailK0 = kTailKIters * kGemm0WarpK;
|
||||
|
||||
if(gemm0_tail_k_iters == kTailKIters)
|
||||
{
|
||||
using Gemm0TailProblem = BlockGemmProblem<
|
||||
QDataType,
|
||||
KDataType,
|
||||
SaccDataType,
|
||||
Problem::kNumGemm0Warps * get_warp_size(),
|
||||
TileGemmShape<
|
||||
sequence<kM0, kN0, kTailK0>,
|
||||
typename BlockFmhaShape::Gemm0BlockWarps,
|
||||
sequence<BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
|
||||
BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
|
||||
kGemm0WarpK>>>;
|
||||
constexpr auto gemm_0_tail =
|
||||
BlockGemmARegBSmemCRegV2<Gemm0TailProblem,
|
||||
typename BlockGemm0::Policy>{};
|
||||
|
||||
auto q_slice =
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, i_k0 * kK0>{},
|
||||
sequence<kM0, i_k0 * kK0 + kTailK0>{});
|
||||
auto k_tail_window = make_tile_window(
|
||||
k_lds, make_tuple(number<kN0>{}, number<kTailK0>{}), {0, 0});
|
||||
|
||||
gemm_0_tail(s_acc, q_slice, k_tail_window);
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
auto q_slice = get_slice_tile(
|
||||
q_tile, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{});
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX)
|
||||
@@ -540,19 +627,37 @@ struct BlockFmhaPipelineQRKSVS
|
||||
}
|
||||
};
|
||||
|
||||
if constexpr(k0_loops > 2)
|
||||
if constexpr(k0_loops > kK0PrefetchDepth)
|
||||
{
|
||||
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
|
||||
static_for<0, k0_loops - kK0PrefetchDepth, 1>{}([&](auto i_k0) {
|
||||
block_sync_lds();
|
||||
run_gemm_0(number<i_k0>{});
|
||||
block_sync_lds();
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
if constexpr(kPadHeadDimQ && i_k0 == (k0_loops - 1 - kK0PrefetchDepth))
|
||||
{
|
||||
if(!skip_last_k0_loop)
|
||||
{
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
}
|
||||
|
||||
store_tile(
|
||||
k_lds_window,
|
||||
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
|
||||
k_block_tile = load_tile(k_dram_window); // global read i + 2
|
||||
store_tile(
|
||||
k_lds_window,
|
||||
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
|
||||
|
||||
if(!skip_last_k0_loop)
|
||||
{
|
||||
k_block_tile = load_tile(k_dram_window); // global read i + 2
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
|
||||
store_tile(
|
||||
k_lds_window,
|
||||
tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1
|
||||
k_block_tile = load_tile(k_dram_window); // global read i + 2
|
||||
}
|
||||
k_scale_block_tile = load_k_scale_block_tile();
|
||||
});
|
||||
}
|
||||
@@ -577,16 +682,19 @@ struct BlockFmhaPipelineQRKSVS
|
||||
}
|
||||
{ // tail
|
||||
block_sync_lds();
|
||||
run_gemm_0(number<k0_loops - 2>{});
|
||||
block_sync_lds();
|
||||
run_gemm_0(number<k0_loops - kK0PrefetchDepth>{});
|
||||
if(!skip_last_k0_loop)
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
|
||||
|
||||
k_scale_block_tile = load_k_scale_block_tile();
|
||||
k_scale_block_tile = load_k_scale_block_tile();
|
||||
|
||||
block_sync_lds();
|
||||
block_sync_lds();
|
||||
|
||||
run_gemm_0(number<k0_loops - 1>{});
|
||||
run_gemm_0(number<k0_loops - 1>{});
|
||||
}
|
||||
}
|
||||
if constexpr(kVPrefetch == VPrefetchPoint::AfterGemm0Tail)
|
||||
{
|
||||
@@ -933,6 +1041,31 @@ struct BlockFmhaPipelineQRKSVS
|
||||
auto o_acc0 = decltype(o_acc){};
|
||||
clear_tile(o_acc0);
|
||||
|
||||
constexpr auto gemm_1_config =
|
||||
BlockGemm1::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using Gemm1WarpGemm = remove_cvref_t<decltype(gemm_1_config.template at<0>())>;
|
||||
constexpr index_t kGemm1NWarp = gemm_1_config.template at<2>();
|
||||
constexpr index_t kGemm1NPerIter = kGemm1NWarp * Gemm1WarpGemm::kN;
|
||||
const index_t valid_n_iters = [&]() {
|
||||
if constexpr(kPadHeadDimV && kBlockGemm1SupportsPartialN)
|
||||
{
|
||||
return ck_tile::integer_divide_ceil(valid_n1_length, kGemm1NPerIter);
|
||||
}
|
||||
return static_cast<index_t>(0);
|
||||
}();
|
||||
|
||||
auto run_gemm_1_impl =
|
||||
[&](auto& o_acc_tensor, const auto& p_slice, const auto&... gemm_1_args) {
|
||||
if constexpr(kPadHeadDimV && kBlockGemm1SupportsPartialN)
|
||||
{
|
||||
gemm_1(o_acc_tensor, p_slice, gemm_1_args..., valid_n_iters);
|
||||
}
|
||||
else
|
||||
{
|
||||
gemm_1(o_acc_tensor, p_slice, gemm_1_args...);
|
||||
}
|
||||
};
|
||||
|
||||
auto run_gemm_1 = [&](auto i_k1) {
|
||||
auto p_slice =
|
||||
get_slice_tile(p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{});
|
||||
@@ -942,15 +1075,19 @@ struct BlockFmhaPipelineQRKSVS
|
||||
get_slice_tile(p_scale,
|
||||
sequence<0, i_k1*(kK1 / kVScaleGranularity)>{},
|
||||
sequence<kM0, (i_k1 + 1) * (kK1 / kVScaleGranularity)>{});
|
||||
gemm_1(o_acc, p_slice, p_scale_slice, v_lds_window, v_scale_block_tile);
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
gemm_1(o_acc0, p_slice, v_lds_window);
|
||||
run_gemm_1_impl(
|
||||
o_acc, p_slice, p_scale_slice, v_lds_window, v_scale_block_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
gemm_1(o_acc, p_slice, v_lds_window);
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
run_gemm_1_impl(o_acc0, p_slice, v_lds_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
run_gemm_1_impl(o_acc, p_slice, v_lds_window);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1075,6 +1212,94 @@ struct BlockFmhaPipelineQRKSVS
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
typename VElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename LSEElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices,
|
||||
typename QScaleDramBlockWindowTmp,
|
||||
typename KScaleDramBlockWindowTmp,
|
||||
typename VScaleDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const VElementFunction& v_element_func,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
|
||||
LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
|
||||
const LSEElementFunction& lse_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout,
|
||||
const float* k_descale_ptr,
|
||||
const float* v_descale_ptr,
|
||||
const index_t block_scale_size_kv,
|
||||
const QScaleDramBlockWindowTmp&
|
||||
q_scale_dram_block_window_tmp, // M0*(K0/kQKScaleGranularity) tile
|
||||
const KScaleDramBlockWindowTmp&
|
||||
k_scale_dram_block_window_tmp, // N0*(K0/kQKScaleGranularity) tile
|
||||
const VScaleDramBlockWindowTmp&
|
||||
v_scale_dram_block_window_tmp, // N1*(K1/kVScaleGranularity) tile
|
||||
const float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
q_element_func,
|
||||
k_dram_block_window_tmp,
|
||||
k_element_func,
|
||||
v_dram_block_window_tmp,
|
||||
v_element_func,
|
||||
bias_dram_block_window_tmp,
|
||||
bias_element_func,
|
||||
randval_dram_block_window_tmp,
|
||||
lse_dram_window_tmp,
|
||||
lse_element_func,
|
||||
s_acc_element_func,
|
||||
p_compute_element_func,
|
||||
o_acc_element_func,
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
block_scale_size_kv,
|
||||
q_scale_dram_block_window_tmp,
|
||||
k_scale_dram_block_window_tmp,
|
||||
v_scale_dram_block_window_tmp,
|
||||
sink_v,
|
||||
kQKHeaddim / kK0,
|
||||
kK0,
|
||||
kN1);
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
@@ -1099,7 +1324,10 @@ struct BlockFmhaPipelineQRKSVS
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
const float sink_v,
|
||||
const index_t valid_k0_loops,
|
||||
const index_t valid_last_k0_length,
|
||||
const index_t valid_n1_length) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
@@ -1129,7 +1357,56 @@ struct BlockFmhaPipelineQRKSVS
|
||||
make_null_tile_window(make_tuple()),
|
||||
make_null_tile_window(make_tuple()),
|
||||
make_null_tile_window(make_tuple()),
|
||||
sink_v);
|
||||
sink_v,
|
||||
valid_k0_loops,
|
||||
valid_last_k0_length,
|
||||
valid_n1_length);
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename RandValDramBlockWindowTmp,
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename PositionEncoding,
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
|
||||
LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float scale_s,
|
||||
const AttentionVariant& variant,
|
||||
const AttentionVariantParams& variant_params,
|
||||
const BlockIndices& block_indices,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout,
|
||||
const float sink_v) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
k_dram_block_window_tmp,
|
||||
v_dram_block_window_tmp,
|
||||
bias_dram_block_window_tmp,
|
||||
randval_dram_block_window_tmp,
|
||||
lse_dram_block_window_tmp,
|
||||
mask,
|
||||
position_encoding,
|
||||
scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
sink_v,
|
||||
kQKHeaddim / kK0,
|
||||
kK0,
|
||||
kN1);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -21,13 +21,18 @@ struct BlockGemmARegBSmemCRegV2
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr bool kSupportsPartialK = true;
|
||||
static constexpr bool kSupportsPartialN = true;
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
template <bool UsePartialN,
|
||||
typename CBlockTensor,
|
||||
typename ABlockTensorTmp,
|
||||
typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void Impl(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp,
|
||||
[[maybe_unused]] const index_t valid_n_iters) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
@@ -134,10 +139,7 @@ struct BlockGemmARegBSmemCRegV2
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_ford<sequence<KIterPerWarp, NIterPerWarp>>{}([&](auto kn) {
|
||||
constexpr auto kIter = number<kn[number<0>{}]>{};
|
||||
constexpr auto nIter = number<kn[number<1>{}]>{};
|
||||
auto run_n_iter = [&](auto kIter, auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
@@ -166,7 +168,44 @@ struct BlockGemmARegBSmemCRegV2
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
// hot loop:
|
||||
if constexpr(UsePartialN)
|
||||
{
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
if(static_cast<index_t>(nIter.value) < valid_n_iters)
|
||||
{
|
||||
run_n_iter(kIter, nIter);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { run_n_iter(kIter, nIter); });
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// C += A * B (executing only the first valid_n_iters N sub-iterations)
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp,
|
||||
const index_t valid_n_iters) const
|
||||
{
|
||||
Impl<true>(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp, valid_n_iters);
|
||||
}
|
||||
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
Impl<false>(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp, 0);
|
||||
}
|
||||
|
||||
template <index_t MPerBlock = BlockGemmShape::kM, index_t KPerBlock = BlockGemmShape::kK>
|
||||
@@ -227,7 +266,17 @@ struct BlockGemmARegBSmemCRegV2
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
// C = A * B (executing only the first valid_n_iters N sub-iterations)
|
||||
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp,
|
||||
const index_t valid_n_iters) const
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp, valid_n_iters);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
|
||||
Reference in New Issue
Block a user