diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp index 2f048097a1..d951bc4e4b 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp @@ -291,8 +291,8 @@ int main(int argc, char* argv[]) float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << conv->GetTypeString() - << std::endl; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << conv->GetTypeString() << std::endl; if(do_verification) { diff --git a/include/ck/config.hpp b/include/ck/config.hpp index 293e27ad97..a4d2ef7c55 100644 --- a/include/ck/config.hpp +++ b/include/ck/config.hpp @@ -163,10 +163,6 @@ // tuning parameter #define CK_WORKAROUND_SWDEV_325164 1 -// workaround for verification failure ConvNd forward -// https://github.com/ROCmSoftwarePlatform/composable_kernel/issues/135 -#define CK_WORKAROUND_GITHUB_135 1 - namespace ck { enum struct InMemoryDataOperationEnum diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp index 7cd6cc34c9..368da4d207 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -1,5 +1,4 @@ -#ifndef CHECK_ERR_HPP -#define CHECK_ERR_HPP +#pragma once #include #include @@ -169,17 +168,34 @@ check_err(const std::vector& out, return false; } + bool res{true}; + int err_count = 0; + int64_t err = 0; + int64_t max_err = std::numeric_limits::min(); for(std::size_t i = 0; i < ref.size(); ++i) { - if(out[i] != ref[i]) + int64_t o = out[i]; + int64_t r = ref[i]; + err = std::abs(o - r); + + if(err > 0) { - std::cout << "out[" << i << "] != ref[" << i << "]: " << static_cast(out[i]) - << " != " << static_cast(ref[i]) << std::endl - << msg << std::endl; - return false; + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cout << "out[" << i << "] != ref[" << i << "]: " << static_cast(out[i]) + << " != " << static_cast(ref[i]) << std::endl + << msg << std::endl; + } + res = false; } } - return true; + if(!res) + { + std::cout << "max err: " << max_err << std::endl; + } + return res; } } // namespace utils @@ -191,5 +207,3 @@ std::ostream& operator<<(std::ostream& os, const std::vector& v) std::copy(std::begin(v), std::end(v), std::ostream_iterator(os, " ")); return os; } - -#endif diff --git a/library/include/ck/library/utility/conv_util.hpp b/library/include/ck/library/utility/conv_util.hpp index c881b89705..409fa5aff2 100644 --- a/library/include/ck/library/utility/conv_util.hpp +++ b/library/include/ck/library/utility/conv_util.hpp @@ -402,8 +402,8 @@ template , - typename WeightsInitFun = FillUniform> + typename InputInitFun = FillUniformDistribution, + typename WeightsInitFun = FillUniformDistribution> class ConvFwdOpInstance : public ck::utils::OpInstance { using DeviceConvFwdOp = tensor_operation::device:: @@ -422,8 +422,8 @@ class ConvFwdOpInstance : public ck::utils::OpInstance output_spatial_lengths_; const bool do_init_; - const InputInitFun& input_init_f_; - const WeightsInitFun& weights_init_f_; + InputInitFun input_init_f_; + WeightsInitFun weights_init_f_; }; } // namespace conv diff --git a/library/include/ck/library/utility/fill.hpp b/library/include/ck/library/utility/fill.hpp index f44aec969d..8c31e56beb 100644 --- a/library/include/ck/library/utility/fill.hpp +++ b/library/include/ck/library/utility/fill.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include "data_type.hpp" @@ -8,46 +9,56 @@ namespace ck { namespace utils { -// template -// struct FillUniform; - -// TODO: what's wrong with this specialization??? -// err: segmentation fault in mt19937 - infinite loop like. -// template -// struct FillUniform::value && -// !std::is_same::value>::type> -// { -// int a_{0}; -// int b_{5}; -// // T a_ = T{0}; -// // T b_ = T{5}; - -// template -// void operator()(ForwardIter first, ForwardIter last) const -// { -// std::mt19937 gen{11939}; -// std::uniform_int_distribution dis(a_, b_); -// std::generate(first, last, [&dis, &gen]() { return ck::type_convert(dis(gen)); }); -// } -// }; - -// struct FillUniform::value || -// std::is_same::value>::type> template -struct FillUniform +struct FillUniformDistribution { - float a_{0}; - float b_{5}; + float a_{-5.f}; + float b_{5.f}; template void operator()(ForwardIter first, ForwardIter last) const { - std::mt19937 gen{11939}; - std::uniform_real_distribution<> dis(a_, b_); + std::mt19937 gen(11939); + std::uniform_real_distribution dis(a_, b_); std::generate(first, last, [&dis, &gen]() { return ck::type_convert(dis(gen)); }); } }; +// Normally FillUniformDistributionIntegerValue should use std::uniform_int_distribution as below. +// However this produces segfaults in std::mt19937 which look like inifite loop. +// template +// struct FillUniformDistributionIntegerValue +// { +// int a_{-5}; +// int b_{5}; +// +// template +// void operator()(ForwardIter first, ForwardIter last) const +// { +// std::mt19937 gen(11939); +// std::uniform_int_distribution dis(a_, b_); +// std::generate( +// first, last, [&dis, &gen]() { return ck::type_convert(dis(gen)); }); +// } +// }; + +// Workaround for uniform_int_distribution not working as expected. See note above.< +template +struct FillUniformDistributionIntegerValue +{ + float a_{-5.f}; + float b_{5.f}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::mt19937 gen(11939); + std::uniform_real_distribution dis(a_, b_); + std::generate( + first, last, [&dis, &gen]() { return ck::type_convert(std::round(dis(gen))); }); + } +}; + template struct FillMonotonicSeq { diff --git a/library/include/ck/library/utility/op_instance_engine.hpp b/library/include/ck/library/utility/op_instance_engine.hpp index 5429f66d3e..1d11b62a4a 100644 --- a/library/include/ck/library/utility/op_instance_engine.hpp +++ b/library/include/ck/library/utility/op_instance_engine.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -78,7 +79,8 @@ class OpInstanceRunEngine template > OpInstanceRunEngine(const OpInstanceT& op_instance, - const ReferenceOp& reference_op = ReferenceOp{}) + const ReferenceOp& reference_op = ReferenceOp{}, + bool do_verification = true) : op_instance_{op_instance} { in_tensors_ = op_instance_.GetInputTensors(); @@ -88,8 +90,11 @@ class OpInstanceRunEngine const Tensor&..., Tensor&>) { - ref_output_ = op_instance_.GetOutputTensor(); - CallRefOpUnpackArgs(reference_op, std::make_index_sequence{}); + if(do_verification) + { + ref_output_ = op_instance_.GetOutputTensor(); + CallRefOpUnpackArgs(reference_op, std::make_index_sequence{}); + } } AllocateDeviceInputTensors(std::make_index_sequence{}); out_device_buffer_ = @@ -110,6 +115,7 @@ class OpInstanceRunEngine op_ptr.get(), in_device_buffers_, out_device_buffer_); if(op_ptr->IsSupportedArgument(argument.get())) { + std::cout << "Testing instance: " << op_ptr->GetTypeString() << std::endl; invoker->Run(argument.get()); out_device_buffer_->FromDevice(out_tensor_->mData.data()); if(!ref_output_) @@ -119,9 +125,16 @@ class OpInstanceRunEngine " You have to provide reference function."); } // TODO: enable flexible use of custom check_error functions - res = res && check_err(out_tensor_->mData, ref_output_->mData); + bool inst_res = CheckErr(out_tensor_->mData, ref_output_->mData); + std::cout << (inst_res ? "SUCCESS" : "FAILURE") << std::endl; + res = res && inst_res; out_device_buffer_->SetZero(); } + else + { + std::cout << "Given conv problem is not supported by instance: \n\t>>>>" + << op_ptr->GetTypeString() << std::endl; + } } return res; } @@ -132,7 +145,6 @@ class OpInstanceRunEngine bool do_verification = false, bool do_log = false) { - bool res{true}; ProfileBestConfig best_config; for(auto& op_ptr : op_ptrs) @@ -153,7 +165,7 @@ class OpInstanceRunEngine std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << op_name << std::endl; - if(tflops < best_config.best_tflops) + if(avg_time < best_config.best_avg_time) { best_config.best_op_name = op_name; best_config.best_tflops = tflops; @@ -171,7 +183,7 @@ class OpInstanceRunEngine " You have to provide reference function."); } // TODO: enable flexible use of custom check_error functions - res = res && CheckErr(out_tensor_->mData, ref_output_->mData); + CheckErr(out_tensor_->mData, ref_output_->mData); if(do_log) {} } @@ -223,7 +235,7 @@ class OpInstanceRunEngine template bool CheckErr(const std::vector& dev_out, const std::vector& ref_out) const { - return ck::utils::check_err(dev_out, ref_out, "Error: incorrect results!", atol_, rtol_); + return ck::utils::check_err(dev_out, ref_out, "Error: incorrect results!", rtol_, atol_); } }; diff --git a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp index 9288e40e56..a133300f73 100644 --- a/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv1d_fwd/device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instance.cpp @@ -28,15 +28,12 @@ static constexpr auto ConvFwd1x1S1P0 = // Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] using device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances = std::tuple< -// clang-format off + // clang-format off //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if !CK_WORKAROUND_GITHUB_135 - // FIXME: this instance causes numerical errors. DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, -#endif DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt index 857e36d6f5..1ef4a9b07e 100644 --- a/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/CMakeLists.txt @@ -6,7 +6,18 @@ set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp; device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp; ) +set(DEVICE_CONVND_2D_FWD_INSTANCE_SOURCE + device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; + device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; + device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp; + device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp; +) + add_library(device_conv2d_fwd_instance OBJECT ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE}) +add_library(device_convnd_2d_fwd_instance OBJECT ${DEVICE_CONVND_2D_FWD_INSTANCE_SOURCE}) + set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(device_convnd_2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) clang_tidy_check(device_conv2d_fwd_instance) +clang_tidy_check(device_convnd_2d_fwd_instance) diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp new file mode 100644 index 0000000000..de98151ef8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp @@ -0,0 +1,113 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances{}); +} + +} // namespace device_conv2d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp new file mode 100644 index 0000000000..4b4a0fc25a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp @@ -0,0 +1,112 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_instance { + +using F16 = ck::half_t; +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1> + // clang-format on + >; + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{}); +} + +} // namespace device_conv2d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp new file mode 100644 index 0000000000..5603fc5d06 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp @@ -0,0 +1,111 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1> + // clang-format on + >; + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{}); +} + +} // namespace device_conv2d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp new file mode 100644 index 0000000000..b4447bcb82 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/conv2d_fwd/device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp @@ -0,0 +1,112 @@ +#include +#include "config.hpp" +#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_conv2d_fwd_instance { + +using F32 = float; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvFwdDefault = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto ConvFwd1x1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + +static constexpr auto ConvFwd1x1S1P0 = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + +// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = + std::tuple< + // clang-format off + //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| + //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| + //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| + //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>, + DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1> + // clang-format on + >; + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_int8_instances{}); + add_device_operation_instances(instances, + device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances{}); +} + +} // namespace device_conv2d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp index 745d26904a..bff51affd1 100644 --- a/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/conv3d_fwd/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instance.cpp @@ -28,15 +28,12 @@ static constexpr auto ConvFwd1x1S1P0 = // Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k] using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances = std::tuple< -// clang-format off + // clang-format off //################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| //################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | -#if !CK_WORKAROUND_GITHUB_135 - // FIXME: this instance causes numerical errors. DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, -#endif DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>, diff --git a/profiler/src/profile_convnd_fwd.cpp b/profiler/src/profile_convnd_fwd.cpp index 87778a04a5..cb92587897 100644 --- a/profiler/src/profile_convnd_fwd.cpp +++ b/profiler/src/profile_convnd_fwd.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -150,9 +151,12 @@ void profile_convnd_instances_impl(const ck::utils::conv::ConvParams& params, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, - ck::utils::FillUniform, - ck::utils::FillUniform>>( - params, true, ck::utils::FillUniform{}, ck::utils::FillUniform{}); + ck::utils::FillUniformDistributionIntegerValue, + ck::utils::FillUniformDistributionIntegerValue>>( + params, + true, + ck::utils::FillUniformDistributionIntegerValue{}, + ck::utils::FillUniformDistributionIntegerValue{}); break; case 2: conv_instance = std::make_unique< @@ -165,12 +169,12 @@ void profile_convnd_instances_impl(const ck::utils::conv::ConvParams& params, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough, - ck::utils::FillUniform, - ck::utils::FillUniform>>( + ck::utils::FillUniformDistribution, + ck::utils::FillUniformDistribution>>( params, true, - ck::utils::FillUniform{}, - ck::utils::FillUniform{}); + ck::utils::FillUniformDistribution{}, + ck::utils::FillUniformDistribution{}); break; default: throw std::runtime_error("Unsupported init method!"); } @@ -181,8 +185,10 @@ void profile_convnd_instances_impl(const ck::utils::conv::ConvParams& params, _1, _2, _3); - OpInstanceRunEngine run_engine(*conv_instance, - reference_conv_fwd_fun); + + OpInstanceRunEngine run_engine( + *conv_instance, reference_conv_fwd_fun, do_verification); + auto best_conf = run_engine.Profile( conv::ConvolutionFwdInstances::template Get(), time_kernel, diff --git a/script/profile_conv.sh b/script/profile_conv.sh index 42736dd37f..c3ba39c926 100755 --- a/script/profile_conv.sh +++ b/script/profile_conv.sh @@ -47,7 +47,7 @@ REPEAT=$9 #$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 #$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 1 1 56 56 1 1 1 1 0 0 0 0 #$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 -#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 8 7 7 224 224 2 2 1 1 3 3 3 3 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 3 7 7 224 224 2 2 1 1 3 3 3 3 # Resnet50 fusion diff --git a/test/convnd_fwd/CMakeLists.txt b/test/convnd_fwd/CMakeLists.txt index 34e698681b..444ec6c8aa 100644 --- a/test/convnd_fwd/CMakeLists.txt +++ b/test/convnd_fwd/CMakeLists.txt @@ -5,7 +5,7 @@ target_link_libraries(test_conv1d_fwd PRIVATE host_tensor device_conv1d_fwd_inst add_dependencies(test_convnd_fwd test_conv1d_fwd) add_gtest_executable(test_conv2d_fwd conv2d_fwd.cpp) -target_link_libraries(test_conv2d_fwd PRIVATE host_tensor device_conv2d_fwd_instance conv_util) +target_link_libraries(test_conv2d_fwd PRIVATE host_tensor device_conv2d_fwd_instance device_convnd_2d_fwd_instance conv_util) add_dependencies(test_convnd_fwd test_conv2d_fwd) add_gtest_executable(test_conv3d_fwd conv3d_fwd.cpp) diff --git a/test/convnd_fwd/conv1d_fwd.cpp b/test/convnd_fwd/conv1d_fwd.cpp index b6b6a89b2c..9b4708e94b 100644 --- a/test/convnd_fwd/conv1d_fwd.cpp +++ b/test/convnd_fwd/conv1d_fwd.cpp @@ -1,5 +1,4 @@ #include -#include #include #include #include "gtest/gtest.h" @@ -11,83 +10,180 @@ namespace { -template -bool test_conv1d_nwc_instances(const std::vector& conv_ptrs) +class Conv1dFwdNWCInstances : public ::testing::Test +{ + public: + template + bool test_conv1d_nwc_instances(const std::vector& conv_ptrs, + const ck::utils::conv::ConvParams& params) + { + using namespace std::placeholders; + using namespace ck::utils; + namespace ctl = ck::tensor_layout::convolution; + + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<1, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(atol_); + run_engine.SetRtol(rtol_); + return run_engine.Test(conv_ptrs); + } + + template + bool test_default() + { + return test_conv1d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<1>(), params_default_); + } + + template + bool test_filter1x1_stride1_pad0() + { + return test_conv1d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<1>(), + params_filter1x1_stride1_pad0_); + } + + template + bool test_filter1x1_pad0() + { + return test_conv1d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<1>(), + params_filter1x1_pad0_); + } + + static inline ck::utils::conv::ConvParams params_default_{ + 1, 4, 256, 64, {3}, {71}, {2}, {2}, {2}, {2}}; + static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{ + 1, 4, 256, 64, {1}, {28}, {1}, {1}, {0}, {0}}; + static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{ + 1, 4, 256, 64, {1}, {28}, {2}, {1}, {0}, {0}}; + + private: + double atol_{1e-5}; + double rtol_{1e-4}; +}; + +} // anonymous namespace + +TEST(Conv1DFwdNWC, IntegerValues) { using namespace std::placeholders; using namespace ck::utils; namespace ctl = ck::tensor_layout::convolution; + using T = float; - ck::utils::conv::ConvParams params; - params.num_dim_spatial_ = 1; - params.filter_spatial_lengths_ = std::vector{3}; - params.input_spatial_lengths_ = std::vector{71}; - params.conv_filter_strides_ = std::vector{2}; - params.conv_filter_dilations_ = std::vector{1}; - params.input_left_pads_ = std::vector{1}; - params.input_right_pads_ = std::vector{1}; + ck::utils::conv::ConvParams params{1, 4, 256, 64, {3}, {36}, {1}, {2}, {2}, {2}}; - conv::ConvFwdOpInstance conv_instance(params); + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<1, T, T, T, T>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); auto reference_conv_fwd_fun = std::bind(conv::run_reference_convolution_forward<1, T, T, T>, params, _1, _2, _3); OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); - return run_engine.Test(conv_ptrs); -} - -} // anonymous namespace - -TEST(Conv1DFwdNWC, TestConv1D) -{ - using namespace std::placeholders; - using namespace ck::utils; - namespace ctl = ck::tensor_layout::convolution; - - ck::utils::conv::ConvParams params; - params.num_dim_spatial_ = 1; - params.N_ = 2; - params.K_ = 16; - params.C_ = 4; - params.filter_spatial_lengths_ = std::vector{3}; - params.input_spatial_lengths_ = std::vector{16}; - params.conv_filter_strides_ = std::vector{1}; - params.conv_filter_dilations_ = std::vector{1}; - params.input_left_pads_ = std::vector{1}; - params.input_right_pads_ = std::vector{1}; - - std::vector conv_ptrs; - test::conv::get_test_convolution_fwd_instance<1>(conv_ptrs); - conv::ConvFwdOpInstance conv_instance( - params); - - auto reference_conv_fwd_fun = std::bind( - conv::run_reference_convolution_forward<1, float, float, float>, params, _1, _2, _3); - OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); run_engine.SetAtol(1e-5); run_engine.SetRtol(1e-4); EXPECT_TRUE(run_engine.Test(conv_ptrs)); } -TEST(Conv1DFwdNWC, Bf16Iinstances) +TEST(Conv1DFwdNWC, FloatingPointValues) { - EXPECT_TRUE(test_conv1d_nwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<1>())); + using namespace std::placeholders; + using namespace ck::utils; + namespace ctl = ck::tensor_layout::convolution; + using T = ck::half_t; + + ck::utils::conv::ConvParams params{1, 4, 256, 64, {3}, {36}, {1}, {2}, {2}, {2}}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<1, T, T, T, float>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistribution> + conv_instance(params, true, FillUniformDistribution{}, FillUniformDistribution{}); + + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<1, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(0.1); + run_engine.SetRtol(1e-2); + EXPECT_TRUE(run_engine.Test(conv_ptrs)); } -TEST(Conv1DFwdNWC, F16Instances) +TEST_F(Conv1dFwdNWCInstances, BF16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv1dFwdNWCInstances, BF16_filter1x1_stride1_pad0) { - EXPECT_TRUE(test_conv1d_nwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<1>())); + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv1dFwdNWCInstances, BF16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); } -TEST(Conv1DFwdNWC, F32Instances) +TEST_F(Conv1dFwdNWCInstances, F16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv1dFwdNWCInstances, F16_filter1x1_stride1_pad0) { - EXPECT_TRUE(test_conv1d_nwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<1>())); + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv1dFwdNWCInstances, F16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); } -TEST(Conv1DFwdNWC, Int8Instances) +TEST_F(Conv1dFwdNWCInstances, F32_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv1dFwdNWCInstances, F32_filter1x1_stride1_pad0) { - EXPECT_TRUE(test_conv1d_nwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<1>())); + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv1dFwdNWCInstances, F32_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} + +TEST_F(Conv1dFwdNWCInstances, I8_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv1dFwdNWCInstances, I8_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv1dFwdNWCInstances, I8_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); } diff --git a/test/convnd_fwd/conv2d_fwd.cpp b/test/convnd_fwd/conv2d_fwd.cpp index 05e46147be..4e0238cc4f 100644 --- a/test/convnd_fwd/conv2d_fwd.cpp +++ b/test/convnd_fwd/conv2d_fwd.cpp @@ -1,91 +1,265 @@ -#include -#include #include #include #include "gtest/gtest.h" +#include "ck/library/utility/conv_util.hpp" +#include "config.hpp" +#include "conv_util.hpp" #include "data_type.hpp" #include "element_wise_operation.hpp" -#include "ck/library/utility/conv_util.hpp" -#include "conv_util.hpp" +#include "fill.hpp" namespace { -template -bool test_conv2d_nhwc_instances(const std::vector& conv_ptrs) +class Conv2dFwdNHWCInstances : public ::testing::Test +{ + public: + template + bool test_conv2d_nhwc_instances(const std::vector& conv_ptrs, + const ck::utils::conv::ConvParams& params) + { + using namespace std::placeholders; + using namespace ck::utils; + + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(atol_); + run_engine.SetRtol(rtol_); + return run_engine.Test(conv_ptrs); + } + + template + bool test_default(bool use_convnd = false) + { + if(use_convnd) + { + return test_conv2d_nhwc_instances( + test::conv::ConvolutionNDFwdInstances::Get(2), params_default_); + } + else + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), + params_default_); + } + } + + template + bool test_filter1x1_stride1_pad0(bool use_convnd = false) + { + if(use_convnd) + { + return test_conv2d_nhwc_instances( + test::conv::ConvolutionNDFwdInstances::Get(2), + params_filter1x1_stride1_pad0_); + } + else + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), + params_filter1x1_stride1_pad0_); + } + } + + template + bool test_filter1x1_pad0(bool use_convnd = false) + { + if(use_convnd) + { + return test_conv2d_nhwc_instances( + test::conv::ConvolutionNDFwdInstances::Get(2), params_filter1x1_pad0_); + } + else + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), + params_filter1x1_pad0_); + } + } + + template + bool test_oddC() + { + return test_conv2d_nhwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<2>(), params_oddC_); + } + + static inline ck::utils::conv::ConvParams params_default_{ + 2, 4, 256, 64, {3, 3}, {36, 36}, {2, 2}, {2, 2}, {2, 2}, {2, 2}}; + static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{ + 2, 4, 256, 64, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{ + 2, 4, 256, 64, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}}; + static inline ck::utils::conv::ConvParams params_oddC_{ + 2, 4, 256, 3, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + + private: + double atol_{1e-5}; + double rtol_{1e-4}; +}; + +} // anonymous namespace + +TEST(Conv2DFwdNHWC, IntegerValues) { using namespace std::placeholders; using namespace ck::utils; + using T = float; - conv::ConvParams params; - params.num_dim_spatial_ = 2; - params.filter_spatial_lengths_ = std::vector{3, 3}; - params.input_spatial_lengths_ = std::vector{71, 71}; - params.conv_filter_strides_ = std::vector{2, 2}; - params.conv_filter_dilations_ = std::vector{1, 1}; - params.input_left_pads_ = std::vector{1, 1}; - params.input_right_pads_ = std::vector{1, 1}; + ck::utils::conv::ConvParams params{ + 2, 4, 256, 64, {3, 3}, {36, 36}, {1, 1}, {2, 2}, {2, 2}, {2, 2}}; - conv::ConvFwdOpInstance conv_instance(params); + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<2, T, T, T, T>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); auto reference_conv_fwd_fun = std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3); OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); - return run_engine.Test(conv_ptrs); -} - -} // anonymous namespace - -TEST(Conv2DFwdNHWC, TestConv2D) -{ - using namespace std::placeholders; - using namespace ck::utils; - - ck::utils::conv::ConvParams params; - params.N_ = 2; - params.K_ = 16; - params.C_ = 4; - params.input_spatial_lengths_ = std::vector{16, 16}; - params.conv_filter_strides_ = std::vector{1, 1}; - - std::vector conv_ptrs; - test::conv::get_test_convolution_fwd_instance<2>(conv_ptrs); - conv::ConvFwdOpInstance conv_instance(params); - - auto reference_conv_fwd_fun = std::bind( - conv::run_reference_convolution_forward<2, float, float, float>, params, _1, _2, _3); - OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); run_engine.SetAtol(1e-5); run_engine.SetRtol(1e-4); EXPECT_TRUE(run_engine.Test(conv_ptrs)); } -TEST(Conv2DFwdNHWC, Bf16Instances) +TEST(Conv2DFwdNHWC, FloatingPointValues) { - EXPECT_TRUE(test_conv2d_nhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<2>())); + using namespace std::placeholders; + using namespace ck::utils; + using T = ck::half_t; + + ck::utils::conv::ConvParams params{ + 2, 4, 256, 64, {3, 3}, {36, 36}, {2, 2}, {2, 2}, {2, 2}, {2, 2}}; + + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<2, T, T, T, float>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistribution> + conv_instance(params, true, FillUniformDistribution{}, FillUniformDistribution{}); + + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(2e-4); + run_engine.SetRtol(1e-3); + EXPECT_TRUE(run_engine.Test(conv_ptrs)); } -TEST(Conv2DFwdNHWC, F16Instances) +TEST_F(Conv2dFwdNHWCInstances, BF16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv2dFwdNHWCInstances, BF16_filter1x1_stride1_pad0) { - EXPECT_TRUE(test_conv2d_nhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<2>())); + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, BF16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, F16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv2dFwdNHWCInstances, F16_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, F16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, F16_oddC) { EXPECT_TRUE(this->test_oddC()); } +TEST_F(Conv2dFwdNHWCInstances, F32_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv2dFwdNHWCInstances, F32_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, F32_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, I8_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv2dFwdNHWCInstances, I8_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv2dFwdNHWCInstances, I8_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); } -TEST(Conv2DFwdNHWC, BF32Instances) +TEST_F(Conv2dFwdNHWCInstances, ND_BF16_default) { - EXPECT_TRUE(test_conv2d_nhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<2>())); + EXPECT_TRUE(this->test_default(true)); } - -TEST(Conv2DFwdNHWC, F32Instances) +TEST_F(Conv2dFwdNHWCInstances, ND_BF16_filter1x1_stride1_pad0) { - EXPECT_TRUE(test_conv2d_nhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<2>())); + EXPECT_TRUE(this->test_filter1x1_stride1_pad0(true)); } - -TEST(Conv2DFwdNHWC, Int8Instances) +TEST_F(Conv2dFwdNHWCInstances, ND_BF16_filter1x1_pad0) { - EXPECT_TRUE(test_conv2d_nhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<2>())); + EXPECT_TRUE(this->test_filter1x1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_F16_default) +{ + EXPECT_TRUE(this->test_default(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_F16_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_F16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_F32_default) { EXPECT_TRUE(this->test_default(true)); } +TEST_F(Conv2dFwdNHWCInstances, ND_F32_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_F32_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_I8_default) { EXPECT_TRUE(this->test_default(true)); } +TEST_F(Conv2dFwdNHWCInstances, ND_I8_filter1x1_stride1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_stride1_pad0(true)); +} +TEST_F(Conv2dFwdNHWCInstances, ND_I8_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0(true)); } diff --git a/test/convnd_fwd/conv3d_fwd.cpp b/test/convnd_fwd/conv3d_fwd.cpp index c6f0e7ec07..2470727fd7 100644 --- a/test/convnd_fwd/conv3d_fwd.cpp +++ b/test/convnd_fwd/conv3d_fwd.cpp @@ -12,61 +12,143 @@ namespace { -template -bool test_conv3d_ndhwc_instances(const std::vector& conv_ptrs) +class Conv3dFwdNDHWCInstances : public ::testing::Test +{ + public: + template + bool test_conv3d_nwc_instances(const std::vector& conv_ptrs, + const ck::utils::conv::ConvParams& params) + { + using namespace std::placeholders; + using namespace ck::utils; + namespace ctl = ck::tensor_layout::convolution; + + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<3, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(atol_); + run_engine.SetRtol(rtol_); + return run_engine.Test(conv_ptrs); + } + + template + bool test_default() + { + return test_conv3d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<3>(), params_default_); + } + + template + bool test_filter1x1_stride1_pad0() + { + return test_conv3d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<3>(), + params_filter1x1_stride1_pad0_); + } + + template + bool test_filter1x1_pad0() + { + return test_conv3d_nwc_instances( + ck::utils::conv::ConvolutionFwdInstances::template Get<3>(), + params_filter1x1_pad0_); + } + + static inline ck::utils::conv::ConvParams params_default_{ + 3, 4, 256, 64, {3, 3, 3}, {28, 28, 28}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}}; + static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{ + 3, 4, 256, 64, {1, 1, 1}, {28, 28, 28}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}; + static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{ + 3, 4, 256, 64, {1, 1, 1}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}}; + + private: + double atol_{1e-5}; + double rtol_{1e-4}; +}; + +} // anonymous namespace + +TEST(Conv3DFwdNDHWC, IntegerValues) { using namespace std::placeholders; using namespace ck::utils; namespace ctl = ck::tensor_layout::convolution; + using T = float; - conv::ConvParams params; - params.N_ = 64; - params.num_dim_spatial_ = 3; - params.filter_spatial_lengths_ = std::vector{3, 3, 2}; - params.input_spatial_lengths_ = std::vector{32, 32, 2}; - params.conv_filter_strides_ = std::vector{2, 2, 2}; - params.conv_filter_dilations_ = std::vector{1, 1, 1}; - params.input_left_pads_ = std::vector{1, 1, 1}; - params.input_right_pads_ = std::vector{1, 1, 1}; + ck::utils::conv::ConvParams params{ + 3, 4, 256, 64, {3, 3, 3}, {18, 18, 18}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}}; - conv::ConvFwdOpInstance conv_instance(params); + std::vector conv_ptrs; + test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistributionIntegerValue> + conv_instance(params, + true, + FillUniformDistributionIntegerValue{}, + FillUniformDistributionIntegerValue{}); auto reference_conv_fwd_fun = std::bind(conv::run_reference_convolution_forward<3, T, T, T>, params, _1, _2, _3); OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); - return run_engine.Test(conv_ptrs); + run_engine.SetAtol(1e-5); + run_engine.SetRtol(1e-3); + EXPECT_TRUE(run_engine.Test(conv_ptrs)); } -} // anonymous namespace - -TEST(Conv3DFwdNDHWC, TestConv3D) +TEST(Conv3DFwdNDHWC, FloatingPointValues) { using namespace std::placeholders; using namespace ck::utils; namespace ctl = ck::tensor_layout::convolution; + using T = ck::half_t; - conv::ConvParams params; - params.num_dim_spatial_ = 3; - params.N_ = 2; - params.K_ = 16; - params.C_ = 4; - params.filter_spatial_lengths_ = std::vector{3, 3, 3}; - params.input_spatial_lengths_ = std::vector{16, 16, 16}; - params.conv_filter_strides_ = std::vector{1, 1, 1}; - params.conv_filter_dilations_ = std::vector{1, 1, 1}; - params.input_left_pads_ = std::vector{1, 1, 1}; - params.input_right_pads_ = std::vector{1, 1, 1}; + ck::utils::conv::ConvParams params{ + 3, 4, 256, 64, {3, 3, 3}, {18, 18, 18}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}}; std::vector conv_ptrs; - test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs); - conv::ConvFwdOpInstance conv_instance( - params); + test::conv::get_test_convolution_fwd_instance<3, T, T, T, float>(conv_ptrs); + conv::ConvFwdOpInstance, + FillUniformDistribution> + conv_instance(params, true, FillUniformDistribution{}, FillUniformDistribution{}); - auto reference_conv_fwd_fun = std::bind( - conv::run_reference_convolution_forward<3, float, float, float>, params, _1, _2, _3); - OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); - run_engine.SetAtol(1e-5); - run_engine.SetRtol(1e-4); + auto reference_conv_fwd_fun = + std::bind(conv::run_reference_convolution_forward<3, T, T, T>, params, _1, _2, _3); + OpInstanceRunEngine run_engine(conv_instance, reference_conv_fwd_fun); + run_engine.SetAtol(1e-3); + run_engine.SetRtol(1e-3); EXPECT_TRUE(run_engine.Test(conv_ptrs)); } @@ -74,6 +156,7 @@ TEST(Conv3DFwdNDHWC, InputOver2GB) { using PassThrough = ck::tensor_operation::element_wise::PassThrough; using namespace ck::utils; + using T = float; // >2GB Input conv::ConvParams params; @@ -89,8 +172,7 @@ TEST(Conv3DFwdNDHWC, InputOver2GB) params.input_right_pads_ = std::vector{1, 1, 1}; std::vector conv_ptrs; - test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs); - + test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs); auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr, nullptr, nullptr, @@ -114,6 +196,7 @@ TEST(Conv3DFwdNDHWC, FiltersOver2GB) { using PassThrough = ck::tensor_operation::element_wise::PassThrough; using namespace ck::utils; + using T = float; // >2GB Filters conv::ConvParams params; @@ -129,8 +212,7 @@ TEST(Conv3DFwdNDHWC, FiltersOver2GB) params.input_right_pads_ = std::vector{1, 1, 1}; std::vector conv_ptrs; - test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs); - + test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs); auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr, nullptr, nullptr, @@ -154,6 +236,7 @@ TEST(Conv3DFwdNDHWC, OutputOver2GB) { using PassThrough = ck::tensor_operation::element_wise::PassThrough; using namespace ck::utils; + using T = float; // >2GB Output conv::ConvParams params; @@ -169,7 +252,7 @@ TEST(Conv3DFwdNDHWC, OutputOver2GB) params.input_right_pads_ = std::vector{2, 2, 2}; std::vector conv_ptrs; - test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs); + test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs); auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr, nullptr, nullptr, @@ -189,26 +272,42 @@ TEST(Conv3DFwdNDHWC, OutputOver2GB) EXPECT_FALSE(conv_ptrs.back()->IsSupportedArgument(arg.get())); } -TEST(Conv3DFwdNDHWC, Bf16Instances) +TEST_F(Conv3dFwdNDHWCInstances, BF16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv3dFwdNDHWCInstances, BF16_filter1x1_stride1_pad0) { - EXPECT_TRUE(test_conv3d_ndhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<3>())); + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv3dFwdNDHWCInstances, BF16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); } -TEST(Conv3DFwdNDHWC, F16Instances) +TEST_F(Conv3dFwdNDHWCInstances, F16_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv3dFwdNDHWCInstances, F16_filter1x1_stride1_pad0) { - EXPECT_TRUE(test_conv3d_ndhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<3>())); + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv3dFwdNDHWCInstances, F16_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); } -TEST(Conv3DFwdNDHWC, F32Instances) +TEST_F(Conv3dFwdNDHWCInstances, F32_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv3dFwdNDHWCInstances, F32_filter1x1_stride1_pad0) { - EXPECT_TRUE(test_conv3d_ndhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<3>())); + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv3dFwdNDHWCInstances, F32_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); } -TEST(Conv3DFwdNDHWC, Int8Instances) +TEST_F(Conv3dFwdNDHWCInstances, I8_default) { EXPECT_TRUE(this->test_default()); } +TEST_F(Conv3dFwdNDHWCInstances, I8_filter1x1_stride1_pad0) { - EXPECT_TRUE(test_conv3d_ndhwc_instances( - ck::utils::conv::ConvolutionFwdInstances::Get<3>())); + EXPECT_TRUE(this->test_filter1x1_stride1_pad0()); +} +TEST_F(Conv3dFwdNDHWCInstances, I8_filter1x1_pad0) +{ + EXPECT_TRUE(this->test_filter1x1_pad0()); } diff --git a/test/convnd_fwd/conv_util.hpp b/test/convnd_fwd/conv_util.hpp index 09f641b415..1ec83bd118 100644 --- a/test/convnd_fwd/conv_util.hpp +++ b/test/convnd_fwd/conv_util.hpp @@ -1,14 +1,33 @@ -#ifndef TEST_CONV_UTIL_HPP -#define TEST_CONV_UTIL_HPP +#pragma once #include #include "config.hpp" +#include "data_type.hpp" #include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp" #include "element_wise_operation.hpp" #include "host_tensor.hpp" #include "sequence.hpp" +namespace ck { +namespace tensor_operation { +namespace device { + +using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr; +namespace device_conv2d_fwd_instance { + +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector&); +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector&); +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector&); +void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector&); + +} // namespace device_conv2d_fwd_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + namespace test { namespace conv { @@ -25,57 +44,128 @@ using DeviceConvFwdNoOpPtr = static constexpr auto ConvFwdDefault = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; -template +template using DeviceConvNDFwdInstance = ck::tensor_operation::device:: DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< // clang-format off InDataType, // WeiDataType, // OutDataType, // - InDataType, // + AccDataType, // Accumulator data type. InElementOp, // Input Elementwise Operation WeiElementOp, // Weights Elementwise Operation OutElementOp, // Output Elementwise Operation ConvFwdDefault, // ConvForwardSpecialization SpatialDims, // SptialDims - 64, // BlockSize - 16, // MPerBlock - 16, // NPerBlock + 256, // BlockSize + 128, // MPerBlock + 256, // NPerBlock 4, // K0PerBlock - 1, // K1 - 16, // MPerXDL - 16, // NPerXDL - 1, // MXdlPerWave - 1, // NXdlPerWave - S<1, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + 8, // K1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 4, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder 2, // ABlockTransferSrcVectorDim - 1, // ABlockTransferSrcScalarPerVector - 1, // ABlockTransferDstScalarPerVector_K1 + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 true, // ABlockLdsAddExtraM - S<1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder 2, // BBlockTransferSrcVectorDim - 1, // BBlockTransferSrcScalarPerVector - 1, // BBlockTransferDstScalarPerVector_K1 - true, // BBlockTransferAddExtraN + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN 7, // CThreadTransferSrcDstVectorDim - 1>; // CThreadTransferDstScalarPerVector + 1>; // CThreadTransferDstScalarPerVector // clang-format on template + typename InDataType, + typename WeiDataType, + typename OutDataType, + typename AccDataType> void get_test_convolution_fwd_instance(std::vector& instances) { - using ConvInstanceT = DeviceConvNDFwdInstance; + using ConvInstanceT = + DeviceConvNDFwdInstance; instances.emplace_back(std::make_unique()); } +// TODO (aosewski) +// Temporary solution to get all DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K +// instances. When switched over to DeviceConvNDFwdXdl for 2D remove ConvolutionNDFwdInstances +// structures. +template +struct ConvolutionNDFwdInstances; + +template <> +struct ConvolutionNDFwdInstances +{ + static std::vector Get(std::size_t num_dim_spatial) + { + std::vector conv_ptrs; + if(num_dim_spatial == 2) + { + ck::tensor_operation::device::device_conv2d_fwd_instance:: + add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); + } + return conv_ptrs; + } +}; + +template <> +struct ConvolutionNDFwdInstances +{ + static std::vector Get(std::size_t num_dim_spatial) + { + std::vector conv_ptrs; + if(num_dim_spatial == 2) + { + ck::tensor_operation::device::device_conv2d_fwd_instance:: + add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); + } + return conv_ptrs; + } +}; + +template <> +struct ConvolutionNDFwdInstances +{ + static std::vector Get(std::size_t num_dim_spatial) + { + std::vector conv_ptrs; + if(num_dim_spatial == 2) + { + ck::tensor_operation::device::device_conv2d_fwd_instance:: + add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); + } + return conv_ptrs; + } +}; + +template <> +struct ConvolutionNDFwdInstances +{ + static std::vector Get(std::size_t num_dim_spatial) + { + std::vector conv_ptrs; + if(num_dim_spatial == 2) + { + ck::tensor_operation::device::device_conv2d_fwd_instance:: + add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); + } + return conv_ptrs; + } +}; + } // namespace conv } // namespace test - -#endif