mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
Separate Traits from Problem while being used for defining the pipeline
This commit is contained in:
@@ -41,8 +41,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
|
||||
static constexpr bool kUseTrLoad = false;
|
||||
#endif
|
||||
|
||||
template <typename HstuTraits>
|
||||
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
|
||||
using HstuPipelineProblem = ck_tile::HstuAttentionFwdPipelineProblem<
|
||||
InOutDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType,
|
||||
@@ -53,8 +52,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kUseTrLoad,
|
||||
HstuAttentionTileSetting,
|
||||
HstuTraits>;
|
||||
HstuAttentionTileSetting>;
|
||||
|
||||
static void Run(HstuAttentionFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
@@ -82,8 +80,6 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
|
||||
kPadHeadDimV,
|
||||
occupancy>;
|
||||
|
||||
using HstuPipelineProblem = HstuPipelineProblemTemp<HstuTraits>;
|
||||
|
||||
using HstuEpilogue =
|
||||
ck_tile::NRepetitions2DEpilogue<ck_tile::Default2DEpilogueProblem<
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType,
|
||||
@@ -95,8 +91,10 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
|
||||
{
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseSoftmax,
|
||||
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem>,
|
||||
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem>>;
|
||||
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem,
|
||||
HstuTraits>,
|
||||
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem,
|
||||
HstuTraits>>;
|
||||
|
||||
using HstuKernel = ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
@@ -106,8 +104,11 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
|
||||
{
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseSoftmax,
|
||||
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad<HstuPipelineProblem>,
|
||||
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad<HstuPipelineProblem>>;
|
||||
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad<
|
||||
HstuPipelineProblem,
|
||||
HstuTraits>,
|
||||
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad<HstuPipelineProblem,
|
||||
HstuTraits>>;
|
||||
|
||||
using HstuKernel = ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
|
||||
@@ -41,8 +41,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
|
||||
static constexpr bool kUseTrLoad = false;
|
||||
#endif
|
||||
|
||||
template <typename HstuTraits>
|
||||
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
|
||||
using HstuPipelineProblem = ck_tile::HstuAttentionFwdPipelineProblem<
|
||||
InOutDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType,
|
||||
@@ -53,8 +52,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kUseTrLoad,
|
||||
HstuAttentionTileSetting,
|
||||
HstuTraits>;
|
||||
HstuAttentionTileSetting>;
|
||||
|
||||
static void Run(HstuAttentionFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
@@ -76,8 +74,6 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
|
||||
kPadHeadDimV,
|
||||
occupancy>;
|
||||
|
||||
using HstuPipelineProblem = HstuPipelineProblemTemp<HstuTraits>;
|
||||
|
||||
using HstuEpilogue = ck_tile::NRepetitions2DEpilogue<ck_tile::Default2DEpilogueProblem<
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::ODataType,
|
||||
@@ -88,8 +84,10 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
|
||||
{
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseSoftmax,
|
||||
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem>,
|
||||
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem>>;
|
||||
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem,
|
||||
HstuTraits>,
|
||||
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVS<HstuPipelineProblem,
|
||||
HstuTraits>>;
|
||||
|
||||
using HstuKernel = ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
@@ -99,8 +97,10 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
|
||||
{
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseSoftmax,
|
||||
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad<HstuPipelineProblem>,
|
||||
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad<HstuPipelineProblem>>;
|
||||
ck_tile::HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad<HstuPipelineProblem,
|
||||
HstuTraits>,
|
||||
ck_tile::HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad<HstuPipelineProblem,
|
||||
HstuTraits>>;
|
||||
|
||||
using HstuKernel = ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
|
||||
@@ -10,10 +10,13 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = HstuAttentionFwdPipelineQRKSVSDefaultPolicy>
|
||||
template <typename Problem_,
|
||||
typename Traits_,
|
||||
typename Policy_ = HstuAttentionFwdPipelineQRKSVSDefaultPolicy>
|
||||
struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QKVDataType = remove_cvref_t<typename Problem::InOutDataType>;
|
||||
using GemmAccDataType = remove_cvref_t<typename Problem::GemmAccDataType>;
|
||||
@@ -46,11 +49,10 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
static constexpr bool kUseTrLoad = false;
|
||||
|
||||
static constexpr bool kPadSeqLenQ = Problem::Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQK = Problem::Traits::kPadHeadDimQK;
|
||||
static constexpr bool kPadHeadDimV =
|
||||
(kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::Traits::kPadHeadDimV;
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQK = Traits::kPadHeadDimQK;
|
||||
static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? true : Traits::kPadHeadDimV;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
@@ -59,7 +61,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQK ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
Problem::Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
@@ -74,8 +76,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
Policy::template GetKVBlockGemmSingleRepN<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::Traits::kBlockPerCu != -1)
|
||||
return Problem::Traits::kBlockPerCu;
|
||||
if constexpr(Traits::kBlockPerCu != -1)
|
||||
return Traits::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
if constexpr(kQKHeaddim == 32)
|
||||
|
||||
@@ -10,10 +10,13 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = HstuAttentionFwdPipelineQRKSVSDefaultPolicy>
|
||||
template <typename Problem_,
|
||||
typename Traits_,
|
||||
typename Policy_ = HstuAttentionFwdPipelineQRKSVSDefaultPolicy>
|
||||
struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QKVDataType = remove_cvref_t<typename Problem::InOutDataType>;
|
||||
using GemmAccDataType = remove_cvref_t<typename Problem::GemmAccDataType>;
|
||||
@@ -48,11 +51,10 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
static_assert(Problem::kLoadWholeQTileOnceThroughLds == true, "Check failed!");
|
||||
|
||||
static constexpr bool kPadSeqLenQ = Problem::Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQK = Problem::Traits::kPadHeadDimQK;
|
||||
static constexpr bool kPadHeadDimV =
|
||||
(kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::Traits::kPadHeadDimV;
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQK = Traits::kPadHeadDimQK;
|
||||
static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? true : Traits::kPadHeadDimV;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
@@ -61,7 +63,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQK ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
Problem::Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
@@ -76,8 +78,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
Policy::template GetKVBlockGemmSingleRepN<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::Traits::kBlockPerCu != -1)
|
||||
return Problem::Traits::kBlockPerCu;
|
||||
if constexpr(Traits::kBlockPerCu != -1)
|
||||
return Traits::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
if constexpr(kQKHeaddim == 32)
|
||||
|
||||
@@ -23,8 +23,7 @@ template <typename InOutDataType_,
|
||||
bool kHasCausal_,
|
||||
bool kUseSoftmax_,
|
||||
bool kUseTrLoad_, // use transposed loading to load V tile from lds to vgprs
|
||||
typename AttentionTileSetting_,
|
||||
typename Traits_>
|
||||
typename AttentionTileSetting_>
|
||||
struct HstuAttentionFwdPipelineProblem
|
||||
{
|
||||
using InOutDataType = remove_cvref_t<InOutDataType_>;
|
||||
@@ -52,8 +51,6 @@ struct HstuAttentionFwdPipelineProblem
|
||||
|
||||
using HstuAttentionTileSetting = remove_cvref_t<AttentionTileSetting_>;
|
||||
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static constexpr index_t kNumGemm0Warps = AttentionTileSetting_::NumGemm0Warps;
|
||||
static constexpr index_t kNumGemm1Warps = AttentionTileSetting_::NumGemm1Warps;
|
||||
static constexpr index_t kBlockSize = AttentionTileSetting_::NumWarps * get_warp_size();
|
||||
|
||||
@@ -10,10 +10,13 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = HstuAttentionFwdPipelineQRKSVSDefaultPolicy>
|
||||
template <typename Problem_,
|
||||
typename Traits_,
|
||||
typename Policy_ = HstuAttentionFwdPipelineQRKSVSDefaultPolicy>
|
||||
struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QKVDataType = remove_cvref_t<typename Problem::InOutDataType>;
|
||||
using GemmAccDataType = remove_cvref_t<typename Problem::GemmAccDataType>;
|
||||
@@ -46,11 +49,10 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
|
||||
static constexpr bool kUseTrLoad = false;
|
||||
|
||||
static constexpr bool kPadSeqLenQ = Problem::Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQK = Problem::Traits::kPadHeadDimQK;
|
||||
static constexpr bool kPadHeadDimV =
|
||||
(kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::Traits::kPadHeadDimV;
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQK = Traits::kPadHeadDimQK;
|
||||
static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? true : Traits::kPadHeadDimV;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
@@ -59,7 +61,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQK ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
Problem::Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
@@ -74,8 +76,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
Policy::template GetKVBlockGemmSingleRepN<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::Traits::kBlockPerCu != -1)
|
||||
return Problem::Traits::kBlockPerCu;
|
||||
if constexpr(Traits::kBlockPerCu != -1)
|
||||
return Traits::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
if constexpr(kQKHeaddim == 32)
|
||||
|
||||
@@ -10,10 +10,13 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = HstuAttentionFwdPipelineQRKSVSDefaultPolicy>
|
||||
template <typename Problem_,
|
||||
typename Traits_,
|
||||
typename Policy_ = HstuAttentionFwdPipelineQRKSVSDefaultPolicy>
|
||||
struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QKVDataType = remove_cvref_t<typename Problem::InOutDataType>;
|
||||
using GemmAccDataType = remove_cvref_t<typename Problem::GemmAccDataType>;
|
||||
@@ -48,11 +51,10 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
static_assert(Problem::kLoadWholeQTileOnceThroughLds == true, "Check failed!");
|
||||
|
||||
static constexpr bool kPadSeqLenQ = Problem::Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQK = Problem::Traits::kPadHeadDimQK;
|
||||
static constexpr bool kPadHeadDimV =
|
||||
(kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::Traits::kPadHeadDimV;
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQK = Traits::kPadHeadDimQK;
|
||||
static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? true : Traits::kPadHeadDimV;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
@@ -61,7 +63,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQK ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
Problem::Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
@@ -76,8 +78,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
Policy::template GetKVBlockGemmSingleRepN<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::Traits::kBlockPerCu != -1)
|
||||
return Problem::Traits::kBlockPerCu;
|
||||
if constexpr(Traits::kBlockPerCu != -1)
|
||||
return Traits::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
if constexpr(kQKHeaddim == 32)
|
||||
|
||||
Reference in New Issue
Block a user