mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Switch into universal gemms for conv bwds (#2981)
* switch into universal gemms for conv bwds * some fixes and support universal gemm in conv fwd * add reviewer comments
This commit is contained in:
@@ -44,13 +44,13 @@ struct GroupedConvBwdDataKernelArgs
|
||||
CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0])};
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0])};
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
@@ -145,15 +145,15 @@ struct GroupedConvBwdDataKernelArgs
|
||||
CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1])};
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1])};
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
@@ -161,13 +161,13 @@ struct GroupedConvBwdDataKernelArgs
|
||||
static_cast<index_t>(args.output_spatial_lengths_[1])};
|
||||
|
||||
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[1])};
|
||||
static_cast<index_t>(args.conv_filter_strides_[1])};
|
||||
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[1])};
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
|
||||
static_cast<index_t>(args.input_left_pads_[1])};
|
||||
static_cast<index_t>(args.input_left_pads_[1])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
|
||||
static_cast<index_t>(args.input_right_pads_[1])};
|
||||
static_cast<index_t>(args.input_right_pads_[1])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
@@ -262,17 +262,17 @@ struct GroupedConvBwdDataKernelArgs
|
||||
CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[2])};
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[2])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[2])};
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[2])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
@@ -281,17 +281,17 @@ struct GroupedConvBwdDataKernelArgs
|
||||
static_cast<index_t>(args.output_spatial_lengths_[2])};
|
||||
|
||||
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[1]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[2])};
|
||||
static_cast<index_t>(args.conv_filter_strides_[1]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[2])};
|
||||
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[1]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[2])};
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
|
||||
static_cast<index_t>(args.input_left_pads_[1]),
|
||||
static_cast<index_t>(args.input_left_pads_[2])};
|
||||
static_cast<index_t>(args.input_left_pads_[1]),
|
||||
static_cast<index_t>(args.input_left_pads_[2])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
|
||||
static_cast<index_t>(args.input_right_pads_[1]),
|
||||
static_cast<index_t>(args.input_right_pads_[2])};
|
||||
static_cast<index_t>(args.input_right_pads_[1]),
|
||||
static_cast<index_t>(args.input_right_pads_[2])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
@@ -387,8 +387,8 @@ struct GroupedConvBwdDataKernelArgs
|
||||
|
||||
static constexpr index_t MaxGroupedGemmGroupsNum = 128;
|
||||
|
||||
using ABCGridDescs = remove_cvref_t<
|
||||
decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(1))>;
|
||||
using ABCGridDescs = remove_cvref_t<decltype(
|
||||
ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(1))>;
|
||||
|
||||
using AGridDescMK = remove_cvref_t<decltype(ABCGridDescs{}[number<0>{}])>;
|
||||
using BGridDescNK = remove_cvref_t<decltype(ABCGridDescs{}[number<1>{}])>;
|
||||
@@ -471,10 +471,6 @@ template <typename GroupedConvTraitsType_,
|
||||
typename EpiloguePipeline_>
|
||||
struct GroupedConvolutionBackwardDataKernel
|
||||
{
|
||||
// Todo: Enable Vector Load Size > 1
|
||||
static_assert(GroupedConvTraitsType_::VectorSizeA == 1 &&
|
||||
GroupedConvTraitsType_::VectorSizeB == 1);
|
||||
|
||||
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization =
|
||||
GroupedConvTraitsType_::ConvSpecialization;
|
||||
@@ -517,12 +513,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
|
||||
// TODO: Change to and enable vector load
|
||||
// static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>,
|
||||
// "Not supported A GEMM layout!");
|
||||
// static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>,
|
||||
// "Not supported B GEMM layout!");
|
||||
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
|
||||
"Not supported C GEMM layout!");
|
||||
|
||||
@@ -742,8 +733,8 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
const auto& b_pad_view = [&]() {
|
||||
const auto& b_tensor_view = views.at(I1);
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
@@ -788,9 +779,9 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
|
||||
const auto& b_block_window = [&]() {
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, i_k});
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_k, i_n});
|
||||
}();
|
||||
|
||||
const auto ds_block_window = generate_tuple(
|
||||
|
||||
@@ -40,13 +40,13 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0])};
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0])};
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
@@ -80,8 +80,8 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
|
||||
GroupedConvTraitsType_::NDimSpatial>();
|
||||
|
||||
a_grid_desc_m_k = grid_descs.at(number<0>{});
|
||||
b_grid_desc_n_k = grid_descs.at(number<1>{});
|
||||
a_grid_desc_k_m = grid_descs.at(number<0>{});
|
||||
b_grid_desc_k_n = grid_descs.at(number<1>{});
|
||||
c_grid_desc_m_n = grid_descs.at(number<2>{});
|
||||
|
||||
group_stride_a = args.K_; // A: Out NWGK
|
||||
@@ -92,9 +92,9 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
|
||||
GemmM = a_grid_desc_m_k.get_length(number<0>{});
|
||||
GemmN = b_grid_desc_n_k.get_length(number<0>{});
|
||||
GemmK = a_grid_desc_m_k.get_length(number<1>{});
|
||||
GemmM = a_grid_desc_k_m.get_length(number<1>{});
|
||||
GemmN = b_grid_desc_k_n.get_length(number<1>{});
|
||||
GemmK = a_grid_desc_k_m.get_length(number<0>{});
|
||||
GemmBatch = args.G_;
|
||||
}
|
||||
|
||||
@@ -109,15 +109,15 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1])};
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1])};
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
@@ -125,13 +125,13 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
static_cast<index_t>(args.output_spatial_lengths_[1])};
|
||||
|
||||
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[1])};
|
||||
static_cast<index_t>(args.conv_filter_strides_[1])};
|
||||
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[1])};
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
|
||||
static_cast<index_t>(args.input_left_pads_[1])};
|
||||
static_cast<index_t>(args.input_left_pads_[1])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
|
||||
static_cast<index_t>(args.input_right_pads_[1])};
|
||||
static_cast<index_t>(args.input_right_pads_[1])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
@@ -156,8 +156,8 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
|
||||
GroupedConvTraitsType_::NDimSpatial>();
|
||||
|
||||
a_grid_desc_m_k = grid_descs.at(number<0>{});
|
||||
b_grid_desc_n_k = grid_descs.at(number<1>{});
|
||||
a_grid_desc_k_m = grid_descs.at(number<0>{});
|
||||
b_grid_desc_k_n = grid_descs.at(number<1>{});
|
||||
c_grid_desc_m_n = grid_descs.at(number<2>{});
|
||||
|
||||
group_stride_a = args.K_; // A: Out NHWGK
|
||||
@@ -168,9 +168,9 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
|
||||
GemmM = a_grid_desc_m_k.get_length(number<0>{});
|
||||
GemmN = b_grid_desc_n_k.get_length(number<0>{});
|
||||
GemmK = a_grid_desc_m_k.get_length(number<1>{});
|
||||
GemmM = a_grid_desc_k_m.get_length(number<1>{});
|
||||
GemmN = b_grid_desc_k_n.get_length(number<1>{});
|
||||
GemmK = a_grid_desc_k_m.get_length(number<0>{});
|
||||
GemmBatch = args.G_;
|
||||
}
|
||||
|
||||
@@ -185,17 +185,17 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[2])};
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[2])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[2])};
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[2])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
@@ -204,17 +204,17 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
static_cast<index_t>(args.output_spatial_lengths_[2])};
|
||||
|
||||
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[1]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[2])};
|
||||
static_cast<index_t>(args.conv_filter_strides_[1]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[2])};
|
||||
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[1]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[2])};
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
|
||||
static_cast<index_t>(args.input_left_pads_[1]),
|
||||
static_cast<index_t>(args.input_left_pads_[2])};
|
||||
static_cast<index_t>(args.input_left_pads_[1]),
|
||||
static_cast<index_t>(args.input_left_pads_[2])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
|
||||
static_cast<index_t>(args.input_right_pads_[1]),
|
||||
static_cast<index_t>(args.input_right_pads_[2])};
|
||||
static_cast<index_t>(args.input_right_pads_[1]),
|
||||
static_cast<index_t>(args.input_right_pads_[2])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
@@ -239,8 +239,8 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
|
||||
GroupedConvTraitsType_::NDimSpatial>();
|
||||
|
||||
a_grid_desc_m_k = grid_descs.at(number<0>{});
|
||||
b_grid_desc_n_k = grid_descs.at(number<1>{});
|
||||
a_grid_desc_k_m = grid_descs.at(number<0>{});
|
||||
b_grid_desc_k_n = grid_descs.at(number<1>{});
|
||||
c_grid_desc_m_n = grid_descs.at(number<2>{});
|
||||
|
||||
group_stride_a = args.K_; // A: Out NDHWGK
|
||||
@@ -251,17 +251,17 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
1,
|
||||
std::multiplies<index_t>());
|
||||
|
||||
GemmM = a_grid_desc_m_k.get_length(number<0>{});
|
||||
GemmN = b_grid_desc_n_k.get_length(number<0>{});
|
||||
GemmK = a_grid_desc_m_k.get_length(number<1>{});
|
||||
GemmM = a_grid_desc_k_m.get_length(number<1>{});
|
||||
GemmN = b_grid_desc_k_n.get_length(number<1>{});
|
||||
GemmK = a_grid_desc_k_m.get_length(number<0>{});
|
||||
GemmBatch = args.G_;
|
||||
}
|
||||
|
||||
using ABCGridDescs = remove_cvref_t<
|
||||
decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N())>;
|
||||
using ABCGridDescs = remove_cvref_t<decltype(
|
||||
ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N())>;
|
||||
|
||||
using AGridDescMK = remove_cvref_t<decltype(ABCGridDescs{}[number<0>{}])>;
|
||||
using BGridDescNK = remove_cvref_t<decltype(ABCGridDescs{}[number<1>{}])>;
|
||||
using AGridDescKM = remove_cvref_t<decltype(ABCGridDescs{}[number<0>{}])>;
|
||||
using BGridDescKN = remove_cvref_t<decltype(ABCGridDescs{}[number<1>{}])>;
|
||||
using CGridDescMN = remove_cvref_t<decltype(ABCGridDescs{}[number<2>{}])>;
|
||||
|
||||
static constexpr index_t NonSpatialDims = 3;
|
||||
@@ -285,8 +285,8 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
std::array<const void*, NumDTensor> ds_ptr;
|
||||
void* wei_ptr;
|
||||
|
||||
AGridDescMK a_grid_desc_m_k;
|
||||
BGridDescNK b_grid_desc_n_k;
|
||||
AGridDescKM a_grid_desc_k_m;
|
||||
BGridDescKN b_grid_desc_k_n;
|
||||
CGridDescMN c_grid_desc_m_n;
|
||||
|
||||
long_index_t group_stride_a;
|
||||
@@ -338,10 +338,6 @@ template <typename GroupedConvTraitsType_,
|
||||
typename EpiloguePipeline_>
|
||||
struct GroupedConvolutionBackwardWeightKernel
|
||||
{
|
||||
// Todo: Enable Vector Load Size > 1
|
||||
static_assert(GroupedConvTraitsType_::VectorSizeA == 1 &&
|
||||
GroupedConvTraitsType_::VectorSizeB == 1);
|
||||
|
||||
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization =
|
||||
GroupedConvTraitsType_::ConvSpecialization;
|
||||
@@ -380,12 +376,8 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
|
||||
static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
|
||||
"Not supported!");
|
||||
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
|
||||
// TODO: Change to and enable vector load
|
||||
// static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::ColumnMajor>, "Not
|
||||
// supported!"); static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not
|
||||
// supported!");
|
||||
static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
@@ -592,12 +584,12 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(a_ptr,
|
||||
kargs.a_grid_desc_m_k); // A: out
|
||||
kargs.a_grid_desc_k_m); // A: out
|
||||
}();
|
||||
|
||||
const auto& b_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr,
|
||||
kargs.b_grid_desc_n_k); // B: in
|
||||
kargs.b_grid_desc_k_n); // B: in
|
||||
}();
|
||||
|
||||
const auto& c_tensor_view = [&]() {
|
||||
@@ -628,16 +620,16 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
const auto& a_pad_view = [&]() {
|
||||
const auto& a_tensor_view = views.at(I0);
|
||||
return pad_tensor_view(a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{} * k_batch),
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{} * k_batch,
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
const auto& b_pad_view = [&]() {
|
||||
const auto& b_tensor_view = views.at(I1);
|
||||
return pad_tensor_view(b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{} * k_batch),
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{} * k_batch,
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
@@ -675,16 +667,16 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
|
||||
const auto& a_block_window = [&]() {
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, i_k});
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::MPerBlock>{}),
|
||||
{i_k, i_m});
|
||||
}();
|
||||
|
||||
const auto& b_block_window = [&]() {
|
||||
return make_tile_window(b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, i_k});
|
||||
make_tuple(number<TilePartitioner::KPerBlock>{},
|
||||
number<TilePartitioner::NPerBlock>{}),
|
||||
{i_k, i_n});
|
||||
}();
|
||||
|
||||
const auto ds_block_window = generate_tuple(
|
||||
|
||||
@@ -41,13 +41,13 @@ struct GroupedConvFwdKernelArgs
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0])};
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0])};
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
@@ -124,15 +124,15 @@ struct GroupedConvFwdKernelArgs
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1])};
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1])};
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
@@ -140,13 +140,13 @@ struct GroupedConvFwdKernelArgs
|
||||
static_cast<index_t>(args.output_spatial_lengths_[1])};
|
||||
|
||||
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[1])};
|
||||
static_cast<index_t>(args.conv_filter_strides_[1])};
|
||||
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[1])};
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
|
||||
static_cast<index_t>(args.input_left_pads_[1])};
|
||||
static_cast<index_t>(args.input_left_pads_[1])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
|
||||
static_cast<index_t>(args.input_right_pads_[1])};
|
||||
static_cast<index_t>(args.input_right_pads_[1])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
@@ -216,17 +216,17 @@ struct GroupedConvFwdKernelArgs
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[2])};
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[2])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[2])};
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[2])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
@@ -235,17 +235,17 @@ struct GroupedConvFwdKernelArgs
|
||||
static_cast<index_t>(args.output_spatial_lengths_[2])};
|
||||
|
||||
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[1]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[2])};
|
||||
static_cast<index_t>(args.conv_filter_strides_[1]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[2])};
|
||||
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[1]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[2])};
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
|
||||
static_cast<index_t>(args.input_left_pads_[1]),
|
||||
static_cast<index_t>(args.input_left_pads_[2])};
|
||||
static_cast<index_t>(args.input_left_pads_[1]),
|
||||
static_cast<index_t>(args.input_left_pads_[2])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
|
||||
static_cast<index_t>(args.input_right_pads_[1]),
|
||||
static_cast<index_t>(args.input_right_pads_[2])};
|
||||
static_cast<index_t>(args.input_right_pads_[1]),
|
||||
static_cast<index_t>(args.input_right_pads_[2])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
@@ -306,15 +306,15 @@ struct GroupedConvFwdKernelArgs
|
||||
args.output_spatial_lengths_[2];
|
||||
}
|
||||
|
||||
using AGridDescMK = remove_cvref_t<
|
||||
decltype(ConvToGemmFwdTransformer{}
|
||||
.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>())>;
|
||||
using BGridDescNK = remove_cvref_t<
|
||||
decltype(ConvToGemmFwdTransformer{}
|
||||
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>())>;
|
||||
using CGridDescMN = remove_cvref_t<
|
||||
decltype(ConvToGemmFwdTransformer{}
|
||||
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>())>;
|
||||
using AGridDescMK = remove_cvref_t<decltype(
|
||||
ConvToGemmFwdTransformer{}
|
||||
.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>())>;
|
||||
using BGridDescNK = remove_cvref_t<decltype(
|
||||
ConvToGemmFwdTransformer{}
|
||||
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>())>;
|
||||
using CGridDescMN = remove_cvref_t<decltype(
|
||||
ConvToGemmFwdTransformer{}
|
||||
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>())>;
|
||||
|
||||
static constexpr index_t NonSpatialDims = 3;
|
||||
array<index_t, NonSpatialDims + GroupedConvTraitsType_::NDimSpatial> in_g_n_c_wis_lengths;
|
||||
|
||||
@@ -82,20 +82,14 @@ struct GroupedConvTraits
|
||||
true,
|
||||
true,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor,
|
||||
// TODO: Change to and enable vector load
|
||||
// ck_tile::tensor_layout::gemm::RowMajor,
|
||||
// ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
using GroupedConvImplicitGemmTraitsBwdWeight =
|
||||
TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor,
|
||||
// TODO: Change to and enable vector load
|
||||
// ck_tile::tensor_layout::gemm::ColumnMajor,
|
||||
// ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
static constexpr ck_tile::index_t VectorSizeA = VectorSizeA_;
|
||||
static constexpr ck_tile::index_t VectorSizeB = VectorSizeB_;
|
||||
|
||||
@@ -502,7 +502,7 @@ struct TransformConvBwdDataToGemm
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_),
|
||||
make_tuple(NStride, HiStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
@@ -512,7 +512,7 @@ struct TransformConvBwdDataToGemm
|
||||
// GKYXC
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, Y_, X_, C_),
|
||||
make_tuple(C_ * X_ * Y_, C_ * X_, C_, I1),
|
||||
number<VectorSizeC>{},
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
@@ -547,7 +547,7 @@ struct TransformConvBwdDataToGemm
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N_, Di_, Hi_, Wi_, C_),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
number<VectorSizeC>{},
|
||||
I1);
|
||||
}
|
||||
|
||||
@@ -558,7 +558,7 @@ struct TransformConvBwdDataToGemm
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(K_, Z_, Y_, X_, C_),
|
||||
make_tuple(C_ * X_ * Y_ * Z_, C_ * X_ * Y_, C_ * X_, C_, I1),
|
||||
number<VectorSizeC>{},
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
}
|
||||
// TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as
|
||||
@@ -642,7 +642,7 @@ struct TransformConvBwdDataToGemm
|
||||
make_tuple(make_merge_transform(make_tuple(XDotSlice, K_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// c: input
|
||||
const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
@@ -797,7 +797,7 @@ struct TransformConvBwdDataToGemm
|
||||
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<1, 2, 0>{}, sequence<3>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// c: input
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
@@ -999,7 +999,7 @@ struct TransformConvBwdDataToGemm
|
||||
make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(sequence<1, 2, 3, 0>{}, sequence<4>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
// c: input
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
|
||||
@@ -421,7 +421,6 @@ struct TransformConvBwdWeightToGemm
|
||||
constexpr auto KStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_),
|
||||
make_tuple(KStride, NDoHoWoStride),
|
||||
number<VectorSizeA>{},
|
||||
@@ -463,9 +462,8 @@ struct TransformConvBwdWeightToGemm
|
||||
constexpr auto KStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Ho_ * Wo_),
|
||||
make_tuple(KStride, NDoHoWoStride),
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), // K_M
|
||||
make_tuple(NDoHoWoStride, KStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
@@ -480,7 +478,7 @@ struct TransformConvBwdWeightToGemm
|
||||
constexpr auto CStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_),
|
||||
return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), // K_N
|
||||
make_tuple(NStride, HiStride, WiStride, CStride),
|
||||
number<VectorSizeB>{},
|
||||
I1);
|
||||
@@ -506,9 +504,8 @@ struct TransformConvBwdWeightToGemm
|
||||
constexpr auto KStride = I1;
|
||||
|
||||
// TODO Add support for NumGroupsToMerge > 1
|
||||
|
||||
return make_naive_tensor_descriptor(make_tuple(K_, N_ * Do_ * Ho_ * Wo_),
|
||||
make_tuple(KStride, NDoHoWoStride),
|
||||
return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
|
||||
make_tuple(NDoHoWoStride, KStride),
|
||||
number<VectorSizeA>{},
|
||||
I1);
|
||||
}
|
||||
@@ -577,7 +574,7 @@ struct TransformConvBwdWeightToGemm
|
||||
make_tuple(make_merge_transform(make_tuple(X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Wo_))),
|
||||
make_tuple(sequence<1, 3>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
}
|
||||
@@ -614,7 +611,7 @@ struct TransformConvBwdWeightToGemm
|
||||
make_tuple(make_merge_transform(make_tuple(Y_, X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
}
|
||||
@@ -657,7 +654,7 @@ struct TransformConvBwdWeightToGemm
|
||||
make_tuple(make_merge_transform(make_tuple(Z_, Y_, X_, C_)),
|
||||
make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_))),
|
||||
make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return make_tuple(out_grid_desc, in_gemmn_gemmktotal_grid_desc, wei_grid_desc);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user