diff --git a/client_example/11_grouped_conv_bwd_weight/common.hpp b/client_example/11_grouped_conv_bwd_weight/common.hpp index f63e5f2157..4292cded20 100644 --- a/client_example/11_grouped_conv_bwd_weight/common.hpp +++ b/client_example/11_grouped_conv_bwd_weight/common.hpp @@ -32,63 +32,49 @@ struct SimpleDeviceMem }; template -std::size_t GetFlops(ck::index_t G, - ck::index_t N, - ck::index_t K, - ck::index_t C, - const std::array& output_spatial_lengths, - const std::array& filter_spatial_lengths) +std::size_t GetFlops(const std::array& output_lengths, + const std::array& filter_lengths) { + constexpr ck::index_t spatial_offset = 3; + const auto C = filter_lengths[2]; // 2 * G * N * K * C * * - return static_cast(2) * G * N * K * C * - std::accumulate(std::begin(output_spatial_lengths), - std::end(output_spatial_lengths), + return static_cast(2) * C * + std::accumulate(std::begin(output_lengths), + std::end(output_lengths), static_cast(1), std::multiplies<>()) * - std::accumulate(std::begin(filter_spatial_lengths), - std::end(filter_spatial_lengths), + std::accumulate(std::begin(filter_lengths) + spatial_offset, + std::end(filter_lengths), static_cast(1), std::multiplies<>()); } template -std::size_t GetInputByte(ck::index_t G, - ck::index_t N, - ck::index_t C, - const std::array& input_spatial_lengths) +std::size_t GetInputByte(const std::array& input_lengths) { // sizeof(InDataType) * (G * N * C * ) + - return sizeof(InDataType) * (G * N * C * - std::accumulate(std::begin(input_spatial_lengths), - std::end(input_spatial_lengths), + return sizeof(InDataType) * (std::accumulate(std::begin(input_lengths), + std::end(input_lengths), static_cast(1), std::multiplies<>())); } template -std::size_t GetWeightByte(ck::index_t G, - ck::index_t K, - ck::index_t C, - const std::array& filter_spatial_lengths) +std::size_t GetWeightByte(const std::array& filter_lengths) { // sizeof(WeiDataType) * (G * K * C * ) + - return sizeof(WeiDataType) * (G * K * C * - std::accumulate(std::begin(filter_spatial_lengths), - std::end(filter_spatial_lengths), + return sizeof(WeiDataType) * (std::accumulate(std::begin(filter_lengths), + std::end(filter_lengths), static_cast(1), std::multiplies<>())); } template -std::size_t GetOutputByte(ck::index_t G, - ck::index_t N, - ck::index_t K, - const std::array& output_spatial_lengths) +std::size_t GetOutputByte(const std::array& output_lengths) { // sizeof(OutDataType) * (G * N * K * ); - return sizeof(OutDataType) * (G * N * K * - std::accumulate(std::begin(output_spatial_lengths), - std::end(output_spatial_lengths), + return sizeof(OutDataType) * (std::accumulate(std::begin(output_lengths), + std::end(output_lengths), static_cast(1), std::multiplies())); } @@ -101,14 +87,11 @@ template bool run_grouped_conv_bwd_weight( - const ck::index_t G, - const ck::index_t N, - const ck::index_t K, - const ck::index_t C, - const std::array& input_spatial_lengths, - const std::array& filter_spatial_lengths, - const std::array& output_spatial_lengths, + const std::array& input_lengths, const std::array& input_strides, + const std::array& filter_lengths, + const std::array& weights_strides, + const std::array& output_lengths, const std::array& output_strides, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, @@ -117,9 +100,9 @@ bool run_grouped_conv_bwd_weight( { ck::index_t split_k = 2; - SimpleDeviceMem in(GetInputByte(G, N, C, input_spatial_lengths)); - SimpleDeviceMem wei(GetWeightByte(G, K, C, filter_spatial_lengths)); - SimpleDeviceMem out(GetOutputByte(G, N, K, output_spatial_lengths)); + SimpleDeviceMem in(GetInputByte(input_lengths)); + SimpleDeviceMem wei(GetWeightByte(filter_lengths)); + SimpleDeviceMem out(GetOutputByte(output_lengths)); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight a_g_n_c_wis_lengths{}; + std::array a_g_n_c_wis_strides{}; + std::array b_g_k_c_xs_lengths{}; + // profile device operation instances std::cout << "Run all instances and do timing" << std::endl; @@ -152,14 +139,11 @@ bool run_grouped_conv_bwd_weight( auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), wei.GetDeviceBuffer(), out.GetDeviceBuffer(), - G, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + input_lengths, input_strides, + filter_lengths, + weights_strides, + output_lengths, output_strides, conv_filter_strides, conv_filter_dilations, @@ -176,12 +160,10 @@ bool run_grouped_conv_bwd_weight( { float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); - std::size_t flop = - GetFlops(G, N, K, C, output_spatial_lengths, filter_spatial_lengths); - std::size_t num_bytes = - GetInputByte(G, N, C, input_spatial_lengths) + - GetWeightByte(G, K, C, filter_spatial_lengths) + - GetOutputByte(G, N, K, output_spatial_lengths); + std::size_t flop = GetFlops(output_lengths, filter_lengths); + std::size_t num_bytes = GetInputByte(input_lengths) + + GetWeightByte(filter_lengths) + + GetOutputByte(output_lengths); float tflops = static_cast(flop) / 1.E9 / avg_time; float gb_per_sec = num_bytes / 1.E6 / avg_time; @@ -221,14 +203,11 @@ bool run_grouped_conv_bwd_weight( auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(), wei.GetDeviceBuffer(), out.GetDeviceBuffer(), - G, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + input_lengths, input_strides, + filter_lengths, + weights_strides, + output_lengths, output_strides, conv_filter_strides, conv_filter_dilations, diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp index 1c6f485da2..e6d427faf4 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv1d_bwd_weight_fp16.cpp @@ -22,11 +22,12 @@ static constexpr ck::index_t C = 192; static constexpr ck::index_t X = 3; static constexpr ck::index_t Wi = 28; static constexpr ck::index_t Wo = 28; -static constexpr std::array input_spatial_lengths{Wi}; -static constexpr std::array filter_spatial_lengths{X}; -static constexpr std::array output_spatial_lengths{Wo}; -static constexpr std::array input_strides{N * Wi * C, Wi* C, C, 1}; -static constexpr std::array output_strides{N * Wo * K, Wo* K, K, 1}; +static constexpr std::array input_lengths{G, N, C, Wi}; +static constexpr std::array filter_lengths{G, K, C, X}; +static constexpr std::array output_lengths{G, N, K, Wo}; +static constexpr std::array input_strides{N * Wi * C, Wi* C, 1, C}; +static constexpr std::array weights_strides{K * X * C, X* C, 1, C}; +static constexpr std::array output_strides{N * Wo * K, Wo* K, 1, K}; static constexpr std::array conv_filter_strides{1}; static constexpr std::array conv_filter_dilations{1}; static constexpr std::array input_left_pads{1}; @@ -40,14 +41,11 @@ int main() OutDataType, InLayout, WeiLayout, - OutLayout>(G, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + OutLayout>(input_lengths, input_strides, + filter_lengths, + weights_strides, + output_lengths, output_strides, conv_filter_strides, conv_filter_dilations, diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp index 25e82f3896..4201ea61b4 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv2d_bwd_weight_fp16.cpp @@ -25,13 +25,15 @@ static constexpr ck::index_t Hi = 28; static constexpr ck::index_t Wi = 28; static constexpr ck::index_t Ho = 28; static constexpr ck::index_t Wo = 28; -static constexpr std::array input_spatial_lengths{Hi, Wi}; -static constexpr std::array filter_spatial_lengths{Y, X}; -static constexpr std::array output_spatial_lengths{Ho, Wo}; +static constexpr std::array input_lengths{G, N, C, Hi, Wi}; +static constexpr std::array filter_lengths{G, K, C, Y, X}; +static constexpr std::array output_lengths{G, N, K, Ho, Wo}; static constexpr std::array input_strides{ - N * Hi * Wi * C, Hi* Wi* C, Wi* C, C, 1}; + N * Hi * Wi * C, Hi* Wi* C, 1, Wi* C, C}; +static constexpr std::array weights_strides{ + K * Y * X * C, Y* X* C, 1, X* C, C}; static constexpr std::array output_strides{ - N * Ho * Wo * K, Ho* Wo* K, Wo* K, K, 1}; + N * Ho * Wo * K, Ho* Wo* K, 1, Wo* K, K}; static constexpr std::array conv_filter_strides{1, 1}; static constexpr std::array conv_filter_dilations{1, 1}; static constexpr std::array input_left_pads{1, 1}; @@ -45,14 +47,11 @@ int main() OutDataType, InLayout, WeiLayout, - OutLayout>(G, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + OutLayout>(input_lengths, input_strides, + filter_lengths, + weights_strides, + output_lengths, output_strides, conv_filter_strides, conv_filter_dilations, diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp index a5f5e628ff..3ae46bcd55 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp16.cpp @@ -28,13 +28,15 @@ static constexpr ck::index_t Wi = 3; static constexpr ck::index_t Do = 28; static constexpr ck::index_t Ho = 28; static constexpr ck::index_t Wo = 3; -static constexpr std::array input_spatial_lengths{Di, Hi, Wi}; -static constexpr std::array filter_spatial_lengths{Z, Y, X}; -static constexpr std::array output_spatial_lengths{Do, Ho, Wo}; +static constexpr std::array input_lengths{G, N, C, Di, Hi, Wi}; +static constexpr std::array filter_lengths{G, K, C, Z, Y, X}; +static constexpr std::array output_lengths{G, N, K, Do, Ho, Wo}; static constexpr std::array input_strides{ - N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1}; + N * Di * Hi * Wi * C, Di* Hi* Wi* C, 1, Hi* Wi* C, Wi* C, C}; +static constexpr std::array weights_strides{ + K * Z * Y * X * C, Z* Y* X* C, 1, Y* X* C, X* C, C}; static constexpr std::array output_strides{ - N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1}; + N * Do * Ho * Wo * K, Do* Ho* Wo* K, 1, Ho* Wo* K, Wo* K, K}; static constexpr std::array conv_filter_strides{1, 1, 1}; static constexpr std::array conv_filter_dilations{1, 1, 1}; static constexpr std::array input_left_pads{1, 1, 1}; @@ -48,14 +50,11 @@ int main() OutDataType, InLayout, WeiLayout, - OutLayout>(G, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + OutLayout>(input_lengths, input_strides, + filter_lengths, + weights_strides, + output_lengths, output_strides, conv_filter_strides, conv_filter_dilations, diff --git a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp index d95e8a205e..2eb869f392 100644 --- a/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp +++ b/client_example/11_grouped_conv_bwd_weight/grouped_conv3d_bwd_weight_fp32.cpp @@ -28,13 +28,15 @@ static constexpr ck::index_t Wi = 3; static constexpr ck::index_t Do = 28; static constexpr ck::index_t Ho = 28; static constexpr ck::index_t Wo = 3; -static constexpr std::array input_spatial_lengths{Di, Hi, Wi}; -static constexpr std::array filter_spatial_lengths{Z, Y, X}; -static constexpr std::array output_spatial_lengths{Do, Ho, Wo}; +static constexpr std::array input_lengths{G, N, C, Di, Hi, Wi}; +static constexpr std::array filter_lengths{G, K, C, Z, Y, X}; +static constexpr std::array output_lengths{G, N, K, Do, Ho, Wo}; static constexpr std::array input_strides{ - N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1}; + N * Di * Hi * Wi * C, Di* Hi* Wi* C, 1, Hi* Wi* C, Wi* C, C}; +static constexpr std::array weights_strides{ + K * Z * Y * X * C, Z* Y* X* C, 1, Y* X* C, X* C, C}; static constexpr std::array output_strides{ - N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1}; + N * Do * Ho * Wo * K, Do* Ho* Wo* K, 1, Ho* Wo* K, Wo* K, K}; static constexpr std::array conv_filter_strides{1, 1, 1}; static constexpr std::array conv_filter_dilations{1, 1, 1}; static constexpr std::array input_left_pads{1, 1, 1}; @@ -48,20 +50,16 @@ int main() OutDataType, InLayout, WeiLayout, - OutLayout>( - G, - N, - K, - C, - {Di, Hi, Wi}, - {Z, Y, X}, - {Do, Ho, Wo}, - {N * Di * Hi * Wi * C, Di * Hi * Wi * C, Hi * Wi * C, Wi * C, C, 1}, - {N * Do * Ho * Wo * K, Do * Ho * Wo * K, Ho * Wo * K, Wo * K, K, 1}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}) + OutLayout>(input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads) ? EXIT_SUCCESS : EXIT_FAILURE; } diff --git a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc index 49bd9fc7f0..29ce0324ab 100644 --- a/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc +++ b/example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc @@ -72,10 +72,11 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, // init to 0 wei_device_buf.SetZero(); - std::array input_spatial_lengths{}; - std::array filter_spatial_lengths{}; - std::array output_spatial_lengths{}; + std::array input_lengths{}; std::array input_strides{}; + std::array filter_lengths{}; + std::array weights_strides{}; + std::array output_lengths{}; std::array output_strides{}; std::array conv_filter_strides{}; std::array conv_filter_dilations{}; @@ -84,10 +85,11 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); }; - range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths)); - range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths)); - range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths)); + range_copy(in_g_n_c_wis_desc.GetLengths(), begin(input_lengths)); range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides)); + range_copy(wei_g_k_c_xs_desc.GetLengths(), begin(filter_lengths)); + range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides)); + range_copy(out_g_n_k_wos_desc.GetLengths(), begin(output_lengths)); range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides)); range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides)); range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations)); @@ -100,14 +102,11 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config, auto argument = conv.MakeArgument(static_cast(in_device_buf.GetDeviceBuffer()), static_cast(wei_device_buf.GetDeviceBuffer()), static_cast(out_device_buf.GetDeviceBuffer()), - conv_param.G_, - conv_param.N_, - conv_param.K_, - conv_param.C_, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + input_lengths, input_strides, + filter_lengths, + weights_strides, + output_lengths, output_strides, conv_filter_strides, conv_filter_dilations, diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp index a103bdff35..ab9e6adb41 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp @@ -27,15 +27,12 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator MakeArgumentPointer(const void* p_in, void* p_wei, const void* p_out, - const ck::index_t G, - const ck::index_t N, - const ck::index_t K, - const ck::index_t C, - const std::array& input_spatial_lengths, - const std::array& filter_spatial_lengths, - const std::array& output_spatial_lengths, - const std::array& input_strides, - const std::array& output_strides, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp index 003a508f07..198751cdf3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp @@ -784,15 +784,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl Argument(const InDataType* p_in_grid, WeiDataType* p_wei_grid, const OutDataType* p_out_grid, - const ck::index_t G, - const ck::index_t N, - const ck::index_t K, - const ck::index_t C, - const std::array& input_spatial_lengths, - const std::array& filter_spatial_lengths, - const std::array& output_spatial_lengths, - const std::array& /*input_strides*/, - const std::array& /*output_strides*/, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& /*a_g_n_c_wis_strides*/, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& /*b_g_k_c_xs_strides*/, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& /*e_g_n_k_wos_strides*/, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, @@ -812,27 +809,38 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl a_element_op_{out_element_op}, b_element_op_{wei_element_op}, c_element_op_{in_element_op}, - Conv_G_{G}, - Conv_N_{N}, - Conv_K_{K}, - Conv_C_{C}, - input_spatial_lengths_{input_spatial_lengths}, - filter_spatial_lengths_{filter_spatial_lengths}, - output_spatial_lengths_{output_spatial_lengths}, + Conv_G_{a_g_n_c_wis_lengths[0]}, + Conv_N_{a_g_n_c_wis_lengths[1]}, + Conv_K_{b_g_k_c_xs_lengths[1]}, + Conv_C_{a_g_n_c_wis_lengths[2]}, + input_spatial_lengths_{}, + filter_spatial_lengths_{}, + output_spatial_lengths_{}, conv_filter_strides_{conv_filter_strides}, conv_filter_dilations_{conv_filter_dilations}, input_left_pads_{input_left_pads}, input_right_pads_{input_right_pads}, k_batch_{split_k} { + constexpr index_t spatial_offset = 3; + std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset, + end(a_g_n_c_wis_lengths), + begin(input_spatial_lengths_)); + std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset, + end(b_g_k_c_xs_lengths), + begin(filter_spatial_lengths_)); + std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset, + end(e_g_n_k_wos_lengths), + begin(output_spatial_lengths_)); + const auto descs = DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, conv_filter_strides, conv_filter_dilations, input_left_pads, @@ -856,21 +864,21 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl // A/B/C Batch Stride compute_ptr_offset_of_batch_.BatchStrideA_ = - N * K * - std::accumulate(begin(output_spatial_lengths), - end(output_spatial_lengths), + Conv_N_ * Conv_K_ * + std::accumulate(begin(output_spatial_lengths_), + end(output_spatial_lengths_), index_t{1}, std::multiplies<>{}); compute_ptr_offset_of_batch_.BatchStrideB_ = - N * C * - std::accumulate(begin(input_spatial_lengths), - end(input_spatial_lengths), + Conv_N_ * Conv_C_ * + std::accumulate(begin(input_spatial_lengths_), + end(input_spatial_lengths_), index_t{1}, std::multiplies<>{}); compute_ptr_offset_of_batch_.BatchStrideC_ = - K * C * - std::accumulate(begin(filter_spatial_lengths), - end(filter_spatial_lengths), + Conv_K_ * Conv_C_ * + std::accumulate(begin(filter_spatial_lengths_), + end(filter_spatial_lengths_), index_t{1}, std::multiplies<>{}); } @@ -904,9 +912,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl const index_t Conv_K_; const index_t Conv_C_; - const std::array& input_spatial_lengths_; - const std::array& filter_spatial_lengths_; - const std::array& output_spatial_lengths_; + std::array input_spatial_lengths_; + std::array filter_spatial_lengths_; + std::array output_spatial_lengths_; const std::array& conv_filter_strides_; const std::array& conv_filter_dilations_; const std::array& input_left_pads_; @@ -1110,39 +1118,34 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl return IsSupportedArgument(*dynamic_cast(p_arg)); } - static auto MakeArgument(const InDataType* p_in_grid, - WeiDataType* p_wei_grid, - const OutDataType* p_out_grid, - const ck::index_t G, - const ck::index_t N, - const ck::index_t K, - const ck::index_t C, - const std::array& input_spatial_lengths, - const std::array& filter_spatial_lengths, - const std::array& output_spatial_lengths, - const std::array& input_strides, - const std::array& output_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op, - ck::index_t split_k) + static auto + MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) { return Argument{p_in_grid, p_wei_grid, p_out_grid, - G, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, - input_strides, - output_strides, + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, @@ -1159,15 +1162,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl MakeArgumentPointer(const void* p_in_grid, void* p_wei_grid, const void* p_out_grid, - const ck::index_t G, - const ck::index_t N, - const ck::index_t K, - const ck::index_t C, - const std::array& input_spatial_lengths, - const std::array& filter_spatial_lengths, - const std::array& output_spatial_lengths, - const std::array& input_strides, - const std::array& output_strides, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, @@ -1180,15 +1180,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl return std::make_unique(static_cast(p_in_grid), static_cast(p_wei_grid), static_cast(p_out_grid), - G, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, - input_strides, - output_strides, + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index a55da9add2..5a8bb6188b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -245,21 +245,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const ck::index_t K, const std::array& output_strides) { - if constexpr(is_GNHWK_GKYXC_GNHWC) - { - return make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); - } - else if constexpr(is_NHWGK_GKYXC_NHWGC) - { - const index_t WoStride = output_strides[4]; - const auto KStride = Number<1>{}; - return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K), - make_tuple(WoStride, KStride)); - } - else - { - throw std::runtime_error("wrong! unsupported layout: " + OutLayout::name()); - } + const index_t WoStride = output_strides[4]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K), + make_tuple(WoStride, KStride)); } template ::type = false> @@ -270,42 +259,36 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const ck::index_t C, const std::array& input_strides) { - if constexpr(is_GNHWK_GKYXC_GNHWC) + const index_t NStride = input_strides[1]; + const index_t HiStride = input_strides[3]; + const index_t WiStride = input_strides[4]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { - if constexpr(ConvBackwardWeightSpecialization == - ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) - { - return make_naive_tensor_descriptor_packed(make_tuple(N * Hi * Wi, C)); - } - else - { - return make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); - } - } - else if constexpr(is_NHWGK_GKYXC_NHWGC) - { - const index_t NStride = input_strides[1]; - const index_t HiStride = input_strides[3]; - const index_t WiStride = input_strides[4]; - const auto CStride = input_strides[2]; - if constexpr(ConvBackwardWeightSpecialization == - ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) - { - return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, C), - make_tuple(WiStride, CStride)); - } - else - { - return make_naive_tensor_descriptor( - make_tuple(N, Hi, Wi, C), make_tuple(NStride, HiStride, WiStride, CStride)); - } + return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, C), + make_tuple(WiStride, CStride)); } else { - throw std::runtime_error("wrong! unsupported layout: " + InLayout::name()); + return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C), + make_tuple(NStride, HiStride, WiStride, CStride)); } } + template ::type = false> + constexpr static auto + make_wei_grid_desc(const ck::index_t K, + const ck::index_t Y, + const ck::index_t X, + const ck::index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + return make_naive_tensor_descriptor(make_tuple(K, Y * X * C), make_tuple(KStride, CStride)); + } + template ::type = false> constexpr static auto make_out_grid_desc(const ck::index_t N, @@ -315,21 +298,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const ck::index_t K, const std::array& output_strides) { - if constexpr(is_GNDHWK_GKZYXC_GNDHWC) - { - return make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)); - } - else if constexpr(is_NDHWGK_GKZYXC_NDHWGC) - { - const index_t WoStride = output_strides[5]; - const auto KStride = Number<1>{}; - return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K), - make_tuple(WoStride, KStride)); - } - else - { - throw std::runtime_error("wrong! unsupported layout: " + OutLayout::name()); - } + const index_t WoStride = output_strides[5]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K), + make_tuple(WoStride, KStride)); } template ::type = false> @@ -341,44 +313,40 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const ck::index_t C, const std::array& input_strides) { - if constexpr(is_GNDHWK_GKZYXC_GNDHWC) + const index_t NStride = input_strides[1]; + const index_t DiStride = input_strides[3]; + const index_t HiStride = input_strides[4]; + const index_t WiStride = input_strides[5]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) { - if constexpr(ConvBackwardWeightSpecialization == - ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) - { - return make_naive_tensor_descriptor_packed(make_tuple(N * Di * Hi * Wi, C)); - } - else - { - return make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); - } - } - else if constexpr(is_NDHWGK_GKZYXC_NDHWGC) - { - const index_t NStride = input_strides[1]; - const index_t DiStride = input_strides[3]; - const index_t HiStride = input_strides[4]; - const index_t WiStride = input_strides[5]; - const auto CStride = input_strides[2]; - if constexpr(ConvBackwardWeightSpecialization == - ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) - { - return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C), - make_tuple(WiStride, CStride)); - } - else - { - return make_naive_tensor_descriptor( - make_tuple(N, Di, Hi, Wi, C), - make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); - } + return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C), + make_tuple(WiStride, CStride)); } else { - throw std::runtime_error("wrong! unsupported layout: " + InLayout::name()); + return make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); } } + template ::type = false> + constexpr static auto + make_wei_grid_desc(const ck::index_t K, + const ck::index_t Z, + const ck::index_t Y, + const ck::index_t X, + const ck::index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C), + make_tuple(KStride, CStride)); + } + template ::type = false> static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( const ck::index_t N, @@ -388,6 +356,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const std::array& filter_spatial_lengths, const std::array& output_spatial_lengths, const std::array& /* input_strides */, + const std::array& /* weights_strides */, const std::array& /* output_strides */, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, @@ -542,6 +511,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const std::array& filter_spatial_lengths, const std::array& output_spatial_lengths, const std::array& input_strides, + const std::array& weights_strides, const std::array& output_strides, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, @@ -584,6 +554,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const auto out_grid_desc = make_out_grid_desc(N, Ho, Wo, K, output_strides); const auto in_grid_desc = make_in_grid_desc(N, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Y, X, C, weights_strides); if constexpr(ConvBackwardWeightSpecialization == ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) @@ -618,13 +589,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_grid_desc); } else { @@ -684,13 +651,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_grid_desc); } } @@ -703,6 +666,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const std::array& filter_spatial_lengths, const std::array& output_spatial_lengths, const std::array& input_strides, + const std::array& weights_strides, const std::array& output_strides, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, @@ -752,6 +716,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Z, Y, X, C, weights_strides); if constexpr(ConvBackwardWeightSpecialization == ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) @@ -786,13 +751,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_grid_desc); } else { @@ -861,13 +822,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - // C: weight tensor - const auto wei_gemmm_gemmn_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)); - return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc); + wei_grid_desc); } } // function end @@ -887,6 +844,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle lengths, strides, strides, + strides, params, params, params, @@ -910,6 +868,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle lengths, strides, strides, + strides, params, params, params, @@ -933,6 +892,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle lengths, strides, strides, + strides, params, params, params, @@ -1051,15 +1011,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle Argument(const InDataType* p_in_grid, WeiDataType* p_wei_grid, const OutDataType* p_out_grid, - const ck::index_t G, - const ck::index_t N, - const ck::index_t K, - const ck::index_t C, - const std::array& input_spatial_lengths, - const std::array& filter_spatial_lengths, - const std::array& output_spatial_lengths, - const std::array& input_strides, - const std::array& output_strides, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, @@ -1084,27 +1041,40 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle a_element_op_{out_element_op}, b_element_op_{in_element_op}, c_element_op_{wei_element_op}, - Conv_G_{G}, - Conv_N_{N}, - Conv_K_{K}, - Conv_C_{C}, - output_spatial_lengths_{output_spatial_lengths}, - filter_spatial_lengths_{filter_spatial_lengths}, + Conv_G_{a_g_n_c_wis_lengths[0]}, + Conv_N_{a_g_n_c_wis_lengths[1]}, + Conv_K_{b_g_k_c_xs_lengths[1]}, + Conv_C_{a_g_n_c_wis_lengths[2]}, + input_spatial_lengths_{}, + filter_spatial_lengths_{}, + output_spatial_lengths_{}, conv_filter_strides_{conv_filter_strides}, input_left_pads_{input_left_pads}, input_right_pads_{input_right_pads}, k_batch_{split_k} { + constexpr index_t spatial_offset = 3; + std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset, + end(a_g_n_c_wis_lengths), + begin(input_spatial_lengths_)); + std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset, + end(b_g_k_c_xs_lengths), + begin(filter_spatial_lengths_)); + std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset, + end(e_g_n_k_wos_lengths), + begin(output_spatial_lengths_)); + const auto descs = DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, - input_strides, - output_strides, + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + a_g_n_c_wis_strides, + b_g_k_c_xs_strides, + e_g_n_k_wos_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, @@ -1119,12 +1089,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); // A/B/C Batch Stride - compute_ptr_offset_of_batch_.BatchStrideA_ = output_strides[0]; - compute_ptr_offset_of_batch_.BatchStrideB_ = input_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[0]; compute_ptr_offset_of_batch_.BatchStrideC_ = - K * C * - std::accumulate(begin(filter_spatial_lengths), - end(filter_spatial_lengths), + Conv_K_ * Conv_C_ * + std::accumulate(begin(filter_spatial_lengths_), + end(filter_spatial_lengths_), index_t{1}, std::multiplies<>{}); @@ -1163,8 +1133,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle const index_t Conv_N_; const index_t Conv_K_; const index_t Conv_C_; - const std::array& output_spatial_lengths_; - const std::array& filter_spatial_lengths_; + std::array input_spatial_lengths_; + std::array filter_spatial_lengths_; + std::array output_spatial_lengths_; const std::array& conv_filter_strides_; const std::array& input_left_pads_; const std::array& input_right_pads_; @@ -1339,39 +1310,34 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle return IsSupportedArgument(*dynamic_cast(p_arg)); } - static auto MakeArgument(const InDataType* p_in_grid, - WeiDataType* p_wei_grid, - const OutDataType* p_out_grid, - const ck::index_t G, - const ck::index_t N, - const ck::index_t K, - const ck::index_t C, - const std::array& input_spatial_lengths, - const std::array& filter_spatial_lengths, - const std::array& output_spatial_lengths, - const std::array& input_strides, - const std::array& output_strides, - const std::array& conv_filter_strides, - const std::array& conv_filter_dilations, - const std::array& input_left_pads, - const std::array& input_right_pads, - InElementwiseOperation in_element_op, - WeiElementwiseOperation wei_element_op, - OutElementwiseOperation out_element_op, - const ck::index_t split_k) + static auto + MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) { return Argument{p_in_grid, p_wei_grid, p_out_grid, - G, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, - input_strides, - output_strides, + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, @@ -1390,15 +1356,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle MakeArgumentPointer(const void* p_in_grid, void* p_wei_grid, const void* p_out_grid, - const ck::index_t G, - const ck::index_t N, - const ck::index_t K, - const ck::index_t C, - const std::array& input_spatial_lengths, - const std::array& filter_spatial_lengths, - const std::array& output_spatial_lengths, - const std::array& input_strides, - const std::array& output_strides, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, const std::array& conv_filter_strides, const std::array& conv_filter_dilations, const std::array& input_left_pads, @@ -1411,15 +1374,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle return std::make_unique(static_cast(p_in_grid), static_cast(p_wei_grid), static_cast(p_out_grid), - G, - N, - K, - C, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, - input_strides, - output_strides, + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, conv_filter_strides, conv_filter_dilations, input_left_pads, diff --git a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp index 3a4295eeb3..48bf639a70 100644 --- a/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_bwd_weight_impl.hpp @@ -136,10 +136,11 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, // profile device Conv instances bool all_pass = true; - std::array input_spatial_lengths{}; - std::array filter_spatial_lengths{}; - std::array output_spatial_lengths{}; + std::array input_lengths{}; + std::array filter_lengths{}; + std::array output_lengths{}; std::array input_strides{}; + std::array weights_strides{}; std::array output_strides{}; std::array conv_filter_strides{}; std::array conv_filter_dilations{}; @@ -148,10 +149,11 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); }; - range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths)); - range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths)); - range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths)); + range_copy(in_g_n_c_wis_desc.GetLengths(), begin(input_lengths)); range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides)); + range_copy(wei_g_k_c_xs_desc.GetLengths(), begin(filter_lengths)); + range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides)); + range_copy(out_g_n_k_wos_desc.GetLengths(), begin(output_lengths)); range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides)); range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides)); range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations)); @@ -164,14 +166,11 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification, op_ptr->MakeArgumentPointer(static_cast(in_device_buf.GetDeviceBuffer()), static_cast(wei_device_buf.GetDeviceBuffer()), static_cast(out_device_buf.GetDeviceBuffer()), - conv_param.G_, - conv_param.N_, - conv_param.K_, - conv_param.C_, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + input_lengths, input_strides, + filter_lengths, + weights_strides, + output_lengths, output_strides, conv_filter_strides, conv_filter_dilations, diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface.cpp index 57d4e4186c..cfbf13f00e 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface.cpp @@ -70,10 +70,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( conv_param); - std::array input_spatial_lengths{}; - std::array filter_spatial_lengths{}; - std::array output_spatial_lengths{}; + std::array input_lengths{}; + std::array filter_lengths{}; + std::array output_lengths{}; std::array input_strides{}; + std::array weights_strides{}; std::array output_strides{}; std::array conv_filter_strides{}; std::array conv_filter_dilations{}; @@ -82,10 +83,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); }; - range_copy(conv_param.input_spatial_lengths_, begin(input_spatial_lengths)); - range_copy(conv_param.filter_spatial_lengths_, begin(filter_spatial_lengths)); - range_copy(conv_param.output_spatial_lengths_, begin(output_spatial_lengths)); + range_copy(in_g_n_c_wis_desc.GetLengths(), begin(input_lengths)); range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides)); + range_copy(wei_g_k_c_xs_desc.GetLengths(), begin(filter_lengths)); + range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides)); + range_copy(out_g_n_k_wos_desc.GetLengths(), begin(output_lengths)); range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides)); range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides)); range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations)); @@ -97,14 +99,11 @@ class TestGroupedConvndBwdWeight : public ::testing::Test auto argument = conv.MakeArgument(nullptr, nullptr, nullptr, - conv_param.G_, - conv_param.N_, - conv_param.K_, - conv_param.C_, - input_spatial_lengths, - filter_spatial_lengths, - output_spatial_lengths, + input_lengths, input_strides, + filter_lengths, + weights_strides, + output_lengths, output_strides, conv_filter_strides, conv_filter_dilations,