mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] FMHA BWD Remove Unnecessary Padding (#2550)
* Remove unnecessary pssk * Add BlockFmhaBwdDQDKDVPipeline wrapper * Resolve copilot comments & Remove kpad & fix * Remove spad
This commit is contained in:
@@ -52,8 +52,6 @@ struct FmhaBwdDQDKDVKernel
|
||||
using BiasGradDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasGradDataType>;
|
||||
|
||||
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
|
||||
@@ -85,8 +83,6 @@ struct FmhaBwdDQDKDVKernel
|
||||
#define _TS_ std::to_string
|
||||
auto pn = [&] () {
|
||||
std::string n;
|
||||
if (kPadSeqLenQ) n += "s";
|
||||
if (kPadSeqLenK) n += "sk";
|
||||
if (kPadHeadDimQ) n += "d";
|
||||
if (kPadHeadDimV) n += "dv";
|
||||
return n.empty() ? n : std::string("p") + n; }();
|
||||
@@ -100,7 +96,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
"r" + _TS_(gbr4::at(ck_tile::number<0>{})) + "x" + _TS_(gbr4::at(ck_tile::number<1>{})) + "x" + _TS_(gbr4::at(ck_tile::number<2>{})) + "_" +
|
||||
"w" + _TS_(gwt0::at(ck_tile::number<0>{})) + "x" + _TS_(gwt0::at(ck_tile::number<1>{})) + "x" + _TS_(gwt0::at(ck_tile::number<2>{})) + "_" +
|
||||
"w" + _TS_(gwt1::at(ck_tile::number<0>{})) + "x" + _TS_(gwt1::at(ck_tile::number<1>{})) + "x" + _TS_(gwt1::at(ck_tile::number<2>{})) + "_" +
|
||||
("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "_npad" : "_" + pn) +
|
||||
("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "_npad" : "_" + pn) +
|
||||
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasBiasGrad ? "_dbias" : "_ndbias") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kHasDropout ? "_dropout" : "_ndropout" ) +
|
||||
(kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "_ndeterministic" );
|
||||
@@ -1221,7 +1217,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const auto q_dram = pad_tensor_view(
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
@@ -1232,7 +1228,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const auto k_dram = pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<kPadSeqLenK, kPadHeadDimQ>{});
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
|
||||
const auto v_dram = [&]() {
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
@@ -1244,22 +1240,15 @@ struct FmhaBwdDQDKDVKernel
|
||||
return pad_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
sequence<kPadSeqLenK, kPadHeadDimV>{});
|
||||
sequence<false, kPadHeadDimV>{});
|
||||
}();
|
||||
|
||||
const auto lse_dram = [&]() {
|
||||
const auto lse_dram_naive = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
lse_ptr, make_tuple(kargs.seqlen_q), number<1>{});
|
||||
return pad_tensor_view(
|
||||
lse_dram_naive, make_tuple(number<FmhaPipeline::kM0>{}), sequence<kPadSeqLenQ>{});
|
||||
}();
|
||||
// lse and d should be fine to read unpaded data as they are not on the reduction dimension
|
||||
const auto lse_dram = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
lse_ptr, make_tuple(kargs.seqlen_q), number<FmhaPipeline::kM0>{});
|
||||
|
||||
const auto d_dram = [&]() {
|
||||
const auto d_dram_naive = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
d_ptr, make_tuple(kargs.seqlen_q), number<1>{});
|
||||
return pad_tensor_view(
|
||||
d_dram_naive, make_tuple(number<FmhaPipeline::kM0>{}), sequence<kPadSeqLenQ>{});
|
||||
}();
|
||||
const auto d_dram = make_naive_tensor_view_packed<address_space_enum::global>(
|
||||
d_ptr, make_tuple(kargs.seqlen_q), number<FmhaPipeline::kM0>{});
|
||||
|
||||
const auto do_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
do_ptr,
|
||||
@@ -1270,7 +1259,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const auto do_dram = pad_tensor_view(
|
||||
do_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimV>{});
|
||||
sequence<false, kPadHeadDimV>{});
|
||||
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram,
|
||||
@@ -1313,7 +1302,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
return pad_tensor_view(
|
||||
dq_acc_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
@@ -1341,7 +1330,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
return pad_tensor_view(
|
||||
dq_acc_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
@@ -1376,9 +1365,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
number<FmhaPipeline::kAlignmentBias>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(bias_dram_naive,
|
||||
bias_dram_window_lengths,
|
||||
sequence<kPadSeqLenQ, kPadSeqLenK>{});
|
||||
return pad_tensor_view(
|
||||
bias_dram_naive, bias_dram_window_lengths, sequence<false, true>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(bias_dram, bias_dram_window_lengths, {0, i_n0});
|
||||
@@ -1406,9 +1394,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
number<FmhaPipeline::kAlignmentBias>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(dbias_dram_naive,
|
||||
bias_dram_window_lengths,
|
||||
sequence<kPadSeqLenQ, kPadSeqLenK>{});
|
||||
return pad_tensor_view(
|
||||
dbias_dram_naive, bias_dram_window_lengths, sequence<false, true>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(dbias_dram, bias_dram_window_lengths, {0, i_n0});
|
||||
@@ -1495,9 +1482,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(randval_dram_naive,
|
||||
randval_dram_window_lengths,
|
||||
sequence<kPadSeqLenQ, kPadSeqLenK>{});
|
||||
return pad_tensor_view(
|
||||
randval_dram_naive, randval_dram_window_lengths, sequence<false, true>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(randval_dram, randval_dram_window_lengths, {0, i_n0});
|
||||
@@ -1550,7 +1536,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
return pad_tensor_view(
|
||||
dk_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<kPadSeqLenK, kPadHeadDimQ>{});
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
}();
|
||||
|
||||
auto dv_dram = [&]() {
|
||||
@@ -1564,7 +1550,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
return pad_tensor_view(
|
||||
dv_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
sequence<kPadSeqLenK, kPadHeadDimV>{});
|
||||
sequence<false, kPadHeadDimV>{});
|
||||
}();
|
||||
|
||||
auto dk_dram_window = make_tile_window(
|
||||
|
||||
Reference in New Issue
Block a user