diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
index 3fb047f207..359711e5c4 100644
--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
-// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
+// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -106,89 +106,35 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
static constexpr auto I3 = Number<3>{};
static constexpr index_t KPerBlock = K0PerBlock * K1;
- static constexpr auto transform_conv_to_gemm =
- TransformConvBwdDataToGemm_v1{};
+ using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1;
- static auto GetDummyABDsEGridDescriptor()
+ static auto
+ GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform)
{
- const std::array dummy_tensor_lengths = {1};
- const std::array dummy_tensor_strides = {1};
- const std::array dummy_spatial_lengths = {1};
-
- const auto a_grid_desc_ak0_m_ak1 =
- transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1(
- dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths);
-
- const auto b_grid_desc_bk0_n_bk1 =
- transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1(
- dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths);
-
- const auto ds_grid_desc_m_n = generate_tuple(
- [&](auto i) {
- using DLayout = remove_cvref_t>;
-
- return transform_conv_to_gemm.template MakeCDescriptor_M_N(
- dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths);
- },
- Number{});
-
- const auto e_grid_desc_m_n =
- transform_conv_to_gemm.template MakeCDescriptor_M_N(dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths);
-
+ const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1();
+ const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1();
+ const auto ds_grid_desc_m_n =
+ generate_tuple([&](auto) { return conv_to_gemm_transform.MakeCDescriptor_M_N(); },
+ Number{});
+ const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N();
return make_tuple(
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
}
// desc
- using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor());
+ constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform;
+ using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform));
using AGridDesc_AK0_M_AK1 = remove_cvref_t>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t>;
@@ -270,7 +216,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
const std::array& b_g_k_c_xs_lengths,
const std::array& b_g_k_c_xs_strides,
const std::array, NumDTensor>&
- ds_g_n_c_wis_lengths,
+ /*ds_g_n_c_wis_lengths*/,
const std::array, NumDTensor>&
ds_g_n_c_wis_strides,
const std::array& e_g_n_c_wis_lengths,
@@ -291,15 +237,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths},
- a_g_n_k_wos_strides_{a_g_n_k_wos_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
- b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
- ds_g_n_c_wis_lengths_{ds_g_n_c_wis_lengths},
- ds_g_n_c_wis_strides_{ds_g_n_c_wis_strides},
- e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths},
- e_g_n_c_wis_strides_{e_g_n_c_wis_strides},
conv_filter_strides_{conv_filter_strides},
- conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
@@ -382,68 +321,47 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
tildes = {i_ztilde, i_ytilde, i_xtilde};
}
+ ConvToGemmBwdDataTransform conv_to_gemm_transform_{a_g_n_k_wos_lengths,
+ a_g_n_k_wos_strides,
+ b_g_k_c_xs_lengths,
+ b_g_k_c_xs_strides,
+ e_g_n_c_wis_lengths,
+ e_g_n_c_wis_strides,
+ conv_filter_strides,
+ conv_filter_dilations,
+ input_left_pads,
+ input_right_pads,
+ tildes};
+
const auto a_grid_desc_ak0_m_ak1 =
- transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1(
- a_g_n_k_wos_lengths,
- a_g_n_k_wos_strides,
- b_g_k_c_xs_lengths,
- b_g_k_c_xs_strides,
- e_g_n_c_wis_lengths,
- e_g_n_c_wis_strides,
- conv_filter_strides,
- conv_filter_dilations,
- input_left_pads,
- input_right_pads,
- tildes);
+ conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
const auto b_grid_desc_bk0_n_bk1 =
- transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1(
- a_g_n_k_wos_lengths,
- a_g_n_k_wos_strides,
- b_g_k_c_xs_lengths,
- b_g_k_c_xs_strides,
- e_g_n_c_wis_lengths,
- e_g_n_c_wis_strides,
- conv_filter_strides,
- conv_filter_dilations,
- input_left_pads,
- input_right_pads,
- tildes);
+ conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
DsGridDesc_M_N ds_grid_desc_m_n;
// populate Ds desc
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t>;
-
- ds_grid_desc_m_n(i) =
- transform_conv_to_gemm.template MakeCDescriptor_M_N(
- a_g_n_k_wos_lengths,
- a_g_n_k_wos_strides,
- b_g_k_c_xs_lengths,
- b_g_k_c_xs_strides,
- ds_g_n_c_wis_lengths[i],
- ds_g_n_c_wis_strides[i],
- conv_filter_strides,
- conv_filter_dilations,
- input_left_pads,
- input_right_pads,
- tildes);
- });
-
- const auto e_grid_desc_m_n =
- transform_conv_to_gemm.template MakeCDescriptor_M_N(
+ static_assert(is_same_v);
+ ConvToGemmBwdDataTransform conv_to_gemm_transform_d{
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
e_g_n_c_wis_lengths,
- e_g_n_c_wis_strides,
+ ds_g_n_c_wis_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
- tildes);
+ tildes};
+
+ ds_grid_desc_m_n(i) = conv_to_gemm_transform_d.MakeCDescriptor_M_N();
+ });
+
+ const auto e_grid_desc_m_n = conv_to_gemm_transform_.MakeCDescriptor_M_N();
// for check validity
ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
@@ -522,17 +440,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
BElementwiseOp b_element_op_;
CDEElementwiseOp cde_element_op_;
- // for checking IsSupportedArgument()
std::array a_g_n_k_wos_lengths_;
- std::array a_g_n_k_wos_strides_;
std::array b_g_k_c_xs_lengths_;
- std::array b_g_k_c_xs_strides_;
- std::array, NumDTensor> ds_g_n_c_wis_lengths_;
- std::array, NumDTensor> ds_g_n_c_wis_strides_;
- std::array e_g_n_c_wis_lengths_;
- std::array e_g_n_c_wis_strides_;
std::array conv_filter_strides_;
- std::array conv_filter_dilations_;
std::array input_left_pads_;
std::array input_right_pads_;
};
diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
index b544c925e1..c8c58d5d85 100644
--- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
+++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
@@ -54,15 +54,16 @@ template
__global__ void
#if CK_USE_LAUNCH_BOUNDS
@@ -73,10 +74,9 @@ __global__ void
const ABDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
- const AElementwiseOperation a_element_op,
- const BElementwiseOperation b_element_op,
- const CDEElementwiseOperation cde_element_op,
- const index_t batch_count,
+ const AElementwiseOp a_element_op,
+ const BElementwiseOp b_element_op,
+ const CDEElementwiseOp cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
@@ -84,24 +84,29 @@ __global__ void
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map,
- const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
+ const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
+ const ComputePtrOffsetOfN compute_ptr_offset_of_n)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
// offset base pointer for each work-group
- const index_t num_blocks_per_batch =
- __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
- const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
+ const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
+ const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
- const long_index_t a_batch_offset = amd_wave_read_first_lane(
- static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
- const long_index_t b_batch_offset = amd_wave_read_first_lane(
- static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
- const long_index_t e_batch_offset = amd_wave_read_first_lane(
- static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
+ const long_index_t a_batch_offset =
+ amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
+ const long_index_t b_batch_offset =
+ amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
+ const long_index_t e_batch_offset =
+ amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
+ const long_index_t a_n_offset =
+ amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
+ const long_index_t e_n_offset =
+ amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
+
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
DsPointer p_ds_grid_grp;
@@ -112,10 +117,10 @@ __global__ void
static_for<0, NumDTensor, 1>{}(
[&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
- GridwiseGemm::template Run(p_a_grid + a_batch_offset,
+ GridwiseGemm::template Run(p_a_grid + a_batch_offset + a_n_offset,
p_b_grid + b_batch_offset,
p_ds_grid_grp,
- p_e_grid + e_batch_offset,
+ p_e_grid + e_batch_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
@@ -130,7 +135,6 @@ __global__ void
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
- ignore = batch_count;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
@@ -139,6 +143,7 @@ __global__ void
ignore = b_element_op;
ignore = cde_element_op;
ignore = compute_ptr_offset_of_batch;
+ ignore = compute_ptr_offset_of_n;
ignore = block_2_ctile_map;
#endif
}
@@ -233,82 +238,54 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
- static constexpr auto transform_conv_to_gemm =
- TransformConvBwdDataToGemm_v1{};
+ using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1;
- static auto GetDummyABDsEGridDescriptor()
+ static auto
+ GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform)
{
- const std::array dummy_tensor_lengths = {1};
- const std::array dummy_tensor_strides = {1};
- const std::array dummy_spatial_lengths = {1};
+ const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1();
- const auto a_grid_desc_ak0_m_ak1 =
- transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1(
- dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths);
-
- const auto b_grid_desc_bk0_n_bk1 =
- transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1(
- dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths);
+ const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1();
const auto ds_grid_desc_m_n = generate_tuple(
[&](auto i) {
- using DLayout = remove_cvref_t>;
-
- return transform_conv_to_gemm.template MakeCDescriptor_M_N(
- dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths);
+ using DLayout = remove_cvref_t>;
+ using DDataType = remove_cvref_t>;
+ using ConvToGemmBwdDataTransformD =
+ TransformConvBwdDataToGemm_v1;
+ return ConvToGemmBwdDataTransformD{}.MakeCDescriptor_M_N();
},
Number{});
- const auto e_grid_desc_m_n =
- transform_conv_to_gemm.template MakeCDescriptor_M_N(dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_tensor_lengths,
- dummy_tensor_strides,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths,
- dummy_spatial_lengths);
+ const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N();
return make_tuple(
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
@@ -377,7 +354,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
// desc
- using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor());
+ constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform;
+ using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform));
using AGridDesc_AK0_M_AK1 = remove_cvref_t>;
using BGridDesc_BK0_N_BK1 = remove_cvref_t>;
@@ -431,15 +409,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths},
- a_g_n_k_wos_strides_{a_g_n_k_wos_strides},
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
- b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
- ds_g_n_c_wis_lengths_{ds_g_n_c_wis_lengths},
- ds_g_n_c_wis_strides_{ds_g_n_c_wis_strides},
- e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths},
- e_g_n_c_wis_strides_{e_g_n_c_wis_strides},
conv_filter_strides_{conv_filter_strides},
- conv_filter_dilations_{conv_filter_dilations},
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
@@ -450,11 +421,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
p_ds_grid_(i) = static_cast(p_ds[i]);
});
- // A/B/Ds/E Batch Stride
- compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
- compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
- compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0];
-
static_for<0, NumDTensor, 1>{}([&](auto i) {
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0];
});
@@ -526,68 +492,65 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
throw std::runtime_error("wrong! only implemented for 2D and 3D now");
}
+ ConvToGemmBwdDataTransform conv_to_gemm_transform_{a_g_n_k_wos_lengths,
+ a_g_n_k_wos_strides,
+ b_g_k_c_xs_lengths,
+ b_g_k_c_xs_strides,
+ e_g_n_c_wis_lengths,
+ e_g_n_c_wis_strides,
+ conv_filter_strides,
+ conv_filter_dilations,
+ input_left_pads,
+ input_right_pads,
+ tildes};
+
+ conv_N_per_block_ = conv_to_gemm_transform_.N_;
+
const auto a_grid_desc_ak0_m_ak1 =
- transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1(
- a_g_n_k_wos_lengths,
- a_g_n_k_wos_strides,
- b_g_k_c_xs_lengths,
- b_g_k_c_xs_strides,
- e_g_n_c_wis_lengths,
- e_g_n_c_wis_strides,
- conv_filter_strides,
- conv_filter_dilations,
- input_left_pads,
- input_right_pads,
- tildes);
+ conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1();
const auto b_grid_desc_bk0_n_bk1 =
- transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1(
- a_g_n_k_wos_lengths,
- a_g_n_k_wos_strides,
- b_g_k_c_xs_lengths,
- b_g_k_c_xs_strides,
- e_g_n_c_wis_lengths,
- e_g_n_c_wis_strides,
- conv_filter_strides,
- conv_filter_dilations,
- input_left_pads,
- input_right_pads,
- tildes);
+ conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1();
DsGridDesc_M_N ds_grid_desc_m_n;
// populate Ds desc
static_for<0, NumDTensor, 1>{}([&](auto i) {
- using DLayout = remove_cvref_t>;
-
- ds_grid_desc_m_n(i) =
- transform_conv_to_gemm.template MakeCDescriptor_M_N(
- a_g_n_k_wos_lengths,
- a_g_n_k_wos_strides,
- b_g_k_c_xs_lengths,
- b_g_k_c_xs_strides,
- ds_g_n_c_wis_lengths[i],
- ds_g_n_c_wis_strides[i],
- conv_filter_strides,
- conv_filter_dilations,
- input_left_pads,
- input_right_pads,
- tildes);
- });
-
- const auto e_grid_desc_m_n =
- transform_conv_to_gemm.template MakeCDescriptor_M_N(
+ using DLayout = remove_cvref_t>;
+ using DDataType = remove_cvref_t>;
+ using ConvToGemmBwdDataTransformD =
+ TransformConvBwdDataToGemm_v1;
+ ConvToGemmBwdDataTransformD conv_to_gemm_transform_d{
a_g_n_k_wos_lengths,
a_g_n_k_wos_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
- e_g_n_c_wis_lengths,
- e_g_n_c_wis_strides,
+ ds_g_n_c_wis_lengths[i],
+ ds_g_n_c_wis_strides[i],
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
- tildes);
+ tildes};
+
+ ds_grid_desc_m_n(i) = conv_to_gemm_transform_d.MakeCDescriptor_M_N();
+ });
+
+ const auto e_grid_desc_m_n = conv_to_gemm_transform_.MakeCDescriptor_M_N();
// desc for problem definition
const auto a_grid_desc_m_k =
@@ -628,6 +591,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
}
}
}
+ // A/B/Ds/E Batch Stride
+ compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
+ compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
+ compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0];
+
+ compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_k_wos_strides[1] * conv_N_per_block_;
+ compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_c_wis_strides[1] * conv_N_per_block_;
}
void Print() const
@@ -660,6 +630,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// tensor descriptor for problem definition
index_t num_group_;
+ index_t conv_N_per_block_;
std::vector a_grid_desc_m_k_container_;
std::vector b_grid_desc_n_k_container_;
std::vector ds_grid_desc_m_n_container_;
@@ -678,23 +649,16 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
// for computing batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
+ ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_;
// element-wise op
AElementwiseOp a_element_op_;
BElementwiseOp b_element_op_;
CDEElementwiseOp cde_element_op_;
- // for checking IsSupportedArgument()
std::array a_g_n_k_wos_lengths_;
- std::array a_g_n_k_wos_strides_;
std::array b_g_k_c_xs_lengths_;
- std::array b_g_k_c_xs_strides_;
- std::array, NumDTensor> ds_g_n_c_wis_lengths_;
- std::array, NumDTensor> ds_g_n_c_wis_strides_;
- std::array e_g_n_c_wis_lengths_;
- std::array e_g_n_c_wis_strides_;
std::array conv_filter_strides_;
- std::array conv_filter_dilations_;
std::array input_left_pads_;
std::array input_right_pads_;
};
@@ -711,8 +675,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
arg.Print();
}
- float ave_time = 0;
+ const index_t gdy = arg.num_group_;
+ const index_t num_workgroups_per_Conv_N =
+ arg.a_g_n_k_wos_lengths_[I1] / arg.conv_N_per_block_;
+ const index_t gdz = num_workgroups_per_Conv_N;
+ float ave_time = 0;
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i],
@@ -724,9 +692,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
throw std::runtime_error("wrong! device_op has invalid setting");
}
- const index_t grid_size = arg.block_2_etile_map_container_[i].CalculateGridSize(
- arg.e_grid_desc_m_n_container_[i]) *
- arg.num_group_;
+ const index_t gdx = arg.block_2_etile_map_container_[i].CalculateGridSize(
+ arg.e_grid_desc_m_n_container_[i]);
const auto GemmK = arg.a_grid_desc_m_k_container_[i].GetLength(I1);
@@ -747,12 +714,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch,
+ ComputePtrOffsetOfStridedBatch,
has_main_loop>;
return launch_and_time_kernel(
stream_config,
kernel,
- dim3(grid_size),
+ dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_a_grid_,
@@ -762,13 +730,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
- arg.a_g_n_k_wos_lengths_[0], // Group count
arg.a_grid_desc_ak0_m_ak1_container_[i],
arg.b_grid_desc_bk0_n_bk1_container_[i],
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
arg.block_2_etile_map_container_[i],
- arg.compute_ptr_offset_of_batch_);
+ arg.compute_ptr_offset_of_batch_,
+ arg.compute_ptr_offset_of_n_);
};
if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK))
diff --git a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
index 2be0b66812..8df0d885b9 100644
--- a/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
+++ b/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
-// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
+// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -13,150 +13,6 @@
namespace ck {
namespace tensor_operation {
-namespace {
-template <
- index_t NDimSpatial,
- typename ALayout,
- ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization>
-constexpr auto make_out_grid_desc(const index_t N,
- const index_t Do,
- const index_t Ho,
- const index_t Wo,
- const index_t K,
- const std::array& out_g_n_k_wos_strides)
-{
- const auto KStride = Number<1>{};
-
- if constexpr(is_same_v)
- {
- const index_t NStride = out_g_n_k_wos_strides[1];
- const index_t HiStride = out_g_n_k_wos_strides[3];
- const index_t WiStride = out_g_n_k_wos_strides[4];
- if constexpr(ConvBwdDataSpecialization ==
- ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
- Filter1x1Stride1Pad0)
- {
-
- return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K),
- make_tuple(WiStride, KStride));
- }
- else
- {
- return make_naive_tensor_descriptor(make_tuple(N, Ho, Wo, K),
- make_tuple(NStride, HiStride, WiStride, KStride));
- }
- }
- else if constexpr(is_same_v)
- {
- const index_t NStride = out_g_n_k_wos_strides[1];
- const index_t DoStride = out_g_n_k_wos_strides[3];
- const index_t HoStride = out_g_n_k_wos_strides[4];
- const index_t WoStride = out_g_n_k_wos_strides[5];
- if constexpr(ConvBwdDataSpecialization ==
- ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
- Filter1x1Stride1Pad0)
- {
-
- return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K),
- make_tuple(WoStride, KStride));
- }
- else
- {
- return make_naive_tensor_descriptor(
- make_tuple(N, Do, Ho, Wo, K),
- make_tuple(NStride, DoStride, HoStride, WoStride, KStride));
- }
- }
- else if constexpr(is_same_v)
- {
- // assume packed
- if constexpr(ConvBwdDataSpecialization ==
- ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
- Filter1x1Stride1Pad0)
- {
- return make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
- }
- else
- {
- return make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K));
- }
- }
- else if constexpr(is_same_v)
- {
- // assume packed
- if constexpr(ConvBwdDataSpecialization ==
- ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
- Filter1x1Stride1Pad0)
- {
- return make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K));
- }
- else
- {
- return make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K));
- }
- }
- else
- {
- throw std::runtime_error("wrong! unsupported layout: " + ALayout::name());
- }
-}
-
-template
-constexpr auto make_wei_grid_desc(
- const index_t K, const index_t Z, const index_t Y, const index_t X, const index_t C)
-{
-
- if constexpr(is_same_v)
- {
- return make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C));
- }
- else if constexpr(is_same_v)
- {
- return make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C));
- }
- else
- {
- throw std::runtime_error("wrong! unsupported layout: " + BLayout::name());
- }
-}
-
-template
-constexpr auto make_in_grid_desc(const index_t N,
- const index_t Di,
- const index_t Hi,
- const index_t Wi,
- const index_t C,
- const std::array& in_g_n_c_wis_strides)
-{
-
- if constexpr(is_same_v ||
- is_same_v ||
- is_same_v)
- {
- return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C),
- make_tuple(in_g_n_c_wis_strides[1],
- in_g_n_c_wis_strides[3],
- in_g_n_c_wis_strides[4],
- in_g_n_c_wis_strides[2]));
- }
- else if constexpr(is_same_v ||
- is_same_v)
- {
- return make_naive_tensor_descriptor(make_tuple(N, Di, Hi, Wi, C),
- make_tuple(in_g_n_c_wis_strides[1],
- in_g_n_c_wis_strides[3],
- in_g_n_c_wis_strides[4],
- in_g_n_c_wis_strides[5],
- in_g_n_c_wis_strides[2]));
- }
- else
- {
- throw std::runtime_error("wrong! unsupported layout: " + CLayout::name());
- }
-}
-
-} // namespace
-
template <
index_t NDimSpatial,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization,
@@ -166,92 +22,605 @@ template <
index_t GemmNPerBlock,
index_t GemmKPerBlock,
bool DoPadGemmM,
- bool DoPadGemmN>
+ bool DoPadGemmN,
+ typename ALayout,
+ typename BLayout,
+ typename CLayout,
+ bool SplitN = false,
+ typename ADataType = float,
+ typename CDataType = float,
+ index_t NumGroupsToMerge = 1,
+ typename IndexType = index_t>
struct TransformConvBwdDataToGemm_v1
{
+ private:
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
+ static constexpr auto I2 = Number<2>{};
+ static constexpr auto I3 = Number<3>{};
static constexpr auto NonSpatialDimsNum = Number<3>{};
- static constexpr auto DIdx = Number{};
+ static constexpr auto DIdx = NonSpatialDimsNum;
static constexpr auto HIdx =
- NDimSpatial == 2 ? Number{} : Number{};
+ NDimSpatial == 2 ? NonSpatialDimsNum : Number{};
static constexpr auto WIdx =
NDimSpatial == 2 ? Number{} : Number{};
- static constexpr auto ZIdx = Number{};
+ static constexpr auto ZIdx = NonSpatialDimsNum;
static constexpr auto YIdx =
- NDimSpatial == 2 ? Number{} : Number{};
+ NDimSpatial == 2 ? NonSpatialDimsNum : Number{};
static constexpr auto XIdx =
NDimSpatial == 2 ? Number{} : Number{};
- template ||
- is_same_v ||
- is_same_v ||
- is_same_v),
- bool>::type = false>
- static auto MakeADescriptor_AK0_M_AK1(
- const std::array& out_g_n_k_wos_lengths,
- const std::array& out_g_n_k_wos_strides,
- const std::array& wei_g_k_c_xs_lengths,
- const std::array& /* wei_g_k_c_xs_strides */,
- const std::array& in_g_n_c_wis_lengths,
- const std::array& /* in_g_n_c_wis_strides */,
- const std::array& conv_filter_strides,
- const std::array& conv_filter_dilations,
- const std::array& input_left_pads,
- const std::array& /* input_right_pads */,
- const std::array& tildes)
+ template
+ static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths,
+ const ConvDimsType& strides,
+ index_t i)
{
- index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum];
- index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum];
- index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum];
+ long_index_t acc = 1;
+ for(; i < (NDimSpatial + 3); i++)
+ {
+ acc +=
+ static_cast(lengths[i] - I1) * static_cast(strides[i]);
+ }
- const index_t N = in_g_n_c_wis_lengths[1];
- const index_t K = wei_g_k_c_xs_lengths[1];
+ return acc;
+ }
- const index_t Di = NDimSpatial == 3 ? in_g_n_c_wis_lengths[DIdx] : 1;
- const index_t Hi = in_g_n_c_wis_lengths[HIdx];
- const index_t Wi = in_g_n_c_wis_lengths[WIdx];
+ template
+ static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_k_wos_lengths,
+ const ConvDimsType& a_g_n_k_wos_strides,
+ const ConvDimsType& c_g_n_c_wis_lengths,
+ const ConvDimsType& c_g_n_c_wis_strides)
+ {
+ const long_index_t a_element_space_size =
+ calculate_element_space_size_impl(a_g_n_k_wos_lengths, a_g_n_k_wos_strides, I1);
+ const long_index_t c_element_space_size =
+ calculate_element_space_size_impl(c_g_n_c_wis_lengths, c_g_n_c_wis_strides, I1);
+ const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType),
+ c_element_space_size * sizeof(CDataType));
+ constexpr long_index_t TwoGB = (long_index_t{1} << 31);
- const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1;
- const index_t Ho = out_g_n_k_wos_lengths[HIdx];
- const index_t Wo = out_g_n_k_wos_lengths[WIdx];
+ const IndexType N = a_g_n_k_wos_lengths[I1];
- const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1;
- const index_t Y = wei_g_k_c_xs_lengths[YIdx];
- const index_t X = wei_g_k_c_xs_lengths[XIdx];
+ if(element_space_size > TwoGB)
+ {
+ // Minimum divisor of N to not exceed 2GB
+ const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB);
- const index_t InLeftPadD = input_left_pads[DIdx - NonSpatialDimsNum];
- const index_t InLeftPadH = input_left_pads[HIdx - NonSpatialDimsNum];
- const index_t InLeftPadW = input_left_pads[WIdx - NonSpatialDimsNum];
+ if(divisor <= static_cast(N))
+ {
+ // Find least divisor of N larger than element_space_size / TwoGB
+ // Iterate up to sqrt(N). There are no divisors above this value.
+ for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N;
+ least_divisor++)
+ {
+ if(N % least_divisor == 0)
+ {
+ return N / least_divisor;
+ }
+ }
+ // Not found, process one Convolution N per block
+ return 1;
+ }
+ else
+ {
+ // Not possible to support even after split N.
+ // Too large tensor.
+ return N;
+ }
+ }
+ else
+ {
+ // Split N is not needed.
+ return N;
+ }
+ }
- const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
- const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
- const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
+ public:
+ __host__ __device__ constexpr TransformConvBwdDataToGemm_v1() {}
- const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
- const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
- const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
+ template
+ __host__ __device__ TransformConvBwdDataToGemm_v1(
+ const TransformConvBwdDataToGemm_v1Base& transform_conv_bwd_data_to_gemm_base)
+ : N_{static_cast(transform_conv_bwd_data_to_gemm_base.N_)},
+ Di_{static_cast(transform_conv_bwd_data_to_gemm_base.Di_)},
+ Hi_{static_cast(transform_conv_bwd_data_to_gemm_base.Hi_)},
+ Wi_{static_cast(transform_conv_bwd_data_to_gemm_base.Wi_)},
+ Do_{static_cast(transform_conv_bwd_data_to_gemm_base.Do_)},
+ Ho_{static_cast(transform_conv_bwd_data_to_gemm_base.Ho_)},
+ Wo_{static_cast(transform_conv_bwd_data_to_gemm_base.Wo_)},
+ Z_{static_cast(transform_conv_bwd_data_to_gemm_base.Z_)},
+ Y_{static_cast(transform_conv_bwd_data_to_gemm_base.Y_)},
+ X_{static_cast(transform_conv_bwd_data_to_gemm_base.X_)},
+ K_{static_cast(transform_conv_bwd_data_to_gemm_base.K_)},
+ C_{static_cast(transform_conv_bwd_data_to_gemm_base.C_)},
+ DiStride_{static_cast(transform_conv_bwd_data_to_gemm_base.DiStride_)},
+ HiStride_{static_cast(transform_conv_bwd_data_to_gemm_base.HiStride_)},
+ WiStride_{static_cast(transform_conv_bwd_data_to_gemm_base.WiStride_)},
+ DoStride_{static_cast(transform_conv_bwd_data_to_gemm_base.DoStride_)},
+ HoStride_{static_cast(transform_conv_bwd_data_to_gemm_base.HoStride_)},
+ WoStride_{static_cast(transform_conv_bwd_data_to_gemm_base.WoStride_)},
+ CStrideTensorB_{
+ static_cast(transform_conv_bwd_data_to_gemm_base.CStrideTensorB_)},
+ CStrideTensorC_{
+ static_cast(transform_conv_bwd_data_to_gemm_base.CStrideTensorC_)},
+ KStrideTensorA_{
+ static_cast(transform_conv_bwd_data_to_gemm_base.KStrideTensorA_)},
+ KStrideTensorB_{
+ static_cast(transform_conv_bwd_data_to_gemm_base.KStrideTensorB_)},
+ NStrideTensorA_{
+ static_cast(transform_conv_bwd_data_to_gemm_base.NStrideTensorA_)},
+ NStrideTensorC_{
+ static_cast(transform_conv_bwd_data_to_gemm_base.NStrideTensorC_)},
+ ConvStrideD_{static_cast(transform_conv_bwd_data_to_gemm_base.ConvStrideD_)},
+ ConvStrideH_{static_cast(transform_conv_bwd_data_to_gemm_base.ConvStrideH_)},
+ ConvStrideW_{static_cast(transform_conv_bwd_data_to_gemm_base.ConvStrideW_)},
+ ConvDilationD_{
+ static_cast(transform_conv_bwd_data_to_gemm_base.ConvDilationD_)},
+ ConvDilationH_{
+ static_cast(transform_conv_bwd_data_to_gemm_base.ConvDilationH_)},
+ ConvDilationW_{
+ static_cast(transform_conv_bwd_data_to_gemm_base.ConvDilationW_)},
+ InLeftPadD_{static_cast(transform_conv_bwd_data_to_gemm_base.InLeftPadD_)},
+ InLeftPadH_{static_cast(transform_conv_bwd_data_to_gemm_base.InLeftPadH_)},
+ InLeftPadW_{static_cast(transform_conv_bwd_data_to_gemm_base.InLeftPadW_)},
+ InRightPadD_{static_cast(transform_conv_bwd_data_to_gemm_base.InRightPadD_)},
+ InRightPadH_{static_cast(transform_conv_bwd_data_to_gemm_base.InRightPadH_)},
+ InRightPadW_{static_cast(transform_conv_bwd_data_to_gemm_base.InRightPadW_)},
+ IdxZTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.IdxZTilde_)},
+ IdxYTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.IdxYTilde_)},
+ IdxXTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.IdxXTilde_)},
+ GcdStrideDilationD_{
+ static_cast(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationD_)},
+ GcdStrideDilationH_{
+ static_cast(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationH_)},
+ GcdStrideDilationW_{
+ static_cast(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationW_)},
+ ZTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.ZTilde_)},
+ YTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.YTilde_)},
+ XTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.XTilde_)},
+ DTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.DTilde_)},
+ HTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.HTilde_)},
+ WTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.WTilde_)},
+ ZDot_{static_cast(transform_conv_bwd_data_to_gemm_base.ZDot_)},
+ YDot_{static_cast(transform_conv_bwd_data_to_gemm_base.YDot_)},
+ XDot_{static_cast(transform_conv_bwd_data_to_gemm_base.XDot_)}
+ {
+ }
+ template
+ __host__ __device__
+ TransformConvBwdDataToGemm_v1(const ConvDimsType& a_g_n_k_wos_lengths,
+ const ConvDimsType& a_g_n_k_wos_strides,
+ const ConvDimsType& b_g_k_c_xs_lengths,
+ const ConvDimsType& b_g_k_c_xs_strides,
+ const ConvDimsType& c_g_n_c_wis_lengths,
+ const ConvDimsType& c_g_n_c_wis_strides,
+ const ConvSpatialDimsType& conv_filter_strides,
+ const ConvSpatialDimsType& conv_filter_dilations,
+ const ConvSpatialDimsType& input_left_pads,
+ const ConvSpatialDimsType& input_right_pads,
+ const ConvSpatialDimsType& tildes)
+ : Hi_{c_g_n_c_wis_lengths[HIdx]},
+ Wi_{c_g_n_c_wis_lengths[WIdx]},
+ Ho_{a_g_n_k_wos_lengths[HIdx]},
+ Wo_{a_g_n_k_wos_lengths[WIdx]},
+ Y_{b_g_k_c_xs_lengths[YIdx]},
+ X_{b_g_k_c_xs_lengths[XIdx]},
+ K_{a_g_n_k_wos_lengths[I2]},
+ C_{b_g_k_c_xs_lengths[I2]},
+ HiStride_{c_g_n_c_wis_strides[HIdx]},
+ WiStride_{c_g_n_c_wis_strides[WIdx]},
+ HoStride_{a_g_n_k_wos_strides[HIdx]},
+ WoStride_{a_g_n_k_wos_strides[WIdx]},
+ CStrideTensorB_{b_g_k_c_xs_strides[I2]},
+ CStrideTensorC_{c_g_n_c_wis_strides[I2]},
+ KStrideTensorA_{a_g_n_k_wos_strides[I2]},
+ KStrideTensorB_{b_g_k_c_xs_strides[I1]},
+ NStrideTensorA_{a_g_n_k_wos_strides[I1]},
+ NStrideTensorC_{c_g_n_c_wis_strides[I1]},
+ ConvStrideH_{conv_filter_strides[HIdx - NonSpatialDimsNum]},
+ ConvStrideW_{conv_filter_strides[WIdx - NonSpatialDimsNum]},
+ ConvDilationH_{conv_filter_dilations[HIdx - NonSpatialDimsNum]},
+ ConvDilationW_{conv_filter_dilations[WIdx - NonSpatialDimsNum]},
+ InLeftPadH_{input_left_pads[HIdx - NonSpatialDimsNum]},
+ InLeftPadW_{input_left_pads[WIdx - NonSpatialDimsNum]},
+ InRightPadH_{input_right_pads[HIdx - NonSpatialDimsNum]},
+ InRightPadW_{input_right_pads[WIdx - NonSpatialDimsNum]},
+ IdxYTilde_{tildes[YIdx - NonSpatialDimsNum]},
+ IdxXTilde_{tildes[XIdx - NonSpatialDimsNum]}
+ {
+ static_assert(is_same_v> ||
+ is_same_v>);
+ static_assert(is_same_v> ||
+ is_same_v>);
+
+ if constexpr(SplitN)
+ {
+ N_ = GetSplitedNSize(
+ a_g_n_k_wos_lengths, a_g_n_k_wos_strides, c_g_n_c_wis_lengths, c_g_n_c_wis_strides);
+ }
+ else
+ {
+ N_ = c_g_n_c_wis_lengths[I1];
+ }
+ if constexpr(NDimSpatial == 3)
+ {
+ Di_ = c_g_n_c_wis_lengths[DIdx];
+ Do_ = a_g_n_k_wos_lengths[DIdx];
+ Z_ = b_g_k_c_xs_lengths[ZIdx];
+ DiStride_ = c_g_n_c_wis_strides[DIdx];
+ DoStride_ = a_g_n_k_wos_strides[DIdx];
+ ConvStrideD_ = conv_filter_strides[DIdx - NonSpatialDimsNum];
+ ConvDilationD_ = conv_filter_dilations[DIdx - NonSpatialDimsNum];
+ InLeftPadD_ = input_left_pads[DIdx - NonSpatialDimsNum];
+ InRightPadD_ = input_right_pads[DIdx - NonSpatialDimsNum];
+ IdxZTilde_ = tildes[ZIdx - NonSpatialDimsNum];
+ GcdStrideDilationD_ = math::gcd(ConvStrideD_, ConvDilationD_);
+ ZTilde_ = ConvStrideD_ / GcdStrideDilationD_;
+ DTilde_ = Do_ + math::integer_divide_ceil(ConvDilationD_ * (Z_ - I1), ConvStrideD_);
+ ZDot_ = math::integer_divide_ceil(Z_, ZTilde_);
+ }
+ else
+ {
+ Di_ = Do_ = Z_ = ZTilde_ = ConvStrideD_ = DTilde_ = ZDot_ = 1;
+ InLeftPadD_ = InRightPadD_ = DiStride_ = DoStride_ = IdxZTilde_ = 0;
+ }
+
+ GcdStrideDilationH_ = math::gcd(ConvStrideH_, ConvDilationH_);
+ GcdStrideDilationW_ = math::gcd(ConvStrideW_, ConvDilationW_);
+
+ YTilde_ = ConvStrideH_ / GcdStrideDilationH_;
+ XTilde_ = ConvStrideW_ / GcdStrideDilationW_;
+
+ HTilde_ = Ho_ + math::integer_divide_ceil(ConvDilationH_ * (Y_ - I1), ConvStrideH_);
+ WTilde_ = Wo_ + math::integer_divide_ceil(ConvDilationW_ * (X_ - I1), ConvStrideW_);
+
+ YDot_ = math::integer_divide_ceil(Y_, YTilde_);
+ XDot_ = math::integer_divide_ceil(X_, XTilde_);
+ }
+
+#if 0 // At now not supported to split tensor
+ __host__ bool AreDescriptorsSmallerThan2GB() const
+ {
+ constexpr long_index_t TwoGB = (long_index_t{1} << 31);
+
+ const long_index_t in_desc_space_size =
+ I1 + (N_ - I1) * NStrideTensorC_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ +
+ (Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorC_;
+ const long_index_t out_desc_space_size =
+ I1 + (N_ - I1) * NStrideTensorA_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ +
+ (Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorA_;
+
+ bool is_a_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(ADataType)) <= TwoGB;
+ bool is_c_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(CDataType)) <= TwoGB;
+
+ return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB;
+ }
+
+ __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
+ CDataType* c_grid_ptr_base) const
+ {
+ // Create copies
+ auto conv_to_gemm_transformer_left = *this;
+ auto conv_to_gemm_transformer_right = *this;
+ IndexType a_right_offset = 0;
+ IndexType c_right_offset = 0;
+ // Calculate real filter size
+ const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1;
+ const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1;
+ const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1;
+ // Calculate start position in input for right tensor
+ const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_;
+ const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_;
+ const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_;
+ // Calculate last position in input for left tensor
+ const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff;
+ const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff;
+ const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff;
+ // Allow to split if whole left padding will be in left tensor and right padding in right
+ // tensor
+ const bool is_possible_to_split_d = Do_ != 1 &&
+ di_right_transformer_start_idx > InLeftPadD_ &&
+ di_left_transformer_end_idx <= (InLeftPadD_ + Di_);
+ const bool is_possible_to_split_h = Ho_ != 1 &&
+ hi_right_transformer_start_idx > InLeftPadH_ &&
+ hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_);
+ const bool is_possible_to_split_w = Wo_ != 1 &&
+ wi_right_transformer_start_idx > InLeftPadW_ &&
+ wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_);
+
+ if(is_possible_to_split_d)
+ {
+ // Apply new sizes
+ // Split output on half
+ conv_to_gemm_transformer_left.Do_ = Do_ / 2;
+ conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2;
+ // Assign left padding to left convolution
+ conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
+ conv_to_gemm_transformer_right.InLeftPadD_ = 0;
+ // Assign right padding to right convolution
+ conv_to_gemm_transformer_left.InRightPadD_ = 0;
+ conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
+ // Calculate new input size
+ conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_;
+ conv_to_gemm_transformer_right.Di_ =
+ math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_),
+ (conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff);
+ ;
+ // Calcualte offsets
+ a_right_offset = (Do_ / 2) * DoStride_;
+ c_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_;
+ }
+ else if(is_possible_to_split_h)
+ {
+ conv_to_gemm_transformer_left.Ho_ = Ho_ / 2;
+ conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2;
+
+ conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
+ conv_to_gemm_transformer_right.InLeftPadH_ = 0;
+
+ conv_to_gemm_transformer_left.InRightPadH_ = 0;
+ conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
+
+ conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_;
+ conv_to_gemm_transformer_right.Hi_ =
+ math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_),
+ (conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff);
+ a_right_offset = (Ho_ / 2) * HoStride_;
+ c_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_;
+ }
+ else if(is_possible_to_split_w)
+ {
+ conv_to_gemm_transformer_left.Wo_ = Wo_ / 2;
+ conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2;
+
+ conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
+ conv_to_gemm_transformer_right.InLeftPadW_ = 0;
+
+ conv_to_gemm_transformer_left.InRightPadW_ = 0;
+ conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
+
+ conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_;
+ conv_to_gemm_transformer_right.Wi_ =
+ math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_),
+ (conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff);
+
+ a_right_offset = (Wo_ / 2) * WoStride_;
+ c_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_;
+ }
+ // Return left transform, right transformer, right offset to Input and right offset to
+ // Output
+ return ck::make_tuple(conv_to_gemm_transformer_left,
+ conv_to_gemm_transformer_right,
+ a_grid_ptr_base + a_right_offset,
+ c_grid_ptr_base + c_right_offset);
+ }
+
+ __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
+ CDataType* c_grid_ptr_base) const
+ {
+ // Create copies
+ auto conv_to_gemm_transformer_left = *this;
+ auto conv_to_gemm_transformer_right = *this;
+ IndexType a_right_offset = 0;
+ IndexType c_right_offset = 0;
+
+ // Calculate start position in input for right tensor
+ const IndexType do_right_transformer_start_idx = math::integer_divide_ceil((Di_ / 2) + InLeftPadD_ - ((Z_ - 1) * ConvDilationD_), ConvStrideD_);
+ const IndexType ho_right_transformer_start_idx = math::integer_divide_ceil((Hi_ / 2) + InLeftPadH_ - ((Y_ - 1) * ConvDilationH_), ConvStrideH_);
+ const IndexType wo_right_transformer_start_idx = math::integer_divide_ceil((Wi_ / 2) + InLeftPadW_ - ((X_ - 1) * ConvDilationW_), ConvStrideW_);
+ // Calculate last position in input for left tensor
+ const IndexType do_left_transformer_end_idx = math::integer_divide_ceil((Di_ / 2 - 1) + InLeftPadD_, ConvStrideD_);
+ const IndexType ho_left_transformer_end_idx = math::integer_divide_ceil((Hi_ / 2 - 1) + InLeftPadH_, ConvStrideH_);
+ const IndexType wo_left_transformer_end_idx = math::integer_divide_ceil((Wi_ / 2 - 1) + InLeftPadW_, ConvStrideW_);
+
+
+ if(Di_!=1)
+ {
+ // Apply new sizes
+ // Split output on half
+ conv_to_gemm_transformer_left.Di_ = Di_ / 2;
+ conv_to_gemm_transformer_right.Di_ = Di_ - Di_ / 2;
+ // Assign left padding to left convolution
+ conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
+ conv_to_gemm_transformer_right.InLeftPadD_ = 0;
+ // // Assign right padding to right convolution
+ conv_to_gemm_transformer_left.InRightPadD_ = 0;
+ conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
+ // Calculate new input size
+ conv_to_gemm_transformer_left.Do_ = do_left_transformer_end_idx;
+ conv_to_gemm_transformer_right.Do_ = Do_ - do_right_transformer_start_idx;
+ ;
+ // Calcualte offsets
+ a_right_offset = do_right_transformer_start_idx * DoStride_;
+ c_right_offset = (Di_ / 2) * DiStride_;
+ }
+ else if(Hi_!=1)
+ {
+ // Apply new sizes
+ // Split output on half
+ conv_to_gemm_transformer_left.Hi_ = Hi_ / 2;
+ conv_to_gemm_transformer_right.Hi_ = Hi_ - Hi_ / 2;
+ // Assign left padding to left convolution
+ conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
+ conv_to_gemm_transformer_right.InLeftPadH_ = 0;
+ // // Assign right padding to right convolution
+ conv_to_gemm_transformer_left.InRightPadH_ = 0;
+ conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
+ // Calculate new input size
+ conv_to_gemm_transformer_left.Ho_ = ho_left_transformer_end_idx ;
+ conv_to_gemm_transformer_right.Ho_ = Ho_ - ho_right_transformer_start_idx ;
+ ;
+ // Calcualte offsets
+ a_right_offset = ho_right_transformer_start_idx * HoStride_;
+ c_right_offset = (Hi_ / 2) * HiStride_;
+ }
+ else if(Wi_!=1)
+ {
+ // Apply new sizes
+ // Split output on half
+ conv_to_gemm_transformer_left.Wi_ = Wi_ / 2;
+ conv_to_gemm_transformer_right.Wi_ = Wi_ - Wi_ / 2;
+ // Assign left padding to left convolution
+ conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
+ conv_to_gemm_transformer_right.InLeftPadW_ = 0;
+ // Assign right padding to right convolution
+ conv_to_gemm_transformer_left.InRightPadW_ = 0;
+ conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
+ // Calculate new input size
+ conv_to_gemm_transformer_left.Wo_ = wo_left_transformer_end_idx;
+ conv_to_gemm_transformer_right.Wo_ = Wo_ - wo_right_transformer_start_idx;
+ ;
+ // Calcualte offsets
+ a_right_offset = wo_right_transformer_start_idx * WoStride_;
+ c_right_offset = (Wi_ / 2) * WiStride_;
+ }
+ // Return left transform, right transformer, right offset to Input and right offset to
+ // Output
+ return ck::make_tuple(conv_to_gemm_transformer_left,
+ conv_to_gemm_transformer_right,
+ a_grid_ptr_base + a_right_offset,
+ c_grid_ptr_base + c_right_offset);
+ }
+#endif
+
+ __host__ __device__ auto MakeOutGridDesc() const
+ {
+ if constexpr(is_same_v)
+ {
+ if constexpr(ConvBwdDataSpecialization ==
+ ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
+ Filter1x1Stride1Pad0)
+ {
+
+ return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_),
+ make_tuple(WoStride_, KStrideTensorA_));
+ }
+ else
+ {
+ return make_naive_tensor_descriptor(
+ make_tuple(N_, Ho_, Wo_, K_),
+ make_tuple(NStrideTensorA_, HoStride_, WoStride_, KStrideTensorA_));
+ }
+ }
+ else if constexpr(is_same_v)
+ {
+ if constexpr(ConvBwdDataSpecialization ==
+ ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
+ Filter1x1Stride1Pad0)
+ {
+
+ return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_),
+ make_tuple(WoStride_, KStrideTensorA_));
+ }
+ else
+ {
+ return make_naive_tensor_descriptor(
+ make_tuple(N_, Do_, Ho_, Wo_, K_),
+ make_tuple(NStrideTensorA_, DoStride_, HoStride_, WoStride_, KStrideTensorA_));
+ }
+ }
+ else if constexpr(is_same_v)
+ {
+ // assume packed
+ if constexpr(ConvBwdDataSpecialization ==
+ ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
+ Filter1x1Stride1Pad0)
+ {
+ return make_naive_tensor_descriptor_packed(make_tuple(N_ * Ho_ * Wo_, K_));
+ }
+ else
+ {
+ return make_naive_tensor_descriptor_packed(make_tuple(N_, Ho_, Wo_, K_));
+ }
+ }
+ else if constexpr(is_same_v)
+ {
+ // assume packed
+ if constexpr(ConvBwdDataSpecialization ==
+ ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
+ Filter1x1Stride1Pad0)
+ {
+ return make_naive_tensor_descriptor_packed(make_tuple(N_ * Do_ * Ho_ * Wo_, K_));
+ }
+ else
+ {
+ return make_naive_tensor_descriptor_packed(make_tuple(N_, Do_, Ho_, Wo_, K_));
+ }
+ }
+ else
+ {
+ throw std::runtime_error("wrong! unsupported layout: " + ALayout::name());
+ }
+ }
+
+ __host__ __device__ auto MakeWeiGridDesc() const
+ {
+
+ if constexpr(is_same_v)
+ {
+ return make_naive_tensor_descriptor_packed(make_tuple(K_, Y_, X_, C_));
+ }
+ else if constexpr(is_same_v)
+ {
+ return make_naive_tensor_descriptor_packed(make_tuple(K_, Z_, Y_, X_, C_));
+ }
+ else
+ {
+ throw std::runtime_error("wrong! unsupported layout: " + BLayout::name());
+ }
+ }
+
+ __host__ __device__ auto MakeInGridDesc() const
+ {
+
+ if constexpr(is_same_v ||
+ is_same_v ||
+ is_same_v)
+ {
+ return make_naive_tensor_descriptor(
+ make_tuple(N_, Hi_, Wi_, C_),
+ make_tuple(NStrideTensorC_, HiStride_, WiStride_, CStrideTensorC_));
+ }
+ else if constexpr(is_same_v ||
+ is_same_v)
+ {
+ return make_naive_tensor_descriptor(
+ make_tuple(N_, Di_, Hi_, Wi_, C_),
+ make_tuple(NStrideTensorC_, DiStride_, HiStride_, WiStride_, CStrideTensorC_));
+ }
+ else
+ {
+ throw std::runtime_error("wrong! unsupported layout: " + CLayout::name());
+ }
+ }
+
+ template <
+ typename ALayout_ = ALayout,
+ typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
+ (is_same_v ||
+ is_same_v ||
+ is_same_v ||
+ is_same_v),
+ bool>::type = false>
+ __host__ __device__ auto MakeADescriptor_AK0_M_AK1() const
+ {
// n_do_ho_wo_k for 3d or n_ho_wo_k for 2d
- const auto out_grid_desc =
- make_out_grid_desc(
- N, Do, Ho, Wo, K, out_g_n_k_wos_strides);
+ const auto out_grid_desc = MakeOutGridDesc();
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
- const index_t AK0 = math::integer_divide_ceil(K, AK1);
+ const index_t AK0 = math::integer_divide_ceil(K_, AK1);
// A: output tensor
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
out_grid_desc,
- make_tuple(make_pass_through_transform(N * Do * Ho * Wo),
+ make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_),
make_unmerge_transform(make_tuple(AK0, AK1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
@@ -266,82 +635,63 @@ struct TransformConvBwdDataToGemm_v1
}
else
{
- const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
- const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
- const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
-
- const auto ZTilde = ConvStrideD / GcdStrideDilationD;
- const auto YTilde = ConvStrideH / GcdStrideDilationH;
- const auto XTilde = ConvStrideW / GcdStrideDilationW;
-
- const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
- const auto YDot = math::integer_divide_ceil(Y, YTilde);
- const auto XDot = math::integer_divide_ceil(X, XTilde);
-
- const auto DTilde =
- Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
- const auto HTilde =
- Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
- const auto WTilde =
- Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
-
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IDTildeSliceBegin = math::integer_divide_floor(
- math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
+ math::max(I0, InLeftPadD_ - ConvDilationD_ * (ZTilde_ - I1)), ConvStrideD_);
const auto IHTildeSliceBegin = math::integer_divide_floor(
- math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
+ math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_);
const auto IWTildeSliceBegin = math::integer_divide_floor(
- math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
+ math::max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_);
const auto IDTildeSliceEnd = math::min(
- DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
+ DTilde_, math::integer_divide_ceil(InLeftPadD_ + Di_ - I1, ConvStrideD_) + I1);
const auto IHTildeSliceEnd = math::min(
- HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
+ HTilde_, math::integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1);
const auto IWTildeSliceEnd = math::min(
- WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
+ WTilde_, math::integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1);
const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// GemmK is different for each GEMM
- const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
- const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
- const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
+ const auto ZDotSlice = math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_);
+ const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_);
+ const auto XDotSlice = math::integer_divide_ceil(X_ - IdxXTilde_, XTilde_);
if constexpr(NDimSpatial == 2)
{
// A: output tensor
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_grid_desc,
- make_tuple(make_pass_through_transform(N),
- make_pad_transform(Ho, I0, I0),
- make_pad_transform(Wo, I0, I0),
- make_pass_through_transform(K)),
+ make_tuple(make_pass_through_transform(N_),
+ make_pad_transform(Ho_, I0, I0),
+ make_pad_transform(Wo_, I0, I0),
+ make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(
- make_pass_through_transform(N),
- make_embed_transform(make_tuple(YDot, HTilde),
- make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
- make_embed_transform(make_tuple(XDot, WTilde),
- make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
- make_pass_through_transform(K)),
+ make_pass_through_transform(N_),
+ make_embed_transform(make_tuple(YDot_, HTilde_),
+ make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)),
+ make_embed_transform(make_tuple(XDot_, WTilde_),
+ make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)),
+ make_pass_through_transform(K_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
transform_tensor_descriptor(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
- make_tuple(make_pass_through_transform(N),
- make_slice_transform(YDot, I0, YDotSlice),
- make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
- make_slice_transform(XDot, I0, XDotSlice),
- make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
- make_pass_through_transform(K)),
+ make_tuple(make_pass_through_transform(N_),
+ make_slice_transform(YDot_, I0, YDotSlice),
+ make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
+ make_slice_transform(XDot_, I0, XDotSlice),
+ make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
+ make_pass_through_transform(K_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
@@ -357,8 +707,8 @@ struct TransformConvBwdDataToGemm_v1
const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
- make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)),
- make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))),
+ make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)),
+ make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -385,11 +735,11 @@ struct TransformConvBwdDataToGemm_v1
// A: output tensor
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_grid_desc,
- make_tuple(make_pass_through_transform(N),
- make_pad_transform(Do, I0, I0),
- make_pad_transform(Ho, I0, I0),
- make_pad_transform(Wo, I0, I0),
- make_pass_through_transform(K)),
+ make_tuple(make_pass_through_transform(N_),
+ make_pad_transform(Do_, I0, I0),
+ make_pad_transform(Ho_, I0, I0),
+ make_pad_transform(Wo_, I0, I0),
+ make_pass_through_transform(K_)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
@@ -398,17 +748,17 @@ struct TransformConvBwdDataToGemm_v1
const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc =
transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
- make_tuple(make_pass_through_transform(N),
+ make_tuple(make_pass_through_transform(N_),
make_embed_transform(
- make_tuple(ZDot, DTilde),
- make_tuple(-ConvDilationD / GcdStrideDilationD, I1)),
+ make_tuple(ZDot_, DTilde_),
+ make_tuple(-ConvDilationD_ / GcdStrideDilationD_, I1)),
make_embed_transform(
- make_tuple(YDot, HTilde),
- make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
+ make_tuple(YDot_, HTilde_),
+ make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)),
make_embed_transform(
- make_tuple(XDot, WTilde),
- make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
- make_pass_through_transform(K)),
+ make_tuple(XDot_, WTilde_),
+ make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)),
+ make_pass_through_transform(K_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
@@ -424,14 +774,15 @@ struct TransformConvBwdDataToGemm_v1
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
transform_tensor_descriptor(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
- make_tuple(make_pass_through_transform(N),
- make_slice_transform(ZDot, I0, ZDotSlice),
- make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
- make_slice_transform(YDot, I0, YDotSlice),
- make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
- make_slice_transform(XDot, I0, XDotSlice),
- make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
- make_pass_through_transform(K)),
+ make_tuple(
+ make_pass_through_transform(N_),
+ make_slice_transform(ZDot_, I0, ZDotSlice),
+ make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice),
+ make_slice_transform(YDot_, I0, YDotSlice),
+ make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
+ make_slice_transform(XDot_, I0, XDotSlice),
+ make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
+ make_pass_through_transform(K_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
@@ -452,8 +803,9 @@ struct TransformConvBwdDataToGemm_v1
const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
make_tuple(
- make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)),
- make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice))),
+ make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)),
+ make_merge_transform(
+ make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice))),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -482,66 +834,31 @@ struct TransformConvBwdDataToGemm_v1
}
}
- template ||
- is_same_v),
+ (is_same_v ||
+ is_same_v),
bool>::type = false>
- static auto MakeBDescriptor_BK0_N_BK1(
- const std::array& out_g_n_k_wos_lengths,
- const std::array& /* out_g_n_k_wos_strides */,
- const std::array& wei_g_k_c_xs_lengths,
- const std::array& /* wei_g_k_c_xs_strides */,
- const std::array& in_g_n_c_wis_lengths,
- const std::array& /* in_g_n_c_wis_strides */,
- const std::array& conv_filter_strides,
- const std::array& conv_filter_dilations,
- const std::array& /* input_left_pads */,
- const std::array& /* input_right_pads */,
- const std::array& tildes)
+ __host__ __device__ auto MakeBDescriptor_BK0_N_BK1() const
{
- index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum];
- index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum];
- index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum];
-
- const index_t N = in_g_n_c_wis_lengths[1];
- const index_t K = wei_g_k_c_xs_lengths[1];
- const index_t C = wei_g_k_c_xs_lengths[2];
-
- const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1;
- const index_t Ho = out_g_n_k_wos_lengths[HIdx];
- const index_t Wo = out_g_n_k_wos_lengths[WIdx];
-
- const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1;
- const index_t Y = wei_g_k_c_xs_lengths[YIdx];
- const index_t X = wei_g_k_c_xs_lengths[XIdx];
-
- const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
- const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
- const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
-
- const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
- const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
- const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
-
// assume packed
// k_y_x_c for 2d or k_z_y_x_c for 3d
- const auto wei_grid_desc = make_wei_grid_desc(K, Z, Y, X, C);
+ const auto wei_grid_desc = MakeWeiGridDesc();
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
- const index_t BK0 = math::integer_divide_ceil(K, BK1);
+ const index_t BK0 = math::integer_divide_ceil(K_, BK1);
// B: weight tensor
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc =
- transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
+ transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K_, C_)),
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
- make_pass_through_transform(C)),
+ make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
- make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, C), make_tuple(I0, I1));
+ make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, C_), make_tuple(I0, I1));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
@@ -553,22 +870,10 @@ struct TransformConvBwdDataToGemm_v1
}
else
{
- const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
- const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
- const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
-
- const auto ZTilde = ConvStrideD / GcdStrideDilationD;
- const auto YTilde = ConvStrideH / GcdStrideDilationH;
- const auto XTilde = ConvStrideW / GcdStrideDilationW;
-
- const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
- const auto YDot = math::integer_divide_ceil(Y, YTilde);
- const auto XDot = math::integer_divide_ceil(X, XTilde);
-
// GemmK is different for each GEMM
- const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
- const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
- const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
+ const auto ZDotSlice = math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_);
+ const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_);
+ const auto XDotSlice = math::integer_divide_ceil(X_ - IdxXTilde_, XTilde_);
// B weight tensor
if constexpr(NDimSpatial == 2)
@@ -576,23 +881,23 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
wei_grid_desc,
make_tuple(
- make_pass_through_transform(K),
- make_embed_transform(make_tuple(YDot, YTilde),
- make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
- make_embed_transform(make_tuple(XDot, XTilde),
- make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
- make_pass_through_transform(C)),
+ make_pass_through_transform(K_),
+ make_embed_transform(make_tuple(YDot_, YTilde_),
+ make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)),
+ make_embed_transform(make_tuple(XDot_, XTilde_),
+ make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)),
+ make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_k_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
- make_tuple(make_pass_through_transform(K),
- make_slice_transform(YDot, I0, YDotSlice),
- make_slice_transform(XDot, I0, XDotSlice),
- make_freeze_transform(i_ytilde),
- make_freeze_transform(i_xtilde),
- make_pass_through_transform(C)),
+ make_tuple(make_pass_through_transform(K_),
+ make_slice_transform(YDot_, I0, YDotSlice),
+ make_slice_transform(XDot_, I0, XDotSlice),
+ make_freeze_transform(IdxYTilde_),
+ make_freeze_transform(IdxXTilde_),
+ make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
@@ -608,8 +913,8 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
wei_k_ydotslice_xdotslice_c_grid_desc,
- make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)),
- make_pass_through_transform(C)),
+ make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)),
+ make_pass_through_transform(C_)),
make_tuple(Sequence<1, 2, 0>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -636,15 +941,17 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc =
transform_tensor_descriptor(
wei_grid_desc,
- make_tuple(
- make_pass_through_transform(K),
- make_embed_transform(make_tuple(ZDot, ZTilde),
- make_tuple(ConvStrideD / GcdStrideDilationD, I1)),
- make_embed_transform(make_tuple(YDot, YTilde),
- make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
- make_embed_transform(make_tuple(XDot, XTilde),
- make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
- make_pass_through_transform(C)),
+ make_tuple(make_pass_through_transform(K_),
+ make_embed_transform(
+ make_tuple(ZDot_, ZTilde_),
+ make_tuple(ConvStrideD_ / GcdStrideDilationD_, I1)),
+ make_embed_transform(
+ make_tuple(YDot_, YTilde_),
+ make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)),
+ make_embed_transform(
+ make_tuple(XDot_, XTilde_),
+ make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)),
+ make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
@@ -659,14 +966,14 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
- make_tuple(make_pass_through_transform(K),
- make_slice_transform(ZDot, I0, ZDotSlice),
- make_slice_transform(YDot, I0, YDotSlice),
- make_slice_transform(XDot, I0, XDotSlice),
- make_freeze_transform(i_ztilde),
- make_freeze_transform(i_ytilde),
- make_freeze_transform(i_xtilde),
- make_pass_through_transform(C)),
+ make_tuple(make_pass_through_transform(K_),
+ make_slice_transform(ZDot_, I0, ZDotSlice),
+ make_slice_transform(YDot_, I0, YDotSlice),
+ make_slice_transform(XDot_, I0, XDotSlice),
+ make_freeze_transform(IdxZTilde_),
+ make_freeze_transform(IdxYTilde_),
+ make_freeze_transform(IdxXTilde_),
+ make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
@@ -686,8 +993,9 @@ struct TransformConvBwdDataToGemm_v1
const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc,
- make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)),
- make_pass_through_transform(C)),
+ make_tuple(
+ make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)),
+ make_pass_through_transform(C_)),
make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -716,66 +1024,20 @@ struct TransformConvBwdDataToGemm_v1
}
}
- template ||
- is_same_v ||
- is_same_v ||
- is_same_v ||
- is_same_v),
- bool>::type = false>
- static auto
- MakeCDescriptor_M_N(const std::array& out_g_n_k_wos_lengths,
- const std::array& /* out_g_n_k_wos_strides */,
- const std::array& wei_g_k_c_xs_lengths,
- const std::array& /* wei_g_k_c_xs_strides */,
- const std::array& in_g_n_c_wis_lengths,
- const std::array& in_g_n_c_wis_strides,
- const std::array& conv_filter_strides,
- const std::array& conv_filter_dilations,
- const std::array& input_left_pads,
- const std::array& input_right_pads,
- const std::array& tildes)
+ template <
+ typename CLayout_ = CLayout,
+ typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
+ (is_same_v ||
+ is_same_v ||
+ is_same_v ||
+ is_same_v ||
+ is_same_v),
+ bool>::type = false>
+ __host__ __device__ auto MakeCDescriptor_M_N() const
{
- index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum];
- index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum];
- index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum];
-
- const index_t N = in_g_n_c_wis_lengths[1];
- const index_t C = wei_g_k_c_xs_lengths[2];
-
- const index_t Di = NDimSpatial == 3 ? in_g_n_c_wis_lengths[DIdx] : 1;
- const index_t Hi = in_g_n_c_wis_lengths[HIdx];
- const index_t Wi = in_g_n_c_wis_lengths[WIdx];
-
- const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1;
- const index_t Ho = out_g_n_k_wos_lengths[HIdx];
- const index_t Wo = out_g_n_k_wos_lengths[WIdx];
-
- const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1;
- const index_t Y = wei_g_k_c_xs_lengths[YIdx];
- const index_t X = wei_g_k_c_xs_lengths[XIdx];
-
- const index_t InLeftPadD = input_left_pads[DIdx - NonSpatialDimsNum];
- const index_t InLeftPadH = input_left_pads[HIdx - NonSpatialDimsNum];
- const index_t InLeftPadW = input_left_pads[WIdx - NonSpatialDimsNum];
-
- const index_t InRightPadD = input_right_pads[DIdx - NonSpatialDimsNum];
- const index_t InRightPadH = input_right_pads[HIdx - NonSpatialDimsNum];
- const index_t InRightPadW = input_right_pads[WIdx - NonSpatialDimsNum];
-
- const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
- const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
- const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
-
- const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
- const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
- const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
-
// assume strided
// n_hi_wi_c for 2d n_di_hi_wi_c for 3d
- const auto in_grid_desc =
- make_in_grid_desc(N, Di, Hi, Wi, C, in_g_n_c_wis_strides);
+ const auto in_grid_desc = MakeInGridDesc();
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
@@ -787,10 +1049,10 @@ struct TransformConvBwdDataToGemm_v1
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(
- make_pass_through_transform(N),
- make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
- make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
- make_pass_through_transform(C)),
+ make_pass_through_transform(N_),
+ make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)),
+ make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)),
+ make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
@@ -798,8 +1060,8 @@ struct TransformConvBwdDataToGemm_v1
in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0),
- make_merge_transform(make_tuple(N, Ho, Wo)),
- make_pass_through_transform(C)),
+ make_merge_transform(make_tuple(N_, Ho_, Wo_)),
+ make_pass_through_transform(C_)),
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
@@ -818,11 +1080,11 @@ struct TransformConvBwdDataToGemm_v1
const auto in_n_x_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(
- make_pass_through_transform(N),
- make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)),
- make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
- make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
- make_pass_through_transform(C)),
+ make_pass_through_transform(N_),
+ make_embed_transform(make_tuple(I1, Do_), make_tuple(I1, ConvStrideD_)),
+ make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)),
+ make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)),
+ make_pass_through_transform(C_)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
@@ -836,8 +1098,8 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0),
make_freeze_transform(I0),
- make_merge_transform(make_tuple(N, Do, Ho, Wo)),
- make_pass_through_transform(C)),
+ make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
+ make_pass_through_transform(C_)),
make_tuple(Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
@@ -861,36 +1123,21 @@ struct TransformConvBwdDataToGemm_v1
}
else
{
- const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
- const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
- const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
-
- const auto ZTilde = ConvStrideD / GcdStrideDilationD;
- const auto YTilde = ConvStrideH / GcdStrideDilationH;
- const auto XTilde = ConvStrideW / GcdStrideDilationW;
-
- const auto DTilde =
- Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
- const auto HTilde =
- Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
- const auto WTilde =
- Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
-
// only work on DTilde, HTilde and WTilde that contribute to
// non-padding area of input tensor
const auto IDTildeSliceBegin = math::integer_divide_floor(
- math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
+ math::max(I0, InLeftPadD_ - ConvDilationD_ * (ZTilde_ - I1)), ConvStrideD_);
const auto IHTildeSliceBegin = math::integer_divide_floor(
- math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
+ math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_);
const auto IWTildeSliceBegin = math::integer_divide_floor(
- math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
+ math::max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_);
const auto IDTildeSliceEnd = math::min(
- DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
+ DTilde_, math::integer_divide_ceil(InLeftPadD_ + Di_ - I1, ConvStrideD_) + I1);
const auto IHTildeSliceEnd = math::min(
- HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
+ HTilde_, math::integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1);
const auto IWTildeSliceEnd = math::min(
- WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
+ WTilde_, math::integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1);
const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
@@ -901,34 +1148,34 @@ struct TransformConvBwdDataToGemm_v1
{
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
- make_tuple(make_pass_through_transform(N),
- make_pad_transform(Hi, InLeftPadH, InRightPadH),
- make_pad_transform(Wi, InLeftPadW, InRightPadW),
- make_pass_through_transform(C)),
+ make_tuple(make_pass_through_transform(N_),
+ make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
+ make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
+ make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc =
transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
- make_tuple(make_pass_through_transform(N),
- make_embed_transform(make_tuple(YTilde, HTilde),
- make_tuple(ConvDilationH, ConvStrideH)),
- make_embed_transform(make_tuple(XTilde, WTilde),
- make_tuple(ConvDilationW, ConvStrideW)),
- make_pass_through_transform(C)),
+ make_tuple(make_pass_through_transform(N_),
+ make_embed_transform(make_tuple(YTilde_, HTilde_),
+ make_tuple(ConvDilationH_, ConvStrideH_)),
+ make_embed_transform(make_tuple(XTilde_, WTilde_),
+ make_tuple(ConvDilationW_, ConvStrideW_)),
+ make_pass_through_transform(C_)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
- make_tuple(make_pass_through_transform(N),
- make_freeze_transform(i_ytilde),
- make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
- make_freeze_transform(i_xtilde),
- make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
- make_pass_through_transform(C)),
+ make_tuple(make_pass_through_transform(N_),
+ make_freeze_transform(IdxYTilde_),
+ make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
+ make_freeze_transform(IdxXTilde_),
+ make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
+ make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
@@ -944,8 +1191,8 @@ struct TransformConvBwdDataToGemm_v1
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
in_n_htildeslice_wtildeslice_c_grid_desc,
- make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
- make_pass_through_transform(C)),
+ make_tuple(make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice)),
+ make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -961,11 +1208,11 @@ struct TransformConvBwdDataToGemm_v1
{
const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
- make_tuple(make_pass_through_transform(N),
- make_pad_transform(Di, InLeftPadD, InRightPadD),
- make_pad_transform(Hi, InLeftPadH, InRightPadH),
- make_pad_transform(Wi, InLeftPadW, InRightPadW),
- make_pass_through_transform(C)),
+ make_tuple(make_pass_through_transform(N_),
+ make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
+ make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
+ make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
+ make_pass_through_transform(C_)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
@@ -974,14 +1221,14 @@ struct TransformConvBwdDataToGemm_v1
const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
transform_tensor_descriptor(
in_n_dip_hip_wip_c_grid_desc,
- make_tuple(make_pass_through_transform(N),
- make_embed_transform(make_tuple(ZTilde, DTilde),
- make_tuple(ConvDilationD, ConvStrideD)),
- make_embed_transform(make_tuple(YTilde, HTilde),
- make_tuple(ConvDilationH, ConvStrideH)),
- make_embed_transform(make_tuple(XTilde, WTilde),
- make_tuple(ConvDilationW, ConvStrideW)),
- make_pass_through_transform(C)),
+ make_tuple(make_pass_through_transform(N_),
+ make_embed_transform(make_tuple(ZTilde_, DTilde_),
+ make_tuple(ConvDilationD_, ConvStrideD_)),
+ make_embed_transform(make_tuple(YTilde_, HTilde_),
+ make_tuple(ConvDilationH_, ConvStrideH_)),
+ make_embed_transform(make_tuple(XTilde_, WTilde_),
+ make_tuple(ConvDilationW_, ConvStrideW_)),
+ make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
@@ -996,14 +1243,14 @@ struct TransformConvBwdDataToGemm_v1
const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
transform_tensor_descriptor(
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
- make_tuple(make_pass_through_transform(N),
- make_freeze_transform(i_ztilde),
- make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
- make_freeze_transform(i_ytilde),
- make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
- make_freeze_transform(i_xtilde),
- make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
- make_pass_through_transform(C)),
+ make_tuple(make_pass_through_transform(N_),
+ make_freeze_transform(IdxZTilde_),
+ make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice),
+ make_freeze_transform(IdxYTilde_),
+ make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice),
+ make_freeze_transform(IdxXTilde_),
+ make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice),
+ make_pass_through_transform(C_)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
@@ -1024,8 +1271,8 @@ struct TransformConvBwdDataToGemm_v1
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
make_tuple(
- make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
- make_pass_through_transform(C)),
+ make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice)),
+ make_pass_through_transform(C_)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
@@ -1044,84 +1291,41 @@ struct TransformConvBwdDataToGemm_v1
}
// for input bias
- template ||
- is_same_v),
+ (is_same_v ||
+ is_same_v),
bool>::type = false>
- static auto
- MakeCDescriptor_M_N(const std::array& out_g_n_k_wos_lengths,
- const std::array& /* out_g_n_k_wos_strides */,
- const std::array& wei_g_k_c_xs_lengths,
- const std::array& /* wei_g_k_c_xs_strides */,
- const std::array& in_g_n_c_wis_lengths,
- const std::array& /* in_g_n_c_wis_strides */,
- const std::array