Update to the two trload pipeline to load whole Q-tile once through LDS on mi350

This commit is contained in:
Qianfeng Zhang
2025-11-12 15:59:38 +00:00
parent 8f876f094e
commit 881ddc5741
4 changed files with 68 additions and 75 deletions

View File

@@ -119,7 +119,9 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
using QDataType = remove_cvref_t<typename Problem::QKVDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kMPerBlock = Problem::kLoadWholeQTileOnceThroughLds
? Problem::HstuAttentionTileSetting::kM0
: GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(QDataType);
@@ -246,7 +248,9 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{
constexpr index_t kMPerBlock = GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kMPerBlock = Problem::kLoadWholeQTileOnceThroughLds
? Problem::HstuAttentionTileSetting::kM0
: GetQKBlockGemmSingleRepM<Problem>();
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
constexpr index_t kKVector = GetAlignmentQ<Problem>();
@@ -361,6 +365,38 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
sequence<2, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::HstuAttentionTileSetting::kM0;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kSubQKHeaddim;
constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType);
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
static_assert(0 < ElemPerThread);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t KPerThread = kMaxVecLoad;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
// for Q-Tile [64, 128], the encoding is [4W * 4T * 4E, 16T * 8E]
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NumWarps, MThreadPerWarp, MPerThread>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
{

View File

@@ -46,6 +46,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
static constexpr bool kUseTrLoad = true;
static_assert(Problem::kLoadWholeQTileOnceThroughLds == true, "Check failed!");
static constexpr bool kPadSeqLenQ = Problem::Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQK = Problem::Traits::kPadHeadDimQK;
@@ -184,9 +186,9 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramSingleRepMTileDistribution<Problem>());
Policy::template MakeQDramTileDistribution<Problem>());
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
@@ -198,13 +200,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
{seqlen_k_start, 0},
Policy::template MakeKDramTileDistribution<Problem>());
using q_dram_tile_type = decltype(load_tile(q_dram_window));
statically_indexed_array<q_dram_tile_type, kGemmNumRepM> q_dram_tiles;
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
q_dram_tiles[i_rep] = load_tile(q_dram_window);
move_tile_window(q_dram_window, {kGemmSingleRepM, 0});
});
auto q_dram_tile = load_tile(q_dram_window);
using k_tile_type = decltype(load_tile(k_dram_window));
@@ -226,9 +222,9 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
auto q_lds_read_window =
make_tile_window(q_lds,
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
{0, 0},
Policy::template MakeQRegSingleRepMTileDistribution<Problem>());
Policy::template MakeQRegTileDistribution<Problem>());
// K tile in LDS
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
@@ -327,39 +323,20 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
}();
using q_reg_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
Policy::template MakeQRegSingleRepMTileDistribution<Problem>()));
statically_indexed_array<q_reg_tile_type, kGemmNumRepM> q_reg_tiles;
using q_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
Policy::template MakeQRegTileDistribution<Problem>()));
q_tile_type q_tile;
{
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
store_tile(q_lds_write_window, q_dram_tiles[i_rep]);
store_tile(q_lds_write_window, q_dram_tile);
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice written
// by each wavefront is read by itself
__builtin_amdgcn_s_waitcnt(0xc07f);
clear_tile(o_acc);
q_reg_tiles[i_rep] = load_tile(q_lds_read_window);
__builtin_amdgcn_sched_barrier(0x00000001);
__builtin_amdgcn_s_waitcnt(0xc07f);
block_sync_lds();
// the following codes will not generate actual instructions by the compiler
set_slice_tile(q_tile,
q_reg_tiles[i_rep],
sequence<i_rep * kGemmSingleRepM, 0>{},
sequence<(i_rep + 1) * kGemmSingleRepM, kQKHeaddim>{});
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice read
// by each wavefront is over-written by itself
});
clear_tile(o_acc);
};
q_tile = load_tile(q_lds_read_window);
q_tile = tile_elementwise_in(q_element_func, q_tile);

