From f885c131d8af88a8235677c6e2c0453373570c2b Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 9 Aug 2021 22:13:47 +0000 Subject: [PATCH] tidy --- .../include/utility/amd_address_space.hpp | 6 ++ composable_kernel/include/utility/print.hpp | 48 ---------------- ...licit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp | 3 - .../driver_dynamic_contraction_dlops_v1r2.hpp | 4 -- ...mplicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp | 4 -- ..._gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp | 4 -- .../driver_dynamic_gemm_dlops_v1r2.hpp | 48 ++++++++-------- .../driver_dynamic_gemm_dlops_v1r3.hpp | 56 +++++++++++-------- .../driver_dynamic_gemm_xdlops_v2r3.hpp | 28 +++++----- .../src/conv_fwd_driver_offline.cpp | 16 +++--- host/host_tensor/include/device.hpp | 20 +++---- 11 files changed, 90 insertions(+), 147 deletions(-) diff --git a/composable_kernel/include/utility/amd_address_space.hpp b/composable_kernel/include/utility/amd_address_space.hpp index f9bb6a5133..c5bb1b2cd0 100644 --- a/composable_kernel/include/utility/amd_address_space.hpp +++ b/composable_kernel/include/utility/amd_address_space.hpp @@ -20,6 +20,12 @@ __device__ T* cast_pointer_to_generic_address_space(T CONSTANT* p) return (T*)p; } +template +__host__ __device__ T CONSTANT* cast_pointer_to_constant_address_space(T* p) +{ + return (T CONSTANT*)p; +} + } // namespace ck #endif diff --git a/composable_kernel/include/utility/print.hpp b/composable_kernel/include/utility/print.hpp index 0dd646153a..d7d58bbb83 100644 --- a/composable_kernel/include/utility/print.hpp +++ b/composable_kernel/include/utility/print.hpp @@ -11,59 +11,11 @@ namespace ck { template __host__ __device__ void print_array(const char* s, T a) { - using data_type = decltype(a.At(Number<0>{})); constexpr index_t nsize = a.Size(); -#if 0 - if constexpr(is_same{}) - { - printf("%s size %u, {", s, nsize); - static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%u, ", uint32_t{a[i]}); }); - printf("}\n"); - } - else if constexpr(is_same{}) - { - printf("%s size %d, {", s, nsize); - static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); }); - printf("}\n"); - } - else if constexpr(is_same{}) - { - printf("%s size %d, {", s, nsize); - static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", bool{a[i]}); }); - printf("}\n"); - } -#else printf("%s size %d, {", s, nsize); static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); }); printf("}\n"); -#endif -} - -template -__host__ __device__ void print_array_v2(const char* s, T a) -{ - using data_type = decltype(a.At(Number<0>{})); - constexpr index_t nsize = a.Size(); - -#if 0 - if constexpr(is_same{}) - { - printf("%s size %u, {", s, nsize); - static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%u] %u, ", i.value, a[i]); }); - printf("}\n"); - } - else if constexpr(is_same{}) - { - printf("%s size %d, {", s, nsize); - static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); }); - printf("}\n"); - } -#else - printf("%s size %d, {", s, nsize); - static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); }); - printf("}\n"); -#endif } } // namespace ck diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp index ce94f2071b..d553d2586c 100644 --- a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhwk.hpp @@ -257,9 +257,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_dlops_nhwc_kyxc_nhw const auto K = out_n_ho_wo_k_lengths[I3]; const auto C = wei_k_y_x_c_lengths[I3]; - const auto Hi = in_n_hi_wi_c_lengths[I1]; - const auto Wi = in_n_hi_wi_c_lengths[I2]; - const auto Ho = out_n_ho_wo_k_lengths[I1]; const auto Wo = out_n_ho_wo_k_lengths[I2]; diff --git a/host/driver_offline/include/driver_dynamic_contraction_dlops_v1r2.hpp b/host/driver_offline/include/driver_dynamic_contraction_dlops_v1r2.hpp index 2f175962c1..b520be5b6a 100644 --- a/host/driver_offline/include/driver_dynamic_contraction_dlops_v1r2.hpp +++ b/host/driver_offline/include/driver_dynamic_contraction_dlops_v1r2.hpp @@ -194,7 +194,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, @@ -221,7 +220,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, @@ -248,7 +246,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, @@ -275,7 +272,6 @@ driver_dynamic_contraction_dlops_v1r2(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, diff --git a/host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp index 34b9a54374..693045cd16 100644 --- a/host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw.hpp @@ -244,7 +244,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad dim3(GridSize), dim3(BlockSize), 0, - 0, wei_e_k_global_desc, p_wei_global, in_e_n_ho_wo_global_desc, @@ -270,7 +269,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad dim3(GridSize), dim3(BlockSize), 0, - 0, wei_e_k_global_desc, p_wei_global, in_e_n_ho_wo_global_desc, @@ -296,7 +294,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad dim3(GridSize), dim3(BlockSize), 0, - 0, wei_e_k_global_desc, p_wei_global, in_e_n_ho_wo_global_desc, @@ -322,7 +319,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_pad dim3(GridSize), dim3(BlockSize), 0, - 0, wei_e_k_global_desc, p_wei_global, in_e_n_ho_wo_global_desc, diff --git a/host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp b/host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp index 4e0f6e9f77..2238b355f9 100644 --- a/host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp +++ b/host/driver_offline/include/driver_dynamic_convolution_forward_implicit_gemm_v5r1_dlops_nchw_kcyx_nkhw_outpad.hpp @@ -257,7 +257,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp dim3(GridSize), dim3(BlockSize), 0, - 0, wei_e_k_global_desc, p_wei_global, in_e_n_ho_wo_global_desc, @@ -284,7 +283,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp dim3(GridSize), dim3(BlockSize), 0, - 0, wei_e_k_global_desc, p_wei_global, in_e_n_ho_wo_global_desc, @@ -311,7 +309,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp dim3(GridSize), dim3(BlockSize), 0, - 0, wei_e_k_global_desc, p_wei_global, in_e_n_ho_wo_global_desc, @@ -338,7 +335,6 @@ struct DriverDynamicConvolutionForwardImplicitGemmDlops_v5r1_nchw_kcyx_nkhw_outp dim3(GridSize), dim3(BlockSize), 0, - 0, wei_e_k_global_desc, p_wei_global, in_e_n_ho_wo_global_desc, diff --git a/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r2.hpp b/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r2.hpp index 0ebc68b48a..29a72502d5 100644 --- a/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r2.hpp +++ b/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r2.hpp @@ -189,7 +189,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, @@ -216,7 +215,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, @@ -243,7 +241,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, @@ -270,7 +267,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, @@ -315,14 +311,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, - (void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); } else if(has_main_k_block_loop && !has_double_tail_k_block_loop) { @@ -343,14 +340,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, - (void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); } else if(!has_main_k_block_loop && has_double_tail_k_block_loop) { @@ -371,14 +369,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, - (void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); } else { @@ -399,14 +398,15 @@ __host__ float driver_dynamic_gemm_dlops_v1r2(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, - (void CONSTANT*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + cast_pointer_to_constant_address_space(a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); } return ave_time; diff --git a/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r3.hpp b/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r3.hpp index d075eac822..242bcfb28b 100644 --- a/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r3.hpp +++ b/host/driver_offline/include/driver_dynamic_gemm_dlops_v1r3.hpp @@ -185,7 +185,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, @@ -212,7 +211,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, @@ -239,7 +237,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, @@ -266,7 +263,6 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, @@ -311,14 +307,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, - (void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + cast_pointer_to_constant_address_space( + a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); } else if(has_main_k_block_loop && !has_double_tail_k_block_loop) { @@ -339,14 +338,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, - (void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + cast_pointer_to_constant_address_space( + a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); } else if(!has_main_k_block_loop && has_double_tail_k_block_loop) { @@ -367,14 +369,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, - (void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + cast_pointer_to_constant_address_space( + a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); } else { @@ -395,14 +400,17 @@ __host__ float driver_dynamic_gemm_dlops_v1r3(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, - (void CONSTANT*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + cast_pointer_to_constant_address_space( + a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); } return ave_time; diff --git a/host/driver_offline/include/driver_dynamic_gemm_xdlops_v2r3.hpp b/host/driver_offline/include/driver_dynamic_gemm_xdlops_v2r3.hpp index a2f4e28c54..85f5e27b8d 100644 --- a/host/driver_offline/include/driver_dynamic_gemm_xdlops_v2r3.hpp +++ b/host/driver_offline/include/driver_dynamic_gemm_xdlops_v2r3.hpp @@ -153,7 +153,6 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid, dim3(grid_size), dim3(BlockSize), 0, - 0, p_a_grid, p_b_grid, p_c_grid, @@ -173,20 +172,19 @@ __host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid, c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc); c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); - float ave_time = - launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - (void CONSTANT*)a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer(), - (void CONSTANT*)c_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + float ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); #endif return ave_time; } diff --git a/host/driver_offline/src/conv_fwd_driver_offline.cpp b/host/driver_offline/src/conv_fwd_driver_offline.cpp index 54392f3926..4aac2b5e4f 100644 --- a/host/driver_offline/src/conv_fwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_fwd_driver_offline.cpp @@ -142,10 +142,8 @@ int main(int argc, char* argv[]) std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); - switch(layout) + if(layout == ConvTensorLayout::NCHW) { - case ConvTensorLayout::NCHW: - // NCHW in_lengths_host[0] = static_cast(N); in_lengths_host[1] = static_cast(C); in_lengths_host[2] = static_cast(Hi); @@ -158,9 +156,9 @@ int main(int argc, char* argv[]) out_lengths_host[1] = static_cast(K); out_lengths_host[2] = static_cast(Ho); out_lengths_host[3] = static_cast(Wo); - break; - case ConvTensorLayout::NHWC: - // NHWC + } + else if(layout == ConvTensorLayout::NHWC) + { in_lengths_host[0] = static_cast(N); in_lengths_host[1] = static_cast(Hi); in_lengths_host[2] = static_cast(Wi); @@ -173,8 +171,10 @@ int main(int argc, char* argv[]) out_lengths_host[1] = static_cast(Ho); out_lengths_host[2] = static_cast(Wo); out_lengths_host[3] = static_cast(K); - break; - default: throw std::runtime_error("wrong! not implemented"); + } + else + { + std::runtime_error("wrong! not implemented"); } Tensor in(in_lengths_host); diff --git a/host/host_tensor/include/device.hpp b/host/host_tensor/include/device.hpp index 2299e14921..e2cba94100 100644 --- a/host/host_tensor/include/device.hpp +++ b/host/host_tensor/include/device.hpp @@ -34,24 +34,16 @@ struct KernelTimer using device_stream_t = hipStream_t; template -void launch_kernel(F kernel, - dim3 grid_dim, - dim3 block_dim, - std::size_t lds_byte, - hipStream_t stream_id, - Args... args) +void launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) { + hipStream_t stream_id = nullptr; + hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...); } template -float launch_and_time_kernel(F kernel, - int nrepeat, - dim3 grid_dim, - dim3 block_dim, - std::size_t lds_byte, - hipStream_t stream_id, - Args... args) +float launch_and_time_kernel( + F kernel, int nrepeat, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) { KernelTimer timer; @@ -66,6 +58,8 @@ float launch_and_time_kernel(F kernel, printf("Warm up\n"); + hipStream_t stream_id = nullptr; + // warm up hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);