Add support of loading QK tiles of hdim96 without padding to hdim128

This commit is contained in:
Qianfeng Zhang
2025-12-17 16:39:15 +00:00
parent d281c519f3
commit 384f4708a1
4 changed files with 510 additions and 176 deletions

View File

@@ -59,6 +59,25 @@ struct has_ignore_fast_exp2_flag<
template <typename T>
static inline constexpr bool ignore_fast_exp2_v = has_ignore_fast_exp2_flag<T>::value;
// A helper struct for detecting naive_hdim_load, naive_hdim_load means load tiles of
// hdim96/hdim160/hdim192 without padding the tensor_view/tile_window to hdim128/hdim256
// naive_hdim_load is current supported by the qr_ks_vs_whole_k_prefetch_pipeline
template <typename T, typename = void>
struct has_naive_hdim_load_flag : std::false_type
{
};
template <typename T>
struct has_naive_hdim_load_flag<
T,
std::enable_if_t<std::is_convertible_v<decltype(T::kIsNaiveHDimLoad), bool> &&
T::kIsNaiveHDimLoad>> : std::true_type
{
};
template <typename T>
static inline constexpr bool is_naive_hdim_load_v = has_naive_hdim_load_flag<T>::value;
}; // namespace detail
template <typename FmhaPipeline_, typename EpiloguePipeline_>
@@ -1313,6 +1332,10 @@ struct FmhaFwdKernel
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
batch_offset_o;
constexpr index_t kQKHeaddimToUse = detail::is_naive_hdim_load_v<FmhaPipeline>
? FmhaPipeline::kQKHeaddim
: FmhaPipeline::kSubQKHeaddim;
// Q/K/V DRAM and DRAM window
const auto q_dram = [&]() {
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
@@ -1325,7 +1348,7 @@ struct FmhaFwdKernel
{
return pad_tensor_view(q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kSubQKHeaddim>{}),
number<kQKHeaddimToUse>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
else
@@ -1350,7 +1373,7 @@ struct FmhaFwdKernel
{
return pad_tensor_view(k_dram_naive,
make_tuple(number<FmhaPipeline::kN0Sub>{},
number<FmhaPipeline::kSubQKHeaddim>{}),
number<kQKHeaddimToUse>{}),
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
}
else
@@ -1371,18 +1394,29 @@ struct FmhaFwdKernel
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
const auto v_dram_transposed = transform_tensor_view(
v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen_k)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
if constexpr(!kUseTrLoad)
{
const auto v_dram_transposed = transform_tensor_view(
v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen_k)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
return pad_tensor_view(
v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK_>{});
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
return pad_tensor_view(
v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK_>{});
}
else
{
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kK1>{}, number<FmhaPipeline::kN1>{}),
sequence<false, kPadHeadDimV>{});
};
}
else
{
@@ -1406,7 +1440,7 @@ struct FmhaFwdKernel
[&]() {
if constexpr(FmhaPipeline::kQLoadOnce)
return make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kSubQKHeaddim>{});
number<kQKHeaddimToUse>{});
else
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{});
}(),
@@ -1416,8 +1450,8 @@ struct FmhaFwdKernel
if constexpr(detail::is_n0loop_pipeline_v<FmhaPipeline>)
{
return make_tile_window(k_dram,
make_tuple(number<FmhaPipeline::kK1>{},
number<FmhaPipeline::kSubQKHeaddim>{}),
make_tuple(number<FmhaPipeline::kN0Sub>{},
number<kQKHeaddimToUse>{}),
{0, 0});
}
else

View File

