mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
refactor
This commit is contained in:
@@ -35,7 +35,7 @@ enum ConvForwardAlgo
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
using namespace ck_driver;
|
||||
using namespace ck::driver;
|
||||
using size_t = std::size_t;
|
||||
|
||||
hipStream_t stream;
|
||||
|
||||
@@ -231,7 +231,7 @@ void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
using namespace ck_driver;
|
||||
using namespace ck::driver;
|
||||
using namespace detail_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw;
|
||||
using size_t = std::size_t;
|
||||
|
||||
|
||||
@@ -227,7 +227,7 @@ void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
using namespace ck_driver;
|
||||
using namespace ck::driver;
|
||||
using namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw;
|
||||
using size_t = std::size_t;
|
||||
|
||||
|
||||
@@ -31,11 +31,11 @@ void online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcy
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
const ck_driver::CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param,
|
||||
const ck::driver::CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
using namespace ck_driver;
|
||||
using namespace ck::driver;
|
||||
using size_t = std::size_t;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
@@ -100,8 +100,9 @@ void online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcy
|
||||
"dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp";
|
||||
std::string algo_name = "implicit_gemm_conv_fwd_v6r1_dlops_nchw";
|
||||
|
||||
std::string compile_param_string = get_ck_hip_online_compile_common_flag() + compile_param.GetCompileParameterString();
|
||||
std::string network_config = compile_param_string;
|
||||
std::string compile_param_string =
|
||||
get_ck_hip_online_compile_common_flag() + compile_param.GetCompileParameterString();
|
||||
std::string network_config = compile_param_string;
|
||||
|
||||
std::vector<float> kernel1_times;
|
||||
std::vector<float> kernel2_times;
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#ifndef ONLINE_DRIVER_COMMON_HPP
|
||||
#define ONLINE_DRIVER_COMMON_HPP
|
||||
|
||||
namespace ck_driver {
|
||||
namespace ck {
|
||||
namespace driver {
|
||||
|
||||
inline auto get_ck_hip_online_compile_common_flag()
|
||||
{
|
||||
@@ -47,5 +48,6 @@ auto gcd(X x, Ys... ys)
|
||||
return gcd(x, gcd(ys...));
|
||||
}
|
||||
|
||||
} // namespace ck_driver
|
||||
} // namespace driver
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
|
||||
#include <numeric>
|
||||
|
||||
namespace ck_driver {
|
||||
namespace ck {
|
||||
namespace driver {
|
||||
|
||||
struct CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw
|
||||
{
|
||||
@@ -669,5 +670,6 @@ struct ConvIgemmFwdV6r1DlopsNchwKcyxNkhw
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_driver
|
||||
} // namespace driver
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#ifndef CONVOLUTION_PROBLEM_DESCRIPTOR
|
||||
#define CONVOLUTION_PROBLEM_DESCRIPTOR
|
||||
|
||||
namespace ck_driver {
|
||||
namespace ck {
|
||||
namespace driver {
|
||||
|
||||
struct ConvolutionProblemDescriptor
|
||||
{
|
||||
@@ -75,5 +76,6 @@ struct ConvolutionProblemDescriptor
|
||||
std::size_t CalculateFlop() const { return 2L * N * K * C * Y * X * Ho * Wo; }
|
||||
};
|
||||
|
||||
} // namespace ck_driver
|
||||
} // namespace driver
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user