mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
Update to the two trload pipeline to load whole Q-tile once through LDS on mi350
This commit is contained in:
@@ -119,7 +119,9 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
using QDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kMPerBlock = Problem::kLoadWholeQTileOnceThroughLds
|
||||
? Problem::HstuAttentionTileSetting::kM0
|
||||
: GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(QDataType);
|
||||
@@ -246,7 +248,9 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kMPerBlock = Problem::kLoadWholeQTileOnceThroughLds
|
||||
? Problem::HstuAttentionTileSetting::kM0
|
||||
: GetQKBlockGemmSingleRepM<Problem>();
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentQ<Problem>();
|
||||
@@ -361,6 +365,38 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
sequence<2, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
|
||||
{
|
||||
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::HstuAttentionTileSetting::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType);
|
||||
|
||||
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
|
||||
static_assert(0 < ElemPerThread);
|
||||
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
|
||||
|
||||
constexpr index_t KPerThread = kMaxVecLoad;
|
||||
constexpr index_t KThreads = kKPerBlock / KPerThread;
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
|
||||
// for Q-Tile [64, 128], the encoding is [4W * 4T * 4E, 16T * 8E]
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NumWarps, MThreadPerWarp, MPerThread>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
|
||||
{
|
||||
|
||||
@@ -46,6 +46,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
static constexpr bool kUseTrLoad = true;
|
||||
|
||||
static_assert(Problem::kLoadWholeQTileOnceThroughLds == true, "Check failed!");
|
||||
|
||||
static constexpr bool kPadSeqLenQ = Problem::Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQK = Problem::Traits::kPadHeadDimQK;
|
||||
@@ -184,9 +186,9 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
auto q_dram_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQDramSingleRepMTileDistribution<Problem>());
|
||||
Policy::template MakeQDramTileDistribution<Problem>());
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
@@ -198,13 +200,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
using q_dram_tile_type = decltype(load_tile(q_dram_window));
|
||||
statically_indexed_array<q_dram_tile_type, kGemmNumRepM> q_dram_tiles;
|
||||
|
||||
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
|
||||
q_dram_tiles[i_rep] = load_tile(q_dram_window);
|
||||
move_tile_window(q_dram_window, {kGemmSingleRepM, 0});
|
||||
});
|
||||
auto q_dram_tile = load_tile(q_dram_window);
|
||||
|
||||
using k_tile_type = decltype(load_tile(k_dram_window));
|
||||
|
||||
@@ -226,9 +222,9 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds,
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeQRegSingleRepMTileDistribution<Problem>());
|
||||
Policy::template MakeQRegTileDistribution<Problem>());
|
||||
|
||||
// K tile in LDS
|
||||
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
|
||||
@@ -327,39 +323,20 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
|
||||
}();
|
||||
|
||||
using q_reg_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeQRegSingleRepMTileDistribution<Problem>()));
|
||||
statically_indexed_array<q_reg_tile_type, kGemmNumRepM> q_reg_tiles;
|
||||
|
||||
using q_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeQRegTileDistribution<Problem>()));
|
||||
|
||||
q_tile_type q_tile;
|
||||
|
||||
{
|
||||
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
|
||||
store_tile(q_lds_write_window, q_dram_tiles[i_rep]);
|
||||
store_tile(q_lds_write_window, q_dram_tile);
|
||||
|
||||
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice written
|
||||
// by each wavefront is read by itself
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
clear_tile(o_acc);
|
||||
|
||||
q_reg_tiles[i_rep] = load_tile(q_lds_read_window);
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
block_sync_lds();
|
||||
|
||||
// the following codes will not generate actual instructions by the compiler
|
||||
set_slice_tile(q_tile,
|
||||
q_reg_tiles[i_rep],
|
||||
sequence<i_rep * kGemmSingleRepM, 0>{},
|
||||
sequence<(i_rep + 1) * kGemmSingleRepM, kQKHeaddim>{});
|
||||
|
||||
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice read
|
||||
// by each wavefront is over-written by itself
|
||||
});
|
||||
|
||||
clear_tile(o_acc);
|
||||
};
|
||||
q_tile = load_tile(q_lds_read_window);
|
||||
|
||||
q_tile = tile_elementwise_in(q_element_func, q_tile);
|
||||
|
||||
|
||||
@@ -47,6 +47,9 @@ struct HstuAttentionFwdPipelineProblem
|
||||
static constexpr bool kUseSoftmax = kUseSoftmax_;
|
||||
static constexpr bool kUseTrLoad = kUseTrLoad_;
|
||||
|
||||
// ToDo: should we define kUseTrLoad and kLoadWholeQTileOnceThrough Lds here ?
|
||||
static constexpr bool kLoadWholeQTileOnceThroughLds = kUseTrLoad ? true : false;
|
||||
|
||||
using HstuAttentionTileSetting = remove_cvref_t<AttentionTileSetting_>;
|
||||
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
@@ -46,6 +46,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
static constexpr bool kUseTrLoad = true;
|
||||
|
||||
static_assert(Problem::kLoadWholeQTileOnceThroughLds == true, "Check failed!");
|
||||
|
||||
static constexpr bool kPadSeqLenQ = Problem::Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQK = Problem::Traits::kPadHeadDimQK;
|
||||
@@ -195,9 +197,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
|
||||
auto q_dram_window =
|
||||
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQDramSingleRepMTileDistribution<Problem>());
|
||||
Policy::template MakeQDramTileDistribution<Problem>());
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
@@ -209,13 +211,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
{seqlen_k_start, 0},
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
using q_dram_tile_type = decltype(load_tile(q_dram_window));
|
||||
statically_indexed_array<q_dram_tile_type, kGemmNumRepM> q_dram_tiles;
|
||||
|
||||
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
|
||||
q_dram_tiles[i_rep] = load_tile(q_dram_window);
|
||||
move_tile_window(q_dram_window, {kGemmSingleRepM, 0});
|
||||
});
|
||||
auto q_dram_tile = load_tile(q_dram_window);
|
||||
|
||||
using k_tile_type = decltype(load_tile(k_dram_window));
|
||||
|
||||
@@ -242,9 +238,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds,
|
||||
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
|
||||
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeQRegSingleRepMTileDistribution<Problem>());
|
||||
Policy::template MakeQRegTileDistribution<Problem>());
|
||||
|
||||
// K tile in LDS
|
||||
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
|
||||
@@ -340,45 +336,26 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
|
||||
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
|
||||
}();
|
||||
|
||||
using q_reg_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeQRegSingleRepMTileDistribution<Problem>()));
|
||||
statically_indexed_array<q_reg_tile_type, kGemmNumRepM> q_reg_tiles;
|
||||
|
||||
using q_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeQRegTileDistribution<Problem>()));
|
||||
|
||||
q_tile_type q_tile;
|
||||
|
||||
{
|
||||
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
|
||||
store_tile(q_lds_write_window, q_dram_tiles[i_rep]);
|
||||
store_tile(q_lds_write_window, q_dram_tile);
|
||||
|
||||
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice written
|
||||
// by each wavefront is read by itself
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
clear_tile(o_acc);
|
||||
|
||||
q_reg_tiles[i_rep] = load_tile(q_lds_read_window);
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
block_sync_lds();
|
||||
|
||||
// the following codes will not generate actual instructions by the compiler
|
||||
set_slice_tile(q_tile,
|
||||
q_reg_tiles[i_rep],
|
||||
sequence<i_rep * kGemmSingleRepM, 0>{},
|
||||
sequence<(i_rep + 1) * kGemmSingleRepM, kQKHeaddim>{});
|
||||
|
||||
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice read
|
||||
// by each wavefront is over-written by itself
|
||||
});
|
||||
|
||||
clear_tile(o_acc);
|
||||
|
||||
set_tile(m, -numeric<CompDataType>::infinity());
|
||||
clear_tile(l);
|
||||
};
|
||||
q_tile = load_tile(q_lds_read_window);
|
||||
|
||||
q_tile = tile_elementwise_in(q_element_func, q_tile);
|
||||
|
||||
set_tile(m, -numeric<CompDataType>::infinity());
|
||||
clear_tile(l);
|
||||
|
||||
auto seqlen_k_curr = seqlen_k_start;
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x00000001);
|
||||
|
||||
Reference in New Issue
Block a user