[CK_TILE] FMHA Reduce register spilling in fwd with dropout (workaround for CI failures with clang-22) (#3221)

* Use vectorized stores for dropout randvals

With no kPadSeqLenK the kernel uses 2 buffer_store_dwordx2 instead of
16 buffer_store_byte. This requires less registers and reduces spilling.

* Calculate dropout randvals for storing and applying only once

Even though it may add a small overhead when storing is not required,
it uses significantly less registers and hence no spilling.
This commit is contained in:
Anton Gorenko
2025-11-19 10:40:12 +05:00
committed by GitHub
parent e91ee8578c
commit d7b3197869
11 changed files with 37 additions and 16 deletions

View File

@@ -333,23 +333,15 @@ struct BlockDropout
return randval;
};
if(is_store_randval)
{
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
const auto randval = generate_randval(i_m0, i_n0);
// save to Global
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
store_tile(randval_dram_window, randval_store);
move_tile_window(randval_dram_window, {0, kNPerStep});
});
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
});
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
}
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
const auto randval = generate_randval(i_m0, i_n0);
if(is_store_randval)
{
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
store_tile(randval_dram_window, randval_store);
}
move_tile_window(randval_dram_window, {0, kNPerStep});
// Drop values of P based on the generated probabilities
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
@@ -369,7 +361,9 @@ struct BlockDropout
});
});
});
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
});
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
}
const unsigned long long ph_seed;

View File

@@ -1005,7 +1005,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
rand_val_ptr,
make_tuple(kargs.seqlen_q, kargs.seqlen_k),
make_tuple(kargs.stride_randval, 1),
number<1>{},
number<FmhaPipeline::kAlignmentRandVal>{},
number<1>{});
return pad_tensor_view(randval_dram_naive,

View File

@@ -1450,7 +1450,7 @@ struct FmhaFwdKernel
rand_val_ptr,
make_tuple(kargs.seqlen_q, kargs.seqlen_k),
make_tuple(kargs.stride_randval, 1),
number<1>{},
number<FmhaPipeline::kAlignmentRandVal>{},
number<1>{});
return pad_tensor_view(randval_dram_naive,

View File

@@ -80,6 +80,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
static constexpr index_t kAlignmentRandVal =
kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)

View File

@@ -83,6 +83,8 @@ struct BlockFmhaPipelineQRKSVS
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
static constexpr index_t kAlignmentRandVal =
kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)

View File

@@ -81,6 +81,8 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
static constexpr index_t kAlignmentRandVal =
kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal<Problem>();
#if CK_TILE_FMHA_FWD_FAST_EXP2
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;

View File

@@ -90,6 +90,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
static constexpr index_t kAlignmentRandVal =
kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)

View File

@@ -69,6 +69,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
static constexpr index_t kAlignmentRandVal =
kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)

View File

@@ -74,6 +74,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
static constexpr index_t kAlignmentRandVal =
kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)

View File

@@ -78,6 +78,8 @@ struct BlockFmhaPipelineQSKSVS
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
static constexpr index_t kAlignmentRandVal =
kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)

View File

@@ -422,6 +422,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentRandVal()
{
using BlockGemm = remove_cvref_t<decltype(QXPolicy::template GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
using CWarpDstr = typename WG::CWarpDstr;
constexpr auto c_warp_y_lengths = CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::RandValOutputDataType);
return min(MaxVectorSize, c_warp_y_lengths.get(number<CWarpDstr::NDimY - 1>{}));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
{