From 1264925422920f24b3bb4fa34f178e31a23c97b5 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sun, 18 Jul 2021 00:43:05 -0500 Subject: [PATCH] reorganize files to prepare for MIOpen integration (#51) * change olc cmake * adding online compile to fwd-v4r5r2 * update scripts * remane fwd-v4r5r2 to fwd-v6r1 * clean up --- CMakeLists.txt | 60 +- .../driver_dynamic_contraction_v1r1.hpp | 292 -------- .../driver_dynamic_contraction_v1r2.hpp | 264 +++---- .../driver/driver_dynamic_gemm_v1r1.hpp | 387 ---------- ...volution_into_gemm_v4r5_nchw_kcyx_nkhw.hpp | 125 ---- ...volution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp | 0 ...lution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp | 0 ...volution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp | 0 ...volution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp | 0 ...lution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp | 0 ...lution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp | 129 ++++ ...lution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp | 0 ...olution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp} | 21 +- .../tensor_operation/blockwise_gemm_v2r2.hpp | 2 +- .../tensor_operation/blockwise_gemm_v2r3.hpp | 400 +++++----- .../gridwise_dynamic_contraction_v1r1.hpp | 681 ------------------ .../gridwise_dynamic_contraction_v1r2.hpp | 433 ++++++----- .../gridwise_dynamic_gemm_v1r1.hpp | 552 -------------- .../gridwise_dynamic_gemm_v1r3.hpp | 31 +- ...gemm_v2.hpp => threadwise_contraction.hpp} | 155 ++-- .../utility/{config.amd.hpp.in => config.hpp} | 0 .../{float_type.amd.hpp.in => float_type.hpp} | 0 ...ization.amd.hpp.in => synchronization.hpp} | 0 ...ward_implicit_gemm_v4r5_nchw_kcyx_nkhw.cpp | 379 ---------- ...ward_implicit_gemm_v6r1_nchw_kcyx_nkhw.cpp | 402 +++++++++++ driver/include/conv_tunables.hpp | 271 ------- ...ward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp | 520 ------------- ...ward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp | 240 ------ ...ward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp | 404 ----------- host/CMakeLists.txt | 4 + host/driver_offline/CMakeLists.txt | 21 + .../conv_bwd_driver_offline.cpp | 0 .../conv_fwd_driver_offline.cpp | 87 +-- ...plicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp | 0 ...icit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp | 0 ...ward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp | 0 ...plicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp | 283 ++++++++ ...rd_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp | 0 ...icit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp | 0 ...icit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp | 240 ++++++ ...icit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp | 305 ++++++++ ...icit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp | 0 ...ward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp | 0 ...ward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp | 162 ++--- host/driver_online/CMakeLists.txt | 21 + .../driver_online/conv_fwd_driver_online.cpp | 125 ++-- .../conv_tunable_fwd_v4r4_nchw_kcyx_nkhw.hpp | 50 ++ ...tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp | 73 ++ ...tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp | 73 ++ .../conv_tunable_fwd_v6r1_nchw_kcyx_nkhw.hpp | 42 ++ ...ward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp | 10 +- ...plicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp | 10 +- ...plicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp | 10 +- ...ward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp | 425 +++++++++++ .../include/online_driver_common.hpp | 0 host/host_tensor/CMakeLists.txt | 19 + .../host_tensor}/include/conv_common.hpp | 0 .../host_tensor}/include/device.hpp | 44 +- .../host_tensor}/include/device_tensor.hpp | 0 .../host_tensor}/include/host_conv.hpp | 0 .../include/host_conv_bwd_data.hpp | 0 .../host_tensor}/include/host_tensor.hpp | 0 .../include/host_tensor_generator.hpp | 0 {driver => host/host_tensor}/src/device.cpp | 50 +- .../host_tensor}/src/host_tensor.cpp | 0 .../online_compilation}/CMakeLists.txt | 122 ++-- .../addkernels/CMakeLists.txt | 0 .../addkernels/addkernels.cpp | 0 .../addkernels/include_inliner.cpp | 0 .../addkernels/include_inliner.hpp | 0 .../addkernels/source_file_desc.hpp | 0 .../hip_utility/binary_cache.cpp | 0 .../hip_utility/exec_utils.cpp | 0 .../hip_utility/handlehip.cpp | 0 .../hip_utility/hip_build_utils.cpp | 0 .../hip_utility/hipoc_kernel.cpp | 0 .../hip_utility/hipoc_program.cpp | 0 .../hip_utility/kernel_build_params.cpp | 0 .../hip_utility/kernel_cache.cpp | 0 .../hip_utility/logger.cpp | 0 .../online_compilation}/hip_utility/md5.cpp | 0 .../hip_utility/target_properties.cpp | 0 .../hip_utility/tmp_dir.cpp | 0 .../include/binary_cache.hpp | 0 .../online_compilation}/include/config.h.in | 0 .../online_compilation}/include/env.hpp | 0 .../include/exec_utils.hpp | 0 .../online_compilation}/include/handle.hpp | 0 .../online_compilation}/include/hipCheck.hpp | 0 .../include/hip_build_utils.hpp | 0 .../include/hipoc_kernel.hpp | 0 .../include/hipoc_program.hpp | 0 .../include/hipoc_program_impl.hpp | 0 .../online_compilation}/include/kernel.hpp | 0 .../include/kernel_build_params.hpp | 0 .../include/kernel_cache.hpp | 0 .../online_compilation}/include/logger.hpp | 0 .../include/manage_ptr.hpp | 0 .../online_compilation}/include/md5.hpp | 0 .../include/op_kernel_args.hpp | 0 .../include/simple_hash.hpp | 0 .../include/stringutils.hpp | 0 .../include/target_properties.hpp | 0 .../online_compilation}/include/tmp_dir.hpp | 0 .../include/write_file.hpp | 0 .../online_compilation}/kernel.cpp.in | 0 .../kernel_includes.cpp.in | 0 .../online_compilation}/kernels_batch.cpp.in | 0 script/cmake-rocm.sh | 3 +- script/run.sh | 47 +- 110 files changed, 2992 insertions(+), 4982 deletions(-) delete mode 100644 composable_kernel/include/driver/driver_dynamic_contraction_v1r1.hpp delete mode 100644 composable_kernel/include/driver/driver_dynamic_gemm_v1r1.hpp delete mode 100644 composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r5_nchw_kcyx_nkhw.hpp rename composable_kernel/include/{kernel_algorithm => problem_transform}/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp (100%) rename composable_kernel/include/{kernel_algorithm => problem_transform}/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp (100%) rename composable_kernel/include/{kernel_algorithm => problem_transform}/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp (100%) rename composable_kernel/include/{kernel_algorithm => problem_transform}/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp (100%) rename composable_kernel/include/{kernel_algorithm => problem_transform}/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp (100%) create mode 100644 composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp rename composable_kernel/include/{kernel_algorithm => problem_transform}/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp (100%) rename composable_kernel/include/{kernel_algorithm/transform_forward_convolution_into_gemm_v4r5r2_nchw_kcyx_nkhw.hpp => problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp} (90%) delete mode 100644 composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_v1r1.hpp delete mode 100644 composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r1.hpp rename composable_kernel/include/tensor_operation/{threadwise_gemm_v2.hpp => threadwise_contraction.hpp} (53%) rename composable_kernel/include/utility/{config.amd.hpp.in => config.hpp} (100%) rename composable_kernel/include/utility/{float_type.amd.hpp.in => float_type.hpp} (100%) rename composable_kernel/include/utility/{synchronization.amd.hpp.in => synchronization.hpp} (100%) delete mode 100644 composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.cpp create mode 100644 composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.cpp delete mode 100644 driver/include/conv_tunables.hpp delete mode 100644 driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp delete mode 100644 driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp delete mode 100644 driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp create mode 100644 host/CMakeLists.txt create mode 100644 host/driver_offline/CMakeLists.txt rename driver/conv_bwd_data_driver_v2.cpp => host/driver_offline/conv_bwd_driver_offline.cpp (100%) rename driver/conv_driver_v2.cpp => host/driver_offline/conv_fwd_driver_offline.cpp (86%) rename {driver => host/driver_offline}/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp (100%) rename {driver => host/driver_offline}/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp (100%) rename {driver => host/driver_offline}/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp (100%) create mode 100644 host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp rename {driver => host/driver_offline}/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp (100%) rename {driver => host/driver_offline}/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp (100%) create mode 100644 host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp create mode 100644 host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp rename {driver => host/driver_offline}/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp (100%) rename {driver => host/driver_offline}/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp (100%) rename driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp => host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp (62%) create mode 100644 host/driver_online/CMakeLists.txt rename driver/conv_driver_v2_olc.cpp => host/driver_online/conv_fwd_driver_online.cpp (80%) create mode 100644 host/driver_online/include/conv_tunable_fwd_v4r4_nchw_kcyx_nkhw.hpp create mode 100644 host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp create mode 100644 host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp create mode 100644 host/driver_online/include/conv_tunable_fwd_v6r1_nchw_kcyx_nkhw.hpp rename driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp => host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp (99%) rename driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp => host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp (98%) rename driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp => host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp (98%) create mode 100644 host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp rename driver/include/olc_driver_common.hpp => host/driver_online/include/online_driver_common.hpp (100%) create mode 100644 host/host_tensor/CMakeLists.txt rename {driver => host/host_tensor}/include/conv_common.hpp (100%) rename {driver => host/host_tensor}/include/device.hpp (61%) rename {driver => host/host_tensor}/include/device_tensor.hpp (100%) rename {driver => host/host_tensor}/include/host_conv.hpp (100%) rename {driver => host/host_tensor}/include/host_conv_bwd_data.hpp (100%) rename {driver => host/host_tensor}/include/host_tensor.hpp (100%) rename {driver => host/host_tensor}/include/host_tensor_generator.hpp (100%) rename {driver => host/host_tensor}/src/device.cpp (50%) rename {driver => host/host_tensor}/src/host_tensor.cpp (100%) rename {driver => host/online_compilation}/CMakeLists.txt (50%) rename {driver/olCompiling => host/online_compilation}/addkernels/CMakeLists.txt (100%) rename {driver/olCompiling => host/online_compilation}/addkernels/addkernels.cpp (100%) rename {driver/olCompiling => host/online_compilation}/addkernels/include_inliner.cpp (100%) rename {driver/olCompiling => host/online_compilation}/addkernels/include_inliner.hpp (100%) rename {driver/olCompiling => host/online_compilation}/addkernels/source_file_desc.hpp (100%) rename {driver/olCompiling => host/online_compilation}/hip_utility/binary_cache.cpp (100%) rename {driver/olCompiling => host/online_compilation}/hip_utility/exec_utils.cpp (100%) rename {driver/olCompiling => host/online_compilation}/hip_utility/handlehip.cpp (100%) rename {driver/olCompiling => host/online_compilation}/hip_utility/hip_build_utils.cpp (100%) rename {driver/olCompiling => host/online_compilation}/hip_utility/hipoc_kernel.cpp (100%) rename {driver/olCompiling => host/online_compilation}/hip_utility/hipoc_program.cpp (100%) rename {driver/olCompiling => host/online_compilation}/hip_utility/kernel_build_params.cpp (100%) rename {driver/olCompiling => host/online_compilation}/hip_utility/kernel_cache.cpp (100%) rename {driver/olCompiling => host/online_compilation}/hip_utility/logger.cpp (100%) rename {driver/olCompiling => host/online_compilation}/hip_utility/md5.cpp (100%) rename {driver/olCompiling => host/online_compilation}/hip_utility/target_properties.cpp (100%) rename {driver/olCompiling => host/online_compilation}/hip_utility/tmp_dir.cpp (100%) rename {driver/olCompiling => host/online_compilation}/include/binary_cache.hpp (100%) rename {driver/olCompiling => host/online_compilation}/include/config.h.in (100%) rename {driver/olCompiling => host/online_compilation}/include/env.hpp (100%) rename {driver/olCompiling => host/online_compilation}/include/exec_utils.hpp (100%) rename {driver/olCompiling => host/online_compilation}/include/handle.hpp (100%) rename {driver/olCompiling => host/online_compilation}/include/hipCheck.hpp (100%) rename {driver/olCompiling => host/online_compilation}/include/hip_build_utils.hpp (100%) rename {driver/olCompiling => host/online_compilation}/include/hipoc_kernel.hpp (100%) rename {driver/olCompiling => host/online_compilation}/include/hipoc_program.hpp (100%) rename {driver/olCompiling => host/online_compilation}/include/hipoc_program_impl.hpp (100%) rename {driver/olCompiling => host/online_compilation}/include/kernel.hpp (100%) rename {driver/olCompiling => host/online_compilation}/include/kernel_build_params.hpp (100%) rename {driver/olCompiling => host/online_compilation}/include/kernel_cache.hpp (100%) rename {driver/olCompiling => host/online_compilation}/include/logger.hpp (100%) rename {driver/olCompiling => host/online_compilation}/include/manage_ptr.hpp (100%) rename {driver/olCompiling => host/online_compilation}/include/md5.hpp (100%) rename {driver/olCompiling => host/online_compilation}/include/op_kernel_args.hpp (100%) rename {driver/olCompiling => host/online_compilation}/include/simple_hash.hpp (100%) rename {driver/olCompiling => host/online_compilation}/include/stringutils.hpp (100%) rename {driver/olCompiling => host/online_compilation}/include/target_properties.hpp (100%) rename {driver/olCompiling => host/online_compilation}/include/tmp_dir.hpp (100%) rename {driver/olCompiling => host/online_compilation}/include/write_file.hpp (100%) rename {driver/olCompiling => host/online_compilation}/kernel.cpp.in (100%) rename {driver/olCompiling => host/online_compilation}/kernel_includes.cpp.in (100%) rename {driver/olCompiling => host/online_compilation}/kernels_batch.cpp.in (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 51d57d016f..0cf342bb45 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,14 +6,14 @@ list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") include(TargetFlags) include(AddKernels) -#c++ +## C++ enable_language(CXX) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") -#OpenMP +## OpenMP if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") # workaround issue hipcc in rocm3.5 cannot find openmp set(OpenMP_CXX "${CMAKE_CXX_COMPILER}") @@ -35,56 +35,8 @@ set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") link_libraries(${OpenMP_gomp_LIBRARY}) link_libraries(${OpenMP_pthread_LIBRARY}) -#GPU backend -if(DEVICE_BACKEND STREQUAL "AMD") - find_package(HIP REQUIRED) -endif() - -# -include_directories(BEFORE - ${PROJECT_SOURCE_DIR}/composable_kernel/include - ${PROJECT_SOURCE_DIR}/composable_kernel/include/utility - ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description - ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation - ${PROJECT_SOURCE_DIR}/composable_kernel/include/kernel_algorithm - ${PROJECT_SOURCE_DIR}/composable_kernel/include/driver - ${PROJECT_SOURCE_DIR}/external/half/include - ${PROJECT_SOURCE_DIR}/driver/include - ${PROJECT_BINARY_DIR}/composable_kernel/include/utility -) - -if(DEVICE_BACKEND STREQUAL "AMD") - include_directories(BEFORE - ${PROJECT_SOURCE_DIR}/external/rocm/include - ) -endif() - -if(DEVICE_BACKEND STREQUAL "AMD") - configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp") - configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp") - configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/synchronization.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/synchronization.hpp") -endif() - -add_subdirectory(driver) - -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") - -message("Compiling options for drivers: ${CMAKE_CXX_FLAGS}") - -if(DEVICE_BACKEND STREQUAL "AMD") - set(CONV_V2_SOURCE driver/conv_driver_v2.cpp) - set(CONV_BWD_DATA_V2_SOURCE driver/conv_bwd_data_driver_v2.cpp) - set(CONV_V2_OLC_SOURCE driver/conv_driver_v2_olc.cpp) -endif() - -add_executable(conv_driver_v2 ${CONV_V2_SOURCE}) -add_executable(conv_bwd_data_driver_v2 ${CONV_BWD_DATA_V2_SOURCE}) -add_executable(conv_driver_v2_olc ${CONV_V2_OLC_SOURCE}) - -target_include_directories(conv_driver_v2_olc PRIVATE driver/olCompiling/include/) - -target_link_libraries(conv_driver_v2 PRIVATE modConv) -target_link_libraries(conv_bwd_data_driver_v2 PRIVATE modConv) -target_link_libraries(conv_driver_v2_olc PRIVATE modConv) - +## HIP +find_package(HIP REQUIRED) +message(STATUS "Build with HIP ${hip_VERSION}") +add_subdirectory(host) diff --git a/composable_kernel/include/driver/driver_dynamic_contraction_v1r1.hpp b/composable_kernel/include/driver/driver_dynamic_contraction_v1r1.hpp deleted file mode 100644 index 0252f9487a..0000000000 --- a/composable_kernel/include/driver/driver_dynamic_contraction_v1r1.hpp +++ /dev/null @@ -1,292 +0,0 @@ -#ifndef CK_DRIVER_DYNAMIC_CONTRACTION_V1R1_HPP -#define CK_DRIVER_DYNAMIC_CONTRACTION_V1R1_HPP - -#include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_contraction_v1r1.hpp" - -namespace ck { - -template -__host__ float -driver_dynamic_contraction_v1r1(const FloatAB* p_a_grid, - const FloatAB* p_b_grid, - FloatC* p_c_grid, - const AGKGM0GM1GridDesc& a_gk_gm0_gm1_grid_desc, - const BGKGN0GN1GridDesc& b_gk_gn0_gn1_grid_desc, - const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks, - index_t nrepeat) - -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - - // GEMM - using GridwiseContraction = GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1< - BlockSize, - FloatAB, - FloatAcc, - FloatC, - CGlobalMemoryDataOperation, - AGKGM0GM1GridDesc, - BGKGN0GN1GridDesc, - CGM0GM1GN0GN1GridDesc, - GM1PerBlockGM11, - GN1PerBlockGN11, - KPerBlock, - M1PerThread, - N1PerThread, - KPerThread, - M1N1ThreadClusterM10, - M1N1ThreadClusterN10, - M1N1ThreadClusterM11, - M1N1ThreadClusterN11, - ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11, - ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_GM11, - AThreadTransferSrcResetCoordinateAfterRun, - BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11, - BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_GN11, - BThreadTransferSrcResetCoordinateAfterRun, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks>; - - const auto K = a_gk_gm0_gm1_grid_desc.GetLength(I0); - - if(!GridwiseContraction::CheckValidity( - a_gk_gm0_gm1_grid_desc, b_gk_gn0_gn1_grid_desc, c_gm0_gm1_gn0_gn1_grid_desc)) - { - throw std::runtime_error( - "wrong! GridwiseDynamicContraction_km_kn0n1_mn0n1_v1r1 has invalid setting"); - } - - const auto a_gk_gm0_gm10_gm11_grid_desc = - GridwiseContraction::MakeAGKGM0GM10GM11GridDescriptor(a_gk_gm0_gm1_grid_desc); - const auto b_gk_gn0_gn10_gn11_grid_desc = - GridwiseContraction::MakeBGKGN0GN10GN11GridDescriptor(b_gk_gn0_gn1_grid_desc); - - using AGKGM0GM10GM11GridDesc = decltype(a_gk_gm0_gm10_gm11_grid_desc); - using BGKGN0GN10GN11GridDesc = decltype(b_gk_gn0_gn10_gn11_grid_desc); - - // c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc - const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc = - GridwiseContraction::MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(c_gm0_gm1_gn0_gn1_grid_desc); - - using CGM10BM0BM1GN10BN0BN1GridDesc = decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc); - - // c_blockid_to_gm10_gn10_block_cluster_adaptor - const auto c_blockid_to_gm10_gn10_block_cluster_adaptor = - GridwiseContraction::MakeCBlockIdToGM10GN10BlockClusterAdaptor(c_gm0_gm1_gn0_gn1_grid_desc); - - using CBlockIdToGM10GN10BlockClusterAdaptor = - decltype(c_blockid_to_gm10_gn10_block_cluster_adaptor); - - const index_t grid_size = GridwiseContraction::CalculateGridSize(c_gm0_gm1_gn0_gn1_grid_desc); - - const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(K); - - const bool has_double_tail_k_block_loop = - GridwiseContraction::CalculateHasDoubleTailKBlockLoop(K); - - { - std::cout << "a_gk_gm0_gm10_gm11_grid_desc{" << a_gk_gm0_gm10_gm11_grid_desc.GetLength(I0) - << ", " << a_gk_gm0_gm10_gm11_grid_desc.GetLength(I1) << ", " - << a_gk_gm0_gm10_gm11_grid_desc.GetLength(I2) << ", " - << a_gk_gm0_gm10_gm11_grid_desc.GetLength(I3) << "}" << std::endl; - - std::cout << "b_gk_gn0_gn10_gn11_grid_desc{" << b_gk_gn0_gn10_gn11_grid_desc.GetLength(I0) - << ", " << b_gk_gn0_gn10_gn11_grid_desc.GetLength(I1) << ", " - << b_gk_gn0_gn10_gn11_grid_desc.GetLength(I2) << ", " - << b_gk_gn0_gn10_gn11_grid_desc.GetLength(I3) << "}" << std::endl; - - std::cout << "c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc{ " - << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I0) << ", " - << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I1) << ", " - << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I2) << ", " - << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I3) << ", " - << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I4) << ", " - << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I5) << "}" << std::endl; - } - - float ave_time = 0; - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = kernel_dynamic_contraction_v1r1< - GridwiseContraction, - FloatAB, - FloatC, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_gk_gm0_gm10_gm11_grid_desc, - b_gk_gn0_gn10_gn11_grid_desc, - c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, - c_blockid_to_gm10_gn10_block_cluster_adaptor); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = kernel_dynamic_contraction_v1r1< - GridwiseContraction, - FloatAB, - FloatC, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_gk_gm0_gm10_gm11_grid_desc, - b_gk_gn0_gn10_gn11_grid_desc, - c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, - c_blockid_to_gm10_gn10_block_cluster_adaptor); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = kernel_dynamic_contraction_v1r1< - GridwiseContraction, - FloatAB, - FloatC, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_gk_gm0_gm10_gm11_grid_desc, - b_gk_gn0_gn10_gn11_grid_desc, - c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, - c_blockid_to_gm10_gn10_block_cluster_adaptor); - } - else - { - const auto kernel = kernel_dynamic_contraction_v1r1< - GridwiseContraction, - FloatAB, - FloatC, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_gk_gm0_gm10_gm11_grid_desc, - b_gk_gn0_gn10_gn11_grid_desc, - c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, - c_blockid_to_gm10_gn10_block_cluster_adaptor); - } - - return ave_time; -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/driver/driver_dynamic_contraction_v1r2.hpp b/composable_kernel/include/driver/driver_dynamic_contraction_v1r2.hpp index 20d7c2ef89..2f68fec7e3 100644 --- a/composable_kernel/include/driver/driver_dynamic_contraction_v1r2.hpp +++ b/composable_kernel/include/driver/driver_dynamic_contraction_v1r2.hpp @@ -13,19 +13,19 @@ template {}; // GEMM - using GridwiseContraction = GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2< - BlockSize, - FloatAB, - FloatAcc, - FloatC, - CGlobalMemoryDataOperation, - AGKGM0GM1GridDesc, - BGKGN0GN1GridDesc, - CGM0GM1GN0GN1GridDesc, - GM1PerBlockGM11, - GN1PerBlockGN11, - KPerBlock, - M1PerThread, - N1PerThread, - KPerThread, - M1N1ThreadClusterM10, - M1N1ThreadClusterN10, - M1N1ThreadClusterM11, - M1N1ThreadClusterN11, - ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, - ABlockTransferSrcVectorTensorContiguousDimOrder, - BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - BBlockTransferSrcVectorTensorContiguousDimOrder, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks>; + using GridwiseContraction = + GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + CGlobalMemoryDataOperation, + AGridDesc_GK0_GM0_GM1_GK1, + BGridDesc_GK0_GN0_GN1_GK1, + CGridDesc_GM0_GM1_GN0_GN1, + GM1PerBlockGM11, + GN1PerBlockGN11, + GK0PerBlock, + BM1PerThreadBM11, + BN1PerThreadBN11, + BK0PerThread, + BM10BN10ThreadClusterBM100, + BM10BN10ThreadClusterBN100, + BM10BN10ThreadClusterBM101, + BM10BN10ThreadClusterBN101, + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferSrcVectorTensorContiguousDimOrder, + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferSrcVectorTensorContiguousDimOrder, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + AGridIteratorHacks, + BGridIteratorHacks, + CGridIteratorHacks, + AGridMoveSliceWindowIteratorHacks, + BGridMoveSliceWindowIteratorHacks>; - const auto GK0 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I0); + const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0); if(!GridwiseContraction::CheckValidity( - a_gk0_gm0_gm1_gk1_grid_desc, b_gk0_gn0_gn1_gk1_grid_desc, c_gm0_gm1_gn0_gn1_grid_desc)) + a_grid_desc_gk0_gm0_gm1_gk1, b_grid_desc_gk0_gn0_gn1_gk1, c_grid_desc_gm0_gm1_gn0_gn1)) { - throw std::runtime_error( - "wrong! GridwiseDynamicContraction_km_kn0n1_mn0n1_v1r1 has invalid setting"); + throw std::runtime_error("wrong! " + "GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_" + "GM0_GM1_GN0_GN1 has invalid setting"); } - const auto a_gk0_gm0_gm10_gm11_gk1_grid_desc = - GridwiseContraction::MakeAGK0GM0GM10GM11GK1GridDescriptor(a_gk0_gm0_gm1_gk1_grid_desc); - const auto b_gk0_gn0_gn10_gn11_gk1_grid_desc = - GridwiseContraction::MakeBGK0GN0GN10GN11GK1GridDescriptor(b_gk0_gn0_gn1_gk1_grid_desc); + const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = + GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(a_grid_desc_gk0_gm0_gm1_gk1); + const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = + GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(b_grid_desc_gk0_gn0_gn1_gk1); - using AGK0GM0GM10GM11GK1GridDesc = decltype(a_gk0_gm0_gm10_gm11_gk1_grid_desc); - using BGK0GN0GN10GN11GK1GridDesc = decltype(b_gk0_gn0_gn10_gn11_gk1_grid_desc); + using AGridDesc_GK0_GM0_GM10_GM11_GK1 = decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1); + using BGridDesc_GK0_GN0_GN10_GN11_GK1 = decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1); - // c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc - const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc = - GridwiseContraction::MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(c_gm0_gm1_gn0_gn1_grid_desc); + // c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 + const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = + GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( + c_grid_desc_gm0_gm1_gn0_gn1); - using CGM10BM0BM1GN10BN0BN1GridDesc = decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc); + using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1); - // c_blockid_to_gm10_gn10_block_cluster_adaptor - const auto c_blockid_to_gm10_gn10_block_cluster_adaptor = - GridwiseContraction::MakeCBlockIdToGM10GN10BlockClusterAdaptor(c_gm0_gm1_gn0_gn1_grid_desc); + // c_grid_block_cluster_blockid_to_gm10_gn10 + const auto c_grid_block_cluster_blockid_to_gm10_gn10 = + GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10( + c_grid_desc_gm0_gm1_gn0_gn1); - using CBlockIdToGM10GN10BlockClusterAdaptor = - decltype(c_blockid_to_gm10_gn10_block_cluster_adaptor); + using CGridBlockCluster_BlockId_To_GM10_GN10 = + decltype(c_grid_block_cluster_blockid_to_gm10_gn10); - const index_t grid_size = GridwiseContraction::CalculateGridSize(c_gm0_gm1_gn0_gn1_grid_desc); + const index_t grid_size = GridwiseContraction::CalculateGridSize(c_grid_desc_gm0_gm1_gn0_gn1); const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(GK0); @@ -151,41 +155,41 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, GridwiseContraction::CalculateHasDoubleTailKBlockLoop(GK0); { - std::cout << "a_gk0_gm0_gm10_gm11_gk1_grid_desc{" - << a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I0) << ", " - << a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I1) << ", " - << a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I2) << ", " - << a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I3) << ", " - << a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I4) << "}" << std::endl; + std::cout << "a_grid_desc_gk0_gm0_gm10_gm11_gk1{" + << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0) << ", " + << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I1) << ", " + << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I2) << ", " + << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I3) << ", " + << a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I4) << "}" << std::endl; - std::cout << "b_gk0_gn0_gn10_gn11_gk1_grid_desc{" - << b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I0) << ", " - << b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I1) << ", " - << b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I2) << ", " - << b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I3) << ", " - << b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I4) << "}" << std::endl; + std::cout << "b_grid_desc_gk0_gn0_gn10_gn11_gk1{" + << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I0) << ", " + << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I1) << ", " + << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I2) << ", " + << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I3) << ", " + << b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I4) << "}" << std::endl; - std::cout << "c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc{ " - << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I0) << ", " - << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I1) << ", " - << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I2) << ", " - << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I3) << ", " - << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I4) << ", " - << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I5) << "}" << std::endl; + std::cout << "c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1{ " + << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I0) << ", " + << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I1) << ", " + << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I2) << ", " + << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I3) << ", " + << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I4) << ", " + << c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I5) << "}" << std::endl; } float ave_time = 0; if(has_main_k_block_loop && has_double_tail_k_block_loop) { - const auto kernel = kernel_dynamic_contraction_v1r1< + const auto kernel = kernel_dynamic_contraction_v1r2< GridwiseContraction, FloatAB, FloatC, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, true, true>; @@ -198,21 +202,21 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, p_a_grid, p_b_grid, p_c_grid, - a_gk0_gm0_gm10_gm11_gk1_grid_desc, - b_gk0_gn0_gn10_gn11_gk1_grid_desc, - c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, - c_blockid_to_gm10_gn10_block_cluster_adaptor); + a_grid_desc_gk0_gm0_gm10_gm11_gk1, + b_grid_desc_gk0_gn0_gn10_gn11_gk1, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + c_grid_block_cluster_blockid_to_gm10_gn10); } else if(has_main_k_block_loop && !has_double_tail_k_block_loop) { - const auto kernel = kernel_dynamic_contraction_v1r1< + const auto kernel = kernel_dynamic_contraction_v1r2< GridwiseContraction, FloatAB, FloatC, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, true, false>; @@ -225,21 +229,21 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, p_a_grid, p_b_grid, p_c_grid, - a_gk0_gm0_gm10_gm11_gk1_grid_desc, - b_gk0_gn0_gn10_gn11_gk1_grid_desc, - c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, - c_blockid_to_gm10_gn10_block_cluster_adaptor); + a_grid_desc_gk0_gm0_gm10_gm11_gk1, + b_grid_desc_gk0_gn0_gn10_gn11_gk1, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + c_grid_block_cluster_blockid_to_gm10_gn10); } else if(!has_main_k_block_loop && has_double_tail_k_block_loop) { - const auto kernel = kernel_dynamic_contraction_v1r1< + const auto kernel = kernel_dynamic_contraction_v1r2< GridwiseContraction, FloatAB, FloatC, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, false, true>; @@ -252,21 +256,21 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, p_a_grid, p_b_grid, p_c_grid, - a_gk0_gm0_gm10_gm11_gk1_grid_desc, - b_gk0_gn0_gn10_gn11_gk1_grid_desc, - c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, - c_blockid_to_gm10_gn10_block_cluster_adaptor); + a_grid_desc_gk0_gm0_gm10_gm11_gk1, + b_grid_desc_gk0_gn0_gn10_gn11_gk1, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + c_grid_block_cluster_blockid_to_gm10_gn10); } else { - const auto kernel = kernel_dynamic_contraction_v1r1< + const auto kernel = kernel_dynamic_contraction_v1r2< GridwiseContraction, FloatAB, FloatC, - remove_reference_t, - remove_reference_t, - remove_reference_t, - remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, false, false>; @@ -279,10 +283,10 @@ driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, p_a_grid, p_b_grid, p_c_grid, - a_gk0_gm0_gm10_gm11_gk1_grid_desc, - b_gk0_gn0_gn10_gn11_gk1_grid_desc, - c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, - c_blockid_to_gm10_gn10_block_cluster_adaptor); + a_grid_desc_gk0_gm0_gm10_gm11_gk1, + b_grid_desc_gk0_gn0_gn10_gn11_gk1, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + c_grid_block_cluster_blockid_to_gm10_gn10); } return ave_time; diff --git a/composable_kernel/include/driver/driver_dynamic_gemm_v1r1.hpp b/composable_kernel/include/driver/driver_dynamic_gemm_v1r1.hpp deleted file mode 100644 index 1b52d368fe..0000000000 --- a/composable_kernel/include/driver/driver_dynamic_gemm_v1r1.hpp +++ /dev/null @@ -1,387 +0,0 @@ -#ifndef CK_DRIVER_DYNAMIC_GEMM_V1 -#define CK_DRIVER_DYNAMIC_GEMM_V1 - -#include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_gemm_v1r1.hpp" - -namespace ck { - -template -__host__ float launch_kernel_dynamic_gemm_v1r1(const FloatAB* p_a_global, - const FloatAB* p_b_global, - FloatC* p_c_global, - const AGlobalDesc& a_k_m_global_desc, - const BGlobalDesc& b_k_n_global_desc, - const CGlobalDesc& c_m0_m1_n0_n1_global_desc, - const CBlockClusterDesc& c_block_cluster_desc, - AGlobalIteratorHacks, - BGlobalIteratorHacks, - CGlobalIteratorHacks, - AGlobalMoveSliceWindowIteratorHacks, - BGlobalMoveSliceWindowIteratorHacks, - index_t nrepeat) - -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto M = a_k_m_global_desc.GetLength(I1); - const auto N = b_k_n_global_desc.GetLength(I1); - const auto K = a_k_m_global_desc.GetLength(I0); - - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) - { - throw std::runtime_error("wrong! GEMM size no divisible"); - } - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - if(!(MPerBlock % M1 == 0 && NPerBlock % N1 == 0)) - { - throw std::runtime_error("wrong! GEMM size no divisible"); - } - - // GEMM - using gridwise_gemm = - GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1; - - const auto GridSize = (M / MPerBlock) * (N / NPerBlock); - - const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1; - - const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0; - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE - float ave_time = 0; - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = kernel_dynamic_gemm_v1r1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - a_k_m_global_desc, - b_k_n_global_desc, - c_m0_m1_n0_n1_global_desc, - c_block_cluster_desc); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = kernel_dynamic_gemm_v1r1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - a_k_m_global_desc, - b_k_n_global_desc, - c_m0_m1_n0_n1_global_desc, - c_block_cluster_desc); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = kernel_dynamic_gemm_v1r1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - a_k_m_global_desc, - b_k_n_global_desc, - c_m0_m1_n0_n1_global_desc, - c_block_cluster_desc); - } - else - { - const auto kernel = kernel_dynamic_gemm_v1r1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - a_k_m_global_desc, - b_k_n_global_desc, - c_m0_m1_n0_n1_global_desc, - c_block_cluster_desc); - } - - return ave_time; -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER - DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc)); - DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc)); - DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc)); - DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc)); - - a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc); - b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc); - c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc); - c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc); - - float ave_time = 0; - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = kernel_dynamic_gemm_v1r1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = kernel_dynamic_gemm_v1r1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = kernel_dynamic_gemm_v1r1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); - } - else - { - const auto kernel = kernel_dynamic_gemm_v1r1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); - } - - return ave_time; -#endif -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r5_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r5_nchw_kcyx_nkhw.hpp deleted file mode 100644 index ff2d4254c6..0000000000 --- a/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r5_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,125 +0,0 @@ -#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5_NCHW_KCYX_NKHW_HPP -#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5_NCHW_KCYX_NKHW_HPP - -#include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" - -namespace ck { - -// GemmM = K -// GemmN = N * Ho * Wo -// GemmK = C * Y * X -template -__host__ __device__ constexpr auto -transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad( - const DynamicTensorDescriptor& wei_k_c_y_x_grid_desc, - const DynamicTensorDescriptor& in_n_c_hi_wi_grid_desc, - const DynamicTensorDescriptor& out_n_k_ho_wo_grid_desc, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); - const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); - const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); - - const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); - const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); - - const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); - const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); - - const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); - const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); - - const auto ConvStrideH = conv_strides[I0]; - const auto ConvStrideW = conv_strides[I1]; - - const auto ConvDilationH = conv_dilations[I0]; - const auto ConvDilationW = conv_dilations[I1]; - - const auto InLeftPadH = in_left_pads[I0]; - const auto InLeftPadW = in_left_pads[I1]; - - const auto InRightPadH = in_right_pads[I0]; - const auto InRightPadW = in_right_pads[I1]; - - // weight tensor - const auto wei_gk_gm0_gm1_grid_desc = transform_dynamic_tensor_descriptor( - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), - make_tuple(make_unmerge_transform(make_tuple(I1, K)), - make_pass_through_transform(C * Y * X)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1, 2>{}, Sequence<0>{})); - - // input tensor - const auto in_n_c_hip_wip_grid_desc = transform_dynamic_tensor_descriptor( - in_n_c_hi_wi_grid_desc, - make_tuple(make_pass_through_transform(N), - make_pass_through_transform(C), - make_pad_transform(Hi, InLeftPadH, InRightPadH), - make_pad_transform(Wi, InLeftPadW, InRightPadW)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - constexpr auto N0 = Number{}; - const auto N1 = N / N0; - - const auto in_n0_n1_c_y_ho_x_wo_grid_desc = transform_dynamic_tensor_descriptor( - in_n_c_hip_wip_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(N0, N1)), - make_pass_through_transform(C), - make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), - make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4>{}, Sequence<5, 6>{})); - - const auto in_gk_gn0_gn1_grid_desc = transform_dynamic_tensor_descriptor( - in_n0_n1_c_y_ho_x_wo_grid_desc, - make_tuple(make_merge_transform(make_tuple(C, Y, X)), - make_pass_through_transform(N0), - make_merge_transform(make_tuple(N1, Ho, Wo))), - make_tuple(Sequence<2, 3, 5>{}, Sequence<0>{}, Sequence<1, 4, 6>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - // output tensor - const auto out_n_k_howo_grid_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)); - - const auto out_n0_n1_1_k_howo_grid_desc = transform_dynamic_tensor_descriptor( - out_n_k_howo_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(Number{}, N1)), - make_unmerge_transform(make_tuple(I1, K)), - make_pass_through_transform(Ho * Wo)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{})); - - const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_dynamic_tensor_descriptor( - out_n0_n1_1_k_howo_grid_desc, - make_tuple(make_pass_through_transform(I1), - make_pass_through_transform(K), - make_pass_through_transform(Number{}), - make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))), - make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - return make_tuple( - wei_gk_gm0_gm1_grid_desc, in_gk_gn0_gn1_grid_desc, out_gm0_gm1_gn0_gn1_grid_desc); -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp similarity index 100% rename from composable_kernel/include/kernel_algorithm/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp rename to composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp diff --git a/composable_kernel/include/kernel_algorithm/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp similarity index 100% rename from composable_kernel/include/kernel_algorithm/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp rename to composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp diff --git a/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp similarity index 100% rename from composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp rename to composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp diff --git a/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp similarity index 100% rename from composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp rename to composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp diff --git a/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp similarity index 100% rename from composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp rename to composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp diff --git a/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..5814e66766 --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,129 @@ +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmN = N * Ho * Wo +// GemmK = C * Y * X +template +__host__ __device__ constexpr auto +transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad( + const DynamicTensorDescriptor& wei_k_y_x_c_grid_desc, + const DynamicTensorDescriptor& in_n_hi_wi_c_grid_desc, + const DynamicTensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = N * Ho * Wo; + const auto GemmK = C * Y * X; + const auto GemmK0 = GemmK / GemmK1; + + // weight tensor + const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + wei_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // input tensor + const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmn_grid_desc = + transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + in_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp similarity index 100% rename from composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp rename to composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp diff --git a/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r5r2_nchw_kcyx_nkhw.hpp b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp similarity index 90% rename from composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r5r2_nchw_kcyx_nkhw.hpp rename to composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp index a3f0a3268f..957ca02723 100644 --- a/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r5r2_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/problem_transform/transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp @@ -1,5 +1,5 @@ -#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5R2_NCHW_KCYX_NKHW_HPP -#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5R2_NCHW_KCYX_NKHW_HPP +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V6R1_NCHW_KCYX_NKHW_HPP #include "common_header.hpp" #include "dynamic_tensor_descriptor.hpp" @@ -17,10 +17,10 @@ template + typename N0Type, + typename C0Type> __host__ __device__ constexpr auto -transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad( +transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad( const DynamicTensorDescriptor& wei_k_c_y_x_grid_desc, const DynamicTensorDescriptor& in_n_c_hi_wi_grid_desc, const DynamicTensorDescriptor& out_n_k_ho_wo_grid_desc, @@ -28,8 +28,8 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad( const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, const InRightPads& in_right_pads, - Number, - Number) + const N0Type& N0, + const C0Type& C0) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -61,9 +61,6 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad( const auto InRightPadH = in_right_pads[I0]; const auto InRightPadW = in_right_pads[I1]; - constexpr auto N0 = Number{}; - constexpr auto C0 = Number{}; - const auto N1 = N / N0; const auto C1 = C / C0; @@ -109,7 +106,7 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad( const auto out_n0_n1_1_k_howo_grid_desc = transform_dynamic_tensor_descriptor( out_n_k_howo_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(Number{}, N1)), + make_tuple(make_unmerge_transform(make_tuple(N0, N1)), make_unmerge_transform(make_tuple(I1, K)), make_pass_through_transform(Ho * Wo)), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), @@ -119,7 +116,7 @@ transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad( out_n0_n1_1_k_howo_grid_desc, make_tuple(make_pass_through_transform(I1), make_pass_through_transform(K), - make_pass_through_transform(Number{}), + make_pass_through_transform(N0), make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))), make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_v2r2.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_v2r2.hpp index 97fbc0bbaf..89cf53abce 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_v2r2.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_v2r2.hpp @@ -4,7 +4,7 @@ #include "common_header.hpp" #include "tensor_adaptor.hpp" #include "threadwise_dynamic_tensor_slice_transfer.hpp" -#include "threadwise_gemm_v2.hpp" +#include "threadwise_contraction.hpp" namespace ck { diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_v2r3.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_v2r3.hpp index c57134762b..e3ba21494a 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_v2r3.hpp @@ -4,43 +4,43 @@ #include "common_header.hpp" #include "tensor_adaptor.hpp" #include "threadwise_dynamic_tensor_slice_transfer_v2.hpp" -#include "threadwise_gemm_v2.hpp" +#include "threadwise_contraction.hpp" namespace ck { -// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1] +// C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1] // A and B are visable to the whole block, C is distributed among each thread // Assume: // 1. A: -// 1. AK0MK1BlockDesc is known at compile-time +// 1. ABlockDesc_BK0_BM_BK1 is known at compile-time // 2. ABlockBuffer is DynamicBuffer // 2. B: -// 1. BK0NK1BlockDesc is known at compile-time +// 1. BBlockDesc_BK0_BN_BK1 is known at compile-time // 2. BBlockBuffer is DynamicBuffer // 3. C: -// 1. CM0M1N0N1ThreadDesc is known at compile-time +// 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time // 2. CThreadBuffer is StaticBuffer // Also assume: -// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization) +// BM0 = BN0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization) template ::type = false> -struct BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2 +struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 { using AIndex = MultiIndex<3>; using BIndex = MultiIndex<3>; @@ -51,138 +51,144 @@ struct BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2 static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; - static constexpr index_t K0 = AK0MK1BlockDesc{}.GetLength(I0); - static constexpr index_t K1 = AK0MK1BlockDesc{}.GetLength(I2); - static constexpr index_t M = AK0MK1BlockDesc{}.GetLength(I1); - static constexpr index_t N = BK0NK1BlockDesc{}.GetLength(I1); + static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0); + static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2); + static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1); + static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1); - static constexpr index_t M100 = M1N1ThreadClusterM100; - static constexpr index_t N100 = M1N1ThreadClusterN100; + static constexpr index_t BM100 = BM10BN10ThreadClusterBM100; + static constexpr index_t BN100 = BM10BN10ThreadClusterBN100; - static constexpr index_t M101 = M1N1ThreadClusterM101; - static constexpr index_t N101 = M1N1ThreadClusterN101; + static constexpr index_t BM101 = BM10BN10ThreadClusterBM101; + static constexpr index_t BN101 = BM10BN10ThreadClusterBN101; - static constexpr index_t M11 = M1PerThreadM11; - static constexpr index_t N11 = N1PerThreadN11; + static constexpr index_t BM11 = BM1PerThreadBM11; + static constexpr index_t BN11 = BN1PerThreadBN11; - static constexpr index_t M1 = M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThreadM11; - static constexpr index_t N1 = M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThreadN11; + static constexpr index_t BM1 = + BM10BN10ThreadClusterBM100 * BM10BN10ThreadClusterBM101 * BM1PerThreadBM11; + static constexpr index_t BN1 = + BM10BN10ThreadClusterBN100 * BM10BN10ThreadClusterBN101 * BN1PerThreadBN11; - static constexpr index_t M0 = M / M1; - static constexpr index_t N0 = N / N1; + static constexpr index_t BM0 = BM / BM1; + static constexpr index_t BN0 = BN / BN1; __host__ __device__ static constexpr auto - MakeAK0M0M1K1BlockDescriptor(const AK0MK1BlockDesc& a_k0_m_k1_block_desc) + MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1) { - const auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor( - a_k0_m_k1_block_desc, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{})), + const auto a_block_bk0_bm0_bm1_bk1 = transform_dynamic_tensor_descriptor( + a_block_desc_bk0_bm_bk1, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - return a_k0_m0_m1_k1_block_desc; + return a_block_bk0_bm0_bm1_bk1; } __host__ __device__ static constexpr auto - MakeBK0N0N1K1BlockDescriptor(const BK0NK1BlockDesc& b_k0_n_k1_block_desc) + MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1) { - const auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor( - b_k0_n_k1_block_desc, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{})), + const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_dynamic_tensor_descriptor( + b_block_desc_bk0_bn_bk1, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - return b_k0_n0_n1_k1_block_desc; + return b_block_desc_bk0_bn0_bn1_bk1; } - __host__ __device__ static constexpr auto MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor() + __host__ __device__ static constexpr auto + MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN() { - // upper: [M0, M100, M101, M11, N0, N100, N101, N11] - // lower: [M, N] - constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor = + // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + // lower: [BM, BN] + constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n = make_single_stage_tensor_adaptor( make_tuple(make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{}, Number{})), + Number{}, Number{}, Number{}, Number{})), make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{}, Number{}))), + Number{}, Number{}, Number{}, Number{}))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{})); - return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor; + return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n; } __host__ __device__ static constexpr auto - MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor() + MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1() { - // upper: [M0, M100, M101, M11, N0, N100, N101, N11] - // lower: [M0, M1, N0, N1] - constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor = + // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + // lower: [BM0, BM1, BN0, BN1] + constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 = make_single_stage_tensor_adaptor( - make_tuple(make_pass_through_transform(Number{}), + make_tuple(make_pass_through_transform(Number{}), make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{}), + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{}), make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Number{}, Number{}, Number{}))), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{})); - return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor; + return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1; } - __host__ __device__ static constexpr auto GetCM0M1N0N1ThreadTensorLengths() + __host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1() { - return Sequence{}; + return Sequence{}; } - static constexpr auto a_k0_m0_m1_k1_block_desc_ = - MakeAK0M0M1K1BlockDescriptor(AK0MK1BlockDesc{}); - static constexpr auto b_k0_n0_n1_k1_block_desc_ = - MakeBK0N0N1K1BlockDescriptor(BK0NK1BlockDesc{}); + static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ = + MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{}); + + static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ = + MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{}); public: - __device__ BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2() - : c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock( + __device__ BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2() + : c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( get_thread_local_1d_id())}, a_thread_copy_{ make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1], 0)}, b_thread_copy_{ make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3], 0)} { - static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && - BK0NK1BlockDesc::IsKnownAtCompileTime(), + static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() && + BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); - static_assert(BlockSize == M101 * M100 * N101 * N100, + static_assert(BlockSize == BM101 * BM100 * BN101 * BN100, "wrong! blocksize and cluster size not consistent"); - static_assert(M % M1 == 0 && N % N1 == 0, "wrong!"); + static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!"); - static_assert(AK0MK1BlockDesc{}.GetLength(I0) == BK0NK1BlockDesc{}.GetLength(I0), + static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) == + BBlockDesc_BK0_BN_BK1{}.GetLength(I0), "wrong! K dimension not consistent"); // TODO: remove this restriction - static_assert(M0 == 2 && N0 == 2, "wrong"); + static_assert(BM0 == 2 && BN0 == 2, "wrong"); } - __device__ static CIndex CalculateCM0M1N0N1ThreadOriginOnBlock(index_t thread_id) + __device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id) { - // lower: [M0, M1, N0, N1] - // upper: [M0, M100, M101, M11, N0, N100, N101, N11] - constexpr auto adaptor0 = MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor(); + // lower: [BM0, BM1, BN0, BN1] + // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + constexpr auto adaptor0 = + MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1(); - // lower: [M0, M100, M101, M11, N0, N100, N101, N11] - // upper: [Tid, M0, M11, N0, N11] + // lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + // upper: [Tid, BM0, BM11, BN0, BN11] constexpr auto adaptor1 = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M100, N100, M101, N101)), - make_pass_through_transform(M0), - make_pass_through_transform(M11), - make_pass_through_transform(N0), - make_pass_through_transform(N11)), + make_tuple(make_merge_transform(make_tuple(BM100, BN100, BM101, BN101)), + make_pass_through_transform(BM0), + make_pass_through_transform(BM11), + make_pass_through_transform(BN0), + make_pass_through_transform(BN11)), make_tuple( Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); @@ -192,201 +198,203 @@ struct BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2 return adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id(), 0, 0, 0, 0)); } - template - __device__ void Run(const CM0M1N0N1ThreadDesc& c_m0_m1_n0_n1_thread_desc, + __device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11& c_m0_m1_n0_n1_thread_desc, const ABlockBuffer& a_block_buf, const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - static_assert(CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(), + static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); // TODO: remove this restriction - static_assert(M0 == 2 && N0 == 2 && CM0M1N0N1ThreadDesc{}.GetLength(I0) == M0 && - CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0, + static_assert(BM0 == 2 && BN0 == 2 && + CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I0) == BM0 && + CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0, "wrong"); auto a_thread_buf = make_static_buffer( - a_k0_m0_m1_k1_thread_desc_.GetElementSpaceSize()); + a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize()); auto b_thread_buf = make_static_buffer( - b_k0_n0_n1_k1_thread_desc_.GetElementSpaceSize()); + b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize()); - constexpr auto threadwise_gemm = - ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1, - Sequence<1, M1PerThreadM11>, - Sequence<1, N1PerThreadN11>>{}; + constexpr auto threadwise_contraction = + ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< + FloatA, + FloatB, + FloatC, + decltype(a_thread_desc_bk0_bm0_bm1_bk1_), + decltype(b_thread_desc_bk0_bn0_bn1_bk1_), + CThreadDesc_BM0_BM11_BN0_BN11, + Sequence, + Sequence<1, BM1PerThreadBM11>, + Sequence<1, BN1PerThreadBN11>>{}; // read A_sub_0 - a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_, + a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, make_tuple(I0, I0, I0, I0), a_block_buf, - a_k0_m0_m1_k1_thread_desc_, + a_thread_desc_bk0_bm0_bm1_bk1_, make_tuple(I0, I0, I0, I0), a_thread_buf); // read B_sub_0 - b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_, + b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, make_tuple(I0, I0, I0, I0), b_block_buf, - b_k0_n0_n1_k1_thread_desc_, + b_thread_desc_bk0_bn0_bn1_bk1_, make_tuple(I0, I0, I0, I0), b_thread_buf); // read B_sub_1 - b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_, + b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, make_tuple(I0, I1, I0, I0), b_block_buf, - b_k0_n0_n1_k1_thread_desc_, + b_thread_desc_bk0_bn0_bn1_bk1_, make_tuple(I0, I1, I0, I0), b_thread_buf); // read A_sub_1 - a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_, + a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, make_tuple(I0, I1, I0, I0), a_block_buf, - a_k0_m0_m1_k1_thread_desc_, + a_thread_desc_bk0_bm0_bm1_bk1_, make_tuple(I0, I1, I0, I0), a_thread_buf); // C_sub_00 += transpose(A_sub_0) * B_sub_0 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, I0, I0, I0), - b_thread_buf, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - make_tuple(I0, I0, I0, I0)); + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); // C_sub_01 += transpose(A_sub_0) * B_sub_1 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, I0, I0, I0), - b_thread_buf, - make_tuple(I0, I1, I0, I0), - c_thread_buf, - make_tuple(I0, I0, I1, I0)); + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); - // loop over rest of k - static_for{}([&](auto k) { + // loop over rest of bk0 + static_for{}([&](auto bk0) { // read A_sub_0 - a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_, - make_tuple(k, I0, I0, I0), + a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, + make_tuple(bk0, I0, I0, I0), a_block_buf, - a_k0_m0_m1_k1_thread_desc_, + a_thread_desc_bk0_bm0_bm1_bk1_, make_tuple(I0, I0, I0, I0), a_thread_buf); // C_sub_10 += transpose(A_sub_1) * B_sub_0 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, I1, I0, I0), - b_thread_buf, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - make_tuple(I1, I0, I0, I0)); + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); // read B_sub_0 - b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_, - make_tuple(k, I0, I0, I0), + b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, + make_tuple(bk0, I0, I0, I0), b_block_buf, - b_k0_n0_n1_k1_thread_desc_, + b_thread_desc_bk0_bn0_bn1_bk1_, make_tuple(I0, I0, I0, I0), b_thread_buf); // C_sub_11 += transpose(A_sub_1) * B_sub_1 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, I1, I0, I0), - b_thread_buf, - make_tuple(I0, I1, I0, I0), - c_thread_buf, - make_tuple(I1, I0, I1, I0)); + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); // read B_sub_1 - b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_, - make_tuple(k, I1, I0, I0), + b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, + make_tuple(bk0, I1, I0, I0), b_block_buf, - b_k0_n0_n1_k1_thread_desc_, + b_thread_desc_bk0_bn0_bn1_bk1_, make_tuple(I0, I1, I0, I0), b_thread_buf); // read A_sub_1 - a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_, - make_tuple(k, I1, I0, I0), + a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, + make_tuple(bk0, I1, I0, I0), a_block_buf, - a_k0_m0_m1_k1_thread_desc_, + a_thread_desc_bk0_bm0_bm1_bk1_, make_tuple(I0, I1, I0, I0), a_thread_buf); // C_sub_00 += transpose(A_sub_0) * B_sub_0 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, I0, I0, I0), - b_thread_buf, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - make_tuple(I0, I0, I0, I0)); + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); // C_sub_01 += transpose(A_sub_0) * B_sub_1 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, I0, I0, I0), - b_thread_buf, - make_tuple(I0, I1, I0, I0), - c_thread_buf, - make_tuple(I0, I0, I1, I0)); + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); }); // C_sub_10 += transpose(A_sub_1) * B_sub_0 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, I1, I0, I0), - b_thread_buf, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - make_tuple(I1, I0, I0, I0)); + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); // C_sub_11 += transpose(A_sub_1) * B_sub_1 - threadwise_gemm.Run(a_thread_buf, - make_tuple(I0, I1, I0, I0), - b_thread_buf, - make_tuple(I0, I1, I0, I0), - c_thread_buf, - make_tuple(I1, I0, I1, I0)); + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); } private: - // A[K0, M0, M1, K1] - static constexpr auto a_k0_m0_m1_k1_thread_desc_ = - make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, Number{}, Number{}, Number{})); + // A[BK0, BM0, BM1, BK1] + static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( + Number{}, Number{}, Number{}, Number{})); - // B[K0, N0, N1, K1] - static constexpr auto b_k0_n0_n1_k1_thread_desc_ = - make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(Number{}, Number{}, Number{}, Number{})); + // B[BK0, BN0, BN1, BK1] + static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( + Number{}, Number{}, Number{}, Number{})); using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1< FloatA, FloatA, - decltype(a_k0_m0_m1_k1_block_desc_), - decltype(a_k0_m0_m1_k1_thread_desc_), - Sequence, // SliceLengths - Sequence<0, 1, 2, 3>, // DimAccessOrder - Sequence<1, 1, M1PerThreadM11, K1>, // SrcVectorTensorLengths - Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder + decltype(a_block_desc_bk0_bm0_bm1_bk1_), + decltype(a_thread_desc_bk0_bm0_bm1_bk1_), + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths + Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1< FloatB, FloatB, - decltype(b_k0_n0_n1_k1_block_desc_), - decltype(b_k0_n0_n1_k1_thread_desc_), - Sequence, // SliceLengths - Sequence<0, 1, 2, 3>, // DimAccessOrder - Sequence<1, 1, N1PerThreadN11, K1>, // SrcVectorTensorLengths - Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder + decltype(b_block_desc_bk0_bn0_bn1_bk1_), + decltype(b_thread_desc_bk0_bn0_bn1_bk1_), + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + Sequence<1, 1, BN1PerThreadBN11, BK1>, // SrcVectorTensorLengths + Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder CIndex c_thread_origin_data_idx_; diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_v1r1.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_v1r1.hpp deleted file mode 100644 index 915a8e28d4..0000000000 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_v1r1.hpp +++ /dev/null @@ -1,681 +0,0 @@ -#ifndef CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R1_HPP -#define CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R1_HPP - -#include "common_header.hpp" -#include "dynamic_multi_index_transform_helper.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "blockwise_gemm_v2r2.hpp" -#include "blockwise_dynamic_tensor_slice_transfer.hpp" -#include "threadwise_dynamic_tensor_slice_transfer.hpp" -#include "threadwise_dynamic_tensor_slice_set.hpp" - -namespace ck { - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_dynamic_contraction_v1r1( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const AGKGM0GM10GM11GridDesc a_gk_gm0_gm10_gm11_grid_desc, - const BGKGN0GN10GN11GridDesc b_gk_gn0_gn10_gn11_grid_desc, - const CGM10BM0BM1GN10BN0BN1GridDesc c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, - const CBlockIdToGM10GN10BlockClusterAdaptor - c_blockid_to_gm10_gn10_block_cluster_adaptor) -{ - constexpr index_t shared_block_size = - GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseContraction::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_gk_gm0_gm10_gm11_grid_desc, - b_gk_gn0_gn10_gn11_grid_desc, - c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, - c_blockid_to_gm10_gn10_block_cluster_adaptor, - integral_constant{}, - integral_constant{}); -} - -template -struct GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1 -{ - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - // GM0 and GN0 need to known at compile-time - static constexpr auto GM0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I0); - static constexpr auto GN0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I2); - - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr auto max_lds_align = math::lcm(Number{}, - Number{}, - Number{}, - Number{}); - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_gk_gm0_gm10_gm11_block_desc = - make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, GM0, I1, Number{}), max_lds_align); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_gk_gn0_gn10_gn11_block_desc = - make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, GN0, I1, Number{}), max_lds_align); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_aligned_space_size = math::integer_least_multiple( - a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_aligned_space_size = math::integer_least_multiple( - b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize(), max_lds_align); - - return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); - } - - __host__ __device__ static constexpr bool - CheckValidity(const AGKGM0GM1GridDesc& a_gk_gm0_gm1_grid_desc, - const BGKGN0GN1GridDesc& b_gk_gn0_gn1_grid_desc, - const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc) - { - static_assert(is_known_at_compile_time>::value && - is_known_at_compile_time>::value, - "wrong! GM0 and GN0 need to be known at compile-time"); - - const auto GM1 = a_gk_gm0_gm1_grid_desc.GetLength(I2); - const auto GN1 = b_gk_gn0_gn1_grid_desc.GetLength(I2); - const auto GK = a_gk_gm0_gm1_grid_desc.GetLength(I0); - - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - - return ((GM0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I0) && - GM1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1) && - GN0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I2) && - GN1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3) && - GM0 == a_gk_gm0_gm1_grid_desc.GetLength(I1) && - GM1 == a_gk_gm0_gm1_grid_desc.GetLength(I2) && - GN0 == b_gk_gn0_gn1_grid_desc.GetLength(I1) && - GN1 == b_gk_gn0_gn1_grid_desc.GetLength(I2) && - GK == b_gk_gn0_gn1_grid_desc.GetLength(I0)) && - (GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK % KPerBlock == 0)); - } - - __host__ __device__ static constexpr index_t - CalculateGridSize(const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc) - { - const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1); - const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3); - - constexpr index_t GM11 = GM1PerBlockGM11; - constexpr index_t GN11 = GN1PerBlockGN11; - - const index_t GM10 = GM1 / GM11; - const index_t GN10 = GN1 / GN11; - - const index_t grid_size = GM10 * GN10; - - return grid_size; - } - - __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK) - { - const bool has_main_k_block_loop = (GK + KPerBlock) / (2 * KPerBlock) > 1; - - return has_main_k_block_loop; - } - - __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK) - { - const bool has_double_tail_k_block_loop = (GK / KPerBlock) % 2 == 0; - - return has_double_tail_k_block_loop; - } - - __host__ __device__ static constexpr auto - MakeAGKGM0GM10GM11GridDescriptor(const AGKGM0GM1GridDesc& a_gk_gm0_gm1_grid_desc) - { - const auto GK = a_gk_gm0_gm1_grid_desc.GetLength(I0); - const auto GM1 = a_gk_gm0_gm1_grid_desc.GetLength(I2); - - const auto GM11 = Number{}; - const auto GM10 = GM1 / GM11; - - const auto a_gk_gm0_gm10_gm11_grid_desc = transform_dynamic_tensor_descriptor( - a_gk_gm0_gm1_grid_desc, - make_tuple(make_pass_through_transform(GK), - make_pass_through_transform(GM0), - make_unmerge_transform(make_tuple(GM10, GM11))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); - - return a_gk_gm0_gm10_gm11_grid_desc; - } - - __host__ __device__ static constexpr auto - MakeBGKGN0GN10GN11GridDescriptor(const BGKGN0GN1GridDesc& b_gk_gn0_gn1_grid_desc) - { - const auto GK = b_gk_gn0_gn1_grid_desc.GetLength(I0); - const auto GN1 = b_gk_gn0_gn1_grid_desc.GetLength(I2); - - const auto GN11 = Number{}; - const auto GN10 = GN1 / GN11; - - const auto b_gk_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor( - b_gk_gn0_gn1_grid_desc, - make_tuple(make_pass_through_transform(GK), - make_pass_through_transform(GN0), - make_unmerge_transform(make_tuple(GN10, GN11))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); - - return b_gk_gn0_gn10_gn11_grid_desc; - } - - __host__ __device__ static constexpr auto MakeCGM10BM0BM1GN10BN0BN1GridDescriptor( - const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc) - { - const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1); - const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3); - - constexpr auto GM11 = Number{}; - constexpr auto GN11 = Number{}; - - const auto GM10 = GM1 / GM11; - const auto GN10 = GN1 / GN11; - - constexpr auto BM = GM0 * GM11; - constexpr auto BN = GN0 * GN11; - - constexpr auto BM1 = - Number{}; - constexpr auto BN1 = - Number{}; - - constexpr auto BM0 = BM / BM1; - constexpr auto BN0 = BN / BN1; - - const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor( - c_gm0_gm1_gn0_gn1_grid_desc, - make_tuple(make_pass_through_transform(GM0), - make_unmerge_transform(make_tuple(GM10, GM11)), - make_pass_through_transform(GN0), - make_unmerge_transform(make_tuple(GN10, GN11))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{})); - - const auto c_gm10_bm_gn10_bn_grid_desc = transform_dynamic_tensor_descriptor( - c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc, - make_tuple(make_pass_through_transform(GM10), - make_merge_transform(make_tuple(GM0, GM11)), - make_pass_through_transform(GN10), - make_merge_transform(make_tuple(GN0, GN11))), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc = transform_dynamic_tensor_descriptor( - c_gm10_bm_gn10_bn_grid_desc, - make_tuple(make_pass_through_transform(GM10), - make_unmerge_transform(make_tuple(BM0, BM1)), - make_pass_through_transform(GN10), - make_unmerge_transform(make_tuple(BN0, BN1))), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{})); - - return c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc; - } - - __host__ __device__ static constexpr auto MakeCBlockIdToGM10GN10BlockClusterAdaptor( - const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc) - { - const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1); - const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3); - - constexpr auto GM11 = Number{}; - constexpr auto GN11 = Number{}; - - const auto GM10 = GM1 / GM11; - const auto GN10 = GN1 / GN11; - - const auto c_blockid_to_gm10_gn10_block_cluster_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(GM10, GN10))), - make_tuple(Sequence<0, 1>{}), - make_tuple(Sequence<0>{})); - - return c_blockid_to_gm10_gn10_block_cluster_adaptor; - } - - using AGKGM0GM10GM11GridDesc = decltype(MakeAGKGM0GM10GM11GridDescriptor(AGKGM0GM1GridDesc{})); - using BGKGN0GN10GN11GridDesc = decltype(MakeBGKGN0GN10GN11GridDescriptor(BGKGN0GN1GridDesc{})); - using CGM10BM0BM1GN10BN0BN1GridDesc = - decltype(MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(CGM0GM1GN0GN1GridDesc{})); - using CBlockIdToGM10GN10BlockClusterAdaptor = - decltype(MakeCBlockIdToGM10GN10BlockClusterAdaptor(CGM0GM1GN0GN1GridDesc{})); - - template - __device__ static void - Run(const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - FloatAB* __restrict__ p_shared_block, - const AGKGM0GM10GM11GridDesc& a_gk_gm0_gm10_gm11_grid_desc, - const BGKGN0GN10GN11GridDesc& b_gk_gn0_gn10_gn11_grid_desc, - const CGM10BM0BM1GN10BN0BN1GridDesc& c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, - const CBlockIdToGM10GN10BlockClusterAdaptor& c_blockid_to_gm10_gn10_block_cluster_adaptor, - integral_constant, - integral_constant) - { - const auto a_global_buf = make_dynamic_buffer( - p_a_grid, a_gk_gm0_gm10_gm11_grid_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( - p_b_grid, b_gk_gn0_gn10_gn11_grid_desc.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetElementSpaceSize()); - - const auto GK = a_gk_gm0_gm10_gm11_grid_desc.GetLength(I0); - - // divide block work by [GM10, GN10] - const auto c_gm10_gn10_block_cluster_idx = - c_blockid_to_gm10_gn10_block_cluster_adaptor.CalculateBottomIndex( - make_multi_index(get_block_1d_id())); - - // HACK: this force index data into SGPR - const index_t igm10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I0]); - const index_t ign10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I1]); - - // lds max alignment - // part of them should be moved into blockwise-gemm - constexpr auto max_lds_align = math::lcm(Number{}, - Number{}, - Number{}, - Number{}); - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_gk_gm0_gm10_gm11_block_desc = - make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, GM0, I1, Number{}), max_lds_align); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_gk_gn0_gn10_gn11_block_desc = - make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, GN0, I1, Number{}), max_lds_align); - - // A matrix in LDS memory for blockwise GEMM - // be careful of LDS alignment - constexpr auto a_gk_bm_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, GM0 * Number{}), max_lds_align); - - // B matrix in LDS memory for blockwise GEMM - // be careful of LDS alignment - constexpr auto b_gk_bn_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, GN0 * Number{}), max_lds_align); - - // A matrix blockwise copy - auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4< - BlockSize, - InMemoryDataOperation::Set, - Sequence, - ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11, - ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_gk_gm0_gm10_gm11_grid_desc), - decltype(a_gk_gm0_gm10_gm11_block_desc), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2, 3>, - ABlockTransferSrcVectorDim, - 3, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_GM11, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>(a_gk_gm0_gm10_gm11_grid_desc, - make_multi_index(0, 0, igm10, 0), - a_gk_gm0_gm10_gm11_block_desc, - make_multi_index(0, 0, 0, 0)); - - // B matrix blockwise copy - auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4< - BlockSize, - InMemoryDataOperation::Set, - Sequence, - BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11, - BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_gk_gn0_gn10_gn11_grid_desc), - decltype(b_gk_gn0_gn10_gn11_block_desc), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2, 3>, - BBlockTransferSrcVectorDim, - 3, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_GN11, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>(b_gk_gn0_gn10_gn11_grid_desc, - make_multi_index(0, 0, ign10, 0), - b_gk_gn0_gn10_gn11_block_desc, - make_multi_index(0, 0, 0, 0)); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[KPerBlock, GM1PerBlockGM11] is in LDS - // b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS - // c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in - // register - const auto blockwise_gemm = - BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2{}; - constexpr auto c_bm0_bm1_bn0_bn1_thread_tensor_lengths = - decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths(); - - constexpr auto c_bm0_bm1_bn0_bn1_thread_desc = - make_dynamic_naive_tensor_descriptor_packed_v2( - sequence_to_tuple_of_number(c_bm0_bm1_bn0_bn1_thread_tensor_lengths)); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_aligned_space_size = math::integer_least_multiple( - a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_aligned_space_size = math::integer_least_multiple( - b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize(), max_lds_align); - - FloatAB* p_a_block_double = p_shared_block; - FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; - - // register allocation for output - auto c_thread_buf = make_static_buffer( - c_bm0_bm1_bn0_bn1_thread_desc.GetElementSpaceSize()); - - ThreadwiseDynamicTensorSliceSet_v1{} - .Run(c_bm0_bm1_bn0_bn1_thread_desc, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - FloatAcc{0}); - - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); - - // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_k_m0_m1_global_iterator_hacks = AGridIteratorHacks{}; - constexpr auto b_k_n0_n1_global_iterator_hacks = BGridIteratorHacks{}; - - // hack to control index calculation when move slice window for A and B matrix for - // threadwise copy - constexpr auto a_k_m0_m1_global_move_slice_window_iterator_hack = - AGridMoveSliceWindowIteratorHacks{}; - constexpr auto b_k_n0_n1_global_move_slice_window_iterator_hack = - BGridMoveSliceWindowIteratorHacks{}; - - auto a_block_even_buf = make_dynamic_buffer( - p_a_block_double, a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize()); - auto b_block_even_buf = make_dynamic_buffer( - p_b_block_double, b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize()); - - auto a_block_odd_buf = make_dynamic_buffer( - p_a_block_double + a_block_aligned_space_size, - a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize()); - auto b_block_odd_buf = make_dynamic_buffer( - p_b_block_double + b_block_aligned_space_size, - b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize()); - - // LDS double buffer: preload data into LDS - { - a_blockwise_copy.RunRead( - a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); - b_blockwise_copy.RunRead( - b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); - - a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_even_buf); - b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_even_buf); - } - - if constexpr(HasMainKBlockLoop) - { - index_t k_block_data_begin = 0; - - // LDS double buffer: main body - // use Do-While loop instead of For loop to simplify control flow - do - { - // even iteration - a_blockwise_copy.MoveSrcSliceWindow( - a_gk_gm0_gm10_gm11_grid_desc, - a_block_slice_copy_step, - a_k_m0_m1_global_move_slice_window_iterator_hack); - b_blockwise_copy.MoveSrcSliceWindow( - b_gk_gn0_gn10_gn11_grid_desc, - b_block_slice_copy_step, - b_k_n0_n1_global_move_slice_window_iterator_hack); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead( - a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); - b_blockwise_copy.RunRead( - b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(c_bm0_bm1_bn0_bn1_thread_desc, - a_block_even_buf, - b_block_even_buf, - c_thread_buf); - - // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_odd_buf); - - // odd iteration - a_blockwise_copy.MoveSrcSliceWindow( - a_gk_gm0_gm10_gm11_grid_desc, - a_block_slice_copy_step, - a_k_m0_m1_global_move_slice_window_iterator_hack); - b_blockwise_copy.MoveSrcSliceWindow( - b_gk_gn0_gn10_gn11_grid_desc, - b_block_slice_copy_step, - b_k_n0_n1_global_move_slice_window_iterator_hack); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead( - a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); - b_blockwise_copy.RunRead( - b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run( - c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); - - // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_even_buf); - b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_even_buf); - - k_block_data_begin += 2 * KPerBlock; - } while(k_block_data_begin < GK - 2 * KPerBlock); - } - - // LDS double buffer: tail - if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left - { - a_blockwise_copy.MoveSrcSliceWindow(a_gk_gm0_gm10_gm11_grid_desc, - a_block_slice_copy_step, - a_k_m0_m1_global_move_slice_window_iterator_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_gk_gn0_gn10_gn11_grid_desc, - b_block_slice_copy_step, - b_k_n0_n1_global_move_slice_window_iterator_hack); - - __syncthreads(); - - // LDS double buffer: load last data from device mem - a_blockwise_copy.RunRead( - a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); - b_blockwise_copy.RunRead( - b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); - - // LDS double buffer: GEMM on 2nd-last data - blockwise_gemm.Run( - c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); - - // LDS double buffer: store last data to LDS - a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_odd_buf); - - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run( - c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); - } - else // if has 1 iteration left - { - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run( - c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); - } - - // output: register to global memory - { - constexpr index_t M11 = - M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101; - constexpr index_t N11 = - N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101; - - constexpr index_t M10 = GM1PerBlockGM11 / M11; - constexpr index_t N10 = GN1PerBlockGN11 / N11; - - constexpr index_t M111 = M1PerThreadM111; - constexpr index_t N111 = N1PerThreadN111; - - constexpr auto c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc = - make_dynamic_naive_tensor_descriptor_packed_v2( - make_tuple(I1, - Number{}, - Number{}, - I1, - Number{}, - Number{})); - - const auto c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block = - blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id()); - - ThreadwiseDynamicTensorSliceTransfer_v1r3< - FloatAcc, - FloatC, - decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc), - decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc), - Sequence<1, - c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I0], - c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I1], - 1, - c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I2], - c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I3]>, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, - make_multi_index(igm10, - c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I0], - c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I1], - ign10, - c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I2], - c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I3])} - .Run(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0), - c_thread_buf, - c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, - c_grid_buf, - CGridIteratorHacks{}); - } - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_v1r2.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_v1r2.hpp index af25458b6c..f47e85e0bd 100644 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_v1r2.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_v1r2.hpp @@ -5,8 +5,8 @@ #include "dynamic_multi_index_transform_helper.hpp" #include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor_helper.hpp" -#include "blockwise_gemm_v2r2.hpp" -#include "blockwise_dynamic_tensor_slice_transfer.hpp" +#include "blockwise_gemm_v2r3.hpp" +#include "blockwise_dynamic_tensor_slice_transfer_v2.hpp" #include "threadwise_dynamic_tensor_slice_transfer.hpp" #include "threadwise_dynamic_tensor_slice_set.hpp" @@ -15,10 +15,10 @@ namespace ck { template __global__ void @@ -29,11 +29,10 @@ __global__ void const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, - const AGK0GM0GM10GM11GK1GridDesc a_gk0_gm0_gm10_gm11_gk1_grid_desc, - const BGK0GN0GN10GN11GK1GridDesc b_gk0_gn0_gn10_gn11_gk1_grid_desc, - const CGM10BM0BM1GN10BN0BN1GridDesc c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, - const CBlockIdToGM10GN10BlockClusterAdaptor - c_blockid_to_gm10_gn10_block_cluster_adaptor) + const AGridDesc_GK0_GM0_GM10_GM11_GK1 a_grid_desc_gk0_gm0_gm10_gm11_gk1, + const BGridDesc_GK0_GN0_GN10_GN11_GK1 b_grid_desc_gk0_gn0_gn10_gn11_gk1, + const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + const CGridBlockCluster_BlockId_To_GM10_GN10 c_grid_block_cluster_blockid_to_gm10_gn10) { constexpr index_t shared_block_size = GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); @@ -44,10 +43,10 @@ __global__ void p_b_grid, p_c_grid, p_shared_block, - a_gk0_gm0_gm10_gm11_gk1_grid_desc, - b_gk0_gn0_gn10_gn11_gk1_grid_desc, - c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, - c_blockid_to_gm10_gn10_block_cluster_adaptor, + a_grid_desc_gk0_gm0_gm10_gm11_gk1, + b_grid_desc_gk0_gn0_gn10_gn11_gk1, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + c_grid_block_cluster_blockid_to_gm10_gn10, integral_constant{}, integral_constant{}); } @@ -57,19 +56,19 @@ template -struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2 +struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1 { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -100,9 +99,9 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2 static constexpr auto I3 = Number<3>{}; // GM0 and GN0 need to known at compile-time - static constexpr auto GM0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I0); - static constexpr auto GN0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I2); - static constexpr auto GK1 = AGK0GM0GM1GK1GridDesc{}.GetLength(I3); + static constexpr auto GM0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I0); + static constexpr auto GN0 = CGridDesc_GM0_GM1_GN0_GN1{}.GetLength(I2); + static constexpr auto GK1 = AGridDesc_GK0_GM0_GM1_GK1{}.GetLength(I3); __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { @@ -113,61 +112,62 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2 // A matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto a_gk0_gm0_gm10_gm11_gk1_block_desc = + constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, GM0, I1, Number{}, GK1), + make_tuple(Number{}, GM0, I1, Number{}, GK1), max_lds_align); // B matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto b_gk0_gn0_gn10_gn11_gk1_block_desc = + constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, GN0, I1, Number{}, GK1), + make_tuple(Number{}, GN0, I1, Number{}, GK1), max_lds_align); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_aligned_space_size = math::integer_least_multiple( - a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize(), max_lds_align); + a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align); constexpr auto b_block_aligned_space_size = math::integer_least_multiple( - b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize(), max_lds_align); + b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align); return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); } __host__ __device__ static constexpr bool - CheckValidity(const AGK0GM0GM1GK1GridDesc& a_gk0_gm0_gm1_gk1_grid_desc, - const BGK0GN0GN1GK1GridDesc& b_gk0_gn0_gn1_gk1_grid_desc, - const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc) + CheckValidity(const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1, + const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1, + const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) { static_assert(is_known_at_compile_time>::value && is_known_at_compile_time>::value, "wrong! GM0 and GN0 need to be known at compile-time"); - const auto GM1 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I2); - const auto GN1 = b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I2); - const auto GK0 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I0); + const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2); + const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2); + const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0); // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - return ((GM0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I0) && - GM1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1) && - GN0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I2) && - GN1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3) && - GM0 == a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I1) && - GM1 == a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I2) && - GN0 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I1) && - GN1 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I2) && - GK0 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I0) && - GK1 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I3)) && - (GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK0 % KPerBlock == 0)); + return ( + (GM0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I0) && + GM1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1) && + GN0 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I2) && + GN1 == c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3) && + GM0 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I1) && + GM1 == a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2) && + GN0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I1) && + GN1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2) && + GK0 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0) && + GK1 == b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I3)) && + (GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK0 % GK0PerBlock == 0)); } __host__ __device__ static constexpr index_t - CalculateGridSize(const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc) + CalculateGridSize(const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) { - const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1); - const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3); + const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1); + const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3); constexpr index_t GM11 = GM1PerBlockGM11; constexpr index_t GN11 = GN1PerBlockGN11; @@ -182,29 +182,29 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK0) { - const bool has_main_k_block_loop = (GK0 + KPerBlock) / (2 * KPerBlock) > 1; + const bool has_main_k_block_loop = (GK0 + GK0PerBlock) / (2 * GK0PerBlock) > 1; return has_main_k_block_loop; } __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK0) { - const bool has_double_tail_k_block_loop = (GK0 / KPerBlock) % 2 == 0; + const bool has_double_tail_k_block_loop = (GK0 / GK0PerBlock) % 2 == 0; return has_double_tail_k_block_loop; } - __host__ __device__ static constexpr auto - MakeAGK0GM0GM10GM11GK1GridDescriptor(const AGK0GM0GM1GK1GridDesc& a_gk0_gm0_gm1_gk1_grid_desc) + __host__ __device__ static constexpr auto MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1( + const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1) { - const auto GK0 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I0); - const auto GM1 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I2); + const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0); + const auto GM1 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I2); const auto GM11 = Number{}; const auto GM10 = GM1 / GM11; - const auto a_gk0_gm0_gm10_gm11_gk1_grid_desc = transform_dynamic_tensor_descriptor( - a_gk0_gm0_gm1_gk1_grid_desc, + const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = transform_dynamic_tensor_descriptor( + a_grid_desc_gk0_gm0_gm1_gk1, make_tuple(make_pass_through_transform(GK0), make_pass_through_transform(GM0), make_unmerge_transform(make_tuple(GM10, GM11)), @@ -212,20 +212,20 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2 make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); - return a_gk0_gm0_gm10_gm11_gk1_grid_desc; + return a_grid_desc_gk0_gm0_gm10_gm11_gk1; } - __host__ __device__ static constexpr auto - MakeBGK0GN0GN10GN11GK1GridDescriptor(const BGK0GN0GN1GK1GridDesc& b_gk0_gn0_gn1_gk1_grid_desc) + __host__ __device__ static constexpr auto MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1( + const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1) { - const auto GK0 = b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I0); - const auto GN1 = b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I2); + const auto GK0 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I0); + const auto GN1 = b_grid_desc_gk0_gn0_gn1_gk1.GetLength(I2); const auto GN11 = Number{}; const auto GN10 = GN1 / GN11; - const auto b_gk0_gn0_gn10_gn11_gk1_grid_desc = transform_dynamic_tensor_descriptor( - b_gk0_gn0_gn1_gk1_grid_desc, + const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = transform_dynamic_tensor_descriptor( + b_grid_desc_gk0_gn0_gn1_gk1, make_tuple(make_pass_through_transform(GK0), make_pass_through_transform(GN0), make_unmerge_transform(make_tuple(GN10, GN11)), @@ -233,14 +233,14 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2 make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); - return b_gk0_gn0_gn10_gn11_gk1_grid_desc; + return b_grid_desc_gk0_gn0_gn10_gn11_gk1; } - __host__ __device__ static constexpr auto MakeCGM10BM0BM1GN10BN0BN1GridDescriptor( - const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc) + __host__ __device__ static constexpr auto MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( + const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) { - const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1); - const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3); + const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1); + const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3); constexpr auto GM11 = Number{}; constexpr auto GN11 = Number{}; @@ -252,15 +252,15 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2 constexpr auto BN = GN0 * GN11; constexpr auto BM1 = - Number{}; + Number{}; constexpr auto BN1 = - Number{}; + Number{}; constexpr auto BM0 = BM / BM1; constexpr auto BN0 = BN / BN1; const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor( - c_gm0_gm1_gn0_gn1_grid_desc, + c_grid_desc_gm0_gm1_gn0_gn1, make_tuple(make_pass_through_transform(GM0), make_unmerge_transform(make_tuple(GM10, GM11)), make_pass_through_transform(GN0), @@ -277,7 +277,7 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2 make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc = transform_dynamic_tensor_descriptor( + const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = transform_dynamic_tensor_descriptor( c_gm10_bm_gn10_bn_grid_desc, make_tuple(make_pass_through_transform(GM10), make_unmerge_transform(make_tuple(BM0, BM1)), @@ -286,14 +286,14 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2 make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{})); - return c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc; + return c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1; } - __host__ __device__ static constexpr auto MakeCBlockIdToGM10GN10BlockClusterAdaptor( - const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc) + __host__ __device__ static constexpr auto MakeCGridBlockCluster_BlockId_To_GM10_GN10( + const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1) { - const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1); - const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3); + const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1); + const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3); constexpr auto GM11 = Number{}; constexpr auto GN11 = Number{}; @@ -301,22 +301,22 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2 const auto GM10 = GM1 / GM11; const auto GN10 = GN1 / GN11; - const auto c_blockid_to_gm10_gn10_block_cluster_adaptor = make_single_stage_tensor_adaptor( + const auto c_grid_block_cluster_blockid_to_gm10_gn10 = make_single_stage_tensor_adaptor( make_tuple(make_merge_transform(make_tuple(GM10, GN10))), make_tuple(Sequence<0, 1>{}), make_tuple(Sequence<0>{})); - return c_blockid_to_gm10_gn10_block_cluster_adaptor; + return c_grid_block_cluster_blockid_to_gm10_gn10; } - using AGK0GM0GM10GM11GK1GridDesc = - decltype(MakeAGK0GM0GM10GM11GK1GridDescriptor(AGK0GM0GM1GK1GridDesc{})); - using BGK0GN0GN10GN11GK1GridDesc = - decltype(MakeBGK0GN0GN10GN11GK1GridDescriptor(BGK0GN0GN1GK1GridDesc{})); - using CGM10BM0BM1GN10BN0BN1GridDesc = - decltype(MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(CGM0GM1GN0GN1GridDesc{})); - using CBlockIdToGM10GN10BlockClusterAdaptor = - decltype(MakeCBlockIdToGM10GN10BlockClusterAdaptor(CGM0GM1GN0GN1GridDesc{})); + using AGridDesc_GK0_GM0_GM10_GM11_GK1 = + decltype(MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(AGridDesc_GK0_GM0_GM1_GK1{})); + using BGridDesc_GK0_GN0_GN10_GN11_GK1 = + decltype(MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(BGridDesc_GK0_GN0_GN1_GK1{})); + using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = + decltype(MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(CGridDesc_GM0_GM1_GN0_GN1{})); + using CGridBlockCluster_BlockId_To_GM10_GN10 = + decltype(MakeCGridBlockCluster_BlockId_To_GM10_GN10(CGridDesc_GM0_GM1_GN0_GN1{})); template __device__ static void @@ -324,25 +324,25 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2 const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, FloatAB* __restrict__ p_shared_block, - const AGK0GM0GM10GM11GK1GridDesc& a_gk0_gm0_gm10_gm11_gk1_grid_desc, - const BGK0GN0GN10GN11GK1GridDesc& b_gk0_gn0_gn10_gn11_gk1_grid_desc, - const CGM10BM0BM1GN10BN0BN1GridDesc& c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, - const CBlockIdToGM10GN10BlockClusterAdaptor& c_blockid_to_gm10_gn10_block_cluster_adaptor, + const AGridDesc_GK0_GM0_GM10_GM11_GK1& a_grid_desc_gk0_gm0_gm10_gm11_gk1, + const BGridDesc_GK0_GN0_GN10_GN11_GK1& b_grid_desc_gk0_gn0_gn10_gn11_gk1, + const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1& c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + const CGridBlockCluster_BlockId_To_GM10_GN10& c_grid_block_cluster_blockid_to_gm10_gn10, integral_constant, integral_constant) { const auto a_global_buf = make_dynamic_buffer( - p_a_grid, a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetElementSpaceSize()); + p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); const auto b_global_buf = make_dynamic_buffer( - p_b_grid, b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetElementSpaceSize()); + p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetElementSpaceSize()); + p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize()); - const auto GK0 = a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I0); + const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0); // divide block work by [GM10, GN10] const auto c_gm10_gn10_block_cluster_idx = - c_blockid_to_gm10_gn10_block_cluster_adaptor.CalculateBottomIndex( + c_grid_block_cluster_blockid_to_gm10_gn10.CalculateBottomIndex( make_multi_index(get_block_1d_id())); // HACK: this force index data into SGPR @@ -356,46 +356,46 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2 // A matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto a_gk0_gm0_gm10_gm11_gk1_block_desc = + constexpr auto a_block_desc_gk0_gm0_gm10_gm11_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, GM0, I1, Number{}, GK1), + make_tuple(Number{}, GM0, I1, Number{}, GK1), max_lds_align); // B matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment - constexpr auto b_gk0_gn0_gn10_gn11_gk1_block_desc = + constexpr auto b_block_desc_gk0_gn0_gn10_gn11_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, GN0, I1, Number{}, GK1), + make_tuple(Number{}, GN0, I1, Number{}, GK1), max_lds_align); // A matrix in LDS memory for blockwise GEMM // be careful of LDS alignment - constexpr auto a_gk0_bm_gk1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, GM0 * Number{}, GK1), max_lds_align); + constexpr auto a_block_desc_gk0_bm_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GM0 * Number{}, GK1), max_lds_align); // B matrix in LDS memory for blockwise GEMM // be careful of LDS alignment - constexpr auto b_gk0_bn_gk1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, GN0 * Number{}, GK1), max_lds_align); + constexpr auto b_block_desc_gk0_bn_gk1 = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GN0 * Number{}, GK1), max_lds_align); - static_assert(a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize() == - a_gk0_bm_gk1_block_desc.GetElementSpaceSize() && - b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize() == - b_gk0_bn_gk1_block_desc.GetElementSpaceSize(), + static_assert(a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize() == + a_block_desc_gk0_bm_gk1.GetElementSpaceSize() && + b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize() == + b_block_desc_gk0_bn_gk1.GetElementSpaceSize(), "wrong!"); // A matrix blockwise copy auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< BlockSize, InMemoryDataOperation::Set, - Sequence, + Sequence, ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, ABlockTransferThreadClusterArrangeOrder, FloatAB, FloatAB, - decltype(a_gk0_gm0_gm10_gm11_gk1_grid_desc), - decltype(a_gk0_gm0_gm10_gm11_gk1_block_desc), + decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1), + decltype(a_block_desc_gk0_gm0_gm10_gm11_gk1), ABlockTransferSrcAccessOrder, Sequence<0, 1, 2, 3, 4>, ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // SrcVectorTensorLengths @@ -403,23 +403,23 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2 ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder false, - true>(a_gk0_gm0_gm10_gm11_gk1_grid_desc, + true>(a_grid_desc_gk0_gm0_gm10_gm11_gk1, make_multi_index(0, 0, igm10, 0, 0), - a_gk0_gm0_gm10_gm11_gk1_block_desc, + a_block_desc_gk0_gm0_gm10_gm11_gk1, make_multi_index(0, 0, 0, 0, 0)); // B matrix blockwise copy auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< BlockSize, InMemoryDataOperation::Set, - Sequence, + Sequence, BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, BBlockTransferThreadClusterArrangeOrder, FloatAB, FloatAB, - decltype(b_gk0_gn0_gn10_gn11_gk1_grid_desc), - decltype(b_gk0_gn0_gn10_gn11_gk1_block_desc), + decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1), + decltype(b_block_desc_gk0_gn0_gn10_gn11_gk1), BBlockTransferSrcAccessOrder, Sequence<0, 1, 2, 3, 4>, BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // SrcVectorTensorLengths @@ -427,102 +427,103 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2 BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder false, - true>(b_gk0_gn0_gn10_gn11_gk1_grid_desc, + true>(b_grid_desc_gk0_gn0_gn10_gn11_gk1, make_multi_index(0, 0, ign10, 0, 0), - b_gk0_gn0_gn10_gn11_gk1_block_desc, + b_block_desc_gk0_gn0_gn10_gn11_gk1, make_multi_index(0, 0, 0, 0, 0)); // GEMM definition // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[KPerBlock, GM1PerBlockGM11] is in LDS + // a_mtx[GK0PerBlock, GM1PerBlockGM11] is in LDS // b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS // c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in // register const auto blockwise_gemm = - BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2{}; + BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockSize, + FloatAB, + FloatAB, + FloatAcc, + decltype(a_block_desc_gk0_bm_gk1), + decltype(b_block_desc_gk0_bn_gk1), + BM1PerThreadBM11, + BN1PerThreadBN11, + BK0PerThread, + BM10BN10ThreadClusterBM100, + BM10BN10ThreadClusterBN100, + BM10BN10ThreadClusterBM101, + BM10BN10ThreadClusterBN101, + BM1PerThreadBM11, + BN1PerThreadBN11>{}; - constexpr auto c_bm0_bm1_bn0_bn1_thread_tensor_lengths = - decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths(); + constexpr auto c_thread_tensor_lengths_bm0_bm1_bn0_bn1 = + decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); - constexpr auto c_bm0_bm1_bn0_bn1_thread_desc = + constexpr auto c_thread_desc_bm0_bm1_bn0_bn1 = make_dynamic_naive_tensor_descriptor_packed_v2( - sequence_to_tuple_of_number(c_bm0_bm1_bn0_bn1_thread_tensor_lengths)); + sequence_to_tuple_of_number(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_aligned_space_size = math::integer_least_multiple( - a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize(), max_lds_align); + a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize(), max_lds_align); constexpr auto b_block_aligned_space_size = math::integer_least_multiple( - b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize(), max_lds_align); + b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize(), max_lds_align); FloatAB* p_a_block_double = p_shared_block; FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; // register allocation for output auto c_thread_buf = make_static_buffer( - c_bm0_bm1_bn0_bn1_thread_desc.GetElementSpaceSize()); + c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize()); ThreadwiseDynamicTensorSliceSet_v1{} - .Run(c_bm0_bm1_bn0_bn1_thread_desc, + decltype(c_thread_desc_bm0_bm1_bn0_bn1), + decltype(c_thread_tensor_lengths_bm0_bm1_bn0_bn1)>{} + .Run(c_thread_desc_bm0_bm1_bn0_bn1, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0, 0); + constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0); auto a_block_even_buf = make_dynamic_buffer( - p_a_block_double, a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize()); + p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); auto b_block_even_buf = make_dynamic_buffer( - p_b_block_double, b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize()); + p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); auto a_block_odd_buf = make_dynamic_buffer( p_a_block_double + a_block_aligned_space_size, - a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize()); + a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize()); auto b_block_odd_buf = make_dynamic_buffer( p_b_block_double + b_block_aligned_space_size, - b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize()); + b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize()); // LDS double buffer: preload data into LDS { a_blockwise_copy.RunRead( - a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{}); + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); b_blockwise_copy.RunRead( - b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{}); + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); - a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_even_buf); - b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_even_buf); + a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf); } if constexpr(HasMainKBlockLoop) { - index_t k_block_data_begin = 0; + index_t gk0_block_on_grid = 0; // LDS double buffer: main body // use Do-While loop instead of For loop to simplify control flow do { // even iteration - a_blockwise_copy.MoveSrcSliceWindow(a_gk0_gm0_gm10_gm11_gk1_grid_desc, + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_block_slice_copy_step, AGridMoveSliceWindowIteratorHacks{}); - b_blockwise_copy.MoveSrcSliceWindow(b_gk0_gn0_gn10_gn11_gk1_grid_desc, + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_block_slice_copy_step, BGridMoveSliceWindowIteratorHacks{}); @@ -530,25 +531,25 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2 // LDS doubel buffer: load next data from device mem a_blockwise_copy.RunRead( - a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{}); + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); b_blockwise_copy.RunRead( - b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{}); + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); // LDS double buffer: GEMM on current data - blockwise_gemm.Run(c_bm0_bm1_bn0_bn1_thread_desc, + blockwise_gemm.Run(c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf); // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_odd_buf); + a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf); // odd iteration - a_blockwise_copy.MoveSrcSliceWindow(a_gk0_gm0_gm10_gm11_gk1_grid_desc, + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_block_slice_copy_step, AGridMoveSliceWindowIteratorHacks{}); - b_blockwise_copy.MoveSrcSliceWindow(b_gk0_gn0_gn10_gn11_gk1_grid_desc, + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_block_slice_copy_step, BGridMoveSliceWindowIteratorHacks{}); @@ -556,29 +557,29 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2 // LDS doubel buffer: load next data from device mem a_blockwise_copy.RunRead( - a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{}); + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); b_blockwise_copy.RunRead( - b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{}); + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); // LDS double buffer: GEMM on current data blockwise_gemm.Run( - c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf); // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_even_buf); - b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_even_buf); + a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_even_buf); - k_block_data_begin += 2 * KPerBlock; - } while(k_block_data_begin < GK0 - 2 * KPerBlock); + gk0_block_on_grid += 2 * GK0PerBlock; + } while(gk0_block_on_grid < GK0 - 2 * GK0PerBlock); } // LDS double buffer: tail if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left { - a_blockwise_copy.MoveSrcSliceWindow(a_gk0_gm0_gm10_gm11_gk1_grid_desc, + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_block_slice_copy_step, AGridMoveSliceWindowIteratorHacks{}); - b_blockwise_copy.MoveSrcSliceWindow(b_gk0_gn0_gn10_gn11_gk1_grid_desc, + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_block_slice_copy_step, BGridMoveSliceWindowIteratorHacks{}); @@ -586,23 +587,23 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2 // LDS double buffer: load last data from device mem a_blockwise_copy.RunRead( - a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{}); + a_grid_desc_gk0_gm0_gm10_gm11_gk1, a_global_buf, AGridIteratorHacks{}); b_blockwise_copy.RunRead( - b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{}); + b_grid_desc_gk0_gn0_gn10_gn11_gk1, b_global_buf, BGridIteratorHacks{}); // LDS double buffer: GEMM on 2nd-last data blockwise_gemm.Run( - c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf); // LDS double buffer: store last data to LDS - a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_odd_buf); + a_blockwise_copy.RunWrite(a_block_desc_gk0_gm0_gm10_gm11_gk1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_gk0_gn0_gn10_gn11_gk1, b_block_odd_buf); __syncthreads(); // LDS double buffer: GEMM on last data blockwise_gemm.Run( - c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + c_thread_desc_bm0_bm1_bn0_bn1, a_block_odd_buf, b_block_odd_buf, c_thread_buf); } else // if has 1 iteration left { @@ -610,61 +611,51 @@ struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2 // LDS double buffer: GEMM on last data blockwise_gemm.Run( - c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + c_thread_desc_bm0_bm1_bn0_bn1, a_block_even_buf, b_block_even_buf, c_thread_buf); } // output: register to global memory { - constexpr index_t M11 = - M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101; - constexpr index_t N11 = - N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101; - - constexpr index_t M10 = GM1PerBlockGM11 / M11; - constexpr index_t N10 = GN1PerBlockGN11 / N11; - - constexpr index_t M111 = M1PerThreadM111; - constexpr index_t N111 = N1PerThreadN111; - - constexpr auto c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc = + constexpr auto c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1 = make_dynamic_naive_tensor_descriptor_packed_v2( make_tuple(I1, - Number{}, - Number{}, + Number{}, + Number{}, I1, - Number{}, - Number{})); + Number{}, + Number{})); - const auto c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block = - blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id()); + const auto c_thread_origin_on_block_bm0_bm1_bn0_bn1 = + blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( + get_thread_local_1d_id()); ThreadwiseDynamicTensorSliceTransfer_v1r3< FloatAcc, FloatC, - decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc), - decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc), + decltype(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1), + decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1), Sequence<1, - c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I0], - c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I1], + c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I0], + c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I1], 1, - c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I2], - c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I3]>, + c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I2], + c_thread_tensor_lengths_bm0_bm1_bn0_bn1[I3]>, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, CGlobalMemoryDataOperation, 1, - true>{c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, - make_multi_index(igm10, - c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I0], - c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I1], - ign10, - c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I2], - c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I3])} - .Run(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc, + false>{c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + make_multi_index(igm10, + c_thread_origin_on_block_bm0_bm1_bn0_bn1[I0], + c_thread_origin_on_block_bm0_bm1_bn0_bn1[I1], + ign10, + c_thread_origin_on_block_bm0_bm1_bn0_bn1[I2], + c_thread_origin_on_block_bm0_bm1_bn0_bn1[I3])} + .Run(c_thread_desc_gm10_bm0_bm1_gn10_bn0_bn1, make_tuple(I0, I0, I0, I0, I0, I0), c_thread_buf, - c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, c_grid_buf, CGridIteratorHacks{}); } diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r1.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r1.hpp deleted file mode 100644 index 8e8af1a12a..0000000000 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r1.hpp +++ /dev/null @@ -1,552 +0,0 @@ -#ifndef CK_GRIDWISE_DYNAMIC_GEMM_HPP -#define CK_GRIDWISE_DYNAMIC_GEMM_HPP - -#include "common_header.hpp" -#include "dynamic_multi_index_transform_helper.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "blockwise_gemm_v2.hpp" -#include "blockwise_dynamic_tensor_slice_transfer.hpp" -#include "threadwise_dynamic_tensor_slice_transfer.hpp" -#include "threadwise_dynamic_tensor_slice_set.hpp" - -namespace ck { - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_dynamic_gemm_v1r1(const FloatA* __restrict__ p_a_global, - const FloatB* __restrict__ p_b_global, - FloatC* __restrict__ p_c_global, - const AGlobalDesc a_k_m_global_desc, - const BGlobalDesc b_k_n_global_desc, - const CGlobalDesc c_m0_m1_n0_n1_global_desc, - const CBlockClusterDesc c_block_cluster_desc) -{ - GridwiseGemm::Run(p_a_global, - p_b_global, - p_c_global, - a_k_m_global_desc, - b_k_n_global_desc, - c_m0_m1_n0_n1_global_desc, - c_block_cluster_desc, - integral_constant{}, - integral_constant{}); -} -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER -// pass tensor descriptor by __CONSTANT__ void pointer -// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to -// non-modifiable parameter address space, so compiler can enable corresponding optimization -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - kernel_dynamic_gemm_v1r1(const FloatA* __restrict__ p_a_global, - const FloatB* __restrict__ p_b_global, - FloatC* __restrict__ p_c_global, - const void __CONSTANT__* p_a_k_m_global_desc, - const void __CONSTANT__* p_b_k_n_global_desc, - const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc, - const void __CONSTANT__* p_c_block_cluster_desc) -{ - // first cast void __CONSTANT__ void* to void* - // second cast void* to Desc* - // the copy constructor of tensor descriptor doesn't take address_space(4) - const auto a_k_m_global_desc = - *reinterpret_cast((const void*)p_a_k_m_global_desc); - const auto b_k_n_global_desc = - *reinterpret_cast((const void*)p_b_k_n_global_desc); - const auto c_m0_m1_n0_n1_global_desc = - *reinterpret_cast((const void*)p_c_m0_m1_n0_n1_global_desc); - - const auto c_block_cluster_desc = - *reinterpret_cast((const void*)p_c_block_cluster_desc); - - GridwiseGemm::Run(p_a_global, - p_b_global, - p_c_global, - a_k_m_global_desc, - b_k_n_global_desc, - c_m0_m1_n0_n1_global_desc, - c_block_cluster_desc, - integral_constant{}, - integral_constant{}); -} -#endif - -template -struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1 -{ - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr auto max_lds_align = math::lcm(Number{}, - Number{}, - Number{}, - Number{}); - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, Number{}), max_lds_align); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, Number{}), max_lds_align); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size = - math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); - - return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB); - } - - template - __device__ static void Run(const FloatAB* __restrict__ p_a_global, - const FloatAB* __restrict__ p_b_global, - FloatC* __restrict__ p_c_global, - const AGlobalDesc& a_k_m_global_desc, - const BGlobalDesc& b_k_n_global_desc, - const CGlobalDesc& c_m0_m1_n0_n1_global_desc, - const CBlockClusterDesc& c_block_cluster_desc, - FloatAB* __restrict__ p_shared_block, - integral_constant, - integral_constant) - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto a_global_buf = make_dynamic_buffer( - p_a_global, a_k_m_global_desc.GetElementSpaceSize()); - const auto b_global_buf = make_dynamic_buffer( - p_b_global, b_k_n_global_desc.GetElementSpaceSize()); - auto c_global_buf = make_dynamic_buffer( - p_c_global, c_m0_m1_n0_n1_global_desc.GetElementSpaceSize()); - - const auto K = a_k_m_global_desc.GetLength(I0); - const auto M = a_k_m_global_desc.GetLength(I1); - const auto N = b_k_n_global_desc.GetLength(I1); - - // divide block work by [M, N] - const auto block_work_idx = - c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - // HACK: this force m/n_block_data_idx_on_global into SGPR - const index_t m_block_data_idx_on_global = - __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); - - const index_t n_block_data_idx_on_global = - __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(Number{}, - Number{}, - Number{}, - Number{}); - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, Number{}), max_lds_align); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( - make_tuple(Number{}, Number{}), max_lds_align); - - // A matrix blockwise copy - auto a_blockwise_copy = - BlockwiseDynamicTensorSliceTransfer_v4, - ABlockTransferThreadSliceLengths_K_M, - ABlockTransferThreadClusterLengths_K_M, - ABlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(a_k_m_global_desc), - decltype(a_k_m_block_desc), - ABlockTransferSrcAccessOrder, - Sequence<0, 1>, - ABlockTransferSrcVectorDim, - 1, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_M, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true>( - a_k_m_global_desc, - make_multi_index(0, m_block_data_idx_on_global), - a_k_m_block_desc, - make_multi_index(0, 0)); - - // B matrix blockwise copy - auto b_blockwise_copy = - BlockwiseDynamicTensorSliceTransfer_v4, - BBlockTransferThreadSliceLengths_K_N, - BBlockTransferThreadClusterLengths_K_N, - BBlockTransferThreadClusterArrangeOrder, - FloatAB, - FloatAB, - decltype(b_k_n_global_desc), - decltype(b_k_n_block_desc), - BBlockTransferSrcAccessOrder, - Sequence<0, 1>, - BBlockTransferSrcVectorDim, - 1, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_N, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true>( - b_k_n_global_desc, - make_multi_index(0, n_block_data_idx_on_global), - b_k_n_block_desc, - make_multi_index(0, 0)); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[KPerBlock, MPerBlock] is in LDS - // b_mtx[KPerBlocl, NPerBlock] is in LDS - // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in - // register - // sanity check - static_assert( - MPerBlock % (M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10) == 0 && - NPerBlock % (N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10) == 0, - "wrong!"); - - constexpr index_t M0PerThread = - MPerBlock / (M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10); - constexpr index_t N0PerThread = - NPerBlock / (N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10); - - constexpr auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor( - a_k_m_block_desc, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple( - Number{}, - Number{}))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{})); - - constexpr auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor( - b_k_n_block_desc, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple( - Number{}, - Number{}))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{})); - - constexpr auto c_m0_m1_n0_n1_thread_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number{}, - Number{}, - Number{}, - Number{})); - - const auto blockwise_gemm = - BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2{}; - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size = - math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size = - math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); - - FloatAB* p_a_block_double = p_shared_block; - FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size; - - // register allocation for output - auto c_thread_buf = make_static_buffer( - c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize()); - - ThreadwiseDynamicTensorSliceSet_v1< - FloatAcc, - decltype(c_m0_m1_n0_n1_thread_desc), - Sequence>{} - .Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); - - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0); - - // hack to control index calculation when iterating over A and B matrix for threadwise copy - constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{}; - constexpr auto b_k_n_global_iterator_hacks = BGlobalIteratorHacks{}; - - // hack to control index calculation when move slice window for A and B matrix for - // threadwise copy - constexpr auto a_k_m_global_move_slice_window_iterator_hack = - AGlobalMoveSliceWindowIteratorHacks{}; - constexpr auto b_k_n_global_move_slice_window_iterator_hack = - BGlobalMoveSliceWindowIteratorHacks{}; - - auto a_block_even_buf = make_dynamic_buffer( - p_a_block_double, a_k_m_block_desc.GetElementSpaceSize()); - auto b_block_even_buf = make_dynamic_buffer( - p_b_block_double, b_k_n_block_desc.GetElementSpaceSize()); - - auto a_block_odd_buf = make_dynamic_buffer( - p_a_block_double + a_block_space_size, a_k_m_block_desc.GetElementSpaceSize()); - auto b_block_odd_buf = make_dynamic_buffer( - p_b_block_double + b_block_space_size, b_k_n_block_desc.GetElementSpaceSize()); - - // LDS double buffer: preload data into LDS - { - a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks); - b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks); - - a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf); - b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf); - } - - if constexpr(HasMainKBlockLoop) - { - index_t k_block_data_begin = 0; - - // LDS double buffer: main body - // use Do-While loop instead of For loop to simplify control flow - do - { - // even iteration - a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, - a_block_slice_copy_step, - a_k_m_global_move_slice_window_iterator_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, - b_block_slice_copy_step, - b_k_n_global_move_slice_window_iterator_hack); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead( - a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks); - b_blockwise_copy.RunRead( - b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); - - // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf); - - // odd iteration - a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, - a_block_slice_copy_step, - a_k_m_global_move_slice_window_iterator_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, - b_block_slice_copy_step, - b_k_n_global_move_slice_window_iterator_hack); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunRead( - a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks); - b_blockwise_copy.RunRead( - b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf); - - // LDS double buffer: store next data to LDS - a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_even_buf); - b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_even_buf); - - k_block_data_begin += 2 * KPerBlock; - } while(k_block_data_begin < K - 2 * KPerBlock); - } - - // LDS double buffer: tail - if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left - { - a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, - a_block_slice_copy_step, - a_k_m_global_move_slice_window_iterator_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, - b_block_slice_copy_step, - b_k_n_global_move_slice_window_iterator_hack); - - __syncthreads(); - - // LDS double buffer: load last data from device mem - a_blockwise_copy.RunRead(a_k_m_global_desc, a_global_buf, a_k_m_global_iterator_hacks); - b_blockwise_copy.RunRead(b_k_n_global_desc, b_global_buf, b_k_n_global_iterator_hacks); - - // LDS double buffer: GEMM on 2nd-last data - blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); - - // LDS double buffer: store last data to LDS - a_blockwise_copy.RunWrite(a_k_m_block_desc, a_block_odd_buf); - b_blockwise_copy.RunWrite(b_k_n_block_desc, b_block_odd_buf); - - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf); - } - else // if has 1 iteration left - { - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); - } - - // output: register to global memory - { - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor - constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; - - const auto c_thread_data_idx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(get_thread_local_1d_id()); - - ThreadwiseDynamicTensorSliceTransfer_v1r3< - FloatAcc, - FloatC, - decltype(c_m0_m1_n0_n1_thread_desc), - decltype(c_m0_m1_n0_n1_global_desc), - Sequence, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{ - c_m0_m1_n0_n1_global_desc, - make_multi_index(m_block_data_idx_on_global / M1 + c_thread_data_idx_on_block[I0], - c_thread_data_idx_on_block[I1], - n_block_data_idx_on_global / N1 + c_thread_data_idx_on_block[I2], - c_thread_data_idx_on_block[I3])} - .Run(c_m0_m1_n0_n1_thread_desc, - make_tuple(I0, I0, I0, I0), - c_thread_buf, - c_m0_m1_n0_n1_global_desc, - c_global_buf, - c_m0_m1_n0_n1_global_tensor_iterator_hacks); - } - } - - template - __device__ static void Run(const FloatAB* __restrict__ p_a_global, - const FloatAB* __restrict__ p_b_global, - FloatC* __restrict__ p_c_global, - const AGlobalDesc& a_k_m_global_desc, - const BGlobalDesc& b_k_n_global_desc, - const CGlobalDesc& c_m0_m1_n0_n1_global_desc, - const CBlockClusterDesc& c_block_cluster_desc, - integral_constant, - integral_constant) - { - constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - Run(p_a_global, - p_b_global, - p_c_global, - a_k_m_global_desc, - b_k_n_global_desc, - c_m0_m1_n0_n1_global_desc, - c_block_cluster_desc, - p_shared_block, - integral_constant{}, - integral_constant{}); - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r3.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r3.hpp index c070cac826..20f91140db 100644 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r3.hpp @@ -435,21 +435,22 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3 // c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in // register const auto blockwise_gemm = - BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2{}; + BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockSize, + FloatAB, + FloatAB, + FloatAcc, + decltype(a_k0_m_k1_block_desc), + decltype(b_k0_n_k1_block_desc), + M1PerThreadM111, + N1PerThreadN111, + KPerThread, + M11N11ThreadClusterM1100, + M11N11ThreadClusterN1100, + M11N11ThreadClusterM1101, + M11N11ThreadClusterN1101, + M1PerThreadM111, + N1PerThreadN111>{}; constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths(); diff --git a/composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp b/composable_kernel/include/tensor_operation/threadwise_contraction.hpp similarity index 53% rename from composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp rename to composable_kernel/include/tensor_operation/threadwise_contraction.hpp index 748a888999..995c871c5e 100644 --- a/composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_contraction.hpp @@ -1,40 +1,44 @@ -#ifndef CK_THREADWISE_GEMM_V2_HPP -#define CK_THREADWISE_GEMM_V2_HPP +#ifndef CK_THREADWISE_CONTRACTION_HPP +#define CK_THREADWISE_CONTRACTION_HPP #include "common_header.hpp" #include "math.hpp" namespace ck { -// C[M0, M1, N0, N1] += A[K, M0, M1] * B[K, N0, N1] +// C[TM0, TM1, TN0, TN1] += A[TK, TM0, TM1] * B[TK, TN0, TN1] // Tensor element can be vectorized data // Assume: -// 1. ADesc, BDesc, CDesc are known at compile-time +// 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are +// known at compile-time // 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time template ::type = false> struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 { __device__ constexpr ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1() { - static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && - CDesc::IsKnownAtCompileTime(), + static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && + BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && + CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); - // TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths + // TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, + // CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths // TODO remove this restriction - static_assert(KLengths::Size() == 1 && MLengths::Size() == 2 && NLengths::Size() == 2, + static_assert(TKLengths::Size() == 1 && TMLengths::Size() == 2 && TNLengths::Size() == 2, "wrong!"); } @@ -70,28 +74,31 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - constexpr auto K = KLengths{}[I0]; - constexpr auto M0 = MLengths{}[I0]; - constexpr auto M1 = MLengths{}[I1]; - constexpr auto N0 = NLengths{}[I0]; - constexpr auto N1 = NLengths{}[I1]; + constexpr auto TK = TKLengths{}[I0]; + constexpr auto TM0 = TMLengths{}[I0]; + constexpr auto TM1 = TMLengths{}[I1]; + constexpr auto TN0 = TNLengths{}[I0]; + constexpr auto TN1 = TNLengths{}[I1]; constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); - static_for<0, K, 1>{}([&](auto k) { - static_for<0, M0, 1>{}([&](auto m0) { - static_for<0, M1, 1>{}([&](auto m1) { - static_for<0, N0, 1>{}([&](auto n0) { - static_for<0, N1, 1>{}([&](auto n1) { + static_for<0, TK, 1>{}([&](auto tk) { + static_for<0, TM0, 1>{}([&](auto tm0) { + static_for<0, TM1, 1>{}([&](auto tm1) { + static_for<0, TN0, 1>{}([&](auto tn0) { + static_for<0, TN1, 1>{}([&](auto tn1) { constexpr index_t a_offset = - ADesc{}.CalculateOffset(a_origin_idx + make_multi_index(k, m0, m1)); + AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset( + a_origin_idx + make_multi_index(tk, tm0, tm1)); constexpr index_t b_offset = - BDesc{}.CalculateOffset(b_origin_idx + make_multi_index(k, n0, n1)); - constexpr index_t c_offset = CDesc{}.CalculateOffset( - c_origin_idx + make_multi_index(m0, m1, n0, n1)); + BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset( + b_origin_idx + make_multi_index(tk, tn0, tn1)); + constexpr index_t c_offset = + CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( + c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); amd_inner_product_dlop( a_buf[Number{}], @@ -105,35 +112,39 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 } }; -// C[M0, M1, N0, N1] += A[K0, M0, M1, K1] * B[K0, N0, N1, K1] +// C[TM0, TM1, TN0, TN1] += A[TK0, TM0, TM1, TK1] * B[TK0, TN0, TN1, TK1] // Tensor element can be vectorized data // Assume: -// 1. ADesc, BDesc, CDesc are known at compile-time +// 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are +// known at compile-time // 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time template ::type = false> -struct ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1 +struct ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 { - __device__ constexpr ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1() + __device__ constexpr ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() { - static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && - CDesc::IsKnownAtCompileTime(), + static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && + BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && + CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); - // TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths + // TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, + // CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths // TODO remove this restriction - static_assert(KLengths::Size() == 2 && MLengths::Size() == 2 && NLengths::Size() == 2, + static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2, "wrong!"); } @@ -169,43 +180,45 @@ struct ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1 constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - constexpr index_t K0 = KLengths{}[I0]; - constexpr index_t K1 = KLengths{}[I1]; - constexpr index_t M0 = MLengths{}[I0]; - constexpr index_t M1 = MLengths{}[I1]; - constexpr index_t N0 = NLengths{}[I0]; - constexpr index_t N1 = NLengths{}[I1]; + constexpr index_t TK0 = TKLengths{}[I0]; + constexpr index_t TK1 = TKLengths{}[I1]; + constexpr index_t TM0 = TMLengths{}[I0]; + constexpr index_t TM1 = TMLengths{}[I1]; + constexpr index_t TN0 = TNLengths{}[I0]; + constexpr index_t TN1 = TNLengths{}[I1]; constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); - static_for<0, K0, 1>{}([&](auto k0) { - static_for<0, M0, 1>{}([&](auto m0) { - static_for<0, M1, 1>{}([&](auto m1) { - static_for<0, N0, 1>{}([&](auto n0) { - static_for<0, N1, 1>{}([&](auto n1) { + static_for<0, TK0, 1>{}([&](auto tk0) { + static_for<0, TM0, 1>{}([&](auto tm0) { + static_for<0, TM1, 1>{}([&](auto tm1) { + static_for<0, TN0, 1>{}([&](auto tn0) { + static_for<0, TN1, 1>{}([&](auto tn1) { - vector_type a_vec; - vector_type b_vec; + vector_type a_vec; + vector_type b_vec; - static_for<0, K1, 1>{}([&](auto k1) { - constexpr index_t a_offset = ADesc{}.CalculateOffset( - a_origin_idx + make_multi_index(k0, m0, m1, k1)); + static_for<0, TK1, 1>{}([&](auto tk1) { + constexpr index_t a_offset = + AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset( + a_origin_idx + make_multi_index(tk0, tm0, tm1, tk1)); - constexpr index_t b_offset = BDesc{}.CalculateOffset( - b_origin_idx + make_multi_index(k0, n0, n1, k1)); + constexpr index_t b_offset = + BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset( + b_origin_idx + make_multi_index(tk0, tn0, tn1, tk1)); - a_vec.template AsType()(k1) = a_buf[Number{}]; - - b_vec.template AsType()(k1) = b_buf[Number{}]; + a_vec.template AsType()(tk1) = a_buf[Number{}]; + b_vec.template AsType()(tk1) = b_buf[Number{}]; }); - using a_vector_t = typename vector_type::type; - using b_vector_t = typename vector_type::type; + using a_vector_t = typename vector_type::type; + using b_vector_t = typename vector_type::type; - constexpr index_t c_offset = CDesc{}.CalculateOffset( - c_origin_idx + make_multi_index(m0, m1, n0, n1)); + constexpr index_t c_offset = + CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( + c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); amd_inner_product_dlop( a_vec.template AsType()[I0], diff --git a/composable_kernel/include/utility/config.amd.hpp.in b/composable_kernel/include/utility/config.hpp similarity index 100% rename from composable_kernel/include/utility/config.amd.hpp.in rename to composable_kernel/include/utility/config.hpp diff --git a/composable_kernel/include/utility/float_type.amd.hpp.in b/composable_kernel/include/utility/float_type.hpp similarity index 100% rename from composable_kernel/include/utility/float_type.amd.hpp.in rename to composable_kernel/include/utility/float_type.hpp diff --git a/composable_kernel/include/utility/synchronization.amd.hpp.in b/composable_kernel/include/utility/synchronization.hpp similarity index 100% rename from composable_kernel/include/utility/synchronization.amd.hpp.in rename to composable_kernel/include/utility/synchronization.hpp diff --git a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.cpp deleted file mode 100644 index fc27016624..0000000000 --- a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.cpp +++ /dev/null @@ -1,379 +0,0 @@ -#include "common_header.hpp" -#include "type_helper.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_contraction_v1r1.hpp" -#include "transform_forward_convolution_into_gemm_v4r5_nchw_kcyx_nkhw.hpp" - -using namespace ck; - -using FloatAB = typename get_type_from_type_id(CK_PARAM_IN_WEI_DATATYPE)>::type; -using FloatC = typename get_type_from_type_id(CK_PARAM_OUT_DATATYPE)>::type; -using FloatAcc = typename get_type_from_type_id(CK_PARAM_CONV_COMPTYPE)>::type; - -constexpr index_t BlockSize = CK_PARAM_BlockSize; -constexpr index_t N0 = CK_PARAM_N0; - -constexpr index_t GM1PerBlockGM11 = CK_PARAM_GM1PerBlockGM11; -constexpr index_t GN1PerBlockGN11 = CK_PARAM_GN1PerBlockGN11; -constexpr index_t KPerBlock = CK_PARAM_KPerBlock; -constexpr index_t M1PerThread = CK_PARAM_M1PerThread; -constexpr index_t N1PerThread = CK_PARAM_N1PerThread; -constexpr index_t KPerThread = CK_PARAM_KPerThread; -constexpr index_t M1N1ThreadClusterM10 = CK_PARAM_M1N1ThreadClusterM10; -constexpr index_t M1N1ThreadClusterN10 = CK_PARAM_M1N1ThreadClusterN10; -constexpr index_t M1N1ThreadClusterM11 = CK_PARAM_M1N1ThreadClusterM11; -constexpr index_t M1N1ThreadClusterN11 = CK_PARAM_M1N1ThreadClusterN11; - -using ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11 = - Sequence; -using ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11 = - Sequence; -using ABlockTransferThreadClusterArrangeOrder = - Sequence; -using ABlockTransferSrcAccessOrder = Sequence; - -constexpr index_t ABlockTransferSrcVectorDim = CK_PARAM_ABlockTransferSrcVectorDim; -constexpr index_t ABlockTransferSrcScalarPerVector = CK_PARAM_ABlockTransferSrcScalarPerVector; -constexpr index_t ABlockTransferDstScalarPerVector_GM11 = - CK_PARAM_ABlockTransferDstScalarPerVector_GM11; -constexpr bool AThreadTransferSrcResetCoordinateAfterRun = - static_cast(CK_PARAM_AThreadTransferSrcResetCoordinateAfterRun); - -using BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11 = - Sequence; -using BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11 = - Sequence; -using BBlockTransferThreadClusterArrangeOrder = - Sequence; -using BBlockTransferSrcAccessOrder = Sequence; - -constexpr index_t BBlockTransferSrcVectorDim = CK_PARAM_BBlockTransferSrcVectorDim; -constexpr index_t BBlockTransferSrcScalarPerVector = CK_PARAM_BBlockTransferSrcScalarPerVector; -constexpr index_t BBlockTransferDstScalarPerVector_GN11 = - CK_PARAM_BBlockTransferDstScalarPerVector_GN11; -constexpr bool BThreadTransferSrcResetCoordinateAfterRun = - static_cast(CK_PARAM_BThreadTransferSrcResetCoordinateAfterRun); - -using CThreadTransferSrcDstAccessOrder = Sequence; -constexpr index_t CThreadTransferSrcDstVectorDim = CK_PARAM_CThreadTransferSrcDstVectorDim; -constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; - -constexpr bool HasMainKBlockLoop = static_cast(CK_PARAM_HAS_MAIN_KBLOCK_LOOP); -constexpr bool HasDoubleTailKBlockLoop = static_cast(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP); - -extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw_prepare( - int n, - int c, - int hi, - int wi, - int k, - int y, - int x, - int convStrideH, - int convStrideW, - int convDilationY, - int convDilationX, - int leftPadH, - int leftPadW, - int rightPadH, - int rightPadW, - void* p_a_gk_gm0_gm10_gm11_grid_desc, - void* p_b_gk_gn0_gn10_gn11_grid_desc, - void* p_c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, - void* p_c_blockid_to_gm10_gn10_block_cluster_adaptor) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - const index_t ho = (hi + leftPadH + rightPadH - convDilationY * (y - 1) - 1) / convStrideH + 1; - const index_t wo = (wi + leftPadW + rightPadW - convDilationX * (x - 1) - 1) / convStrideW + 1; - - const auto in_n_c_hi_wi_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, c, hi, wi)); - const auto wei_k_c_y_x_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(k, c, y, x)); - const auto out_n_k_ho_wo_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(n, k, ho, wo)); - - const auto descs = transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad( - wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - make_tuple(convStrideH, convStrideW), - make_tuple(convDilationY, convDilationX), - make_tuple(leftPadH, leftPadW), - make_tuple(rightPadH, rightPadW)); - - const auto a_gk_gm0_gm1_grid_desc = descs[I0]; - const auto b_gk_gn0_gn1_grid_desc = descs[I1]; - const auto c_gm0_gm1_gn0_gn1_grid_desc = descs[I2]; - - using AGKGM0GM1GridDesc = decltype(a_gk_gm0_gm1_grid_desc); - using BGKGN0GN1GridDesc = decltype(b_gk_gn0_gn1_grid_desc); - using CGM0GM1GN0GN1GridDesc = decltype(c_gm0_gm1_gn0_gn1_grid_desc); - - using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0>{}))); - - using BGridIteratorHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); - - using CGridIteratorHacks = decltype(make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); - - using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0>; - - using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0>; - - using GridwiseContraction = GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1< - BlockSize, - FloatAB, - FloatAcc, - FloatC, - InMemoryDataOperation::Set, /* ToDo tunable */ - AGKGM0GM1GridDesc, - BGKGN0GN1GridDesc, - CGM0GM1GN0GN1GridDesc, - GM1PerBlockGM11, - GN1PerBlockGN11, - KPerBlock, - M1PerThread, - N1PerThread, - KPerThread, - M1N1ThreadClusterM10, - M1N1ThreadClusterN10, - M1N1ThreadClusterM11, - M1N1ThreadClusterN11, - ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11, - ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_GM11, - AThreadTransferSrcResetCoordinateAfterRun, - BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11, - BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_GN11, - BThreadTransferSrcResetCoordinateAfterRun, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks>; - - auto a_gk_gm0_gm10_gm11_grid_desc = - GridwiseContraction::MakeAGKGM0GM10GM11GridDescriptor(a_gk_gm0_gm1_grid_desc); - auto b_gk_gn0_gn10_gn11_grid_desc = - GridwiseContraction::MakeBGKGN0GN10GN11GridDescriptor(b_gk_gn0_gn1_grid_desc); - auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc = - GridwiseContraction::MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(c_gm0_gm1_gn0_gn1_grid_desc); - auto c_blockid_to_gm10_gn10_block_cluster_adaptor = - GridwiseContraction::MakeCBlockIdToGM10GN10BlockClusterAdaptor(c_gm0_gm1_gn0_gn1_grid_desc); - - if(hipThreadIdx_x == 0) - { - *static_cast(p_a_gk_gm0_gm10_gm11_grid_desc) = - a_gk_gm0_gm10_gm11_grid_desc; - *static_cast(p_b_gk_gn0_gn10_gn11_grid_desc) = - b_gk_gn0_gn10_gn11_grid_desc; - *static_cast( - p_c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc) = c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc; - *static_cast( - p_c_blockid_to_gm10_gn10_block_cluster_adaptor) = - c_blockid_to_gm10_gn10_block_cluster_adaptor; - }; -}; - -extern "C" __global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) -#endif - dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw( - const FloatAB* __restrict__ p_a_grid, - const FloatAB* __restrict__ p_b_grid, - FloatC* __restrict__ p_c_grid, - const void __CONSTANT__* p_a_gk_gm0_gm10_gm11_grid_desc, - const void __CONSTANT__* p_b_gk_gn0_gn10_gn11_grid_desc, - const void __CONSTANT__* p_c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, - const void __CONSTANT__* p_c_blockid_to_gm10_gn10_block_cluster_adaptor) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - constexpr auto in_n_c_hi_wi_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28)); - constexpr auto wei_k_c_y_x_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 3, 3)); - constexpr auto out_n_k_ho_wo_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28)); - - constexpr auto descs = - transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad( - wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1), - make_tuple(1, 1)); - - constexpr auto a_gk_gm0_gm1_grid_desc = descs[I0]; - constexpr auto b_gk_gn0_gn1_grid_desc = descs[I1]; - constexpr auto c_gm0_gm1_gn0_gn1_grid_desc = descs[I2]; - - using AGKGM0GM1GridDesc = decltype(a_gk_gm0_gm1_grid_desc); - using BGKGN0GN1GridDesc = decltype(b_gk_gn0_gn1_grid_desc); - using CGM0GM1GN0GN1GridDesc = decltype(c_gm0_gm1_gn0_gn1_grid_desc); - - using AGridIteratorHacks = decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0>{}))); - - using BGridIteratorHacks = - decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}))); - - using CGridIteratorHacks = decltype(make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); - - using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0>; - using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0>; - - using GridwiseContraction = GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1< - BlockSize, - FloatAB, - FloatAcc, - FloatC, - InMemoryDataOperation::Set, /* ToDo tunable */ - AGKGM0GM1GridDesc, - BGKGN0GN1GridDesc, - CGM0GM1GN0GN1GridDesc, - GM1PerBlockGM11, - GN1PerBlockGN11, - KPerBlock, - M1PerThread, - N1PerThread, - KPerThread, - M1N1ThreadClusterM10, - M1N1ThreadClusterN10, - M1N1ThreadClusterM11, - M1N1ThreadClusterN11, - ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11, - ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_GM11, - AThreadTransferSrcResetCoordinateAfterRun, - BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11, - BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_GN11, - BThreadTransferSrcResetCoordinateAfterRun, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - AGridIteratorHacks, - BGridIteratorHacks, - CGridIteratorHacks, - AGridMoveSliceWindowIteratorHacks, - BGridMoveSliceWindowIteratorHacks>; - - using AGKGM0GM10GM11GridDesc = - decltype(GridwiseContraction::MakeAGKGM0GM10GM11GridDescriptor(a_gk_gm0_gm1_grid_desc)); - using BGKGN0GN10GN11GridDesc = - decltype(GridwiseContraction::MakeBGKGN0GN10GN11GridDescriptor(b_gk_gn0_gn1_grid_desc)); - using CGM10BM0BM1GN10BN0BN1GridDesc = decltype( - GridwiseContraction::MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(c_gm0_gm1_gn0_gn1_grid_desc)); - using CBlockIdToGM10GN10BlockClusterAdaptor = - decltype(GridwiseContraction::MakeCBlockIdToGM10GN10BlockClusterAdaptor( - c_gm0_gm1_gn0_gn1_grid_desc)); - - const auto a_gk_gm0_gm10_gm11_grid_desc = *reinterpret_cast( - (const void*)p_a_gk_gm0_gm10_gm11_grid_desc); - const auto b_gk_gn0_gn10_gn11_grid_desc = *reinterpret_cast( - (const void*)p_b_gk_gn0_gn10_gn11_grid_desc); - const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc = - *reinterpret_cast( - (const void*)p_c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc); - const auto c_blockid_to_gm10_gn10_block_cluster_adaptor = - *reinterpret_cast( - (const void*)p_c_blockid_to_gm10_gn10_block_cluster_adaptor); - - constexpr index_t shared_block_size = - GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); - - __shared__ FloatAB p_shared_block[shared_block_size]; - - GridwiseContraction::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_gk_gm0_gm10_gm11_grid_desc, - b_gk_gn0_gn10_gn11_grid_desc, - c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, - c_blockid_to_gm10_gn10_block_cluster_adaptor, - integral_constant{}, - integral_constant{}); -}; diff --git a/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.cpp new file mode 100644 index 0000000000..93a3bb39a0 --- /dev/null +++ b/composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.cpp @@ -0,0 +1,402 @@ +#include "common_header.hpp" +#include "type_helper.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_contraction_v1r2.hpp" +#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp" + +using namespace ck; + +using FloatAB = typename get_type_from_type_id(CK_PARAM_IN_WEI_DATATYPE)>::type; +using FloatAcc = typename get_type_from_type_id(CK_PARAM_ACC_DATATYPE)>::type; +using FloatC = typename get_type_from_type_id(CK_PARAM_OUT_DATATYPE)>::type; + +constexpr index_t BlockSize = CK_PARAM_BlockSize; + +constexpr auto GN0 = Number{}; +constexpr auto GK1 = Number{}; + +constexpr index_t GM1PerBlockGM11 = CK_PARAM_GM1PerBlockGM11; +constexpr index_t GN1PerBlockGN11 = CK_PARAM_GN1PerBlockGN11; +constexpr index_t GK0PerBlock = CK_PARAM_GK0PerBlock; +constexpr index_t BM1PerThreadBM11 = CK_PARAM_BM1PerThreadBM11; +constexpr index_t BN1PerThreadBN11 = CK_PARAM_BN1PerThreadBN11; +constexpr index_t BK0PerThread = CK_PARAM_BK0PerThread; +constexpr index_t BM10BN10ThreadClusterBM100 = CK_PARAM_BM10BN10ThreadClusterBM100; +constexpr index_t BM10BN10ThreadClusterBN100 = CK_PARAM_BM10BN10ThreadClusterBN100; +constexpr index_t BM10BN10ThreadClusterBM101 = CK_PARAM_BM10BN10ThreadClusterBM101; +constexpr index_t BM10BN10ThreadClusterBN101 = CK_PARAM_BM10BN10ThreadClusterBN101; + +using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = + Sequence; +using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = + Sequence; +using ABlockTransferThreadClusterArrangeOrder = Sequence<1, 2, 3, 0, 4>; +using ABlockTransferSrcAccessOrder = Sequence<3, 2, 1, 0, 4>; +using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = + Sequence; +using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = + Sequence; +using ABlockTransferSrcVectorTensorContiguousDimOrder = Sequence<0, 1, 2, 3, 4>; + +using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = + Sequence; +using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = + Sequence; +using BBlockTransferThreadClusterArrangeOrder = Sequence<0, 4, 1, 2, 3>; +using BBlockTransferSrcAccessOrder = Sequence<4, 3, 2, 0, 1>; +using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = + Sequence; +using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = + Sequence; +using BBlockTransferSrcVectorTensorContiguousDimOrder = Sequence<0, 1, 2, 3, 4>; + +using CThreadTransferSrcDstAccessOrder = Sequence<3, 4, 5, 0, 1, 2>; +constexpr index_t CThreadTransferSrcDstVectorDim = 5; +constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector; + +constexpr bool HasMainKBlockLoop = static_cast(CK_PARAM_HAS_MAIN_KBLOCK_LOOP); +constexpr bool HasDoubleTailKBlockLoop = static_cast(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP); + +extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw_prepare( + index_t N, + index_t C, + index_t Hi, + index_t Wi, + index_t K, + index_t Y, + index_t X, + index_t ConvStrideH, + index_t ConvStrideW, + index_t ConvDilationH, + index_t ConvDilationW, + index_t InLeftPadH, + index_t InLeftPadW, + index_t InRightPadH, + index_t InRightPadW, + void* p_a_grid_desc_gk0_gm0_gm10_gm11_gk1, + void* p_b_grid_desc_gk0_gn0_gn10_gn11_gk1, + void* p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + void* p_c_grid_block_cluster_blockid_to_gm10_gn10) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + const index_t Ho = + (Hi + InLeftPadH + InRightPadH - ConvDilationH * (Y - 1) - 1) / ConvStrideH + 1; + const index_t Wo = + (Wi + InLeftPadW + InRightPadW - ConvDilationW * (X - 1) - 1) / ConvStrideW + 1; + + const auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C, Hi, Wi)); + const auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C, Y, X)); + const auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo)); + + const auto descs = transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad( + wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + make_tuple(ConvStrideH, ConvStrideW), + make_tuple(ConvDilationH, ConvDilationW), + make_tuple(InLeftPadH, InLeftPadW), + make_tuple(InRightPadH, InRightPadW), + GN0, + GK1); + + const auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0]; + const auto b_grid_desc_gk0_gn0_gn1_gk1 = descs[I1]; + const auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2]; + + using AGridDesc_GK0_GM0_GM1_GK1 = decltype(a_grid_desc_gk0_gm0_gm1_gk1); + using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1); + using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1); + + using AGridIteratorHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11 + Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11 + Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 + + using BGridIteratorHacks = decltype(make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 + + using CGridIteratorHacks = decltype(make_tuple( + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1 + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1 + + using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0>; + + using BGridMoveSliceWindowIteratorHacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>; + + using GridwiseContraction = + GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + InMemoryDataOperation::Set, + AGridDesc_GK0_GM0_GM1_GK1, + BGridDesc_GK0_GN0_GN1_GK1, + CGridDesc_GM0_GM1_GN0_GN1, + GM1PerBlockGM11, + GN1PerBlockGN11, + GK0PerBlock, + BM1PerThreadBM11, + BN1PerThreadBN11, + BK0PerThread, + BM10BN10ThreadClusterBM100, + BM10BN10ThreadClusterBN100, + BM10BN10ThreadClusterBM101, + BM10BN10ThreadClusterBN101, + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferSrcVectorTensorContiguousDimOrder, + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferSrcVectorTensorContiguousDimOrder, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + AGridIteratorHacks, + BGridIteratorHacks, + CGridIteratorHacks, + AGridMoveSliceWindowIteratorHacks, + BGridMoveSliceWindowIteratorHacks>; + + auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = + GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(a_grid_desc_gk0_gm0_gm1_gk1); + auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = + GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(b_grid_desc_gk0_gn0_gn1_gk1); + auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = + GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( + c_grid_desc_gm0_gm1_gn0_gn1); + auto c_grid_block_cluster_blockid_to_gm10_gn10 = + GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10( + c_grid_desc_gm0_gm1_gn0_gn1); + + if(hipThreadIdx_x == 0) + { + *static_cast( + p_a_grid_desc_gk0_gm0_gm10_gm11_gk1) = a_grid_desc_gk0_gm0_gm10_gm11_gk1; + *static_cast( + p_b_grid_desc_gk0_gn0_gn10_gn11_gk1) = b_grid_desc_gk0_gn0_gn10_gn11_gk1; + *static_cast( + p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1) = c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1; + *static_cast( + p_c_grid_block_cluster_blockid_to_gm10_gn10) = + c_grid_block_cluster_blockid_to_gm10_gn10; + }; +}; + +extern "C" __global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void __CONSTANT__* p_a_grid_desc_gk0_gm0_gm10_gm11_gk1, + const void __CONSTANT__* p_b_grid_desc_gk0_gn0_gn10_gn11_gk1, + const void __CONSTANT__* p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + const void __CONSTANT__* p_c_grid_block_cluster_blockid_to_gm10_gn10) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + constexpr auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28)); + constexpr auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 3, 3)); + constexpr auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28)); + + constexpr auto descs = + transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + make_tuple(1, 1), + GN0, + GK1); + + constexpr auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0]; + constexpr auto b_grid_desc_gk0_gn0_gn1_gk1 = descs[I1]; + constexpr auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2]; + + using AGridDesc_GK0_GM0_GM1_GK1 = decltype(a_grid_desc_gk0_gm0_gm1_gk1); + using BGridDesc_GK0_GN0_GN1_GK1 = decltype(b_grid_desc_gk0_gn0_gn1_gk1); + using CGridDesc_GM0_GM1_GN0_GN1 = decltype(c_grid_desc_gm0_gm1_gn0_gn1); + + using AGridIteratorHacks = + decltype(make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11 + Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11 + Sequence<0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 + + using BGridIteratorHacks = decltype(make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}))); // 4-: GK1 + + using CGridIteratorHacks = decltype(make_tuple( + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1 + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}))); // 5-: GN1 + + using AGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0>; + + using BGridMoveSliceWindowIteratorHacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>; + + using GridwiseContraction = + GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + InMemoryDataOperation::Set, + AGridDesc_GK0_GM0_GM1_GK1, + BGridDesc_GK0_GN0_GN1_GK1, + CGridDesc_GM0_GM1_GN0_GN1, + GM1PerBlockGM11, + GN1PerBlockGN11, + GK0PerBlock, + BM1PerThreadBM11, + BN1PerThreadBN11, + BK0PerThread, + BM10BN10ThreadClusterBM100, + BM10BN10ThreadClusterBN100, + BM10BN10ThreadClusterBM101, + BM10BN10ThreadClusterBN101, + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferSrcVectorTensorContiguousDimOrder, + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferSrcVectorTensorContiguousDimOrder, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + AGridIteratorHacks, + BGridIteratorHacks, + CGridIteratorHacks, + AGridMoveSliceWindowIteratorHacks, + BGridMoveSliceWindowIteratorHacks>; + + using AGridDesc_GK0_GM0_GM10_GM11_GK1 = + decltype(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1( + a_grid_desc_gk0_gm0_gm1_gk1)); + using BGridDesc_GK0_GN0_GN10_GN11_GK1 = + decltype(GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1( + b_grid_desc_gk0_gn0_gn1_gk1)); + using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = + decltype(GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1( + c_grid_desc_gm0_gm1_gn0_gn1)); + using CGridBlockCluster_BlockId_To_GM10_GN10 = + decltype(GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10( + c_grid_desc_gm0_gm1_gn0_gn1)); + + const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = + *reinterpret_cast( + (const void*)p_a_grid_desc_gk0_gm0_gm10_gm11_gk1); + const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = + *reinterpret_cast( + (const void*)p_b_grid_desc_gk0_gn0_gn10_gn11_gk1); + const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = + *reinterpret_cast( + (const void*)p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1); + const auto c_grid_block_cluster_blockid_to_gm10_gn10 = + *reinterpret_cast( + (const void*)p_c_grid_block_cluster_blockid_to_gm10_gn10); + + constexpr index_t shared_block_size = + GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseContraction::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_grid_desc_gk0_gm0_gm10_gm11_gk1, + b_grid_desc_gk0_gn0_gn10_gn11_gk1, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1, + c_grid_block_cluster_blockid_to_gm10_gn10, + integral_constant{}, + integral_constant{}); +}; diff --git a/driver/include/conv_tunables.hpp b/driver/include/conv_tunables.hpp deleted file mode 100644 index 0275a95f9a..0000000000 --- a/driver/include/conv_tunables.hpp +++ /dev/null @@ -1,271 +0,0 @@ -#ifndef CONV_TUNABLES_HPP -#define CONV_TUNABLES_HPP - -#include "config.hpp" - -struct tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw -{ - ck::index_t BlockSize; // usually not tunable - - ck::index_t MPerBlock; - ck::index_t NPerBlock; - ck::index_t KPerBlock; - - ck::index_t M1PerThread; - ck::index_t N1PerThread; - ck::index_t KPerThread; - - ck::index_t M1N1ThreadClusterM10; - ck::index_t M1N1ThreadClusterN10; - ck::index_t M1N1ThreadClusterM11; - ck::index_t M1N1ThreadClusterN11; - - std::array ABlockTransferThreadSliceLengths_K_M0_M1; - std::array ABlockTransferThreadClusterLengths_K_M0_M1; - std::array ABlockTransferThreadClusterArrangeOrder; - std::array ABlockTransferSrcAccessOrder; - ck::index_t ABlockTransferSrcVectorDim; - ck::index_t ABlockTransferSrcScalarPerVector; - ck::index_t ABlockTransferDstScalarPerVector_M1; - bool AThreadTransferSrcResetCoordinateAfterRun; - - std::array BBlockTransferThreadSliceLengths_K_N0_N1; - std::array BBlockTransferThreadClusterLengths_K_N0_N1; - std::array BBlockTransferThreadClusterArrangeOrder; - std::array BBlockTransferSrcAccessOrder; - ck::index_t BBlockTransferSrcVectorDim; - ck::index_t BBlockTransferSrcScalarPerVector; - ck::index_t BBlockTransferDstScalarPerVector_N1; - bool BThreadTransferSrcResetCoordinateAfterRun; - - std::array CThreadTransferSrcDstAccessOrder; - ck::index_t CThreadTransferSrcDstVectorDim; - ck::index_t CThreadTransferDstScalarPerVector; -}; - -static tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw default_tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw = { - 256, 128, 128, 8, 4, 4, 1, - 8, 8, 2, 2, {4, 1, 1}, {2, 1, 128}, {2, 1, 0}, - {2, 1, 0}, 0, 4, 1, false, {4, 1, 1}, {2, 1, 128}, - {0, 1, 2}, {0, 1, 2}, 2, 1, 1, false, {3, 4, 5, 0, 1, 2}, - 5, 1}; - -struct tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw -{ - ck::index_t BlockSize; // usually not tunable - - ck::index_t MPerBlock; - ck::index_t NPerBlock; - ck::index_t KPerBlock; - - ck::index_t MPerWave; - ck::index_t NPerWave; - ck::index_t K1; - - ck::index_t MRepeat; - ck::index_t NRepeat; - - std::array ABlockTransferThreadSliceLengths_K0_M_K1; - std::array ABlockTransferThreadClusterLengths_K0_M_K1; - std::array ABlockTransferThreadClusterArrangeOrder; - std::array ABlockTransferSrcAccessOrder; - ck::index_t ABlockTransferSrcVectorDim; - ck::index_t ABlockTransferSrcScalarPerVector; - ck::index_t ABlockTransferDstScalarPerVector_K1; - bool AThreadTransferSrcResetCoordinateAfterRun; - - std::array BBlockTransferThreadSliceLengths_K0_N_K1; - std::array BBlockTransferThreadClusterLengths_K0_N_K1; - std::array BBlockTransferThreadClusterArrangeOrder; - std::array BBlockTransferSrcAccessOrder; - ck::index_t BBlockTransferSrcVectorDim; - ck::index_t BBlockTransferSrcScalarPerVector; - ck::index_t BBlockTransferDstScalarPerVector_K1; - bool BThreadTransferSrcResetCoordinateAfterRun; - - std::array CThreadTransferSrcDstAccessOrder; - ck::index_t CThreadTransferSrcDstVectorDim; - ck::index_t CThreadTransferDstScalarPerVector; -}; - -static tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw - default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw = { - 256, // BlockSize - 128, // MPerBlock, - 128, // NPerBlock, - 4, // KPerBlock, - 32, // MPerWave, - 32, // NPerWave, - 4, // K1, - 2, // MRepeat, - 2, // NRepeat, - {1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1, - {4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1, - {1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder, - {1, 0, 2}, // ABlockTransferSrcAccessOrder, - 2, // ABlockTransferSrcVectorDim - 1, // ABlockTransferSrcScalarPerVector, - 4, // ABlockTransferDstScalarPerVector_K1, - false, // AThreadTransferSrcResetCoordinateAfterRun, - {1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1, - {4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1, - {0, 2, 1}, // BBlockTransferThreadClusterArrangeOrder, - {1, 0, 2}, // BBlockTransferSrcAccessOrder, - 1, // BBlockTransferSrcVectorDim - 1, // BBlockTransferSrcScalarPerVector - 4, // BBlockTransferDstScalarPerVector_K1 - false, // BThreadTransferSrcResetCoordinateAfterRun - {3, 0, 1, 2, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder - 7, // CThreadTransferSrcDstVectorDim, - 1 // CThreadTransferDstScalarPerVector -}; - -struct tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk -{ - ck::index_t BlockSize; // usually not tunable - - ck::index_t MPerBlock; - ck::index_t NPerBlock; - ck::index_t KPerBlock; - - ck::index_t MPerWave; - ck::index_t NPerWave; - ck::index_t K1; - - ck::index_t MRepeat; - ck::index_t NRepeat; - - std::array ABlockTransferThreadSliceLengths_K0_M_K1; - std::array ABlockTransferThreadClusterLengths_K0_M_K1; - std::array ABlockTransferThreadClusterArrangeOrder; - std::array ABlockTransferSrcAccessOrder; - ck::index_t ABlockTransferSrcVectorDim; - ck::index_t ABlockTransferSrcScalarPerVector; - ck::index_t ABlockTransferDstScalarPerVector_K1; - bool AThreadTransferSrcResetCoordinateAfterRun; - - std::array BBlockTransferThreadSliceLengths_K0_N_K1; - std::array BBlockTransferThreadClusterLengths_K0_N_K1; - std::array BBlockTransferThreadClusterArrangeOrder; - std::array BBlockTransferSrcAccessOrder; - ck::index_t BBlockTransferSrcVectorDim; - ck::index_t BBlockTransferSrcScalarPerVector; - ck::index_t BBlockTransferDstScalarPerVector_K1; - bool BThreadTransferSrcResetCoordinateAfterRun; - - std::array CThreadTransferSrcDstAccessOrder; - ck::index_t CThreadTransferSrcDstVectorDim; - ck::index_t CThreadTransferDstScalarPerVector; -}; - -static tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk - default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk = { - 256, // BlockSize - 128, // MPerBlock, - 128, // NPerBlock, - 4, // KPerBlock, - 32, // MPerWave, - 32, // NPerWave, - 4, // K1, - 2, // MRepeat, - 2, // NRepeat, - {1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1, - {4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1, - {1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder, - {1, 0, 2}, // ABlockTransferSrcAccessOrder, - 2, // ABlockTransferSrcVectorDim - 4, // ABlockTransferSrcScalarPerVector, - 4, // ABlockTransferDstScalarPerVector_K1, - false, // AThreadTransferSrcResetCoordinateAfterRun, - {1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1, - {4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1, - {1, 0, 2}, // BBlockTransferThreadClusterArrangeOrder, - {1, 0, 2}, // BBlockTransferSrcAccessOrder, - 2, // BBlockTransferSrcVectorDim - 4, // BBlockTransferSrcScalarPerVector - 4, // BBlockTransferDstScalarPerVector_K1 - false, // BThreadTransferSrcResetCoordinateAfterRun - {2, 3, 0, 1, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder - 7, // CThreadTransferSrcDstVectorDim, - 1 // CThreadTransferDstScalarPerVector -}; - -struct tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw -{ - ck::index_t BlockSize; - - ck::index_t GM1PerBlockGM11; - ck::index_t GN1PerBlockGN11; - ck::index_t KPerBlock; - - ck::index_t M1PerThread; - ck::index_t N1PerThread; - ck::index_t KPerThread; - - ck::index_t M1N1ThreadClusterM10; - ck::index_t M1N1ThreadClusterN10; - ck::index_t M1N1ThreadClusterM11; - ck::index_t M1N1ThreadClusterN11; - - std::array ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11; - std::array ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11; - std::array ABlockTransferThreadClusterArrangeOrder; - std::array ABlockTransferSrcAccessOrder; - ck::index_t ABlockTransferSrcVectorDim; - ck::index_t ABlockTransferSrcScalarPerVector; - ck::index_t ABlockTransferDstScalarPerVector_GM11; - bool AThreadTransferSrcResetCoordinateAfterRun; - - std::array BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11; - std::array BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11; - std::array BBlockTransferThreadClusterArrangeOrder; - std::array BBlockTransferSrcAccessOrder; - ck::index_t BBlockTransferSrcVectorDim; - ck::index_t BBlockTransferSrcScalarPerVector; - ck::index_t BBlockTransferDstScalarPerVector_GN11; - bool BThreadTransferSrcResetCoordinateAfterRun; - - std::array CThreadTransferSrcDstAccessOrder; - ck::index_t CThreadTransferSrcDstVectorDim; - ck::index_t CThreadTransferDstScalarPerVector; -}; - -static tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw default_tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw = { - 256, - 128, - 32, - 8, - 4, - 4, - 1, - 2, - 2, - 8, - 8, - {4, 1, 1, 1}, - {2, 1, 1, 128}, - {3, 2, 1, 0}, - {3, 2, 1, 0}, - 0, - 4, - 1, - false, - {1, 4, 1, 1}, - {8, 1, 1, 32}, - {0, 3, 2, 1}, - {0, 3, 2, 1}, - 3, - 1, - 1, - false, - {3, 4, 5, 0, 1, 2}, - 5, - 1}; - -static inline int -conv_hw_out_size(int hw_in_size, int leftPad, int rightPad, int dilation, int yx_size, int stride) -{ - return (hw_in_size + leftPad + rightPad - dilation * (yx_size - 1) - 1) / stride + 1; -} - -#endif diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index c3640d675c..0000000000 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,520 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp" -#include "driver_dynamic_gemm_v1r2.hpp" - -template -void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_hi_wi_c, - const Tensor& wei_k_y_x_c, - Tensor& out_n_ho_wo_k, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - constexpr auto I8 = Number<8>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); - -#if 0 - // cdata = 16, BlockSize = 64, 16x64x4 - constexpr index_t BlockSize = 64; - - constexpr index_t GemmMPerBlockM1 = 16; - constexpr index_t GemmNPerBlockN1 = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmM1PerThreadM111 = 2; - constexpr index_t GemmN1PerThreadN111 = 2; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmM11N11ThreadClusterM1101 = 2; - constexpr index_t GemmM11N11ThreadClusterN1101 = 2; - constexpr index_t GemmM11N11ThreadClusterM1100 = 2; - constexpr index_t GemmM11N11ThreadClusterN1100 = 8; - - using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>; - using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 16>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1; - constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1; - - using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>; - using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 2; -#elif 0 - // cdata = 32, BlockSize = 64, 16x128x4 - constexpr index_t BlockSize = 64; - - constexpr index_t GemmMPerBlockM1 = 16; - constexpr index_t GemmNPerBlockN1 = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmM1PerThreadM111 = 2; - constexpr index_t GemmN1PerThreadN111 = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmM11N11ThreadClusterM1101 = 2; - constexpr index_t GemmM11N11ThreadClusterN1101 = 2; - constexpr index_t GemmM11N11ThreadClusterM1100 = 2; - constexpr index_t GemmM11N11ThreadClusterN1100 = 8; - - using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>; - using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 16>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1; - constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1; - - using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 2>; - using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 2; -#elif 0 - // cdata = 64, BlockSize = 64, 16x256x2 - constexpr index_t BlockSize = 64; - - constexpr index_t GemmMPerBlockM1 = 16; - constexpr index_t GemmNPerBlockN1 = 256; - constexpr index_t GemmKPerBlock = 2; - - constexpr index_t GemmM1PerThreadM111 = 4; - constexpr index_t GemmN1PerThreadN111 = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmM11N11ThreadClusterM1101 = 1; - constexpr index_t GemmM11N11ThreadClusterN1101 = 2; - constexpr index_t GemmM11N11ThreadClusterM1100 = 2; - constexpr index_t GemmM11N11ThreadClusterN1100 = 16; - - using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>; - using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 16>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1; - constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1; - - using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<2, 1, 4>; - using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4; -#elif 0 - // cdata = 64, BlockSize = 64, 16x256x4 - constexpr index_t BlockSize = 64; - - constexpr index_t GemmMPerBlockM1 = 16; - constexpr index_t GemmNPerBlockN1 = 256; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmM1PerThreadM111 = 4; - constexpr index_t GemmN1PerThreadN111 = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmM11N11ThreadClusterM1101 = 2; - constexpr index_t GemmM11N11ThreadClusterN1101 = 2; - constexpr index_t GemmM11N11ThreadClusterM1100 = 1; - constexpr index_t GemmM11N11ThreadClusterN1100 = 16; - - using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>; - using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 16>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1; - constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1; - - using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 4>; - using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4; -#elif 0 - // cdata = 64, BlockSize = 128, 32x256x4 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlockM1 = 32; - constexpr index_t GemmNPerBlockN1 = 256; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmM1PerThreadM111 = 4; - constexpr index_t GemmN1PerThreadN111 = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmM11N11ThreadClusterM1101 = 2; - constexpr index_t GemmM11N11ThreadClusterN1101 = 2; - constexpr index_t GemmM11N11ThreadClusterM1100 = 2; - constexpr index_t GemmM11N11ThreadClusterN1100 = 16; - - using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>; - using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 32>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1; - constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1; - - using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 2>; - using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 128>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4; -#elif 0 - // cdata = 64, BlockSize = 128, 32x256x8 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlockM1 = 32; - constexpr index_t GemmNPerBlockN1 = 256; - constexpr index_t GemmKPerBlock = 8; - - constexpr index_t GemmM1PerThreadM111 = 4; - constexpr index_t GemmN1PerThreadN111 = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmM11N11ThreadClusterM1101 = 2; - constexpr index_t GemmM11N11ThreadClusterN1101 = 2; - constexpr index_t GemmM11N11ThreadClusterM1100 = 2; - constexpr index_t GemmM11N11ThreadClusterN1100 = 16; - - using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<2, 1, 1>; - using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 32>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 2; - constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1; - - using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<8, 1, 2>; - using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 128>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4; -#elif 1 - // cdata = 64, BlockSize = 256, 128x128x8 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlockM1 = 128; - constexpr index_t GemmNPerBlockN1 = 128; - constexpr index_t GemmKPerBlock = 8; - - constexpr index_t GemmM1PerThreadM111 = 4; - constexpr index_t GemmN1PerThreadN111 = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmM11N11ThreadClusterM1100 = 8; - constexpr index_t GemmM11N11ThreadClusterN1100 = 8; - constexpr index_t GemmM11N11ThreadClusterM1101 = 2; - constexpr index_t GemmM11N11ThreadClusterN1101 = 2; - - using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 1>; - using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 128>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1; - - using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>; - using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4; -#elif 1 - // cdata = 64, BlockSize = 256, 128x128x16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlockM1 = 128; - constexpr index_t GemmNPerBlockN1 = 128; - constexpr index_t GemmKPerBlock = 16; - - constexpr index_t GemmM1PerThreadM111 = 4; - constexpr index_t GemmN1PerThreadN111 = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmM11N11ThreadClusterM1100 = 8; - constexpr index_t GemmM11N11ThreadClusterN1100 = 8; - constexpr index_t GemmM11N11ThreadClusterM1101 = 2; - constexpr index_t GemmM11N11ThreadClusterN1101 = 2; - - using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 2>; - using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 64>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 2; - - using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<8, 1, 1>; - using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4; -#endif - -#if 1 - const auto descs = - transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, - in_n_hi_wi_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads); - -#if 0 - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{})); - - constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); - - constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{})); - - constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; -#else - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{})); - - constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); - - constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{})); - - constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}; -#endif - -#else - const auto descs = - transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1(wei_k_y_x_c_desc, - in_n_hi_wi_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads); - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{})); - - constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); - - constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{})); - - constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks = - Sequence<0, 0, 0, 0, 0>{}; -#endif - - const auto wei_gemmk_gemmm_grid_desc = descs[I0]; - const auto in_gemmk_gemmn_grid_desc = descs[I1]; - const auto out_gemmm_gemmn_grid_desc = descs[I2]; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_dynamic_gemm_v1r2< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperation::Set, - decltype(wei_gemmk_gemmm_grid_desc), - decltype(in_gemmk_gemmn_grid_desc), - decltype(out_gemmm_gemmn_grid_desc), - GemmMPerBlockM1, - GemmNPerBlockN1, - GemmKPerBlock, - GemmM1PerThreadM111, - GemmN1PerThreadN111, - GemmKPerThread, - GemmM11N11ThreadClusterM1100, - GemmM11N11ThreadClusterN1100, - GemmM11N11ThreadClusterM1101, - GemmM11N11ThreadClusterN1101, - GemmABlockTransferThreadSliceLengths_K_M0_M1, - GemmABlockTransferThreadClusterLengths_K_M0_M1, - Sequence<1, 2, 0>, // ABlockTransferThreadClusterArrangeOrder - Sequence<1, 2, 0>, // ABlockTransferSrcAccessOrder - 0, // ABlockTransferSrcVectorDim - GemmABlockTransferSrcScalarPerVector_K, - GemmABlockTransferDstScalarPerVector_M1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_K_N0_N1, - GemmBBlockTransferThreadClusterLengths_K_N0_N1, - Sequence<1, 2, 0>, // BBlockTransferThreadClusterArrangeOrder - Sequence<1, 2, 0>, // BBlockTransferSrcAccessOrder - 0, // BBlockTransferSrcVectorDim - GemmBBlockTransferSrcScalarPerVector_K, - GemmBBlockTransferDstScalarPerVector_N1, - false, // don't move back src coordinate after threadwise copy - Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder - 2, // CThreadTransferSrcDstVectorDim - GemmCThreadTransferDstScalarPerVector_M11, - decltype(wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks), - decltype(in_gemmk_gemmn0_gemmn1_grid_iterator_hacks), - decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks), - decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks), - decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks)>( - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - wei_gemmk_gemmm_grid_desc, - in_gemmk_gemmn_grid_desc, - out_gemmm_gemmn_grid_desc, - wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks, - in_gemmk_gemmn0_gemmn1_grid_iterator_hacks, - out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks, - wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks, - in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks, - nrepeat); - - { - const auto N = out_n_ho_wo_k_lengths[I0]; - const auto K = out_n_ho_wo_k_lengths[I3]; - const auto C = wei_k_y_x_c_lengths[I3]; - - const auto Hi = in_n_hi_wi_c_lengths[I1]; - const auto Wi = in_n_hi_wi_c_lengths[I2]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // copy result back to host - out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); -} diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp deleted file mode 100644 index d00314c8d9..0000000000 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,240 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_forward_convolution_into_gemm_v4r5_nchw_kcyx_nkhw.hpp" -#include "driver_dynamic_contraction_v1r1.hpp" - -template -void device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw( - const InLengths& in_n_c_hi_wi_lengths, - const WeiLengths& wei_k_c_y_x_lengths, - const OutLengths& out_n_k_ho_wo_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); - - in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); - - const auto in_n_c_hi_wi_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); - const auto wei_k_c_y_x_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); - const auto out_n_k_ho_wo_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); - -#if 1 - // cdata = 64, BlockSize = 256, [8, 1, 128] * [8, 4, 32] = [1, 128, 4, 32] - constexpr index_t BlockSize = 256; - - constexpr index_t N0 = 4; - - constexpr index_t GemmGM1PerBlockGM11 = 128; - constexpr index_t GemmGN1PerBlockGN11 = 32; - constexpr index_t GemmKPerBlock = 8; - - constexpr index_t GemmM1PerThreadM111 = 4; - constexpr index_t GemmN1PerThreadN111 = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmM11N11ThreadClusterM1101 = 2; - constexpr index_t GemmM11N11ThreadClusterN1101 = 2; - constexpr index_t GemmM11N11ThreadClusterM1100 = 8; - constexpr index_t GemmM11N11ThreadClusterN1100 = 8; - - using GemmABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11 = Sequence<4, 1, 1, 1>; - using GemmABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11 = Sequence<2, 1, 1, 128>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GK = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GM11 = 1; - - using GemmBBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11 = Sequence<1, 4, 1, 1>; - using GemmBBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11 = Sequence<8, 1, 1, 32>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GN11 = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GN11 = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_BN1 = 1; -#elif 1 - // cdata = 64, BlockSize = 256, [8, 1, 128] * [8, 8, 16] = [1, 128, 8, 16] - constexpr index_t BlockSize = 256; - - constexpr index_t N0 = 8; - - constexpr index_t GemmGM1PerBlockGM11 = 128; - constexpr index_t GemmGN1PerBlockGN11 = 16; - constexpr index_t GemmKPerBlock = 8; - - constexpr index_t GemmM1PerThreadM111 = 4; - constexpr index_t GemmN1PerThreadN111 = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmM11N11ThreadClusterM1101 = 2; - constexpr index_t GemmM11N11ThreadClusterN1101 = 2; - constexpr index_t GemmM11N11ThreadClusterM1100 = 8; - constexpr index_t GemmM11N11ThreadClusterN1100 = 8; - - using GemmABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11 = Sequence<4, 1, 1, 1>; - using GemmABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11 = Sequence<2, 1, 1, 128>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GK = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GM11 = 1; - - using GemmBBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11 = Sequence<1, 4, 1, 1>; - using GemmBBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11 = Sequence<8, 2, 1, 16>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GN11 = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GN11 = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_BN1 = 1; -#endif - - const auto descs = transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad( - wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads); - - const auto wei_gk_gm0_gm1_grid_desc = descs[I0]; - const auto in_gk_gn0_gn1_grid_desc = descs[I1]; - const auto out_gm0_gm1_gn0_gn1_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gk_gm0_gm10_gm11_grid_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0>{})); - - constexpr auto in_gk_gn0_gn10_gn11_grid_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); - - constexpr auto out_gm10_bm0_bm1_gn10_bn0_bn1_grid_iterator_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); - - constexpr auto wei_gk_gm0_gm10_gm11_grid_move_slice_window_iterator_hacks = - Sequence<0, 0, 0, 0, 0, 0>{}; - - constexpr auto in_gk_gn0_gn10_gn11_grid_move_slice_window_iterator_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_dynamic_contraction_v1r1< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperation::Set, - decltype(wei_gk_gm0_gm1_grid_desc), - decltype(in_gk_gn0_gn1_grid_desc), - decltype(out_gm0_gm1_gn0_gn1_grid_desc), - GemmGM1PerBlockGM11, - GemmGN1PerBlockGN11, - GemmKPerBlock, - GemmM1PerThreadM111, - GemmN1PerThreadN111, - GemmKPerThread, - GemmM11N11ThreadClusterM1100, - GemmM11N11ThreadClusterN1100, - GemmM11N11ThreadClusterM1101, - GemmM11N11ThreadClusterN1101, - GemmABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11, - GemmABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11, - Sequence<3, 2, 1, 0>, // ABlockTransferThreadClusterArrangeOrder - Sequence<3, 2, 1, 0>, // ABlockTransferSrcAccessOrder - 0, // ABlockTransferSrcVectorDim - GemmABlockTransferSrcScalarPerVector_GK, - GemmABlockTransferDstScalarPerVector_GM11, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11, - GemmBBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11, - Sequence<0, 3, 2, 1>, // BBlockTransferThreadClusterArrangeOrder - Sequence<0, 3, 2, 1>, // BBlockTransferSrcAccessOrder - 3, // BBlockTransferSrcVectorDim - GemmBBlockTransferSrcScalarPerVector_GN11, - GemmBBlockTransferDstScalarPerVector_GN11, - false, // don't move back src coordinate after threadwise copy - Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder - 5, // CThreadTransferSrcDstVectorDim - GemmCThreadTransferDstScalarPerVector_BN1, - decltype(wei_gk_gm0_gm10_gm11_grid_iterator_hacks), - decltype(in_gk_gn0_gn10_gn11_grid_iterator_hacks), - decltype(out_gm10_bm0_bm1_gn10_bn0_bn1_grid_iterator_hacks), - decltype(wei_gk_gm0_gm10_gm11_grid_move_slice_window_iterator_hacks), - decltype(in_gk_gn0_gn10_gn11_grid_move_slice_window_iterator_hacks)>( - static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), - static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - wei_gk_gm0_gm1_grid_desc, - in_gk_gn0_gn1_grid_desc, - out_gm0_gm1_gn0_gn1_grid_desc, - wei_gk_gm0_gm10_gm11_grid_iterator_hacks, - in_gk_gn0_gn10_gn11_grid_iterator_hacks, - out_gm10_bm0_bm1_gn10_bn0_bn1_grid_iterator_hacks, - wei_gk_gm0_gm10_gm11_grid_move_slice_window_iterator_hacks, - in_gk_gn0_gn10_gn11_grid_move_slice_window_iterator_hacks, - nrepeat); - - float perf = (float)calculate_convolution_flops( - in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); -} diff --git a/driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp b/driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 914de6e81b..0000000000 --- a/driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,404 +0,0 @@ -#include "device.hpp" -#include "host_tensor.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "transform_forward_convolution_into_gemm_v4r5_nchw_kcyx_nkhw.hpp" - -#include "olc_driver_common.hpp" -#include "conv_tunables.hpp" - -#include "handle.hpp" - -namespace detail_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw { - -template -static std::string get_network_config_string_from_types() -{ - std::string out; - - out += static_cast(Driver::get_typeid_from_type()) + - static_cast(Driver::get_typeid_from_type()) + - static_cast(Driver::get_typeid_from_type()); - - return (out); -}; - -static std::string -get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw* pt) -{ - std::string out("TUN_"); - - out += std::to_string(pt->BlockSize) + "_"; - - out += std::to_string(pt->GM1PerBlockGM11) + "x" + std::to_string(pt->GN1PerBlockGN11) + "x" + - std::to_string(pt->KPerBlock) + "_"; - out += std::to_string(pt->M1PerThread) + "x" + std::to_string(pt->N1PerThread) + "x" + - std::to_string(pt->KPerThread) + "_"; - out += std::to_string(pt->M1N1ThreadClusterM10) + "x" + - std::to_string(pt->M1N1ThreadClusterN10) + "x" + - std::to_string(pt->M1N1ThreadClusterM11) + "x" + - std::to_string(pt->M1N1ThreadClusterN11) + "_"; - - out += std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[0]) + "x" + - std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[1]) + "x" + - std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[2]) + "x" + - std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[3]) + "_"; - - out += std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[0]) + "x" + - std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[1]) + "x" + - std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[2]) + "x" + - std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[3]) + "_"; - - out += std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "x" + - std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "x" + - std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]) + "x" + - std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[3]) + "_"; - - out += std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "x" + - std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "x" + - std::to_string(pt->ABlockTransferSrcAccessOrder[2]) + "x" + - std::to_string(pt->ABlockTransferSrcAccessOrder[3]) + "_"; - - out += std::to_string(pt->ABlockTransferSrcVectorDim) + "_"; - out += std::to_string(pt->ABlockTransferSrcScalarPerVector) + "_"; - out += std::to_string(pt->ABlockTransferDstScalarPerVector_GM11) + "_"; - out += std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun) + "_"; - - out += std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[0]) + "x" + - std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[1]) + "x" + - std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[2]) + "x" + - std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[3]); - - out += std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[0]) + "x" + - std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[1]) + "x" + - std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[2]) + "x" + - std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[3]) + "_"; - - out += std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "x" + - std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "x" + - std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]) + "x" + - std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[3]) + "_"; - - out += std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "x" + - std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "x" + - std::to_string(pt->BBlockTransferSrcAccessOrder[2]) + "x" + - std::to_string(pt->BBlockTransferSrcAccessOrder[3]) + "_"; - - out += std::to_string(pt->BBlockTransferSrcVectorDim) + "_"; - out += std::to_string(pt->BBlockTransferSrcScalarPerVector) + "_"; - out += std::to_string(pt->BBlockTransferDstScalarPerVector_GN11) + "_"; - out += std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun) + "_"; - - out += std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "x" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]) + "_"; - - out += std::to_string(pt->CThreadTransferSrcDstVectorDim) + "_"; - out += std::to_string(pt->CThreadTransferDstScalarPerVector); - - return (out); -}; - -template -static std::string get_definition_string_from_types() -{ - std::string out; - - out += " -DCK_PARAM_IN_WEI_DATATYPE=" + std::to_string(Driver::get_typeid_from_type()) + - " -DCK_PARAM_CONV_COMPTYPE=" + std::to_string(Driver::get_typeid_from_type()) + - " -DCK_PARAM_OUT_DATATYPE=" + std::to_string(Driver::get_typeid_from_type()); - - return (out); -}; - -static std::string -get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw* pt) -{ - std::string out; - - out += " -DCK_PARAM_BlockSize=" + std::to_string(pt->BlockSize); - - out += " -DCK_PARAM_GM1PerBlockGM11=" + std::to_string(pt->GM1PerBlockGM11) + - " -DCK_PARAM_GN1PerBlockGN11=" + std::to_string(pt->GN1PerBlockGN11) + - " -DCK_PARAM_KPerBlock=" + std::to_string(pt->KPerBlock); - out += " -DCK_PARAM_M1PerThread=" + std::to_string(pt->M1PerThread) + - " -DCK_PARAM_N1PerThread=" + std::to_string(pt->N1PerThread) + - " -DCK_PARAM_KPerThread=" + std::to_string(pt->KPerThread); - - out += " -DCK_PARAM_M1N1ThreadClusterM10=" + std::to_string(pt->M1N1ThreadClusterM10) + - " -DCK_PARAM_M1N1ThreadClusterN10=" + std::to_string(pt->M1N1ThreadClusterN10) + - " -DCK_PARAM_M1N1ThreadClusterM11=" + std::to_string(pt->M1N1ThreadClusterM11) + - " -DCK_PARAM_M1N1ThreadClusterN11=" + std::to_string(pt->M1N1ThreadClusterN11); - - out += " -DCK_PARAM_ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11=" + - std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[0]) + "," + - std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[1]) + "," + - std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[2]) + "," + - std::to_string(pt->ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11[3]); - - out += " -DCK_PARAM_ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11=" + - std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[0]) + "," + - std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[1]) + "," + - std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[2]) + "," + - std::to_string(pt->ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11[3]); - - out += " -DCK_PARAM_ABlockTransferThreadClusterArrangeOrder=" + - std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[0]) + "," + - std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[1]) + "," + - std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[2]) + "," + - std::to_string(pt->ABlockTransferThreadClusterArrangeOrder[3]); - - out += " -DCK_PARAM_ABlockTransferSrcAccessOrder=" + - std::to_string(pt->ABlockTransferSrcAccessOrder[0]) + "," + - std::to_string(pt->ABlockTransferSrcAccessOrder[1]) + "," + - std::to_string(pt->ABlockTransferSrcAccessOrder[2]) + "," + - std::to_string(pt->ABlockTransferSrcAccessOrder[3]); - - out += - " -DCK_PARAM_ABlockTransferSrcVectorDim=" + std::to_string(pt->ABlockTransferSrcVectorDim); - out += " -DCK_PARAM_ABlockTransferSrcScalarPerVector=" + - std::to_string(pt->ABlockTransferSrcScalarPerVector); - out += " -DCK_PARAM_ABlockTransferDstScalarPerVector_GM11=" + - std::to_string(pt->ABlockTransferDstScalarPerVector_GM11); - out += " -DCK_PARAM_AThreadTransferSrcResetCoordinateAfterRun=" + - std::to_string(pt->AThreadTransferSrcResetCoordinateAfterRun); - - out += " -DCK_PARAM_BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11=" + - std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[0]) + "," + - std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[1]) + "," + - std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[2]) + "," + - std::to_string(pt->BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11[3]); - - out += " -DCK_PARAM_BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11=" + - std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[0]) + "," + - std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[1]) + "," + - std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[2]) + "," + - std::to_string(pt->BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11[3]); - - out += " -DCK_PARAM_BBlockTransferThreadClusterArrangeOrder=" + - std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[0]) + "," + - std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[1]) + "," + - std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[2]) + "," + - std::to_string(pt->BBlockTransferThreadClusterArrangeOrder[3]); - - out += " -DCK_PARAM_BBlockTransferSrcAccessOrder=" + - std::to_string(pt->BBlockTransferSrcAccessOrder[0]) + "," + - std::to_string(pt->BBlockTransferSrcAccessOrder[1]) + "," + - std::to_string(pt->BBlockTransferSrcAccessOrder[2]) + "," + - std::to_string(pt->BBlockTransferSrcAccessOrder[3]); - - out += - " -DCK_PARAM_BBlockTransferSrcVectorDim=" + std::to_string(pt->BBlockTransferSrcVectorDim); - out += " -DCK_PARAM_BBlockTransferSrcScalarPerVector=" + - std::to_string(pt->BBlockTransferSrcScalarPerVector); - out += " -DCK_PARAM_BBlockTransferDstScalarPerVector_GN11=" + - std::to_string(pt->BBlockTransferDstScalarPerVector_GN11); - out += " -DCK_PARAM_BThreadTransferSrcResetCoordinateAfterRun=" + - std::to_string(pt->BThreadTransferSrcResetCoordinateAfterRun); - - out += " -DCK_PARAM_CThreadTransferSrcDstAccessOrder=" + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[0]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[1]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[2]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[3]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[4]) + "," + - std::to_string(pt->CThreadTransferSrcDstAccessOrder[5]); - - out += " -DCK_PARAM_CThreadTransferSrcDstVectorDim=" + - std::to_string(pt->CThreadTransferSrcDstVectorDim); - out += " -DCK_PARAM_CThreadTransferDstScalarPerVector=" + - std::to_string(pt->CThreadTransferDstScalarPerVector); - - return (out); -}; - -} // namespace detail_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw - -template -void device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw_olc( - olCompile::Handle* handle, - const InLengths& in_n_c_hi_wi_lengths, - const WeiLengths& wei_k_c_y_x_lengths, - const OutLengths& out_n_k_ho_wo_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, - const tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw* tunable, - ck::index_t nrepeat) -{ - using namespace ck; - using namespace detail_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw; - using size_t = std::size_t; - - constexpr index_t N0 = 4; // this could not be a tunable so far - - //////////////////////////////////////////////////////////////////////////////////////////////////////////// - // The follow codes are only used for computing the grid_size, hasMainKBlockLoop, - // hasDoubleTailKBlockLoop - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto in_n_c_hi_wi_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); - const auto wei_k_c_y_x_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); - const auto out_n_k_ho_wo_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); - - const auto descs = transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad( - wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads); - - const auto a_gk_gm0_gm1_grid_desc = descs[I0]; - const auto c_gm0_gm1_gn0_gn1_grid_desc = descs[I2]; - - const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1); - const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3); - const auto GK = a_gk_gm0_gm1_grid_desc.GetLength(I0); - - const index_t grid_size = (GM1 / tunable->GM1PerBlockGM11) * (GN1 / tunable->GN1PerBlockGN11); - const bool hasMainKBlockLoop = ((GK + tunable->KPerBlock) / (2 * tunable->KPerBlock) > 1); - const bool hasDoubleTailKBlockLoop = ((GK / tunable->KPerBlock) % 2 == 0); - - /////////////////////////////////////////////////////////////////////////////////////////////////////////// - - // these buffers are usually provided by the user application - DeviceMem in_n_c_hi_wi_dev_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_dev_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_n_k_ho_wo_dev_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); - - in_n_c_hi_wi_dev_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_k_c_y_x_dev_buf.ToDevice(wei_k_c_y_x.mData.data()); - out_n_k_ho_wo_dev_buf.ToDevice(out_n_k_ho_wo.mData.data()); - - // these are workspace buffers that should be expressed to the user by the corresponding - // workspace API - DeviceMem workspace_buf(4096); - - void* a_gk_gm0_gm10_gm11_grid_desc_dev_buf = workspace_buf.GetDeviceBuffer(); - void* b_gk_gn0_gn10_gn11_grid_desc_dev_buf = - static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 1024); - void* c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc_dev_buf = - static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 2048); - void* c_blockid_to_gm10_gn10_block_cluster_adaptor_dev_buf = - static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 3072); - - const std::vector vld = {static_cast(tunable->BlockSize), 1, 1}; - const std::vector vgd1 = {static_cast(tunable->BlockSize), 1, 1}; - const std::vector vgd2 = {static_cast(grid_size * tunable->BlockSize), 1, 1}; - - std::string program_name = "dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.cpp"; - std::string algo_name = "implicit_gemm_conv_fwd_v4r4_nchw"; - - std::string param = " -std=c++17 "; - std::string network_config; - - param += get_definition_string_from_types() + - " -DCK_PARAM_HAS_MAIN_KBLOCK_LOOP=" + std::to_string(hasMainKBlockLoop) + - " -DCK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP=" + std::to_string(hasDoubleTailKBlockLoop) + - " -DCK_PARAM_N0=" + std::to_string(N0) + " " + - get_definition_string_from_tunable(tunable); - network_config = get_network_config_string_from_types() + "_V" + - std::to_string(hasDoubleTailKBlockLoop) + "_" + std::to_string(N0) + "_" + - get_network_config_string_from_tunable(tunable); - - std::vector kernel1_times; - std::vector kernel2_times; - - for(index_t i = 0; i < nrepeat; ++i) - { - KernelTimer timer1, timer2; - std::string kernel_name; - - kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw_prepare"; - auto network_config_1 = network_config + "_1"; - - timer1.Start(); - handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)( - static_cast(in_n_c_hi_wi_lengths[I0]), - static_cast(in_n_c_hi_wi_lengths[I1]), - static_cast(in_n_c_hi_wi_lengths[I2]), - static_cast(in_n_c_hi_wi_lengths[I3]), - static_cast(wei_k_c_y_x_lengths[I0]), - static_cast(wei_k_c_y_x_lengths[I2]), - static_cast(wei_k_c_y_x_lengths[I3]), - conv_strides[I0], - conv_strides[I1], - conv_dilations[I0], - conv_dilations[I1], - in_left_pads[I0], - in_left_pads[I1], - in_right_pads[I0], - in_right_pads[I1], - a_gk_gm0_gm10_gm11_grid_desc_dev_buf, - b_gk_gn0_gn10_gn11_grid_desc_dev_buf, - c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc_dev_buf, - c_blockid_to_gm10_gn10_block_cluster_adaptor_dev_buf); - timer2.End(); - - kernel_name = "dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw"; - auto network_config_2 = network_config + "_2"; - - timer2.Start(); - handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)( - reinterpret_cast(wei_k_c_y_x_dev_buf.GetDeviceBuffer()), - reinterpret_cast(in_n_c_hi_wi_dev_buf.GetDeviceBuffer()), - reinterpret_cast(out_n_k_ho_wo_dev_buf.GetDeviceBuffer()), - (const void*)(a_gk_gm0_gm10_gm11_grid_desc_dev_buf), - (const void*)(b_gk_gn0_gn10_gn11_grid_desc_dev_buf), - (const void*)(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc_dev_buf), - (const void*)(c_blockid_to_gm10_gn10_block_cluster_adaptor_dev_buf)); - timer2.End(); - - kernel1_times.push_back(timer1.GetElapsedTime()); - kernel2_times.push_back(timer2.GetElapsedTime()); - } - - { - auto ave_time1 = Driver::get_effective_average(kernel1_times); - auto ave_time2 = Driver::get_effective_average(kernel2_times); - - const auto N = in_n_c_hi_wi_lengths[I0]; - const auto C = in_n_c_hi_wi_lengths[I1]; - - const auto K = out_n_k_ho_wo_lengths[I1]; - const auto Ho = out_n_k_ho_wo_lengths[I2]; - const auto Wo = out_n_k_ho_wo_lengths[I3]; - - const auto Y = wei_k_c_y_x_lengths[I2]; - const auto X = wei_k_c_y_x_lengths[I3]; - - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / - (std::size_t(1000) * 1000 * 1000) / (ave_time1 + ave_time2); - - std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", " - << ave_time2 << "), " << perf << " TFlop/s" << std::endl; - }; - - // copy result back to host - out_n_k_ho_wo_dev_buf.FromDevice(out_n_k_ho_wo.mData.data()); -} diff --git a/host/CMakeLists.txt b/host/CMakeLists.txt new file mode 100644 index 0000000000..c9779398a6 --- /dev/null +++ b/host/CMakeLists.txt @@ -0,0 +1,4 @@ +add_subdirectory(host_tensor) +add_subdirectory(online_compilation) +add_subdirectory(driver_offline) +add_subdirectory(driver_online) diff --git a/host/driver_offline/CMakeLists.txt b/host/driver_offline/CMakeLists.txt new file mode 100644 index 0000000000..85bd31fbca --- /dev/null +++ b/host/driver_offline/CMakeLists.txt @@ -0,0 +1,21 @@ +include_directories(BEFORE + include + ${PROJECT_SOURCE_DIR}/host/host_tensor/include + ${PROJECT_SOURCE_DIR}/composable_kernel/include + ${PROJECT_SOURCE_DIR}/composable_kernel/include/utility + ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description + ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation + ${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform + ${PROJECT_SOURCE_DIR}/composable_kernel/include/driver + ${PROJECT_SOURCE_DIR}/external/rocm/include + ${PROJECT_SOURCE_DIR}/external/half/include +) + +set(CONV_FWD_DRIVER_OFFLINE_SOURCE conv_fwd_driver_offline.cpp) +set(CONV_BWD_DRIVER_OFFLINE_SOURCE conv_bwd_driver_offline.cpp) + +add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE}) +add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE}) + +target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor) +target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor) diff --git a/driver/conv_bwd_data_driver_v2.cpp b/host/driver_offline/conv_bwd_driver_offline.cpp similarity index 100% rename from driver/conv_bwd_data_driver_v2.cpp rename to host/driver_offline/conv_bwd_driver_offline.cpp diff --git a/driver/conv_driver_v2.cpp b/host/driver_offline/conv_fwd_driver_offline.cpp similarity index 86% rename from driver/conv_driver_v2.cpp rename to host/driver_offline/conv_fwd_driver_offline.cpp index 3574431556..405d6e7c40 100644 --- a/driver/conv_driver_v2.cpp +++ b/host/driver_offline/conv_fwd_driver_offline.cpp @@ -13,34 +13,28 @@ #include "host_conv.hpp" #include "device_tensor.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" -#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" -#include "device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp" -#include "device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" #define USE_DYNAMIC_MODE 1 -#define USE_CONV_FWD_V4R4_NCHW 0 -#define USE_CONV_FWD_V4R4_NHWC 0 +#define USE_CONV_FWD_V4R4_NCHW 1 #define USE_CONV_FWD_V4R4R2_NHWC 0 -#define USE_CONV_FWD_V4R5_NCHW 0 -#define USE_CONV_FWD_V4R5R2_NCHW 0 +#define USE_CONV_FWD_V6R1_NCHW 0 #define USE_CONV_FWD_V5R1_NCHW 0 -#define USE_CONV_FWD_V4R4R2_XDL_NCHW 1 -#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1 +#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0 +#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0 enum ConvForwardAlgo { V4R4NCHW, // 0 - V4R4NHWC, // 1 - V4R4R2NHWC, // 2 - V4R5NCHW, // 3 - V4R5R2NCHW, // 4 - V5R1NCHW, // 5 - V4R4R2XDLNCHW, // 6 - V4R4R4XDLNHWC // 7 + V4R4R2NHWC, // 1 + V6R1NCHW, // 2 + V5R1NCHW, // 3 + V4R4R2XDLNCHW, // 4 + V4R4R4XDLNHWC // 5 }; int main(int argc, char* argv[]) @@ -132,7 +126,7 @@ int main(int argc, char* argv[]) const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; #endif -#if 0 +#if 1 using in_data_t = float; using acc_data_t = float; using out_data_t = float; @@ -323,32 +317,6 @@ int main(int argc, char* argv[]) } #endif -#if USE_CONV_FWD_V4R4_NHWC - if(algo == ConvForwardAlgo::V4R4NHWC) - { - if(layout != ConvTensorLayout::NHWC) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nhwc(); - - device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - out_device, - nrepeat); - } -#endif - #if USE_CONV_FWD_V4R4R2_NHWC if(algo == ConvForwardAlgo::V4R4R2NHWC) { @@ -376,8 +344,8 @@ int main(int argc, char* argv[]) } #endif -#if USE_CONV_FWD_V4R5_NCHW - if(algo == ConvForwardAlgo::V4R5NCHW) +#if USE_CONV_FWD_V6R1_NCHW + if(algo == ConvForwardAlgo::V6R1NCHW) { if(layout != ConvTensorLayout::NCHW) { @@ -386,7 +354,7 @@ int main(int argc, char* argv[]) const auto tmp = f_make_for_device_nchw(); - device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw(tmp[I0], tmp[I1], @@ -402,33 +370,6 @@ int main(int argc, char* argv[]) } #endif -#if USE_CONV_FWD_V4R5R2_NCHW - if(algo == ConvForwardAlgo::V4R5R2NCHW) - { - if(layout != ConvTensorLayout::NCHW) - { - throw std::runtime_error("wrong! layout"); - } - - const auto tmp = f_make_for_device_nchw(); - - device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw( - tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in, - wei, - out_device, - nrepeat); - } -#endif - #if USE_CONV_FWD_V5R1_NCHW if(algo == ConvForwardAlgo::V5R1NCHW) { diff --git a/driver/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp similarity index 100% rename from driver/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp rename to host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp diff --git a/driver/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp similarity index 100% rename from driver/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp rename to host/driver_offline/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp similarity index 100% rename from driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp rename to host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..5890b12e00 --- /dev/null +++ b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,283 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp" + +template +void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + const auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); + +#if 0 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmKPack = 8; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#elif 0 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmKPack = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#elif 0 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmKPack = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 4] + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmKPack = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 4] + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmKPack = 4; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#endif + + const auto descs = +#if 1 + transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad +#else + transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1 +#endif + ( + wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads); + + for(index_t i = 0; i < 5; ++i) + { +#if 0 + float ave_time = launch_kernel_dynamic_gemm_xdlops_v1 +#else + float ave_time = launch_kernel_dynamic_gemm_xdlops_v2 +#endif + , + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK, + GemmABlockTransferDstScalarPerVector_KPack, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<0, 2, 1>, + Sequence<1, 0, 2>, + 1, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_KPack, + false, // don't move back src coordinate after threadwise copy, which will be fused + // with MoveSrcSliceWindow() to save addr computation + Sequence<2, 3, 0, 1>, + 3, + GemmCThreadTransferDstScalarPerVector_GemmN1, + decltype(descs[I4]), + decltype(descs[I5]), + decltype(descs[I6]), + decltype(descs[I7]), + decltype(descs[I8])>(static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + descs[I0], + descs[I1], + descs[I2], + descs[I3], + descs[I4], + descs[I5], + descs[I6], + descs[I7], + descs[I8], + nrepeat); + + float perf = (float)calculate_convolution_flops( + in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); +} diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp similarity index 100% rename from driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp rename to host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp similarity index 100% rename from driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp rename to host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..bb37ac309f --- /dev/null +++ b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,240 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" +#include "driver_dynamic_gemm_xdlops_v2r2.hpp" + +template +void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); + +#if 1 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#endif + + const auto descs = + transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, + in_n_hi_wi_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto out_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); + + constexpr auto out_m0_m1_m2_n_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); + + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_dynamic_gemm_xdlops_v2r2< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperation::Set, + decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), + decltype(in_gemmk0_gemmn_gemmk1_grid_desc), + decltype(out_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmK1, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1>, + 2, + GemmCThreadTransferDstScalarPerVector, + decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks), + decltype(out_m0_m1_m2_n_grid_iterator_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks)>( + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + wei_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, + in_gemmk0_gemmn_gemmk1_grid_iterator_hacks, + out_m0_m1_m2_n_grid_iterator_hacks, + wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Hi = in_n_hi_wi_c_lengths[I1]; + const auto Wi = in_n_hi_wi_c_lengths[I2]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); +} diff --git a/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..c1e63664e5 --- /dev/null +++ b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,305 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" +#include "driver_dynamic_gemm_xdlops_v2r3.hpp" + +template +void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); + +#if 1 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [256, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#endif + + const auto descs = + transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, + in_n_hi_wi_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto out_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); + + constexpr auto out_m0_m1_m2_n_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); + + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_dynamic_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperation::Set, + decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), + decltype(in_gemmk0_gemmn_gemmk1_grid_desc), + decltype(out_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmK1, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 6, + GemmCThreadTransferDstScalarPerVector, + decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks), + decltype(out_m0_m1_m2_n_grid_iterator_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + wei_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, + in_gemmk0_gemmn_gemmk1_grid_iterator_hacks, + out_m0_m1_m2_n_grid_iterator_hacks, + wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Hi = in_n_hi_wi_c_lengths[I1]; + const auto Wi = in_n_hi_wi_c_lengths[I2]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); +} diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp similarity index 100% rename from driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp rename to host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp similarity index 100% rename from driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp rename to host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp similarity index 62% rename from driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp rename to host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp index 702ddc9e8f..0b45350234 100644 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp @@ -1,7 +1,7 @@ #include #include "device.hpp" #include "host_tensor.hpp" -#include "transform_forward_convolution_into_gemm_v4r5r2_nchw_kcyx_nkhw.hpp" +#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp" #include "driver_dynamic_contraction_v1r2.hpp" template -void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw( +void device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw( const InLengths& in_n_c_hi_wi_lengths, const WeiLengths& wei_k_c_y_x_lengths, const OutLengths& out_n_k_ho_wo_lengths, @@ -43,11 +43,11 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw( wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); - const auto in_n_c_hi_wi_desc = + const auto in_desc_n_c_hi_wi = make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); - const auto wei_k_c_y_x_desc = + const auto wei_desc_k_c_y_x = make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); - const auto out_n_k_ho_wo_desc = + const auto out_desc_n_k_ho_wo = make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); #if 1 @@ -58,32 +58,32 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw( constexpr index_t GN0 = 4; constexpr index_t GK1 = 1; - constexpr index_t GemmGM1PerBlockGM11 = 128; - constexpr index_t GemmGN1PerBlockGN11 = 32; - constexpr index_t GemmKPerBlock = 8; + constexpr index_t GM1PerBlockGM11 = 128; + constexpr index_t GN1PerBlockGN11 = 32; + constexpr index_t GK0PerBlock = 8; - constexpr index_t GemmM1PerThreadM111 = 4; - constexpr index_t GemmN1PerThreadN111 = 4; - constexpr index_t GemmKPerThread = 1; + constexpr index_t BM1PerThreadBM11 = 4; + constexpr index_t BN1PerThreadBN11 = 4; + constexpr index_t BK0PerThread = 1; - constexpr index_t GemmM11N11ThreadClusterM1101 = 2; - constexpr index_t GemmM11N11ThreadClusterN1101 = 2; - constexpr index_t GemmM11N11ThreadClusterM1100 = 8; - constexpr index_t GemmM11N11ThreadClusterN1100 = 8; + constexpr index_t BM10BN10ThreadClusterBM100 = 8; + constexpr index_t BM10BN10ThreadClusterBN100 = 8; + constexpr index_t BM10BN10ThreadClusterBM101 = 2; + constexpr index_t BM10BN10ThreadClusterBN101 = 2; - using GemmABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>; - using GemmABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>; + using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>; + using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>; - using GemmABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>; - using GemmABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 1>; + using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>; + using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 1>; - using GemmBBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 1>; - using GemmBBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>; + using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 1>; + using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>; - using GemmBBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>; - using GemmBBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>; + using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>; + using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>; - constexpr index_t GemmCThreadTransferDstScalarPerVector_BN1 = 1; + constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1; #elif 1 // [8, 1, 128, 2] * [8, 4, 32, 2] = [1, 128, 4, 32] for fp16 // cdata = 64, BlockSize = 256 @@ -92,48 +92,48 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw( constexpr index_t GN0 = 4; constexpr index_t GK1 = 2; - constexpr index_t GemmGM1PerBlockGM11 = 128; - constexpr index_t GemmGN1PerBlockGN11 = 32; - constexpr index_t GemmKPerBlock = 8; + constexpr index_t GM1PerBlockGM11 = 128; + constexpr index_t GN1PerBlockGN11 = 32; + constexpr index_t GK0PerBlock = 8; - constexpr index_t GemmM1PerThreadM111 = 4; - constexpr index_t GemmN1PerThreadN111 = 4; - constexpr index_t GemmKPerThread = 1; + constexpr index_t BM1PerThreadBM11 = 4; + constexpr index_t BN1PerThreadBN11 = 4; + constexpr index_t BK0PerThread = 1; - constexpr index_t GemmM11N11ThreadClusterM1101 = 2; - constexpr index_t GemmM11N11ThreadClusterN1101 = 2; - constexpr index_t GemmM11N11ThreadClusterM1100 = 8; - constexpr index_t GemmM11N11ThreadClusterN1100 = 8; + constexpr index_t BM10BN10ThreadClusterBM100 = 8; + constexpr index_t BM10BN10ThreadClusterBN100 = 8; + constexpr index_t BM10BN10ThreadClusterBM101 = 2; + constexpr index_t BM10BN10ThreadClusterBN101 = 2; - using GemmABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 2>; - using GemmABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>; + using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 2>; + using ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>; - using GemmABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>; - using GemmABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 2>; + using ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>; + using ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 2>; - using GemmBBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 2>; - using GemmBBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>; + using BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 2>; + using BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>; - using GemmBBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>; - using GemmBBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 2>; + using BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>; + using BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 2>; - constexpr index_t GemmCThreadTransferDstScalarPerVector_BN1 = 1; + constexpr index_t CThreadTransferDstScalarPerVector_BN1 = 1; #endif const auto descs = - transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}, - Number{}); + transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_desc_k_c_y_x, + in_desc_n_c_hi_wi, + out_desc_n_k_ho_wo, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}, + Number{}); - const auto wei_gk0_gm0_gm1_gk1_grid_desc = descs[I0]; - const auto in_gk0_gn0_gn1_gk1_grid_desc = descs[I1]; - const auto out_gm0_gm1_gn0_gn1_grid_desc = descs[I2]; + const auto wei_grid_desc_gk0_gm0_gm1_gk1 = descs[I0]; + const auto in_grid_desc_gk0_gn0_gn1_gk1 = descs[I1]; + const auto out_grid_desc_gm0_gm1_gn0_gn1 = descs[I2]; // HACK: hacks that control index calculation when iterating over A, B, C matrix constexpr auto wei_grid_iterator_hacks = @@ -189,36 +189,36 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw( TAcc, TOut, InMemoryDataOperation::Set, - decltype(wei_gk0_gm0_gm1_gk1_grid_desc), - decltype(in_gk0_gn0_gn1_gk1_grid_desc), - decltype(out_gm0_gm1_gn0_gn1_grid_desc), - GemmGM1PerBlockGM11, - GemmGN1PerBlockGN11, - GemmKPerBlock, - GemmM1PerThreadM111, - GemmN1PerThreadN111, - GemmKPerThread, - GemmM11N11ThreadClusterM1100, - GemmM11N11ThreadClusterN1100, - GemmM11N11ThreadClusterM1101, - GemmM11N11ThreadClusterN1101, - GemmABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, - GemmABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, + decltype(wei_grid_desc_gk0_gm0_gm1_gk1), + decltype(in_grid_desc_gk0_gn0_gn1_gk1), + decltype(out_grid_desc_gm0_gm1_gn0_gn1), + GM1PerBlockGM11, + GN1PerBlockGN11, + GK0PerBlock, + BM1PerThreadBM11, + BN1PerThreadBN11, + BK0PerThread, + BM10BN10ThreadClusterBM100, + BM10BN10ThreadClusterBN100, + BM10BN10ThreadClusterBM101, + BM10BN10ThreadClusterBN101, + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, Sequence<1, 2, 3, 0, 4>, // ABlockTransferThreadClusterArrangeOrder Sequence<3, 2, 1, 0, 4>, // ABlockTransferSrcAccessOrder - GemmABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, - GemmABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, Sequence<0, 1, 2, 3, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder - GemmBBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, - GemmBBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, Sequence<0, 4, 1, 2, 3>, // BBlockTransferThreadClusterArrangeOrder Sequence<4, 3, 2, 0, 1>, // BBlockTransferSrcAccessOrder - GemmBBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - GemmBBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, Sequence<0, 1, 2, 3, 4>, // BBlockTransferSrcVectorTensorContiguousDimOrder Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder 5, // CThreadTransferSrcDstVectorDim - GemmCThreadTransferDstScalarPerVector_BN1, + CThreadTransferDstScalarPerVector_BN1, decltype(wei_grid_iterator_hacks), decltype(in_grid_iterator_hacks), decltype(out_grid_iterator_hacks), @@ -227,9 +227,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw( static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - wei_gk0_gm0_gm1_gk1_grid_desc, - in_gk0_gn0_gn1_gk1_grid_desc, - out_gm0_gm1_gn0_gn1_grid_desc, + wei_grid_desc_gk0_gm0_gm1_gk1, + in_grid_desc_gk0_gn0_gn1_gk1, + out_grid_desc_gm0_gm1_gn0_gn1, wei_grid_iterator_hacks, in_grid_iterator_hacks, out_grid_iterator_hacks, @@ -238,7 +238,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw( nrepeat); float perf = (float)calculate_convolution_flops( - in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) / + in_desc_n_c_hi_wi, wei_desc_k_c_y_x, out_desc_n_k_ho_wo) / (std::size_t(1000) * 1000 * 1000) / ave_time; std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; diff --git a/host/driver_online/CMakeLists.txt b/host/driver_online/CMakeLists.txt new file mode 100644 index 0000000000..2ae05e0ba5 --- /dev/null +++ b/host/driver_online/CMakeLists.txt @@ -0,0 +1,21 @@ +include_directories(BEFORE + include + ${PROJECT_BINARY_DIR}/host/online_compilation/include + ${PROJECT_SOURCE_DIR}/host/online_compilation/include + ${PROJECT_SOURCE_DIR}/host/host_tensor/include + ${PROJECT_SOURCE_DIR}/composable_kernel/include + ${PROJECT_SOURCE_DIR}/composable_kernel/include/utility + ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description + ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation + ${PROJECT_SOURCE_DIR}/composable_kernel/include/problem_transform + ${PROJECT_SOURCE_DIR}/composable_kernel/include/driver + ${PROJECT_SOURCE_DIR}/external/rocm/include + ${PROJECT_SOURCE_DIR}/external/half/include +) + +set(CONV_FWD_DRIVER_ONLINE_SOURCE conv_fwd_driver_online.cpp) + +add_executable(conv_fwd_driver_online ${CONV_FWD_DRIVER_ONLINE_SOURCE}) + +target_link_libraries(conv_fwd_driver_online PRIVATE host_tensor) +target_link_libraries(conv_fwd_driver_online PRIVATE online_compilation) diff --git a/driver/conv_driver_v2_olc.cpp b/host/driver_online/conv_fwd_driver_online.cpp similarity index 80% rename from driver/conv_driver_v2_olc.cpp rename to host/driver_online/conv_fwd_driver_online.cpp index 14e3e95205..3b25f5d039 100644 --- a/driver/conv_driver_v2_olc.cpp +++ b/host/driver_online/conv_fwd_driver_online.cpp @@ -12,26 +12,22 @@ #include "conv_common.hpp" #include "host_conv.hpp" #include "device_tensor.hpp" - -#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" -#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp" - -#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp" -#include "olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp" - -#define USE_CONV_FWD_V4R4_NCHW 1 -#define USE_CONV_FWD_V4R5_NCHW 1 -#define USE_CONV_FWD_V4R4_XDLOPS_NCHW 1 -#define USE_CONV_FWD_V4R4_XDLOPS_NHWC 1 - -#include "conv_tunables.hpp" #include "handle.hpp" #include "hipCheck.hpp" +#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" +#include "online_device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp" +#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp" +#include "online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp" + +#define USE_CONV_FWD_V4R4_NCHW 1 +#define USE_CONV_FWD_V6R1_NCHW 1 +#define USE_CONV_FWD_V4R4_XDLOPS_NCHW 1 +#define USE_CONV_FWD_V4R4_XDLOPS_NHWC 1 enum ConvForwardAlgo { V4R4NCHW, // 0 - V4R5NCHW, // 1 + V6R1NCHW, // 1 V4R4XDLNCHW, // 2 V4R4XDLNHWC // 3 }; @@ -94,15 +90,17 @@ int main(int argc, char* argv[]) const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; #if 1 - constexpr index_t in_vector_size = 1; - using in_data_t = float; - using acc_data_t = float; - using out_data_t = float; + using in_data_t = float; + using acc_data_t = float; + using out_data_t = float; #elif 1 - constexpr index_t in_vector_size = 16; - using in_data_t = int8_t; - using acc_data_t = int32_t; - using out_data_t = int8_t; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; +#elif 1 + using in_data_t = int8_t; + using acc_data_t = int32_t; + using out_data_t = int8_t; #endif std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); @@ -230,9 +228,9 @@ int main(int argc, char* argv[]) tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw* tunable = &default_tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw; - device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw_olc( + online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( handle, tmp[I0], tmp[I1], @@ -249,8 +247,8 @@ int main(int argc, char* argv[]) } #endif -#if USE_CONV_FWD_V4R5_NCHW - if(algo == ConvForwardAlgo::V4R5NCHW) +#if USE_CONV_FWD_V6R1_NCHW + if(algo == ConvForwardAlgo::V6R1NCHW) { if(layout != ConvTensorLayout::NCHW) { @@ -259,12 +257,11 @@ int main(int argc, char* argv[]) const auto tmp = f_make_for_device_nchw(); - tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw* tunable = - &default_tunable_dyn_conv_fwd_v4r5_nchw_kcyx_nkhw; + const auto tunable = tunable_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw{}; - device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw_olc( + online_device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw( handle, tmp[I0], tmp[I1], @@ -294,22 +291,22 @@ int main(int argc, char* argv[]) tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw* tunable = &default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw; - device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_olc( - handle, - tmp[I0], - tmp[I1], - tmp[I2], - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - in, - wei, - out_device, - tunable, - nrepeat); + online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw< + in_data_t, + acc_data_t, + out_data_t>(handle, + tmp[I0], + tmp[I1], + tmp[I2], + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + in, + wei, + out_device, + tunable, + nrepeat); } #endif @@ -326,22 +323,22 @@ int main(int argc, char* argv[]) tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk* tunable = &default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk; - device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_olc( - handle, - tmp[I0], - tmp[I1], - tmp[I2], - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - in, - wei, - out_device, - tunable, - nrepeat); + online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk< + in_data_t, + acc_data_t, + out_data_t>(handle, + tmp[I0], + tmp[I1], + tmp[I2], + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + in, + wei, + out_device, + tunable, + nrepeat); } #endif diff --git a/host/driver_online/include/conv_tunable_fwd_v4r4_nchw_kcyx_nkhw.hpp b/host/driver_online/include/conv_tunable_fwd_v4r4_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..05ee9846b8 --- /dev/null +++ b/host/driver_online/include/conv_tunable_fwd_v4r4_nchw_kcyx_nkhw.hpp @@ -0,0 +1,50 @@ +#ifndef CONV_TUNABLE_FWD_V4R4_NCHW_KCYX_NKHW_HPP +#define CONV_TUNABLE_FWD_V4R4_NCHW_KCYX_NKHW_HPP + +struct tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw +{ + int32_t BlockSize; + + int32_t MPerBlock; + int32_t NPerBlock; + int32_t KPerBlock; + + int32_t M1PerThread; + int32_t N1PerThread; + int32_t KPerThread; + + int32_t M1N1ThreadClusterM10; + int32_t M1N1ThreadClusterN10; + int32_t M1N1ThreadClusterM11; + int32_t M1N1ThreadClusterN11; + + std::array ABlockTransferThreadSliceLengths_K_M0_M1; + std::array ABlockTransferThreadClusterLengths_K_M0_M1; + std::array ABlockTransferThreadClusterArrangeOrder; + std::array ABlockTransferSrcAccessOrder; + int32_t ABlockTransferSrcVectorDim; + int32_t ABlockTransferSrcScalarPerVector; + int32_t ABlockTransferDstScalarPerVector_M1; + bool AThreadTransferSrcResetCoordinateAfterRun; + + std::array BBlockTransferThreadSliceLengths_K_N0_N1; + std::array BBlockTransferThreadClusterLengths_K_N0_N1; + std::array BBlockTransferThreadClusterArrangeOrder; + std::array BBlockTransferSrcAccessOrder; + int32_t BBlockTransferSrcVectorDim; + int32_t BBlockTransferSrcScalarPerVector; + int32_t BBlockTransferDstScalarPerVector_N1; + bool BThreadTransferSrcResetCoordinateAfterRun; + + std::array CThreadTransferSrcDstAccessOrder; + int32_t CThreadTransferSrcDstVectorDim; + int32_t CThreadTransferDstScalarPerVector; +}; + +static tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw default_tunable_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw = { + 256, 128, 128, 8, 4, 4, 1, + 8, 8, 2, 2, {4, 1, 1}, {2, 1, 128}, {2, 1, 0}, + {2, 1, 0}, 0, 4, 1, false, {4, 1, 1}, {2, 1, 128}, + {0, 1, 2}, {0, 1, 2}, 2, 1, 1, false, {3, 4, 5, 0, 1, 2}, + 5, 1}; +#endif diff --git a/host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..7681438d95 --- /dev/null +++ b/host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,73 @@ +#ifndef CONV_TUNABLE_FWD_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP +#define CONV_TUNABLE_FWD_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP + +struct tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw +{ + int32_t BlockSize; + + int32_t MPerBlock; + int32_t NPerBlock; + int32_t KPerBlock; + + int32_t MPerWave; + int32_t NPerWave; + int32_t K1; + + int32_t MRepeat; + int32_t NRepeat; + + std::array ABlockTransferThreadSliceLengths_K0_M_K1; + std::array ABlockTransferThreadClusterLengths_K0_M_K1; + std::array ABlockTransferThreadClusterArrangeOrder; + std::array ABlockTransferSrcAccessOrder; + int32_t ABlockTransferSrcVectorDim; + int32_t ABlockTransferSrcScalarPerVector; + int32_t ABlockTransferDstScalarPerVector_K1; + bool AThreadTransferSrcResetCoordinateAfterRun; + + std::array BBlockTransferThreadSliceLengths_K0_N_K1; + std::array BBlockTransferThreadClusterLengths_K0_N_K1; + std::array BBlockTransferThreadClusterArrangeOrder; + std::array BBlockTransferSrcAccessOrder; + int32_t BBlockTransferSrcVectorDim; + int32_t BBlockTransferSrcScalarPerVector; + int32_t BBlockTransferDstScalarPerVector_K1; + bool BThreadTransferSrcResetCoordinateAfterRun; + + std::array CThreadTransferSrcDstAccessOrder; + int32_t CThreadTransferSrcDstVectorDim; + int32_t CThreadTransferDstScalarPerVector; +}; + +static tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw + default_tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw = { + 256, // BlockSize + 128, // MPerBlock, + 128, // NPerBlock, + 4, // KPerBlock, + 32, // MPerWave, + 32, // NPerWave, + 4, // K1, + 2, // MRepeat, + 2, // NRepeat, + {1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1, + {4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1, + {1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder, + {1, 0, 2}, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim + 1, // ABlockTransferSrcScalarPerVector, + 4, // ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + {1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1, + {4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1, + {0, 2, 1}, // BBlockTransferThreadClusterArrangeOrder, + {1, 0, 2}, // BBlockTransferSrcAccessOrder, + 1, // BBlockTransferSrcVectorDim + 1, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_K1 + false, // BThreadTransferSrcResetCoordinateAfterRun + {3, 0, 1, 2, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder + 7, // CThreadTransferSrcDstVectorDim, + 1 // CThreadTransferDstScalarPerVector +}; +#endif diff --git a/host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..a4fd8095c4 --- /dev/null +++ b/host/driver_online/include/conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,73 @@ +#ifndef CONV_TUNABLE_FWD_V4R4_XDLOPS_NHWC_KYXC_NHWK_HPP +#define CONV_TUNABLE_FWD_V4R4_XDLOPS_NHWC_KYXC_NHWK_HPP + +struct tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk +{ + int32_t BlockSize; + + int32_t MPerBlock; + int32_t NPerBlock; + int32_t KPerBlock; + + int32_t MPerWave; + int32_t NPerWave; + int32_t K1; + + int32_t MRepeat; + int32_t NRepeat; + + std::array ABlockTransferThreadSliceLengths_K0_M_K1; + std::array ABlockTransferThreadClusterLengths_K0_M_K1; + std::array ABlockTransferThreadClusterArrangeOrder; + std::array ABlockTransferSrcAccessOrder; + int32_t ABlockTransferSrcVectorDim; + int32_t ABlockTransferSrcScalarPerVector; + int32_t ABlockTransferDstScalarPerVector_K1; + bool AThreadTransferSrcResetCoordinateAfterRun; + + std::array BBlockTransferThreadSliceLengths_K0_N_K1; + std::array BBlockTransferThreadClusterLengths_K0_N_K1; + std::array BBlockTransferThreadClusterArrangeOrder; + std::array BBlockTransferSrcAccessOrder; + int32_t BBlockTransferSrcVectorDim; + int32_t BBlockTransferSrcScalarPerVector; + int32_t BBlockTransferDstScalarPerVector_K1; + bool BThreadTransferSrcResetCoordinateAfterRun; + + std::array CThreadTransferSrcDstAccessOrder; + int32_t CThreadTransferSrcDstVectorDim; + int32_t CThreadTransferDstScalarPerVector; +}; + +static tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk + default_tunable_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk = { + 256, // BlockSize + 128, // MPerBlock, + 128, // NPerBlock, + 4, // KPerBlock, + 32, // MPerWave, + 32, // NPerWave, + 4, // K1, + 2, // MRepeat, + 2, // NRepeat, + {1, 2, 4}, // ABlockTransferThreadSliceLengths_K0_M_K1, + {4, 64, 1}, // ABlockTransferThreadClusterLengths_K0_M_K1, + {1, 0, 2}, // ABlockTransferThreadClusterArrangeOrder, + {1, 0, 2}, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim + 4, // ABlockTransferSrcScalarPerVector, + 4, // ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + {1, 2, 4}, // BBlockTransferThreadSliceLengths_K0_N_K1, + {4, 64, 1}, // BBlockTransferThreadClusterLengths_K0_N_K1, + {1, 0, 2}, // BBlockTransferThreadClusterArrangeOrder, + {1, 0, 2}, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim + 4, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_K1 + false, // BThreadTransferSrcResetCoordinateAfterRun + {2, 3, 0, 1, 7, 5, 4, 6}, // CThreadTransferSrcDstAccessOrder + 7, // CThreadTransferSrcDstVectorDim, + 1 // CThreadTransferDstScalarPerVector +}; +#endif diff --git a/host/driver_online/include/conv_tunable_fwd_v6r1_nchw_kcyx_nkhw.hpp b/host/driver_online/include/conv_tunable_fwd_v6r1_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..f307e22f53 --- /dev/null +++ b/host/driver_online/include/conv_tunable_fwd_v6r1_nchw_kcyx_nkhw.hpp @@ -0,0 +1,42 @@ +#ifndef CONV_TUNABLE_FWD_V6R1_NCHW_KCYX_NKHW_HPP +#define CONV_TUNABLE_FWD_V6R1_NCHW_KCYX_NKHW_HPP + +struct tunable_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw +{ + int32_t BlockSize = 256; + + int32_t GN0 = 4; + int32_t GK1 = 1; + + int32_t GM1PerBlockGM11 = 128; + int32_t GN1PerBlockGN11 = 32; + int32_t GK0PerBlock = 8; + + int32_t BM1PerThreadBM11 = 4; + int32_t BN1PerThreadBN11 = 4; + int32_t BK0PerThread = 1; + + int32_t BM10BN10ThreadClusterBM100 = 2; + int32_t BM10BN10ThreadClusterBN100 = 2; + int32_t BM10BN10ThreadClusterBM101 = 8; + int32_t BM10BN10ThreadClusterBN101 = 8; + + std::array ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = {4, 1, 1, 1, 1}; + std::array ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = { + 2, 1, 1, 128, 1}; + std::array ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = { + 4, 1, 1, 1, 1}; + std::array ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = { + 1, 1, 1, 1, 1}; + + std::array BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = {1, 4, 1, 1, 1}; + std::array BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = { + 8, 1, 1, 32, 1}; + std::array BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = { + 1, 1, 1, 1, 1}; + std::array BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = { + 1, 1, 1, 1, 1}; + + int32_t CThreadTransferDstScalarPerVector = 1; +}; +#endif diff --git a/driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp similarity index 99% rename from driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp rename to host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp index 94a9bcc06d..f852c4dc6f 100644 --- a/driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -1,13 +1,11 @@ #include "device.hpp" #include "host_tensor.hpp" +#include "handle.hpp" +#include "online_driver_common.hpp" #include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor_helper.hpp" #include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" - -#include "olc_driver_common.hpp" -#include "conv_tunables.hpp" - -#include "handle.hpp" +#include "conv_tunable_fwd_v4r4_nchw_kcyx_nkhw.hpp" namespace detail_dyn_conv_fwd_v4r4_nchw_kcyx_nkhw { @@ -211,7 +209,7 @@ template -void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw_olc( +void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( olCompile::Handle* handle, const InLengths& in_n_c_hi_wi_lengths, const WeiLengths& wei_k_c_y_x_lengths, diff --git a/driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp similarity index 98% rename from driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp rename to host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp index 2b653dbae1..703f8592b8 100644 --- a/driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp @@ -1,12 +1,10 @@ #include "device.hpp" #include "host_tensor.hpp" +#include "handle.hpp" +#include "online_driver_common.hpp" #include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor_helper.hpp" - -#include "olc_driver_common.hpp" -#include "conv_tunables.hpp" - -#include "handle.hpp" +#include "conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp" namespace detail_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw { @@ -208,7 +206,7 @@ template -void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_olc( +void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( olCompile::Handle* handle, const InLengths& in_n_c_hi_wi_lengths, const WeiLengths& wei_k_c_y_x_lengths, diff --git a/driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp similarity index 98% rename from driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp rename to host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp index 073e90bd63..2f4787d350 100644 --- a/driver/include/olc_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -1,13 +1,11 @@ #include "device.hpp" #include "host_tensor.hpp" +#include "handle.hpp" +#include "online_driver_common.hpp" #include "dynamic_tensor_descriptor.hpp" #include "dynamic_tensor_descriptor_helper.hpp" #include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" - -#include "olc_driver_common.hpp" -#include "conv_tunables.hpp" - -#include "handle.hpp" +#include "conv_tunable_fwd_v4r4_xdlops_nhwc_kyxc_nhwk.hpp" namespace detail_dyn_conv_fwd_v4r4_xdlops_nhwc_kyxc_nhwk { @@ -209,7 +207,7 @@ template -void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_olc( +void online_device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk( olCompile::Handle* handle, const InLengths& in_n_hi_wi_c_lengths, const WeiLengths& wei_k_y_x_c_lengths, diff --git a/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..2ee2680f5c --- /dev/null +++ b/host/driver_online/include/online_device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.hpp @@ -0,0 +1,425 @@ +#include "device.hpp" +#include "host_tensor.hpp" +#include "handle.hpp" +#include "online_driver_common.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp" +#include "conv_tunable_fwd_v6r1_nchw_kcyx_nkhw.hpp" + +namespace detail_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw { + +template +static std::string get_network_config_string_from_types() +{ + std::string out("DAT_"); + + out += static_cast(Driver::get_typeid_from_type()) + + static_cast(Driver::get_typeid_from_type()) + + static_cast(Driver::get_typeid_from_type()); + + return (out); +}; + +static std::string +get_network_config_string_from_tunable(const tunable_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw& tunable) +{ + std::string out("TUN_"); + + out += std::to_string(tunable.BlockSize) + "_"; + + out += std::to_string(tunable.GN0) + "x" + std::to_string(tunable.GK1) + "_"; + + out += std::to_string(tunable.GM1PerBlockGM11) + "x" + std::to_string(tunable.GN1PerBlockGN11) + + "x" + std::to_string(tunable.GK0PerBlock) + "_"; + + out += std::to_string(tunable.BM1PerThreadBM11) + "x" + + std::to_string(tunable.BN1PerThreadBN11) + "x" + std::to_string(tunable.BK0PerThread) + + "_"; + + out += std::to_string(tunable.BM10BN10ThreadClusterBM100) + "x" + + std::to_string(tunable.BM10BN10ThreadClusterBN100) + "x" + + std::to_string(tunable.BM10BN10ThreadClusterBM101) + "x" + + std::to_string(tunable.BM10BN10ThreadClusterBN101) + "_"; + + out += std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[0]) + "x" + + std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[1]) + "x" + + std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[2]) + "x" + + std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[3]) + "x" + + std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[4]) + "_"; + + out += + std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[0]) + "x" + + std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[1]) + "x" + + std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[2]) + "x" + + std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[3]) + "x" + + std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[4]) + "_"; + + out += std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0]) + + "x" + + std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1]) + + "x" + + std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2]) + + "x" + + std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3]) + + "x" + + std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4]) + + "_"; + + out += std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0]) + + "x" + + std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1]) + + "x" + + std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2]) + + "x" + + std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3]) + + "x" + + std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4]) + + "_"; + + out += std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[0]) + "x" + + std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[1]) + "x" + + std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[2]) + "x" + + std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[3]) + "x" + + std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[4]) + "_"; + + out += + std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[0]) + "x" + + std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[1]) + "x" + + std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[2]) + "x" + + std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[3]) + "x" + + std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[4]) + "_"; + + out += std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0]) + + "x" + + std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1]) + + "x" + + std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2]) + + "x" + + std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3]) + + "x" + + std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4]) + + "_"; + + out += std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0]) + + "x" + + std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1]) + + "x" + + std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2]) + + "x" + + std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3]) + + "x" + + std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4]) + + "_"; + + out += std::to_string(tunable.CThreadTransferDstScalarPerVector); + + return (out); +}; + +template +static std::string get_definition_string_from_types() +{ + std::string out; + + out += " -DCK_PARAM_IN_WEI_DATATYPE=" + std::to_string(Driver::get_typeid_from_type()) + + " -DCK_PARAM_ACC_DATATYPE=" + std::to_string(Driver::get_typeid_from_type()) + + " -DCK_PARAM_OUT_DATATYPE=" + std::to_string(Driver::get_typeid_from_type()); + + return (out); +}; + +static std::string +get_definition_string_from_tunable(const tunable_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw& tunable) +{ + std::string out; + + out += " -DCK_PARAM_BlockSize=" + std::to_string(tunable.BlockSize); + + out += " -DCK_PARAM_GN0=" + std::to_string(tunable.GN0); + out += " -DCK_PARAM_GK1=" + std::to_string(tunable.GK1); + + out += " -DCK_PARAM_GM1PerBlockGM11=" + std::to_string(tunable.GM1PerBlockGM11) + + " -DCK_PARAM_GN1PerBlockGN11=" + std::to_string(tunable.GN1PerBlockGN11) + + " -DCK_PARAM_GK0PerBlock=" + std::to_string(tunable.GK0PerBlock); + + out += " -DCK_PARAM_BM1PerThreadBM11=" + std::to_string(tunable.BM1PerThreadBM11) + + " -DCK_PARAM_BN1PerThreadBN11=" + std::to_string(tunable.BN1PerThreadBN11) + + " -DCK_PARAM_BK0PerThread=" + std::to_string(tunable.BK0PerThread); + + out += " -DCK_PARAM_BM10BN10ThreadClusterBM100=" + + std::to_string(tunable.BM10BN10ThreadClusterBM100) + + " -DCK_PARAM_BM10BN10ThreadClusterBN100=" + + std::to_string(tunable.BM10BN10ThreadClusterBN100) + + " -DCK_PARAM_BM10BN10ThreadClusterBM101=" + + std::to_string(tunable.BM10BN10ThreadClusterBM101) + + " -DCK_PARAM_BM10BN10ThreadClusterBN101=" + + std::to_string(tunable.BM10BN10ThreadClusterBN101); + + out += " -DCK_PARAM_ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1=" + + std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[0]) + "," + + std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[1]) + "," + + std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[2]) + "," + + std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[3]) + "," + + std::to_string(tunable.ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1[4]); + + out += + " -DCK_PARAM_ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1=" + + std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[0]) + "," + + std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[1]) + "," + + std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[2]) + "," + + std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[3]) + "," + + std::to_string(tunable.ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1[4]); + + out += " -DCK_PARAM_ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" + + std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0]) + + "," + + std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1]) + + "," + + std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2]) + + "," + + std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3]) + + "," + + std::to_string(tunable.ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4]); + + out += " -DCK_PARAM_ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1=" + + std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[0]) + + "," + + std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[1]) + + "," + + std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[2]) + + "," + + std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[3]) + + "," + + std::to_string(tunable.ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1[4]); + + out += " -DCK_PARAM_BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1=" + + std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[0]) + "," + + std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[1]) + "," + + std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[2]) + "," + + std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[3]) + "," + + std::to_string(tunable.BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1[4]); + + out += + " -DCK_PARAM_BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1=" + + std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[0]) + "," + + std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[1]) + "," + + std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[2]) + "," + + std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[3]) + "," + + std::to_string(tunable.BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1[4]); + + out += " -DCK_PARAM_BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" + + std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0]) + + "," + + std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1]) + + "," + + std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2]) + + "," + + std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3]) + + "," + + std::to_string(tunable.BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4]); + + out += " -DCK_PARAM_BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1=" + + std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[0]) + + "," + + std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[1]) + + "," + + std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[2]) + + "," + + std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[3]) + + "," + + std::to_string(tunable.BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1[4]); + + out += " -DCK_PARAM_CThreadTransferDstScalarPerVector=" + + std::to_string(tunable.CThreadTransferDstScalarPerVector); + + return (out); +}; + +} // namespace detail_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw + +template +void online_device_dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw( + olCompile::Handle* handle, + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + const tunable_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw& tunable, + ck::index_t nrepeat) +{ + using namespace ck; + using namespace detail_dyn_conv_fwd_v6r1_nchw_kcyx_nkhw; + using size_t = std::size_t; + + //////////////////////////////////////////////////////////////////////////////////////////////////////////// + // The follow codes are only used for computing the grid_size, hasMainKBlockLoop, + // hasDoubleTailKBlockLoop + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); + + const auto descs = + transform_forward_convolution_into_contraction_v6r1_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + tunable.GN0, + tunable.GK1); + + const auto a_grid_desc_gk0_gm0_gm1_gk1 = descs[I0]; + const auto c_grid_desc_gm0_gm1_gn0_gn1 = descs[I2]; + + const auto GM1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I1); + const auto GN1 = c_grid_desc_gm0_gm1_gn0_gn1.GetLength(I3); + const auto GK = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0); + + const index_t grid_size = (GM1 / tunable.GM1PerBlockGM11) * (GN1 / tunable.GN1PerBlockGN11); + const bool hasMainKBlockLoop = ((GK + tunable.GK0PerBlock) / (2 * tunable.GK0PerBlock) > 1); + const bool hasDoubleTailKBlockLoop = ((GK / tunable.GK0PerBlock) % 2 == 0); + + /////////////////////////////////////////////////////////////////////////////////////////////////////////// + + // these buffers are usually provided by the user application + DeviceMem in_n_c_hi_wi_dev_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_dev_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_dev_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_dev_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_dev_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_dev_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + // these are workspace buffers that should be expressed to the user by the corresponding + // workspace API + DeviceMem workspace_buf(4096); + + void* a_grid_desc_gk0_gm0_gm10_gm11_gk1_dev_buf = workspace_buf.GetDeviceBuffer(); + void* b_grid_desc_gk0_gn0_gn10_gn11_gk1_dev_buf = + static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 1024); + void* c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1_dev_buf = + static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 2048); + void* c_grid_block_cluster_blockid_to_gm10_gn10_dev_buf = + static_cast(static_cast(workspace_buf.GetDeviceBuffer()) + 3072); + + const std::vector vld = {static_cast(tunable.BlockSize), 1, 1}; + const std::vector vgd1 = {static_cast(tunable.BlockSize), 1, 1}; + const std::vector vgd2 = {static_cast(grid_size * tunable.BlockSize), 1, 1}; + + std::string program_name = "dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw.cpp"; + std::string algo_name = "implicit_gemm_conv_fwd_v6r1_nchw"; + + std::string param = " -std=c++17 "; + std::string network_config; + + param += get_definition_string_from_types() + + " -DCK_PARAM_HAS_MAIN_KBLOCK_LOOP=" + std::to_string(hasMainKBlockLoop) + + " -DCK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP=" + std::to_string(hasDoubleTailKBlockLoop) + + get_definition_string_from_tunable(tunable); + + network_config = get_network_config_string_from_types() + "_" + + std::to_string(hasDoubleTailKBlockLoop) + "_" + + get_network_config_string_from_tunable(tunable); + + std::vector kernel1_times; + std::vector kernel2_times; + + for(index_t i = 0; i < nrepeat; ++i) + { + KernelTimer timer1, timer2; + std::string kernel_name; + + kernel_name = "dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw_prepare"; + auto network_config_1 = network_config + "_1"; + + timer1.Start(); + handle->AddKernel(algo_name, network_config_1, program_name, kernel_name, vld, vgd1, param)( + static_cast(in_n_c_hi_wi_lengths[I0]), + static_cast(in_n_c_hi_wi_lengths[I1]), + static_cast(in_n_c_hi_wi_lengths[I2]), + static_cast(in_n_c_hi_wi_lengths[I3]), + static_cast(wei_k_c_y_x_lengths[I0]), + static_cast(wei_k_c_y_x_lengths[I2]), + static_cast(wei_k_c_y_x_lengths[I3]), + conv_strides[I0], + conv_strides[I1], + conv_dilations[I0], + conv_dilations[I1], + in_left_pads[I0], + in_left_pads[I1], + in_right_pads[I0], + in_right_pads[I1], + a_grid_desc_gk0_gm0_gm10_gm11_gk1_dev_buf, + b_grid_desc_gk0_gn0_gn10_gn11_gk1_dev_buf, + c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1_dev_buf, + c_grid_block_cluster_blockid_to_gm10_gn10_dev_buf); + timer2.End(); + + kernel_name = "dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw"; + auto network_config_2 = network_config + "_2"; + + timer2.Start(); + handle->AddKernel(algo_name, network_config_2, program_name, kernel_name, vld, vgd2, param)( + reinterpret_cast(wei_k_c_y_x_dev_buf.GetDeviceBuffer()), + reinterpret_cast(in_n_c_hi_wi_dev_buf.GetDeviceBuffer()), + reinterpret_cast(out_n_k_ho_wo_dev_buf.GetDeviceBuffer()), + (const void*)(a_grid_desc_gk0_gm0_gm10_gm11_gk1_dev_buf), + (const void*)(b_grid_desc_gk0_gn0_gn10_gn11_gk1_dev_buf), + (const void*)(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1_dev_buf), + (const void*)(c_grid_block_cluster_blockid_to_gm10_gn10_dev_buf)); + timer2.End(); + + kernel1_times.push_back(timer1.GetElapsedTime()); + kernel2_times.push_back(timer2.GetElapsedTime()); + } + + { + auto ave_time1 = Driver::get_effective_average(kernel1_times); + auto ave_time2 = Driver::get_effective_average(kernel2_times); + + const auto N = in_n_c_hi_wi_lengths[I0]; + const auto C = in_n_c_hi_wi_lengths[I1]; + + const auto K = out_n_k_ho_wo_lengths[I1]; + const auto Ho = out_n_k_ho_wo_lengths[I2]; + const auto Wo = out_n_k_ho_wo_lengths[I3]; + + const auto Y = wei_k_c_y_x_lengths[I2]; + const auto X = wei_k_c_y_x_lengths[I3]; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / (ave_time1 + ave_time2); + + std::cout << "Average time : " << ave_time1 + ave_time2 << " ms(" << ave_time1 << ", " + << ave_time2 << "), " << perf << " TFlop/s" << std::endl; + }; + + // copy result back to host + out_n_k_ho_wo_dev_buf.FromDevice(out_n_k_ho_wo.mData.data()); +} diff --git a/driver/include/olc_driver_common.hpp b/host/driver_online/include/online_driver_common.hpp similarity index 100% rename from driver/include/olc_driver_common.hpp rename to host/driver_online/include/online_driver_common.hpp diff --git a/host/host_tensor/CMakeLists.txt b/host/host_tensor/CMakeLists.txt new file mode 100644 index 0000000000..9c30275220 --- /dev/null +++ b/host/host_tensor/CMakeLists.txt @@ -0,0 +1,19 @@ +include_directories(BEFORE + include +) + +set(HOST_TENSOR_SOURCE + src/host_tensor.cpp; + src/device.cpp; +) + +## the library target +add_library(host_tensor SHARED ${HOST_TENSOR_SOURCE}) + +target_link_libraries(host_tensor PRIVATE hip::device) +target_link_libraries(host_tensor INTERFACE hip::host) + +target_compile_features(host_tensor PUBLIC) +set_target_properties(host_tensor PROPERTIES POSITION_INDEPENDENT_CODE ON) + +install(TARGETS host_tensor LIBRARY DESTINATION lib) diff --git a/driver/include/conv_common.hpp b/host/host_tensor/include/conv_common.hpp similarity index 100% rename from driver/include/conv_common.hpp rename to host/host_tensor/include/conv_common.hpp diff --git a/driver/include/device.hpp b/host/host_tensor/include/device.hpp similarity index 61% rename from driver/include/device.hpp rename to host/host_tensor/include/device.hpp index 7dac30a45a..2299e14921 100644 --- a/driver/include/device.hpp +++ b/host/host_tensor/include/device.hpp @@ -2,7 +2,8 @@ #define DEVICE_HPP #include -#include "config.hpp" +#include "hip/hip_runtime.h" +#include "hip/hip_fp16.h" struct DeviceMem { @@ -30,7 +31,6 @@ struct KernelTimer std::unique_ptr impl; }; -#if CK_DEVICE_BACKEND_AMD using device_stream_t = hipStream_t; template @@ -83,44 +83,4 @@ float launch_and_time_kernel(F kernel, return timer.GetElapsedTime() / nrepeat; } -#elif CK_DEVICE_BACKEND_NVIDIA -using device_stream_t = cudaStream_t; - -template -void launch_kernel(F kernel, - dim3 grid_dim, - dim3 block_dim, - std::size_t lds_byte, - cudaStream_t stream_id, - Args... args) -{ - const void* f = reinterpret_cast(kernel); - void* p_args[] = {&args...}; - - cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, stream_id); -} - -template -float launch_and_time_kernel(F kernel, - dim3 grid_dim, - dim3 block_dim, - std::size_t lds_byte, - cudaStream_t stream_id, - Args... args) -{ - KernelTimer timer; - - const void* f = reinterpret_cast(kernel); - void* p_args[] = {&args...}; - - timer.Start(); - - cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, stream_id); - - timer.End(); - - return timer.GetElapsedTime(); -} -#endif - #endif diff --git a/driver/include/device_tensor.hpp b/host/host_tensor/include/device_tensor.hpp similarity index 100% rename from driver/include/device_tensor.hpp rename to host/host_tensor/include/device_tensor.hpp diff --git a/driver/include/host_conv.hpp b/host/host_tensor/include/host_conv.hpp similarity index 100% rename from driver/include/host_conv.hpp rename to host/host_tensor/include/host_conv.hpp diff --git a/driver/include/host_conv_bwd_data.hpp b/host/host_tensor/include/host_conv_bwd_data.hpp similarity index 100% rename from driver/include/host_conv_bwd_data.hpp rename to host/host_tensor/include/host_conv_bwd_data.hpp diff --git a/driver/include/host_tensor.hpp b/host/host_tensor/include/host_tensor.hpp similarity index 100% rename from driver/include/host_tensor.hpp rename to host/host_tensor/include/host_tensor.hpp diff --git a/driver/include/host_tensor_generator.hpp b/host/host_tensor/include/host_tensor_generator.hpp similarity index 100% rename from driver/include/host_tensor_generator.hpp rename to host/host_tensor/include/host_tensor_generator.hpp diff --git a/driver/src/device.cpp b/host/host_tensor/src/device.cpp similarity index 50% rename from driver/src/device.cpp rename to host/host_tensor/src/device.cpp index 14f4792d26..d0d74a4c2a 100644 --- a/driver/src/device.cpp +++ b/host/host_tensor/src/device.cpp @@ -1,107 +1,59 @@ -#include "config.hpp" #include "device.hpp" DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size) { -#if CK_DEVICE_BACKEND_AMD hipGetErrorString(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); -#elif CK_DEVICE_BACKEND_NVIDIA - cudaMalloc(static_cast(&mpDeviceBuf), mMemSize); -#endif } void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; } void DeviceMem::ToDevice(const void* p) { -#if CK_DEVICE_BACKEND_AMD hipGetErrorString( hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); -#elif CK_DEVICE_BACKEND_NVIDIA - cudaMemcpy(mpDeviceBuf, const_cast(p), mMemSize, cudaMemcpyHostToDevice); -#endif } void DeviceMem::FromDevice(void* p) { -#if CK_DEVICE_BACKEND_AMD hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); -#elif CK_DEVICE_BACKEND_NVIDIA - cudaMemcpy(p, mpDeviceBuf, mMemSize, cudaMemcpyDeviceToHost); -#endif } -DeviceMem::~DeviceMem() -{ -#if CK_DEVICE_BACKEND_AMD - hipGetErrorString(hipFree(mpDeviceBuf)); -#elif CK_DEVICE_BACKEND_NVIDIA - cudaFree(mpDeviceBuf); -#endif -} +DeviceMem::~DeviceMem() { hipGetErrorString(hipFree(mpDeviceBuf)); } struct KernelTimerImpl { KernelTimerImpl() { -#if CK_DEVICE_BACKEND_AMD hipEventCreate(&mStart); hipEventCreate(&mEnd); -#elif CK_DEVICE_BACKEND_NVIDIA - cudaEventCreate(&mStart); - cudaEventCreate(&mEnd); -#endif } ~KernelTimerImpl() { -#if CK_DEVICE_BACKEND_AMD hipEventDestroy(mStart); hipEventDestroy(mEnd); -#elif CK_DEVICE_BACKEND_NVIDIA - cudaEventDestroy(mStart); - cudaEventDestroy(mEnd); -#endif } void Start() { -#if CK_DEVICE_BACKEND_AMD hipDeviceSynchronize(); hipEventRecord(mStart, 0); -#elif CK_DEVICE_BACKEND_NVIDIA - cudaDeviceSynchronize(); - cudaEventRecord(mStart, 0); -#endif } void End() { -#if CK_DEVICE_BACKEND_AMD hipEventRecord(mEnd, 0); hipEventSynchronize(mEnd); -#elif CK_DEVICE_BACKEND_NVIDIA - cudaEventRecord(mEnd, 0); - cudaEventSynchronize(mEnd); -#endif } float GetElapsedTime() const { float time; -#if CK_DEVICE_BACKEND_AMD hipEventElapsedTime(&time, mStart, mEnd); -#elif CK_DEVICE_BACKEND_NVIDIA - cudaEventElapsedTime(&time, mStart, mEnd); -#endif return time; } -#if CK_DEVICE_BACKEND_AMD hipEvent_t mStart, mEnd; -#elif CK_DEVICE_BACKEND_NVIDIA - cudaEvent_t mStart, mEnd; -#endif }; KernelTimer::KernelTimer() : impl(new KernelTimerImpl()) {} diff --git a/driver/src/host_tensor.cpp b/host/host_tensor/src/host_tensor.cpp similarity index 100% rename from driver/src/host_tensor.cpp rename to host/host_tensor/src/host_tensor.cpp diff --git a/driver/CMakeLists.txt b/host/online_compilation/CMakeLists.txt similarity index 50% rename from driver/CMakeLists.txt rename to host/online_compilation/CMakeLists.txt index ecc4d7091d..7bbfc65288 100644 --- a/driver/CMakeLists.txt +++ b/host/online_compilation/CMakeLists.txt @@ -1,4 +1,3 @@ - set(CMAKE_CXX_COMPILER /opt/rocm/llvm/bin/clang++) ## for online-compiling of HIP kernels @@ -17,6 +16,7 @@ if(OLC_HIP_COMPILER MATCHES ".*clang\\+\\+$") ${CMAKE_INSTALL_PREFIX}/llvm ) endif() + if(OLC_OFFLOADBUNDLER_BIN) message(STATUS "clang-offload-bundler found: ${OLC_OFFLOADBUNDLER_BIN}") set(OLC_OFFLOADBUNDLER_BIN "${OLC_OFFLOADBUNDLER_BIN}") @@ -67,92 +67,58 @@ else() set(OLC_DEBUG 0) endif() -configure_file("${CMAKE_CURRENT_SOURCE_DIR}/olCompiling/include/config.h.in" "${CMAKE_CURRENT_SOURCE_DIR}/olCompiling/include/config.h") +configure_file("${PROJECT_SOURCE_DIR}/host/online_compilation/include/config.h.in" "${PROJECT_BINARY_DIR}/host/online_compilation/include/config.h") + +include_directories(BEFORE + ${PROJECT_BINARY_DIR}/host/online_compilation/include +) message(STATUS "Hip compiler flags: ${HIP_COMPILER_FLAGS}") ## HIP_COMPILER_FLAGS will be used for on-line compiling of the HIP kernels add_definitions("-DHIP_COMPILER_FLAGS=${HIP_COMPILER_FLAGS}") -file(GLOB COMPOSABLE_KERNEL_INCLUDE_1 "${PROJECT_SOURCE_DIR}/composable_kernel/include/kernel_algorithm/*.hpp") -file(GLOB COMPOSABLE_KERNEL_INCLUDE_2 "${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description/*.hpp") -file(GLOB COMPOSABLE_KERNEL_INCLUDE_3 "${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation/*.hpp") -file(GLOB COMPOSABLE_KERNEL_INCLUDE_4 "${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/*.hpp") -file(GLOB COMPOSABLE_KERNEL_INCLUDE_5 "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/*.hpp") -file(GLOB COMPOSABLE_KERNEL_INCLUDE_6 "${PROJECT_SOURCE_DIR}/external/rocm/include/bfloat16_dev.hpp") +file(GLOB_RECURSE COMPOSABLE_KERNEL_INCLUDE_1 "${PROJECT_SOURCE_DIR}/composable_kernel/include/*/*.hpp") +file(GLOB COMPOSABLE_KERNEL_INCLUDE_2 "${PROJECT_SOURCE_DIR}/external/rocm/include/bfloat16_dev.hpp") set(MCONV_KERNEL_INCLUDES ${COMPOSABLE_KERNEL_INCLUDE_1} ${COMPOSABLE_KERNEL_INCLUDE_2} - ${COMPOSABLE_KERNEL_INCLUDE_3} - ${COMPOSABLE_KERNEL_INCLUDE_4} - ${COMPOSABLE_KERNEL_INCLUDE_5} - ${COMPOSABLE_KERNEL_INCLUDE_6} ) -set(MCONV_KERNELS - ../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.cpp - ../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.cpp - ../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.cpp - ../composable_kernel/src/kernel_wrapper/dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk.cpp - ) +file(GLOB_RECURSE MCONV_KERNELS "${PROJECT_SOURCE_DIR}/composable_kernel/src/kernel_wrapper/*.cpp") -add_kernels("olCompiling/" "${MCONV_KERNELS}") -add_kernel_includes("olCompiling/" "${MCONV_KERNEL_INCLUDES}") +add_kernels(${CMAKE_CURRENT_SOURCE_DIR} "${MCONV_KERNELS}") +add_kernel_includes(${CMAKE_CURRENT_SOURCE_DIR} "${MCONV_KERNEL_INCLUDES}") -set(MCONV_SOURCES - src/host_tensor.cpp; - src/device.cpp; +set(ONLINE_COMPILATION_SOURCE + ${PROJECT_BINARY_DIR}/kernel.cpp + ${PROJECT_BINARY_DIR}/kernel_includes.cpp ) -set(OLC_HIP_UTILITY_HEADERS - olCompiling/include/config.h - olCompiling/include/logger.hpp - olCompiling/include/stringutils.hpp - olCompiling/include/tmp_dir.hpp - olCompiling/include/write_file.hpp - olCompiling/include/env.hpp - olCompiling/include/manage_ptr.hpp - olCompiling/include/md5.hpp - olCompiling/include/simple_hash.hpp - olCompiling/include/exec_utils.hpp - olCompiling/include/hipCheck.hpp - olCompiling/include/target_properties.hpp - olCompiling/include/handle.hpp - olCompiling/include/op_kernel_args.hpp - olCompiling/include/kernel.hpp - olCompiling/include/kernel_build_params.hpp - olCompiling/include/hip_build_utils.hpp - olCompiling/include/hipoc_program.hpp - olCompiling/include/hipoc_program_impl.hpp - olCompiling/include/hipoc_kernel.hpp - olCompiling/include/kernel_cache.hpp - olCompiling/include/binary_cache.hpp - ) +include_directories(BEFORE + ${PROJECT_BINARY_DIR}/host/online_compilation/include + include +) set(OLC_HIP_UTILITY_CPPS - olCompiling/hip_utility/logger.cpp - olCompiling/hip_utility/tmp_dir.cpp - olCompiling/hip_utility/md5.cpp - olCompiling/hip_utility/exec_utils.cpp - olCompiling/hip_utility/target_properties.cpp - olCompiling/hip_utility/handlehip.cpp - olCompiling/hip_utility/kernel_build_params.cpp - olCompiling/hip_utility/hip_build_utils.cpp - olCompiling/hip_utility/hipoc_program.cpp - olCompiling/hip_utility/hipoc_kernel.cpp - olCompiling/hip_utility/kernel_cache.cpp - olCompiling/hip_utility/binary_cache.cpp + hip_utility/logger.cpp + hip_utility/tmp_dir.cpp + hip_utility/md5.cpp + hip_utility/exec_utils.cpp + hip_utility/target_properties.cpp + hip_utility/handlehip.cpp + hip_utility/kernel_build_params.cpp + hip_utility/hip_build_utils.cpp + hip_utility/hipoc_program.cpp + hip_utility/hipoc_kernel.cpp + hip_utility/kernel_cache.cpp + hip_utility/binary_cache.cpp ) list(APPEND OLC_SOURCES ${OLC_HIP_UTILITY_CPPS} ${OLC_HIP_UTILITY_HEADERS}) -list(INSERT MCONV_SOURCES 0 - ${PROJECT_BINARY_DIR}/kernel.cpp - ${PROJECT_BINARY_DIR}/kernel_includes.cpp - ) - ## addkernels provide the tool to create inlined kernels in one header -add_subdirectory(olCompiling/addkernels) +add_subdirectory(addkernels) function(inline_kernels_src KERNELS KERNEL_INCLUDES) set(KERNEL_SRC_HPP_FILENAME batch_all.cpp.hpp) @@ -166,7 +132,7 @@ function(inline_kernels_src KERNELS KERNEL_INCLUDES) COMMAND $ -target ${KERNEL_SRC_HPP_PATH} -extern -source ${KERNELS} COMMENT "Inlining All kernels" ) - configure_file(olCompiling/kernels_batch.cpp.in ${KERNEL_SRC_CPP_PATH}) + configure_file(kernels_batch.cpp.in ${KERNEL_SRC_CPP_PATH}) list(APPEND OLC_SOURCES ${KERNEL_SRC_CPP_PATH} ${KERNEL_SRC_HPP_PATH}) set(OLC_SOURCES ${OLC_SOURCES} PARENT_SCOPE) @@ -174,7 +140,7 @@ endfunction() inline_kernels_src("${MCONV_KERNELS}" "${MCONV_KERNEL_INCLUDES}") -list(APPEND MCONV_SOURCES ${OLC_SOURCES} ${PROJECT_BINARY_DIR}/olc_kernel_includes.h) +list(APPEND ONLINE_COMPILATION_SOURCE ${OLC_SOURCES} ${PROJECT_BINARY_DIR}/olc_kernel_includes.h) add_custom_command( OUTPUT ${PROJECT_BINARY_DIR}/olc_kernel_includes.h @@ -185,19 +151,17 @@ add_custom_command( ) ## the library target -add_library(modConv SHARED ${MCONV_SOURCES}) +add_library(online_compilation SHARED ${ONLINE_COMPILATION_SOURCE}) -target_include_directories(modConv PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/olCompiling/include/) -target_include_directories(modConv PRIVATE ${PROJECT_BINARY_DIR}) -target_include_directories(modConv PRIVATE ${PROJECT_SOURCE_DIR}/external/half/include/) +target_include_directories(online_compilation PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/online_compilation/include/) +target_include_directories(online_compilation PRIVATE ${PROJECT_BINARY_DIR}) +target_include_directories(online_compilation PRIVATE ${PROJECT_SOURCE_DIR}/external/half/include/) -target_link_libraries(modConv PRIVATE hip::device) -target_link_libraries(modConv INTERFACE hip::host) -target_link_libraries(modConv PRIVATE Boost::filesystem) +target_link_libraries(online_compilation PRIVATE hip::device) +target_link_libraries(online_compilation INTERFACE hip::host) +target_link_libraries(online_compilation PRIVATE Boost::filesystem) -target_compile_options(modConv PRIVATE -mfma) +target_compile_features(online_compilation PUBLIC) +set_target_properties(online_compilation PROPERTIES POSITION_INDEPENDENT_CODE ON) -target_compile_features(modConv PUBLIC) -set_target_properties(modConv PROPERTIES POSITION_INDEPENDENT_CODE ON) - -install(TARGETS modConv LIBRARY DESTINATION lib) +install(TARGETS online_compilation LIBRARY DESTINATION lib) diff --git a/driver/olCompiling/addkernels/CMakeLists.txt b/host/online_compilation/addkernels/CMakeLists.txt similarity index 100% rename from driver/olCompiling/addkernels/CMakeLists.txt rename to host/online_compilation/addkernels/CMakeLists.txt diff --git a/driver/olCompiling/addkernels/addkernels.cpp b/host/online_compilation/addkernels/addkernels.cpp similarity index 100% rename from driver/olCompiling/addkernels/addkernels.cpp rename to host/online_compilation/addkernels/addkernels.cpp diff --git a/driver/olCompiling/addkernels/include_inliner.cpp b/host/online_compilation/addkernels/include_inliner.cpp similarity index 100% rename from driver/olCompiling/addkernels/include_inliner.cpp rename to host/online_compilation/addkernels/include_inliner.cpp diff --git a/driver/olCompiling/addkernels/include_inliner.hpp b/host/online_compilation/addkernels/include_inliner.hpp similarity index 100% rename from driver/olCompiling/addkernels/include_inliner.hpp rename to host/online_compilation/addkernels/include_inliner.hpp diff --git a/driver/olCompiling/addkernels/source_file_desc.hpp b/host/online_compilation/addkernels/source_file_desc.hpp similarity index 100% rename from driver/olCompiling/addkernels/source_file_desc.hpp rename to host/online_compilation/addkernels/source_file_desc.hpp diff --git a/driver/olCompiling/hip_utility/binary_cache.cpp b/host/online_compilation/hip_utility/binary_cache.cpp similarity index 100% rename from driver/olCompiling/hip_utility/binary_cache.cpp rename to host/online_compilation/hip_utility/binary_cache.cpp diff --git a/driver/olCompiling/hip_utility/exec_utils.cpp b/host/online_compilation/hip_utility/exec_utils.cpp similarity index 100% rename from driver/olCompiling/hip_utility/exec_utils.cpp rename to host/online_compilation/hip_utility/exec_utils.cpp diff --git a/driver/olCompiling/hip_utility/handlehip.cpp b/host/online_compilation/hip_utility/handlehip.cpp similarity index 100% rename from driver/olCompiling/hip_utility/handlehip.cpp rename to host/online_compilation/hip_utility/handlehip.cpp diff --git a/driver/olCompiling/hip_utility/hip_build_utils.cpp b/host/online_compilation/hip_utility/hip_build_utils.cpp similarity index 100% rename from driver/olCompiling/hip_utility/hip_build_utils.cpp rename to host/online_compilation/hip_utility/hip_build_utils.cpp diff --git a/driver/olCompiling/hip_utility/hipoc_kernel.cpp b/host/online_compilation/hip_utility/hipoc_kernel.cpp similarity index 100% rename from driver/olCompiling/hip_utility/hipoc_kernel.cpp rename to host/online_compilation/hip_utility/hipoc_kernel.cpp diff --git a/driver/olCompiling/hip_utility/hipoc_program.cpp b/host/online_compilation/hip_utility/hipoc_program.cpp similarity index 100% rename from driver/olCompiling/hip_utility/hipoc_program.cpp rename to host/online_compilation/hip_utility/hipoc_program.cpp diff --git a/driver/olCompiling/hip_utility/kernel_build_params.cpp b/host/online_compilation/hip_utility/kernel_build_params.cpp similarity index 100% rename from driver/olCompiling/hip_utility/kernel_build_params.cpp rename to host/online_compilation/hip_utility/kernel_build_params.cpp diff --git a/driver/olCompiling/hip_utility/kernel_cache.cpp b/host/online_compilation/hip_utility/kernel_cache.cpp similarity index 100% rename from driver/olCompiling/hip_utility/kernel_cache.cpp rename to host/online_compilation/hip_utility/kernel_cache.cpp diff --git a/driver/olCompiling/hip_utility/logger.cpp b/host/online_compilation/hip_utility/logger.cpp similarity index 100% rename from driver/olCompiling/hip_utility/logger.cpp rename to host/online_compilation/hip_utility/logger.cpp diff --git a/driver/olCompiling/hip_utility/md5.cpp b/host/online_compilation/hip_utility/md5.cpp similarity index 100% rename from driver/olCompiling/hip_utility/md5.cpp rename to host/online_compilation/hip_utility/md5.cpp diff --git a/driver/olCompiling/hip_utility/target_properties.cpp b/host/online_compilation/hip_utility/target_properties.cpp similarity index 100% rename from driver/olCompiling/hip_utility/target_properties.cpp rename to host/online_compilation/hip_utility/target_properties.cpp diff --git a/driver/olCompiling/hip_utility/tmp_dir.cpp b/host/online_compilation/hip_utility/tmp_dir.cpp similarity index 100% rename from driver/olCompiling/hip_utility/tmp_dir.cpp rename to host/online_compilation/hip_utility/tmp_dir.cpp diff --git a/driver/olCompiling/include/binary_cache.hpp b/host/online_compilation/include/binary_cache.hpp similarity index 100% rename from driver/olCompiling/include/binary_cache.hpp rename to host/online_compilation/include/binary_cache.hpp diff --git a/driver/olCompiling/include/config.h.in b/host/online_compilation/include/config.h.in similarity index 100% rename from driver/olCompiling/include/config.h.in rename to host/online_compilation/include/config.h.in diff --git a/driver/olCompiling/include/env.hpp b/host/online_compilation/include/env.hpp similarity index 100% rename from driver/olCompiling/include/env.hpp rename to host/online_compilation/include/env.hpp diff --git a/driver/olCompiling/include/exec_utils.hpp b/host/online_compilation/include/exec_utils.hpp similarity index 100% rename from driver/olCompiling/include/exec_utils.hpp rename to host/online_compilation/include/exec_utils.hpp diff --git a/driver/olCompiling/include/handle.hpp b/host/online_compilation/include/handle.hpp similarity index 100% rename from driver/olCompiling/include/handle.hpp rename to host/online_compilation/include/handle.hpp diff --git a/driver/olCompiling/include/hipCheck.hpp b/host/online_compilation/include/hipCheck.hpp similarity index 100% rename from driver/olCompiling/include/hipCheck.hpp rename to host/online_compilation/include/hipCheck.hpp diff --git a/driver/olCompiling/include/hip_build_utils.hpp b/host/online_compilation/include/hip_build_utils.hpp similarity index 100% rename from driver/olCompiling/include/hip_build_utils.hpp rename to host/online_compilation/include/hip_build_utils.hpp diff --git a/driver/olCompiling/include/hipoc_kernel.hpp b/host/online_compilation/include/hipoc_kernel.hpp similarity index 100% rename from driver/olCompiling/include/hipoc_kernel.hpp rename to host/online_compilation/include/hipoc_kernel.hpp diff --git a/driver/olCompiling/include/hipoc_program.hpp b/host/online_compilation/include/hipoc_program.hpp similarity index 100% rename from driver/olCompiling/include/hipoc_program.hpp rename to host/online_compilation/include/hipoc_program.hpp diff --git a/driver/olCompiling/include/hipoc_program_impl.hpp b/host/online_compilation/include/hipoc_program_impl.hpp similarity index 100% rename from driver/olCompiling/include/hipoc_program_impl.hpp rename to host/online_compilation/include/hipoc_program_impl.hpp diff --git a/driver/olCompiling/include/kernel.hpp b/host/online_compilation/include/kernel.hpp similarity index 100% rename from driver/olCompiling/include/kernel.hpp rename to host/online_compilation/include/kernel.hpp diff --git a/driver/olCompiling/include/kernel_build_params.hpp b/host/online_compilation/include/kernel_build_params.hpp similarity index 100% rename from driver/olCompiling/include/kernel_build_params.hpp rename to host/online_compilation/include/kernel_build_params.hpp diff --git a/driver/olCompiling/include/kernel_cache.hpp b/host/online_compilation/include/kernel_cache.hpp similarity index 100% rename from driver/olCompiling/include/kernel_cache.hpp rename to host/online_compilation/include/kernel_cache.hpp diff --git a/driver/olCompiling/include/logger.hpp b/host/online_compilation/include/logger.hpp similarity index 100% rename from driver/olCompiling/include/logger.hpp rename to host/online_compilation/include/logger.hpp diff --git a/driver/olCompiling/include/manage_ptr.hpp b/host/online_compilation/include/manage_ptr.hpp similarity index 100% rename from driver/olCompiling/include/manage_ptr.hpp rename to host/online_compilation/include/manage_ptr.hpp diff --git a/driver/olCompiling/include/md5.hpp b/host/online_compilation/include/md5.hpp similarity index 100% rename from driver/olCompiling/include/md5.hpp rename to host/online_compilation/include/md5.hpp diff --git a/driver/olCompiling/include/op_kernel_args.hpp b/host/online_compilation/include/op_kernel_args.hpp similarity index 100% rename from driver/olCompiling/include/op_kernel_args.hpp rename to host/online_compilation/include/op_kernel_args.hpp diff --git a/driver/olCompiling/include/simple_hash.hpp b/host/online_compilation/include/simple_hash.hpp similarity index 100% rename from driver/olCompiling/include/simple_hash.hpp rename to host/online_compilation/include/simple_hash.hpp diff --git a/driver/olCompiling/include/stringutils.hpp b/host/online_compilation/include/stringutils.hpp similarity index 100% rename from driver/olCompiling/include/stringutils.hpp rename to host/online_compilation/include/stringutils.hpp diff --git a/driver/olCompiling/include/target_properties.hpp b/host/online_compilation/include/target_properties.hpp similarity index 100% rename from driver/olCompiling/include/target_properties.hpp rename to host/online_compilation/include/target_properties.hpp diff --git a/driver/olCompiling/include/tmp_dir.hpp b/host/online_compilation/include/tmp_dir.hpp similarity index 100% rename from driver/olCompiling/include/tmp_dir.hpp rename to host/online_compilation/include/tmp_dir.hpp diff --git a/driver/olCompiling/include/write_file.hpp b/host/online_compilation/include/write_file.hpp similarity index 100% rename from driver/olCompiling/include/write_file.hpp rename to host/online_compilation/include/write_file.hpp diff --git a/driver/olCompiling/kernel.cpp.in b/host/online_compilation/kernel.cpp.in similarity index 100% rename from driver/olCompiling/kernel.cpp.in rename to host/online_compilation/kernel.cpp.in diff --git a/driver/olCompiling/kernel_includes.cpp.in b/host/online_compilation/kernel_includes.cpp.in similarity index 100% rename from driver/olCompiling/kernel_includes.cpp.in rename to host/online_compilation/kernel_includes.cpp.in diff --git a/driver/olCompiling/kernels_batch.cpp.in b/host/online_compilation/kernels_batch.cpp.in similarity index 100% rename from driver/olCompiling/kernels_batch.cpp.in rename to host/online_compilation/kernels_batch.cpp.in diff --git a/script/cmake-rocm.sh b/script/cmake-rocm.sh index 7a31c69dcb..9a02b68e07 100755 --- a/script/cmake-rocm.sh +++ b/script/cmake-rocm.sh @@ -9,8 +9,7 @@ MY_PROJECT_INSTALL=../install.dir cmake \ -D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \ -D CMAKE_BUILD_TYPE=Release \ --D DEVICE_BACKEND=AMD \ --D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx908 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$CWD" \ +-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx908 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$PWD" \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ diff --git a/script/run.sh b/script/run.sh index 1a76adb876..ecb5c85d81 100755 --- a/script/run.sh +++ b/script/run.sh @@ -5,19 +5,19 @@ export GPU_DEVICE_ORDINAL=0 ## Boost -#export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH + export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH ## Compiling - export OLC_DEBUG_HIP_VERBOSE=1 - export OLC_DEBUG_HIP_DUMP=1 - export OLC_DEBUG_SAVE_TEMP_DIR=1 +#export OLC_DEBUG_HIP_VERBOSE=1 +#export OLC_DEBUG_HIP_DUMP=1 +#export OLC_DEBUG_SAVE_TEMP_DIR=1 -#make -j conv_driver_v2 -#make -j conv_bwd_data_driver_v2 - make -j conv_driver_v2_olc + make -j conv_fwd_driver_offline + make -j conv_bwd_driver_offline + make -j conv_fwd_driver_online - rm -rf /root/_hip_binary_kernels_/ - rm -rf /tmp/olCompile* +#rm -rf /root/_hip_binary_kernels_/ +#rm -rf /tmp/olCompile* LAYOUT=$1 ALGO=$2 @@ -26,21 +26,22 @@ INIT=$4 LOG=$5 REPEAT=$6 -################################ layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads - ./conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 -#./conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 -#./conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 2048 3 3 14 14 1 1 1 1 1 1 1 1 -#./conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 -#./conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 +################################################ layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 + ./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 -#./conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0 -#./conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 -#./conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0 +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 -#./conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 -#./conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 -#./conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 -#./conv_bwd_data_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#./host/driver_offline/conv_bwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 - ./conv_driver_v2_olc $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 +#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1