View File

@@ -47,6 +47,9 @@ struct HstuAttentionFwdPipelineProblem
static constexpr bool kUseSoftmax = kUseSoftmax_;
static constexpr bool kUseTrLoad = kUseTrLoad_;
// ToDo: should we define kUseTrLoad and kLoadWholeQTileOnceThrough Lds here ?
static constexpr bool kLoadWholeQTileOnceThroughLds = kUseTrLoad ? true : false;
using HstuAttentionTileSetting = remove_cvref_t<AttentionTileSetting_>;
using Traits = remove_cvref_t<Traits_>;

View File

@@ -46,6 +46,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
static constexpr bool kUseTrLoad = true;
static_assert(Problem::kLoadWholeQTileOnceThroughLds == true, "Check failed!");
static constexpr bool kPadSeqLenQ = Problem::Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQK = Problem::Traits::kPadHeadDimQK;
@@ -195,9 +197,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramSingleRepMTileDistribution<Problem>());
Policy::template MakeQDramTileDistribution<Problem>());
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] =
@@ -209,13 +211,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
{seqlen_k_start, 0},
Policy::template MakeKDramTileDistribution<Problem>());
using q_dram_tile_type = decltype(load_tile(q_dram_window));
statically_indexed_array<q_dram_tile_type, kGemmNumRepM> q_dram_tiles;
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
q_dram_tiles[i_rep] = load_tile(q_dram_window);
move_tile_window(q_dram_window, {kGemmSingleRepM, 0});
});
auto q_dram_tile = load_tile(q_dram_window);
using k_tile_type = decltype(load_tile(k_dram_window));
@@ -242,9 +238,9 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
// when kSubQKHeaddim > kQKHeaddim, read window is actually smaller than write window
auto q_lds_read_window =
make_tile_window(q_lds,
make_tuple(number<kGemmSingleRepM>{}, number<kQKHeaddim>{}),
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
{0, 0},
Policy::template MakeQRegSingleRepMTileDistribution<Problem>());
Policy::template MakeQRegTileDistribution<Problem>());
// K tile in LDS
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
@@ -340,45 +336,26 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
}();
using q_reg_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
Policy::template MakeQRegSingleRepMTileDistribution<Problem>()));
statically_indexed_array<q_reg_tile_type, kGemmNumRepM> q_reg_tiles;
using q_tile_type = decltype(make_static_distributed_tensor<QKVDataType>(
Policy::template MakeQRegTileDistribution<Problem>()));
q_tile_type q_tile;
{
static_for<0, kGemmNumRepM, 1>{}([&](auto i_rep) {
store_tile(q_lds_write_window, q_dram_tiles[i_rep]);
store_tile(q_lds_write_window, q_dram_tile);
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice written
// by each wavefront is read by itself
__builtin_amdgcn_s_waitcnt(0xc07f);
clear_tile(o_acc);
q_reg_tiles[i_rep] = load_tile(q_lds_read_window);
__builtin_amdgcn_sched_barrier(0x00000001);
__builtin_amdgcn_s_waitcnt(0xc07f);
block_sync_lds();
// the following codes will not generate actual instructions by the compiler
set_slice_tile(q_tile,
q_reg_tiles[i_rep],
sequence<i_rep * kGemmSingleRepM, 0>{},
sequence<(i_rep + 1) * kGemmSingleRepM, kQKHeaddim>{});
// no need to call __builtin_amdgcn_s_barrier() since the tile-slice read
// by each wavefront is over-written by itself
});
clear_tile(o_acc);
set_tile(m, -numeric<CompDataType>::infinity());
clear_tile(l);
};
q_tile = load_tile(q_lds_read_window);
q_tile = tile_elementwise_in(q_element_func, q_tile);
set_tile(m, -numeric<CompDataType>::infinity());
clear_tile(l);
auto seqlen_k_curr = seqlen_k_start;
__builtin_amdgcn_sched_barrier(0x00000001);