From cff08cbc72fe038829574734bb9344b66900d469 Mon Sep 17 00:00:00 2001 From: rocking Date: Mon, 24 Apr 2023 12:40:00 +0800 Subject: [PATCH] Revise layout of group convolution (#675) * [What] Remove pure conv int8 instance [Why] We will never use pure int8 conv in AI, use int8 quantization instead * Change layout * Share the kernel parameter * Support more type of NHWGC for group conv * Revise client example of conv 2d, use NHWGC layout * Add instance to cmake * Revise layout of group conv quantization instance * Revise layout of external api of group conv quantization * Revise layout of group conv quantization client example * Fix clang format * Add comment to describe meaning of each parameter [ROCm/composable_kernel commit: 3eecbfb6ec231cd8012faceb8b6fbc87199db60d] --- .../grouped_conv2d_fwd.cpp | 74 +++----- ..._fwd_bias_relu_perchannel_quantization.cpp | 42 +++-- ...2d_fwd_bias_relu_perlayer_quantization.cpp | 26 +-- ..._fwd_bias_tanh_perchannel_quantization.cpp | 28 +-- ...2d_fwd_bias_tanh_perlayer_quantization.cpp | 26 +-- .../conv2d_fwd_perchannel_quantization.cpp | 40 +++-- .../conv2d_fwd_perlayer_quantization.cpp | 24 +-- ...d_bias_perchannel_quantization_example.inc | 6 +- ...fwd_bias_perlayer_quantization_example.inc | 6 +- ...2d_fwd_perchannel_quantization_example.inc | 6 +- ...nv2d_fwd_perlayer_quantization_example.inc | 6 +- .../gpu/grouped_convolution_forward.hpp | 58 +++---- ...n_bias_forward_perchannel_quantization.hpp | 34 ++-- ...ion_bias_forward_perlayer_quantization.hpp | 34 ++-- ...lution_forward_perchannel_quantization.hpp | 22 +-- ...volution_forward_perlayer_quantization.hpp | 22 +-- .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 6 +- .../device_grouped_conv2d_fwd_common.hpp | 53 ++++++ ..._fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp | 112 ++++-------- ..._fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp | 116 ++++--------- ...fwd_dl_gnhwc_gkyxc_gnhwk_int8_instance.cpp | 104 ----------- .../device_grouped_conv2d_fwd_dl_instance.hpp | 50 ++++++ ...wd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp | 162 ++++-------------- ...fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp | 160 ++++------------- ...fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp | 132 ++++---------- ...wd_xdl_gnhwc_gkyxc_gnhwk_int8_instance.cpp | 125 -------------- ...device_grouped_conv2d_fwd_xdl_instance.hpp | 105 ++++++++++++ ...wd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp | 66 +++++++ ...fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp | 160 ++++------------- ...fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp | 66 +++++++ .../conv2d_fwd/conv2d_quantization_common.hpp | 4 +- ..._perchannel_quantization_int8_instance.cpp | 57 ++++-- ...as_perlayer_quantization_int8_instance.cpp | 57 ++++-- .../device_conv2d_dl_int8_instance.hpp | 7 +- ..._perchannel_quantization_int8_instance.cpp | 38 ++-- ...dl_perlayer_quantization_int8_instance.cpp | 38 ++-- ..._perchannel_quantization_int8_instance.cpp | 57 ++++-- ...as_perlayer_quantization_int8_instance.cpp | 57 ++++-- .../device_conv2d_xdl_int8_instance.hpp | 39 +++-- ..._perchannel_quantization_int8_instance.cpp | 38 ++-- ...dl_perlayer_quantization_int8_instance.cpp | 38 ++-- 41 files changed, 1079 insertions(+), 1222 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_common.hpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_instance.hpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_instance.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp diff --git a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp index ece6e30c56..0a798be270 100644 --- a/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp +++ b/client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp @@ -17,22 +17,22 @@ using InDataType = ck::half_t; using WeiDataType = ck::half_t; using OutDataType = ck::half_t; -using InLayout = ck::tensor_layout::convolution::GNHWC; +using InLayout = ck::tensor_layout::convolution::NHWGC; using WeiLayout = ck::tensor_layout::convolution::GKYXC; -using OutLayout = ck::tensor_layout::convolution::GNHWK; +using OutLayout = ck::tensor_layout::convolution::NHWGK; using PassThrough = ck::tensor_operation::element_wise::PassThrough; static constexpr ck::index_t NumDimSpatial = 2; static constexpr ck::index_t G = 32; -static constexpr ck::index_t N = 256; -static constexpr ck::index_t K = 192; -static constexpr ck::index_t C = 192; -static constexpr ck::index_t Y = 3; -static constexpr ck::index_t X = 3; -static constexpr ck::index_t Hi = 28; -static constexpr ck::index_t Wi = 28; -static constexpr ck::index_t Ho = 28; -static constexpr ck::index_t Wo = 28; +static constexpr ck::index_t N = 256; // batch size +static constexpr ck::index_t K = 64; // output channel +static constexpr ck::index_t C = 32; // input channel (per group) +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 28; // input H +static constexpr ck::index_t Wi = 28; // input W +static constexpr ck::index_t Ho = 28; // output H +static constexpr ck::index_t Wo = 28; // output W struct SimpleDeviceMem { @@ -52,50 +52,24 @@ struct SimpleDeviceMem int main() { - std::array in_lengths{G, N, Hi, Wi, C}; - std::array in_strides{0, 0, 0, 0, 1}; - - std::array wei_lengths{G, K, Y, X, C}; - std::array wei_strides{0, 0, 0, 0, 1}; - - std::array out_lengths{G, N, Ho, Wo, K}; - std::array out_strides{0, 0, 0, 0, 1}; - - std::partial_sum(rbegin(in_lengths), - std::prev(rend(in_lengths)), - std::next(rbegin(in_strides)), - std::multiplies<>{}); - std::partial_sum(rbegin(wei_lengths), - std::prev(rend(wei_lengths)), - std::next(rbegin(wei_strides)), - std::multiplies<>{}); - std::partial_sum(rbegin(out_lengths), - std::prev(rend(out_lengths)), - std::next(rbegin(out_strides)), - std::multiplies<>{}); - - // transpose GNHWC/GKYXC/GNHWK to GNCHW/GKCYX/GNCHW - std::rotate( - rbegin(in_lengths), std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 3)); - std::rotate( - rbegin(in_strides), std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 3)); - std::rotate( - rbegin(wei_lengths), std::next(rbegin(wei_lengths)), std::next(rbegin(wei_lengths), 3)); - std::rotate( - rbegin(wei_strides), std::next(rbegin(wei_strides)), std::next(rbegin(wei_strides), 3)); - std::rotate( - rbegin(out_lengths), std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 3)); - std::rotate( - rbegin(out_strides), std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 3)); + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space + // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW + // Hence, we need to adjust the order of stride + std::array in_lengths{G, N, C, Hi, Wi}; + std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; + std::array wei_lengths{G, K, C, Y, X}; + std::array wei_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; + std::array out_lengths{G, N, K, Ho, Wo}; + std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; std::array filter_strides{1, 1}; std::array filter_dilations{1, 1}; std::array input_left_pads{1, 1}; std::array input_right_pads{1, 1}; - SimpleDeviceMem in(sizeof(InDataType) * G * N * Hi * Wi * C); + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * G * N * Ho * Wo * K); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleDRun(argument_ptr.get(), StreamConfig{nullptr, true}); std::size_t flop = std::size_t(2) * G * N * K * C * Ho * Wo * Y * X; - std::size_t num_bytes = sizeof(InDataType) * G * N * Hi * Wi * C + + std::size_t num_bytes = sizeof(InDataType) * N * Hi * Wi * G * C + sizeof(WeiDataType) * G * K * Y * X * C + - sizeof(OutDataType) * G * N * Ho * Wo * K; + sizeof(OutDataType) * N * Ho * Wo * G * K; float tflops = static_cast(flop) / 1.E9 / avg_time; float gb_per_sec = num_bytes / 1.E6 / avg_time; diff --git a/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp index a10dd3e006..43a4779f5f 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_relu_perchannel_quantization.cpp @@ -17,26 +17,26 @@ using BiasDataType = int32_t; using RequantScaleDataType = float; using OutDataType = int8_t; -using InLayout = ck::tensor_layout::convolution::GNHWC; +using InLayout = ck::tensor_layout::convolution::NHWGC; using WeiLayout = ck::tensor_layout::convolution::GKYXC; using BiasLayout = ck::tensor_layout::convolution::G_K; using RequantScaleLayout = ck::tensor_layout::convolution::G_K; -using OutLayout = ck::tensor_layout::convolution::GNHWK; +using OutLayout = ck::tensor_layout::convolution::NHWGK; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ActivationOp = ck::tensor_operation::element_wise::Relu; using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul2_Clamp; static constexpr ck::index_t NumDimSpatial = 2; -static constexpr ck::index_t G = 1; -static constexpr ck::index_t N = 4; // batch size -static constexpr ck::index_t K = 64; // output channel -static constexpr ck::index_t C = 192; // input channel -static constexpr ck::index_t Y = 3; // filter H -static constexpr ck::index_t X = 3; // filter W -static constexpr ck::index_t Hi = 71; // input H -static constexpr ck::index_t Wi = 71; // input W -static constexpr ck::index_t Ho = 36; // output H -static constexpr ck::index_t Wo = 36; // output W +static constexpr ck::index_t G = 4; +static constexpr ck::index_t N = 4; // batch size +static constexpr ck::index_t K = 32; // output channel +static constexpr ck::index_t C = 64; // input channel (per group) +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 71; // input H +static constexpr ck::index_t Wi = 71; // input W +static constexpr ck::index_t Ho = 36; // output H +static constexpr ck::index_t Wo = 36; // output W struct SimpleDeviceMem { SimpleDeviceMem() = delete; @@ -55,8 +55,11 @@ struct SimpleDeviceMem int main(int argc, char* argv[]) { + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space + // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW + // Hence, we need to adjust the order of stride std::array in_lengths{G, N, C, Hi, Wi}; - std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; std::array weight_lengths{G, K, C, Y, X}; std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; std::array bias_lengths{G, N, K, Ho, Wo}; @@ -64,17 +67,18 @@ int main(int argc, char* argv[]) std::array requant_scale_lengths{G, N, K, Ho, Wo}; std::array requant_scale_strides{K, 0, 1, 0, 0}; std::array out_lengths{G, N, K, Ho, Wo}; - std::array out_strides{N * Ho * Wo * K, Ho * Wo * K, 1, Wo * K, K}; + std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; + std::array in_left_pad{1, 1}; std::array in_right_pad{1, 1}; std::array conv_strides{2, 2}; std::array conv_dilations{1, 1}; - SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); - SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C); - SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K); - SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem bias(sizeof(BiasDataType) * G * K); + SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * G * K); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< NumDimSpatial, diff --git a/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp index b8e6a493ef..2ff91fe966 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_relu_perlayer_quantization.cpp @@ -16,19 +16,19 @@ using WeiDataType = int8_t; using BiasDataType = int32_t; using OutDataType = int8_t; -using InLayout = ck::tensor_layout::convolution::GNHWC; +using InLayout = ck::tensor_layout::convolution::NHWGC; using WeiLayout = ck::tensor_layout::convolution::GKYXC; using BiasLayout = ck::tensor_layout::convolution::G_K; -using OutLayout = ck::tensor_layout::convolution::GNHWK; +using OutLayout = ck::tensor_layout::convolution::NHWGK; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ActivationOp = ck::tensor_operation::element_wise::Relu; using OutElementOp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; static constexpr ck::index_t NumDimSpatial = 2; -static constexpr ck::index_t G = 1; +static constexpr ck::index_t G = 4; static constexpr ck::index_t N = 4; // batch size -static constexpr ck::index_t K = 64; // output channel -static constexpr ck::index_t C = 192; // input channel +static constexpr ck::index_t K = 32; // output channel +static constexpr ck::index_t C = 64; // input channel (per group) static constexpr ck::index_t Y = 3; // filter H static constexpr ck::index_t X = 3; // filter W static constexpr ck::index_t Hi = 71; // input H @@ -55,23 +55,27 @@ struct SimpleDeviceMem int main(int argc, char* argv[]) { + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space + // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW + // Hence, we need to adjust the order of stride std::array in_lengths{G, N, C, Hi, Wi}; - std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; std::array weight_lengths{G, K, C, Y, X}; std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; std::array bias_lengths{G, N, K, Ho, Wo}; std::array bias_strides{K, 0, 1, 0, 0}; std::array out_lengths{G, N, K, Ho, Wo}; - std::array out_strides{N * Ho * Wo * K, Ho * Wo * K, 1, Wo * K, K}; + std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; + std::array in_left_pad{1, 1}; std::array in_right_pad{1, 1}; std::array conv_strides{2, 2}; std::array conv_dilations{1, 1}; - SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); - SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem bias(sizeof(BiasDataType) * G * K); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD; static constexpr ck::index_t NumDimSpatial = 2; -static constexpr ck::index_t G = 1; +static constexpr ck::index_t G = 4; static constexpr ck::index_t N = 4; // batch size -static constexpr ck::index_t K = 64; // output channel -static constexpr ck::index_t C = 192; // input channel +static constexpr ck::index_t K = 32; // output channel +static constexpr ck::index_t C = 64; // input channel (per group) static constexpr ck::index_t Y = 3; // filter H static constexpr ck::index_t X = 3; // filter W static constexpr ck::index_t Hi = 71; // input H @@ -58,8 +58,11 @@ struct SimpleDeviceMem int main(int argc, char* argv[]) { + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space + // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW + // Hence, we need to adjust the order of stride std::array in_lengths{G, N, C, Hi, Wi}; - std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; std::array weight_lengths{G, K, C, Y, X}; std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; std::array bias_lengths{G, N, K, Ho, Wo}; @@ -67,17 +70,18 @@ int main(int argc, char* argv[]) std::array requant_scale_lengths{G, N, K, Ho, Wo}; std::array requant_scale_strides{K, 0, 1, 0, 0}; std::array out_lengths{G, N, K, Ho, Wo}; - std::array out_strides{N * Ho * Wo * K, Ho * Wo * K, 1, Wo * K, K}; + std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; + std::array in_left_pad{1, 1}; std::array in_right_pad{1, 1}; std::array conv_strides{2, 2}; std::array conv_dilations{1, 1}; - SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); - SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C); - SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K); - SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem bias(sizeof(BiasDataType) * G * K); + SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * G * K); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD< NumDimSpatial, diff --git a/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp b/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp index 7637f5c785..33407c9a1c 100644 --- a/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp +++ b/client_example/09_quantization/conv2d_fwd_bias_tanh_perlayer_quantization.cpp @@ -16,19 +16,19 @@ using WeiDataType = int8_t; using BiasDataType = int32_t; using OutDataType = int8_t; -using InLayout = ck::tensor_layout::convolution::GNHWC; +using InLayout = ck::tensor_layout::convolution::NHWGC; using WeiLayout = ck::tensor_layout::convolution::GKYXC; using BiasLayout = ck::tensor_layout::convolution::G_K; -using OutLayout = ck::tensor_layout::convolution::GNHWK; +using OutLayout = ck::tensor_layout::convolution::NHWGK; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ActivationOp = ck::tensor_operation::element_wise::TanH; using OutElementOp = ck::tensor_operation::element_wise::Add_Mul_Activation_Mul_Clamp; static constexpr ck::index_t NumDimSpatial = 2; -static constexpr ck::index_t G = 1; +static constexpr ck::index_t G = 4; static constexpr ck::index_t N = 4; // batch size -static constexpr ck::index_t K = 64; // output channel -static constexpr ck::index_t C = 192; // input channel +static constexpr ck::index_t K = 32; // output channel +static constexpr ck::index_t C = 64; // input channel (per group) static constexpr ck::index_t Y = 3; // filter H static constexpr ck::index_t X = 3; // filter W static constexpr ck::index_t Hi = 71; // input H @@ -56,23 +56,27 @@ struct SimpleDeviceMem int main(int argc, char* argv[]) { + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space + // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW + // Hence, we need to adjust the order of stride std::array in_lengths{G, N, C, Hi, Wi}; - std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; std::array weight_lengths{G, K, C, Y, X}; std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; std::array bias_lengths{G, N, K, Ho, Wo}; std::array bias_strides{K, 0, 1, 0, 0}; std::array out_lengths{G, N, K, Ho, Wo}; - std::array out_strides{N * Ho * Wo * K, Ho * Wo * K, 1, Wo * K, K}; + std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; + std::array in_left_pad{1, 1}; std::array in_right_pad{1, 1}; std::array conv_strides{2, 2}; std::array conv_dilations{1, 1}; - SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); - SimpleDeviceMem bias(sizeof(BiasDataType) * K * Y * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem bias(sizeof(BiasDataType) * G * K); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD; static constexpr ck::index_t NumDimSpatial = 2; -static constexpr ck::index_t G = 1; -static constexpr ck::index_t N = 4; // batch size -static constexpr ck::index_t K = 64; // output channel -static constexpr ck::index_t C = 192; // input channel -static constexpr ck::index_t Y = 3; // filter H -static constexpr ck::index_t X = 3; // filter W -static constexpr ck::index_t Hi = 71; // input H -static constexpr ck::index_t Wi = 71; // input W -static constexpr ck::index_t Ho = 36; // output H -static constexpr ck::index_t Wo = 36; // output W +static constexpr ck::index_t G = 4; +static constexpr ck::index_t N = 4; // batch size +static constexpr ck::index_t K = 32; // output channel +static constexpr ck::index_t C = 64; // input channel (per group) +static constexpr ck::index_t Y = 3; // filter H +static constexpr ck::index_t X = 3; // filter W +static constexpr ck::index_t Hi = 71; // input H +static constexpr ck::index_t Wi = 71; // input W +static constexpr ck::index_t Ho = 36; // output H +static constexpr ck::index_t Wo = 36; // output W struct SimpleDeviceMem { @@ -54,23 +54,27 @@ struct SimpleDeviceMem int main(int argc, char* argv[]) { + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space + // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW + // Hence, we need to adjust the order of stride std::array in_lengths{G, N, C, Hi, Wi}; - std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; std::array weight_lengths{G, K, C, Y, X}; std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; std::array requant_scale_lengths{G, N, K, Ho, Wo}; std::array requant_scale_strides{K, 0, 1, 0, 0}; std::array out_lengths{G, N, K, Ho, Wo}; - std::array out_strides{N * Ho * Wo * K, Ho * Wo * K, 1, Wo * K, K}; + std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; + std::array in_left_pad{1, 1}; std::array in_right_pad{1, 1}; std::array conv_strides{2, 2}; std::array conv_dilations{1, 1}; - SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); - SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * K); - SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem requant_scale(sizeof(RequantScaleDataType) * G * K); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD; static constexpr ck::index_t NumDimSpatial = 2; -static constexpr ck::index_t G = 1; +static constexpr ck::index_t G = 4; static constexpr ck::index_t N = 4; // batch size -static constexpr ck::index_t K = 64; // output channel -static constexpr ck::index_t C = 192; // input channel +static constexpr ck::index_t K = 32; // output channel +static constexpr ck::index_t C = 64; // input channel (per group) static constexpr ck::index_t Y = 3; // filter H static constexpr ck::index_t X = 3; // filter W static constexpr ck::index_t Hi = 71; // input H @@ -53,20 +53,24 @@ struct SimpleDeviceMem int main(int argc, char* argv[]) { + // We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space + // However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW + // Hence, we need to adjust the order of stride std::array in_lengths{G, N, C, Hi, Wi}; - std::array in_strides{N * Hi * Wi * C, Hi * Wi * C, 1, Wi * C, C}; + std::array in_strides{C, Hi * Wi * G * C, 1, Wi * G * C, G * C}; std::array weight_lengths{G, K, C, Y, X}; std::array weight_strides{K * Y * X * C, Y * X * C, 1, X * C, C}; std::array out_lengths{G, N, K, Ho, Wo}; - std::array out_strides{N * Ho * Wo * K, Ho * Wo * K, 1, Wo * K, K}; + std::array out_strides{C, Ho * Wo * G * C, 1, Wo * G * C, G * C}; + std::array in_left_pad{1, 1}; std::array in_right_pad{1, 1}; std::array conv_strides{2, 2}; std::array conv_dilations{1, 1}; - SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * C); - SimpleDeviceMem wei(sizeof(WeiDataType) * K * Y * X * C); - SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * K); + SimpleDeviceMem in(sizeof(InDataType) * N * Hi * Wi * G * C); + SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); + SimpleDeviceMem out(sizeof(OutDataType) * N * Ho * Wo * G * K); using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD(conv_param); diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc index 455e0804d4..9fd19c1c4e 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc @@ -178,10 +178,10 @@ int run_conv2d_fwd_bias_perlayer_quantization_example(const OutElementOp& out_el const auto in_element_op = InElementOp{}; const auto wei_element_op = WeiElementOp{}; - using InLayout = ck::tensor_layout::convolution::GNHWC; - using WeiLayout = ck::tensor_layout::convolution::GKYXC; + using InLayout = ck::tensor_layout::convolution::NHWGC; + using WeiLayout = ck::tensor_layout::convolution::KYXGC; using BiasLayout = ck::tensor_layout::convolution::G_K; - using OutLayout = ck::tensor_layout::convolution::GNHWK; + using OutLayout = ck::tensor_layout::convolution::NHWGK; const auto in_g_n_c_wis_desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc index 8e75c27746..cacedfdadd 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc @@ -180,10 +180,10 @@ int run_conv2d_fwd_perchannel_quantization_example(const OutElementOp& out_eleme const auto in_element_op = InElementOp{}; const auto wei_element_op = WeiElementOp{}; - using InLayout = ck::tensor_layout::convolution::GNHWC; - using WeiLayout = ck::tensor_layout::convolution::GKYXC; + using InLayout = ck::tensor_layout::convolution::NHWGC; + using WeiLayout = ck::tensor_layout::convolution::KYXGC; using RequantScaleLayout = ck::tensor_layout::convolution::G_K; - using OutLayout = ck::tensor_layout::convolution::GNHWK; + using OutLayout = ck::tensor_layout::convolution::NHWGK; const auto in_g_n_c_wis_desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); diff --git a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc index 926c033c58..77332cb6d9 100644 --- a/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc +++ b/example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc @@ -162,9 +162,9 @@ int run_conv2d_fwd_perlayer_quantization_example(const OutElementOp& out_element const auto in_element_op = InElementOp{}; const auto wei_element_op = WeiElementOp{}; - using InLayout = ck::tensor_layout::convolution::GNHWC; - using WeiLayout = ck::tensor_layout::convolution::GKYXC; - using OutLayout = ck::tensor_layout::convolution::GNHWK; + using InLayout = ck::tensor_layout::convolution::NHWGC; + using WeiLayout = ck::tensor_layout::convolution::KYXGC; + using OutLayout = ck::tensor_layout::convolution::NHWGK; const auto in_g_n_c_wis_desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index a8df7f0d5b..175932e632 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -117,20 +117,6 @@ void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( PassThrough, PassThrough>>>& instances); -void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances( - std::vector>>& instances); - void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances( std::vector>>& instances); -void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances( +// grouped conv2d forward, NHWGC/GKYXC/NHWGK +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( std::vector>>& instances); -// grouped conv2d forward, NHWGC/GKYXC/NHWGK + void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( std::vector>>& instances); +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances); + // grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK void add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances( std::vector && is_same_v && - is_same_v) - { - add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs); - add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances(op_ptrs); - } } else if constexpr(NumDimSpatial == 2 && is_same_v && is_same_v && is_same_v) @@ -398,7 +393,7 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { - // no instance + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs); } else if constexpr(is_same_v && is_same_v && is_same_v) @@ -409,12 +404,7 @@ struct DeviceOperationInstanceFactory && is_same_v) { - // no instance - } - else if constexpr(is_same_v && is_same_v && - is_same_v) - { - // no instance + add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances(op_ptrs); } } else if constexpr(NumDimSpatial == 3 && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp index 793dc8d04a..daec48050c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perchannel_quantization.hpp @@ -17,14 +17,14 @@ namespace tensor_operation { namespace device { namespace instance { -// grouped conv2d forward, GNHWC/GKYXC/GNHWK +// grouped conv2d forward, NHWGC/GKYXC/NHWGK void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances( std::vector< std::unique_ptr> op_ptrs; - if constexpr(NumDimSpatial == 2 && is_same_v && + if constexpr(NumDimSpatial == 2 && is_same_v && is_same_v && is_same_v && - is_same_v) + is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) @@ -220,9 +220,9 @@ struct DeviceOperationInstanceFactory> op_ptrs; - if constexpr(NumDimSpatial == 2 && is_same_v && + if constexpr(NumDimSpatial == 2 && is_same_v && is_same_v && is_same_v && - is_same_v) + is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp index c570f76750..b7d81021e2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_bias_forward_perlayer_quantization.hpp @@ -17,14 +17,14 @@ namespace tensor_operation { namespace device { namespace instance { -// grouped conv2d forward, GNHWC/GKYXC/GNHWK +// grouped conv2d forward, NHWGC/GKYXC/NHWGK void add_device_conv2d_dl_bias_perlayer_quantization_int8_instances( std::vector< std::unique_ptr> op_ptrs; - if constexpr(NumDimSpatial == 2 && is_same_v && + if constexpr(NumDimSpatial == 2 && is_same_v && is_same_v && is_same_v && - is_same_v) + is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) @@ -218,9 +218,9 @@ struct DeviceOperationInstanceFactory> op_ptrs; - if constexpr(NumDimSpatial == 2 && is_same_v && + if constexpr(NumDimSpatial == 2 && is_same_v && is_same_v && is_same_v && - is_same_v) + is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp index 089343fe6f..2d54879ea6 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perchannel_quantization.hpp @@ -17,13 +17,13 @@ namespace tensor_operation { namespace device { namespace instance { -// grouped conv2d forward, GNHWC/GKYXC/GNHWK +// grouped conv2d forward, NHWGC/GKYXC/NHWGK void add_device_conv2d_dl_perchannel_quantization_int8_instances( std::vector> op_ptrs; - if constexpr(NumDimSpatial == 2 && is_same_v && + if constexpr(NumDimSpatial == 2 && is_same_v && is_same_v && is_same_v && - is_same_v) + is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) diff --git a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp index e570027eb5..f278cfa224 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/quantization/grouped_convolution_forward_perlayer_quantization.hpp @@ -17,13 +17,13 @@ namespace tensor_operation { namespace device { namespace instance { -// grouped conv2d forward, GNHWC/GKYXC/GNHWK +// grouped conv2d forward, NHWGC/GKYXC/NHWGK void add_device_conv2d_dl_perlayer_quantization_int8_instances( std::vector> op_ptrs; - if constexpr(NumDimSpatial == 2 && is_same_v && - is_same_v && is_same_v) + if constexpr(NumDimSpatial == 2 && is_same_v && + is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt index 5ef1b6866e..a36e1b47ca 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt @@ -3,11 +3,11 @@ add_instance_library(device_grouped_conv2d_fwd_instance device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp - device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instance.cpp # NHWGC, GKYXC, NHWGK + device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp - #dl + device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp + #dl device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp - device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_common.hpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_common.hpp new file mode 100644 index 0000000000..b4de825fb6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_common.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using NHWGC = ck::tensor_layout::convolution::NHWGC; +using GNHWC = ck::tensor_layout::convolution::GNHWC; + +using GKYXC = ck::tensor_layout::convolution::GKYXC; + +using NHWGK = ck::tensor_layout::convolution::NHWGK; +using GNHWK = ck::tensor_layout::convolution::GNHWK; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +static constexpr auto ConvFwdOddC = + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp index fc18b3c735..f7e575df2b 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -1,100 +1,54 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "device_grouped_conv2d_fwd_dl_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -using InDataType = ck::half_t; -using WeiDataType = ck::half_t; -using AccDataType = float; -using OutDataType = ck::half_t; - -using Empty_Tuple = ck::Tuple<>; -template -using S = ck::Sequence; - -using InElementOp = ck::tensor_operation::element_wise::PassThrough; -using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; -using OutElementOp = ck::tensor_operation::element_wise::PassThrough; - -using InLayout = ck::tensor_layout::convolution::GNHWC; -using WeiLayout = ck::tensor_layout::convolution::GKYXC; -using OutLayout = ck::tensor_layout::convolution::GNHWK; - -static constexpr auto ConvSpec = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; -static constexpr auto Filter1x1Pad0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; -static constexpr auto Filter1x1Stride1Pad0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances = std::tuple< - // clang-format off - // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; -using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_Filter1x1Pad0_instances = std::tuple< - // clang-format off - // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Pad0, GemmPadingSpec, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_Filter1x1Stride1Pad0_instances = - std::tuple< - // clang-format off - // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Stride1Pad0, GemmPadingSpec, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances( std::vector>>& instances) + F16, + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instances{}); + device_grouped_conv2d_fwd_dl_f16_instances{}); - add_device_operation_instances( - instances, device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_Filter1x1Pad0_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_dl_f16_instances{}); - add_device_operation_instances( - instances, - device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_Filter1x1Stride1Pad0_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_dl_f16_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp index 648b39637e..85300b4e44 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -1,104 +1,54 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "device_grouped_conv2d_fwd_dl_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { -using InDataType = float; -using WeiDataType = float; -using AccDataType = float; -using OutDataType = float; - -using Empty_Tuple = ck::Tuple<>; -template -using S = ck::Sequence; - -using InElementOp = ck::tensor_operation::element_wise::PassThrough; -using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; -using OutElementOp = ck::tensor_operation::element_wise::PassThrough; - -using InLayout = ck::tensor_layout::convolution::GNHWC; -using WeiLayout = ck::tensor_layout::convolution::GKYXC; -using OutLayout = ck::tensor_layout::convolution::GNHWK; - -static constexpr auto ConvSpec = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; -static constexpr auto Filter1x1Pad0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; -static constexpr auto Filter1x1Stride1Pad0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances = std::tuple< - // clang-format off - // clang-format off - // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_Filter1x1Pad0_instances = std::tuple< - // clang-format off - // clang-format off - // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Pad0, GemmPadingSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_Filter1x1Stride1Pad0_instances = - std::tuple< - // clang-format off - // clang-format off - // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Stride1Pad0, GemmPadingSpec, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances( std::vector>>& instances) + F32, + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_instances{}); + device_grouped_conv2d_fwd_dl_f32_instances{}); - add_device_operation_instances( - instances, device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_Filter1x1Pad0_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_dl_f32_instances{}); - add_device_operation_instances( - instances, - device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f32_Filter1x1Stride1Pad0_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_dl_f32_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instance.cpp deleted file mode 100644 index 1cb5d06999..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instance.cpp +++ /dev/null @@ -1,104 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using InDataType = int8_t; -using WeiDataType = int8_t; -using AccDataType = int32_t; -using OutDataType = int8_t; - -using Empty_Tuple = ck::Tuple<>; -template -using S = ck::Sequence; - -using InElementOp = ck::tensor_operation::element_wise::PassThrough; -using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; -using OutElementOp = ck::tensor_operation::element_wise::PassThrough; - -using InLayout = ck::tensor_layout::convolution::GNHWC; -using WeiLayout = ck::tensor_layout::convolution::GKYXC; -using OutLayout = ck::tensor_layout::convolution::GNHWK; - -static constexpr auto ConvSpec = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; -static constexpr auto Filter1x1Pad0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; -static constexpr auto Filter1x1Stride1Pad0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto GemmPadingSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances = std::tuple< - // clang-format off - // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, ConvSpec, GemmPadingSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_Filter1x1Pad0_instances = std::tuple< - // clang-format off - // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Pad0, GemmPadingSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -using device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_Filter1x1Stride1Pad0_instances = - std::tuple< - // clang-format off - // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| - // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| - // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | - // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, InDataType, WeiDataType, Empty_Tuple, OutDataType, AccDataType, InLayout, WeiLayout, Empty_Tuple, OutLayout, InElementOp, WeiElementOp, OutElementOp, Filter1x1Stride1Pad0, GemmPadingSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, 4> - // clang-format on - >; - -void add_device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_instances{}); - - add_device_operation_instances( - instances, device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_Filter1x1Pad0_instances{}); - - add_device_operation_instances( - instances, - device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_int8_Filter1x1Stride1Pad0_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_instance.hpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_instance.hpp new file mode 100644 index 0000000000..bcda22006b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_dl_instance.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp" +#include "device_grouped_conv2d_fwd_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using device_grouped_conv2d_fwd_dl_f16_instances = std::tuple< + // clang-format off + // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F16, F16, DsDatatype, F16, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<8, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +template +using device_grouped_conv2d_fwd_dl_f32_instances = std::tuple< + // clang-format off + // clang-format off + // ########################################| NDim| InData| WeiData| MultpleD| OutData| AccData| InLayout| WeiLayout| MultipleD| OutLayout| In| Wei| Out| Convolution| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer| + // ########################################| Spatial| Type| Type| Type| Type| Type| | | Layout| | Elementwise| Elementwise| Elementwise| Forward| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector| + // ########################################| | | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | | + // ########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< 2, F32, F32, DsDatatype, F32, F32, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 16, 1, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<8, 1, 1, 1>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 1, 1, 1>, S<0, 1, 2, 3, 4, 5>, 5, 4> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp index 29f3310314..40593a0efb 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp @@ -1,137 +1,14 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "device_grouped_conv2d_fwd_xdl_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { - -using BF16 = ck::bhalf_t; -using F32 = float; - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using GNHWC = ck::tensor_layout::convolution::GNHWC; -using GKYXC = ck::tensor_layout::convolution::GKYXC; -using GNHWK = ck::tensor_layout::convolution::GNHWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto ConvFwdOddC = - ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - // Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] -using device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances = - std::tuple< - // clang-format off - // Default - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Stride1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // OddC - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 256, 64, 32, 8, 8, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - void add_device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances( std::vector>>& instances) { - add_device_operation_instances( - instances, device_grouped_conv1d_fwd_xdl_gnhwc_gkyxc_gnhwk_bf16_instances{}); + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_bf16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_bf16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_bf16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_bf16_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp index 6a4a3d2a45..7088028bf5 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp @@ -1,137 +1,14 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "device_grouped_conv2d_fwd_xdl_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using GNHWC = ck::tensor_layout::convolution::GNHWC; -using GKYXC = ck::tensor_layout::convolution::GKYXC; -using GNHWK = ck::tensor_layout::convolution::GNHWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto ConvFwdOddC = - ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - // Compilation parameters for in[g, n, hi ,wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] -using device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances = - std::tuple< - // clang-format off - // Default - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Stride1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // OddC - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 256, 64, 32, 8, 8, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances( std::vector>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f16_instances{}); + device_grouped_conv2d_fwd_xdl_f16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f16_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp index 1fec35fd9f..919274c503 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp @@ -1,109 +1,14 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "device_grouped_conv2d_fwd_xdl_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { - -using F32 = float; - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using GNHWC = ck::tensor_layout::convolution::GNHWC; -using GKYXC = ck::tensor_layout::convolution::GKYXC; -using GNHWK = ck::tensor_layout::convolution::GNHWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - // Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] -using device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances = - std::tuple< - // clang-format off - // Default - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - - // Filter1x1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - - // Filter1x1Stride1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> - // clang-format on - >; - void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances( std::vector>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_f32_instances{}); + device_grouped_conv2d_fwd_xdl_f32_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f32_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f32_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f32_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instance.cpp deleted file mode 100644 index 59b0121340..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instance.cpp +++ /dev/null @@ -1,125 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using GNHWC = ck::tensor_layout::convolution::GNHWC; -using GKYXC = ck::tensor_layout::convolution::GKYXC; -using GNHWK = ck::tensor_layout::convolution::GNHWK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -// Compilation parameters for in[g, n, hi, wi, c] * wei[g, k, y, x, c] = out[g, n, ho, wo, k] -using device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances = std::tuple< - // clang-format off - // Default - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Stride1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, GNHWC, GKYXC, Empty_Tuple, GNHWK, int8_t, int8_t, int32_t, int8_t, Empty_Tuple, int8_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - -void add_device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, device_grouped_conv2d_fwd_xdl_gnhwc_gkyxc_gnhwk_int8_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_instance.hpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_instance.hpp new file mode 100644 index 0000000000..2858671ee9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_instance.hpp @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" +#include "device_grouped_conv2d_fwd_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using device_grouped_conv2d_fwd_xdl_f16_instances = + std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDatatype, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +template +using device_grouped_conv2d_fwd_xdl_bf16_instances = + std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDatatype, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +template +using device_grouped_conv2d_fwd_xdl_f32_instances = + std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 256, 16, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 64, 128, 16, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 64, 16, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 128, 64, 16, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 256, 64, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 128, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 128, 32, 128, 16, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 64, 32, 16, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, DsDatatype, F32, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 1, 64, 32, 64, 16, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 0000000000..25caf61df1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "device_grouped_conv2d_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_bf16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_bf16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_bf16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_bf16_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp index 8aca730433..b997cfb672 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -1,137 +1,14 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "device_grouped_conv2d_fwd_xdl_instance.hpp" namespace ck { namespace tensor_operation { namespace device { namespace instance { - -using F16 = ck::half_t; -using F32 = float; - -using Empty_Tuple = ck::Tuple<>; - -template -using S = ck::Sequence; - -using NHWGC = ck::tensor_layout::convolution::NHWGC; -using GKYXC = ck::tensor_layout::convolution::GKYXC; -using NHWGK = ck::tensor_layout::convolution::NHWGK; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -static constexpr auto ConvFwdDefault = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; - -static constexpr auto ConvFwd1x1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; - -static constexpr auto ConvFwd1x1S1P0 = - ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; - -static constexpr auto ConvFwdOddC = - ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; - -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - // Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -using device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances = - std::tuple< - // clang-format off - // Default - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdDefault, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // Filter1x1Stride1Pad0 - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - - // OddC - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 256, 256, 64, 32, 8, 8, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle< 2, NHWGC, GKYXC, Empty_Tuple, NHWGK, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvFwdOddC, GemmMNKPadding, 1, 128, 64, 64, 32, 8, 8, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances( std::vector>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances{}); + device_grouped_conv2d_fwd_xdl_f16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f16_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f16_instances{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp new file mode 100644 index 0000000000..3256a2a826 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "device_grouped_conv2d_fwd_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f32_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f32_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f32_instances{}); + + add_device_operation_instances(instances, + device_grouped_conv2d_fwd_xdl_f32_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp index b231f8c956..672cdba65d 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/conv2d_quantization_common.hpp @@ -19,9 +19,9 @@ using Empty_Tuple = ck::Tuple<>; template using S = ck::Sequence; -using GNHWC = ck::tensor_layout::convolution::GNHWC; +using NHWGC = ck::tensor_layout::convolution::NHWGC; using GKYXC = ck::tensor_layout::convolution::GKYXC; -using GNHWK = ck::tensor_layout::convolution::GNHWK; +using NHWGK = ck::tensor_layout::convolution::NHWGK; using GK = ck::tensor_layout::convolution::G_K; using PassThrough = ck::tensor_operation::element_wise::PassThrough; using Relu = ck::tensor_operation::element_wise::Relu; diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp index ae5c1d7c32..d4b5484d8b 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_bias_perchannel_quantization_int8_instance.cpp @@ -9,10 +9,10 @@ namespace device { namespace instance { void add_device_conv2d_dl_bias_perchannel_quantization_int8_instances( std::vector{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, DstScalarPerVector> + DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK< NDimSpatial, int8_t, int8_t, DsDatatype, int8_t, int32_t, InLayout, WeiLayout, DsLayout, OutLayout, PassThrough, PassThrough, OutElementOp, ConvSpec, GemmSpec, 256, 128, 128, 16, 4, 4, 4, 1, S<8, 2>, S<8, 2>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<8, 1, 1, 4>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<0, 1, 2, 3, 4, 5>, 5, DstScalarPerVector> >; // clang-format on diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp index d45c1c1ee7..c8f5f7042c 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_dl_perchannel_quantization_int8_instance.cpp @@ -9,10 +9,10 @@ namespace device { namespace instance { void add_device_conv2d_dl_perchannel_quantization_int8_instances( std::vector>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_dl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances using device_grouped_conv2d_xdl_int8_instances = std::tuple < - //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, DstScalarPerVector>, - DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, DstScalarPerVector> + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 64, 1, 2>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, DstScalarPerVector>, + DeviceGroupedConvFwdMultipleD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, DstScalarPerVector> >; // clang-format on diff --git a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp index b0f86b5c2e..9d69377085 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/quantization/conv2d_fwd/device_conv2d_xdl_perchannel_quantization_int8_instance.cpp @@ -9,10 +9,10 @@ namespace device { namespace instance { void add_device_conv2d_xdl_perchannel_quantization_int8_instances( std::vector>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances>>& instances) { add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances{}); add_device_operation_instances(instances, - device_grouped_conv2d_xdl_int8_instances