mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 03:07:02 +00:00
Change to the Q/K DramTile encoding and renaming in V/VShuffled DramTile
This commit is contained in:
@@ -341,15 +341,15 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
|
||||
// for Q-Tile [64, 128], the encoding is [4W * 4T * 4E, 16T * 8E]
|
||||
// for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E]
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NumWarps, MThreadPerWarp, MPerThread>,
|
||||
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -367,15 +367,15 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
|
||||
// for Q-Tile [64, 128], the encoding is [4W * 4T * 4E, 16T * 8E]
|
||||
// for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E]
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NumWarps, MThreadPerWarp, MPerThread>,
|
||||
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -512,10 +512,10 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NPerThread, NThreadPerWarp, NumWarps>,
|
||||
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<2>, sequence<1, 0>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
@@ -621,20 +621,19 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
|
||||
if constexpr(!Problem::kUseTrLoad)
|
||||
{
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t NPerThread = GetAlignmentV<Problem>();
|
||||
constexpr index_t NThreads = kNPerBlock / NPerThread;
|
||||
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
static_assert(ElemPerThread % N1 == 0);
|
||||
|
||||
constexpr index_t K2 = ElemPerThread / N1;
|
||||
constexpr index_t K1 = get_warp_size() / N0;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t KPerThread = ElemPerThread / NPerThread;
|
||||
constexpr index_t KThreadPerWarp = get_warp_size() / NThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<NThreads, NPerThread>,
|
||||
sequence<NumWarps, KThreadPerWarp, KPerThread>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
@@ -642,20 +641,19 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t NPerThread = GetAlignmentV<Problem>();
|
||||
constexpr index_t NThreads = kNPerBlock / NPerThread;
|
||||
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
static_assert(ElemPerThread % N1 == 0);
|
||||
|
||||
constexpr index_t K2 = ElemPerThread / N1;
|
||||
constexpr index_t K1 = get_warp_size() / N0;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t KPerThread = ElemPerThread / NPerThread;
|
||||
constexpr index_t KThreadPerWarp = get_warp_size() / NThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<K0, K1, K2>, sequence<N0, N1>>,
|
||||
tuple<sequence<NumWarps, KThreadPerWarp, KPerThread>,
|
||||
sequence<NThreads, NPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
@@ -671,20 +669,19 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t NPerThread = GetAlignmentV<Problem>();
|
||||
constexpr index_t NThreads = kNPerBlock / NPerThread;
|
||||
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
static_assert(ElemPerThread % N1 == 0);
|
||||
|
||||
constexpr index_t K2 = ElemPerThread / N1;
|
||||
constexpr index_t K1 = get_warp_size() / N0;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t KPerThread = ElemPerThread / NPerThread;
|
||||
constexpr index_t KThreadPerWarp = get_warp_size() / NThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<NThreads, NPerThread>,
|
||||
sequence<NumWarps, KThreadPerWarp, KPerThread>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
|
||||
Reference in New Issue
Block a user