Change to the Q/K DramTile encoding and renaming in V/VShuffled DramTile

This commit is contained in:
Qianfeng Zhang
2025-12-16 14:31:11 +00:00
parent d7ddc76542
commit 13fdb382b2

View File

@@ -341,15 +341,15 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
// for Q-Tile [64, 128], the encoding is [4W * 4T * 4E, 16T * 8E]
// for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E]
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NumWarps, MThreadPerWarp, MPerThread>,
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
sequence<1, 1>>{});
}
template <typename Problem>
@@ -367,15 +367,15 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
// for Q-Tile [64, 128], the encoding is [4W * 4T * 4E, 16T * 8E]
// for Q-Tile [64, 128], the encoding is [4W * 4E * 4T, 16T * 8E]
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NumWarps, MThreadPerWarp, MPerThread>,
tuple<sequence<NumWarps, MPerThread, MThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
sequence<1, 1>>{});
}
template <typename Problem>
@@ -512,10 +512,10 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NPerThread, NThreadPerWarp, NumWarps>,
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<2>, sequence<1, 0>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
@@ -621,20 +621,19 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
if constexpr(!Problem::kUseTrLoad)
{
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t NPerThread = GetAlignmentV<Problem>();
constexpr index_t NThreads = kNPerBlock / NPerThread;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(ElemPerThread % N1 == 0);
constexpr index_t K2 = ElemPerThread / N1;
constexpr index_t K1 = get_warp_size() / N0;
constexpr index_t K0 = kBlockSize / get_warp_size();
constexpr index_t KPerThread = ElemPerThread / NPerThread;
constexpr index_t KThreadPerWarp = get_warp_size() / NThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2>>,
tuple<sequence<NThreads, NPerThread>,
sequence<NumWarps, KThreadPerWarp, KPerThread>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>,
@@ -642,20 +641,19 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
}
else
{
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t NPerThread = GetAlignmentV<Problem>();
constexpr index_t NThreads = kNPerBlock / NPerThread;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(ElemPerThread % N1 == 0);
constexpr index_t K2 = ElemPerThread / N1;
constexpr index_t K1 = get_warp_size() / N0;
constexpr index_t K0 = kBlockSize / get_warp_size();
constexpr index_t KPerThread = ElemPerThread / NPerThread;
constexpr index_t KThreadPerWarp = get_warp_size() / NThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<K0, K1, K2>, sequence<N0, N1>>,
tuple<sequence<NumWarps, KThreadPerWarp, KPerThread>,
sequence<NThreads, NPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
@@ -671,20 +669,19 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t NPerThread = GetAlignmentV<Problem>();
constexpr index_t NThreads = kNPerBlock / NPerThread;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(ElemPerThread % N1 == 0);
constexpr index_t K2 = ElemPerThread / N1;
constexpr index_t K1 = get_warp_size() / N0;
constexpr index_t K0 = kBlockSize / get_warp_size();
constexpr index_t KPerThread = ElemPerThread / NPerThread;
constexpr index_t KThreadPerWarp = get_warp_size() / NThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2>>,
tuple<sequence<NThreads, NPerThread>,
sequence<NumWarps, KThreadPerWarp, KPerThread>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,