mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Remove the using of MakeKRegBlockDescriptor and MakeVRegBlockDescriptor
This commit is contained in:
@@ -189,7 +189,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
Policy::template MakeKRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
|
||||
Policy::template MakeKRegBlockDescriptor<Problem>());
|
||||
Policy::template MakeKRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// V, HBM ->LDS ->Reg
|
||||
@@ -214,9 +214,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
v_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeVRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeVRegBlockDescriptor<Problem>());
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// KT, Reg ->LDS ->Reg
|
||||
auto shuffled_k_block_tile = make_static_distributed_tensor<KDataType>(
|
||||
@@ -259,7 +256,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
v_reg_tensor = load_tile(v_lds_read_window);
|
||||
auto v_reg_tensor = load_tile(v_lds_read_window);
|
||||
block_sync_lds();
|
||||
//---------------------------- Loop Load in ----------------------------//
|
||||
// Q: HBM ->Reg ->LDS
|
||||
|
||||
@@ -189,7 +189,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
Policy::template MakeKRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
|
||||
Policy::template MakeKRegBlockDescriptor<Problem>());
|
||||
Policy::template MakeKRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// V, HBM ->LDS ->Reg
|
||||
@@ -214,9 +214,6 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
v_lds_write_window.get_window_origin(),
|
||||
Policy::template MakeVRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
|
||||
Policy::template MakeVRegBlockDescriptor<Problem>());
|
||||
|
||||
//------------------------------------------------------------------
|
||||
// KT, Reg ->LDS ->Reg
|
||||
auto shuffled_k_block_tile = make_static_distributed_tensor<KDataType>(
|
||||
@@ -259,7 +256,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
v_reg_tensor = load_tile(v_lds_read_window);
|
||||
auto v_reg_tensor = load_tile(v_lds_read_window);
|
||||
//---------------------------- Loop Load in ----------------------------//
|
||||
// Q: HBM ->Reg ->LDS
|
||||
auto q_dram_window =
|
||||
|
||||
@@ -849,38 +849,6 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
return k_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKRegBlockDescriptor()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr auto k_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode);
|
||||
|
||||
return k_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor()
|
||||
{
|
||||
@@ -924,38 +892,6 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
return v_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegBlockDescriptor()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetOGradVBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr auto v_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode);
|
||||
|
||||
return v_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKRegWriteBlockDescriptor()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user