mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
refactor
This commit is contained in:
@@ -35,7 +35,7 @@ enum ConvForwardAlgo
|
|||||||
int main(int argc, char* argv[])
|
int main(int argc, char* argv[])
|
||||||
{
|
{
|
||||||
using namespace ck;
|
using namespace ck;
|
||||||
using namespace ck_driver;
|
using namespace ck::driver;
|
||||||
using size_t = std::size_t;
|
using size_t = std::size_t;
|
||||||
|
|
||||||
hipStream_t stream;
|
hipStream_t stream;
|
||||||
|
|||||||
@@ -231,7 +231,7 @@ void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcy
|
|||||||
ck::index_t nrepeat)
|
ck::index_t nrepeat)
|
||||||
{
|
{
|
||||||
using namespace ck;
|
using namespace ck;
|
||||||
using namespace ck_driver;
|
using namespace ck::driver;
|
||||||
using namespace detail_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw;
|
using namespace detail_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw;
|
||||||
using size_t = std::size_t;
|
using size_t = std::size_t;
|
||||||
|
|
||||||
|
|||||||
@@ -227,7 +227,7 @@ void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kc
|
|||||||
ck::index_t nrepeat)
|
ck::index_t nrepeat)
|
||||||
{
|
{
|
||||||
using namespace ck;
|
using namespace ck;
|
||||||
using namespace ck_driver;
|
using namespace ck::driver;
|
||||||
using namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw;
|
using namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw;
|
||||||
using size_t = std::size_t;
|
using size_t = std::size_t;
|
||||||
|
|
||||||
|
|||||||
@@ -31,11 +31,11 @@ void online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcy
|
|||||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||||
const Tensor<TInWei>& wei_k_c_y_x,
|
const Tensor<TInWei>& wei_k_c_y_x,
|
||||||
Tensor<TOut>& out_n_k_ho_wo,
|
Tensor<TOut>& out_n_k_ho_wo,
|
||||||
const ck_driver::CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param,
|
const ck::driver::CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw& compile_param,
|
||||||
ck::index_t nrepeat)
|
ck::index_t nrepeat)
|
||||||
{
|
{
|
||||||
using namespace ck;
|
using namespace ck;
|
||||||
using namespace ck_driver;
|
using namespace ck::driver;
|
||||||
using size_t = std::size_t;
|
using size_t = std::size_t;
|
||||||
|
|
||||||
std::cout << __func__ << std::endl;
|
std::cout << __func__ << std::endl;
|
||||||
@@ -100,8 +100,9 @@ void online_device_dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcy
|
|||||||
"dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp";
|
"dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp";
|
||||||
std::string algo_name = "implicit_gemm_conv_fwd_v6r1_dlops_nchw";
|
std::string algo_name = "implicit_gemm_conv_fwd_v6r1_dlops_nchw";
|
||||||
|
|
||||||
std::string compile_param_string = get_ck_hip_online_compile_common_flag() + compile_param.GetCompileParameterString();
|
std::string compile_param_string =
|
||||||
std::string network_config = compile_param_string;
|
get_ck_hip_online_compile_common_flag() + compile_param.GetCompileParameterString();
|
||||||
|
std::string network_config = compile_param_string;
|
||||||
|
|
||||||
std::vector<float> kernel1_times;
|
std::vector<float> kernel1_times;
|
||||||
std::vector<float> kernel2_times;
|
std::vector<float> kernel2_times;
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
#ifndef ONLINE_DRIVER_COMMON_HPP
|
#ifndef ONLINE_DRIVER_COMMON_HPP
|
||||||
#define ONLINE_DRIVER_COMMON_HPP
|
#define ONLINE_DRIVER_COMMON_HPP
|
||||||
|
|
||||||
namespace ck_driver {
|
namespace ck {
|
||||||
|
namespace driver {
|
||||||
|
|
||||||
inline auto get_ck_hip_online_compile_common_flag()
|
inline auto get_ck_hip_online_compile_common_flag()
|
||||||
{
|
{
|
||||||
@@ -47,5 +48,6 @@ auto gcd(X x, Ys... ys)
|
|||||||
return gcd(x, gcd(ys...));
|
return gcd(x, gcd(ys...));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace ck_driver
|
} // namespace driver
|
||||||
|
} // namespace ck
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -3,7 +3,8 @@
|
|||||||
|
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
|
|
||||||
namespace ck_driver {
|
namespace ck {
|
||||||
|
namespace driver {
|
||||||
|
|
||||||
struct CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw
|
struct CompileParameterConvIgemmFwdV6r1DlopsNchwKcyxNkhw
|
||||||
{
|
{
|
||||||
@@ -669,5 +670,6 @@ struct ConvIgemmFwdV6r1DlopsNchwKcyxNkhw
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace ck_driver
|
} // namespace driver
|
||||||
|
} // namespace ck
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
#ifndef CONVOLUTION_PROBLEM_DESCRIPTOR
|
#ifndef CONVOLUTION_PROBLEM_DESCRIPTOR
|
||||||
#define CONVOLUTION_PROBLEM_DESCRIPTOR
|
#define CONVOLUTION_PROBLEM_DESCRIPTOR
|
||||||
|
|
||||||
namespace ck_driver {
|
namespace ck {
|
||||||
|
namespace driver {
|
||||||
|
|
||||||
struct ConvolutionProblemDescriptor
|
struct ConvolutionProblemDescriptor
|
||||||
{
|
{
|
||||||
@@ -75,5 +76,6 @@ struct ConvolutionProblemDescriptor
|
|||||||
std::size_t CalculateFlop() const { return 2L * N * K * C * Y * X * Ho * Wo; }
|
std::size_t CalculateFlop() const { return 2L * N * K * C * Y * X * Ho * Wo; }
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace ck_driver
|
} // namespace driver
|
||||||
|
} // namespace ck
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
Reference in New Issue
Block a user