mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
Rename variables used in distributio encoding
This commit is contained in:
@@ -64,15 +64,16 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
|
||||
constexpr index_t kNPerBlock = Problem::kTileSizeSk;
|
||||
constexpr index_t kKPerBlock = Problem::kTileSizeD;
|
||||
|
||||
constexpr index_t K1 = 16 / sizeof(KDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
constexpr index_t KPerThread = 16 / sizeof(KDataType);
|
||||
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
|
||||
constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
|
||||
sequence<KThreadPerBlock, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
@@ -100,15 +101,16 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
|
||||
constexpr index_t N1 = 16 / sizeof(VDataType);
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t K2 = get_warp_size() / N0;
|
||||
constexpr index_t K1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t K0 = kKPerBlock / (K2 * K1);
|
||||
constexpr index_t NPerThread = 16 / sizeof(VDataType);
|
||||
constexpr index_t NThreadPerBlock = kNPerBlock / NPerThread;
|
||||
constexpr index_t KThreadPerWarp = get_warp_size() / NThreadPerBlock;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t KPerThread = kKPerBlock / (NumWarps * KThreadPerWarp);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<NThreadPerBlock, NPerThread>,
|
||||
sequence<KPerThread, NumWarps, KThreadPerWarp>>,
|
||||
tuple<sequence<2>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<0, 2>>,
|
||||
sequence<1, 2>,
|
||||
@@ -116,21 +118,33 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = 16 / sizeof(VDataType);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
constexpr index_t KPerThread = 16 / sizeof(VDataType);
|
||||
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
|
||||
constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
|
||||
sequence<KThreadPerBlock, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
}
|
||||
|
||||
// 4 vals per load
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeRotaryCosSinInterleaveDramTileDistribution()
|
||||
{
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeRotaryCosSinContiguousDramTileDistribution()
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user