[CK_TILE] FMHA BWD Pad HDim to a Multiple of 8 (#2918)

This commit is contained in:
Yi DING
2025-09-26 16:42:59 +08:00
committed by GitHub
parent 518d24e662
commit 32773fe5cb
12 changed files with 110 additions and 88 deletions

View File

@@ -49,8 +49,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
@@ -60,18 +60,18 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad = 1;
static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad<Problem>();
static constexpr index_t kAlignmentBias = 1;
static constexpr const char* name = "kr_ktr_vr";

View File

@@ -49,8 +49,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
@@ -60,18 +60,18 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad = 1;
static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad<Problem>();
static constexpr index_t kAlignmentBias = 1;
static constexpr const char* name = "kr_ktr_vr_iglp";

View File

@@ -14,7 +14,8 @@ namespace ck_tile {
template <typename Problem, typename Policy>
class BlockFmhaBwdDQDKDVPipelineSelector
{
static constexpr bool has_dpad = Problem::Traits::kPadHeadDimQ || Problem::Traits::kPadHeadDimV;
static constexpr bool has_dpad1 =
Problem::Traits::kPadHeadDimQ == 1 || Problem::Traits::kPadHeadDimV == 1;
static constexpr bool is_decode = Problem::BlockFmhaShape::kMaxSeqLenQ > 0;
public:
@@ -24,7 +25,7 @@ class BlockFmhaBwdDQDKDVPipelineSelector
std::conditional_t<is_decode,
BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR<TS...>,
BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR<TS...>>,
std::conditional_t<has_dpad,
std::conditional_t<has_dpad1,
BlockFmhaBwdDQDKDVPipelineKRKTRVR<TS...>,
BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<TS...>>>;
using type = std::conditional_t<std::is_same_v<Policy, void>, //

View File

@@ -49,8 +49,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
@@ -60,18 +60,18 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad = 1;
static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad<Problem>();
static constexpr index_t kAlignmentBias = 1;
static constexpr const char* name = "trload_kr_ktr_vr";

View File

@@ -51,8 +51,8 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr index_t kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr index_t kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
@@ -62,18 +62,18 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad = 1;
static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
kPadHeadDimQ ? kPadHeadDimQ : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
kPadHeadDimV ? kPadHeadDimV : Policy::template GetAlignmentVGrad<Problem>();
static constexpr index_t kAlignmentBias = 1;
static constexpr const char* name = "trload_kr_ktr_vr";

View File

@@ -57,13 +57,11 @@ struct BlockFmhaBwdPipelineProblem
static constexpr bool kUseTrLoad = kUseTrLoad_;
// attributes from traits
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static_assert(!Traits::kPadSeqLenQ, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ");
static_assert(!Traits::kPadSeqLenK, "BlockFmhaBwdPipelineProblem does not need kPadSeqLenQ");
static constexpr index_t kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr index_t kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
template <typename ODataType_,

View File

@@ -37,6 +37,23 @@ struct TileFmhaTraits
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
};
template <index_t kPadHeadDimQ_ /* paddding for hdim_q */,
index_t kPadHeadDimV_ /* paddding for hdim_v */,
BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaBwdTraits
{
static constexpr index_t kPadHeadDimQ = kPadHeadDimQ_;
static constexpr index_t kPadHeadDimV = kPadHeadDimV_;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
static_assert(kPadHeadDimQ == 0 || kPadHeadDimQ == 8 || kPadHeadDimQ == 1);
static_assert(kPadHeadDimV == 0 || kPadHeadDimV == 8 || kPadHeadDimV == 1);
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadSeqLenK_ /* padding for seqlen_k */,
bool kPadHeadDimQ_ /* paddding for hdim_q */,