Separate Traits from Problem while being used for defining the pipeline

This commit is contained in:
Qianfeng Zhang
2025-11-14 16:08:14 +00:00
parent 95c1bb25e3
commit 238b5c4f08
7 changed files with 66 additions and 60 deletions

View File

@@ -41,8 +41,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
static constexpr bool kUseTrLoad = false;
#endif
template <typename HstuTraits>
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
using HstuPipelineProblem = ck_tile::HstuAttentionFwdPipelineProblem<
InOutDataType,
typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType,
typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType,
@@ -53,8 +52,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
kUseCausal,
kUseSoftmax,
kUseTrLoad,
HstuAttentionTileSetting,
HstuTraits>;
HstuAttentionTileSetting>;
static void Run(HstuAttentionFwdParams& param, hipStream_t stream)
{
@@ -82,8 +80,6 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
kPadHeadDimV,
occupancy>;
using HstuPipelineProblem = HstuPipelineProblemTemp<HstuTraits>;
using HstuEpilogue =
ck_tile::NRepetitions2DEpilogue<ck_tile::Default2DEpilogueProblem<
typename HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType,
@@ -95,8 +91,10 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
{
using HstuPipeline = std::conditional_t<
kUseSoftmax,
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem>,
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem>>;
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem,
HstuTraits>,
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem,
HstuTraits>>;
using HstuKernel = ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
@@ -106,8 +104,11 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
{
using HstuPipeline = std::conditional_t<
kUseSoftmax,
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad<HstuPipelineProblem>,
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad<HstuPipelineProblem>>;
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad<
HstuPipelineProblem,
HstuTraits>,
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad<HstuPipelineProblem,
HstuTraits>>;
using HstuKernel = ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;

View File

@@ -41,8 +41,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
static constexpr bool kUseTrLoad = false;
#endif
template <typename HstuTraits>
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
using HstuPipelineProblem = ck_tile::HstuAttentionFwdPipelineProblem<
InOutDataType,
typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType,
typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType,
@@ -53,8 +52,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
kUseCausal,
kUseSoftmax,
kUseTrLoad,
HstuAttentionTileSetting,
HstuTraits>;
HstuAttentionTileSetting>;
static void Run(HstuAttentionFwdParams& param, hipStream_t stream)
{
@@ -76,8 +74,6 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
kPadHeadDimV,
occupancy>;
using HstuPipelineProblem = HstuPipelineProblemTemp<HstuTraits>;
using HstuEpilogue = ck_tile::NRepetitions2DEpilogue<ck_tile::Default2DEpilogueProblem<
typename HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType,
typename HstuAttentionFwdTypeConfig<InOutDataType>::ODataType,
@@ -88,8 +84,10 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
{
using HstuPipeline = std::conditional_t<
kUseSoftmax,
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem>,
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem>>;
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem,
HstuTraits>,
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem,
HstuTraits>>;
using HstuKernel = ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
@@ -99,8 +97,10 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
{
using HstuPipeline = std::conditional_t<
kUseSoftmax,
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad<HstuPipelineProblem>,
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad<HstuPipelineProblem>>;
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad<HstuPipelineProblem,
HstuTraits>,
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad<HstuPipelineProblem,
HstuTraits>>;
using HstuKernel = ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;

View File

@@ -10,10 +10,13 @@
namespace ck_tile {
template <typename Problem_, typename Policy_ = HstuAttentionFwdPipelineQRKSVSDefaultPolicy>
template <typename Problem_,
typename Traits_,
typename Policy_ = HstuAttentionFwdPipelineQRKSVSDefaultPolicy>
struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
{
using Problem = remove_cvref_t<Problem_>;
using Traits = remove_cvref_t<Traits_>;
using Policy = remove_cvref_t<Policy_>;
using QKVDataType = remove_cvref_t<typename Problem::InOutDataType>;
using GemmAccDataType = remove_cvref_t<typename Problem::GemmAccDataType>;
@@ -46,11 +49,10 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
static constexpr bool kUseTrLoad = false;
static constexpr bool kPadSeqLenQ = Problem::Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQK = Problem::Traits::kPadHeadDimQK;
static constexpr bool kPadHeadDimV =
(kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::Traits::kPadHeadDimV;
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQK = Traits::kPadHeadDimQK;
static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? true : Traits::kPadHeadDimV;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
@@ -59,7 +61,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
static constexpr index_t kAlignmentK =
kPadHeadDimQK ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
Problem::Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
@@ -74,8 +76,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
Policy::template GetKVBlockGemmSingleRepN<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::Traits::kBlockPerCu != -1)
return Problem::Traits::kBlockPerCu;
if constexpr(Traits::kBlockPerCu != -1)
return Traits::kBlockPerCu;
else
{
if constexpr(kQKHeaddim == 32)

View File

@@ -10,10 +10,13 @@
namespace ck_tile {
template <typename Problem_, typename Policy_ = HstuAttentionFwdPipelineQRKSVSDefaultPolicy>
template <typename Problem_,
typename Traits_,
typename Policy_ = HstuAttentionFwdPipelineQRKSVSDefaultPolicy>
struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
{
using Problem = remove_cvref_t<Problem_>;
using Traits = remove_cvref_t<Traits_>;
using Policy = remove_cvref_t<Policy_>;
using QKVDataType = remove_cvref_t<typename Problem::InOutDataType>;
using GemmAccDataType = remove_cvref_t<typename Problem::GemmAccDataType>;
@@ -48,11 +51,10 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
static_assert(Problem::kLoadWholeQTileOnceThroughLds == true, "Check failed!");
static constexpr bool kPadSeqLenQ = Problem::Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQK = Problem::Traits::kPadHeadDimQK;
static constexpr bool kPadHeadDimV =
(kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::Traits::kPadHeadDimV;
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQK = Traits::kPadHeadDimQK;
static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? true : Traits::kPadHeadDimV;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
@@ -61,7 +63,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
static constexpr index_t kAlignmentK =
kPadHeadDimQK ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
Problem::Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
@@ -76,8 +78,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
Policy::template GetKVBlockGemmSingleRepN<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::Traits::kBlockPerCu != -1)
return Problem::Traits::kBlockPerCu;
if constexpr(Traits::kBlockPerCu != -1)
return Traits::kBlockPerCu;
else
{
if constexpr(kQKHeaddim == 32)

View File

@@ -23,8 +23,7 @@ template <typename InOutDataType_,
bool kHasCausal_,
bool kUseSoftmax_,
bool kUseTrLoad_, // use transposed loading to load V tile from lds to vgprs
typename AttentionTileSetting_,
typename Traits_>
typename AttentionTileSetting_>
struct HstuAttentionFwdPipelineProblem
{
using InOutDataType = remove_cvref_t<InOutDataType_>;
@@ -52,8 +51,6 @@ struct HstuAttentionFwdPipelineProblem
using HstuAttentionTileSetting = remove_cvref_t<AttentionTileSetting_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kNumGemm0Warps = AttentionTileSetting_::NumGemm0Warps;
static constexpr index_t kNumGemm1Warps = AttentionTileSetting_::NumGemm1Warps;
static constexpr index_t kBlockSize = AttentionTileSetting_::NumWarps * get_warp_size();

View File

@@ -10,10 +10,13 @@
namespace ck_tile {
template <typename Problem_, typename Policy_ = HstuAttentionFwdPipelineQRKSVSDefaultPolicy>
template <typename Problem_,
typename Traits_,
typename Policy_ = HstuAttentionFwdPipelineQRKSVSDefaultPolicy>
struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
{
using Problem = remove_cvref_t<Problem_>;
using Traits = remove_cvref_t<Traits_>;
using Policy = remove_cvref_t<Policy_>;
using QKVDataType = remove_cvref_t<typename Problem::InOutDataType>;
using GemmAccDataType = remove_cvref_t<typename Problem::GemmAccDataType>;
@@ -46,11 +49,10 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
static constexpr bool kUseTrLoad = false;
static constexpr bool kPadSeqLenQ = Problem::Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQK = Problem::Traits::kPadHeadDimQK;
static constexpr bool kPadHeadDimV =
(kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::Traits::kPadHeadDimV;
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQK = Traits::kPadHeadDimQK;
static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? true : Traits::kPadHeadDimV;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
@@ -59,7 +61,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
static constexpr index_t kAlignmentK =
kPadHeadDimQK ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
Problem::Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
@@ -74,8 +76,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
Policy::template GetKVBlockGemmSingleRepN<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::Traits::kBlockPerCu != -1)
return Problem::Traits::kBlockPerCu;
if constexpr(Traits::kBlockPerCu != -1)
return Traits::kBlockPerCu;
else
{
if constexpr(kQKHeaddim == 32)

View File

@@ -10,10 +10,13 @@
namespace ck_tile {
template <typename Problem_, typename Policy_ = HstuAttentionFwdPipelineQRKSVSDefaultPolicy>
template <typename Problem_,
typename Traits_,
typename Policy_ = HstuAttentionFwdPipelineQRKSVSDefaultPolicy>
struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
{
using Problem = remove_cvref_t<Problem_>;
using Traits = remove_cvref_t<Traits_>;
using Policy = remove_cvref_t<Policy_>;
using QKVDataType = remove_cvref_t<typename Problem::InOutDataType>;
using GemmAccDataType = remove_cvref_t<typename Problem::GemmAccDataType>;
@@ -48,11 +51,10 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
static_assert(Problem::kLoadWholeQTileOnceThroughLds == true, "Check failed!");
static constexpr bool kPadSeqLenQ = Problem::Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQK = Problem::Traits::kPadHeadDimQK;
static constexpr bool kPadHeadDimV =
(kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::Traits::kPadHeadDimV;
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQK = Traits::kPadHeadDimQK;
static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? true : Traits::kPadHeadDimV;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
@@ -61,7 +63,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
static constexpr index_t kAlignmentK =
kPadHeadDimQK ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
Problem::Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
@@ -76,8 +78,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
Policy::template GetKVBlockGemmSingleRepN<Problem>();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::Traits::kBlockPerCu != -1)
return Problem::Traits::kBlockPerCu;
if constexpr(Traits::kBlockPerCu != -1)
return Traits::kBlockPerCu;
else
{
if constexpr(kQKHeaddim == 32)