Remove exposing kUseTrLoad as template parameter of pipeline problem

This commit is contained in:
Qianfeng Zhang
2026-04-21 15:35:03 +00:00
parent 8f0f7ca436
commit 0b6bbe45d6
15 changed files with 129 additions and 116 deletions

View File

@@ -39,9 +39,9 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
HstuAttentionNoSoftmaxFwdTileSetting<MaxK, MTile>>::Type;
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
static constexpr bool kUseTrLoad = true;
static constexpr bool use_trload_pipeline = true;
#else
static constexpr bool kUseTrLoad = false;
static constexpr bool use_trload_pipeline = false;
#endif
template <bool kIsCrossAttention>
@@ -57,7 +57,6 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
kHasDropout,
kUseCausal,
kUseSoftmax,
kUseTrLoad,
HstuAttentionTileSetting>;
static void Run(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
@@ -96,7 +95,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] {
using HstuPipelineProblem = HstuPipelineProblemTemp<kIsCrossAttention>;
if constexpr(!kUseTrLoad)
if constexpr(!use_trload_pipeline)
{
using HstuPipeline = std::conditional_t<
kUseSoftmax,

View File

@@ -44,9 +44,9 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
using HstuAttentionCombineTileSetting = HstuAttentionFwdSplitKVCombineTileSetting<MaxK>::Type;
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
static constexpr bool kUseTrLoad = true;
static constexpr bool use_trload_pipeline = true;
#else
static constexpr bool kUseTrLoad = false;
static constexpr bool use_trload_pipeline = false;
#endif
template <bool kIsCrossAttention>
@@ -62,7 +62,6 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
kHasDropout,
kUseCausal,
kUseSoftmax,
kUseTrLoad,
HstuAttentionFwdTileSetting>;
using OaccDataType = HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType;
@@ -107,7 +106,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] {
using HstuPipelineProblem = HstuFwdPipelineProblemTemp<kIsCrossAttention>;
if constexpr(!kUseTrLoad)
if constexpr(!use_trload_pipeline)
{
using HstuPipeline = std::conditional_t<
kUseSoftmax,

View File

@@ -13,6 +13,7 @@
#include <variant>
#include "hstu_block_masking.hpp"
#include "hstu_attention_util.hpp"
#ifndef HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
#define HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM 1
@@ -45,13 +46,14 @@ struct HstuAttentionFwdKernel
static constexpr auto kHasBias = HstuAttentionPipeline::Problem::kHasBias;
static constexpr bool kHasDropout = HstuAttentionPipeline::Problem::kHasDropout;
static constexpr bool kHasCausalMask = HstuAttentionPipeline::Problem::kHasCausal;
static constexpr bool kUseTrLoad = HstuAttentionPipeline::Problem::kUseTrLoad;
static constexpr bool kPadSeqLenQ = HstuAttentionPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = HstuAttentionPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQK = HstuAttentionPipeline::kPadHeadDimQK;
static constexpr bool kPadHeadDimV = HstuAttentionPipeline::kPadHeadDimV;
static constexpr bool kUseTrLoad = detail::is_using_trload_v<HstuAttentionPipeline>;
template <ck_tile::index_t I> // to avoid duplicated base class problem, introduce an template
// arg
struct HstuAttentionFwdEmptyKargs

View File

@@ -45,10 +45,10 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
return WG::WarpGemmAttribute::kKPerThread;
};
template <typename Problem>
template <typename Problem, bool kUseTrLoad = false>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetKVWarpGemmKPerThreadSize()
{
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem, kUseTrLoad>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
@@ -58,26 +58,14 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution()
{
if constexpr(!Problem::kUseTrLoad)
{
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
return BlockGemm::template MakeABlockTileDistribution<
Problem::HstuAttentionTileSetting::kM0,
Problem::HstuAttentionTileSetting::kN0>();
}
else
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto bias_block_dstr_encode = BlockGemm::template MakeCBlockDistributionEncode<
Problem::HstuAttentionTileSetting::kM0,
Problem::HstuAttentionTileSetting::kN0>();
constexpr auto bias_block_dstr = make_static_tile_distribution(bias_block_dstr_encode);
constexpr auto bias_block_dstr_encode =
BlockGemm::template MakeCBlockDistributionEncode<
Problem::HstuAttentionTileSetting::kM0,
Problem::HstuAttentionTileSetting::kN0>();
constexpr auto bias_block_dstr = make_static_tile_distribution(bias_block_dstr_encode);
return bias_block_dstr;
};
return bias_block_dstr;
}
template <typename Problem>
@@ -117,20 +105,20 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
return Problem::GetKDramTileAccessMaxVectorSize();
}
template <typename Problem>
template <typename Problem, bool kUseTrLoad = false>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
{
if constexpr(GetKVWarpGemmKPerThreadSize<Problem>() >= 8)
if constexpr(GetKVWarpGemmKPerThreadSize<Problem, kUseTrLoad>() >= 8)
return 8;
else
return 4;
}
template <typename Problem>
template <typename Problem, bool kUseTrLoad = false>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
{
// special consideration when shuffling is required before storing V to LDS
if constexpr(!Problem::kUseTrLoad)
if constexpr(!kUseTrLoad)
{
using VDataType = remove_cvref_t<typename Problem::QKVDataType>;
@@ -183,13 +171,13 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
};
};
template <typename Problem>
template <typename Problem, bool kUseTrLoad = false>
CK_TILE_HOST_DEVICE static constexpr auto GetVSingleSmemElementSpaceSize()
{
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
if constexpr(!Problem::kUseTrLoad)
if constexpr(!kUseTrLoad)
{
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
@@ -203,14 +191,14 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
};
};
template <typename Problem>
template <typename Problem, bool kPipelineUseTrLoad = false>
CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize()
{
return max(GetKSingleSmemElementSpaceSize<Problem>(),
GetVSingleSmemElementSpaceSize<Problem>());
GetVSingleSmemElementSpaceSize<Problem, kPipelineUseTrLoad>());
};
template <typename Problem>
template <typename Problem, bool kPipelineUseTrLoad = false>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
{
constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers<Problem>();
@@ -219,6 +207,9 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr index_t kKVector = GetAlignmentK<Problem>();
constexpr index_t SingleSmemElementSpaceSize =
GetSingleSmemElementSpaceSize<Problem, kPipelineUseTrLoad>();
// for hdim96 and hdim160, use simplest layout
if constexpr(kKPerBlock < Problem::HstuAttentionTileSetting::kSubQKHeaddim)
{
@@ -226,8 +217,6 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize<Problem>());
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumKLdsBuffers>{}, number<kNPerBlock>{}, number<kKPerBlock>{}),
make_tuple(number<SingleSmemElementSpaceSize>{}, number<kKPerBlock>{}, number<1>{}),
@@ -252,8 +241,6 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize<Problem>());
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
using KDataType = remove_cvref_t<typename Problem::QKVDataType>;
constexpr index_t DataTypeSize = sizeof(KDataType);
@@ -322,8 +309,6 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize<Problem>());
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
constexpr auto k_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumKLdsBuffers>{},
number<kKPerBlock / kKVector>{},
@@ -405,7 +390,7 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
};
}
template <typename Problem>
template <typename Problem, bool kUseTrLoad = false>
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
{
constexpr index_t NumVLdsBuffers = GetNumKVLdsBuffers<Problem>();
@@ -413,7 +398,7 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
if constexpr(!Problem::kUseTrLoad)
if constexpr(!kUseTrLoad)
{
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
@@ -456,14 +441,15 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
}
else
{
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr index_t kKPack = GetSmemKPackV<Problem, true>();
constexpr auto XorGroupSize =
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{});
constexpr index_t VSingleSmemElementSpaceSize = kNPerBlock * kKPerBlock;
static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize<Problem>());
static_assert(VSingleSmemElementSpaceSize ==
GetVSingleSmemElementSpaceSize<Problem, true>());
constexpr auto v_lds_block_desc_naive =
make_naive_tensor_descriptor(make_tuple(number<NumVLdsBuffers>{},
@@ -497,14 +483,14 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
};
}
template <typename Problem>
template <typename Problem, bool kUseTrLoad = false>
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
if constexpr(!Problem::kUseTrLoad)
if constexpr(!kUseTrLoad)
{
constexpr index_t NPerThread = GetAlignmentV<Problem>();
constexpr index_t NThreads = kNPerBlock / NPerThread;
@@ -526,7 +512,7 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
}
else
{
constexpr index_t NPerThread = GetAlignmentV<Problem>();
constexpr index_t NPerThread = GetAlignmentV<Problem, true>();
constexpr index_t NThreads = kNPerBlock / NPerThread;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
@@ -652,7 +638,7 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
Problem::HstuAttentionTileSetting::Gemm1BlockWarps::at(number<1>{});
};
template <typename Problem>
template <typename Problem, bool kUseTrLoad = false>
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
{
using GemmProblem = BlockGemmProblem<
@@ -727,7 +713,7 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
typename Problem::HstuAttentionTileSetting::Gemm1BlockWarps,
WarpGemm>;
if constexpr(!Problem::kUseTrLoad)
if constexpr(!kUseTrLoad)
{
return BlockGemmARegBSmemCRegV2Hack_1<GemmProblem, BlockGemmPolicy>{};
}
@@ -737,22 +723,22 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
};
}
template <typename Problem>
template <typename Problem, bool kUseTrLoad = false>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
{
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem, kUseTrLoad>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
}
template <typename Problem>
template <typename Problem, bool kPipelineUseTrLoad = false>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
{
constexpr index_t num_kv_lds_buffers = GetNumKVLdsBuffers<Problem>();
return num_kv_lds_buffers * GetSingleSmemElementSpaceSize<Problem>() *
return num_kv_lds_buffers * GetSingleSmemElementSpaceSize<Problem, kPipelineUseTrLoad>() *
sizeof(typename Problem::QKVDataType);
};
@@ -762,10 +748,10 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
return 0;
};
template <typename Problem>
template <typename Problem, bool kPipelineUseTrLoad = false>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>();
return GetSmemSizeKV<Problem, kPipelineUseTrLoad>() + GetSmemSizeDropout<Problem>();
}
};

