From cb95421311dfc625edf5e0c59aa243aac1b00268 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 6 Aug 2021 22:17:51 +0000 Subject: [PATCH] refactor --- host/driver_online/conv_fwd_driver_online.cpp | 2 +- ...n_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp | 2 +- ..._forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp | 2 +- ...n_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp | 9 +++++---- host/driver_online/include/online_driver_common.hpp | 6 ++++-- .../include/conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp | 6 ++++-- host/solver/include/convolution_problem_descriptor.hpp | 6 ++++-- 7 files changed, 20 insertions(+), 13 deletions(-) diff --git a/host/driver_online/conv_fwd_driver_online.cpp b/host/driver_online/conv_fwd_driver_online.cpp index 29609d5474..53e6179aa6 100644 --- a/host/driver_online/conv_fwd_driver_online.cpp +++ b/host/driver_online/conv_fwd_driver_online.cpp @@ -35,7 +35,7 @@ enum ConvForwardAlgo int main(int argc, char* argv[]) { using namespace ck; - using namespace ck_driver; + using namespace ck::driver; using size_t = std::size_t; hipStream_t stream; diff --git a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp index 06412fba0b..419b8ca95d 100644 --- a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw.hpp @@ -231,7 +231,7 @@ void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy ck::index_t nrepeat) { using namespace ck; - using namespace ck_driver; + using namespace ck::driver; using namespace detail_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw; using size_t = std::size_t; diff --git a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp index 61ce41fe84..46d065f615 100644 --- a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp @@ -227,7 +227,7 @@ void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc ck::index_t nrepeat) { using namespace ck; - using namespace ck_driver; + using namespace ck::driver; using namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw; using size_t = std::size_t; diff --git a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp index 92467a7668..7b88ef02b4 100644 --- a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.hpp @@ -31,11 +31,11 @@ void online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcy const Tensor& in_n_c_hi_wi, const Tensor& wei_k_c_y_x, Tensor& out_n_k_ho_wo, - const ck_driver::CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param, + const ck::driver::CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param, ck::index_t nrepeat) { using namespace ck; - using namespace ck_driver; + using namespace ck::driver; using size_t = std::size_t; std::cout << __func__ << std::endl; @@ -100,8 +100,9 @@ void online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcy "dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp"; std::string algo_name = "implicit_gemm_conv_fwd_v6r1_dlops_nchw"; - std::string compile_param_string = get_ck_hip_online_compile_common_flag() + compile_param.GetCompileParameterString(); - std::string network_config = compile_param_string; + std::string compile_param_string = + get_ck_hip_online_compile_common_flag() + compile_param.GetCompileParameterString(); + std::string network_config = compile_param_string; std::vector kernel1_times; std::vector kernel2_times; diff --git a/host/driver_online/include/online_driver_common.hpp b/host/driver_online/include/online_driver_common.hpp index d05a156d89..508a3594cd 100644 --- a/host/driver_online/include/online_driver_common.hpp +++ b/host/driver_online/include/online_driver_common.hpp @@ -1,7 +1,8 @@ #ifndef ONLINE_DRIVER_COMMON_HPP #define ONLINE_DRIVER_COMMON_HPP -namespace ck_driver { +namespace ck { +namespace driver { inline auto get_ck_hip_online_compile_common_flag() { @@ -47,5 +48,6 @@ auto gcd(X x, Ys... ys) return gcd(x, gcd(ys...)); } -} // namespace ck_driver +} // namespace driver +} // namespace ck #endif diff --git a/host/solver/include/conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp b/host/solver/include/conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp index b0c4921019..a30c2720ee 100644 --- a/host/solver/include/conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp +++ b/host/solver/include/conv_igemm_fwd_v6r1_dlops_nchw_kcyx_nkhw.hpp @@ -3,7 +3,8 @@ #include -namespace ck_driver { +namespace ck { +namespace driver { struct CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw { @@ -669,5 +670,6 @@ struct ConvIgemmFwdV6r1DlopsNchwKcyxNkhw } }; -} // namespace ck_driver +} // namespace driver +} // namespace ck #endif diff --git a/host/solver/include/convolution_problem_descriptor.hpp b/host/solver/include/convolution_problem_descriptor.hpp index df9c110e70..8c0ecbee80 100644 --- a/host/solver/include/convolution_problem_descriptor.hpp +++ b/host/solver/include/convolution_problem_descriptor.hpp @@ -1,7 +1,8 @@ #ifndef CONVOLUTION_PROBLEM_DESCRIPTOR #define CONVOLUTION_PROBLEM_DESCRIPTOR -namespace ck_driver { +namespace ck { +namespace driver { struct ConvolutionProblemDescriptor { @@ -75,5 +76,6 @@ struct ConvolutionProblemDescriptor std::size_t CalculateFlop() const { return 2L * N * K * C * Y * X * Ho * Wo; } }; -} // namespace ck_driver +} // namespace driver +} // namespace ck #endif