@@ -8,6 +8,52 @@
namespace ck_tile {
namespace detail {
template <typename DataType, index_t ElemPerThread>
CK_TILE_HOST_DEVICE static constexpr auto GetMaxVectorSize()
{
if constexpr(std::is_same_v<DataType, half_t> || std::is_same_v<DataType, bf16_t>)
{
// ToDo: need support in ck_tile for using buffer_load_dwordx3
// if constexpr(ElemPerThread % 6 == 0)
// return 6;
if constexpr(ElemPerThread % 8 == 0)
return 8;
else if constexpr(ElemPerThread % 4 == 0)
return 4;
else if constexpr(ElemPerThread % 2 == 0)
return 2;
return 1;
}
else if constexpr(std::is_same_v<DataType, float>)
{
// ToDo: need support in ck_tile for using buffer_load_dwordx3
// if constexpr(ElemPerThread % 3 == 0)
// return 3;
if constexpr(ElemPerThread % 4 == 0)
return 4;
else if constexpr(ElemPerThread % 2 == 0)
return 2;
return 1;
}
else
static_assert(false, "The data type is not supported!");
};
template <typename DataType,
index_t kThreadBlockSize,
index_t kHigherDimSize,
index_t kLowerDimSize>
CK_TILE_HOST_DEVICE static constexpr auto GetDramTileAccessMaxVectorSize()
{
constexpr index_t ElemPerThread = (kHigherDimSize * kLowerDimSize) / kThreadBlockSize;
return GetMaxVectorSize<DataType, ElemPerThread>();
}
}; // namespace detail
template <typename QDataType_,
typename KDataType_,
typename VDataType_,
@@ -50,6 +96,9 @@ struct BlockFmhaPipelineProblem
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kUseTrLoad = kUseTrLoad_;
// ToDo: should we define kUseTrLoad and kLoadWholeQTileOnceThrough Lds here ?
static constexpr bool kLoadWholeQTileOnceThroughLds = kUseTrLoad ? true : false;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
@@ -62,6 +111,33 @@ struct BlockFmhaPipelineProblem
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr auto QScaleEnum = Traits::QScaleEnum;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
CK_TILE_HOST_DEVICE static constexpr auto GetQDramTileAccessMaxVectorSize()
{
constexpr index_t kMPerBlock = BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = BlockFmhaShape::kQKHeaddim;
return detail::
GetDramTileAccessMaxVectorSize<QDataType, kBlockSize, kMPerBlock, kKPerBlock>();
}
CK_TILE_HOST_DEVICE static constexpr auto GetKDramTileAccessMaxVectorSize()
{
constexpr index_t kNPerBlock = BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = BlockFmhaShape::kK0;
return detail::
GetDramTileAccessMaxVectorSize<KDataType, kBlockSize, kNPerBlock, kKPerBlock>();
}
CK_TILE_HOST_DEVICE static constexpr auto GetVDramTileAccessMaxVectorSize()
{
constexpr index_t kNPerBlock = BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = BlockFmhaShape::kK1;
return detail::
GetDramTileAccessMaxVectorSize<VDataType, kBlockSize, kNPerBlock, kKPerBlock>();
};
};
template <typename QDataType_,

View File

