From d1a0ccb5423b6d4dd93ad848219466566abeb077 Mon Sep 17 00:00:00 2001 From: Shaojie WANG Date: Wed, 25 May 2022 00:13:00 +0800 Subject: [PATCH] add GetWorkSpaceSize to base arg (#253) * add GetWorkSpaceSize to base arg and make an example on convnd_bwd_weight * remove redundant compute * use datatype and split k to check whether a workspace is used * remove unused computation for work space size [ROCm/composable_kernel commit: 0d08cf1893a3aa568249ce1c101556fde9c8f613] --- .../convnd_bwd_weight_xdl.cpp | 56 ++++++++++++++++--- .../gpu/device/device_base.hpp | 2 + ...rd_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp | 51 +++++++++++++++++ 3 files changed, 100 insertions(+), 9 deletions(-) diff --git a/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp b/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp index 1f709808b1..0fc976c34a 100644 --- a/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp +++ b/example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp @@ -257,11 +257,11 @@ int main(int argc, char* argv[]) case 0: break; case 1: out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; default: out_n_k_ho_wo.GenerateTensorValue(GeneratorTensor_1{1}); - in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1{1}); + in_n_c_hi_wi.GenerateTensorValue(GeneratorTensor_1{1}); } DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); @@ -296,15 +296,53 @@ int main(int argc, char* argv[]) OutElementOp{}, split_k); - if(!conv->IsSupportedArgument(argument.get())) + // alloc work space + size_t bwd_weight_workspace_size = conv->GetWorkSpaceSize(argument.get()); + float ave_time = 0.f; + if(std::is_same::value && split_k > 1) { - std::cout << "wrong! device_conv with the specified compilation parameters does " - "not support this Conv problem" - << std::endl; - return 1; - } + DeviceMem wei_work_space_device_buf(bwd_weight_workspace_size); + wei_work_space_device_buf.SetZero(); + argument = conv->MakeArgumentPointer( + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_work_space_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer()), + params.N_, + params.K_, + params.C_, + params.input_spatial_lengths_, + params.filter_spatial_lengths_, + output_spatial_lengths, + params.conv_filter_strides_, + params.conv_filter_dilations_, + params.input_left_pads_, + params.input_right_pads_, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}, + split_k); - float ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + if(!conv->IsSupportedArgument(argument.get())) + { + std::cout << "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem" + << std::endl; + return 1; + } + + ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + } + else + { + if(!conv->IsSupportedArgument(argument.get())) + { + std::cout << "wrong! device_conv with the specified compilation parameters does " + "not support this Conv problem" + << std::endl; + return 1; + } + ave_time = invoker->Run(argument.get(), StreamConfig{nullptr, time_kernel}); + } std::size_t flop = ck::utils::conv::get_flops( params.N_, params.C_, params.K_, params.filter_spatial_lengths_, output_spatial_lengths); diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index 950cfc1d61..9bc3cb1a02 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -40,6 +40,8 @@ struct BaseOperator virtual bool IsSupportedArgument(const BaseArgument*) { return false; } virtual std::string GetTypeString() const { return ""; } + virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; } + virtual ~BaseOperator() {} }; diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp index 96a86b39db..dde9e0f873 100644 --- a/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -1175,6 +1175,57 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_ return str.str(); } + + template ::type = false> + static size_t GetWorkSpaceSize(const Argument& arg) + { + size_t WorkSpaceSize = 0; + if(arg.k_batch_ > 1) + { + if constexpr(std::is_same::value) + { + WorkSpaceSize = + arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] * sizeof(float); + } + } + return WorkSpaceSize; + } + + template ::type = false> + static size_t GetWorkSpaceSize(const Argument& arg) + { + size_t WorkSpaceSize = 0; + if(arg.k_batch_ > 1) + { + if constexpr(std::is_same::value) + { + WorkSpaceSize = arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] * + arg.filter_spatial_lengths_[1] * sizeof(float); + } + } + return WorkSpaceSize; + } + + template ::type = false> + static size_t GetWorkSpaceSize(const Argument& arg) + { + size_t WorkSpaceSize = 0; + if(arg.k_batch_ > 1) + { + if constexpr(std::is_same::value) + { + WorkSpaceSize = arg.Conv_K_ * arg.Conv_C_ * arg.filter_spatial_lengths_[0] * + arg.filter_spatial_lengths_[1] * arg.filter_spatial_lengths_[2] * + sizeof(float); + } + } + return WorkSpaceSize; + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override final + { + return GetWorkSpaceSize(*dynamic_cast(p_arg)); + } }; } // namespace device