mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Add wei_strides to grouped conv3d wei to keep consistency (#817)
* Add wei_strides to grouped conv3d wei to keep consistency * Fix strides in client examples * Unify backward weight api with forward * Fix for example * Fixes for examples --------- Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -784,15 +784,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
Argument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& /*input_strides*/,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& /*output_strides*/,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
|
||||
const std::array<index_t, NDimSpatial + 3>& /*a_g_n_c_wis_strides*/,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& /*b_g_k_c_xs_strides*/,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
|
||||
const std::array<index_t, NDimSpatial + 3>& /*e_g_n_k_wos_strides*/,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
@@ -812,27 +809,38 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
a_element_op_{out_element_op},
|
||||
b_element_op_{wei_element_op},
|
||||
c_element_op_{in_element_op},
|
||||
Conv_G_{G},
|
||||
Conv_N_{N},
|
||||
Conv_K_{K},
|
||||
Conv_C_{C},
|
||||
input_spatial_lengths_{input_spatial_lengths},
|
||||
filter_spatial_lengths_{filter_spatial_lengths},
|
||||
output_spatial_lengths_{output_spatial_lengths},
|
||||
Conv_G_{a_g_n_c_wis_lengths[0]},
|
||||
Conv_N_{a_g_n_c_wis_lengths[1]},
|
||||
Conv_K_{b_g_k_c_xs_lengths[1]},
|
||||
Conv_C_{a_g_n_c_wis_lengths[2]},
|
||||
input_spatial_lengths_{},
|
||||
filter_spatial_lengths_{},
|
||||
output_spatial_lengths_{},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
{
|
||||
constexpr index_t spatial_offset = 3;
|
||||
std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset,
|
||||
end(a_g_n_c_wis_lengths),
|
||||
begin(input_spatial_lengths_));
|
||||
std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset,
|
||||
end(b_g_k_c_xs_lengths),
|
||||
begin(filter_spatial_lengths_));
|
||||
std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset,
|
||||
end(e_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
const auto descs =
|
||||
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
Conv_N_,
|
||||
Conv_K_,
|
||||
Conv_C_,
|
||||
input_spatial_lengths_,
|
||||
filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
@@ -856,21 +864,21 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
|
||||
// A/B/C Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ =
|
||||
N * K *
|
||||
std::accumulate(begin(output_spatial_lengths),
|
||||
end(output_spatial_lengths),
|
||||
Conv_N_ * Conv_K_ *
|
||||
std::accumulate(begin(output_spatial_lengths_),
|
||||
end(output_spatial_lengths_),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ =
|
||||
N * C *
|
||||
std::accumulate(begin(input_spatial_lengths),
|
||||
end(input_spatial_lengths),
|
||||
Conv_N_ * Conv_C_ *
|
||||
std::accumulate(begin(input_spatial_lengths_),
|
||||
end(input_spatial_lengths_),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
compute_ptr_offset_of_batch_.BatchStrideC_ =
|
||||
K * C *
|
||||
std::accumulate(begin(filter_spatial_lengths),
|
||||
end(filter_spatial_lengths),
|
||||
Conv_K_ * Conv_C_ *
|
||||
std::accumulate(begin(filter_spatial_lengths_),
|
||||
end(filter_spatial_lengths_),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
}
|
||||
@@ -904,9 +912,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
const index_t Conv_K_;
|
||||
const index_t Conv_C_;
|
||||
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths_;
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths_;
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths_;
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths_;
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
@@ -1110,39 +1118,34 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
ck::index_t split_k)
|
||||
static auto
|
||||
MakeArgument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
ck::index_t split_k)
|
||||
{
|
||||
return Argument{p_in_grid,
|
||||
p_wei_grid,
|
||||
p_out_grid,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
a_g_n_c_wis_lengths, // input
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths, // weight
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths, // output
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
@@ -1159,15 +1162,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
MakeArgumentPointer(const void* p_in_grid,
|
||||
void* p_wei_grid,
|
||||
const void* p_out_grid,
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
@@ -1180,15 +1180,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
|
||||
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
|
||||
static_cast<WeiDataType*>(p_wei_grid),
|
||||
static_cast<const OutDataType*>(p_out_grid),
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
a_g_n_c_wis_lengths, // input
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths, // weight
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths, // output
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
|
||||
@@ -245,21 +245,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
const ck::index_t K,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& output_strides)
|
||||
{
|
||||
if constexpr(is_GNHWK_GKYXC_GNHWC)
|
||||
{
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
}
|
||||
else if constexpr(is_NHWGK_GKYXC_NHWGC)
|
||||
{
|
||||
const index_t WoStride = output_strides[4];
|
||||
const auto KStride = Number<1>{};
|
||||
return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K),
|
||||
make_tuple(WoStride, KStride));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! unsupported layout: " + OutLayout::name());
|
||||
}
|
||||
const index_t WoStride = output_strides[4];
|
||||
const auto KStride = Number<1>{};
|
||||
return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K),
|
||||
make_tuple(WoStride, KStride));
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
@@ -270,42 +259,36 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& input_strides)
|
||||
{
|
||||
if constexpr(is_GNHWK_GKYXC_GNHWC)
|
||||
const index_t NStride = input_strides[1];
|
||||
const index_t HiStride = input_strides[3];
|
||||
const index_t WiStride = input_strides[4];
|
||||
const auto CStride = input_strides[2];
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
}
|
||||
}
|
||||
else if constexpr(is_NHWGK_GKYXC_NHWGC)
|
||||
{
|
||||
const index_t NStride = input_strides[1];
|
||||
const index_t HiStride = input_strides[3];
|
||||
const index_t WiStride = input_strides[4];
|
||||
const auto CStride = input_strides[2];
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, C),
|
||||
make_tuple(WiStride, CStride));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride));
|
||||
}
|
||||
return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, C),
|
||||
make_tuple(WiStride, CStride));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! unsupported layout: " + InLayout::name());
|
||||
return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C),
|
||||
make_tuple(NStride, HiStride, WiStride, CStride));
|
||||
}
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
constexpr static auto
|
||||
make_wei_grid_desc(const ck::index_t K,
|
||||
const ck::index_t Y,
|
||||
const ck::index_t X,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides)
|
||||
{
|
||||
const auto CStride = Number<1>{};
|
||||
const auto KStride = weights_strides[1];
|
||||
return make_naive_tensor_descriptor(make_tuple(K, Y * X * C), make_tuple(KStride, CStride));
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
constexpr static auto
|
||||
make_out_grid_desc(const ck::index_t N,
|
||||
@@ -315,21 +298,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
const ck::index_t K,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& output_strides)
|
||||
{
|
||||
if constexpr(is_GNDHWK_GKZYXC_GNDHWC)
|
||||
{
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K));
|
||||
}
|
||||
else if constexpr(is_NDHWGK_GKZYXC_NDHWGC)
|
||||
{
|
||||
const index_t WoStride = output_strides[5];
|
||||
const auto KStride = Number<1>{};
|
||||
return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K),
|
||||
make_tuple(WoStride, KStride));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! unsupported layout: " + OutLayout::name());
|
||||
}
|
||||
const index_t WoStride = output_strides[5];
|
||||
const auto KStride = Number<1>{};
|
||||
return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K),
|
||||
make_tuple(WoStride, KStride));
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
@@ -341,44 +313,40 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& input_strides)
|
||||
{
|
||||
if constexpr(is_GNDHWK_GKZYXC_GNDHWC)
|
||||
const index_t NStride = input_strides[1];
|
||||
const index_t DiStride = input_strides[3];
|
||||
const index_t HiStride = input_strides[4];
|
||||
const index_t WiStride = input_strides[5];
|
||||
const auto CStride = input_strides[2];
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C));
|
||||
}
|
||||
}
|
||||
else if constexpr(is_NDHWGK_GKZYXC_NDHWGC)
|
||||
{
|
||||
const index_t NStride = input_strides[1];
|
||||
const index_t DiStride = input_strides[3];
|
||||
const index_t HiStride = input_strides[4];
|
||||
const index_t WiStride = input_strides[5];
|
||||
const auto CStride = input_strides[2];
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C),
|
||||
make_tuple(WiStride, CStride));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N, Di, Hi, Wi, C),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
|
||||
}
|
||||
return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C),
|
||||
make_tuple(WiStride, CStride));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! unsupported layout: " + InLayout::name());
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N, Di, Hi, Wi, C),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
|
||||
}
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
constexpr static auto
|
||||
make_wei_grid_desc(const ck::index_t K,
|
||||
const ck::index_t Z,
|
||||
const ck::index_t Y,
|
||||
const ck::index_t X,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides)
|
||||
{
|
||||
const auto CStride = Number<1>{};
|
||||
const auto KStride = weights_strides[1];
|
||||
return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C),
|
||||
make_tuple(KStride, CStride));
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
const ck::index_t N,
|
||||
@@ -388,6 +356,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& /* input_strides */,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& /* weights_strides */,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& /* output_strides */,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
@@ -542,6 +511,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
@@ -584,6 +554,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
|
||||
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
|
||||
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
|
||||
const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Y, X, C, weights_strides);
|
||||
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
@@ -618,13 +589,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// C: weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
wei_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -684,13 +651,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// C: weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
wei_grid_desc);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -703,6 +666,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& weights_strides,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
@@ -752,6 +716,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
|
||||
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
|
||||
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
|
||||
const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Z, Y, X, C, weights_strides);
|
||||
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
@@ -786,13 +751,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// C: weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
wei_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -861,13 +822,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}));
|
||||
|
||||
// C: weight tensor
|
||||
const auto wei_gemmm_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_gemmm_gemmn_grid_desc);
|
||||
wei_grid_desc);
|
||||
}
|
||||
} // function end
|
||||
|
||||
@@ -887,6 +844,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
lengths,
|
||||
strides,
|
||||
strides,
|
||||
strides,
|
||||
params,
|
||||
params,
|
||||
params,
|
||||
@@ -910,6 +868,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
lengths,
|
||||
strides,
|
||||
strides,
|
||||
strides,
|
||||
params,
|
||||
params,
|
||||
params,
|
||||
@@ -933,6 +892,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
lengths,
|
||||
strides,
|
||||
strides,
|
||||
strides,
|
||||
params,
|
||||
params,
|
||||
params,
|
||||
@@ -1051,15 +1011,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
Argument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
@@ -1084,27 +1041,40 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
a_element_op_{out_element_op},
|
||||
b_element_op_{in_element_op},
|
||||
c_element_op_{wei_element_op},
|
||||
Conv_G_{G},
|
||||
Conv_N_{N},
|
||||
Conv_K_{K},
|
||||
Conv_C_{C},
|
||||
output_spatial_lengths_{output_spatial_lengths},
|
||||
filter_spatial_lengths_{filter_spatial_lengths},
|
||||
Conv_G_{a_g_n_c_wis_lengths[0]},
|
||||
Conv_N_{a_g_n_c_wis_lengths[1]},
|
||||
Conv_K_{b_g_k_c_xs_lengths[1]},
|
||||
Conv_C_{a_g_n_c_wis_lengths[2]},
|
||||
input_spatial_lengths_{},
|
||||
filter_spatial_lengths_{},
|
||||
output_spatial_lengths_{},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
{
|
||||
constexpr index_t spatial_offset = 3;
|
||||
std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset,
|
||||
end(a_g_n_c_wis_lengths),
|
||||
begin(input_spatial_lengths_));
|
||||
std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset,
|
||||
end(b_g_k_c_xs_lengths),
|
||||
begin(filter_spatial_lengths_));
|
||||
std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset,
|
||||
end(e_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
const auto descs =
|
||||
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
Conv_N_,
|
||||
Conv_K_,
|
||||
Conv_C_,
|
||||
input_spatial_lengths_,
|
||||
filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
@@ -1119,12 +1089,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
|
||||
|
||||
// A/B/C Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = output_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = input_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideC_ =
|
||||
K * C *
|
||||
std::accumulate(begin(filter_spatial_lengths),
|
||||
end(filter_spatial_lengths),
|
||||
Conv_K_ * Conv_C_ *
|
||||
std::accumulate(begin(filter_spatial_lengths_),
|
||||
end(filter_spatial_lengths_),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
|
||||
@@ -1163,8 +1133,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
const index_t Conv_N_;
|
||||
const index_t Conv_K_;
|
||||
const index_t Conv_C_;
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths_;
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths_;
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths_;
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads_;
|
||||
@@ -1339,39 +1310,34 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
const ck::index_t split_k)
|
||||
static auto
|
||||
MakeArgument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
const ck::index_t split_k)
|
||||
{
|
||||
return Argument{p_in_grid,
|
||||
p_wei_grid,
|
||||
p_out_grid,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
a_g_n_c_wis_lengths, // input
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths, // weight
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths, // output
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
@@ -1390,15 +1356,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
MakeArgumentPointer(const void* p_in_grid,
|
||||
void* p_wei_grid,
|
||||
const void* p_out_grid,
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NDimSpatial + 3>& output_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<ck::index_t, NDimSpatial>& input_left_pads,
|
||||
@@ -1411,15 +1374,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
|
||||
static_cast<WeiDataType*>(p_wei_grid),
|
||||
static_cast<const OutDataType*>(p_out_grid),
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_strides,
|
||||
output_strides,
|
||||
a_g_n_c_wis_lengths, // input
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths, // weight
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths, // output
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
|
||||
Reference in New Issue
Block a user