Update to MakeLSEaccDramTileDistribution trying to assign more threads to MThreadPerWarp so that block_tile_reduce_sync() work on less KThreadPerWarp

This commit is contained in:
Qianfeng Zhang
2026-05-23 07:22:29 +00:00
parent 1dbd127d1b
commit e841981ddd
2 changed files with 17 additions and 36 deletions

View File

@@ -95,44 +95,23 @@ struct HstuAttentionFwdSplitKVCombinePipelinePolicy
static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!");
constexpr index_t KPerThread = KVector;
static_assert(kMPerBlock <= NumWarps * get_warp_size(), "Check failed!");
if constexpr(OtherK < get_warp_size())
{
// try to assign more consecutive threads on dim-K
constexpr index_t KThreads = OtherK;
constexpr index_t MThreadPerWarp = kMPerBlock / NumWarps;
constexpr index_t MPerThread = 1;
constexpr index_t KThreadPerWarp = get_warp_size() / MThreadPerWarp;
constexpr index_t KRepPerThread = OtherK / KThreadPerWarp;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
// 32/64 Threads should be in lay-out [kThreads, MThreadPerWarp] since the tile
// distribution will be used by block_tile_reduce_sync(..., bool_constant<0>{})
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<2, 1>>,
tuple<sequence<1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
else
{
// all threads in the warp are assigned on dim-K
constexpr index_t KThreads = get_warp_size();
constexpr index_t KRepPerThread = OtherK / KThreads;
constexpr index_t MPerThread = kMPerBlock / NumWarps;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, NumWarps>,
sequence<KRepPerThread, KThreads, KPerThread>>,
tuple<sequence<1>, sequence<2>>,
tuple<sequence<1>, sequence<1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
};
// 32/64 Threads should be in lay-out [kThreads, MThreadPerWarp] since the tile
// distribution will be used by block_tile_reduce_sync(..., bool_constant<0>{})
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MThreadPerWarp, NumWarps, MPerThread>,
sequence<KRepPerThread, KThreadPerWarp, KVector>>,
tuple<sequence<1>, sequence<2, 1>>,
tuple<sequence<1>, sequence<1, 0>>,
sequence<1, 2, 2>,
sequence<2, 0, 2>>{});
}
template <typename Problem>

View File

@@ -152,6 +152,8 @@ struct HstuAttentionFwdSplitKVCombinePipelineProblem
static constexpr index_t kBlockSize = CombineTileSetting_::NumWarps * get_warp_size();
static constexpr index_t kMaxSplits = kMaxSplits_;
static_assert((kMaxSplits == 0) || (kM * kMaxSplits >= kBlockSize), "Check failed!");
CK_TILE_HOST_DEVICE static constexpr auto GetOaccDramTileAccessMaxVectorSize()
{
constexpr index_t kMPerBlock = kM;