mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
Remove exposing kUseTrLoad as template parameter of pipeline problem
This commit is contained in:
@@ -39,9 +39,9 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
|
||||
HstuAttentionNoSoftmaxFwdTileSetting<MaxK, MTile>>::Type;
|
||||
|
||||
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
|
||||
static constexpr bool kUseTrLoad = true;
|
||||
static constexpr bool use_trload_pipeline = true;
|
||||
#else
|
||||
static constexpr bool kUseTrLoad = false;
|
||||
static constexpr bool use_trload_pipeline = false;
|
||||
#endif
|
||||
|
||||
template <bool kIsCrossAttention>
|
||||
@@ -57,7 +57,6 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
|
||||
kHasDropout,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kUseTrLoad,
|
||||
HstuAttentionTileSetting>;
|
||||
|
||||
static void Run(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
|
||||
@@ -96,7 +95,7 @@ struct batched_forward_causal_softmax_bias_dropout_dispatch
|
||||
BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] {
|
||||
using HstuPipelineProblem = HstuPipelineProblemTemp<kIsCrossAttention>;
|
||||
|
||||
if constexpr(!kUseTrLoad)
|
||||
if constexpr(!use_trload_pipeline)
|
||||
{
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseSoftmax,
|
||||
|
||||
@@ -44,9 +44,9 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
using HstuAttentionCombineTileSetting = HstuAttentionFwdSplitKVCombineTileSetting<MaxK>::Type;
|
||||
|
||||
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
|
||||
static constexpr bool kUseTrLoad = true;
|
||||
static constexpr bool use_trload_pipeline = true;
|
||||
#else
|
||||
static constexpr bool kUseTrLoad = false;
|
||||
static constexpr bool use_trload_pipeline = false;
|
||||
#endif
|
||||
|
||||
template <bool kIsCrossAttention>
|
||||
@@ -62,7 +62,6 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
kHasDropout,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kUseTrLoad,
|
||||
HstuAttentionFwdTileSetting>;
|
||||
|
||||
using OaccDataType = HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType;
|
||||
@@ -107,7 +106,7 @@ struct batched_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] {
|
||||
using HstuPipelineProblem = HstuFwdPipelineProblemTemp<kIsCrossAttention>;
|
||||
|
||||
if constexpr(!kUseTrLoad)
|
||||
if constexpr(!use_trload_pipeline)
|
||||
{
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseSoftmax,
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include <variant>
|
||||
|
||||
#include "hstu_block_masking.hpp"
|
||||
#include "hstu_attention_util.hpp"
|
||||
|
||||
#ifndef HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
|
||||
#define HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM 1
|
||||
@@ -45,13 +46,14 @@ struct HstuAttentionFwdKernel
|
||||
static constexpr auto kHasBias = HstuAttentionPipeline::Problem::kHasBias;
|
||||
static constexpr bool kHasDropout = HstuAttentionPipeline::Problem::kHasDropout;
|
||||
static constexpr bool kHasCausalMask = HstuAttentionPipeline::Problem::kHasCausal;
|
||||
static constexpr bool kUseTrLoad = HstuAttentionPipeline::Problem::kUseTrLoad;
|
||||
|
||||
static constexpr bool kPadSeqLenQ = HstuAttentionPipeline::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = HstuAttentionPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQK = HstuAttentionPipeline::kPadHeadDimQK;
|
||||
static constexpr bool kPadHeadDimV = HstuAttentionPipeline::kPadHeadDimV;
|
||||
|
||||
static constexpr bool kUseTrLoad = detail::is_using_trload_v<HstuAttentionPipeline>;
|
||||
|
||||
template <ck_tile::index_t I> // to avoid duplicated base class problem, introduce an template
|
||||
// arg
|
||||
struct HstuAttentionFwdEmptyKargs
|
||||
|
||||
@@ -45,10 +45,10 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
return WG::WarpGemmAttribute::kKPerThread;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, bool kUseTrLoad = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetKVWarpGemmKPerThreadSize()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
|
||||
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem, kUseTrLoad>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
@@ -58,26 +58,14 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution()
|
||||
{
|
||||
if constexpr(!Problem::kUseTrLoad)
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
|
||||
return BlockGemm::template MakeABlockTileDistribution<
|
||||
Problem::HstuAttentionTileSetting::kM0,
|
||||
Problem::HstuAttentionTileSetting::kN0>();
|
||||
}
|
||||
else
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto bias_block_dstr_encode = BlockGemm::template MakeCBlockDistributionEncode<
|
||||
Problem::HstuAttentionTileSetting::kM0,
|
||||
Problem::HstuAttentionTileSetting::kN0>();
|
||||
constexpr auto bias_block_dstr = make_static_tile_distribution(bias_block_dstr_encode);
|
||||
|
||||
constexpr auto bias_block_dstr_encode =
|
||||
BlockGemm::template MakeCBlockDistributionEncode<
|
||||
Problem::HstuAttentionTileSetting::kM0,
|
||||
Problem::HstuAttentionTileSetting::kN0>();
|
||||
constexpr auto bias_block_dstr = make_static_tile_distribution(bias_block_dstr_encode);
|
||||
|
||||
return bias_block_dstr;
|
||||
};
|
||||
return bias_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -117,20 +105,20 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
return Problem::GetKDramTileAccessMaxVectorSize();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, bool kUseTrLoad = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
|
||||
{
|
||||
if constexpr(GetKVWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
if constexpr(GetKVWarpGemmKPerThreadSize<Problem, kUseTrLoad>() >= 8)
|
||||
return 8;
|
||||
else
|
||||
return 4;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, bool kUseTrLoad = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
|
||||
{
|
||||
// special consideration when shuffling is required before storing V to LDS
|
||||
if constexpr(!Problem::kUseTrLoad)
|
||||
if constexpr(!kUseTrLoad)
|
||||
{
|
||||
using VDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
@@ -183,13 +171,13 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, bool kUseTrLoad = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVSingleSmemElementSpaceSize()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
|
||||
if constexpr(!Problem::kUseTrLoad)
|
||||
if constexpr(!kUseTrLoad)
|
||||
{
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
@@ -203,14 +191,14 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, bool kPipelineUseTrLoad = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSingleSmemElementSpaceSize()
|
||||
{
|
||||
return max(GetKSingleSmemElementSpaceSize<Problem>(),
|
||||
GetVSingleSmemElementSpaceSize<Problem>());
|
||||
GetVSingleSmemElementSpaceSize<Problem, kPipelineUseTrLoad>());
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, bool kPipelineUseTrLoad = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers<Problem>();
|
||||
@@ -219,6 +207,9 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentK<Problem>();
|
||||
|
||||
constexpr index_t SingleSmemElementSpaceSize =
|
||||
GetSingleSmemElementSpaceSize<Problem, kPipelineUseTrLoad>();
|
||||
|
||||
// for hdim96 and hdim160, use simplest layout
|
||||
if constexpr(kKPerBlock < Problem::HstuAttentionTileSetting::kSubQKHeaddim)
|
||||
{
|
||||
@@ -226,8 +217,6 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
|
||||
static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize<Problem>());
|
||||
|
||||
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
|
||||
|
||||
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumKLdsBuffers>{}, number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
make_tuple(number<SingleSmemElementSpaceSize>{}, number<kKPerBlock>{}, number<1>{}),
|
||||
@@ -252,8 +241,6 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
|
||||
static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize<Problem>());
|
||||
|
||||
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
|
||||
|
||||
using KDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
constexpr index_t DataTypeSize = sizeof(KDataType);
|
||||
@@ -322,8 +309,6 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
|
||||
static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize<Problem>());
|
||||
|
||||
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
|
||||
|
||||
constexpr auto k_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumKLdsBuffers>{},
|
||||
number<kKPerBlock / kKVector>{},
|
||||
@@ -405,7 +390,7 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, bool kUseTrLoad = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t NumVLdsBuffers = GetNumKVLdsBuffers<Problem>();
|
||||
@@ -413,7 +398,7 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
|
||||
if constexpr(!Problem::kUseTrLoad)
|
||||
if constexpr(!kUseTrLoad)
|
||||
{
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
@@ -456,14 +441,15 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem, true>();
|
||||
|
||||
constexpr auto XorGroupSize =
|
||||
Problem::HstuAttentionTileSetting::Gemm1WarpTile::at(number<0>{});
|
||||
|
||||
constexpr index_t VSingleSmemElementSpaceSize = kNPerBlock * kKPerBlock;
|
||||
|
||||
static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize<Problem>());
|
||||
static_assert(VSingleSmemElementSpaceSize ==
|
||||
GetVSingleSmemElementSpaceSize<Problem, true>());
|
||||
|
||||
constexpr auto v_lds_block_desc_naive =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumVLdsBuffers>{},
|
||||
@@ -497,14 +483,14 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, bool kUseTrLoad = false>
|
||||
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
|
||||
if constexpr(!Problem::kUseTrLoad)
|
||||
if constexpr(!kUseTrLoad)
|
||||
{
|
||||
constexpr index_t NPerThread = GetAlignmentV<Problem>();
|
||||
constexpr index_t NThreads = kNPerBlock / NPerThread;
|
||||
@@ -526,7 +512,7 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t NPerThread = GetAlignmentV<Problem>();
|
||||
constexpr index_t NPerThread = GetAlignmentV<Problem, true>();
|
||||
constexpr index_t NThreads = kNPerBlock / NPerThread;
|
||||
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
@@ -652,7 +638,7 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
Problem::HstuAttentionTileSetting::Gemm1BlockWarps::at(number<1>{});
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, bool kUseTrLoad = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
|
||||
{
|
||||
using GemmProblem = BlockGemmProblem<
|
||||
@@ -727,7 +713,7 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
typename Problem::HstuAttentionTileSetting::Gemm1BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
if constexpr(!Problem::kUseTrLoad)
|
||||
if constexpr(!kUseTrLoad)
|
||||
{
|
||||
return BlockGemmARegBSmemCRegV2Hack_1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
@@ -737,22 +723,22 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, bool kUseTrLoad = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
|
||||
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem, kUseTrLoad>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, bool kPipelineUseTrLoad = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKV()
|
||||
{
|
||||
constexpr index_t num_kv_lds_buffers = GetNumKVLdsBuffers<Problem>();
|
||||
|
||||
return num_kv_lds_buffers * GetSingleSmemElementSpaceSize<Problem>() *
|
||||
return num_kv_lds_buffers * GetSingleSmemElementSpaceSize<Problem, kPipelineUseTrLoad>() *
|
||||
sizeof(typename Problem::QKVDataType);
|
||||
};
|
||||
|
||||
@@ -762,10 +748,10 @@ struct HstuAttentionFwdPipelineQRKSVSPolicy
|
||||
return 0;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
template <typename Problem, bool kPipelineUseTrLoad = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>();
|
||||
return GetSmemSizeKV<Problem, kPipelineUseTrLoad>() + GetSmemSizeDropout<Problem>();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include <variant>
|
||||
|
||||
#include "hstu_block_masking.hpp"
|
||||
#include "hstu_attention_util.hpp"
|
||||
|
||||
#ifndef HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM
|
||||
#define HSTU_SCHED_BATCH_AS_FIRST_GRID_DIM 1
|
||||
@@ -48,13 +49,14 @@ struct HstuAttentionFwdSplitKVKernel
|
||||
static constexpr auto kHasBias = HstuAttentionPipeline::Problem::kHasBias;
|
||||
static constexpr bool kHasDropout = HstuAttentionPipeline::Problem::kHasDropout;
|
||||
static constexpr bool kHasCausalMask = HstuAttentionPipeline::Problem::kHasCausal;
|
||||
static constexpr bool kUseTrLoad = HstuAttentionPipeline::Problem::kUseTrLoad;
|
||||
|
||||
static constexpr bool kPadSeqLenQ = HstuAttentionPipeline::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = HstuAttentionPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQK = HstuAttentionPipeline::kPadHeadDimQK;
|
||||
static constexpr bool kPadHeadDimV = HstuAttentionPipeline::kPadHeadDimV;
|
||||
|
||||
static constexpr bool kUseTrLoad = detail::is_using_trload_v<HstuAttentionPipeline>;
|
||||
|
||||
template <ck_tile::index_t I> // to avoid duplicated base class problem, introduce an template
|
||||
// arg
|
||||
struct HstuAttentionFwdEmptyKargs
|
||||
|
||||
@@ -39,9 +39,9 @@ struct group_forward_causal_softmax_bias_dropout_dispatch
|
||||
HstuAttentionNoSoftmaxFwdTileSetting<MaxK, MTile>>::Type;
|
||||
|
||||
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
|
||||
static constexpr bool kUseTrLoad = true;
|
||||
static constexpr bool use_trload_pipeline = true;
|
||||
#else
|
||||
static constexpr bool kUseTrLoad = false;
|
||||
static constexpr bool use_trload_pipeline = false;
|
||||
#endif
|
||||
|
||||
template <bool kIsCrossAttention>
|
||||
@@ -57,7 +57,6 @@ struct group_forward_causal_softmax_bias_dropout_dispatch
|
||||
kHasDropout,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kUseTrLoad,
|
||||
HstuAttentionTileSetting>;
|
||||
|
||||
static void Run(HstuAttentionGroupFwdParams& param, hipStream_t stream)
|
||||
@@ -89,7 +88,7 @@ struct group_forward_causal_softmax_bias_dropout_dispatch
|
||||
BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] {
|
||||
using HstuPipelineProblem = HstuPipelineProblemTemp<kIsCrossAttention>;
|
||||
|
||||
if constexpr(!kUseTrLoad)
|
||||
if constexpr(!use_trload_pipeline)
|
||||
{
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseSoftmax,
|
||||
|
||||
@@ -45,9 +45,9 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
typename HstuAttentionFwdSplitKVCombineTileSetting<MaxK>::Type;
|
||||
|
||||
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
|
||||
static constexpr bool kUseTrLoad = true;
|
||||
static constexpr bool use_trload_pipeline = true;
|
||||
#else
|
||||
static constexpr bool kUseTrLoad = false;
|
||||
static constexpr bool use_trload_pipeline = false;
|
||||
#endif
|
||||
|
||||
template <bool kIsCrossAttention>
|
||||
@@ -63,7 +63,6 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
kHasDropout,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kUseTrLoad,
|
||||
HstuAttentionFwdTileSetting>;
|
||||
|
||||
using OaccDataType = HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType;
|
||||
@@ -101,7 +100,7 @@ struct group_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] {
|
||||
using HstuPipelineProblem = HstuFwdPipelineProblemTemp<kIsCrossAttention>;
|
||||
|
||||
if constexpr(!kUseTrLoad)
|
||||
if constexpr(!use_trload_pipeline)
|
||||
{
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseSoftmax,
|
||||
|
||||
@@ -39,9 +39,9 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
|
||||
HstuAttentionNoSoftmaxFwdTileSetting<MaxK, MTile>>::Type;
|
||||
|
||||
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
|
||||
static constexpr bool kUseTrLoad = true;
|
||||
static constexpr bool use_trload_pipeline = true;
|
||||
#else
|
||||
static constexpr bool kUseTrLoad = false;
|
||||
static constexpr bool use_trload_pipeline = false;
|
||||
#endif
|
||||
|
||||
template <bool kIsCrossAttention>
|
||||
@@ -57,7 +57,6 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
|
||||
kHasDropout,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kUseTrLoad,
|
||||
HstuAttentionTileSetting>;
|
||||
|
||||
static void Run(HstuAttentionNoGroupFwdParams& param, hipStream_t stream)
|
||||
@@ -89,7 +88,7 @@ struct jagged_forward_causal_softmax_bias_dropout_dispatch
|
||||
BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] {
|
||||
using HstuPipelineProblem = HstuPipelineProblemTemp<kIsCrossAttention>;
|
||||
|
||||
if constexpr(!kUseTrLoad)
|
||||
if constexpr(!use_trload_pipeline)
|
||||
{
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseSoftmax,
|
||||
|
||||
@@ -44,9 +44,9 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
using HstuAttentionCombineTileSetting = HstuAttentionFwdSplitKVCombineTileSetting<MaxK>::Type;
|
||||
|
||||
#ifdef BUILD_HSTU_FOR_GFX95_ONLY
|
||||
static constexpr bool kUseTrLoad = true;
|
||||
static constexpr bool use_trload_pipeline = true;
|
||||
#else
|
||||
static constexpr bool kUseTrLoad = false;
|
||||
static constexpr bool use_trload_pipeline = false;
|
||||
#endif
|
||||
|
||||
template <bool kIsCrossAttention>
|
||||
@@ -62,7 +62,6 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
kHasDropout,
|
||||
kUseCausal,
|
||||
kUseSoftmax,
|
||||
kUseTrLoad,
|
||||
HstuAttentionFwdTileSetting>;
|
||||
|
||||
using OaccDataType = HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType;
|
||||
@@ -100,7 +99,7 @@ struct jagged_forward_splitkv_causal_softmax_bias_dropout_dispatch
|
||||
BOOL_SWITCH(param.is_cross_attention, kIsCrossAttention, [&] {
|
||||
using HstuPipelineProblem = HstuFwdPipelineProblemTemp<kIsCrossAttention>;
|
||||
|
||||
if constexpr(!kUseTrLoad)
|
||||
if constexpr(!use_trload_pipeline)
|
||||
{
|
||||
using HstuPipeline = std::conditional_t<
|
||||
kUseSoftmax,
|
||||
|
||||
@@ -46,8 +46,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr bool kHasCausal = Problem::kHasCausal;
|
||||
|
||||
static_assert(Problem::kUseTrLoad == false, "Check failed!");
|
||||
|
||||
static constexpr bool kUseTrLoad = false;
|
||||
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
|
||||
@@ -46,8 +46,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr bool kHasCausal = Problem::kHasCausal;
|
||||
|
||||
static_assert(Problem::kUseTrLoad == true, "Check failed!");
|
||||
|
||||
static constexpr bool kUseTrLoad = true;
|
||||
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
@@ -62,10 +60,10 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQK ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem, true /*kUseTrLoad*/>();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem, true /*kUseTrLoad */>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
|
||||
@@ -110,7 +108,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
return Policy::template GetSmemSize<Problem, true /*kPipelineUseTrLoad*/>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
@@ -166,7 +164,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem, true /*kUseTrLoad*/>();
|
||||
|
||||
// SaccBlockTile size is [kM0, kN0Sub]
|
||||
// PcompBlockTile size is [kM0, kN0]
|
||||
@@ -220,9 +218,13 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
// K tile in LDS
|
||||
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
k_lds_ptr,
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem, true /*kPipelineUseTrLoad*/>());
|
||||
auto k_lds_window = make_tile_window(
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
k_lds,
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem, true /*kPipelineUseTrLoad*/>()
|
||||
.get_lengths(),
|
||||
{0, 0});
|
||||
|
||||
using k_lds_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
|
||||
@@ -238,9 +240,11 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
// V tile in LDS
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<QKVDataType*>(smem_ptr),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>());
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem, true /*kUseTrLoad*/>());
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
v_lds,
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem, true /*kUseTrLoad*/>().get_lengths(),
|
||||
{0, 0});
|
||||
|
||||
using v_lds_window_type =
|
||||
decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kK1, kN1>{}));
|
||||
@@ -252,11 +256,11 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
v_lds_window, sequence<i_buf * kK1, 0>{}, sequence<(i_buf + 1) * kK1, kN1>{});
|
||||
});
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeVDramTileDistribution<Problem, true /*kUseTrLoad*/>());
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_silu = [&](CompDataType& x) {
|
||||
@@ -397,7 +401,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto randval_lds_ptr =
|
||||
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
|
||||
reinterpret_cast<char*>(smem_ptr) +
|
||||
Policy::template GetSmemSizeKV<Problem, true /*kPipelineUseTrLoad*/>();
|
||||
|
||||
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
|
||||
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
|
||||
|
||||
@@ -57,8 +57,6 @@ CK_TILE_HOST_DEVICE static constexpr auto GetDramTileAccessMaxVectorSize()
|
||||
// but it also contains other information needed by the pipeline, which includes
|
||||
// TileShape -- which determines how block-layer calculation is done in tiles and
|
||||
// how warps are allocated on dimensions
|
||||
// Traits -- other information required for running the kernel and pipeline
|
||||
|
||||
template <typename InOutDataType_,
|
||||
typename GemmAccDataType_,
|
||||
typename CompDataType_, // data type for SiLU and other non-linear calculation
|
||||
@@ -70,7 +68,6 @@ template <typename InOutDataType_,
|
||||
bool kHasDropout_,
|
||||
bool kHasCausal_,
|
||||
bool kUseSoftmax_,
|
||||
bool kUseTrLoad_, // use transposed loading to load V tile from lds to vgprs
|
||||
typename AttentionTileSetting_>
|
||||
struct HstuAttentionFwdPipelineProblem
|
||||
{
|
||||
@@ -94,7 +91,6 @@ struct HstuAttentionFwdPipelineProblem
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
static constexpr bool kHasCausal = kHasCausal_;
|
||||
static constexpr bool kUseSoftmax = kUseSoftmax_;
|
||||
static constexpr bool kUseTrLoad = kUseTrLoad_;
|
||||
|
||||
static_assert(!kUseGroup || (kUseGroup && kIsJagged),
|
||||
"Group HSTU is only used with jagged mode!");
|
||||
|
||||
@@ -30,3 +30,29 @@ static inline int get_number_of_cu()
|
||||
|
||||
return props.multiProcessorCount;
|
||||
}
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
|
||||
// A helper struct for detecting kUseTrLoad
|
||||
// T is the pipeline class used by the kernel instance
|
||||
template <typename T, typename = void>
|
||||
struct has_use_trload_flag : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct has_use_trload_flag<
|
||||
T,
|
||||
std::enable_if_t<std::is_convertible_v<decltype(T::kUseTrLoad), bool> && T::kUseTrLoad>>
|
||||
: std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static inline constexpr bool is_using_trload_v = has_use_trload_flag<T>::value;
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -46,8 +46,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr bool kHasCausal = Problem::kHasCausal;
|
||||
|
||||
static_assert(Problem::kUseTrLoad == false, "Check failed!");
|
||||
|
||||
static constexpr bool kUseTrLoad = false;
|
||||
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
|
||||
@@ -46,8 +46,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr bool kHasCausal = Problem::kHasCausal;
|
||||
|
||||
static_assert(Problem::kUseTrLoad == true, "Check failed!");
|
||||
|
||||
static constexpr bool kUseTrLoad = true;
|
||||
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
@@ -62,10 +60,10 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQK ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV =
|
||||
Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem, true /*kUseTrLoad*/>();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem, true /*kUseTrLoad*/>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
|
||||
@@ -110,7 +108,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
return Policy::template GetSmemSize<Problem, true /*kPipelineUseTrLoad*/>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
@@ -169,7 +167,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 =
|
||||
Policy::template GetKVBlockGemm<Problem, true /*kPipelineUseTrLoad*/>();
|
||||
|
||||
// SaccBlockTile size is [kM0, kN0Sub]
|
||||
// PcompBlockTile size is [kM0, kN0]
|
||||
@@ -230,9 +229,13 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
// K tile in LDS
|
||||
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
k_lds_ptr,
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem, true /*kPipelineUseTrLoad*/>());
|
||||
auto k_lds_window = make_tile_window(
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
k_lds,
|
||||
Policy::template MakeKLdsBlockDescriptor<Problem, true /*kPipelineUseTrLoad*/>()
|
||||
.get_lengths(),
|
||||
{0, 0});
|
||||
|
||||
using k_lds_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
|
||||
@@ -248,9 +251,11 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
// V tile in LDS
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<QKVDataType*>(smem_ptr),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>());
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem, true /*kUseTrLoad*/>());
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
v_lds,
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem, true /*kUseTrLoad*/>().get_lengths(),
|
||||
{0, 0});
|
||||
|
||||
using v_lds_window_type =
|
||||
decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kK1, kN1>{}));
|
||||
@@ -262,11 +267,11 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
v_lds_window, sequence<i_buf * kK1, 0>{}, sequence<(i_buf + 1) * kK1, kN1>{});
|
||||
});
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeVDramTileDistribution<Problem, true /*kUseTrLoad*/>());
|
||||
|
||||
const auto f_exp = [&](CompDataType x) {
|
||||
if constexpr(std::is_same_v<CompDataType, float>)
|
||||
@@ -508,7 +513,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto randval_lds_ptr =
|
||||
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
|
||||
reinterpret_cast<char*>(smem_ptr) +
|
||||
Policy::template GetSmemSizeKV<Problem, true /*kPipelineUseTrLoad*/>();
|
||||
|
||||
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
|
||||
randval_lds_ptr, seqlen_k_curr, pcomp_tile, null_randval_window);
|
||||
|
||||
Reference in New Issue
Block a user