Refactored and fixed formatting of bwd_data instances

This commit is contained in:
apoorva
2026-02-02 11:38:37 +00:00
parent cc395ff4fc
commit 21e9dc2ef2
2 changed files with 1 additions and 14 deletions

View File

@@ -513,19 +513,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
return grid_desc_m_k;
}
template <typename Desc_K0_N_K1>
static auto transform_k0_n_k1_to_n_k(const Desc_K0_N_K1& desc_k0_n_k1)
{
const auto grid_desc_n_k = transform_tensor_descriptor(
desc_k0_n_k1,
make_tuple(make_pass_through_transform(desc_k0_n_k1.GetLength(I1)),
make_merge_transform(
make_tuple(desc_k0_n_k1.GetLength(I0), desc_k0_n_k1.GetLength(I2)))),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return grid_desc_n_k;
}
// Note: the dummy function is used just to create the alias
constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform;
@@ -537,7 +524,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffleV3
using EGridDesc_M_N = remove_cvref_t<tuple_element_t<3, ABDsEGridDesc>>;
using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{}));
using BGridDesc_N_K = decltype(transform_k0_n_k1_to_n_k(BGridDesc_BK0_N_BK1{}));
using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{}));
// Note: here we can call gridwise functions with dummy arguments,
// just to create the alias