@@ -35,8 +35,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
static constexpr bool kQLoadOnce = true;
static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr bool kUseN0Loop = true;
static constexpr bool kIgnoreFastExp2 = true;
static constexpr bool kUseN0Loop = true;
static constexpr bool kIgnoreFastExp2 = true;
static constexpr bool kIsNaiveHDimLoad = true;
static constexpr index_t kBlockSize = Problem::kBlockSize;
@@ -50,14 +51,14 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
@@ -135,9 +136,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
typename AttentionVariantParams,
typename BlockIndices>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func,
@@ -170,8 +171,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kSubQKHeaddim ==
KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
@@ -225,7 +225,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kGemmSingleRepM>{}, number<kSubQKHeaddim>{}),
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramSingleRepMTileDistribution<Problem>());
@@ -235,7 +235,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kN0Sub>{}, number<kSubQKHeaddim>{}),
make_tuple(number<kN0Sub>{}, number<kQKHeaddim>{}),
{seqlen_k_start, 0},
Policy::template MakeKDramTileDistribution<Problem>());
@@ -271,7 +271,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_write_window = make_tile_window(
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
auto q_lds_read_window =
make_tile_window(q_lds,
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
@@ -286,25 +285,15 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
auto k_lds_window = make_tile_window(
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
using k_lds_write_window_type = decltype(get_slice_tile(
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kSubQKHeaddim>{}));
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
using k_lds_read_window_type = decltype(get_slice_tile(
using k_lds_window_type = decltype(get_slice_tile(
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
statically_indexed_array<k_lds_window_type, NumKVLdsBuffers> k_lds_windows;
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
k_lds_write_windows[i_buf] =
get_slice_tile(k_lds_window,
sequence<i_buf * kN0Sub, 0>{},
sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{});
k_lds_read_windows[i_buf] =
get_slice_tile(k_lds_window,
sequence<i_buf * kN0Sub, 0>{},
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
k_lds_windows[i_buf] = get_slice_tile(k_lds_window,
sequence<i_buf * kN0Sub, 0>{},
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
});
// V tile in LDS
@@ -434,7 +423,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
if(seqlen_k_curr < seqlen_k_end - kN0) // not the last iteration
{
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
partition_index);
@@ -460,9 +449,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
};
block_sync_lds();
gemm_0(sacc_tile,
q_tile,
k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
gemm_0(
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
@@ -475,7 +463,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
else // the iteration is also the last iteration
{
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
partition_index);
@@ -492,9 +480,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
};
block_sync_lds();
gemm_0(sacc_tile,
q_tile,
k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
gemm_0(
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
@@ -510,7 +497,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
if(seqlen_k_curr < seqlen_k_end - kN0) // intermediate iteration
{
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
partition_index);
@@ -525,9 +512,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
move_tile_window(k_dram_window, {kN0Sub, 0});
block_sync_lds();
gemm_0(sacc_tile,
q_tile,
k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
gemm_0(
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
@@ -540,7 +526,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
else // last iteration
{
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
partition_index);
@@ -551,9 +537,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
};
block_sync_lds();
gemm_0(sacc_tile,
q_tile,
k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
gemm_0(
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
@@ -568,7 +553,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
else // only preload one unroll of K for next iteration, used when kM0=128
{
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
tile_elementwise_in(k_element_func, k_tiles[I0]),
partition_index);
@@ -590,7 +575,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
block_sync_lds();
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);

View File

@@ -114,16 +114,21 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
if constexpr(Problem::kLoadWholeQTileOnceThroughLds)
{
return Problem::GetQDramTileAccessMaxVectorSize();
}
else
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(QDataType);
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
return min(MaxVectorSize, ElemPerThread);
return detail::
GetDramTileAccessMaxVectorSize<QDataType, kBlockSize, kMPerBlock, kKPerBlock>();
};
}
template <typename Problem>
@@ -142,12 +147,10 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0Sub;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
return min(MaxVectorSize, ElemPerThread);
return detail::
GetDramTileAccessMaxVectorSize<KDataType, kBlockSize, kNPerBlock, kKPerBlock>();
}
template <typename Problem>
@@ -162,23 +165,31 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
// special consideration when shuffling is required before storing V to LDS
if constexpr(!Problem::kUseTrLoad)
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
constexpr index_t kMaxVecLoad = Problem::GetVDramTileAccessMaxVectorSize();
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (ElemPerThread / kMinVecLoad);
// try to avoid writing sub-dword to LDS due to poor performance
constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (ElemPerThread / kMinVecLoad);
return kVecLoad;
return kVecLoad;
}
else
{
return Problem::GetVDramTileAccessMaxVectorSize();
};
}
template <typename Problem>
@@ -195,11 +206,16 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0Sub;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr index_t kKVector = GetAlignmentK<Problem>();
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
// for hdim96 and hdim160
if constexpr(kKPerBlock < Problem::BlockFmhaShape::kSubQKHeaddim)
{
return kKPerBlock * kNPerBlock;
}
else if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
{
static_assert(kKVector == kKPack);
@@ -236,12 +252,23 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t kMPerBlock = Problem::kLoadWholeQTileOnceThroughLds
? Problem::BlockFmhaShape::kM0
: GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
constexpr index_t kKVector = GetAlignmentQ<Problem>();
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
// for hdim96 and hdim160, use simplest layout
if constexpr(kKPerBlock < Problem::BlockFmhaShape::kSubQKHeaddim)
{
return make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
make_tuple(number<kKPerBlock>{}, number<1>{}),
number<kKVector>{},
number<1>{});
}
else if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
{
static_assert(kKVector == kKPack);
@@ -324,25 +351,113 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKVector = GetAlignmentQ<Problem>();
constexpr index_t OtherK = kKPerBlock / kKVector;
constexpr index_t KPerThread = kKVector;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
if constexpr(kKPerBlock == Problem::BlockFmhaShape::kSubQKHeaddim)
// for kKPerBlock=32,64,128,256
{
static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!");
// for Q-Tile [64, 128], the encoding is [4W * 4T * 4E, 16T * 8E]
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NumWarps, MThreadPerWarp, MPerThread>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
constexpr index_t KPerThread = kKVector;
constexpr index_t KThreads = OtherK;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
// for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E]
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
else // for kKPerBlock=96,160
{
static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!");
// ToDo: need more considieration for hdim72
constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5;
constexpr index_t KThreads = OtherK / KRepPerThread;
static_assert((KThreads & (KThreads - 1)) == 0, "Check failed!");
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
sequence<KRepPerThread, KThreads, kKVector>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 1>>,
sequence<1, 2, 2>,
sequence<1, 0, 2>>{});
};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKVector = GetAlignmentQ<Problem>();
constexpr index_t OtherK = kKPerBlock / kKVector;
if constexpr(kKPerBlock == Problem::BlockFmhaShape::kSubQKHeaddim)
// for kKPerBlock=32,64,128,256
{
static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!");
constexpr index_t KPerThread = kKVector;
constexpr index_t KThreads = OtherK;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
// for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E]
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
else // for kKPerBlock=96,160
{
static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!");
// ToDo: need more considieration for hdim72
constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5;
constexpr index_t KThreads = OtherK / KRepPerThread;
static_assert((KThreads & (KThreads - 1)) == 0, "Check failed!");
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
sequence<KRepPerThread, KThreads, kKVector>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 1>>,
sequence<1, 2, 2>,
sequence<1, 0, 2>>{});
};
}
template <typename Problem>
@@ -350,11 +465,36 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
{
constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers<Problem>();
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0Sub;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr index_t kKVector = GetAlignmentK<Problem>();
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
// for hdim96 and hdim160, use simplest layout
if constexpr(kKPerBlock < Problem::BlockFmhaShape::kSubQKHeaddim)
{
constexpr index_t KSingleSmemElementSpaceSize = kNPerBlock * kKPerBlock;
static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize<Problem>());
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumKLdsBuffers>{}, number<kNPerBlock>{}, number<kKPerBlock>{}),
make_tuple(number<SingleSmemElementSpaceSize>{}, number<kKPerBlock>{}, number<1>{}),
number<kKVector>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
k_lds_block_desc_0,
make_tuple(make_merge_transform(
make_tuple(number<NumKLdsBuffers>{}, number<kNPerBlock>{})),
make_pass_through_transform(number<kKPerBlock>{})),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return k_lds_block_desc;
}
else if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
{
static_assert(kKVector == kKPack);
@@ -362,9 +502,15 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
constexpr index_t DataTypeSize = sizeof(KDataType);
#ifdef __gfx950__
// 256 contiguous bytes mapped to 64 banks with each bank 4 contiguous bytes
constexpr auto NLdsLayer =
(64 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (64 * 4 / kKPerBlock / DataTypeSize);
#else
// 128 contiguous bytes mapped to 32 banks with each bank 4 contiguous bytes
constexpr auto NLdsLayer =
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
#endif
constexpr auto k_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumKLdsBuffers>{},
@@ -455,24 +601,52 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0Sub;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKVector = GetAlignmentK<Problem>();
constexpr index_t OtherK = kKPerBlock / kKVector;
constexpr index_t KPerThread = kKVector;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps);
if constexpr(kKPerBlock == Problem::BlockFmhaShape::kSubQKHeaddim)
// for kKPerBlock=32,64,128,256
{
static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NPerThread, NThreadPerWarp, NumWarps>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<2>, sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
constexpr index_t KPerThread = kKVector;
constexpr index_t KThreads = OtherK;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
else // for kKPerBlock=96,160
{
static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!");
constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5;
constexpr index_t KThreads = OtherK / KRepPerThread;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
sequence<KRepPerThread, KThreads, kKVector>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
};
}
template <typename Problem>
@@ -483,43 +657,87 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
if constexpr(!Problem::kUseTrLoad)
{
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
// K2 is the vector size for storing shuffled tile to LDS
constexpr index_t K2 = ElemPerThread / N1;
// K2 is the vector size for storing shuffled tile to LDS
constexpr index_t K2 = ElemPerThread / N1;
// GetSmemKPackV() is the vector size for loading from LDS by BlockGemm
constexpr index_t kKPack = GetSmemKPackV<Problem>();
// GetSmemKPackV() is the vector size for loading from LDS by BlockGemm
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(kKPack >= K2, "Check failed!");
static_assert(kKPack >= K2, "Check failed!");
constexpr index_t VSingleSmemElementSpaceSize = N0 * (N1 * kKPerBlock + kKPack);
constexpr index_t VSingleSmemElementSpaceSize = N0 * (N1 * kKPerBlock + kKPack);
static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize<Problem>());
static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize<Problem>());
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{}, number<kKPerBlock>{}),
make_tuple(number<SingleSmemElementSpaceSize>{},
number<N1 * kKPerBlock + kKPack>{},
number<kKPerBlock>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(
number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{}, number<kKPerBlock>{}),
make_tuple(number<SingleSmemElementSpaceSize>{},
number<N1 * kKPerBlock + kKPack>{},
number<kKPerBlock>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(make_merge_transform(
make_tuple(number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{})),
make_pass_through_transform(number<kKPerBlock>{})),
make_tuple(sequence<0, 1, 2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(make_merge_transform(
make_tuple(number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{})),
make_pass_through_transform(number<kKPerBlock>{})),
make_tuple(sequence<0, 1, 2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return v_lds_block_desc;
return v_lds_block_desc;
}
else
{
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr auto XorGroupSize = Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
constexpr index_t VSingleSmemElementSpaceSize = kNPerBlock * kKPerBlock;
static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize<Problem>());
constexpr auto v_lds_block_desc_naive =
make_naive_tensor_descriptor(make_tuple(number<NumVLdsBuffers>{},
number<kKPerBlock>{},
number<kNPerBlock / XorGroupSize>{},
number<XorGroupSize>{}),
make_tuple(number<VSingleSmemElementSpaceSize>{},
number<kNPerBlock>{},
number<XorGroupSize>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto v_lds_block_desc_permuted = transform_tensor_descriptor(
v_lds_block_desc_naive,
make_tuple(make_pass_through_transform(number<NumVLdsBuffers>{}),
make_xor_transform(make_tuple(number<kKPerBlock>{},
number<kNPerBlock / XorGroupSize>{})),
make_pass_through_transform(number<XorGroupSize>{})),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
return transform_tensor_descriptor(
v_lds_block_desc_permuted,
make_tuple(make_merge_transform(
make_tuple(number<NumVLdsBuffers>{}, number<kKPerBlock>{})),
make_merge_transform_v3_division_mod(make_tuple(
number<kNPerBlock / XorGroupSize>{}, number<XorGroupSize>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
};
}
template <typename Problem>
@@ -529,24 +747,46 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
if constexpr(!Problem::kUseTrLoad)
{
constexpr index_t NPerThread = GetAlignmentV<Problem>();
constexpr index_t NThreads = kNPerBlock / NPerThread;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(ElemPerThread % N1 == 0);
constexpr index_t KPerThread = ElemPerThread / NPerThread;
constexpr index_t KThreadPerWarp = get_warp_size() / NThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t K2 = ElemPerThread / N1;
constexpr index_t K1 = get_warp_size() / N0;
constexpr index_t K0 = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NThreads, NPerThread>,
sequence<NumWarps, KThreadPerWarp, KPerThread>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>,
sequence<2, 1>>{});
}
else
{
constexpr index_t NPerThread = GetAlignmentV<Problem>();
constexpr index_t NThreads = kNPerBlock / NPerThread;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>,
sequence<2, 1>>{});
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t KPerThread = ElemPerThread / NPerThread;
constexpr index_t KThreadPerWarp = get_warp_size() / NThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NumWarps, KThreadPerWarp, KPerThread>,
sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
};
}
template <typename Problem>
@@ -556,20 +796,19 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t NPerThread = GetAlignmentV<Problem>();
constexpr index_t NThreads = kNPerBlock / NPerThread;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(ElemPerThread % N1 == 0);
constexpr index_t K2 = ElemPerThread / N1;
constexpr index_t K1 = get_warp_size() / N0;
constexpr index_t K0 = kBlockSize / get_warp_size();
constexpr index_t KPerThread = ElemPerThread / NPerThread;
constexpr index_t KThreadPerWarp = get_warp_size() / NThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2>>,
tuple<sequence<NThreads, NPerThread>,
sequence<NumWarps, KThreadPerWarp, KPerThread>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,