mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
[CK_TILE] FMHA Reduce register spilling in fwd with dropout (workaround for CI failures with clang-22) (#3221)
* Use vectorized stores for dropout randvals With no kPadSeqLenK the kernel uses 2 buffer_store_dwordx2 instead of 16 buffer_store_byte. This requires less registers and reduces spilling. * Calculate dropout randvals for storing and applying only once Even though it may add a small overhead when storing is not required, it uses significantly less registers and hence no spilling.
This commit is contained in:
@@ -333,23 +333,15 @@ struct BlockDropout
|
||||
return randval;
|
||||
};
|
||||
|
||||
if(is_store_randval)
|
||||
{
|
||||
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
|
||||
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
|
||||
const auto randval = generate_randval(i_m0, i_n0);
|
||||
// save to Global
|
||||
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
|
||||
store_tile(randval_dram_window, randval_store);
|
||||
move_tile_window(randval_dram_window, {0, kNPerStep});
|
||||
});
|
||||
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
|
||||
});
|
||||
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
|
||||
}
|
||||
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
|
||||
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
|
||||
const auto randval = generate_randval(i_m0, i_n0);
|
||||
if(is_store_randval)
|
||||
{
|
||||
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
|
||||
store_tile(randval_dram_window, randval_store);
|
||||
}
|
||||
move_tile_window(randval_dram_window, {0, kNPerStep});
|
||||
// Drop values of P based on the generated probabilities
|
||||
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
|
||||
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
|
||||
@@ -369,7 +361,9 @@ struct BlockDropout
|
||||
});
|
||||
});
|
||||
});
|
||||
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
|
||||
});
|
||||
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
|
||||
}
|
||||
|
||||
const unsigned long long ph_seed;
|
||||
|
||||
@@ -1005,7 +1005,7 @@ struct FmhaBatchPrefillWithPagedKVCacheKernel
|
||||
rand_val_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.seqlen_k),
|
||||
make_tuple(kargs.stride_randval, 1),
|
||||
number<1>{},
|
||||
number<FmhaPipeline::kAlignmentRandVal>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(randval_dram_naive,
|
||||
|
||||
@@ -1450,7 +1450,7 @@ struct FmhaFwdKernel
|
||||
rand_val_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.seqlen_k),
|
||||
make_tuple(kargs.stride_randval, 1),
|
||||
number<1>{},
|
||||
number<FmhaPipeline::kAlignmentRandVal>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(randval_dram_naive,
|
||||
|
||||
@@ -80,6 +80,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVS
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
static constexpr index_t kAlignmentRandVal =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
|
||||
@@ -83,6 +83,8 @@ struct BlockFmhaPipelineQRKSVS
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
static constexpr index_t kAlignmentRandVal =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
|
||||
@@ -81,6 +81,8 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
static constexpr index_t kAlignmentRandVal =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal<Problem>();
|
||||
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
|
||||
|
||||
@@ -90,6 +90,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload
|
||||
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
static constexpr index_t kAlignmentRandVal =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
|
||||
@@ -69,6 +69,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
static constexpr index_t kAlignmentRandVal =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
|
||||
@@ -74,6 +74,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
static constexpr index_t kAlignmentRandVal =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
|
||||
@@ -78,6 +78,8 @@ struct BlockFmhaPipelineQSKSVS
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
static constexpr index_t kAlignmentRandVal =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentRandVal<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::kBlockPerCu != -1)
|
||||
|
||||
@@ -422,6 +422,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentRandVal()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(QXPolicy::template GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
constexpr auto c_warp_y_lengths = CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::RandValOutputDataType);
|
||||
return min(MaxVectorSize, c_warp_y_lengths.get(number<CWarpDstr::NDimY - 1>{}));
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user