mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
Add support of loading QK tiles of hdim96 without padding to hdim128
This commit is contained in:
@@ -59,6 +59,25 @@ struct has_ignore_fast_exp2_flag<
|
||||
template <typename T>
|
||||
static inline constexpr bool ignore_fast_exp2_v = has_ignore_fast_exp2_flag<T>::value;
|
||||
|
||||
// A helper struct for detecting naive_hdim_load, naive_hdim_load means load tiles of
|
||||
// hdim96/hdim160/hdim192 without padding the tensor_view/tile_window to hdim128/hdim256
|
||||
// naive_hdim_load is current supported by the qr_ks_vs_whole_k_prefetch_pipeline
|
||||
template <typename T, typename = void>
|
||||
struct has_naive_hdim_load_flag : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct has_naive_hdim_load_flag<
|
||||
T,
|
||||
std::enable_if_t<std::is_convertible_v<decltype(T::kIsNaiveHDimLoad), bool> &&
|
||||
T::kIsNaiveHDimLoad>> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static inline constexpr bool is_naive_hdim_load_v = has_naive_hdim_load_flag<T>::value;
|
||||
|
||||
}; // namespace detail
|
||||
|
||||
template <typename FmhaPipeline_, typename EpiloguePipeline_>
|
||||
@@ -1313,6 +1332,10 @@ struct FmhaFwdKernel
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
|
||||
batch_offset_o;
|
||||
|
||||
constexpr index_t kQKHeaddimToUse = detail::is_naive_hdim_load_v<FmhaPipeline>
|
||||
? FmhaPipeline::kQKHeaddim
|
||||
: FmhaPipeline::kSubQKHeaddim;
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
const auto q_dram = [&]() {
|
||||
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
@@ -1325,7 +1348,7 @@ struct FmhaFwdKernel
|
||||
{
|
||||
return pad_tensor_view(q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{},
|
||||
number<FmhaPipeline::kSubQKHeaddim>{}),
|
||||
number<kQKHeaddimToUse>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
}
|
||||
else
|
||||
@@ -1350,7 +1373,7 @@ struct FmhaFwdKernel
|
||||
{
|
||||
return pad_tensor_view(k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0Sub>{},
|
||||
number<FmhaPipeline::kSubQKHeaddim>{}),
|
||||
number<kQKHeaddimToUse>{}),
|
||||
sequence<kPadSeqLenK_, kPadHeadDimQ>{});
|
||||
}
|
||||
else
|
||||
@@ -1371,18 +1394,29 @@ struct FmhaFwdKernel
|
||||
number<FmhaPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
const auto v_dram_transposed = transform_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v),
|
||||
make_pass_through_transform(kargs.seqlen_k)),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
if constexpr(!kUseTrLoad)
|
||||
{
|
||||
const auto v_dram_transposed = transform_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v),
|
||||
make_pass_through_transform(kargs.seqlen_k)),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
|
||||
return pad_tensor_view(
|
||||
v_dram_transposed,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK_>{});
|
||||
constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
|
||||
|
||||
return pad_tensor_view(
|
||||
v_dram_transposed,
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, kPadSeqLenK_>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kK1>{}, number<FmhaPipeline::kN1>{}),
|
||||
sequence<false, kPadHeadDimV>{});
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1406,7 +1440,7 @@ struct FmhaFwdKernel
|
||||
[&]() {
|
||||
if constexpr(FmhaPipeline::kQLoadOnce)
|
||||
return make_tuple(number<FmhaPipeline::kM0>{},
|
||||
number<FmhaPipeline::kSubQKHeaddim>{});
|
||||
number<kQKHeaddimToUse>{});
|
||||
else
|
||||
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{});
|
||||
}(),
|
||||
@@ -1416,8 +1450,8 @@ struct FmhaFwdKernel
|
||||
if constexpr(detail::is_n0loop_pipeline_v<FmhaPipeline>)
|
||||
{
|
||||
return make_tile_window(k_dram,
|
||||
make_tuple(number<FmhaPipeline::kK1>{},
|
||||
number<FmhaPipeline::kSubQKHeaddim>{}),
|
||||
make_tuple(number<FmhaPipeline::kN0Sub>{},
|
||||
number<kQKHeaddimToUse>{}),
|
||||
{0, 0});
|
||||
}
|
||||
else
|
||||
|
||||
@@ -8,6 +8,52 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename DataType, index_t ElemPerThread>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetMaxVectorSize()
|
||||
{
|
||||
if constexpr(std::is_same_v<DataType, half_t> || std::is_same_v<DataType, bf16_t>)
|
||||
{
|
||||
// ToDo: need support in ck_tile for using buffer_load_dwordx3
|
||||
// if constexpr(ElemPerThread % 6 == 0)
|
||||
// return 6;
|
||||
if constexpr(ElemPerThread % 8 == 0)
|
||||
return 8;
|
||||
else if constexpr(ElemPerThread % 4 == 0)
|
||||
return 4;
|
||||
else if constexpr(ElemPerThread % 2 == 0)
|
||||
return 2;
|
||||
return 1;
|
||||
}
|
||||
else if constexpr(std::is_same_v<DataType, float>)
|
||||
{
|
||||
// ToDo: need support in ck_tile for using buffer_load_dwordx3
|
||||
// if constexpr(ElemPerThread % 3 == 0)
|
||||
// return 3;
|
||||
if constexpr(ElemPerThread % 4 == 0)
|
||||
return 4;
|
||||
else if constexpr(ElemPerThread % 2 == 0)
|
||||
return 2;
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
static_assert(false, "The data type is not supported!");
|
||||
};
|
||||
|
||||
template <typename DataType,
|
||||
index_t kThreadBlockSize,
|
||||
index_t kHigherDimSize,
|
||||
index_t kLowerDimSize>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetDramTileAccessMaxVectorSize()
|
||||
{
|
||||
constexpr index_t ElemPerThread = (kHigherDimSize * kLowerDimSize) / kThreadBlockSize;
|
||||
|
||||
return GetMaxVectorSize<DataType, ElemPerThread>();
|
||||
}
|
||||
|
||||
}; // namespace detail
|
||||
|
||||
template <typename QDataType_,
|
||||
typename KDataType_,
|
||||
typename VDataType_,
|
||||
@@ -50,6 +96,9 @@ struct BlockFmhaPipelineProblem
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kUseTrLoad = kUseTrLoad_;
|
||||
|
||||
// ToDo: should we define kUseTrLoad and kLoadWholeQTileOnceThrough Lds here ?
|
||||
static constexpr bool kLoadWholeQTileOnceThroughLds = kUseTrLoad ? true : false;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
@@ -62,6 +111,33 @@ struct BlockFmhaPipelineProblem
|
||||
static constexpr bool kHasDropout = Traits::kHasDropout;
|
||||
static constexpr auto QScaleEnum = Traits::QScaleEnum;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQDramTileAccessMaxVectorSize()
|
||||
{
|
||||
constexpr index_t kMPerBlock = BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
return detail::
|
||||
GetDramTileAccessMaxVectorSize<QDataType, kBlockSize, kMPerBlock, kKPerBlock>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKDramTileAccessMaxVectorSize()
|
||||
{
|
||||
constexpr index_t kNPerBlock = BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = BlockFmhaShape::kK0;
|
||||
|
||||
return detail::
|
||||
GetDramTileAccessMaxVectorSize<KDataType, kBlockSize, kNPerBlock, kKPerBlock>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVDramTileAccessMaxVectorSize()
|
||||
{
|
||||
constexpr index_t kNPerBlock = BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = BlockFmhaShape::kK1;
|
||||
|
||||
return detail::
|
||||
GetDramTileAccessMaxVectorSize<VDataType, kBlockSize, kNPerBlock, kKPerBlock>();
|
||||
};
|
||||
};
|
||||
|
||||
template <typename QDataType_,
|
||||
|
||||
@@ -35,8 +35,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
static constexpr bool kQLoadOnce = true;
|
||||
static_assert(kQLoadOnce == Policy::QLoadOnce);
|
||||
|
||||
static constexpr bool kUseN0Loop = true;
|
||||
static constexpr bool kIgnoreFastExp2 = true;
|
||||
static constexpr bool kUseN0Loop = true;
|
||||
static constexpr bool kIgnoreFastExp2 = true;
|
||||
static constexpr bool kIsNaiveHDimLoad = true;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
@@ -50,14 +51,14 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = (kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kStoreLSE = Problem::kStoreLSE;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
@@ -135,9 +136,9 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
typename AttentionVariantParams,
|
||||
typename BlockIndices>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kQKHeaddim tile
|
||||
const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const VElementFunction& v_element_func,
|
||||
@@ -170,8 +171,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK1 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kSubQKHeaddim ==
|
||||
KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kQKHeaddim == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
@@ -225,7 +225,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
|
||||
auto q_dram_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kSubQKHeaddim>{}),
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQDramSingleRepMTileDistribution<Problem>());
|
||||
|
||||
@@ -235,7 +235,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kN0Sub>{}, number<kSubQKHeaddim>{}),
|
||||
make_tuple(number<kN0Sub>{}, number<kQKHeaddim>{}),
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
@@ -271,7 +271,6 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
auto q_lds_write_window = make_tile_window(
|
||||
q_lds, Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds,
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
|
||||
@@ -286,25 +285,15 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
auto k_lds_window = make_tile_window(
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
using k_lds_write_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kSubQKHeaddim>{}));
|
||||
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
using k_lds_read_window_type = decltype(get_slice_tile(
|
||||
using k_lds_window_type = decltype(get_slice_tile(
|
||||
k_lds_window, sequence<0, 0>{}, sequence<kN0Sub, kQKHeaddim>{}));
|
||||
|
||||
statically_indexed_array<k_lds_write_window_type, NumKVLdsBuffers> k_lds_write_windows;
|
||||
statically_indexed_array<k_lds_read_window_type, NumKVLdsBuffers> k_lds_read_windows;
|
||||
statically_indexed_array<k_lds_window_type, NumKVLdsBuffers> k_lds_windows;
|
||||
|
||||
static_for<0, NumKVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
k_lds_write_windows[i_buf] =
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kSubQKHeaddim>{});
|
||||
k_lds_read_windows[i_buf] =
|
||||
get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
|
||||
k_lds_windows[i_buf] = get_slice_tile(k_lds_window,
|
||||
sequence<i_buf * kN0Sub, 0>{},
|
||||
sequence<(i_buf + 1) * kN0Sub, kQKHeaddim>{});
|
||||
});
|
||||
|
||||
// V tile in LDS
|
||||
@@ -434,7 +423,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
if(seqlen_k_curr < seqlen_k_end - kN0) // not the last iteration
|
||||
{
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
@@ -460,9 +449,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(sacc_tile,
|
||||
q_tile,
|
||||
k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
gemm_0(
|
||||
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
@@ -475,7 +463,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
else // the iteration is also the last iteration
|
||||
{
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
@@ -492,9 +480,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(sacc_tile,
|
||||
q_tile,
|
||||
k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
gemm_0(
|
||||
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
@@ -510,7 +497,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
if(seqlen_k_curr < seqlen_k_end - kN0) // intermediate iteration
|
||||
{
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
@@ -525,9 +512,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
move_tile_window(k_dram_window, {kN0Sub, 0});
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(sacc_tile,
|
||||
q_tile,
|
||||
k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
gemm_0(
|
||||
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
@@ -540,7 +526,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
else // last iteration
|
||||
{
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[number<i_n0>{}]),
|
||||
partition_index);
|
||||
|
||||
@@ -551,9 +537,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(sacc_tile,
|
||||
q_tile,
|
||||
k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
gemm_0(
|
||||
sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
@@ -568,7 +553,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
else // only preload one unroll of K for next iteration, used when kM0=128
|
||||
{
|
||||
static_for<0, n0_loops, 1>{}([&](auto i_n0) {
|
||||
store_tile(k_lds_write_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
store_tile(k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tiles[I0]),
|
||||
partition_index);
|
||||
|
||||
@@ -590,7 +575,7 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetch
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
gemm_0(sacc_tile, q_tile, k_lds_read_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
gemm_0(sacc_tile, q_tile, k_lds_windows[number<i_n0 % NumKVLdsBuffers>{}]);
|
||||
|
||||
sacc_tile = tile_elementwise_in(s_acc_element_func, sacc_tile);
|
||||
auto tmp_tile = cast_tile<CompDataType>(sacc_tile);
|
||||
|
||||
@@ -114,16 +114,21 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
if constexpr(Problem::kLoadWholeQTileOnceThroughLds)
|
||||
{
|
||||
return Problem::GetQDramTileAccessMaxVectorSize();
|
||||
}
|
||||
else
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(QDataType);
|
||||
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
|
||||
|
||||
return min(MaxVectorSize, ElemPerThread);
|
||||
return detail::
|
||||
GetDramTileAccessMaxVectorSize<QDataType, kBlockSize, kMPerBlock, kKPerBlock>();
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -142,12 +147,10 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0Sub;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
|
||||
return min(MaxVectorSize, ElemPerThread);
|
||||
return detail::
|
||||
GetDramTileAccessMaxVectorSize<KDataType, kBlockSize, kNPerBlock, kKPerBlock>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -162,23 +165,31 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
|
||||
{
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
// special consideration when shuffling is required before storing V to LDS
|
||||
if constexpr(!Problem::kUseTrLoad)
|
||||
{
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
|
||||
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
|
||||
constexpr index_t kMaxVecLoad = Problem::GetVDramTileAccessMaxVectorSize();
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
|
||||
|
||||
constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad)
|
||||
? kMaxVecLoad
|
||||
: (ElemPerThread / kMinVecLoad);
|
||||
// try to avoid writing sub-dword to LDS due to poor performance
|
||||
constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad)
|
||||
? kMaxVecLoad
|
||||
: (ElemPerThread / kMinVecLoad);
|
||||
|
||||
return kVecLoad;
|
||||
return kVecLoad;
|
||||
}
|
||||
else
|
||||
{
|
||||
return Problem::GetVDramTileAccessMaxVectorSize();
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -195,11 +206,16 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKSingleSmemElementSpaceSize()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0Sub;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentK<Problem>();
|
||||
|
||||
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
// for hdim96 and hdim160
|
||||
if constexpr(kKPerBlock < Problem::BlockFmhaShape::kSubQKHeaddim)
|
||||
{
|
||||
return kKPerBlock * kNPerBlock;
|
||||
}
|
||||
else if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
{
|
||||
static_assert(kKVector == kKPack);
|
||||
|
||||
@@ -236,12 +252,23 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
constexpr index_t kMPerBlock = Problem::kLoadWholeQTileOnceThroughLds
|
||||
? Problem::BlockFmhaShape::kM0
|
||||
: GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentQ<Problem>();
|
||||
|
||||
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
// for hdim96 and hdim160, use simplest layout
|
||||
if constexpr(kKPerBlock < Problem::BlockFmhaShape::kSubQKHeaddim)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<1>{}),
|
||||
number<kKVector>{},
|
||||
number<1>{});
|
||||
}
|
||||
else if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
{
|
||||
static_assert(kKVector == kKPack);
|
||||
|
||||
@@ -324,25 +351,113 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t kKVector = GetAlignmentQ<Problem>();
|
||||
constexpr index_t OtherK = kKPerBlock / kKVector;
|
||||
|
||||
constexpr index_t KPerThread = kKVector;
|
||||
constexpr index_t KThreads = kKPerBlock / KPerThread;
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
if constexpr(kKPerBlock == Problem::BlockFmhaShape::kSubQKHeaddim)
|
||||
// for kKPerBlock=32,64,128,256
|
||||
{
|
||||
static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!");
|
||||
|
||||
// for Q-Tile [64, 128], the encoding is [4W * 4T * 4E, 16T * 8E]
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NumWarps, MThreadPerWarp, MPerThread>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
constexpr index_t KPerThread = kKVector;
|
||||
constexpr index_t KThreads = OtherK;
|
||||
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
|
||||
// for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E]
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
else // for kKPerBlock=96,160
|
||||
{
|
||||
static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!");
|
||||
|
||||
// ToDo: need more considieration for hdim72
|
||||
constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5;
|
||||
constexpr index_t KThreads = OtherK / KRepPerThread;
|
||||
|
||||
static_assert((KThreads & (KThreads - 1)) == 0, "Check failed!");
|
||||
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
|
||||
sequence<KRepPerThread, KThreads, kKVector>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<1, 0, 2>>{});
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t kKVector = GetAlignmentQ<Problem>();
|
||||
constexpr index_t OtherK = kKPerBlock / kKVector;
|
||||
|
||||
if constexpr(kKPerBlock == Problem::BlockFmhaShape::kSubQKHeaddim)
|
||||
// for kKPerBlock=32,64,128,256
|
||||
{
|
||||
static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!");
|
||||
|
||||
constexpr index_t KPerThread = kKVector;
|
||||
constexpr index_t KThreads = OtherK;
|
||||
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
|
||||
// for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E]
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
else // for kKPerBlock=96,160
|
||||
{
|
||||
static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!");
|
||||
|
||||
// ToDo: need more considieration for hdim72
|
||||
constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5;
|
||||
constexpr index_t KThreads = OtherK / KRepPerThread;
|
||||
|
||||
static_assert((KThreads & (KThreads - 1)) == 0, "Check failed!");
|
||||
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
|
||||
sequence<KRepPerThread, KThreads, kKVector>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<1, 0, 2>>{});
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -350,11 +465,36 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
{
|
||||
constexpr index_t NumKLdsBuffers = GetNumKVLdsBuffers<Problem>();
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0Sub;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentK<Problem>();
|
||||
|
||||
if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
// for hdim96 and hdim160, use simplest layout
|
||||
if constexpr(kKPerBlock < Problem::BlockFmhaShape::kSubQKHeaddim)
|
||||
{
|
||||
constexpr index_t KSingleSmemElementSpaceSize = kNPerBlock * kKPerBlock;
|
||||
|
||||
static_assert(KSingleSmemElementSpaceSize == GetKSingleSmemElementSpaceSize<Problem>());
|
||||
|
||||
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
|
||||
|
||||
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumKLdsBuffers>{}, number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
make_tuple(number<SingleSmemElementSpaceSize>{}, number<kKPerBlock>{}, number<1>{}),
|
||||
number<kKVector>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
|
||||
k_lds_block_desc_0,
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(number<NumKLdsBuffers>{}, number<kNPerBlock>{})),
|
||||
make_pass_through_transform(number<kKPerBlock>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return k_lds_block_desc;
|
||||
}
|
||||
else if constexpr(GetQKWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
{
|
||||
static_assert(kKVector == kKPack);
|
||||
|
||||
@@ -362,9 +502,15 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
|
||||
constexpr index_t DataTypeSize = sizeof(KDataType);
|
||||
|
||||
#ifdef __gfx950__
|
||||
// 256 contiguous bytes mapped to 64 banks with each bank 4 contiguous bytes
|
||||
constexpr auto NLdsLayer =
|
||||
(64 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (64 * 4 / kKPerBlock / DataTypeSize);
|
||||
#else
|
||||
// 128 contiguous bytes mapped to 32 banks with each bank 4 contiguous bytes
|
||||
constexpr auto NLdsLayer =
|
||||
(32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize);
|
||||
#endif
|
||||
|
||||
constexpr auto k_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumKLdsBuffers>{},
|
||||
@@ -455,24 +601,52 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0Sub;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t kKVector = GetAlignmentK<Problem>();
|
||||
constexpr index_t OtherK = kKPerBlock / kKVector;
|
||||
|
||||
constexpr index_t KPerThread = kKVector;
|
||||
constexpr index_t KThreads = kKPerBlock / KPerThread;
|
||||
constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps);
|
||||
if constexpr(kKPerBlock == Problem::BlockFmhaShape::kSubQKHeaddim)
|
||||
// for kKPerBlock=32,64,128,256
|
||||
{
|
||||
static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NPerThread, NThreadPerWarp, NumWarps>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<2>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
constexpr index_t KPerThread = kKVector;
|
||||
constexpr index_t KThreads = OtherK;
|
||||
|
||||
constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
else // for kKPerBlock=96,160
|
||||
{
|
||||
static_assert((OtherK & (OtherK - 1)) != 0, "Check failed!");
|
||||
|
||||
constexpr index_t KRepPerThread = (OtherK % 3 == 0) ? 3 : 5;
|
||||
constexpr index_t KThreads = OtherK / KRepPerThread;
|
||||
|
||||
constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
|
||||
sequence<KRepPerThread, KThreads, kKVector>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -483,43 +657,87 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
if constexpr(!Problem::kUseTrLoad)
|
||||
{
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
// K2 is the vector size for storing shuffled tile to LDS
|
||||
constexpr index_t K2 = ElemPerThread / N1;
|
||||
// K2 is the vector size for storing shuffled tile to LDS
|
||||
constexpr index_t K2 = ElemPerThread / N1;
|
||||
|
||||
// GetSmemKPackV() is the vector size for loading from LDS by BlockGemm
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
// GetSmemKPackV() is the vector size for loading from LDS by BlockGemm
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
|
||||
static_assert(kKPack >= K2, "Check failed!");
|
||||
static_assert(kKPack >= K2, "Check failed!");
|
||||
|
||||
constexpr index_t VSingleSmemElementSpaceSize = N0 * (N1 * kKPerBlock + kKPack);
|
||||
constexpr index_t VSingleSmemElementSpaceSize = N0 * (N1 * kKPerBlock + kKPack);
|
||||
|
||||
static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize<Problem>());
|
||||
static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize<Problem>());
|
||||
|
||||
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
|
||||
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
|
||||
|
||||
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{}, number<kKPerBlock>{}),
|
||||
make_tuple(number<SingleSmemElementSpaceSize>{},
|
||||
number<N1 * kKPerBlock + kKPack>{},
|
||||
number<kKPerBlock>{},
|
||||
number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{}, number<kKPerBlock>{}),
|
||||
make_tuple(number<SingleSmemElementSpaceSize>{},
|
||||
number<N1 * kKPerBlock + kKPack>{},
|
||||
number<kKPerBlock>{},
|
||||
number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
|
||||
v_lds_block_desc_0,
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{})),
|
||||
make_pass_through_transform(number<kKPerBlock>{})),
|
||||
make_tuple(sequence<0, 1, 2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
|
||||
v_lds_block_desc_0,
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{})),
|
||||
make_pass_through_transform(number<kKPerBlock>{})),
|
||||
make_tuple(sequence<0, 1, 2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return v_lds_block_desc;
|
||||
return v_lds_block_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
|
||||
constexpr auto XorGroupSize = Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
|
||||
|
||||
constexpr index_t VSingleSmemElementSpaceSize = kNPerBlock * kKPerBlock;
|
||||
|
||||
static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize<Problem>());
|
||||
|
||||
constexpr auto v_lds_block_desc_naive =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumVLdsBuffers>{},
|
||||
number<kKPerBlock>{},
|
||||
number<kNPerBlock / XorGroupSize>{},
|
||||
number<XorGroupSize>{}),
|
||||
make_tuple(number<VSingleSmemElementSpaceSize>{},
|
||||
number<kNPerBlock>{},
|
||||
number<XorGroupSize>{},
|
||||
number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto v_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
v_lds_block_desc_naive,
|
||||
make_tuple(make_pass_through_transform(number<NumVLdsBuffers>{}),
|
||||
make_xor_transform(make_tuple(number<kKPerBlock>{},
|
||||
number<kNPerBlock / XorGroupSize>{})),
|
||||
make_pass_through_transform(number<XorGroupSize>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
v_lds_block_desc_permuted,
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(number<NumVLdsBuffers>{}, number<kKPerBlock>{})),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<kNPerBlock / XorGroupSize>{}, number<XorGroupSize>{}))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -529,24 +747,46 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
if constexpr(!Problem::kUseTrLoad)
|
||||
{
|
||||
constexpr index_t NPerThread = GetAlignmentV<Problem>();
|
||||
constexpr index_t NThreads = kNPerBlock / NPerThread;
|
||||
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
static_assert(ElemPerThread % N1 == 0);
|
||||
constexpr index_t KPerThread = ElemPerThread / NPerThread;
|
||||
constexpr index_t KThreadPerWarp = get_warp_size() / NThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
|
||||
constexpr index_t K2 = ElemPerThread / N1;
|
||||
constexpr index_t K1 = get_warp_size() / N0;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NThreads, NPerThread>,
|
||||
sequence<NumWarps, KThreadPerWarp, KPerThread>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<2, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t NPerThread = GetAlignmentV<Problem>();
|
||||
constexpr index_t NThreads = kNPerBlock / NPerThread;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<2, 1>>{});
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
constexpr index_t KPerThread = ElemPerThread / NPerThread;
|
||||
constexpr index_t KThreadPerWarp = get_warp_size() / NThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NumWarps, KThreadPerWarp, KPerThread>,
|
||||
sequence<NThreads, NPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -556,20 +796,19 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t NPerThread = GetAlignmentV<Problem>();
|
||||
constexpr index_t NThreads = kNPerBlock / NPerThread;
|
||||
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
static_assert(ElemPerThread % N1 == 0);
|
||||
|
||||
constexpr index_t K2 = ElemPerThread / N1;
|
||||
constexpr index_t K1 = get_warp_size() / N0;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t KPerThread = ElemPerThread / NPerThread;
|
||||
constexpr index_t KThreadPerWarp = get_warp_size() / NThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<NThreads, NPerThread>,
|
||||
sequence<NumWarps, KThreadPerWarp, KPerThread>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
|
||||
Reference in New Issue
Block a user