mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Refactored and fixed formatting of bwd_data instances
This commit is contained in:
@@ -513,19 +513,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
|
||||
return grid_desc_m_k;
|
||||
}
|
||||
|
||||
template <typename Desc_K0_N_K1>
|
||||
static auto transform_k0_n_k1_to_n_k(const Desc_K0_N_K1& desc_k0_n_k1)
|
||||
{
|
||||
const auto grid_desc_n_k = transform_tensor_descriptor(
|
||||
desc_k0_n_k1,
|
||||
make_tuple(make_pass_through_transform(desc_k0_n_k1.GetLength(I1)),
|
||||
make_merge_transform(
|
||||
make_tuple(desc_k0_n_k1.GetLength(I0), desc_k0_n_k1.GetLength(I2)))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return grid_desc_n_k;
|
||||
}
|
||||
|
||||
// Note: the dummy function is used just to create the alias
|
||||
constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform;
|
||||
@@ -537,7 +524,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
|
||||
using EGridDesc_M_N = remove_cvref_t<tuple_element_t<3, ABDsEGridDesc>>;
|
||||
|
||||
using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{}));
|
||||
using BGridDesc_N_K = decltype(transform_k0_n_k1_to_n_k(BGridDesc_BK0_N_BK1{}));
|
||||
using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{}));
|
||||
|
||||
// Note: here we can call gridwise functions with dummy arguments,
|
||||
// just to create the alias
|
||||
|
||||
Reference in New Issue
Block a user