[CK_TILE] FMHA BWD Pad HDim to a Multiple of 8 (#2918)

This commit is contained in:
Yi DING
2025-09-26 16:42:59 +08:00
committed by GitHub
parent 518d24e662
commit 32773fe5cb
12 changed files with 110 additions and 88 deletions

View File

@@ -60,12 +60,12 @@ struct FmhaBwdDQDKDVKernel
using VGradDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VGradDataType>;
using BiasGradDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasGradDataType>;
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
static constexpr index_t kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr index_t kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
using FmhaDropout = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaDropout>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
static constexpr bool kHasDropout = FmhaDropout::IsDropout;
@@ -100,8 +100,8 @@ struct FmhaBwdDQDKDVKernel
#define _TS_ std::to_string
auto pn = [&] () {
std::string n;
if (kPadHeadDimQ) n += "d";
if (kPadHeadDimV) n += "dv";
if (kPadHeadDimQ) n += "d" + _TS_(kPadHeadDimQ);
if (kPadHeadDimV) n += "dv"+ _TS_(kPadHeadDimV);
return n.empty() ? n : std::string("p") + n; }();
return
_SS_("fmha_bwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
@@ -815,7 +815,7 @@ struct FmhaBwdDQDKDVKernel
const auto q_dram = pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<false, kPadHeadDimQ>{});
sequence<false, (kPadHeadDimQ > 0)>{});
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr,
@@ -826,7 +826,7 @@ struct FmhaBwdDQDKDVKernel
const auto k_dram = pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<false, kPadHeadDimQ>{});
sequence<false, (kPadHeadDimQ > 0)>{});
const auto v_dram = [&]() {
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
@@ -838,7 +838,7 @@ struct FmhaBwdDQDKDVKernel
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
sequence<false, kPadHeadDimV>{});
sequence<false, (kPadHeadDimV > 0)>{});
}();
// lse and d should be fine to read unpaded data as they are not on the reduction dimension
@@ -857,7 +857,7 @@ struct FmhaBwdDQDKDVKernel
const auto do_dram = pad_tensor_view(
do_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddim>{}),
sequence<false, kPadHeadDimV>{});
sequence<false, (kPadHeadDimV > 0)>{});
auto q_dram_window = make_tile_window(
q_dram,
@@ -905,7 +905,7 @@ struct FmhaBwdDQDKDVKernel
const auto dq_acc_dram = pad_tensor_view(
dq_acc_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<false, kPadHeadDimQ>{});
sequence<false, (kPadHeadDimQ > 0)>{});
return make_tile_window(
dq_acc_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
@@ -1089,7 +1089,7 @@ struct FmhaBwdDQDKDVKernel
return pad_tensor_view(
dk_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<false, kPadHeadDimQ>{});
sequence<false, (kPadHeadDimQ > 0)>{});
}();
auto dv_dram = [&]() {
@@ -1103,7 +1103,7 @@ struct FmhaBwdDQDKDVKernel
return pad_tensor_view(
dv_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
sequence<false, kPadHeadDimV>{});
sequence<false, (kPadHeadDimV > 0)>{});
}();
auto dk_dram_window = make_tile_window(