mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] FMHA BWD Pad HDim to a Multiple of 8 (#2918)
This commit is contained in:
@@ -60,12 +60,12 @@ struct FmhaBwdDQDKDVKernel
|
||||
using VGradDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VGradDataType>;
|
||||
using BiasGradDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasGradDataType>;
|
||||
|
||||
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
|
||||
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
|
||||
static constexpr index_t kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
|
||||
static constexpr index_t kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
using FmhaDropout = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaDropout>;
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
static constexpr bool kHasDropout = FmhaDropout::IsDropout;
|
||||
@@ -100,8 +100,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
#define _TS_ std::to_string
|
||||
auto pn = [&] () {
|
||||
std::string n;
|
||||
if (kPadHeadDimQ) n += "d";
|
||||
if (kPadHeadDimV) n += "dv";
|
||||
if (kPadHeadDimQ) n += "d" + _TS_(kPadHeadDimQ);
|
||||
if (kPadHeadDimV) n += "dv"+ _TS_(kPadHeadDimV);
|
||||
return n.empty() ? n : std::string("p") + n; }();
|
||||
return
|
||||
_SS_("fmha_bwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
|
||||
@@ -815,7 +815,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const auto q_dram = pad_tensor_view(
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
sequence<false, (kPadHeadDimQ > 0)>{});
|
||||
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
@@ -826,7 +826,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const auto k_dram = pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
sequence<false, (kPadHeadDimQ > 0)>{});
|
||||
|
||||
const auto v_dram = [&]() {
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
@@ -838,7 +838,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
return pad_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
sequence<false, kPadHeadDimV>{});
|
||||
sequence<false, (kPadHeadDimV > 0)>{});
|
||||
}();
|
||||
|
||||
// lse and d should be fine to read unpaded data as they are not on the reduction dimension
|
||||
@@ -857,7 +857,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const auto do_dram = pad_tensor_view(
|
||||
do_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
sequence<false, kPadHeadDimV>{});
|
||||
sequence<false, (kPadHeadDimV > 0)>{});
|
||||
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram,
|
||||
@@ -905,7 +905,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
const auto dq_acc_dram = pad_tensor_view(
|
||||
dq_acc_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
sequence<false, (kPadHeadDimQ > 0)>{});
|
||||
return make_tile_window(
|
||||
dq_acc_dram,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
@@ -1089,7 +1089,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
return pad_tensor_view(
|
||||
dk_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
sequence<false, (kPadHeadDimQ > 0)>{});
|
||||
}();
|
||||
|
||||
auto dv_dram = [&]() {
|
||||
@@ -1103,7 +1103,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
return pad_tensor_view(
|
||||
dv_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
sequence<false, kPadHeadDimV>{});
|
||||
sequence<false, (kPadHeadDimV > 0)>{});
|
||||
}();
|
||||
|
||||
auto dk_dram_window = make_tile_window(
|
||||
|
||||
@@ -49,8 +49,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
|
||||
@@ -60,18 +60,18 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad = 1;
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias = 1;
|
||||
|
||||
static constexpr const char* name = "kr_ktr_vr";
|
||||
|
||||
@@ -49,8 +49,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
|
||||
@@ -60,18 +60,18 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad = 1;
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias = 1;
|
||||
|
||||
static constexpr const char* name = "kr_ktr_vr_iglp";
|
||||
|
||||
@@ -14,7 +14,8 @@ namespace ck_tile {
|
||||
template <typename Problem, typename Policy>
|
||||
class BlockFmhaBwdDQDKDVPipelineSelector
|
||||
{
|
||||
static constexpr bool has_dpad = Problem::Traits::kPadHeadDimQ || Problem::Traits::kPadHeadDimV;
|
||||
static constexpr bool has_dpad1 =
|
||||
Problem::Traits::kPadHeadDimQ == 1 || Problem::Traits::kPadHeadDimV == 1;
|
||||
static constexpr bool is_decode = Problem::BlockFmhaShape::kMaxSeqLenQ > 0;
|
||||
|
||||
public:
|
||||
@@ -24,7 +25,7 @@ class BlockFmhaBwdDQDKDVPipelineSelector
|
||||
std::conditional_t<is_decode,
|
||||
BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR<TS...>,
|
||||
BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR<TS...>>,
|
||||
std::conditional_t<has_dpad,
|
||||
std::conditional_t<has_dpad1,
|
||||
BlockFmhaBwdDQDKDVPipelineKRKTRVR<TS...>,
|
||||
BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<TS...>>>;
|
||||
using type = std::conditional_t<std::is_same_v<Policy, void>, //
|
||||
|
||||
@@ -49,8 +49,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
|
||||
@@ -60,18 +60,18 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad = 1;
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias = 1;
|
||||
|
||||
static constexpr const char* name = "trload_kr_ktr_vr";
|
||||
|
||||
@@ -51,8 +51,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
|
||||
@@ -62,18 +62,18 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV<Problem>();
|
||||
static constexpr index_t kAlignmentOGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t kAlignmentQGrad = 1;
|
||||
static constexpr index_t kAlignmentKGrad =
|
||||
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
|
||||
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad<Problem>();
|
||||
static constexpr index_t kAlignmentVGrad =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
|
||||
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad<Problem>();
|
||||
static constexpr index_t kAlignmentBias = 1;
|
||||
|
||||
static constexpr const char* name = "trload_kr_ktr_vr";
|
||||
|
||||
@@ -57,13 +57,11 @@ struct BlockFmhaBwdPipelineProblem
|
||||
static constexpr bool kUseTrLoad = kUseTrLoad_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
static_assert(!Traits::kPadSeqLenQ, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ");
|
||||
static_assert(!Traits::kPadSeqLenK, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ");
|
||||
static constexpr index_t kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr index_t kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
template <typename ODataType_,
|
||||
|
||||
@@ -37,6 +37,23 @@ struct TileFmhaTraits
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
};
|
||||
|
||||
template <index_t kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
index_t kPadHeadDimV_ /* paddding for hdim_v */,
|
||||
BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kHasBiasGrad_,
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
struct TileFmhaBwdTraits
|
||||
{
|
||||
static constexpr index_t kPadHeadDimQ = kPadHeadDimQ_;
|
||||
static constexpr index_t kPadHeadDimV = kPadHeadDimV_;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
|
||||
static_assert(kPadHeadDimQ == 0 || kPadHeadDimQ == 8 || kPadHeadDimQ == 1);
|
||||
static_assert(kPadHeadDimV == 0 || kPadHeadDimV == 8 || kPadHeadDimV == 1);
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadSeqLenK_ /* padding for seqlen_k */,
|
||||
bool kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
|
||||
Reference in New Issue
Block a user