mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Update to MakeLSEaccDramTileDistribution trying to assign more threads to MThreadPerWarp so that block_tile_reduce_sync() work on less KThreadPerWarp
This commit is contained in:
@@ -95,44 +95,23 @@ struct HstuAttentionFwdSplitKVCombinePipelinePolicy
|
||||
|
||||
static_assert((OtherK & (OtherK - 1)) == 0, "Check failed!");
|
||||
|
||||
constexpr index_t KPerThread = KVector;
|
||||
static_assert(kMPerBlock <= NumWarps * get_warp_size(), "Check failed!");
|
||||
|
||||
if constexpr(OtherK < get_warp_size())
|
||||
{
|
||||
// try to assign more consecutive threads on dim-K
|
||||
constexpr index_t KThreads = OtherK;
|
||||
constexpr index_t MThreadPerWarp = kMPerBlock / NumWarps;
|
||||
constexpr index_t MPerThread = 1;
|
||||
constexpr index_t KThreadPerWarp = get_warp_size() / MThreadPerWarp;
|
||||
constexpr index_t KRepPerThread = OtherK / KThreadPerWarp;
|
||||
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
|
||||
// 32/64 Threads should be in lay-out [kThreads, MThreadPerWarp] since the tile
|
||||
// distribution will be used by block_tile_reduce_sync(..., bool_constant<0>{})
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
tuple<sequence<1>, sequence<0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// all threads in the warp are assigned on dim-K
|
||||
constexpr index_t KThreads = get_warp_size();
|
||||
constexpr index_t KRepPerThread = OtherK / KThreads;
|
||||
|
||||
constexpr index_t MPerThread = kMPerBlock / NumWarps;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<MPerThread, NumWarps>,
|
||||
sequence<KRepPerThread, KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<2>>,
|
||||
tuple<sequence<1>, sequence<1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
};
|
||||
// 32/64 Threads should be in lay-out [kThreads, MThreadPerWarp] since the tile
|
||||
// distribution will be used by block_tile_reduce_sync(..., bool_constant<0>{})
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<MThreadPerWarp, NumWarps, MPerThread>,
|
||||
sequence<KRepPerThread, KThreadPerWarp, KVector>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
tuple<sequence<1>, sequence<1, 0>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<2, 0, 2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -152,6 +152,8 @@ struct HstuAttentionFwdSplitKVCombinePipelineProblem
|
||||
static constexpr index_t kBlockSize = CombineTileSetting_::NumWarps * get_warp_size();
|
||||
static constexpr index_t kMaxSplits = kMaxSplits_;
|
||||
|
||||
static_assert((kMaxSplits == 0) || (kM * kMaxSplits >= kBlockSize), "Check failed!");
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetOaccDramTileAccessMaxVectorSize()
|
||||
{
|
||||
constexpr index_t kMPerBlock = kM;
|
||||
|
||||
Reference in New Issue
Block a user