Rename variables used in distributio encoding

This commit is contained in:
PoYen, Chen
2024-07-16 06:27:28 +00:00
parent 879710a495
commit b32fd8d3f4

View File

@@ -64,15 +64,16 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
constexpr index_t kNPerBlock = Problem::kTileSizeSk;
constexpr index_t kKPerBlock = Problem::kTileSizeD;
constexpr index_t K1 = 16 / sizeof(KDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
constexpr index_t KPerThread = 16 / sizeof(KDataType);
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
sequence<KThreadPerBlock, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
@@ -100,15 +101,16 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = 16 / sizeof(VDataType);
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t K2 = get_warp_size() / N0;
constexpr index_t K1 = kBlockSize / get_warp_size();
constexpr index_t K0 = kKPerBlock / (K2 * K1);
constexpr index_t NPerThread = 16 / sizeof(VDataType);
constexpr index_t NThreadPerBlock = kNPerBlock / NPerThread;
constexpr index_t KThreadPerWarp = get_warp_size() / NThreadPerBlock;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t KPerThread = kKPerBlock / (NumWarps * KThreadPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2>>,
tuple<sequence<NThreadPerBlock, NPerThread>,
sequence<KPerThread, NumWarps, KThreadPerWarp>>,
tuple<sequence<2>, sequence<1, 2>>,
tuple<sequence<1>, sequence<0, 2>>,
sequence<1, 2>,
@@ -116,21 +118,33 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
}
else
{
constexpr index_t K1 = 16 / sizeof(VDataType);
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
constexpr index_t KPerThread = 16 / sizeof(VDataType);
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
sequence<KThreadPerBlock, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
}
// 4 vals per load
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeRotaryCosSinInterleaveDramTileDistribution()
{
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeRotaryCosSinContiguousDramTileDistribution()
{
}
};
} // namespace ck_tile