Switch to v2 pipeline for grouped conv bwd data (#2181)

* Change to old pipeline for grouped conv bwd data

* fix

* fix

* fix

* fix

* fix

* fix

* Fix
This commit is contained in:
Bartłomiej Kocot
2025-05-13 10:14:30 +02:00
committed by GitHub
parent 2920604786
commit c53b7bd22e
12 changed files with 256 additions and 1111 deletions

View File

@@ -39,7 +39,6 @@ template <typename ADataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
@@ -330,7 +329,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
[[maybe_unused]] const Block2ETileMap&)
[[maybe_unused]] const Block2ETileMap&,
index_t k_batch = 1)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
@@ -367,7 +367,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
// check gridwise gemm pipeline
const auto num_k_loop = AK / KPerBlock;
const auto num_k_loop = AK / (KPerBlock * k_batch);
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
@@ -393,9 +393,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K,
index_t k_batch = 1)
{
const index_t num_loop = K / KPerBlock;
const index_t num_loop = K / (KPerBlock * k_batch);
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
@@ -500,6 +501,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -519,7 +521,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap& block_2_etile_map)
const Block2ETileMap& block_2_etile_map,
const index_t k_batch = 1,
const index_t k_idx = 0)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
@@ -550,6 +554,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
return;
}
const index_t num_k_per_block =
__builtin_amdgcn_readfirstlane(a_grid_desc_ak0_m_ak1.GetLength(I0) / k_batch);
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
@@ -591,7 +598,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
true,
NumGemmKPrefetchStage>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
make_multi_index(num_k_per_block * k_idx, m_block_data_idx_on_grid, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
@@ -622,7 +629,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
true,
NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
make_multi_index(num_k_per_block * k_idx, n_block_data_idx_on_grid, 0),
b_element_op,
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0),
@@ -688,7 +695,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
(KPerBlock * k_batch));
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
@@ -943,6 +950,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
}
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
GemmSpecialization GemmSpec,
typename ALayout,
typename BLayout,
@@ -1010,22 +1018,24 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation>(
p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
typename AGridDesc_MK,
typename BGridDesc_NK,
typename DsGridDesc_MN,
@@ -1067,19 +1077,20 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation>(
p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
};