From 8efcb80fa5261e97fd10cf6fc216217ec066f134 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sun, 6 Feb 2022 22:32:47 -0600 Subject: [PATCH] GEMM+Bias+ReLU+Add (#76) * tweak conv for odd C * update script * clean up elementwise op * fix build * clean up * added example for gemm+bias+relu+add * added example for gemm+bias+relu * add profiler for gemm_s_shuffle; re-org files * add profiler * fix build * clean up * clean up * clean up * fix build [ROCm/composable_kernel commit: 823657ed120144943b7db87c07fe3e647128db56] --- CMakeLists.txt | 1 + .../element_wise_operation.hpp | 195 ++---- .../threadwise_tensor_slice_transfer.hpp | 10 +- .../threadwise_tensor_slice_transfer_v1r4.hpp | 4 +- .../threadwise_tensor_slice_transfer_v3r1.hpp | 16 +- device_operation/CMakeLists.txt | 111 ++++ ...fle_bias_activation_add_nhwc_kyxc_nhwk.hpp | 12 +- .../include/device_gemm_bias_activation.hpp | 43 ++ .../device_gemm_bias_activation_add.hpp | 47 ++ .../include/device_gemm_xdl_c_shuffle.hpp | 5 +- ...ice_gemm_xdl_c_shuffle_bias_activation.hpp | 349 +++++------ ...gemm_xdl_c_shuffle_bias_activation_add.hpp | 574 ++++++++++++++++++ ...s_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp | 7 +- ...atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp | 0 ..._bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp | 7 +- ..._c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp | 7 +- ...2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp | 0 ...2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp | 0 ...relu_add_f16_f16_f16_km_kn_mn_instance.cpp | 52 ++ ...relu_add_f16_f16_f16_km_nk_mn_instance.cpp | 52 ++ ...relu_add_f16_f16_f16_mk_kn_mn_instance.cpp | 52 ++ ...relu_add_f16_f16_f16_mk_nk_mn_instance.cpp | 57 ++ ...ias_relu_f16_f16_f16_km_kn_mn_instance.cpp | 52 ++ ...ias_relu_f16_f16_f16_km_nk_mn_instance.cpp | 52 ++ ...ias_relu_f16_f16_f16_mk_kn_mn_instance.cpp | 52 ++ ...ias_relu_f16_f16_f16_mk_nk_mn_instance.cpp | 57 ++ ..._shuffle_f16_f16_f16_km_kn_mn_instance.cpp | 52 ++ ..._shuffle_f16_f16_f16_km_nk_mn_instance.cpp | 52 ++ ..._shuffle_f16_f16_f16_mk_kn_mn_instance.cpp | 52 ++ ..._shuffle_f16_f16_f16_mk_nk_mn_instance.cpp | 57 ++ ...gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp | 0 ...gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp | 0 ...gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp | 0 ...gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp | 0 ...gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp | 0 ...gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp | 0 ...gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp | 0 ...gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp | 0 ...l_splitk_f32_f32_f32_km_kn_mn_instance.cpp | 0 ...l_splitk_f32_f32_f32_km_nk_mn_instance.cpp | 0 ...l_splitk_f32_f32_f32_mk_kn_mn_instance.cpp | 0 ...l_splitk_f32_f32_f32_mk_nk_mn_instance.cpp | 0 example/1_gemm_xdl/gemm_xdl.cpp | 53 +- example/2_gemm_xdl_bias_relu/README.md | 61 ++ .../gemm_xdl_bias_relu.cpp | 235 +++++++ .../gemm_xdl_bias_relu_add.cpp | 276 +++------ example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp | 36 +- .../conv2d_fwd_xdl_bias_relu.cpp | 29 +- .../conv2d_fwd_xdl_bias_relu_add.cpp | 42 +- .../conv2d_fwd_xdl_bias_relu_atomic_add.cpp | 18 +- example/CMakeLists.txt | 5 +- host/host_tensor/include/host_gemm.hpp | 17 +- profiler/CMakeLists.txt | 88 +-- .../profile_conv_fwd_bias_relu_add_impl.hpp | 101 ++- .../profile_conv_fwd_bias_relu_impl.hpp | 131 +--- profiler/include/profile_conv_fwd_impl.hpp | 44 +- .../profile_gemm_bias_relu_add_impl.hpp | 286 +++++++++ .../include/profile_gemm_bias_relu_impl.hpp | 264 ++++++++ profiler/include/profile_gemm_impl.hpp | 55 +- profiler/{ => src}/profile_conv_fwd.cpp | 0 .../{ => src}/profile_conv_fwd_bias_relu.cpp | 0 .../profile_conv_fwd_bias_relu_add.cpp | 0 .../profile_conv_fwd_bias_relu_atomic_add.cpp | 0 profiler/{ => src}/profile_gemm.cpp | 9 - profiler/src/profile_gemm_bias_relu.cpp | 148 +++++ profiler/src/profile_gemm_bias_relu_add.cpp | 153 +++++ profiler/{ => src}/profiler.cpp | 26 +- .../include/reference_conv_fwd.hpp | 27 +- .../reference_conv_fwd_bias_activation.hpp | 26 +- ...reference_conv_fwd_bias_activation_add.hpp | 31 +- .../include/reference_gemm.hpp | 132 ++++ .../reference_gemm_bias_activation.hpp | 136 +++++ .../reference_gemm_bias_activation_add.hpp | 144 +++++ script/conv2d_fwd.sh | 46 ++ script/gemm.sh | 20 + script/pool2d_fwd.sh | 46 ++ script/profile_conv.sh | 85 ++- 77 files changed, 3865 insertions(+), 932 deletions(-) create mode 100644 device_operation/CMakeLists.txt create mode 100644 device_operation/include/device_gemm_bias_activation.hpp create mode 100644 device_operation/include/device_gemm_bias_activation_add.hpp rename example/3_gemm_xdl_bias_relu_add/include/device_gemm_xdl_two_extra_source_reduce.hpp => device_operation/include/device_gemm_xdl_c_shuffle_bias_activation.hpp (55%) create mode 100644 device_operation/include/device_gemm_xdl_c_shuffle_bias_activation_add.hpp rename device_operation/{ => src}/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp (93%) rename device_operation/{ => src}/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp (100%) rename device_operation/{ => src}/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp (93%) rename device_operation/{ => src}/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp (93%) rename device_operation/{ => src}/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp (100%) rename device_operation/{ => src}/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp (100%) create mode 100644 device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instance.cpp create mode 100644 device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instance.cpp create mode 100644 device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instance.cpp create mode 100644 device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instance.cpp create mode 100644 device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instance.cpp create mode 100644 device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instance.cpp create mode 100644 device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instance.cpp create mode 100644 device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instance.cpp create mode 100644 device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp create mode 100644 device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp create mode 100644 device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp create mode 100644 device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp rename device_operation/{ => src}/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp (100%) rename device_operation/{ => src}/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp (100%) rename device_operation/{ => src}/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp (100%) rename device_operation/{ => src}/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp (100%) rename device_operation/{ => src}/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp (100%) rename device_operation/{ => src}/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp (100%) rename device_operation/{ => src}/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp (100%) rename device_operation/{ => src}/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp (100%) rename device_operation/{ => src}/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp (100%) rename device_operation/{ => src}/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp (100%) rename device_operation/{ => src}/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp (100%) rename device_operation/{ => src}/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp (100%) create mode 100644 example/2_gemm_xdl_bias_relu/README.md create mode 100644 example/2_gemm_xdl_bias_relu/gemm_xdl_bias_relu.cpp create mode 100644 profiler/include/profile_gemm_bias_relu_add_impl.hpp create mode 100644 profiler/include/profile_gemm_bias_relu_impl.hpp rename profiler/{ => src}/profile_conv_fwd.cpp (100%) rename profiler/{ => src}/profile_conv_fwd_bias_relu.cpp (100%) rename profiler/{ => src}/profile_conv_fwd_bias_relu_add.cpp (100%) rename profiler/{ => src}/profile_conv_fwd_bias_relu_atomic_add.cpp (100%) rename profiler/{ => src}/profile_gemm.cpp (97%) create mode 100644 profiler/src/profile_gemm_bias_relu.cpp create mode 100644 profiler/src/profile_gemm_bias_relu_add.cpp rename profiler/{ => src}/profiler.cpp (64%) rename {host => reference_operation}/include/reference_conv_fwd.hpp (89%) rename {host => reference_operation}/include/reference_conv_fwd_bias_activation.hpp (89%) rename {host => reference_operation}/include/reference_conv_fwd_bias_activation_add.hpp (88%) create mode 100644 reference_operation/include/reference_gemm.hpp create mode 100644 reference_operation/include/reference_gemm_bias_activation.hpp create mode 100644 reference_operation/include/reference_gemm_bias_activation_add.hpp create mode 100755 script/conv2d_fwd.sh create mode 100755 script/gemm.sh create mode 100755 script/pool2d_fwd.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index cb0508fec5..a2af6a812d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -198,6 +198,7 @@ enable_cppcheck( ) add_subdirectory(host) +add_subdirectory(device_operation) add_subdirectory(example) add_subdirectory(profiler) add_subdirectory(test) diff --git a/composable_kernel/include/tensor_operation/element_wise_operation.hpp b/composable_kernel/include/tensor_operation/element_wise_operation.hpp index 306102f4fb..d2054b8301 100644 --- a/composable_kernel/include/tensor_operation/element_wise_operation.hpp +++ b/composable_kernel/include/tensor_operation/element_wise_operation.hpp @@ -7,178 +7,99 @@ namespace element_wise { struct PassThrough { - template - __host__ __device__ void operator()(T& y, const T& x) const - { - y = x; - } + __host__ __device__ void operator()(float& y, const float& x) const { y = x; } - // TODO remove this - template - __host__ __device__ constexpr T operator()(T v) const - { - return v; - } + __host__ __device__ void operator()(half_t& y, const half_t& x) const { y = x; } }; struct AddRelu { - template - __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const + __host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const { - T a = x0 + x1; - y = a > 0 ? a : 0; + const float a = x0 + x1; + y = a > 0 ? a : 0; } - // TODO remove this - template - __host__ constexpr float operator()(float v0, T1 v1) const + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const { - float b = v0 + v1; - float c = b > 0 ? b : 0; + const half_t a = x0 + x1; + y = a > 0 ? a : 0; + } +}; - return c; +struct AddHardswish +{ + __host__ __device__ constexpr void operator()(float& y, const float& x0, const float& x1) const + { + float a = x0 + x1; + float b = a + float{3}; + float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; + y = c; } - // TODO remove this - template - __device__ constexpr float operator()(float v0, T1 v1) const + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const { -#if 0 - float a = v1 + v0; - float b = max(a, float(0)); - - return b; -#else - float b = v1 + v0; - float c = b > 0 ? b : 0; - - return c; -#endif + float a = x0 + x1; + float b = a + float{3}; + float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; + y = c; } }; struct AddReluAdd { - template - __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1, const T& x2) const + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const { - T a = x0 + x1; - T b = a > 0 ? a : 0; - y = b + x2; + half_t a = x0 + x1; + half_t b = a > 0 ? a : 0; + y = b + x2; } - // TODO remove this - template - __host__ constexpr float operator()(float v0, T1 v1, T2 v2) const + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1, const float& x2) const { - float b = v0 + v1; - float c = b > 0 ? b : 0; - float d = c + v2; - - return d; + float a = x0 + x1; + float b = a > 0 ? a : 0; + float c = b + x2; + y = c; } - // TODO remove this - template - __device__ constexpr float operator()(float v0, T1 v1, T2 v2) const + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1, const half_t& x2) const { -#if 0 - float a = v1 + v0; - float b = max(a, float(0)); - float c = b + v2; - - return c; -#else - float b = v1 + v2; - float c = (v0 > -v1) ? b + v0 : v2; - - return c; -#endif + float a = x0 + x1; + float b = a > 0 ? a : 0; + float c = b + x2; + y = c; } }; -} // namespace element_wise -} // namespace tensor_operation -} // namespace ck - -namespace ck { -namespace tensor_operation { -namespace element_wise { - -struct AddLeakyReluAdd +struct AddHardswishAdd { - template - __host__ constexpr float operator()(float v0, T1 v1, T2 v2) const + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1, const float& x2) const { - float a = v0 + v1; - float b = 0.1 * a; - float c = b > 0 ? b : 0; - float d = c + v2; - - return d; + float a = x0 + x1; + float b = a + float{3}; + float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; + float d = c + x2; + y = d; } - template - __device__ constexpr float operator()(float v0, T1 v1, T2 v2) const + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const { -#if 0 - // this use not too many registers, but use fp64 mul - float a = v0 + v1; - float b = 0.1 * a; - float c = b > 0 ? b : 0; - float d = c + v2; - - return d; -#elif 0 - // this spill register - float a = v0 + v1; - float b = float(0.1) * a; - float c = b > 0 ? b : 0; - float d = c + v2; - - return d; -#elif 0 - // this use lots of registers (but no spill) - constexpr float alpha = 0.1; - constexpr float alpha_inv = 1.0 / alpha; - - float a = v2 * alpha_inv; - float b = v1 + v0; - float c = b > 0 ? b : 0; - float d = alpha * (a + c); - - return d; -#elif 1 - // this use lots of registers (but no spill), 89 Tflops - constexpr float alpha = 0.1; - constexpr float alpha_inv = 1.0 / alpha; - - float a = v2 * alpha_inv; - float b = v1 + v0; - float c = max(b, float(0)); - float d = alpha * (a + c); - - return d; -#elif 1 - // this spill registers, 89 Tflops - float a = v0 + v1; - float alpha = 0.1; - - float b; - asm volatile("\n \ - v_mul_f32_e32 %0, %1, %2 \n \ - " - : "=v"(b) - : "s"(alpha), "v"(a)); - - float c = b > 0 ? b : 0; - float d = c + v2; - - return d; -#endif + float a = x0 + x1; + float b = a + float{3}; + float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; + float d = c + x2; + y = d; } }; + } // namespace element_wise } // namespace tensor_operation } // namespace ck diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp index a58855aa35..f914847192 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp @@ -199,9 +199,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3 constexpr index_t src_offset = src_desc.CalculateOffset( src_slice_origin_idx + dst_data_idx + i * dst_scalar_step_in_vector); - // apply element-wise operation and type convert - dst_vector.template AsType()(i) = - type_convert(dst_element_op_(src_buf[Number{}])); + SrcData dst_v; + + // apply element-wise operation + dst_element_op_(dst_v, src_buf[Number{}]); + + // apply type convert + dst_vector.template AsType()(i) = type_convert(dst_v); }); const bool is_dst_valid = diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v1r4.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v1r4.hpp index c669427896..1ef098f6d5 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v1r4.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v1r4.hpp @@ -293,7 +293,9 @@ struct ThreadwiseTensorSliceTransfer_v1r4 dst_vector.template AsType()(Number<0>{}) = type_convert(dst_v); #else // apply element-wise operation in DstData type - const DstData dst_v = dst_element_op_(src_v, dst0_v, dst1_v); + DstData dst_v; + + dst_element_op_(dst_v, src_v, dst0_v, dst1_v); dst_vector.template AsType()(Number<0>{}) = dst_v; #endif diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r1.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r1.hpp index 5497bb2e3d..438f925306 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r1.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r1.hpp @@ -207,8 +207,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1 // apply SrcElementwiseOperation on src_vector_container static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { - src_vector_container.template AsType()(i) = - src_element_op_(src_vector_container.template AsType()[i]); + SrcData src_v; + + src_element_op_(src_v, src_vector_container.template AsType()[i]); + + src_vector_container.template AsType()(i) = src_v; }); // copy data from src_vector_container into src_thread_scratch_ @@ -452,10 +455,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1 auto dst_vector_container = dst_vector_type{ dst_thread_scratch_.template GetAsType(dst_data_idx_seq)}; - // apply DstElementwiseOperation on dst_vector_container static_for<0, DstScalarPerVector, 1>{}([&](auto i) { - dst_vector_container.template AsType()(i) = - dst_element_op_(dst_vector_container.template AsType()[i]); + DstData dst_v; + + // apply DstElementwiseOperation + dst_element_op_(dst_v, dst_vector_container.template AsType()[i]); + + dst_vector_container.template AsType()(i) = dst_v; }); // copy data from dst_vector_container to dst_buf diff --git a/device_operation/CMakeLists.txt b/device_operation/CMakeLists.txt new file mode 100644 index 0000000000..d9a4ebb499 --- /dev/null +++ b/device_operation/CMakeLists.txt @@ -0,0 +1,111 @@ +include_directories(BEFORE + include + ${PROJECT_SOURCE_DIR}/host/host_tensor/include + ${PROJECT_SOURCE_DIR}/device/include + ${PROJECT_SOURCE_DIR}/device_operation/include + ${PROJECT_SOURCE_DIR}/profiler/include + ${PROJECT_SOURCE_DIR}/composable_kernel/include + ${PROJECT_SOURCE_DIR}/composable_kernel/include/utility + ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description + ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation + ${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform + ${PROJECT_SOURCE_DIR}/external/rocm/include +) + +# device_gemm_instance +set(DEVICE_GEMM_INSTANCE_SOURCE + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp; +) + +# device_gemm_bias_relu_instance +set(DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instance.cpp; +) + +# device_gemm_bias_relu_add_instance +set(DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instance.cpp; +) + +# device_conv2d_fwd_instance +set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE + ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; + ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp; +) + +# device_conv2d_fwd_bias_relu_instance +set(DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE + ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp; +) + +# device_conv2d_fwd_bias_relu_add_instance +set(DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE + ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp; +) + +# device_conv2d_fwd_bias_relu_atomic_add_instance +set(DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE + ${PROJECT_SOURCE_DIR}/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp; +) + +add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) +add_library(device_gemm_bias_relu_instance SHARED ${DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE}) +add_library(device_gemm_bias_relu_add_instance SHARED ${DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE}) +add_library(device_conv2d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE}) +add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE}) +add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE}) +add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE}) + +target_include_directories(device_gemm_instance SYSTEM PUBLIC $) +target_include_directories(device_gemm_bias_relu_instance SYSTEM PUBLIC $) +target_include_directories(device_gemm_bias_relu_add_instance SYSTEM PUBLIC $) +target_include_directories(device_conv2d_fwd_instance SYSTEM PUBLIC $) +target_include_directories(device_conv2d_fwd_bias_relu_instance SYSTEM PUBLIC $) +target_include_directories(device_conv2d_fwd_bias_relu_add_instance SYSTEM PUBLIC $) +target_include_directories(device_conv2d_fwd_bias_relu_atomic_add_instance SYSTEM PUBLIC $) + +target_compile_features(device_gemm_instance PUBLIC) +target_compile_features(device_gemm_bias_relu_instance PUBLIC) +target_compile_features(device_gemm_bias_relu_add_instance PUBLIC) +target_compile_features(device_conv2d_fwd_instance PUBLIC) +target_compile_features(device_conv2d_fwd_bias_relu_instance PUBLIC) +target_compile_features(device_conv2d_fwd_bias_relu_add_instance PUBLIC) +target_compile_features(device_conv2d_fwd_bias_relu_atomic_add_instance PUBLIC) + +set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(device_gemm_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(device_gemm_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +set_target_properties(device_conv2d_fwd_bias_relu_atomic_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) + +install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) +install(TARGETS device_gemm_bias_relu_instance LIBRARY DESTINATION lib) +install(TARGETS device_gemm_bias_relu_add_instance LIBRARY DESTINATION lib) +install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib) +install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib) +install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib) +install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib) diff --git a/device_operation/include/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp b/device_operation/include/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp index e9aa4fa42c..6baf1483ac 100644 --- a/device_operation/include/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp +++ b/device_operation/include/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp @@ -451,14 +451,14 @@ struct } } - using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + using GridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); - using AGridDesc_K0_M_K1 = remove_cvref_t; - using BGridDesc_K0_N_K1 = remove_cvref_t; - using CGridDesc_M_N = remove_cvref_t; - using C0GridDesc_M_N = remove_cvref_t; - using C1GridDesc_M_N = remove_cvref_t; + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + using C0GridDesc_M_N = remove_cvref_t; + using C1GridDesc_M_N = remove_cvref_t; // GridwiseGemm using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3< diff --git a/device_operation/include/device_gemm_bias_activation.hpp b/device_operation/include/device_gemm_bias_activation.hpp new file mode 100644 index 0000000000..95736b1887 --- /dev/null +++ b/device_operation/include/device_gemm_bias_activation.hpp @@ -0,0 +1,43 @@ +#ifndef DEVICE_GEMM_BIAS_ACTIVATION_HPP +#define DEVICE_GEMM_BIAS_ACTIVATION_HPP + +#include +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmBiasActivation : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + const void* p_c0, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmBiasActivationPtr = std::unique_ptr< + DeviceGemmBiasActivation>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/device_operation/include/device_gemm_bias_activation_add.hpp b/device_operation/include/device_gemm_bias_activation_add.hpp new file mode 100644 index 0000000000..d304abaa38 --- /dev/null +++ b/device_operation/include/device_gemm_bias_activation_add.hpp @@ -0,0 +1,47 @@ +#ifndef DEVICE_GEMM_BIAS_ACTIVATION_ADD_HPP +#define DEVICE_GEMM_BIAS_ACTIVATION_ADD_HPP + +#include +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmBiasActivationAdd : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + const void* p_c0, + const void* p_c1, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + ck::index_t StrideC1, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmBiasActivationAddPtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/device_operation/include/device_gemm_xdl_c_shuffle.hpp b/device_operation/include/device_gemm_xdl_c_shuffle.hpp index da19b5ec4f..6127e6e6fe 100644 --- a/device_operation/include/device_gemm_xdl_c_shuffle.hpp +++ b/device_operation/include/device_gemm_xdl_c_shuffle.hpp @@ -424,7 +424,8 @@ struct DeviceGemmXdl_C_Shuffle index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) override + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) override { return std::make_unique(static_cast(p_a), static_cast(p_b), @@ -454,7 +455,7 @@ struct DeviceGemmXdl_C_Shuffle auto str = std::stringstream(); // clang-format off - str << "DeviceGemmXdl" + str << "DeviceGemmXdl_C_Shuffle" << "<" << BlockSize << ", " << MPerBlock << ", " diff --git a/example/3_gemm_xdl_bias_relu_add/include/device_gemm_xdl_two_extra_source_reduce.hpp b/device_operation/include/device_gemm_xdl_c_shuffle_bias_activation.hpp similarity index 55% rename from example/3_gemm_xdl_bias_relu_add/include/device_gemm_xdl_two_extra_source_reduce.hpp rename to device_operation/include/device_gemm_xdl_c_shuffle_bias_activation.hpp index ce8ea79bd6..47d16546ae 100644 --- a/example/3_gemm_xdl_bias_relu_add/include/device_gemm_xdl_two_extra_source_reduce.hpp +++ b/device_operation/include/device_gemm_xdl_c_shuffle_bias_activation.hpp @@ -1,70 +1,81 @@ -#ifndef DEVICE_GEMM_XDL_TWO_EXTRA_SOURCE_REDUCE_HPP -#define DEVICE_GEMM_XDL_TWO_EXTRA_SOURCE_REDUCE_HPP +#ifndef DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_ACTIVATION_HPP +#define DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_ACTIVATION_HPP #include #include #include "device.hpp" -#include "device_base.hpp" -#include "device_gemm.hpp" +#include "device_gemm_bias_activation.hpp" #include "common_header.hpp" #include "tensor_layout.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdlops_v2r5.hpp" +#include "gridwise_gemm_xdlops_v3r2.hpp" namespace ck { namespace tensor_operation { namespace device { -template -struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator +// C[M, N] = activate(A[M, K] * B[K, N] + C0[N]) +template < + typename ADataType, + typename BDataType, + typename CDataType, + typename AccDataType, + typename ALayout, + typename BLayout, + typename CLayout, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + ck::index_t BlockSize, + ck::index_t MPerBlock, + ck::index_t NPerBlock, + ck::index_t K0PerBlock, + ck::index_t K1, + ck::index_t MPerXDL, + ck::index_t NPerXDL, + ck::index_t MXdlPerWave, + ck::index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + ck::index_t ABlockTransferSrcVectorDim, + ck::index_t ABlockTransferSrcScalarPerVector, + ck::index_t ABlockTransferDstScalarPerVector_K1, + bool ABlockLdsAddExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + ck::index_t BBlockTransferSrcVectorDim, + ck::index_t BBlockTransferSrcScalarPerVector, + ck::index_t BBlockTransferDstScalarPerVector_K1, + bool BBlockLdsAddExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct DeviceGemmXdl_C_Shuffle_Bias_Activation + : public DeviceGemmBiasActivation { + using DeviceOp = DeviceGemmXdl_C_Shuffle_Bias_Activation; + static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; static constexpr auto K1Number = Number{}; - static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + static auto MakeGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N_C0_M_N( + index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC) { assert(K % K1 == 0); const index_t K0 = K / K1; + // A[K0, M, K1] const auto a_grid_desc_m_k = [&]() { if constexpr(is_same::value) { @@ -83,15 +94,7 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - return a_grid_desc_k0_m_k1; - } - - static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) - { - assert(K % K1 == 0); - - const index_t K0 = K / K1; - + // B[K0, N, K1] const auto b_grid_desc_k_n = [&]() { if constexpr(is_same::value) { @@ -110,33 +113,36 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - return b_grid_desc_k0_n_k1; + // C[M, N] + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // C0[N]: assume a contiguous vector + const auto c0_grid_desc_m_n = + make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, I1)); + + return make_tuple( + a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, c_grid_desc_m_n, c0_grid_desc_m_n); } - static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) - { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); - } - } + using GridDescs = + decltype(MakeGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N_C0_M_N(1, 1, 1, 1, 1, 1)); - using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); - using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); - using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - using C0GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - - // hardcoding - // TODO: fix this - using C1GridDesc_M_N = - decltype(make_naive_tensor_descriptor(make_tuple(1, 1), make_tuple(I1, I0))); + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + using C0GridDesc_M_N = remove_cvref_t; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5< + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2< BlockSize, ADataType, // TODO: distinguish A/B datatype AccDataType, @@ -146,7 +152,6 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator BGridDesc_K0_N_K1, CGridDesc_M_N, C0GridDesc_M_N, - C1GridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, @@ -174,20 +179,10 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator BBlockTransferDstScalarPerVector_K1, false, // BThreadTransferSrcResetCoordinateAfterRun, BBlockLdsAddExtraN, - Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector>; - - using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); - - using C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C0GridDesc_M_N{})); - - using C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = - decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(C1GridDesc_M_N{})); - - using Block2CTileMap = decltype(GridwiseGemm::MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + CBlockTransferScalarPerVector_NWaveNPerXdl>; // Argument struct Argument : public BaseArgument @@ -196,7 +191,6 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator const BDataType* p_b_grid, CDataType* p_c_grid, const CDataType* p_c0_grid, - const CDataType* p_c1_grid, index_t M, index_t N, index_t K, @@ -212,15 +206,12 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator p_b_grid_{p_b_grid}, p_c_grid_{p_c_grid}, p_c0_grid_{p_c0_grid}, - p_c1_grid_{p_c1_grid}, a_grid_desc_k0_m_k1_{}, b_grid_desc_k0_n_k1_{}, c_grid_desc_m_n_{}, c0_grid_desc_m_n_{}, - c1_grid_desc_m_n_{}, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, - c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, - c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, block_2_ctile_map_{}, M01_{M01}, N01_{N01}, @@ -228,33 +219,26 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator b_element_op_{b_element_op}, c_element_op_{c_element_op} { - a_grid_desc_k0_m_k1_ = - DeviceGemmXdl_two_extra_source_reduce::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); - b_grid_desc_k0_n_k1_ = - DeviceGemmXdl_two_extra_source_reduce::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); - c_grid_desc_m_n_ = - DeviceGemmXdl_two_extra_source_reduce::MakeCGridDescriptor_M_N(M, N, StrideC); + const auto descs = DeviceOp::MakeGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N_C0_M_N( + M, N, K, StrideA, StrideB, StrideC); - // assume C0 has same layout as C - // TODO: fix this - c0_grid_desc_m_n_ = - DeviceGemmXdl_two_extra_source_reduce::MakeCGridDescriptor_M_N(M, N, StrideC); - - // hardcoding C1 layout - // TODO: fix this - c1_grid_desc_m_n_ = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, I0)); + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + c0_grid_desc_m_n_ = descs[I3]; if(GridwiseGemm::CheckValidity( a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) { - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c_grid_desc_m_n_); - c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c0_grid_desc_m_n_); - - c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = - GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c1_grid_desc_m_n_); + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c0_grid_desc_m_n_); block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); } @@ -265,16 +249,17 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator const BDataType* p_b_grid_; CDataType* p_c_grid_; const CDataType* p_c0_grid_; - const CDataType* p_c1_grid_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; CGridDesc_M_N c_grid_desc_m_n_; C0GridDesc_M_N c0_grid_desc_m_n_; - C1GridDesc_M_N c1_grid_desc_m_n_; - CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; - C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; - C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; - Block2CTileMap block_2_ctile_map_; + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm::Block2CTileMap block_2_ctile_map_; index_t M01_; index_t N01_; AElementwiseOperation a_element_op_; @@ -285,7 +270,7 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator // Invoker struct Invoker : public BaseInvoker { - using Argument = DeviceGemmXdl_two_extra_source_reduce::Argument; + using Argument = DeviceOp::Argument; float Run(const Argument& arg, int nrepeat = 1) { @@ -303,9 +288,6 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; - - std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0) - << ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, @@ -328,83 +310,81 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator if(has_main_k0_block_loop) { - const auto kernel = kernel_gemm_xdlops_v2r5< + const auto kernel = kernel_gemm_xdlops_v3r2< GridwiseGemm, ADataType, // TODO: distiguish A/B datatype CDataType, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< - DeviceGemmXdl_two_extra_source_reduce::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, remove_reference_t< - DeviceGemmXdl_two_extra_source_reduce::C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, - remove_reference_t< - DeviceGemmXdl_two_extra_source_reduce::C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, - remove_reference_t, + remove_reference_t, true>; - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_c0_grid_, - arg.p_c1_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); } else { - const auto kernel = kernel_gemm_xdlops_v2r5< + const auto kernel = kernel_gemm_xdlops_v3r2< GridwiseGemm, ADataType, // TODO: distiguish A/B datatype CDataType, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, remove_reference_t< - DeviceGemmXdl_two_extra_source_reduce::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, remove_reference_t< - DeviceGemmXdl_two_extra_source_reduce::C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, - remove_reference_t< - DeviceGemmXdl_two_extra_source_reduce::C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, - remove_reference_t, + remove_reference_t, false>; - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.p_c0_grid_, - arg.p_c1_grid_, - arg.a_grid_desc_k0_m_k1_, - arg.b_grid_desc_k0_n_k1_, - arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.block_2_ctile_map_); + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); } return ave_time; @@ -442,7 +422,6 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator const BDataType* p_b, CDataType* p_c, const CDataType* p_c0, - const CDataType* p_c1, index_t M, index_t N, index_t K, @@ -457,7 +436,6 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator p_b, p_c, p_c0, - p_c1, M, N, K, @@ -478,7 +456,6 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator const void* p_b, void* p_c, const void* p_c0, - const void* p_c1, index_t M, index_t N, index_t K, @@ -487,13 +464,13 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) + CElementwiseOperation c_element_op, + index_t KBatch = 1) override { return std::make_unique(static_cast(p_a), static_cast(p_b), static_cast(p_c), static_cast(p_c0), - static_cast(p_c1), M, N, K, @@ -508,7 +485,7 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator } // polymorphic - std::unique_ptr MakeInvokerPointer() + std::unique_ptr MakeInvokerPointer() override { return std::make_unique(Invoker{}); } @@ -518,7 +495,7 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator auto str = std::stringstream(); // clang-format off - str << "DeviceGemmXdl_two_extra_source_reduce" + str << "DeviceGemmXdl_C_Shuffle_Bias_Activation" << "<" << BlockSize << ", " << MPerBlock << ", " diff --git a/device_operation/include/device_gemm_xdl_c_shuffle_bias_activation_add.hpp b/device_operation/include/device_gemm_xdl_c_shuffle_bias_activation_add.hpp new file mode 100644 index 0000000000..b0e2f61a11 --- /dev/null +++ b/device_operation/include/device_gemm_xdl_c_shuffle_bias_activation_add.hpp @@ -0,0 +1,574 @@ +#ifndef DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_ACTIVATION_ADD_HPP +#define DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_ACTIVATION_ADD_HPP + +#include +#include +#include "device.hpp" +#include "device_gemm_bias_activation_add.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v3r3.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// C[M, N] = activate(A[M, K] * B[K, N] + C0[N]) + C1[M, N] +template < + typename ADataType, + typename BDataType, + typename CDataType, + typename AccDataType, + typename ALayout, + typename BLayout, + typename CLayout, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + ck::index_t BlockSize, + ck::index_t MPerBlock, + ck::index_t NPerBlock, + ck::index_t K0PerBlock, + ck::index_t K1, + ck::index_t MPerXDL, + ck::index_t NPerXDL, + ck::index_t MXdlPerWave, + ck::index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + ck::index_t ABlockTransferSrcVectorDim, + ck::index_t ABlockTransferSrcScalarPerVector, + ck::index_t ABlockTransferDstScalarPerVector_K1, + bool ABlockLdsAddExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + ck::index_t BBlockTransferSrcVectorDim, + ck::index_t BBlockTransferSrcScalarPerVector, + ck::index_t BBlockTransferDstScalarPerVector_K1, + bool BBlockLdsAddExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add + : public DeviceGemmBiasActivationAdd +{ + using DeviceOp = DeviceGemmXdl_C_Shuffle_Bias_Activation_Add; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N_C0_M_N_C1_M_N(index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideC1) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + // A[K0, M, K1] + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_k0_m_k1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B[K0, N, K1] + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto b_grid_desc_k0_n_k1 = + transform_tensor_descriptor(b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C[M, N] + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // C0[N]: assume a contiguous vector + const auto c0_grid_desc_m_n = + make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, I1)); + + // C1[M, N]: residual tensor: assume same layout as C + const auto c1_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC1, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC1)); + } + }(); + + return make_tuple(a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m_n, + c0_grid_desc_m_n, + c1_grid_desc_m_n); + } + + using GridDescs = + decltype(MakeGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N_C0_M_N_C1_M_N(1, 1, 1, 1, 1, 1, 1)); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + using C0GridDesc_M_N = remove_cvref_t; + using C1GridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum_t::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + C0GridDesc_M_N, + C1GridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + CBlockTransferScalarPerVector_NWaveNPerXdl>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const CDataType* p_c0_grid, + const CDataType* p_c1_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideC1, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_c0_grid_{p_c0_grid}, + p_c1_grid_{p_c1_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c0_grid_desc_m_n_{}, + c1_grid_desc_m_n_{}, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + const auto descs = DeviceOp::MakeGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N_C0_M_N_C1_M_N( + M, N, K, StrideA, StrideB, StrideC, StrideC1); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + c0_grid_desc_m_n_ = descs[I3]; + c1_grid_desc_m_n_ = descs[I4]; + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + { + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c_grid_desc_m_n_); + + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c0_grid_desc_m_n_); + + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c1_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + const CDataType* p_c0_grid_; + const CDataType* p_c1_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + C0GridDesc_M_N c0_grid_desc_m_n_; + C1GridDesc_M_N c1_grid_desc_m_n_; + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm:: + C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm::Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, int nrepeat = 1) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r5 has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_v3r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.p_c1_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v3r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.p_c1_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, int nrepeat = 1) override + { + return Run(*dynamic_cast(p_arg), nrepeat); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + const CDataType* p_c0, + const CDataType* p_c1, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideC1, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + p_c0, + p_c1, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideC1, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + const void* p_c0, + const void* p_c1, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideC1, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t KBatch = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_c0), + static_cast(p_c1), + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideC1, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmXdl_C_Shuffle_Bias_Activation_Add" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp b/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp similarity index 93% rename from device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp rename to device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp index dbfa6e2031..00f270a8d3 100644 --- a/device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp @@ -118,7 +118,12 @@ using device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_odd_c_f16_ins DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 128, 64, 2, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 256, 256, 64, 2, 4, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 128, 64, 2, 4, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddReluAdd, ConvFwdOddC, 128, 64, 64, 2, 4, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> // clang-format on >; diff --git a/device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp b/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp similarity index 100% rename from device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp rename to device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp diff --git a/device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp b/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp similarity index 93% rename from device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp rename to device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp index 075eddd117..35a88ac5f1 100644 --- a/device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp @@ -120,7 +120,12 @@ using device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_odd_c_f16_instanc DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, - DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 128, 64, 2, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 256, 256, 64, 2, 4, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 128, 64, 2, 4, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, AddRelu, MemorySet, ConvFwdOddC, 128, 64, 64, 2, 4, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> // clang-format on >; diff --git a/device_operation/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp b/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp similarity index 93% rename from device_operation/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp rename to device_operation/src/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp index cd9ee30627..1e93de9cbb 100644 --- a/device_operation/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp +++ b/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp @@ -116,7 +116,12 @@ using device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_odd_c_f16_instances = std:: DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 4, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, - DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 2, 8>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 128, 64, 2, 4, 32, 32, 2, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 256, 256, 64, 2, 4, 32, 32, 4, 1, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 32, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 128, 64, 2, 4, 32, 32, 2, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdOddC, 128, 64, 64, 2, 4, 32, 32, 1, 2, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<2, 16, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> // clang-format on >; diff --git a/device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp b/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp similarity index 100% rename from device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp rename to device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp diff --git a/device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp b/device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp similarity index 100% rename from device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp rename to device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp diff --git a/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instance.cpp b/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..c26f66a9ed --- /dev/null +++ b/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_activation_add.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; + +// c[m, n] = ReLU(a[k, m] * b[k, n] + c0[n]) + c1[m, n] +using device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //#########################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instance.cpp b/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..c0950666b1 --- /dev/null +++ b/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_activation_add.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; + +// c[m, n] = ReLU(a[k, m] * b[n, k] + c0[n]) + c1[m, n] +using device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //#########################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instance.cpp b/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..42c1f72d6e --- /dev/null +++ b/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_activation_add.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; + +// c[m, n] = ReLU(a[m, k] * b[k, n] + c0[n]) + c1[m, n] +using device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#########################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddReluAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instance.cpp b/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..3961def81d --- /dev/null +++ b/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,57 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_activation_add.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd; + +// c[m, n] = ReLU(a[m, k] * b[n, k] + c0[n]) + c1[m, n] +using device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#########################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#########################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#########################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddReluAdd, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instance.cpp b/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..4927a05ca4 --- /dev/null +++ b/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_activation.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; + +// c[m, n] = ReLU(a[k, m] * b[k, n] + c0[n]) +using device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances = std::tuple< + // clang-format off + //#####################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, AddRelu, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instance.cpp b/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..f712f9de11 --- /dev/null +++ b/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_activation.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; + +// c[m, n] = ReLU(a[k, m] * b[n, k] + c0[n]) +using device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances = std::tuple< + // clang-format off + //#####################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, AddRelu, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instance.cpp b/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..26af05bbde --- /dev/null +++ b/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_activation.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; + +// c[m, n] = ReLU(a[m, k] * b[k, n] + c0[n]) +using device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances = std::tuple< + // clang-format off + //#####################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, AddRelu, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instance.cpp b/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..901b7a5d64 --- /dev/null +++ b/device_operation/src/device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,57 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle_bias_activation.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using AddRelu = ck::tensor_operation::element_wise::AddRelu; + +// c[m, n] = ReLU(a[m, k] * b[n, k] + c0[n]) +using device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //#####################################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#####################################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle_Bias_Activation< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, AddRelu, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances( + instances, device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp b/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..c82402f5bf --- /dev/null +++ b/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances = + std::tuple< + // clang-format off + //#####################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp b/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..1609d49e16 --- /dev/null +++ b/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[k, m] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances = + std::tuple< + // clang-format off + //#####################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Col, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp b/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..4afe5e1234 --- /dev/null +++ b/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,52 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[k, n] = c[m, n] +using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances = + std::tuple< + // clang-format off + //#####################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp b/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..0793adcabb --- /dev/null +++ b/device_operation/src/device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,57 @@ +#include +#include "config.hpp" +#include "device_gemm_xdl_c_shuffle.hpp" +#include "element_wise_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +// Compilation parameters for a[m, k] * b[n, k] = c[m, n] +using device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances = + std::tuple< + // clang-format off + //#####################|AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#####################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| + //#####################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 32, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 8>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8>, + DeviceGemmXdl_C_Shuffle< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 1, 16, 1, 1, 4>, 8> + // clang-format on + >; + +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( + std::vector>& instances) +{ + add_device_operation_instances(instances, + device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/device_operation/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp b/device_operation/src/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp similarity index 100% rename from device_operation/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp rename to device_operation/src/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp diff --git a/device_operation/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp b/device_operation/src/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp similarity index 100% rename from device_operation/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp rename to device_operation/src/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp diff --git a/device_operation/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp b/device_operation/src/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp similarity index 100% rename from device_operation/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp rename to device_operation/src/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp diff --git a/device_operation/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp b/device_operation/src/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp similarity index 100% rename from device_operation/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp rename to device_operation/src/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp diff --git a/device_operation/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp b/device_operation/src/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp similarity index 100% rename from device_operation/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp rename to device_operation/src/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp diff --git a/device_operation/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp b/device_operation/src/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp similarity index 100% rename from device_operation/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp rename to device_operation/src/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp diff --git a/device_operation/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp b/device_operation/src/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp similarity index 100% rename from device_operation/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp rename to device_operation/src/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp diff --git a/device_operation/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp b/device_operation/src/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp similarity index 100% rename from device_operation/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp rename to device_operation/src/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp diff --git a/device_operation/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp b/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp similarity index 100% rename from device_operation/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp rename to device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp diff --git a/device_operation/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp b/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp similarity index 100% rename from device_operation/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp rename to device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp diff --git a/device_operation/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp b/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp similarity index 100% rename from device_operation/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp rename to device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp diff --git a/device_operation/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp b/device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp similarity index 100% rename from device_operation/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp rename to device_operation/src/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp diff --git a/example/1_gemm_xdl/gemm_xdl.cpp b/example/1_gemm_xdl/gemm_xdl.cpp index 8211655ca7..d9ed011fbe 100644 --- a/example/1_gemm_xdl/gemm_xdl.cpp +++ b/example/1_gemm_xdl/gemm_xdl.cpp @@ -11,9 +11,9 @@ #include "host_tensor_generator.hpp" #include "host_gemm.hpp" #include "device_tensor.hpp" -#include "device_base.hpp" #include "device_gemm_xdl_c_shuffle.hpp" #include "element_wise_operation.hpp" +#include "reference_gemm.hpp" template using S = ck::Sequence; @@ -72,37 +72,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl // clang-format on -template -static void host_verify(const Tensor& a_m_k, - const Tensor& b_k_n, - Tensor& c_m_n, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CElementwiseOperation& c_element_op) -{ - auto f_mk_kn_mn = [&](auto m, auto n) { - const int K = a_m_k.mDesc.GetLengths()[1]; - - double v = 0; - - for(int k = 0; k < K; ++k) - { - v += static_cast(a_element_op(a_m_k(m, k))) * - static_cast(b_element_op(b_k_n(k, n))); - } - - c_m_n(m, n) = c_element_op(v); - }; - - make_ParallelTensorFunctor(f_mk_kn_mn, - c_m_n.mDesc.GetLengths()[0], - c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency()); -} +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; int main(int argc, char* argv[]) { @@ -191,6 +162,10 @@ int main(int argc, char* argv[]) b_k_n_device_buf.ToDevice(b_k_n.mData.data()); c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + // do GEMM auto gemm = DeviceGemmInstance{}; auto invoker = gemm.MakeInvoker(); @@ -203,9 +178,9 @@ int main(int argc, char* argv[]) StrideA, StrideB, StrideC, - AElementOp{}, - BElementOp{}, - CElementOp{}); + a_element_op, + b_element_op, + c_element_op); if(!gemm.IsSupportedArgument(argument)) { @@ -231,7 +206,13 @@ int main(int argc, char* argv[]) if(do_verification) { - host_verify(a_m_k, b_k_n, c_m_n_host_result, AElementOp{}, BElementOp{}, CElementOp{}); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); check_error(c_m_n_host_result, c_m_n_device_result); } diff --git a/example/2_gemm_xdl_bias_relu/README.md b/example/2_gemm_xdl_bias_relu/README.md new file mode 100644 index 0000000000..379f9a2e75 --- /dev/null +++ b/example/2_gemm_xdl_bias_relu/README.md @@ -0,0 +1,61 @@ +# Instructions for ```gemm_xdl_bias_relu_add``` Example + +## Docker script +```bash +docker run \ +-it \ +--rm \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v ${PATH_TO_LOCAL_WORKSPACE}:/root/workspace \ +rocm/tensorflow:rocm4.3.1-tf2.6-dev \ +/bin/bash +``` + +## Build ```gemm_xdl_bias_relu_add``` +```bash +mkdir build && cd build +``` + +```bash +# Need to specify target ID, example below is gfx908 +cmake \ +-D BUILD_DEV=OFF \ +-D CMAKE_BUILD_TYPE=Release \ +-D CMAKE_CXX_FLAGS="-DCK_AMD_GPU_GFX908 --amdgpu-target=gfx908 -O3 " \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_PREFIX_PATH=/opt/rocm \ +.. +``` + +```bash + make -j gemm_xdl_bias_relu_add +``` + +## Run ```gemm_xdl_bias_relu_add``` +```bash +#arg1: verification (0=no, 1=yes) +#arg2: initialization (0=no init, 1=integer value, 2=decimal value) +#arg3: run kernel # of times (>1) +#arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC +./example/gemm_xdl_bias_relu_add 0 1 5 3840 4096 4096 4096 4096 4096 +``` + +Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) +``` +a_m_k: dim 2, lengths {3840, 4096}, strides {4096, 1} +b_k_n: dim 2, lengths {4096, 4096}, strides {1, 4096} +c_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +c0_m_n: dim 2, lengths {3840, 4096}, strides {4096, 1} +c1_m_n: dim 2, lengths {3840, 4096}, strides {1, 0} +arg.a_grid_desc_k0_m_k1_{512, 3840, 8} +arg.b_grid_desc_k0_n_k1_{512, 4096, 8} +arg.c_grid_desc_m_n_{ 3840, 4096} +arg.c0_grid_desc_m_n_{ 3840, 4096} +arg.c1_grid_desc_m_n_{ 3840, 4096} +launch_and_time_kernel: grid_dim {480, 1, 1}, block_dim {256, 1, 1} +Warm up +Start running 5 times... +Perf: 1.27583 ms, 100.992 TFlops, 73.9688 GB/s +``` diff --git a/example/2_gemm_xdl_bias_relu/gemm_xdl_bias_relu.cpp b/example/2_gemm_xdl_bias_relu/gemm_xdl_bias_relu.cpp new file mode 100644 index 0000000000..4dc8d0b788 --- /dev/null +++ b/example/2_gemm_xdl_bias_relu/gemm_xdl_bias_relu.cpp @@ -0,0 +1,235 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "device_gemm_xdl_c_shuffle_bias_activation.hpp" +#include "reference_gemm_bias_activation.hpp" + +template +using S = ck::Sequence; + +using ADataType = ck::half_t; +using BDataType = ck::half_t; +using CDataType = ck::half_t; +using AccDataType = float; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::AddRelu; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle_Bias_Activation< + ADataType, // ADataType + BDataType, // BDataType + CDataType, // CDataType + AccDataType, // AccDataType + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + 256, // BlockSize + 256, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBiasActivation; + +int main(int argc, char* argv[]) +{ + bool do_verification = 0; + int init_method = 0; + int nrepeat = 5; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + } + else if(argc == 10) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: run kernel # of times (>1)\n"); + printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + exit(0); + } + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + // c0_n[n] + Tensor c0_n(HostTensorDescriptor( + std::vector({static_cast(N)}), std::vector({1}))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "c0_n: " << c0_n.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + c0_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + c0_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem c0_n_device_buf(sizeof(CDataType) * c0_n.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); + c0_n_device_buf.ToDevice(c0_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + static_cast(c0_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, nrepeat); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + + sizeof(CDataType) * M * N + sizeof(CDataType) * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, c0_n, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + check_error(c_m_n_host_result, c_m_n_device_result); + } +} diff --git a/example/3_gemm_xdl_bias_relu_add/gemm_xdl_bias_relu_add.cpp b/example/3_gemm_xdl_bias_relu_add/gemm_xdl_bias_relu_add.cpp index 5b8369c6e9..3ce7e9848b 100644 --- a/example/3_gemm_xdl_bias_relu_add/gemm_xdl_bias_relu_add.cpp +++ b/example/3_gemm_xdl_bias_relu_add/gemm_xdl_bias_relu_add.cpp @@ -11,107 +11,9 @@ #include "host_tensor_generator.hpp" #include "host_gemm.hpp" #include "device_tensor.hpp" -#include "device_base.hpp" -#include "example/3_gemm_xdl_bias_relu_add/include/device_gemm_xdl_two_extra_source_reduce.hpp" - -// C[m, n] = Relu(A[m, k] * B[k, n] + C0[m]) + C1[m, n] -// assume C0 is contiguous in memory -// C0 resides in memory as 1d vector [m], but is represented as 2D matrix [m, n], with stride = -// 0 in the "n" dimension -// assume C1 and C have same layout C - -struct BiasReluAdd -{ - template - __host__ constexpr float operator()(float v0, T1 v1, T2 v2) const - { - float b = v0 + v1; - float c = b > 0 ? b : 0; - float d = c + v2; - - return d; - } - - template - __device__ constexpr float operator()(float v0, T1 v1, T2 v2) const - { -#if 0 - float a = v1 + v0; - float b = a > 0 ? a : 0; - float c = b + v2; - - return c; -#else - float a = v1 + v2; - float b = v2; - - float c = (v0 > -v1) ? a + v0 : v2; - - return c; -#endif - } -}; - -struct DoSomething -{ -#if 1 - // correct result - // no scratch memory, good VGPR allocation (59) - // good perf (101Tflops @ 1089Mhz) - __host__ __device__ constexpr float operator()(float v0, ck::half_t v1, ck::half_t v2) const - { - constexpr float alpha = 0.1; - constexpr float beta = 0.2; - constexpr float gamma = 0.3; - - // compiler seems very volatile to the order of these calculation: - // compiler is very eager to read AccVgpr (v0) out prematurely, resulting in register - // over-allocation. Therefore, move v0 calculation to the very end - float a = ck::half_t(beta) * v1 + ck::half_t(gamma) * v2; - float b = a + float(alpha) * v0; - - return b; - } -#elif 0 - float alpha = 0.1; - float beta = 0.2; - float gamma = 0.3; - - // wrong result - // lots of scratch memory - // huge perf drop - __host__ __device__ constexpr float operator()(float v0, ck::half_t v1, ck::half_t v2) const - { - return alpha * v0 + beta * v1 + gamma * v2; - } -#elif 0 - // correct result - // some scratch memory (68 dword) - // some perf drop (94Tflops @ 1089MHz) - // fp64 instructions are used - __host__ __device__ constexpr auto operator()(float v0, ck::half_t v1, ck::half_t v2) const - { - return 0.1 * v0 + 0.2 * v1 + 0.3 * v2; - } -#elif 1 - // wrong result - // lots of scratch memory - // huge perf drop - __host__ __device__ constexpr auto operator()(float v0, ck::half_t v1, ck::half_t v2) const - { - return float(0.1) * v0 + float(0.2) * v1 + float(0.3) * v2; - } -#endif -}; - -struct PassThrough -{ - template - __host__ __device__ constexpr T operator()(T v) const - { - return v; - } -}; +#include "element_wise_operation.hpp" +#include "device_gemm_xdl_c_shuffle_bias_activation_add.hpp" +#include "reference_gemm_bias_activation_add.hpp" template using S = ck::Sequence; @@ -125,58 +27,58 @@ using ALayout = ck::tensor_layout::gemm::RowMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; using CLayout = ck::tensor_layout::gemm::RowMajor; -using AOp = PassThrough; -using BOp = PassThrough; -#if 1 -using COp = BiasReluAdd; -#else -using COp = DoSomething; -#endif +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::AddReluAdd; -// Compilation parameters for NT problem // clang-format off -using DeviceGemmInstance = - //#################################################################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| AElementwise| BElementwise| CElementwise| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer| - //#################################################################| Type| Type| Type| Type| | | | Operation| Operation| Operation| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| - //#################################################################| | | | | | | | | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| - //#################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - ck::tensor_operation::device::DeviceGemmXdl_two_extra_source_reduce< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AOp, BOp, COp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>; +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl_C_Shuffle_Bias_Activation_Add< + ADataType, // ADataType + BDataType, // BDataType + CDataType, // CDataType + AccDataType, // AccDataType + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + 256, // BlockSize + 256, // MPerBlock + 128, // NPerBlock + 4, // K0PerBlock + 8, // K1 + 32, // MPerXDL + 32, // NPerXDL + 4, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 8, // ABlockTransferSrcScalarPerVector + 8, // ABlockTransferDstScalarPerVector_K1 + true, // ABlockLdsAddExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 8, // BBlockTransferSrcScalarPerVector + 8, // BBlockTransferDstScalarPerVector_K1 + true, // BBlockLdsAddExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 1, 32, 1, 1, 8>, // CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl // clang-format on -template -static void host_verify(const Tensor& a_m_k, - const Tensor& b_k_n, - Tensor& c_m_n, - const Tensor& c0_m_n, - const Tensor& c1_m_n, - const AElementwiseOperation& a_element_op, - const BElementwiseOperation& b_element_op, - const CElementwiseOperation& c_element_op) -{ - auto f_mk_kn_mn = [&](auto m, auto n) { - const int K = a_m_k.mDesc.GetLengths()[1]; - - float acc = 0; - - for(int k = 0; k < K; ++k) - { - acc += static_cast(a_element_op(a_m_k(m, k))) * - static_cast(b_element_op(b_k_n(k, n))); - } - - c_m_n(m, n) = c_element_op(acc, c0_m_n(m, n), c1_m_n(m, n)); - }; - - make_ParallelTensorFunctor(f_mk_kn_mn, - c_m_n.mDesc.GetLengths()[0], - c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency()); -} - +using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemmBiasActivationAdd; int main(int argc, char* argv[]) { bool do_verification = 0; @@ -188,9 +90,10 @@ int main(int argc, char* argv[]) ck::index_t N = 4096; ck::index_t K = 4096; - ck::index_t StrideA = 4096; - ck::index_t StrideB = 4096; - ck::index_t StrideC = 4096; + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + ck::index_t StrideC1 = 4096; if(argc == 4) { @@ -198,7 +101,7 @@ int main(int argc, char* argv[]) init_method = std::stoi(argv[2]); nrepeat = std::stoi(argv[3]); } - else if(argc == 10) + else if(argc == 11) { do_verification = std::stoi(argv[1]); init_method = std::stoi(argv[2]); @@ -208,16 +111,17 @@ int main(int argc, char* argv[]) N = std::stoi(argv[5]); K = std::stoi(argv[6]); - StrideA = std::stoi(argv[7]); - StrideB = std::stoi(argv[8]); - StrideC = std::stoi(argv[9]); + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + StrideC1 = std::stoi(argv[10]); } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: run kernel # of times (>1)\n"); - printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC\n"); + printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, StrideC1\n"); exit(0); } @@ -240,18 +144,17 @@ int main(int argc, char* argv[]) Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); - // C0[m] - Tensor c1_m_n(HostTensorDescriptor( - std::vector({static_cast(M), static_cast(N)}), - std::vector({1, 0}))); + // c0_n[n] + Tensor c0_n(HostTensorDescriptor( + std::vector({static_cast(N)}), std::vector({1}))); - // C1[m ,n] - Tensor c0_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + // c1_m_n[m ,n] + Tensor c1_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; - std::cout << "c0_m_n: " << c0_m_n.mDesc << std::endl; + std::cout << "c0_n: " << c0_n.mDesc << std::endl; std::cout << "c1_m_n: " << c1_m_n.mDesc << std::endl; switch(init_method) @@ -260,31 +163,31 @@ int main(int argc, char* argv[]) case 1: a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - c0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + c0_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); c1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; default: a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - c0_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + c0_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); c1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); - DeviceMem c0_m_n_device_buf(sizeof(CDataType) * c0_m_n.mDesc.GetElementSpace()); + DeviceMem c0_n_device_buf(sizeof(CDataType) * c0_n.mDesc.GetElementSpace()); DeviceMem c1_m_n_device_buf(sizeof(CDataType) * c1_m_n.mDesc.GetElementSpace()); a_m_k_device_buf.ToDevice(a_m_k.mData.data()); b_k_n_device_buf.ToDevice(b_k_n.mData.data()); c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); - c0_m_n_device_buf.ToDevice(c0_m_n.mData.data()); + c0_n_device_buf.ToDevice(c0_n.mData.data()); c1_m_n_device_buf.ToDevice(c1_m_n.mData.data()); - auto a_element_op = AOp{}; - auto b_element_op = BOp{}; - auto c_element_op = COp{}; + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; // do GEMM auto gemm = DeviceGemmInstance{}; @@ -293,7 +196,7 @@ int main(int argc, char* argv[]) auto argument = gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), static_cast(b_k_n_device_buf.GetDeviceBuffer()), static_cast(c_m_n_device_buf.GetDeviceBuffer()), - static_cast(c0_m_n_device_buf.GetDeviceBuffer()), + static_cast(c0_n_device_buf.GetDeviceBuffer()), static_cast(c1_m_n_device_buf.GetDeviceBuffer()), M, N, @@ -301,6 +204,7 @@ int main(int argc, char* argv[]) StrideA, StrideB, StrideC, + StrideC1, a_element_op, b_element_op, c_element_op); @@ -314,9 +218,10 @@ int main(int argc, char* argv[]) float ave_time = invoker.Run(argument, nrepeat); - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + sizeof(CDataType) * M * N; + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + + sizeof(CDataType) * M * N + sizeof(CDataType) * N + + sizeof(CDataType) * M * N; float tflops = static_cast(flop) / 1.E9 / ave_time; @@ -329,14 +234,19 @@ int main(int argc, char* argv[]) if(do_verification) { - host_verify(a_m_k, - b_k_n, - c_m_n_host_result, - c0_m_n, - c1_m_n, - PassThrough{}, - PassThrough{}, - c_element_op); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k, + b_k_n, + c_m_n_host_result, + c0_n, + c1_m_n, + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); check_error(c_m_n_host_result, c_m_n_device_result); } diff --git a/example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp b/example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp index 310de70b25..4c62a7af15 100644 --- a/example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp +++ b/example/4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp @@ -74,13 +74,8 @@ using DeviceConvFwdInstance = ck::tensor_operation::device:: 8>; // CBlockTransferScalarPerVector_NWaveNPerXdl // clang-format on -using ReferenceConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd; +using ReferenceConvFwdInstance = ck::tensor_operation::host:: + ReferenceConvFwd; int main(int argc, char* argv[]) { @@ -254,20 +249,21 @@ int main(int argc, char* argv[]) if(do_verification) { - auto refConv = ReferenceConvFwdInstance{}; - auto refInvoker = refConv.MakeInvoker(); + auto ref_conv = ReferenceConvFwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); - auto refArgument = refConv.MakeArgument(in_n_c_hi_wi, - wei_k_c_y_x, - out_n_k_ho_wo_host_result, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); - refInvoker.Run(refArgument); + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo_host_result, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + + ref_invoker.Run(ref_argument); out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); diff --git a/example/5_conv2d_fwd_xdl_bias_relu/conv2d_fwd_xdl_bias_relu.cpp b/example/5_conv2d_fwd_xdl_bias_relu/conv2d_fwd_xdl_bias_relu.cpp index 79bd332709..aa62e212d0 100644 --- a/example/5_conv2d_fwd_xdl_bias_relu/conv2d_fwd_xdl_bias_relu.cpp +++ b/example/5_conv2d_fwd_xdl_bias_relu/conv2d_fwd_xdl_bias_relu.cpp @@ -81,7 +81,6 @@ using ReferenceConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd_Bias_Activation; @@ -268,21 +267,21 @@ int main(int argc, char* argv[]) if(do_verification) { - auto refConv = ReferenceConvFwdInstance{}; - auto refInvoker = refConv.MakeInvoker(); + auto ref_conv = ReferenceConvFwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); - auto refArgument = refConv.MakeArgument(in_n_c_hi_wi, - wei_k_c_y_x, - out_n_k_ho_wo_host_result, - bias_k, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); - refInvoker.Run(refArgument); + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo_host_result, + bias_k, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + InElementOp{}, + WeiElementOp{}, + OutElementOp{}); + ref_invoker.Run(ref_argument); out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); diff --git a/example/6_conv2d_fwd_xdl_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp b/example/6_conv2d_fwd_xdl_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp index 2b1414b05b..a20a8cbb67 100644 --- a/example/6_conv2d_fwd_xdl_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp +++ b/example/6_conv2d_fwd_xdl_bias_relu_add/conv2d_fwd_xdl_bias_relu_add.cpp @@ -78,7 +78,6 @@ using ReferenceConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd_Bias_Activation_Add; @@ -228,6 +227,10 @@ int main(int argc, char* argv[]) bias_device_buf.ToDevice(bias_k.mData.data()); resi_device_buf.ToDevice(resi_n_k_ho_wo.mData.data()); + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + auto conv = DeviceConvFwdInstance{}; auto invoker = conv.MakeInvoker(); auto argument = @@ -246,9 +249,9 @@ int main(int argc, char* argv[]) conv_filter_dilations, input_left_pads, input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); + in_element_op, + wei_element_op, + out_element_op); if(!conv.IsSupportedArgument(argument)) { @@ -275,22 +278,23 @@ int main(int argc, char* argv[]) if(do_verification) { - auto refConv = ReferenceConvFwdInstance{}; - auto refInvoker = refConv.MakeInvoker(); + auto ref_conv = ReferenceConvFwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); - auto refArgument = refConv.MakeArgument(in_n_c_hi_wi, - wei_k_c_y_x, - out_n_k_ho_wo_host_result, - bias_k, - resi_n_k_ho_wo, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); - refInvoker.Run(refArgument); + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo_host_result, + bias_k, + resi_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); out_device_buf.FromDevice(out_n_k_ho_wo_device_result.mData.data()); diff --git a/example/7_conv2d_fwd_xdl_bias_relu_atomic_add/conv2d_fwd_xdl_bias_relu_atomic_add.cpp b/example/7_conv2d_fwd_xdl_bias_relu_atomic_add/conv2d_fwd_xdl_bias_relu_atomic_add.cpp index c47c094385..8f07cf066b 100644 --- a/example/7_conv2d_fwd_xdl_bias_relu_atomic_add/conv2d_fwd_xdl_bias_relu_atomic_add.cpp +++ b/example/7_conv2d_fwd_xdl_bias_relu_atomic_add/conv2d_fwd_xdl_bias_relu_atomic_add.cpp @@ -65,7 +65,8 @@ void host_reference_calculation(const Tensor& in_n_c_hi_wi, const OutElementOp& out_element_op) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { - double v = 0; + float v_acc = 0; + for(int c = 0; c < wei_k_c_y_x.mDesc.GetLengths()[1]; ++c) { for(int y = 0; y < wei_k_c_y_x.mDesc.GetLengths()[2]; ++y) @@ -77,14 +78,23 @@ void host_reference_calculation(const Tensor& in_n_c_hi_wi, if(hi >= 0 && hi < in_n_c_hi_wi.mDesc.GetLengths()[2] && wi >= 0 && wi < in_n_c_hi_wi.mDesc.GetLengths()[3]) { - v += in_element_op(static_cast(in_n_c_hi_wi(n, c, hi, wi))) * - wei_element_op(static_cast(wei_k_c_y_x(k, c, y, x))); + float v_in; + float v_wei; + + in_element_op(v_in, static_cast(in_n_c_hi_wi(n, c, hi, wi))); + wei_element_op(v_wei, static_cast(wei_k_c_y_x(k, c, y, x))); + + v_acc += v_in * v_wei; } } } } - out_n_k_ho_wo(n, k, ho, wo) += out_element_op(v, bias_k(k)); + float v_out; + + out_element_op(v_out, v_acc, static_cast(bias_k(k))); + + out_n_k_ho_wo(n, k, ho, wo) += v_out; }; make_ParallelTensorFunctor(f_nchw, diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index c25e78bf29..f9474425bc 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -2,8 +2,8 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR} ${PROJECT_SOURCE_DIR}/host/host_tensor/include ${PROJECT_SOURCE_DIR}/host/device/include - ${PROJECT_SOURCE_DIR}/host/include ${PROJECT_SOURCE_DIR}/device_operation/include + ${PROJECT_SOURCE_DIR}/reference_operation/include ${PROJECT_SOURCE_DIR}/composable_kernel/include ${PROJECT_SOURCE_DIR}/composable_kernel/include/utility ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description @@ -13,6 +13,7 @@ include_directories(BEFORE ) set(GEMM_XDL_SOURCE 1_gemm_xdl/gemm_xdl.cpp) +set(GEMM_XDL_BIAS_RELU_SOURCE 2_gemm_xdl_bias_relu/gemm_xdl_bias_relu.cpp) set(GEMM_XDL_BIAS_RELU_ADD_SOURCE 3_gemm_xdl_bias_relu_add/gemm_xdl_bias_relu_add.cpp) set(CONV2D_FWD_XDL_SOURCE 4_conv2d_fwd_xdl/conv2d_fwd_xdl.cpp) set(CONV2D_FWD_XDL_BIAS_RELU_SOURCE 5_conv2d_fwd_xdl_bias_relu/conv2d_fwd_xdl_bias_relu.cpp) @@ -20,6 +21,7 @@ set(CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURCE 6_conv2d_fwd_xdl_bias_relu_add/conv2d_fw set(CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE 7_conv2d_fwd_xdl_bias_relu_atomic_add/conv2d_fwd_xdl_bias_relu_atomic_add.cpp) add_executable(gemm_xdl ${GEMM_XDL_SOURCE}) +add_executable(gemm_xdl_bias_relu ${GEMM_XDL_BIAS_RELU_SOURCE}) add_executable(gemm_xdl_bias_relu_add ${GEMM_XDL_BIAS_RELU_ADD_SOURCE}) add_executable(conv2d_fwd_xdl ${CONV2D_FWD_XDL_SOURCE}) add_executable(conv2d_fwd_xdl_bias_relu ${CONV2D_FWD_XDL_BIAS_RELU_SOURCE}) @@ -27,6 +29,7 @@ add_executable(conv2d_fwd_xdl_bias_relu_add ${CONV2D_FWD_XDL_BIAS_RELU_ADD_SOURC add_executable(conv2d_fwd_xdl_bias_relu_atomic_add ${CONV2D_FWD_XDL_BIAS_RELU_ATOMIC_ADD_SOURCE}) target_link_libraries(gemm_xdl PRIVATE host_tensor) +target_link_libraries(gemm_xdl_bias_relu PRIVATE host_tensor) target_link_libraries(gemm_xdl_bias_relu_add PRIVATE host_tensor) target_link_libraries(conv2d_fwd_xdl PRIVATE host_tensor) target_link_libraries(conv2d_fwd_xdl_bias_relu PRIVATE host_tensor) diff --git a/host/host_tensor/include/host_gemm.hpp b/host/host_tensor/include/host_gemm.hpp index 23a163ad65..211c01c01a 100644 --- a/host/host_tensor/include/host_gemm.hpp +++ b/host/host_tensor/include/host_gemm.hpp @@ -17,15 +17,24 @@ void host_gemm_mk_kn_mn(const Tensor& a_m_k, auto f_mk_kn_mn = [&](auto m, auto n) { const int K = a_m_k.mDesc.GetLengths()[1]; - double v = 0; + float v_acc = 0; for(int k = 0; k < K; ++k) { - v += static_cast(a_element_op(a_m_k(m, k))) * - static_cast(b_element_op(b_k_n(k, n))); + float v_a; + float v_b; + + a_element_op(v_a, static_cast(a_m_k(m, k))); + b_element_op(v_b, static_cast(b_k_n(k, n))); + + v_acc += v_a * v_b; } - c_m_n(m, n) = c_element_op(v); + float v_c; + + c_element_op(v_c, v_acc); + + c_m_n(m, n) = v_c; }; make_ParallelTensorFunctor(f_mk_kn_mn, diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index 7de9e1a378..71e795b4d4 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -3,6 +3,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/host/host_tensor/include ${PROJECT_SOURCE_DIR}/device/include ${PROJECT_SOURCE_DIR}/device_operation/include + ${PROJECT_SOURCE_DIR}/reference_operation/include ${PROJECT_SOURCE_DIR}/profiler/include ${PROJECT_SOURCE_DIR}/composable_kernel/include ${PROJECT_SOURCE_DIR}/composable_kernel/include/utility @@ -12,87 +13,24 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/external/rocm/include ) -# device_gemm_instance -set(DEVICE_GEMM_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp; -) - -add_library(device_gemm_instance SHARED ${DEVICE_GEMM_INSTANCE_SOURCE}) -target_include_directories(device_gemm_instance SYSTEM PUBLIC $) -target_compile_features(device_gemm_instance PUBLIC) -set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) - -# device_conv2d_fwd_instance -set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp; - ${PROJECT_SOURCE_DIR}/device_operation/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp; -) - -add_library(device_conv2d_fwd_instance SHARED ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE}) -target_include_directories(device_conv2d_fwd_instance SYSTEM PUBLIC $) -target_compile_features(device_conv2d_fwd_instance PUBLIC) -set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib) - -# device_conv2d_fwd_bias_relu_instance -set(DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instance.cpp; -) - -add_library(device_conv2d_fwd_bias_relu_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE}) -target_include_directories(device_conv2d_fwd_bias_relu_instance SYSTEM PUBLIC $) -target_compile_features(device_conv2d_fwd_bias_relu_instance PUBLIC) -set_target_properties(device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib) - -# device_conv2d_fwd_bias_relu_add_instance -set(DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instance.cpp; -) - -add_library(device_conv2d_fwd_bias_relu_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE}) -target_include_directories(device_conv2d_fwd_bias_relu_add_instance SYSTEM PUBLIC $) -target_compile_features(device_conv2d_fwd_bias_relu_add_instance PUBLIC) -set_target_properties(device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib) - -# device_conv2d_fwd_bias_relu_atomic_add_instance -set(DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE - ${PROJECT_SOURCE_DIR}/device_operation/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp; -) - -add_library(device_conv2d_fwd_bias_relu_atomic_add_instance SHARED ${DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE}) -target_include_directories(device_conv2d_fwd_bias_relu_atomic_add_instance SYSTEM PUBLIC $) -target_compile_features(device_conv2d_fwd_bias_relu_atomic_add_instance PUBLIC) -set_target_properties(device_conv2d_fwd_bias_relu_atomic_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) -install(TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib) - # ck_profiler set(PROFILER_SOURCE - profiler.cpp - profile_gemm.cpp - profile_conv_fwd.cpp - profile_conv_fwd_bias_relu.cpp - profile_conv_fwd_bias_relu_add.cpp - profile_conv_fwd_bias_relu_atomic_add.cpp - ) + src/profiler.cpp + src/profile_gemm.cpp + src/profile_gemm_bias_relu.cpp + src/profile_gemm_bias_relu_add.cpp + src/profile_conv_fwd.cpp + src/profile_conv_fwd_bias_relu.cpp + src/profile_conv_fwd_bias_relu_add.cpp + src/profile_conv_fwd_bias_relu_atomic_add.cpp +) + add_executable(ckProfiler ${PROFILER_SOURCE}) target_link_libraries(ckProfiler PRIVATE host_tensor) target_link_libraries(ckProfiler PRIVATE device_gemm_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance) +target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) diff --git a/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp b/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp index d665321879..286323c629 100644 --- a/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp +++ b/profiler/include/profile_conv_fwd_bias_relu_add_impl.hpp @@ -3,11 +3,11 @@ #include "device.hpp" #include "host_tensor.hpp" #include "host_tensor_generator.hpp" -#include "host_conv.hpp" #include "tensor_layout.hpp" #include "device_tensor.hpp" -#include "device_conv_fwd_bias_activation_add.hpp" #include "element_wise_operation.hpp" +#include "device_conv_fwd_bias_activation_add.hpp" +#include "reference_conv_fwd_bias_activation_add.hpp" namespace ck { namespace tensor_operation { @@ -30,56 +30,6 @@ void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instan namespace ck { namespace profiler { -template -void host_reference_calculation(const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, - const Tensor& bias_k, - const Tensor& resi_n_k_ho_wo, - const std::vector& conv_strides, - const std::vector& conv_dilations, - const std::vector& in_left_pads, - const std::vector& /* in_right_pads */, - const InElementOp& in_element_op, - const WeiElementOp& wei_element_op, - const OutElementOp& out_element_op) -{ - auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { - double v = 0; - for(int c = 0; c < wei_k_c_y_x.mDesc.GetLengths()[1]; ++c) - { - for(int y = 0; y < wei_k_c_y_x.mDesc.GetLengths()[2]; ++y) - { - int hi = ho * conv_strides[0] + y * conv_dilations[0] - in_left_pads[0]; - for(int x = 0; x < wei_k_c_y_x.mDesc.GetLengths()[3]; ++x) - { - int wi = wo * conv_strides[1] + x * conv_dilations[1] - in_left_pads[1]; - if(hi >= 0 && hi < in_n_c_hi_wi.mDesc.GetLengths()[2] && wi >= 0 && - wi < in_n_c_hi_wi.mDesc.GetLengths()[3]) - { - v += in_element_op(static_cast(in_n_c_hi_wi(n, c, hi, wi))) * - wei_element_op(static_cast(wei_k_c_y_x(k, c, y, x))); - } - } - } - } - - out_n_k_ho_wo(n, k, ho, wo) = out_element_op(v, bias_k(k), resi_n_k_ho_wo(n, k, ho, wo)); - }; - - make_ParallelTensorFunctor(f_nchw, - out_n_k_ho_wo.mDesc.GetLengths()[0], - out_n_k_ho_wo.mDesc.GetLengths()[1], - out_n_k_ho_wo.mDesc.GetLengths()[2], - out_n_k_ho_wo.mDesc.GetLengths()[3])( - std::thread::hardware_concurrency()); -} - template ; + + auto ref_conv = ReferenceConvFwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo_host_result, + bias_k, + resi_n_k_ho_wo, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); } DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); @@ -240,9 +207,9 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification, conv_filter_dilations, input_left_pads, input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); + in_element_op, + wei_element_op, + out_element_op); auto invoker_ptr = op_ptr->MakeInvokerPointer(); diff --git a/profiler/include/profile_conv_fwd_bias_relu_impl.hpp b/profiler/include/profile_conv_fwd_bias_relu_impl.hpp index 955861dcf8..cd68f992e9 100644 --- a/profiler/include/profile_conv_fwd_bias_relu_impl.hpp +++ b/profiler/include/profile_conv_fwd_bias_relu_impl.hpp @@ -3,11 +3,11 @@ #include "device.hpp" #include "host_tensor.hpp" #include "host_tensor_generator.hpp" -#include "host_conv.hpp" #include "tensor_layout.hpp" #include "device_tensor.hpp" -#include "device_conv_fwd_bias_activation.hpp" #include "element_wise_operation.hpp" +#include "device_conv_fwd_bias_activation.hpp" +#include "reference_conv_fwd_bias_activation.hpp" namespace ck { namespace tensor_operation { @@ -30,84 +30,6 @@ void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances( namespace ck { namespace profiler { -void cpu_conv_bias_relu(ck::half_t* in_ptr, - ck::half_t* weight_ptr, - ck::half_t* output_ptr, - ck::half_t* bias_ptr, - const ck::index_t N, - const ck::index_t K, - const ck::index_t C, - const ck::index_t Y, - const ck::index_t X, - const ck::index_t Hi, - const ck::index_t Wi, - const ck::index_t Ho, - const ck::index_t Wo, - const ck::index_t Stride, - const ck::index_t Dilation, - const ck::index_t Pad) -{ - - const auto in_desc = - HostTensorDescriptor(std::vector{static_cast(N), - static_cast(Hi), - static_cast(Wi), - static_cast(C)}); - const auto wei_desc = - HostTensorDescriptor(std::vector{static_cast(K), - static_cast(Y), - static_cast(X), - static_cast(C)}); - const auto out_desc = - HostTensorDescriptor(std::vector{static_cast(N), - static_cast(Ho), - static_cast(Wo), - static_cast(K)}); - const auto bias_desc = - HostTensorDescriptor(std::vector{static_cast(K)}); - - auto f_k = [&](auto k) { - for(int n = 0; n < N; ++n) - { - for(int ho = 0; ho < Ho; ++ho) - { - for(int wo = 0; wo < Wo; ++wo) - { - double v = 0; - for(int c = 0; c < C; ++c) - { - for(int y = 0; y < Y; ++y) - { - int hi = ho * Stride + y * Dilation - Pad; - for(int x = 0; x < X; ++x) - { - int wi = wo * Stride + x * Dilation - Pad; - if(hi >= 0 && hi < Hi && wi >= 0 && wi < Wi) - { - double in = - in_ptr[in_desc.GetOffsetFromMultiIndex(n, hi, wi, c)]; - double wei = - weight_ptr[wei_desc.GetOffsetFromMultiIndex(k, y, x, c)]; - - v += in * wei; - } - } - } - } - - v += bias_ptr[bias_desc.GetOffsetFromMultiIndex(k)]; - - v = v > 0 ? v : 0; - - output_ptr[out_desc.GetOffsetFromMultiIndex(n, ho, wo, k)] = v; - } - } - } - }; - - make_ParallelTensorFunctor(f_k, K)(std::thread::hardware_concurrency()); -} - template ; + + auto ref_conv = ReferenceConvFwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo_host_result, + bias_k, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + ref_invoker.Run(ref_argument); } DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); @@ -263,9 +196,9 @@ void profile_conv_fwd_bias_relu_impl(int do_verification, conv_filter_dilations, input_left_pads, input_right_pads, - InElementOp{}, - WeiElementOp{}, - OutElementOp{}); + in_element_op, + wei_element_op, + out_element_op); auto invoker_ptr = op_ptr->MakeInvokerPointer(); diff --git a/profiler/include/profile_conv_fwd_impl.hpp b/profiler/include/profile_conv_fwd_impl.hpp index 6e79bf4b4a..1eac6218d2 100644 --- a/profiler/include/profile_conv_fwd_impl.hpp +++ b/profiler/include/profile_conv_fwd_impl.hpp @@ -3,11 +3,11 @@ #include "device.hpp" #include "host_tensor.hpp" #include "host_tensor_generator.hpp" -#include "host_conv.hpp" #include "tensor_layout.hpp" #include "device_tensor.hpp" #include "device_conv_fwd.hpp" #include "element_wise_operation.hpp" +#include "reference_conv_fwd.hpp" namespace ck { namespace tensor_operation { @@ -105,15 +105,37 @@ void profile_conv_fwd_impl(int do_verification, wei_k_c_y_x.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); } + using InElementOp = ck::tensor_operation::element_wise::PassThrough; + using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; + using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto in_element_op = InElementOp{}; + const auto wei_element_op = WeiElementOp{}; + const auto out_element_op = OutElementOp{}; + if(do_verification) { - host_conv_nchw_kcyx_nkhw(in_n_c_hi_wi, - wei_k_c_y_x, - out_n_k_ho_wo_host_result, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads); + using ReferenceConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd; + + auto ref_conv = ReferenceConvFwdInstance{}; + auto ref_invoker = ref_conv.MakeInvoker(); + auto ref_argument = ref_conv.MakeArgument(in_n_c_hi_wi, + wei_k_c_y_x, + out_n_k_ho_wo_host_result, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + + ref_invoker.Run(ref_argument); } DeviceMem in_device_buf(sizeof(InDataType) * in_n_c_hi_wi.mDesc.GetElementSpace()); @@ -177,9 +199,9 @@ void profile_conv_fwd_impl(int do_verification, conv_filter_dilations, input_left_pads, input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}); + in_element_op, + wei_element_op, + out_element_op); auto invoker_ptr = conv_ptr->MakeInvokerPointer(); diff --git a/profiler/include/profile_gemm_bias_relu_add_impl.hpp b/profiler/include/profile_gemm_bias_relu_add_impl.hpp new file mode 100644 index 0000000000..f6625a8b22 --- /dev/null +++ b/profiler/include/profile_gemm_bias_relu_add_impl.hpp @@ -0,0 +1,286 @@ +#pragma once +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "device_gemm_bias_activation_add.hpp" +#include "reference_gemm_bias_activation_add.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using DeviceGemmBiasReluAddPtr = ck::tensor_operation::device::DeviceGemmBiasActivationAddPtr< + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddReluAdd>; + +void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances( + std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +void profile_gemm_bias_relu_add_impl(int do_verification, + int init_method, + bool do_log, + int nrepeat, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int StrideC1, + int KBatch = 1) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + // c0_n[n] + Tensor c0_n(HostTensorDescriptor( + std::vector({static_cast(N)}), std::vector({1}))); + + // c1_m_n[m ,n] + Tensor c1_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "c0_n: " << c0_n.mDesc << std::endl; + std::cout << "c1_m_n: " << c1_m_n.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + c0_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + c1_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + c0_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + c1_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + // set zero to c_device_buf + c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}); + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::AddReluAdd; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + if(do_verification) + { + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemmBiasActivationAdd; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_m_k, + b_k_n, + c_m_n_host_result, + c0_n, + c1_m_n, + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem c0_n_device_buf(sizeof(CDataType) * c0_n.mDesc.GetElementSpace()); + DeviceMem c1_m_n_device_buf(sizeof(CDataType) * c1_m_n.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + c_device_buf.ToDevice(c_m_n_device_result.mData.data()); + c0_n_device_buf.ToDevice(c0_n.mData.data()); + c1_m_n_device_buf.ToDevice(c1_m_n.mData.data()); + + // add device GEMM instances + std::vector + gemm_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_kn_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_mk_nk_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_kn_mn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_relu_add_f16_f16_f16_km_nk_mn_instances( + gemm_ptrs); + } + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = gemm_ptr->MakeArgumentPointer( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + static_cast(c0_n_device_buf.GetDeviceBuffer()), + static_cast(c1_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideC1, + a_element_op, + b_element_op, + c_element_op, + KBatch); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string gemm_name = gemm_ptr->GetTypeString(); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + + sizeof(CDataType) * M * N + sizeof(CDataType) * N + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + check_error(c_m_n_host_result, c_m_n_device_result); + + if(do_log) + { + LogRangeAsType(std::cout << "a: ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c0: ", c0_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c1: ", c1_m_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host: ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << "does not support this GEMM problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_gemm_bias_relu_impl.hpp b/profiler/include/profile_gemm_bias_relu_impl.hpp new file mode 100644 index 0000000000..e403a88d58 --- /dev/null +++ b/profiler/include/profile_gemm_bias_relu_impl.hpp @@ -0,0 +1,264 @@ +#pragma once +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "device_gemm_bias_activation.hpp" +#include "reference_gemm_bias_activation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using DeviceGemmBiasReluPtr = ck::tensor_operation::device::DeviceGemmBiasActivationPtr< + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::AddRelu>; + +void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances( + std::vector&); +void add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances( + std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +void profile_gemm_bias_relu_impl(int do_verification, + int init_method, + bool do_log, + int nrepeat, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int KBatch = 1) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + // c0_n[n] + Tensor c0_n(HostTensorDescriptor( + std::vector({static_cast(N)}), std::vector({1}))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + std::cout << "c0_n: " << c0_n.mDesc << std::endl; + + std::size_t num_thread = std::thread::hardware_concurrency(); + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + c0_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + c0_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + } + + // set zero to c_device_buf + c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::AddRelu; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + if(do_verification) + { + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceGemmBiasActivation; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, c0_n, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem c0_n_device_buf(sizeof(CDataType) * c0_n.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + c_device_buf.ToDevice(c_m_n_device_result.mData.data()); + c0_n_device_buf.ToDevice(c0_n.mData.data()); + + // add device GEMM instances + std::vector + gemm_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_bias_relu_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + } + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = gemm_ptr->MakeArgumentPointer( + static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + static_cast(c0_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + KBatch); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string gemm_name = gemm_ptr->GetTypeString(); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + + sizeof(CDataType) * M * N + sizeof(CDataType) * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + check_error(c_m_n_host_result, c_m_n_device_result); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c0 : ", c0_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << "does not support this GEMM problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profile_gemm_impl.hpp b/profiler/include/profile_gemm_impl.hpp index 596770190b..9962c6579d 100644 --- a/profiler/include/profile_gemm_impl.hpp +++ b/profiler/include/profile_gemm_impl.hpp @@ -1,4 +1,14 @@ #pragma once +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "device_gemm.hpp" +#include "reference_gemm.hpp" namespace ck { namespace tensor_operation { @@ -15,6 +25,11 @@ void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector&); void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector&); +void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector&); + void add_device_gemm_xdl_f32_f32_f32_mk_kn_mn_instances(std::vector&); void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(std::vector&); void add_device_gemm_xdl_f32_f32_f32_km_kn_mn_instances(std::vector&); @@ -86,17 +101,30 @@ void profile_gemm_impl(int do_verification, a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); } + // set zero to c_device_buf c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0{}, num_thread); + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + if(do_verification) { - host_gemm_mk_kn_mn(a_m_k, - b_k_n, - c_m_n_host_result, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}, - ck::tensor_operation::element_wise::PassThrough{}); + using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); } DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); @@ -184,6 +212,9 @@ void profile_gemm_impl(int do_verification, { ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs); } else if constexpr(is_same::value && is_same::value && @@ -191,6 +222,9 @@ void profile_gemm_impl(int do_verification, { ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs); } else if constexpr(is_same::value && is_same::value && @@ -198,6 +232,9 @@ void profile_gemm_impl(int do_verification, { ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemm_ptrs); } else if constexpr(is_same::value && is_same::value && @@ -205,6 +242,9 @@ void profile_gemm_impl(int do_verification, { ck::tensor_operation::device::device_gemm_instance:: add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); + + ck::tensor_operation::device::device_gemm_instance:: + add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemm_ptrs); } } @@ -283,8 +323,7 @@ void profile_gemm_impl(int do_verification, } else { - std::cout << "this device GEMM instance does not support this GEMM problem" - << std::endl; + std::cout << "does not support this GEMM problem" << std::endl; } } diff --git a/profiler/profile_conv_fwd.cpp b/profiler/src/profile_conv_fwd.cpp similarity index 100% rename from profiler/profile_conv_fwd.cpp rename to profiler/src/profile_conv_fwd.cpp diff --git a/profiler/profile_conv_fwd_bias_relu.cpp b/profiler/src/profile_conv_fwd_bias_relu.cpp similarity index 100% rename from profiler/profile_conv_fwd_bias_relu.cpp rename to profiler/src/profile_conv_fwd_bias_relu.cpp diff --git a/profiler/profile_conv_fwd_bias_relu_add.cpp b/profiler/src/profile_conv_fwd_bias_relu_add.cpp similarity index 100% rename from profiler/profile_conv_fwd_bias_relu_add.cpp rename to profiler/src/profile_conv_fwd_bias_relu_add.cpp diff --git a/profiler/profile_conv_fwd_bias_relu_atomic_add.cpp b/profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp similarity index 100% rename from profiler/profile_conv_fwd_bias_relu_atomic_add.cpp rename to profiler/src/profile_conv_fwd_bias_relu_atomic_add.cpp diff --git a/profiler/profile_gemm.cpp b/profiler/src/profile_gemm.cpp similarity index 97% rename from profiler/profile_gemm.cpp rename to profiler/src/profile_gemm.cpp index 37d5b4f2ee..8e1c64ac01 100644 --- a/profiler/profile_gemm.cpp +++ b/profiler/src/profile_gemm.cpp @@ -4,15 +4,6 @@ #include #include #include -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor.hpp" -#include "host_tensor_generator.hpp" -#include "host_gemm.hpp" -#include "device_tensor.hpp" -#include "device_base.hpp" -#include "device_gemm_xdl.hpp" #include "profile_gemm_impl.hpp" enum GemmMatrixLayout diff --git a/profiler/src/profile_gemm_bias_relu.cpp b/profiler/src/profile_gemm_bias_relu.cpp new file mode 100644 index 0000000000..a0c7832dc0 --- /dev/null +++ b/profiler/src/profile_gemm_bias_relu.cpp @@ -0,0 +1,148 @@ +#include +#include +#include +#include +#include +#include +#include "profile_gemm_bias_relu_impl.hpp" + +enum GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + MK_KN_NM, // 4 + MK_NK_NM, // 5 + KM_KN_NM, // 6 + KM_NK_NM, // 7 +}; + +enum GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 +}; + +int profile_gemm_bias_relu(int argc, char* argv[]) +{ + if(!(argc == 14 || argc == 15)) + { + printf("arg1: tensor operation (gemm: GEMM+Bias+ReLU)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg7: run kernel # of times (>1)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); + printf("arg14: split k into mulitiple batch\n"); + exit(1); + } + + const int data_type = static_cast(std::stoi(argv[2])); + const int layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const int nrepeat = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + + int KBatch = 1; + + if(argc == 15) + KBatch = std::stoi(argv[14]); + + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm_bias_relu_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_gemm_bias_relu_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_gemm_bias_relu_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm_bias_relu_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC); + } + else + { + throw std::runtime_error("wrong! this data_type & layout is not implemented"); + } + + return 1; +} diff --git a/profiler/src/profile_gemm_bias_relu_add.cpp b/profiler/src/profile_gemm_bias_relu_add.cpp new file mode 100644 index 0000000000..8d5e4e3f7f --- /dev/null +++ b/profiler/src/profile_gemm_bias_relu_add.cpp @@ -0,0 +1,153 @@ +#include +#include +#include +#include +#include +#include +#include "profile_gemm_bias_relu_add_impl.hpp" + +enum GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + MK_KN_NM, // 4 + MK_NK_NM, // 5 + KM_KN_NM, // 6 + KM_NK_NM, // 7 +}; + +enum GemmDataType +{ + F32_F32_F32, // 0 + F16_F16_F16, // 1 +}; + +int profile_gemm_bias_relu_add(int argc, char* argv[]) +{ + if(!(argc == 15 || argc == 16)) + { + printf("arg1: tensor operation (gemm: GEMM+Bias+ReLU+Add)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg7: run kernel # of times (>1)\n"); + printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, StrideC1\n"); + printf("arg15: split k into mulitiple batch\n"); + exit(1); + } + + const int data_type = static_cast(std::stoi(argv[2])); + const int layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const int nrepeat = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + const int StrideC1 = std::stoi(argv[14]); + + int KBatch = 1; + + if(argc == 16) + KBatch = std::stoi(argv[15]); + + if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) + { + ck::profiler::profile_gemm_bias_relu_add_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + ck::profiler::profile_gemm_bias_relu_add_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) + { + ck::profiler::profile_gemm_bias_relu_add_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) + { + ck::profiler::profile_gemm_bias_relu_add_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + (StrideC1 < 0) ? N : StrideC1); + } + else + { + throw std::runtime_error("wrong! this data_type & layout is not implemented"); + } + + return 1; +} diff --git a/profiler/profiler.cpp b/profiler/src/profiler.cpp similarity index 64% rename from profiler/profiler.cpp rename to profiler/src/profiler.cpp index a8d3322872..6855d5bdce 100644 --- a/profiler/profiler.cpp +++ b/profiler/src/profiler.cpp @@ -6,6 +6,8 @@ #include int profile_gemm(int, char*[]); +int profile_gemm_bias_relu(int, char*[]); +int profile_gemm_bias_relu_add(int, char*[]); int profile_conv_fwd(int, char*[]); int profile_conv_fwd_bias_relu(int, char*[]); int profile_conv_fwd_bias_relu_add(int, char*[]); @@ -17,6 +19,14 @@ int main(int argc, char* argv[]) { return profile_gemm(argc, argv); } + if(strcmp(argv[1], "gemm_bias_relu") == 0) + { + return profile_gemm_bias_relu(argc, argv); + } + if(strcmp(argv[1], "gemm_bias_relu_add") == 0) + { + return profile_gemm_bias_relu_add(argc, argv); + } else if(strcmp(argv[1], "conv_fwd") == 0) { return profile_conv_fwd(argc, argv); @@ -35,12 +45,16 @@ int main(int argc, char* argv[]) } else { - printf("arg1: tensor operation (gemm: GEMM;\n" - " conv_fwd: ForwardConvolution;\n" - " conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU)\n" - " conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add)\n" - " conv_fwd_bias_relu_atomic_add: " - "ForwardConvolution+Bias+ReLU+AtomicAdd)\n"); + // clang-format off + printf("arg1: tensor operation (gemm: GEMM\n" + " gemm_bias_relu: GEMM+Bias+ReLU\n" + " gemm_bias_relu_add: GEMM+Bias+ReLU+Add\n" + " conv_fwd: ForwardConvolution\n" + " conv_fwd_bias_relu: ForwardConvolution+Bias+ReLU\n" + " conv_fwd_bias_relu_add: ForwardConvolution+Bias+ReLU+Add\n" + " conv_fwd_bias_relu_atomic_add: ForwardConvolution+Bias+ReLU+AtomicAdd\n"); + // clang-format on + return 0; } } diff --git a/host/include/reference_conv_fwd.hpp b/reference_operation/include/reference_conv_fwd.hpp similarity index 89% rename from host/include/reference_conv_fwd.hpp rename to reference_operation/include/reference_conv_fwd.hpp index a92ed95b3c..f929f3cda5 100644 --- a/host/include/reference_conv_fwd.hpp +++ b/reference_operation/include/reference_conv_fwd.hpp @@ -14,7 +14,6 @@ namespace host { template @@ -68,7 +67,8 @@ struct ReferenceConvFwd : public device::BaseOperator float Run(const Argument& arg) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { - float v = 0; + float v_acc = 0; + for(int c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) { for(int y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) @@ -82,17 +82,26 @@ struct ReferenceConvFwd : public device::BaseOperator if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 && wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) { - v += arg.in_element_op_( - ck::type_convert(arg.in_n_c_hi_wi_(n, c, hi, wi))) * - arg.wei_element_op_( - ck::type_convert(arg.wei_k_c_y_x_(k, c, y, x))); + float v_in; + float v_wei; + + arg.in_element_op_( + v_in, + static_cast(arg.in_n_c_hi_wi_(n, c, hi, wi))); + arg.wei_element_op_( + v_wei, static_cast(arg.wei_k_c_y_x_(k, c, y, x))); + + v_acc += v_in * v_wei; } } } } - arg.out_n_k_ho_wo_(n, k, ho, wo) = - ck::type_convert(arg.out_element_op_(v)); + float v_out; + + arg.out_element_op_(v_out, v_acc); + + arg.out_n_k_ho_wo_(n, k, ho, wo) = v_out; }; make_ParallelTensorFunctor(f_nchw, @@ -101,6 +110,7 @@ struct ReferenceConvFwd : public device::BaseOperator arg.out_n_k_ho_wo_.mDesc.GetLengths()[2], arg.out_n_k_ho_wo_.mDesc.GetLengths()[3])( std::thread::hardware_concurrency()); + return 0; } @@ -160,6 +170,7 @@ struct ReferenceConvFwd : public device::BaseOperator return str.str(); } }; + } // namespace host } // namespace tensor_operation } // namespace ck diff --git a/host/include/reference_conv_fwd_bias_activation.hpp b/reference_operation/include/reference_conv_fwd_bias_activation.hpp similarity index 89% rename from host/include/reference_conv_fwd_bias_activation.hpp rename to reference_operation/include/reference_conv_fwd_bias_activation.hpp index d65bba1a88..8f49b79a1a 100644 --- a/host/include/reference_conv_fwd_bias_activation.hpp +++ b/reference_operation/include/reference_conv_fwd_bias_activation.hpp @@ -15,7 +15,6 @@ namespace host { template @@ -72,7 +71,8 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator float Run(const Argument& arg) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { - float v = 0; + float v_acc = 0; + for(int c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) { for(int y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) @@ -86,17 +86,26 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 && wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) { - v += arg.in_element_op_( - ck::type_convert(arg.in_n_c_hi_wi_(n, c, hi, wi))) * - arg.wei_element_op_( - ck::type_convert(arg.wei_k_c_y_x_(k, c, y, x))); + float v_in; + float v_wei; + + arg.in_element_op_( + v_in, + static_cast(arg.in_n_c_hi_wi_(n, c, hi, wi))); + arg.wei_element_op_( + v_wei, static_cast(arg.wei_k_c_y_x_(k, c, y, x))); + + v_acc += v_in * v_wei; } } } } - arg.out_n_k_ho_wo_(n, k, ho, wo) = - ck::type_convert(arg.out_element_op_(v, arg.bias_k_(k))); + float v_out; + + arg.out_element_op_(v_out, v_acc, static_cast(arg.bias_k_(k))); + + arg.out_n_k_ho_wo_(n, k, ho, wo) = v_out; }; make_ParallelTensorFunctor(f_nchw, @@ -166,6 +175,7 @@ struct ReferenceConvFwd_Bias_Activation : public device::BaseOperator return str.str(); } }; + } // namespace host } // namespace tensor_operation } // namespace ck diff --git a/host/include/reference_conv_fwd_bias_activation_add.hpp b/reference_operation/include/reference_conv_fwd_bias_activation_add.hpp similarity index 88% rename from host/include/reference_conv_fwd_bias_activation_add.hpp rename to reference_operation/include/reference_conv_fwd_bias_activation_add.hpp index eb4b708c12..e4e0899416 100644 --- a/host/include/reference_conv_fwd_bias_activation_add.hpp +++ b/reference_operation/include/reference_conv_fwd_bias_activation_add.hpp @@ -15,7 +15,6 @@ namespace host { template @@ -75,7 +74,8 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator float Run(const Argument& arg) { auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { - float v = 0; + float v_acc = 0; + for(int c = 0; c < arg.wei_k_c_y_x_.mDesc.GetLengths()[1]; ++c) { for(int y = 0; y < arg.wei_k_c_y_x_.mDesc.GetLengths()[2]; ++y) @@ -89,23 +89,29 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator if(hi >= 0 && hi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[2] && wi >= 0 && wi < arg.in_n_c_hi_wi_.mDesc.GetLengths()[3]) { - v += arg.in_element_op_( - ck::type_convert(arg.in_n_c_hi_wi_(n, c, hi, wi))) * - arg.wei_element_op_( - ck::type_convert(arg.wei_k_c_y_x_(k, c, y, x))); + float v_in; + float v_wei; + + arg.in_element_op_( + v_in, + static_cast(arg.in_n_c_hi_wi_(n, c, hi, wi))); + arg.wei_element_op_( + v_wei, static_cast(arg.wei_k_c_y_x_(k, c, y, x))); + + v_acc += v_in * v_wei; } } } } - float v2 = ck::type_convert(arg.out_n_k_ho_wo_(n, k, ho, wo)); + float v_out; - arg.out_element_op_(v2, - v, - ck::type_convert(arg.bias_k_(k)), - ck::type_convert(arg.resi_n_k_ho_wo_(n, k, ho, wo))); + arg.out_element_op_(v_out, + v_acc, + static_cast(arg.bias_k_(k)), + static_cast(arg.resi_n_k_ho_wo_(n, k, ho, wo))); - arg.out_n_k_ho_wo_(n, k, ho, wo) = ck::type_convert(v2); + arg.out_n_k_ho_wo_(n, k, ho, wo) = v_out; }; make_ParallelTensorFunctor(f_nchw, @@ -177,6 +183,7 @@ struct ReferenceConvFwd_Bias_Activation_Add : public device::BaseOperator return str.str(); } }; + } // namespace host } // namespace tensor_operation } // namespace ck diff --git a/reference_operation/include/reference_gemm.hpp b/reference_operation/include/reference_gemm.hpp new file mode 100644 index 0000000000..3601fafc28 --- /dev/null +++ b/reference_operation/include/reference_gemm.hpp @@ -0,0 +1,132 @@ +#ifndef REFERENCE_GEMM_HPP +#define REFERENCE_GEMM_HPP + +#include +#include +#include "device_base.hpp" +#include "host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceGemm : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + c_m_n_{c_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + Tensor& c_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemm::Argument; + + float Run(const Argument& arg) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + float v_acc = 0; + + for(int k = 0; k < K; ++k) + { + float v_a; + float v_b; + + arg.a_element_op_(v_a, static_cast(arg.a_m_k_(m, k))); + arg.b_element_op_(v_b, static_cast(arg.b_k_n_(k, n))); + + v_acc += v_a * v_b; + } + + float v_c; + + arg.c_element_op_(v_c, v_acc); + + arg.c_m_n_(m, n) = v_c; + }; + + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, int) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemm" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/reference_operation/include/reference_gemm_bias_activation.hpp b/reference_operation/include/reference_gemm_bias_activation.hpp new file mode 100644 index 0000000000..7c9df272c2 --- /dev/null +++ b/reference_operation/include/reference_gemm_bias_activation.hpp @@ -0,0 +1,136 @@ +#ifndef REFERENCE_GEMM_BIAS_ACTIVATION_HPP +#define REFERENCE_GEMM_BIAS_ACTIVATION_HPP + +#include +#include +#include "device_base.hpp" +#include "host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceGemmBiasActivation : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const Tensor& c0_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + c_m_n_{c_m_n}, + c0_n_{c0_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + Tensor& c_m_n_; + const Tensor& c0_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemmBiasActivation::Argument; + + float Run(const Argument& arg) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + float v_acc = 0; + + for(int k = 0; k < K; ++k) + { + float v_a; + float v_b; + + arg.a_element_op_(v_a, static_cast(arg.a_m_k_(m, k))); + arg.b_element_op_(v_b, static_cast(arg.b_k_n_(k, n))); + + v_acc += v_a * v_b; + } + + float v_c; + + arg.c_element_op_(v_c, v_acc, static_cast(arg.c0_n_(n))); + + arg.c_m_n_(m, n) = v_c; + }; + + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, int) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const Tensor& c0_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{a_m_k, b_k_n, c_m_n, c0_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemmBiasActivation" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/reference_operation/include/reference_gemm_bias_activation_add.hpp b/reference_operation/include/reference_gemm_bias_activation_add.hpp new file mode 100644 index 0000000000..4d3c5effae --- /dev/null +++ b/reference_operation/include/reference_gemm_bias_activation_add.hpp @@ -0,0 +1,144 @@ +#ifndef REFERENCE_GEMM_BIAS_ACTIVATION_ADD_HPP +#define REFERENCE_GEMM_BIAS_ACTIVATION_ADD_HPP + +#include +#include +#include "device_base.hpp" +#include "host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceGemmBiasActivationAdd : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const Tensor& c0_n, + const Tensor& c1_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : a_m_k_{a_m_k}, + b_k_n_{b_k_n}, + c_m_n_{c_m_n}, + c0_n_{c0_n}, + c1_m_n_{c1_m_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& a_m_k_; + const Tensor& b_k_n_; + Tensor& c_m_n_; + const Tensor& c0_n_; + const Tensor& c1_m_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceGemmBiasActivationAdd::Argument; + + float Run(const Argument& arg) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_m_k_.mDesc.GetLengths()[1]; + + float v_acc = 0; + + for(int k = 0; k < K; ++k) + { + float v_a; + float v_b; + + arg.a_element_op_(v_a, static_cast(arg.a_m_k_(m, k))); + arg.b_element_op_(v_b, static_cast(arg.b_k_n_(k, n))); + + v_acc += v_a * v_b; + } + + float v_c; + + arg.c_element_op_(v_c, + v_acc, + static_cast(arg.c0_n_(n)), + static_cast(arg.c1_m_n_(m, n))); + + arg.c_m_n_(m, n) = v_c; + }; + + make_ParallelTensorFunctor( + f_mk_kn_mn, arg.c_m_n_.mDesc.GetLengths()[0], arg.c_m_n_.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, int) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const Tensor& c0_n, + const Tensor& c1_m_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{ + a_m_k, b_k_n, c_m_n, c0_n, c1_m_n, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceGemmBiasActivationAdd" + << std::endl; + // clang-format on + + return str.str(); + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/script/conv2d_fwd.sh b/script/conv2d_fwd.sh new file mode 100755 index 0000000000..acc91e194f --- /dev/null +++ b/script/conv2d_fwd.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +## GPU visibility + export HIP_VISIBLE_DEVICES=0 + + make -j $1 + +DRIVER=example/$1 +VERIFY=$2 +INIT=$3 +REPEAT=$4 + +# test +######## verify init repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ + $DRIVER $VERIFY $INIT $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT 128 256 64 1 1 1 1 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT 256 64 3 7 7 230 230 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT 256 64 3 7 7 224 224 2 2 1 1 3 3 3 3 + + N=$5 + +# Resnet50 +######## verify init repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ +#$DRIVER $VERIFY $INIT $REPEAT $N 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 128 128 3 3 58 58 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 256 256 3 3 30 30 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 128 256 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 512 256 1 1 56 56 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 512 512 3 3 16 16 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 256 512 1 1 28 28 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $VERIFY $INIT $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE diff --git a/script/gemm.sh b/script/gemm.sh new file mode 100755 index 0000000000..395db86d09 --- /dev/null +++ b/script/gemm.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +## GPU visibility + export HIP_VISIBLE_DEVICES=0 + + make -j $1 + +DRIVER=example/$1 +VERIFY=$2 +INIT=$3 +REPEAT=$4 + +######## verify init repeat M___ N___ K___ StrideA StrideB StrideC StrideC1 +#$DRIVER $VERIFY $INIT $REPEAT 256 256 256 256 256 256 256 +#$DRIVER $VERIFY $INIT $REPEAT 960 1024 1024 1024 1024 1024 1024 +#$DRIVER $VERIFY $INIT $REPEAT 1920 2048 2048 2048 2048 2048 2048 + $DRIVER $VERIFY $INIT $REPEAT 3840 4096 4096 4096 4096 4096 4096 +#$DRIVER $VERIFY $INIT $REPEAT 7680 8192 8192 8192 8192 8192 8192 +#$DRIVER $VERIFY $INIT $REPEAT 1024 1024 1024 1024 1024 1024 1024 +#$DRIVER $VERIFY $INIT $REPEAT 2048 2048 2048 2048 2048 2048 2048 diff --git a/script/pool2d_fwd.sh b/script/pool2d_fwd.sh new file mode 100755 index 0000000000..10acf5394e --- /dev/null +++ b/script/pool2d_fwd.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +## GPU visibility + export HIP_VISIBLE_DEVICES=0 + + make -j $1 + +DRIVER=example/$1 +VERIFY=$2 +INIT=$3 +REPEAT=$4 + +# test +######## verify init repeat N__ C___ Y X Hi__ Wi__ Strides LeftPads RightPads +#$DRIVER $VERIFY $INIT $REPEAT 128 192 3 3 71 71 2 2 1 1 1 1 +#$DRIVER $VERIFY $INIT $REPEAT 128 64 1 1 1 1 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT 256 3 7 7 230 230 2 2 0 0 0 0 + $DRIVER $VERIFY $INIT $REPEAT 256 1024 14 14 14 14 1 1 0 0 0 0 + + N=$5 + +# Resnet50 +######## verify init repeat N__ C___ Y X Hi__ Wi__ Strides LeftPads RightPads +#$DRIVER $VERIFY $INIT $REPEAT $N 1024 1 1 14 14 2 2 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 1024 1 1 14 14 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 1024 1 1 14 14 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 128 3 3 28 28 1 1 1 1 1 1 +#$DRIVER $VERIFY $INIT $REPEAT $N 128 1 1 28 28 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 128 3 3 58 58 2 2 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 2048 1 1 7 7 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 256 1 1 14 14 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 256 3 3 14 14 1 1 1 1 1 1 +#$DRIVER $VERIFY $INIT $REPEAT $N 256 3 3 30 30 2 2 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 256 1 1 56 56 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 256 1 1 56 56 2 2 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 256 1 1 56 56 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 512 3 3 16 16 2 2 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 512 1 1 28 28 2 2 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 512 1 1 28 28 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 512 1 1 28 28 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 512 1 1 7 7 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 512 3 3 7 7 1 1 1 1 1 1 +#$DRIVER $VERIFY $INIT $REPEAT $N 64 1 1 56 56 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 64 1 1 56 56 1 1 0 0 0 0 +#$DRIVER $VERIFY $INIT $REPEAT $N 64 3 3 56 56 1 1 1 1 1 1 +#$DRIVER $VERIFY $INIT $REPEAT $N 3 7 7 230 230 2 2 0 0 0 0 diff --git a/script/profile_conv.sh b/script/profile_conv.sh index 578b63e8db..f3a6d2c70c 100755 --- a/script/profile_conv.sh +++ b/script/profile_conv.sh @@ -19,11 +19,89 @@ REPEAT=$9 # test ######## op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ - $DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 28 28 2 2 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE + + N=${10} + +# Resnet50 from Bing +######## op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 56 56 2 2 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 28 28 2 2 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 256 1 1 56 56 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 14 14 2 2 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 512 1 1 28 28 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 1 1 56 56 1 1 1 1 0 0 0 0 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 8 7 7 224 224 2 2 1 1 3 3 3 3 +# Resnet50 from Bing +#################### op____________________ datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 3 7 7 224 224 2 2 1 1 3 3 3 3 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 1 1 56 56 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 256 1 1 56 56 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 256 1 1 56 56 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 56 56 2 2 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 512 1 1 28 28 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 128 128 3 3 28 28 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 128 1 1 28 28 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 512 1 1 28 28 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 28 28 2 2 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 14 14 2 2 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 +#profiler/ckProfiler conv_fwd_bias_relu $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 512 512 3 3 7 7 1 1 1 1 1 1 1 1 +#profiler/ckProfiler conv_fwd_bias_relu_add $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 -#N=${10} # Resnet50 ######## op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ @@ -49,6 +127,7 @@ REPEAT=$9 #$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE #$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $DESIRED_GRID_SIZE #$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $DESIRED_GRID_SIZE +#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 3 7 7 230 230 2 2 1 1 0 0 0 0 $DESIRED_GRID_SIZE # SSD ######## op datatype in_layout wei_layout out_layout verify init log repeat N__ K___ C___ Y X Hi__ Wi__ Strides Dilations LeftPads RightPads Desired_grid_size__ @@ -96,5 +175,3 @@ REPEAT=$9 #$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 510 512 3 3 10 10 1 1 1 1 1 1 1 1 #$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 510 256 3 3 5 5 1 1 1 1 1 1 1 1 #$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT 120 340 256 3 3 3 3 1 1 1 1 1 1 1 1 - -