View File

@@ -13,6 +13,7 @@
#include <variant>
#include "hstu_block_masking.hpp"
#include "hstu_attention_util.hpp"
#ifndef HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
#define HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM 1
@@ -48,13 +49,14 @@ struct HstuAttentionFwdSplitKVKernel
static constexpr auto kHasBias = HstuAttentionPipeline::Problem::kHasBias;
static constexpr bool kHasDropout = HstuAttentionPipeline::Problem::kHasDropout;
static constexpr bool kHasCausalMask = HstuAttentionPipeline::Problem::kHasCausal;
static constexpr bool kUseTrLoad = HstuAttentionPipeline::Problem::kUseTrLoad;
static constexpr bool kPadSeqLenQ = HstuAttentionPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = HstuAttentionPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQK = HstuAttentionPipeline::kPadHeadDimQK;
static constexpr bool kPadHeadDimV = HstuAttentionPipeline::kPadHeadDimV;
static constexpr bool kUseTrLoad = detail::is_using_trload_v<HstuAttentionPipeline>;
template <ck_tile::index_t I> // to avoid duplicated base class problem, introduce an template
// arg
struct HstuAttentionFwdEmptyKargs

View File

@@ -39,9 +39,9 @@ struct group_forward_causal_softmax_bias_dropout_dispatch
HstuAttentionNoSoftmaxFwdTileSetting<MaxK, MTile>>::Type;
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
static constexpr bool kUseTrLoad = true;
static constexpr bool use_trload_pipeline = true;
#else
static constexpr bool kUseTrLoad = false;
static constexpr bool use_trload_pipeline = false;
#endif
template <bool kIsCrossAttention>
@@ -57,7 +57,6 @@ struct group_forward_causal_softmax_bias_dropout_dispatch
kHasDropout,
kUseCausal,
kUseSoftmax,
kUseTrLoad,
HstuAttentionTileSetting>;
static void Run(HstuAttentionGroupFwdParams& param, hipStream_t stream)
@@ -89,7 +88,7 @@ struct group_forward_causal_softmax_bias_dropout_dispatch
BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] {
using HstuPipelineProblem = HstuPipelineProblemTemp<kIsCrossAttention>;
if constexpr(!kUseTrLoad)
if constexpr(!use_trload_pipeline)
{
using HstuPipeline = std::conditional_t<
kUseSoftmax,

View File

@@ -45,9 +45,9 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
typename HstuAttentionFwdSplitKVCombineTileSetting<MaxK>::Type;
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
static constexpr bool kUseTrLoad = true;
static constexpr bool use_trload_pipeline = true;
#else
static constexpr bool kUseTrLoad = false;
static constexpr bool use_trload_pipeline = false;
#endif
template <bool kIsCrossAttention>
@@ -63,7 +63,6 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
kHasDropout,
kUseCausal,
kUseSoftmax,
kUseTrLoad,
HstuAttentionFwdTileSetting>;
using OaccDataType = HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType;
@@ -101,7 +100,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] {
using HstuPipelineProblem = HstuFwdPipelineProblemTemp<kIsCrossAttention>;
if constexpr(!kUseTrLoad)
if constexpr(!use_trload_pipeline)
{
using HstuPipeline = std::conditional_t<
kUseSoftmax,

View File

@@ -39,9 +39,9 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
HstuAttentionNoSoftmaxFwdTileSetting<MaxK, MTile>>::Type;
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
static constexpr bool kUseTrLoad = true;
static constexpr bool use_trload_pipeline = true;
#else
static constexpr bool kUseTrLoad = false;
static constexpr bool use_trload_pipeline = false;
#endif
template <bool kIsCrossAttention>
@@ -57,7 +57,6 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
kHasDropout,
kUseCausal,
kUseSoftmax,
kUseTrLoad,
HstuAttentionTileSetting>;
static void Run(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
@@ -89,7 +88,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] {
using HstuPipelineProblem = HstuPipelineProblemTemp<kIsCrossAttention>;
if constexpr(!kUseTrLoad)
if constexpr(!use_trload_pipeline)
{
using HstuPipeline = std::conditional_t<
kUseSoftmax,

View File

@@ -44,9 +44,9 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
using HstuAttentionCombineTileSetting = HstuAttentionFwdSplitKVCombineTileSetting<MaxK>::Type;
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
static constexpr bool kUseTrLoad = true;
static constexpr bool use_trload_pipeline = true;
#else
static constexpr bool kUseTrLoad = false;
static constexpr bool use_trload_pipeline = false;
#endif
template <bool kIsCrossAttention>
@@ -62,7 +62,6 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
kHasDropout,
kUseCausal,
kUseSoftmax,
kUseTrLoad,
HstuAttentionFwdTileSetting>;
using OaccDataType = HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType;
@@ -100,7 +99,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] {
using HstuPipelineProblem = HstuFwdPipelineProblemTemp<kIsCrossAttention>;
if constexpr(!kUseTrLoad)
if constexpr(!use_trload_pipeline)
{
using HstuPipeline = std::conditional_t<
kUseSoftmax,

View File

@@ -46,8 +46,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr bool kHasCausal = Problem::kHasCausal;
static_assert(Problem::kUseTrLoad == false, "Check failed!");
static constexpr bool kUseTrLoad = false;
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;

View File

@@ -46,8 +46,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr bool kHasCausal = Problem::kHasCausal;
static_assert(Problem::kUseTrLoad == true, "Check failed!");
static constexpr bool kUseTrLoad = true;
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
@@ -62,10 +60,10 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
static constexpr index_t kAlignmentK =
kPadHeadDimQK ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem, true /*kUseTrLoad*/>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem, true /*kUseTrLoad */>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
@@ -110,7 +108,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
return Policy::template GetSmemSize<Problem, true /*kPipelineUseTrLoad*/>();
}
template <typename QDramBlockWindowTmp,
@@ -166,7 +164,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem, true /*kUseTrLoad*/>();
// SaccBlockTile size is [kM0, kN0Sub]
// PcompBlockTile size is [kM0, kN0]
@@ -220,9 +218,13 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
// K tile in LDS
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
k_lds_ptr,
Policy::template MakeKLdsBlockDescriptor<Problem, true /*kPipelineUseTrLoad*/>());
auto k_lds_window = make_tile_window(
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
k_lds,
Policy::template MakeKLdsBlockDescriptor<Problem, true /*kPipelineUseTrLoad*/>()
.get_lengths(),
{0, 0});
using k_lds_window_type = decltype(get_slice_tile(
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
@@ -238,9 +240,11 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<QKVDataType*>(smem_ptr),
Policy::template MakeVLdsBlockDescriptor<Problem>());
Policy::template MakeVLdsBlockDescriptor<Problem, true /*kUseTrLoad*/>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
v_lds,
Policy::template MakeVLdsBlockDescriptor<Problem, true /*kUseTrLoad*/>().get_lengths(),
{0, 0});
using v_lds_window_type =
decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kK1, kN1>{}));
@@ -252,11 +256,11 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
v_lds_window, sequence<i_buf * kK1, 0>{}, sequence<(i_buf + 1) * kK1, kN1>{});
});
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0},
Policy::template MakeVDramTileDistribution<Problem>());
auto v_dram_window = make_tile_window(
v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0},
Policy::template MakeVDramTileDistribution<Problem, true /*kUseTrLoad*/>());
// reduction function for softmax
const auto f_silu = [&](CompDataType& x) {
@@ -397,7 +401,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
if constexpr(kHasDropout)
{
auto randval_lds_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
reinterpret_cast<char*>(smem_ptr) +
Policy::template GetSmemSizeKV<Problem, true /*kPipelineUseTrLoad*/>();
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);

View File

@@ -57,8 +57,6 @@ CK_TILE_HOST_DEVICE static constexpr auto GetDramTileAccessMaxVectorSize()
// but it also contains other information needed by the pipeline, which includes
// TileShape -- which determines how block-layer calculation is done in tiles and
// how warps are allocated on dimensions
// Traits -- other information required for running the kernel and pipeline
template <typename InOutDataType_,
typename GemmAccDataType_,
typename CompDataType_, // data type for SiLU and other non-linear calculation
@@ -70,7 +68,6 @@ template <typename InOutDataType_,
bool kHasDropout_,
bool kHasCausal_,
bool kUseSoftmax_,
bool kUseTrLoad_, // use transposed loading to load V tile from lds to vgprs
typename AttentionTileSetting_>
struct HstuAttentionFwdPipelineProblem
{
@@ -94,7 +91,6 @@ struct HstuAttentionFwdPipelineProblem
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kHasCausal = kHasCausal_;
static constexpr bool kUseSoftmax = kUseSoftmax_;
static constexpr bool kUseTrLoad = kUseTrLoad_;
static_assert(!kUseGroup || (kUseGroup && kIsJagged),
"Group HSTU is only used with jagged mode!");

View File

@@ -30,3 +30,29 @@ static inline int get_number_of_cu()
return props.multiProcessorCount;
}
namespace ck_tile {
namespace detail {
// A helper struct for detecting kUseTrLoad
// T is the pipeline class used by the kernel instance
template <typename T, typename = void>
struct has_use_trload_flag : std::false_type
{
};
template <typename T>
struct has_use_trload_flag<
T,
std::enable_if_t<std::is_convertible_v<decltype(T::kUseTrLoad), bool> && T::kUseTrLoad>>
: std::true_type
{
};
template <typename T>
static inline constexpr bool is_using_trload_v = has_use_trload_flag<T>::value;
} // namespace detail
} // namespace ck_tile

View File

@@ -46,8 +46,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr bool kHasCausal = Problem::kHasCausal;
static_assert(Problem::kUseTrLoad == false, "Check failed!");
static constexpr bool kUseTrLoad = false;
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;

View File

@@ -46,8 +46,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr bool kHasCausal = Problem::kHasCausal;
static_assert(Problem::kUseTrLoad == true, "Check failed!");
static constexpr bool kUseTrLoad = true;
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
@@ -62,10 +60,10 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
static constexpr index_t kAlignmentK =
kPadHeadDimQK ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem, true /*kUseTrLoad*/>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem, true /*kUseTrLoad*/>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
@@ -110,7 +108,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
return Policy::template GetSmemSize<Problem, true /*kPipelineUseTrLoad*/>();
}
template <typename QDramBlockWindowTmp,
@@ -169,7 +167,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
constexpr auto gemm_1 =
Policy::template GetKVBlockGemm<Problem, true /*kPipelineUseTrLoad*/>();
// SaccBlockTile size is [kM0, kN0Sub]
// PcompBlockTile size is [kM0, kN0]
@@ -230,9 +229,13 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
// K tile in LDS
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
k_lds_ptr,
Policy::template MakeKLdsBlockDescriptor<Problem, true /*kPipelineUseTrLoad*/>());
auto k_lds_window = make_tile_window(
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
k_lds,
Policy::template MakeKLdsBlockDescriptor<Problem, true /*kPipelineUseTrLoad*/>()
.get_lengths(),
{0, 0});
using k_lds_window_type = decltype(get_slice_tile(
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
@@ -248,9 +251,11 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<QKVDataType*>(smem_ptr),
Policy::template MakeVLdsBlockDescriptor<Problem>());
Policy::template MakeVLdsBlockDescriptor<Problem, true /*kUseTrLoad*/>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
v_lds,
Policy::template MakeVLdsBlockDescriptor<Problem, true /*kUseTrLoad*/>().get_lengths(),
{0, 0});
using v_lds_window_type =
decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kK1, kN1>{}));
@@ -262,11 +267,11 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
v_lds_window, sequence<i_buf * kK1, 0>{}, sequence<(i_buf + 1) * kK1, kN1>{});
});
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0},
Policy::template MakeVDramTileDistribution<Problem>());
auto v_dram_window = make_tile_window(
v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0},
Policy::template MakeVDramTileDistribution<Problem, true /*kUseTrLoad*/>());
const auto f_exp = [&](CompDataType x) {
if constexpr(std::is_same_v<CompDataType, float>)
@@ -508,7 +513,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
if constexpr(kHasDropout)
{
auto randval_lds_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
reinterpret_cast<char*>(smem_ptr) +
Policy::template GetSmemSizeKV<Problem, true /*kPipelineUseTrLoad*/>();
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);