diff --git a/CMakeLists.txt b/CMakeLists.txt index a4b016608d..d8e51761bd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -18,26 +18,32 @@ include_directories(BEFORE ${Boost_INCLUDE_DIRS}) link_directories(${Boost_LIBRARY_DIRS}) #OpenMP -if( NOT( ${CMAKE_CXX_COMPILER_ID} STREQUAL "AppleClang") ) - find_package(OpenMP REQUIRED) +if(CMAKE_CXX_COMPILER_ID MATCHES "Clang") + # workaround issue hipcc in rocm3.5 cannot find openmp + set(OpenMP_CXX "${CMAKE_CXX_COMPILER}") + set(OpenMP_CXX_FLAGS "-fopenmp=libomp -Wno-unused-command-line-argument") + set(OpenMP_CXX_LIB_NAMES "libomp" "libgomp" "libiomp5") + set(OpenMP_libomp_LIBRARY ${OpenMP_CXX_LIB_NAMES}) + set(OpenMP_libgomp_LIBRARY ${OpenMP_CXX_LIB_NAMES}) + set(OpenMP_libiomp5_LIBRARY ${OpenMP_CXX_LIB_NAMES}) +else() + find_package(OpenMP REQUIRED) +endif() - message("OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}") - message("OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}") - message("OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}") - message("OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}") +message("OpenMP_CXX_LIB_NAMES: ${OpenMP_CXX_LIB_NAMES}") +message("OpenMP_gomp_LIBRARY: ${OpenMP_gomp_LIBRARY}") +message("OpenMP_pthread_LIBRARY: ${OpenMP_pthread_LIBRARY}") +message("OpenMP_CXX_FLAGS: ${OpenMP_CXX_FLAGS}") - set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") - link_libraries(${OpenMP_gomp_LIBRARY}) - link_libraries(${OpenMP_pthread_LIBRARY}) -endif( NOT( ${CMAKE_CXX_COMPILER_ID} STREQUAL "AppleClang") ) +set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") +link_libraries(${OpenMP_gomp_LIBRARY}) +link_libraries(${OpenMP_pthread_LIBRARY}) #GPU backend if(DEVICE_BACKEND STREQUAL "AMD") - set(CMAKE_MODULE_PATH "/opt/rocm/hip/cmake" ${CMAKE_MODULE_PATH}) find_package(HIP REQUIRED) elseif(DEVICE_BACKEND STREQUAL "NVIDIA") - enable_language(CUDA) - include_directories(BEFORE ${CUDA_COMMON_INCLUDE_DIR}) + enable_language(CUDA) endif() # @@ -47,19 +53,27 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation ${PROJECT_SOURCE_DIR}/composable_kernel/include/kernel_algorithm - ${PROJECT_SOURCE_DIR}/external/include + ${PROJECT_SOURCE_DIR}/external/half/include ${PROJECT_SOURCE_DIR}/driver/include ${PROJECT_BINARY_DIR}/composable_kernel/include/utility ) +if(DEVICE_BACKEND STREQUAL "AMD") + include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/external/rocm/include + ) +endif() + if(DEVICE_BACKEND STREQUAL "AMD") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/in_memory_operation.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/in_memory_operation.hpp") + configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/synchronization.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/synchronization.hpp") elseif(DEVICE_BACKEND STREQUAL "NVIDIA") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/in_memory_operation.hpp") + configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/synchronization.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/synchronization.hpp") endif() add_subdirectory(driver) diff --git a/composable_kernel/include/gridwise_convolution_kernel_wrapper.hpp b/composable_kernel/include/gridwise_convolution_kernel_wrapper.hpp deleted file mode 100644 index 2269e72579..0000000000 --- a/composable_kernel/include/gridwise_convolution_kernel_wrapper.hpp +++ /dev/null @@ -1,12 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_KERNEL_WRAPPER -#define CK_GRIDWISE_CONVOLUTION_KERNEL_WRAPPER - -template -__global__ void run_gridwise_convolution_kernel(const T* const __restrict__ p_in_global, - const T* const __restrict__ p_wei_global, - T* const __restrict__ p_out_global) -{ - GridwiseConvolution{}.Run(p_in_global, p_wei_global, p_out_global); -} - -#endif diff --git a/composable_kernel/include/gridwise_operation_wrapper.hpp b/composable_kernel/include/gridwise_operation_wrapper.hpp index 9c99ee3555..746e41ce33 100644 --- a/composable_kernel/include/gridwise_operation_wrapper.hpp +++ b/composable_kernel/include/gridwise_operation_wrapper.hpp @@ -2,7 +2,7 @@ #define CK_GRIDWISE_OPERATION_KERNEL_WRAPPER template -__global__ void run_gridwise_operation(GridwiseOp, Xs... xs) +__global__ void run_gridwise_operation(Xs... xs) { GridwiseOp{}.Run(xs...); } diff --git a/composable_kernel/include/kernel_algorithm/convolution_common.hpp b/composable_kernel/include/kernel_algorithm/convolution_common.hpp deleted file mode 100644 index 4bcb3347ab..0000000000 --- a/composable_kernel/include/kernel_algorithm/convolution_common.hpp +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef CK_CONVOLUTION_COMMON_HPP -#define CK_CONVOLUTION_COMMON_HPP - -namespace ck { - -enum ConvolutionDirection -{ - Forward, - BackwardData, - BackwardWeight -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_col2im_eb_nchw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_col2im_eb_nchw.hpp deleted file mode 100644 index 74a2b65571..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_col2im_eb_nchw.hpp +++ /dev/null @@ -1,130 +0,0 @@ -#ifndef CK_GRIDWISE_COL2IM_EB_NCHW_HPP -#define CK_GRIDWISE_COL2IM_EB_NCHW_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "blockwise_generic_tensor_slice_copy.hpp" - -namespace ck { - -// B = merge(N, Ho, Wo) -template -struct GridwiseCol2Im_eb_nchw -{ - __device__ void Run(const Float* const __restrict__ p_col_global, - Float* const __restrict__ p_img_global) const - { - constexpr auto col_e_b_global_desc = ColGlobalDesc{}; - constexpr auto img_n_c_hi_wi_global_desc = ImgGlobalDesc{}; - - constexpr index_t N = img_n_c_hi_wi_global_desc.GetLengths()[0]; - constexpr index_t C = img_n_c_hi_wi_global_desc.GetLengths()[1]; - constexpr index_t Hi = img_n_c_hi_wi_global_desc.GetLengths()[2]; - constexpr index_t Wi = img_n_c_hi_wi_global_desc.GetLengths()[3]; - - constexpr index_t Ho = OutputSizes{}[0]; - constexpr index_t Wo = OutputSizes{}[1]; - - constexpr index_t Y = FilterSizes{}[0]; - constexpr index_t X = FilterSizes{}[1]; - - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - constexpr index_t E = C * Y * X; - constexpr index_t B = N * Ho * Wo; - - // sanity-check for vectorized memory load - static_assert((Wo == 1 || (ConvStrideW == 1 || BlockCopyDataPerAccess_B == 1)) && - (X == 1 || ConvDilationW % BlockCopyDataPerAccess_B == 0), - "wrong! aligment requirement for vectorized global load of input tensor will " - "be violated"); - - // divide block work by [E, B] - static_assert(E % EPerBlock == 0 && B % BPerBlock == 0, - "wrong! cannot divide work evenly among block"); - - constexpr index_t EBlockWork = E / EPerBlock; - constexpr index_t BBlockWork = B / BPerBlock; - - constexpr auto block_work_desc = - make_cluster_descriptor(Sequence{}); - - const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id()); - - const index_t e_block_data_on_global = block_work_id[0] * EPerBlock; - const index_t b_block_data_on_global = block_work_id[1] * BPerBlock; - - // construct img_eb_global_desc - constexpr auto img_n_c_hip_wip_global_desc = transform_tensor_descriptor( - img_n_c_hi_wi_global_desc, - make_tuple( - PassThrough{}, PassThrough{}, Pad, LeftPads, RightPads>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); - - constexpr auto img_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( - img_n_c_hip_wip_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - Embed, Sequence>{}, - Embed, Sequence>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - constexpr auto img_e_b_global_desc = transform_tensor_descriptor( - img_n_c_y_ho_x_wo_global_desc, - make_tuple(Merge>{}, Merge>{}), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // blockwise atomic accumulation - auto blockwise_copy = BlockwiseGenericTensorSliceCopy_v4, - BlockCopySubLengths_E_B, - BlockCopyClusterLengths_E_B, - BlockCopyThreadClusterArrangeOrder, - BlockCopySrcAccessOrder, - BlockCopyDstAccessOrder, - 1, - 1, - BlockCopyDataPerAccess_B, - BlockCopyDataPerAccess_B, - AddressSpace::Vgpr, - AddressSpace::Vgpr, - AddressSpace::Global, - InMemoryDataOperation::AtomicAdd>( - {e_block_data_on_global, b_block_data_on_global}, - {e_block_data_on_global, b_block_data_on_global}); - - // blockwise copy - blockwise_copy.Run(p_col_global, p_img_global); - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp index 561969320d..71a9bb6dc0 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp @@ -36,8 +36,8 @@ template {}, Merge>{}), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - // input tensor constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( in_n_c_hi_wi_global_desc, @@ -98,16 +91,15 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); + constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2]; + constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3]; + constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( in_n_c_hip_wip_global_desc, make_tuple(PassThrough{}, PassThrough{}, - Embed, - Sequence>{}, - Embed, - Sequence>{}), + Embed, Sequence>{}, + Embed, Sequence>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); @@ -117,6 +109,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); + // output tensor + constexpr auto out_gemmk_gemmn_global_desc = + transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3), + make_tuple(PassThrough{}, Merge>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + // GEMM // \todo there are more combinations of Y, ConvDilationH and ConvStrideH that don't need // atomic, find out all of them @@ -152,8 +151,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw Sequence<0, 1>, Sequence<0, 1>, 1, - GemmABlockCopySrcDataPerRead_GemmN, - GemmABlockCopyDstDataPerWrite_GemmN, + GemmABlockCopySrcDataPerRead_GemmM, + GemmABlockCopyDstDataPerWrite_GemmM, GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, Sequence<0, 1>, diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp index 0fdb15a440..286a1c995b 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp @@ -25,13 +25,13 @@ template {}, Number{}); + Number{}, Number{}); const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< BlockSize, decltype(a_k_ec0_block_mtx_desc), decltype(b_k_bn0_block_mtx_desc), decltype(c_e0e1c0_b0b1n0_thread_mtx_desc), - GemmMPerThreadSubC, - GemmNPerThreadSubC, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, GemmMLevel0Cluster, GemmNLevel0Cluster, GemmMLevel1Cluster, GemmNLevel1Cluster, - GemmKPerThreadLoop, GemmDataPerReadA, GemmDataPerReadB>{}; @@ -371,7 +371,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl // define input tensor descriptor for threadwise copy // thread input tensor, src of threadwise copy constexpr auto in_e0_e1_c0_b0_b1_n0_thread_desc = make_native_tensor_descriptor_packed( - Sequence{}); + Sequence{}); // global input tensor, dst of threadwise copy constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( @@ -419,10 +419,10 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); const index_t e_thread_data_on_global = - e_block_data_on_global + c_thread_mtx_on_block.row / GemmMPerThreadSubC; + e_block_data_on_global + c_thread_mtx_on_block.row / GemmMPerThread; const index_t b_thread_data_on_global = - b_block_data_on_global + c_thread_mtx_on_block.col / GemmNPerThreadSubC; + b_block_data_on_global + c_thread_mtx_on_block.col / GemmNPerThread; ThreadwiseGenericTensorSliceCopy_v4r2< decltype(in_e0_e1_c0_b0_b1_n0_thread_desc), diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index d4e9da5e54..1eaf724f0f 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -419,7 +419,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw template __device__ static void Run(Float* __restrict__ p_in_global, const Float* __restrict__ p_wei_global, - const Float* __restrict__ p_out_global) + const Float* __restrict__ p_out_global, + Number) { constexpr index_t ConvStrideH = ConvStrides{}[0]; constexpr index_t ConvStrideW = ConvStrides{}[1]; diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_direct_v2_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_direct_v2_nchw_kcyx_nkhw.hpp deleted file mode 100644 index bec97d28c7..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_direct_v2_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,255 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_DIRECT_V2_NCHW_KCYX_NKHW -#define CK_GRIDWISE_CONVOLUTION_DIRECT_V2_NCHW_KCYX_NKHW - -#include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" -#include "blockwise_2d_tensor_op.hpp" -#include "blockwise_4d_tensor_op.hpp" -#include "threadwise_tensor_slice_copy.hpp" -#include "threadwise_direct_convolution.hpp" - -namespace ck { - -template -struct GridwiseConvolutionDirect_v2_nchw_kcyx_nkhw -{ - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_global_desc = InGlobalDesc{}; - constexpr auto wei_kcyx_global_desc = WeiGlobalDesc{}; - constexpr auto out_nkhw_global_desc = OutGlobalDesc{}; - - constexpr index_t N = in_nchw_global_desc.GetLength(I0); - constexpr index_t K = wei_kcyx_global_desc.GetLength(I0); - constexpr index_t C = wei_kcyx_global_desc.GetLength(I1); - constexpr index_t Y = wei_kcyx_global_desc.GetLength(I2); - constexpr index_t X = wei_kcyx_global_desc.GetLength(I3); - - constexpr auto wei_ke_global_desc = make_ConstantTensorDescriptor_packed( - Sequence{}); // 2d view of wei for blockwise copy - - constexpr index_t HiPerBlock = HoPerBlock + Y - 1; - constexpr index_t WiPerBlock = WoPerBlock + X - 1; - - constexpr auto in_nchw_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, - Number{}); - - constexpr auto wei_ke_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, - Number{}); // 2d view of wei for blockwise copy - - constexpr auto wei_kcyx_block_desc = - make_ConstantTensorDescriptor(Sequence{}, - Sequence{}); - - // shared mem - constexpr index_t in_block_element_size = - in_nchw_block_desc.GetElementSpace(Number{}); - constexpr index_t wei_block_element_size = - wei_kcyx_block_desc.GetElementSpace(Number{}); - - constexpr index_t max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead - ? InBlockCopyDataPerRead - : WeiBlockCopyDataPerRead; - - __shared__ Float - p_in_block[max_align * ((in_block_element_size + max_align - 1) / max_align)]; - __shared__ Float - p_wei_block[max_align * ((wei_block_element_size + max_align - 1) / max_align)]; - - // threadwise tensors - constexpr index_t HiPerThread = HoPerThread + Y - 1; - constexpr index_t WiPerThread = WoPerThread + X - 1; - - constexpr auto in_nchw_thread_block_desc = make_ConstantTensorDescriptor( - Sequence{}, - in_nchw_block_desc.GetStrides()); - - constexpr auto wei_kcyx_thread_block_desc = make_ConstantTensorDescriptor( - Sequence{}, wei_kcyx_block_desc.GetStrides()); - - constexpr auto out_nkhw_thread_desc = - get_convolution_output_default_4d_tensor_descriptor_deprecated( - in_nchw_thread_block_desc, wei_kcyx_thread_block_desc); - - // register - Float p_out_thread[out_nkhw_thread_desc.GetElementSpace()]; - - // divide block work - constexpr index_t NBlockWork = - (out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; - constexpr index_t KBlockWork = - (out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock; - constexpr index_t HBlockWork = - (out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; - constexpr index_t WBlockWork = - (out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; - - const index_t block_id = blockIdx.x; - - index_t itmp = block_id; - const index_t n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork); - itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork); - const index_t k_block_work_id = itmp / (HBlockWork * WBlockWork); - itmp -= k_block_work_id * (HBlockWork * WBlockWork); - const index_t h_block_work_id = itmp / WBlockWork; - const index_t w_block_work_id = itmp - h_block_work_id * WBlockWork; - - const index_t n_block_data_begin = n_block_work_id * NPerBlock; - const index_t k_block_data_begin = k_block_work_id * KPerBlock; - const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; - const index_t wo_block_data_begin = w_block_work_id * WoPerBlock; - - const index_t hi_block_data_begin = ho_block_data_begin; // minus padding - const index_t wi_block_data_begin = wo_block_data_begin; // minus padding - - // divide thread work - constexpr index_t NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread; - constexpr index_t KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread; - constexpr index_t HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread; - constexpr index_t WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread; - - const index_t thread_id = get_thread_local_1d_id(); - - itmp = thread_id; - const index_t n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork); - itmp -= n_thread_work_id * (KThreadWork * HThreadWork * WThreadWork); - const index_t k_thread_work_id = itmp / (HThreadWork * WThreadWork); - itmp -= k_thread_work_id * (HThreadWork * WThreadWork); - const index_t h_thread_work_id = itmp / WThreadWork; - const index_t w_thread_work_id = itmp - h_thread_work_id * WThreadWork; - - const index_t n_thread_data_begin = n_thread_work_id * NPerThread; - const index_t k_thread_data_begin = k_thread_work_id * KPerThread; - const index_t ho_thread_data_begin = h_thread_work_id * HoPerThread; - const index_t wo_thread_data_begin = w_thread_work_id * WoPerThread; - - const index_t hi_thread_data_begin = ho_thread_data_begin; - const index_t wi_thread_data_begin = wo_thread_data_begin; - - constexpr auto blockwise_in_copy = - Blockwise4dTensorCopy1{}; - -#if 0 - constexpr auto blockwise_wei_copy = - Blockwise4dTensorCopy1{}; -#elif 1 - const auto blockwise_wei_copy = - Blockwise2dTensorCopy3({0, 0}, {0, 0}); -#endif - - // set threadwise output tensor to 0 - threadwise_4d_tensor_set_zero(out_nkhw_thread_desc, p_out_thread); - - for(index_t c_block_data_begin = 0; c_block_data_begin < C; - c_block_data_begin += CPerBlock, __syncthreads()) - { - // copy input tensor to LDS - blockwise_in_copy.Run( - p_in_global + - in_nchw_global_desc.GetOffsetFromMultiIndex(n_block_data_begin, - c_block_data_begin, - hi_block_data_begin, - wi_block_data_begin), - p_in_block); - - // copy weight tensor to LDS - blockwise_wei_copy.Run(p_wei_global + - wei_kcyx_global_desc.GetOffsetFromMultiIndex( - k_block_data_begin, c_block_data_begin, 0, 0), - p_wei_block); - - __syncthreads(); - - for(index_t c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread) - { -// threadwise convolution -#if 1 - threadwise_direct_convolution_2( - in_nchw_thread_block_desc, - p_in_block + - in_nchw_block_desc.GetOffsetFromMultiIndex(n_thread_data_begin, - c_thread_data, - hi_thread_data_begin, - wi_thread_data_begin), - wei_kcyx_thread_block_desc, - p_wei_block + - wei_kcyx_block_desc.GetOffsetFromMultiIndex( - k_thread_data_begin, c_thread_data, 0, 0), - out_nkhw_thread_desc, - p_out_thread); -#elif 0 - threadwise_direct_convolution_3( - in_nchw_thread_block_desc, - p_in_block + - in_nchw_block_desc.GetOffsetFromMultiIndex(n_thread_data_begin, - c_thread_data, - hi_thread_data_begin, - wi_thread_data_begin), - wei_kcyx_thread_block_desc, - p_wei_block + - wei_kcyx_block_desc.GetOffsetFromMultiIndex( - k_thread_data_begin, c_thread_data, 0, 0), - out_nkhw_thread_desc, - p_out_thread); -#endif - } - } - - // copy output tensor from register to global mem - threadwise_tensor_slice_copy(out_nkhw_thread_desc, - p_out_thread, - out_nkhw_global_desc, - p_out_global + - out_nkhw_global_desc.GetOffsetFromMultiIndex( - n_block_data_begin + n_thread_data_begin, - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin), - out_nkhw_thread_desc.GetLengths(), - Number<1>{}); - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hpp deleted file mode 100644 index d33a4adf96..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hpp +++ /dev/null @@ -1,398 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R1_CHWN_CYXK_KHWN -#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R1_CHWN_CYXK_KHWN - -#include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "blockwise_4d_tensor_op.hpp" -#include "blockwise_2d_tensor_op.hpp" -#include "threadwise_tensor_slice_copy.hpp" -#include "threadwise_4d_tensor_op.hpp" -#include "blockwise_batched_gemm.hpp" - -namespace ck { - -template -struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn -{ - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) const - { - // be careful of this assertion - static_assert( - NPerBlock % NPerThread == 0 && - ((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) || - (GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock && - GemmNPerThreadSubC % NPerThread == 0)), - "wrong!"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_c_h_w_n_global_desc = InGlobalDesc{}; - constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{}; - constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{}; - - constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0); - - constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0); - constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1); - constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2); - constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3); - - constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1); - constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2); - - constexpr index_t HiPerBlock = HoPerBlock + Y - 1; - constexpr index_t WiPerBlock = WoPerBlock + X - 1; - - // divide block work: [K, Ho, Wo, N] - static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 && - Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, - "wrong! cannot evenly divide work for workgroup "); - - constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; - constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; - constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock; - - const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); - index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); - const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork); - itmp -= h_block_work_id * (WBlockWork * NBlockWork); - const index_t w_block_work_id = itmp / NBlockWork; - const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork; - - const index_t k_block_data_begin = k_block_work_id * KPerBlock; - const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; - const index_t wo_block_data_begin = w_block_work_id * WoPerBlock; - const index_t n_block_data_begin = n_block_work_id * NPerBlock; - - const index_t hi_block_data_begin = ho_block_data_begin; - const index_t wi_block_data_begin = wo_block_data_begin; - - // flattend (2d) tensor view of gridwise weight - constexpr auto wei_cyx_k_global_desc = - make_ConstantTensorDescriptor(Sequence{}); - - // tensor view of blockwise input and weight in LDS - // be careful of alignment - constexpr index_t max_align = math::lcm(InBlockCopyDataPerRead_N, - WeiBlockCopyDataPerRead_K, - GemmDataPerReadA, - GemmDataPerReadB); - - constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, - Number{}); - - // this check is ad-hoc - // TODO: need to properly implement tensor descriptor with alignment - static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0, - "GemmDataPerReadB alignment requirement is not meet"); - - constexpr auto wei_cyx_k_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, - Number{}); - - constexpr auto wei_c_y_x_k_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, - Number{}); - - // tensor view of threadwise output in register - constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor( - Sequence{}); - - // blockwise copy - // input: format is [C, Hi, Wi, N] - const auto blockwise_in_copy = -#if 0 - Blockwise4dTensorCopy1{}; -#else - Blockwise4dTensorCopy3{}; -#endif - - // blockwise wei copy - // format is [CPerBlock*Y*X,KPerBlock] - const auto blockwise_wei_copy = - Blockwise2dTensorCopy3{}; - - // a series of blockwise batched GEMM - // C_matrix += transpose(A_matrix) * B_matrix - // A_matrix and B_matrix saved in LDS, C_matrix saved in register - // A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K] - // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] - // C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N] - constexpr auto a_c_k_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - constexpr auto b_c_wn_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - constexpr auto c_k_wn_thread_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - const auto blockwise_batch_gemm = - BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2< - BlockSize, - decltype(a_c_k_block_mtx_desc), - decltype(b_c_wn_block_mtx_desc), - decltype(c_k_wn_thread_mtx_desc), - 0, - in_c_h_w_n_block_desc.GetStride(I1), - out_k_h_w_n_thread_desc.GetStride(I1), - HoPerBlock, - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - HoPerThread, - GemmDataPerReadA, - GemmDataPerReadB>{}; - - // LDS: be careful of alignment - constexpr index_t in_block_space = - in_c_h_w_n_block_desc.GetElementSpace(Number{}); - - constexpr index_t wei_block_space = - wei_c_y_x_k_block_desc.GetElementSpace(Number{}); - - __shared__ Float p_in_block[in_block_space]; - __shared__ Float p_wei_block[wei_block_space]; - - // register - // C++ lambda doesn't capture array, use pointer instead - Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()]; - Float* const p_out_thread = p_out_thread_data; - - // set threadwise output tensor to 0 - threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread); - - const Float* p_in_global_block_offset = - p_in_global + - in_c_h_w_n_global_desc.GetOffsetFromMultiIndex( - 0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin); - - const Float* p_wei_global_block_offset = - p_wei_global + - wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin); - - for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, - p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0), - p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0), - __syncthreads()) - { -#if 1 - blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); - blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); -#else - Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; - Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; - - blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, p_in_register_buffer); - blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset, - p_wei_register_buffer); - - blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block); - blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block); -#endif - - __syncthreads(); - -#pragma unroll - for(index_t y = 0; y < Y; ++y) - { -#pragma unroll - for(index_t x = 0; x < X; ++x) - { -#if 1 - blockwise_batch_gemm.Run -#else - blockwise_batch_gemm.Run_amd_asm -#endif - (p_wei_block + wei_c_y_x_k_block_desc.GetOffsetFromMultiIndex(0, y, x, 0), - p_in_block + in_c_h_w_n_block_desc.GetOffsetFromMultiIndex(0, y, x, 0), - p_out_thread); - } - } - } - - // output: register to global mem, - const auto c_thread_mtx_begin = - blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t k_thread_data_begin = c_thread_mtx_begin.row; - const index_t ho_thread_data_begin = c_thread_mtx_begin.batch; - const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; - const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; - - static_if{}([&](auto f_dummy) { // f_dummy do nothing but - // perfect forwarding. - // Using this trick to - // make this lambda a generic lambda, so it won't be compiled until - // instantiated - static_assert( - (f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0), - "wrong!"); - - // output is a 10d tensor - constexpr index_t N2 = GemmNPerThreadSubC; - constexpr index_t N1 = NPerBlock / N2; - - constexpr index_t W2 = - (GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC); - constexpr index_t W1 = WoPerBlock / W2; - - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = KPerBlock / KPerThread; - - constexpr auto out_10d_global_desc = - make_ConstantTensorDescriptor(Sequence{}); - - constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( - Sequence{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, - "out_k_h_w_n_thread_desc"); - print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); - - print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, - "out_k_h_w_n_global_desc"); - print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); - } -#endif - - threadwise_tensor_slice_copy(out_10d_thread_desc, - p_out_thread, - out_10d_global_desc, - p_out_global + - out_k_h_w_n_global_desc.GetOffsetFromMultiIndex( - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin, - n_block_data_begin + n_thread_data_begin), - out_10d_thread_desc.GetLengths(), - Number{}); - }).Else([&](auto f_dummy) { - static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && - GemmNPerThreadSubC % NPerThread == 0, - "wrong!"); - - // output is a 10d tensor - constexpr index_t N1 = NPerBlock; - - constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock; - constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster; - constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3); - - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = KPerBlock / KPerThread; - - constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor( - Sequence{}); - - constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( - Sequence{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, - "out_k_h_w_n_thread_desc"); - print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); - - print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, - "out_k_h_w_n_global_desc"); - print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); - - for(index_t i = 0; i < 64; ++i) - { - printf("out %f, ", p_out_thread[i]); - } - } -#endif - - threadwise_tensor_slice_copy(out_10d_thread_desc, - p_out_thread, - out_10d_global_desc, - p_out_global + - out_k_h_w_n_global_desc.GetOffsetFromMultiIndex( - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin, - n_block_data_begin + n_thread_data_begin), - out_10d_thread_desc.GetLengths(), - Number{}); - }); - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hpp deleted file mode 100644 index 6975b1e248..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hpp +++ /dev/null @@ -1,435 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R2_CHWN_CYXK_KHWN -#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R2_CHWN_CYXK_KHWN - -#include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "blockwise_2d_tensor_op.hpp" -#include "blockwise_3d_tensor_op.hpp" -#include "blockwise_4d_tensor_op.hpp" -#include "threadwise_tensor_slice_copy.hpp" -#include "threadwise_4d_tensor_op.hpp" -#include "blockwise_batched_gemm.hpp" - -namespace ck { - -template -struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn -{ - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) const - { - // be careful of this assertion - static_assert( - NPerBlock % NPerThread == 0 && - ((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) || - (GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock && - GemmNPerThreadSubC % NPerThread == 0)), - "wrong!"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_c_h_w_n_global_desc = InGlobalDesc{}; - constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{}; - constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{}; - - constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0); - - constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0); - constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1); - constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2); - constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3); - - constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1); - constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2); - - constexpr index_t HiPerBlock = HoPerBlock + Y - 1; - constexpr index_t WiPerBlock = WoPerBlock + X - 1; - - // divide block work: [K, Ho, Wo, N] - static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 && - Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, - "wrong! cannot evenly divide work for workgroup "); - - constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock; - constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock; - constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock; - - const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork); - index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork); - const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork); - itmp -= h_block_work_id * (WBlockWork * NBlockWork); - const index_t w_block_work_id = itmp / NBlockWork; - const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork; - - const index_t k_block_data_begin = k_block_work_id * KPerBlock; - const index_t ho_block_data_begin = h_block_work_id * HoPerBlock; - const index_t wo_block_data_begin = w_block_work_id * WoPerBlock; - const index_t n_block_data_begin = n_block_work_id * NPerBlock; - - const index_t hi_block_data_begin = ho_block_data_begin; - const index_t wi_block_data_begin = wo_block_data_begin; - - // global tensor view - constexpr auto wei_c_x_k_global_desc = - make_ConstantTensorDescriptor(Sequence{}, Sequence{}); - - // LDS tensor view - // be careful of alignment - constexpr index_t max_align = math::lcm(InBlockCopyDataPerRead_N, - WeiBlockCopyDataPerRead_K, - GemmDataPerReadA, - GemmDataPerReadB); - - constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, - Number{}); - - // this check is ad-hoc - // TODO: need to properly implement tensor descriptor with alignment - static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0, - "GemmDataPerReadB alignment requirement is not meet"); - - constexpr auto wei_c_x_k_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, - Number{}); - - // tensor view of threadwise output in register - constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor( - Sequence{}); - -// blockwise copy -// input: format is [C, Hi, Wi, N] -#if 1 - const auto blockwise_in_copy = - Blockwise4dTensorCopy1{}; -#else - const auto blockwise_in_copy = - Blockwise4dTensorCopy3{}; -#endif - - // blockwise wei copy - // format is [CPerBlock, X * KPerBlock] - const auto blockwise_wei_copy = - Blockwise3dTensorCopy1{}; - - // a series of blockwise batched GEMM - // C_matrix += transpose(A_matrix) * B_matrix - // A_matrix and B_matrix saved in LDS, C_matrix saved in register - // A_matrix[C,K] is a sub-matrix of wei_block[C,K] - // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] - // C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N] - constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_c_wn_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - constexpr auto c_k_wn_thread_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - const auto blockwise_batch_gemm = - BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2< - BlockSize, - decltype(a_c_k_block_mtx_desc), - decltype(b_c_wn_block_mtx_desc), - decltype(c_k_wn_thread_mtx_desc), - 0, - in_c_h_w_n_block_desc.GetStride(I1), - out_k_h_w_n_thread_desc.GetStride(I1), - HoPerBlock, - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - HoPerThread, - GemmDataPerReadA, - GemmDataPerReadB>{}; - - // LDS: be careful of alignment - constexpr index_t in_block_space = - in_c_h_w_n_block_desc.GetElementSpace(Number{}); - constexpr index_t wei_block_space = - wei_c_x_k_block_desc.GetElementSpace(Number{}); - - __shared__ Float p_in_block[in_block_space]; - __shared__ Float p_wei_block[wei_block_space]; - - // register - // C++ lambda doesn't capture array, use pointer instead - Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()]; - Float* const p_out_thread = p_out_thread_data; - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc"); - print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc"); - - print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc"); - print_ConstantTensorDescriptor(wei_c_x_k_block_desc, "wei_c_x_k_block_desc"); - - printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space); - } -#endif - - // set threadwise output tensor to 0 - threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread); - -#if 1 - const Float* p_in_global_block_offset = - p_in_global + - in_c_h_w_n_global_desc.GetOffsetFromMultiIndex( - 0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin); - - const Float* p_wei_global_block_offset = - p_wei_global + - wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin); - - for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, - p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0), - p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0)) - { - for(index_t y = 0; y < Y; ++y) - { - blockwise_in_copy.Run( - p_in_global_block_offset + - in_c_h_w_n_global_desc.GetOffsetFromMultiIndex(0, y, 0, 0), - p_in_block); - - blockwise_wei_copy.Run( - p_wei_global_block_offset + - wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, 0, 0), - p_wei_block); - - __syncthreads(); - - for(index_t x = 0; x < X; ++x) - { - blockwise_batch_gemm.Run( - p_wei_block + wei_c_x_k_block_desc.GetOffsetFromMultiIndex(0, x, 0), - p_in_block + in_c_h_w_n_block_desc.GetOffsetFromMultiIndex(0, 0, x, 0), - p_out_thread); - } - - __syncthreads(); - } - } -#else - // this use much more register, haven't figure out why? - for(index_t y = 0; y < Y; ++y) - { - const Float* p_in_global_block_offset = - p_in_global + - in_c_h_w_n_global_desc.GetOffsetFromMultiIndex( - 0, hi_block_data_begin + y, wi_block_data_begin, n_block_data_begin); - - const Float* p_wei_global_block_offset = - p_wei_global + - wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, 0, k_block_data_begin); - - for(index_t - c_block_data_begin = 0; - c_block_data_begin < C; - c_block_data_begin += CPerBlock, - p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0), - p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0)) - { - blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); - - blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); - - __syncthreads(); - - for(index_t x = 0; x < X; ++x) - { - blockwise_batch_gemm.Run( - p_wei_block + wei_c_x_k_block_desc.GetOffsetFromMultiIndex(0, x, 0), - p_in_block + in_c_h_w_n_block_desc.GetOffsetFromMultiIndex(0, 0, x, 0), - p_out_thread); - } - - __syncthreads(); - } - } -#endif - - // output: register to global mem, - const auto c_thread_mtx_begin = - blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t k_thread_data_begin = c_thread_mtx_begin.row; - const index_t ho_thread_data_begin = c_thread_mtx_begin.batch; - const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; - const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; - - static_if{}([&](auto f_dummy) { // f_dummy do nothing but - // perfect forwarding. - // Using this trick to - // make this lambda a generic lambda, so it won't be compiled until - // instantiated - static_assert( - (f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0), - "wrong!"); - - // output is a 10d tensor - constexpr index_t N2 = GemmNPerThreadSubC; - constexpr index_t N1 = NPerBlock / N2; - - constexpr index_t W2 = - (GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC); - constexpr index_t W1 = WoPerBlock / W2; - - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = KPerBlock / KPerThread; - - constexpr auto out_10d_global_desc = - make_ConstantTensorDescriptor(Sequence{}); - - constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( - Sequence{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, - "out_k_h_w_n_thread_desc"); - print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); - - print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, - "out_k_h_w_n_global_desc"); - print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); - } -#endif - - threadwise_tensor_slice_copy(out_10d_thread_desc, - p_out_thread, - out_10d_global_desc, - p_out_global + - out_k_h_w_n_global_desc.GetOffsetFromMultiIndex( - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin, - n_block_data_begin + n_thread_data_begin), - out_10d_thread_desc.GetLengths(), - Number{}); - }).Else([&](auto f_dummy) { - static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && - GemmNPerThreadSubC % NPerThread == 0, - "wrong!"); - - // output is a 10d tensor - constexpr index_t N1 = NPerBlock; - - constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock; - constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster; - constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3); - - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = KPerBlock / KPerThread; - - constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor( - Sequence{}); - - constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor( - Sequence{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, - "out_k_h_w_n_thread_desc"); - print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc"); - - print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, - "out_k_h_w_n_global_desc"); - print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc"); - - for(index_t i = 0; i < 64; ++i) - { - printf("out %f, ", p_out_thread[i]); - } - } -#endif - - threadwise_tensor_slice_copy(out_10d_thread_desc, - p_out_thread, - out_10d_global_desc, - p_out_global + - out_k_h_w_n_global_desc.GetOffsetFromMultiIndex( - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin, - n_block_data_begin + n_thread_data_begin), - out_10d_thread_desc.GetLengths(), - Number{}); - }); - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp deleted file mode 100644 index def4ae086b..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp +++ /dev/null @@ -1,420 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_HPP -#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_HPP - -#include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "blockwise_generic_tensor_slice_copy.hpp" -#include "threadwise_generic_tensor_slice_copy.hpp" -#include "blockwise_batched_gemm.hpp" - -namespace ck { - -template -struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn -{ - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) const - { - // be careful of this assertion - static_assert( - NPerBlock % NPerThread == 0 && - ((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) || - (GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock && - GemmNPerThreadSubC % NPerThread == 0)), - "wrong!"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_c_h_w_n_global_desc = InGlobalDesc{}; - constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{}; - constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{}; - - constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0); - - constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0); - constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1); - constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2); - constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3); - - constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1); - constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2); - - // divide block work: [K, Ho, Wo, N] - static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 && - Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, - "wrong! cannot evenly divide work for workgroup "); - - constexpr index_t KBlockWork = math::integer_divide_ceil(K, KPerBlock); - constexpr index_t HBlockWork = math::integer_divide_ceil(Ho, HoPerBlock); - constexpr index_t WBlockWork = math::integer_divide_ceil(Wo, WoPerBlock); - constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock); - - constexpr auto block_work_desc = make_ConstantTensorDescriptor_packed( - Sequence{}); - - const auto block_work_multi_id = - block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id()); - - const index_t k_block_data_begin = block_work_multi_id[0] * KPerBlock; - const index_t ho_block_data_begin = block_work_multi_id[1] * HoPerBlock; - const index_t wo_block_data_begin = block_work_multi_id[2] * WoPerBlock; - const index_t n_block_data_begin = block_work_multi_id[3] * NPerBlock; - - const index_t hi_block_data_begin = ho_block_data_begin; - const index_t wi_block_data_begin = wo_block_data_begin; - - // global tensor view - constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3); - - // LDS tensor view - // be careful of alignment - constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_N, - WeiBlockCopyDataPerAccess_K, - GemmDataPerReadA, - GemmDataPerReadB); - - constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - // this check is ad-hoc - // TODO: need to properly implement tensor descriptor with alignment - static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0, - "GemmDataPerReadB alignment requirement is not meet"); - - constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - // tensor view of threadwise output in register - constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed( - Sequence{}); - - // blockwise copy - // input: format is [C, Hi, Wi, N] - auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated< - BlockSize, - decltype(in_c_h_w_n_global_desc), - decltype(in_c_h_w_n_block_desc), - decltype(in_c_h_w_n_block_desc.GetLengths()), - InBlockCopySubLengths_CHWN, - InBlockCopyClusterLengths_CHWN, - Sequence<0, 1, 2, 3>, - Sequence<0, 1, 2, 3>, - Sequence<0, 1, 2, 3>, - 3, - 3, - InBlockCopyDataPerAccess_N, - InBlockCopyDataPerAccess_N>({0, 0, 0, 0}, {0, 0, 0, 0}); - - // blockwise wei copy - // format is [CPerBlock, X * KPerBlock] - const auto blockwise_wei_copy = - BlockwiseGenericTensorSliceCopy_v1_deprecated, - Sequence<0, 1>, - Sequence<0, 1>, - 1, - 1, - WeiBlockCopyDataPerAccess_K, - WeiBlockCopyDataPerAccess_K>({0, 0}, - {0, 0}); - - // a series of blockwise batched GEMM - // C_matrix += transpose(A_matrix) * B_matrix - // A_matrix and B_matrix saved in LDS, C_matrix saved in register - // A_matrix[C,K] is a sub-matrix of wei_block[C,K] - // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] - // C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N] - constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_c_wn_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - constexpr auto c_k_wn_thread_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - const auto blockwise_batch_gemm = - BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2< - BlockSize, - decltype(a_c_k_block_mtx_desc), - decltype(b_c_wn_block_mtx_desc), - decltype(c_k_wn_thread_mtx_desc), - 0, - in_c_h_w_n_block_desc.GetStride(I1), - out_k_h_w_n_thread_desc.GetStride(I1), - HoPerBlock, - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - HoPerThread, - GemmDataPerReadA, - GemmDataPerReadB>{}; - - // LDS: be careful of alignment - constexpr index_t in_block_space = in_c_h_w_n_block_desc.GetElementSpace(); - constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(); - - __shared__ Float p_in_block[in_block_space]; - __shared__ Float p_wei_block[wei_block_space]; - - // register - // C++ lambda doesn't capture array, use pointer instead - Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()]; - Float* const p_out_thread = p_out_thread_data; - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc"); - print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc"); - - print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc"); - print_ConstantTensorDescriptor(wei_c_x_k_block_desc, "wei_c_x_k_block_desc"); - - printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space); - } -#endif - - // set threadwise output tensor to 0 - threadwise_matrix_set_zero(c_k_wn_thread_mtx_desc, p_out_thread); - - for(index_t y = 0; y < Y; ++y) - { - for(index_t x = 0; x < X; ++x) - { - const Float* p_in_global_block_offset = - p_in_global + - in_c_h_w_n_global_desc.GetOffsetFromMultiIndex( - 0, hi_block_data_begin + y, wi_block_data_begin + x, n_block_data_begin); - - const Float* p_wei_global_block_offset = - p_wei_global + - wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, k_block_data_begin); - - for(index_t c_block_data_begin = 0; c_block_data_begin < C; - c_block_data_begin += CPerBlock, - p_in_global_block_offset += - CPerBlock * in_c_h_w_n_global_desc.GetStride(I0), - p_wei_global_block_offset += - CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0)) - { - blockwise_in_copy.Run(p_in_global_block_offset, p_in_block); - - blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); - - __syncthreads(); - - blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread); - - __syncthreads(); - } - } - } - - // output: register to global mem - const auto c_thread_mtx_begin = - blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t k_thread_data_begin = c_thread_mtx_begin.row; - const index_t ho_thread_data_begin = c_thread_mtx_begin.batch; - const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; - const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; - - static_if{}([&](auto fwd) { - // fwd do nothing but perfect forwarding. - // Using this trick to make this lambda a generic lambda, so it won't be compiled until - // being instantiated here - static_assert( - (fwd(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0), - "wrong!"); - - // output is a 10d tensor - constexpr index_t N2 = GemmNPerThreadSubC; - constexpr index_t N1 = NPerBlock / N2; - - constexpr index_t W2 = - (GemmNLevel0Cluster * GemmNLevel1Cluster) / fwd(NPerBlock / GemmNPerThreadSubC); - constexpr index_t W1 = WoPerBlock / W2; - - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = KPerBlock / KPerThread; - - constexpr auto out_10d_global_desc = fwd(out_k_h_w_n_global_desc) - .Fold(I3, Number{}, Number{}) - .Fold(I2, Number{}, Number{}) - .Fold(I0, Number{}, Number{}); - - constexpr auto out_10d_thread_desc = fwd(out_k_h_w_n_thread_desc) - .Fold(I3, Number<1>{}, Number{}) - .Fold(I2, Number{}, Number<1>{}) - .Fold(I0, Number<1>{}, Number{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, - "a: out_k_h_w_n_thread_desc"); - print_ConstantTensorDescriptor(out_10d_thread_desc, "a: out_10d_thread_desc"); - - print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, - "a: out_k_h_w_n_global_desc"); - print_ConstantTensorDescriptor(out_10d_global_desc, "a: out_10d_global_desc"); - } -#endif - - Float* p_out_thread_on_global = p_out_global + - out_k_h_w_n_global_desc.GetOffsetFromMultiIndex( - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin, - n_block_data_begin + n_thread_data_begin); - -#if 1 - ThreadwiseGenericTensorSliceCopy_v1r2_deprecated< - decltype(out_10d_thread_desc), - decltype(out_10d_global_desc), - decltype(out_10d_thread_desc.GetLengths()), - arithmetic_sequence_gen<0, 10, 1>::type, - 9, - OutThreadCopyDataPerAccess_N, - OutThreadCopyDataPerAccess_N>(make_zero_array(), - make_zero_array()) - .Run(p_out_thread, p_out_thread_on_global); -#elif 0 - ThreadwiseGenericTensorSliceCopy_v1r1::type, - arithmetic_sequence_gen<0, 10, 1>::type, - 9, - 9, - OutThreadCopyDataPerAccess_N, - OutThreadCopyDataPerAccess_N>( - make_zero_array(), make_zero_array()) - .Run(p_out_thread, p_out_thread_on_global); -#endif - }).Else([&](auto fwd) { - static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && - GemmNPerThreadSubC % NPerThread == 0, - "wrong!"); - - // output is a 10d tensor - constexpr index_t N1 = NPerBlock; - - constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock; - constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster; - constexpr index_t W1 = WoPerBlock / fwd(W2 * W3); - - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = KPerBlock / KPerThread; - - constexpr auto out_10d_global_desc = - fwd(out_k_h_w_n_global_desc) - .Fold(I3, Number{}) - .Fold(I2, Number{}, Number{}, Number{}) - .Fold(I0, Number{}, Number{}); - - constexpr auto out_10d_thread_desc = - fwd(out_k_h_w_n_thread_desc) - .Fold(I3, Number{}) - .Fold(I2, Number{}, Number<1>{}, Number{}) - .Fold(I0, Number<1>{}, Number{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, - "b: out_k_h_w_n_thread_desc"); - print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc"); - - print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, - "b: out_k_h_w_n_global_desc"); - print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc"); - } -#endif - - Float* p_out_thread_on_global = p_out_global + - out_k_h_w_n_global_desc.GetOffsetFromMultiIndex( - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin, - n_block_data_begin + n_thread_data_begin); - -#if 1 - ThreadwiseGenericTensorSliceCopy_v1r2_deprecated< - decltype(out_10d_thread_desc), - decltype(out_10d_global_desc), - decltype(out_10d_thread_desc.GetLengths()), - arithmetic_sequence_gen<0, 10, 1>::type, - 9, - OutThreadCopyDataPerAccess_N, - OutThreadCopyDataPerAccess_N>(make_zero_array(), - make_zero_array()) - .Run(p_out_thread, p_out_thread_on_global); -#elif 0 - ThreadwiseGenericTensorSliceCopy_v1r1::type, - arithmetic_sequence_gen<0, 10, 1>::type, - 9, - 9, - OutThreadCopyDataPerAccess_N, - OutThreadCopyDataPerAccess_N>( - make_zero_array(), make_zero_array()) - .Run(p_out_thread, p_out_thread_on_global); -#endif - }); - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_lds_double_buffer.hpp deleted file mode 100644 index 9528d7cb97..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_lds_double_buffer.hpp +++ /dev/null @@ -1,508 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER_HPP -#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER_HPP - -#include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "blockwise_generic_tensor_slice_copy.hpp" -#include "threadwise_generic_tensor_slice_copy.hpp" -#include "blockwise_batched_gemm.hpp" - -namespace ck { - -template -struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer -{ - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) const - { - // be careful of this assertion - static_assert( - NPerBlock % NPerThread == 0 && - ((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) || - (GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock && - GemmNPerThreadSubC % NPerThread == 0)), - "wrong!"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_c_h_w_n_global_desc = InGlobalDesc{}; - constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{}; - constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{}; - - constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0); - - constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0); - constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1); - constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2); - constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3); - - constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1); - constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2); - - // divide block work: [K, Ho, Wo, N] - static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % (2 * CPerBlock) == 0 && - Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, - "wrong! cannot evenly divide work for workgroup "); - - constexpr index_t KBlockWork = math::integer_divide_ceil(K, KPerBlock); - constexpr index_t HBlockWork = math::integer_divide_ceil(Ho, HoPerBlock); - constexpr index_t WBlockWork = math::integer_divide_ceil(Wo, WoPerBlock); - constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock); - - constexpr auto block_work_desc = make_ConstantTensorDescriptor_packed( - Sequence{}); - - const auto block_work_multi_id = - block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id()); - - const index_t k_block_data_begin = block_work_multi_id[0] * KPerBlock; - const index_t ho_block_data_begin = block_work_multi_id[1] * HoPerBlock; - const index_t wo_block_data_begin = block_work_multi_id[2] * WoPerBlock; - const index_t n_block_data_begin = block_work_multi_id[3] * NPerBlock; - - const index_t hi_block_data_begin = ho_block_data_begin; - const index_t wi_block_data_begin = wo_block_data_begin; - - // global tensor view - constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3); - - // LDS tensor view - // be careful of alignment - constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_N, - WeiBlockCopyDataPerAccess_K, - GemmDataPerReadA, - GemmDataPerReadB); - - constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - // this check is ad-hoc - // TODO: need to properly implement tensor descriptor with alignment - static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0, - "GemmDataPerReadB alignment requirement is not meet"); - - constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - // tensor view of threadwise output in register - constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed( - Sequence{}); - - // blockwise copy - // input: format is [C, Hi, Wi, N] - auto blockwise_in_copy = -#if 0 - BlockwiseGenericTensorSliceCopy_v1_deprecated -#else - BlockwiseGenericTensorSliceCopy_v2_deprecated -#endif - , - Sequence<0, 1, 2, 3>, - Sequence<0, 1, 2, 3>, - 3, - 3, - InBlockCopyDataPerAccess_N, - InBlockCopyDataPerAccess_N>({0, 0, 0, 0}, {0, 0, 0, 0}); - - // blockwise wei copy - // format is [CPerBlock, X * KPerBlock] - const auto blockwise_wei_copy = -#if 0 - BlockwiseGenericTensorSliceCopy_v1_deprecated -#else - BlockwiseGenericTensorSliceCopy_v2_deprecated -#endif - , - Sequence<0, 1>, - Sequence<0, 1>, - 1, - 1, - WeiBlockCopyDataPerAccess_K, - WeiBlockCopyDataPerAccess_K>({0, 0}, {0, 0}); - - // a series of blockwise batched GEMM - // C_matrix += transpose(A_matrix) * B_matrix - // A_matrix and B_matrix saved in LDS, C_matrix saved in register - // A_matrix[C,K] is a sub-matrix of wei_block[C,K] - // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] - // C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N] - constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_c_wn_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - constexpr auto c_k_wn_thread_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - const auto blockwise_batch_gemm = - BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2< - BlockSize, - decltype(a_c_k_block_mtx_desc), - decltype(b_c_wn_block_mtx_desc), - decltype(c_k_wn_thread_mtx_desc), - 0, - in_c_h_w_n_block_desc.GetStride(I1), - out_k_h_w_n_thread_desc.GetStride(I1), - HoPerBlock, - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - HoPerThread, - GemmDataPerReadA, - GemmDataPerReadB>{}; - - // LDS: be careful of alignment - constexpr index_t in_block_space = in_c_h_w_n_block_desc.GetElementSpace(); - constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(); - - // LDS double buffer - __shared__ Float p_in_block_double[2 * in_block_space]; - __shared__ Float p_wei_block_double[2 * wei_block_space]; - - // register - // C++ lambda doesn't capture array, use pointer instead - Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()]; - Float* const p_out_thread = p_out_thread_data; - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc"); - print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc"); - - print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc"); - print_ConstantTensorDescriptor(wei_c_x_k_block_desc, "wei_c_x_k_block_desc"); - - printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space); - } -#endif - - // set threadwise output to 0 - threadwise_matrix_set_zero(c_k_wn_thread_mtx_desc, p_out_thread); - - for(index_t y = 0; y < Y; ++y) - { - for(index_t x = 0; x < X; ++x) - { - const Float* p_in_global_block_offset = - p_in_global + - in_c_h_w_n_global_desc.GetOffsetFromMultiIndex( - 0, hi_block_data_begin + y, wi_block_data_begin + x, n_block_data_begin); - - const Float* p_wei_global_block_offset = - p_wei_global + - wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, k_block_data_begin); - - // LDS double buffer: preload data into LDS - { - Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; - Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; - - blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, - p_in_register_buffer); - blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset, - p_wei_register_buffer); - - blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, - p_in_block_double); - blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, - p_wei_block_double); - } - - // LDS double buffer: main body - for(index_t c_block_data_begin = 0; c_block_data_begin + 2 * CPerBlock < C; - c_block_data_begin += 2 * CPerBlock) - { -#pragma unroll - for(index_t iloop = 0; iloop < 2; ++iloop) - { - const bool even_loop = (iloop % 2 == 0); - - Float* p_in_block_now = - even_loop ? p_in_block_double : p_in_block_double + in_block_space; - Float* p_wei_block_now = - even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space; - - Float* p_in_block_next = - even_loop ? p_in_block_double + in_block_space : p_in_block_double; - Float* p_wei_block_next = - even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; - - Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; - Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; - - p_in_global_block_offset += - CPerBlock * in_c_h_w_n_global_desc.GetStride(I0); - p_wei_global_block_offset += - CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, - p_in_register_buffer); - blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset, - p_wei_register_buffer); - - blockwise_batch_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); - - // LDS double buffer: store next data to LDS - blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, - p_in_block_next); - blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, - p_wei_block_next); - } - } - - // LDS double buffer: tail - { - Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; - Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; - - // even iteration - p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0); - p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, - p_in_register_buffer); - blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset, - p_wei_register_buffer); - - // LDS double buffer: GEMM on current data - blockwise_batch_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); - - // LDS double buffer: store next data to LDS - blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, - p_in_block_double + in_block_space); - blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, - p_wei_block_double + wei_block_space); - - // odd iteration - __syncthreads(); - - // LDS double buffer: GEMM on current data - blockwise_batch_gemm.Run(p_wei_block_double + wei_block_space, - p_in_block_double + in_block_space, - p_out_thread); - } - } - } - - // output: register to global mem - const auto c_thread_mtx_begin = - blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t k_thread_data_begin = c_thread_mtx_begin.row; - const index_t ho_thread_data_begin = c_thread_mtx_begin.batch; - const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; - const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; - - static_if{}([&](auto fwd) { - // fwd do nothing but perfect forwarding. - // Using this trick to make this lambda a generic lambda, so it won't be compiled until - // being instantiated here - static_assert( - (fwd(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0), - "wrong!"); - - // output is a 10d tensor - constexpr index_t N2 = GemmNPerThreadSubC; - constexpr index_t N1 = NPerBlock / N2; - - constexpr index_t W2 = - (GemmNLevel0Cluster * GemmNLevel1Cluster) / fwd(NPerBlock / GemmNPerThreadSubC); - constexpr index_t W1 = WoPerBlock / W2; - - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = KPerBlock / KPerThread; - - constexpr auto out_10d_global_desc = fwd(out_k_h_w_n_global_desc) - .Fold(I3, Number{}, Number{}) - .Fold(I2, Number{}, Number{}) - .Fold(I0, Number{}, Number{}); - - constexpr auto out_10d_thread_desc = fwd(out_k_h_w_n_thread_desc) - .Fold(I3, Number<1>{}, Number{}) - .Fold(I2, Number{}, Number<1>{}) - .Fold(I0, Number<1>{}, Number{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, - "a: out_k_h_w_n_thread_desc"); - print_ConstantTensorDescriptor(out_10d_thread_desc, "a: out_10d_thread_desc"); - - print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, - "a: out_k_h_w_n_global_desc"); - print_ConstantTensorDescriptor(out_10d_global_desc, "a: out_10d_global_desc"); - } -#endif - - Float* p_out_thread_on_global = p_out_global + - out_k_h_w_n_global_desc.GetOffsetFromMultiIndex( - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin, - n_block_data_begin + n_thread_data_begin); - -#if 1 - ThreadwiseGenericTensorSliceCopy_v1r2_deprecated< - decltype(out_10d_thread_desc), - decltype(out_10d_global_desc), - decltype(out_10d_thread_desc.GetLengths()), - arithmetic_sequence_gen<0, 10, 1>::type, - 9, - OutThreadCopyDataPerAccess_N, - OutThreadCopyDataPerAccess_N>(make_zero_array(), - make_zero_array()) - .Run(p_out_thread, p_out_thread_on_global); -#elif 0 - ThreadwiseGenericTensorSliceCopy_v1r1::type, - arithmetic_sequence_gen<0, 10, 1>::type, - 9, - 9, - OutThreadCopyDataPerAccess_N, - OutThreadCopyDataPerAccess_N>( - make_zero_array(), make_zero_array()) - .Run(p_out_thread, p_out_thread_on_global); -#endif - }).Else([&](auto fwd) { - static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && - GemmNPerThreadSubC % NPerThread == 0, - "wrong!"); - - // output is a 10d tensor - constexpr index_t N1 = NPerBlock; - - constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock; - constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster; - constexpr index_t W1 = WoPerBlock / fwd(W2 * W3); - - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = KPerBlock / KPerThread; - - constexpr auto out_10d_global_desc = - fwd(out_k_h_w_n_global_desc) - .Fold(I3, Number{}) - .Fold(I2, Number{}, Number{}, Number{}) - .Fold(I0, Number{}, Number{}); - - constexpr auto out_10d_thread_desc = - fwd(out_k_h_w_n_thread_desc) - .Fold(I3, Number{}) - .Fold(I2, Number{}, Number<1>{}, Number{}) - .Fold(I0, Number<1>{}, Number{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, - "b: out_k_h_w_n_thread_desc"); - print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc"); - - print_ConstantTensorDescriptor(out_k_h_w_n_global_desc, - "b: out_k_h_w_n_global_desc"); - print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc"); - } -#endif - - Float* p_out_thread_on_global = p_out_global + - out_k_h_w_n_global_desc.GetOffsetFromMultiIndex( - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin, - n_block_data_begin + n_thread_data_begin); - -#if 1 - ThreadwiseGenericTensorSliceCopy_v1r2_deprecated< - decltype(out_10d_thread_desc), - decltype(out_10d_global_desc), - decltype(out_10d_thread_desc.GetLengths()), - arithmetic_sequence_gen<0, 10, 1>::type, - 9, - OutThreadCopyDataPerAccess_N, - OutThreadCopyDataPerAccess_N>(make_zero_array(), - make_zero_array()) - .Run(p_out_thread, p_out_thread_on_global); -#elif 0 - ThreadwiseGenericTensorSliceCopy_v1r1::type, - arithmetic_sequence_gen<0, 10, 1>::type, - 9, - 9, - OutThreadCopyDataPerAccess_N, - OutThreadCopyDataPerAccess_N>( - make_zero_array(), make_zero_array()) - .Run(p_out_thread, p_out_thread_on_global); -#endif - }); - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp deleted file mode 100644 index 8fad9b864e..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp +++ /dev/null @@ -1,414 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_PADDED_HPP -#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_PADDED_HPP - -#include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "blockwise_generic_tensor_slice_copy.hpp" -#include "threadwise_generic_tensor_slice_copy.hpp" -#include "blockwise_batched_gemm.hpp" - -namespace ck { - -template -struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded -{ - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) const - { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - static constexpr auto True = integral_constant{}; - static constexpr auto False = integral_constant{}; - - // be careful of this assertion - static_assert( - NPerBlock % NPerThread == 0 && - ((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) || - (GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock && - GemmNPerThreadSubC % NPerThread == 0)), - "wrong!"); - - constexpr auto in_c_h_w_n_global_desc_old = InGlobalDesc{}; - constexpr auto wei_c_y_x_k_global_desc_old = WeiGlobalDesc{}; - constexpr auto out_k_h_w_n_global_desc_old = OutGlobalDesc{}; - - constexpr auto in_c_h_w_n_global_desc = make_native_tensor_descriptor( - in_c_h_w_n_global_desc_old.GetLengths(), in_c_h_w_n_global_desc_old.GetStrides()); - - constexpr auto wei_c_y_x_k_global_desc = make_native_tensor_descriptor( - wei_c_y_x_k_global_desc_old.GetLengths(), wei_c_y_x_k_global_desc_old.GetStrides()); - - constexpr auto out_k_h_w_n_global_desc = make_native_tensor_descriptor( - out_k_h_w_n_global_desc_old.GetLengths(), out_k_h_w_n_global_desc_old.GetStrides()); - - constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0); - constexpr index_t Hi = in_c_h_w_n_global_desc.GetLength(I1); - constexpr index_t Wi = in_c_h_w_n_global_desc.GetLength(I2); - - constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0); - constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1); - constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2); - constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3); - - constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1); - constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2); - - // divide block work: [K, Ho, Wo, N] - static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 && - Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, - "wrong! cannot evenly divide work for workgroup "); - - constexpr index_t KBlockWork = math::integer_divide_ceil(K, KPerBlock); - constexpr index_t HBlockWork = math::integer_divide_ceil(Ho, HoPerBlock); - constexpr index_t WBlockWork = math::integer_divide_ceil(Wo, WoPerBlock); - constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock); - - constexpr auto block_work_desc = make_ConstantTensorDescriptor_packed( - Sequence{}); - - const auto block_work_multi_id = - block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id()); - - const index_t k_block_data_begin = block_work_multi_id[0] * KPerBlock; - const index_t ho_block_data_begin = block_work_multi_id[1] * HoPerBlock; - const index_t wo_block_data_begin = block_work_multi_id[2] * WoPerBlock; - const index_t n_block_data_begin = block_work_multi_id[3] * NPerBlock; - - const index_t hp_block_data_begin = ho_block_data_begin; - const index_t wp_block_data_begin = wo_block_data_begin; - - // input global tensor view - constexpr auto in_c_hp_wp_n_global_desc = transform_tensor_descriptor( - in_c_h_w_n_global_desc, - make_tuple( - PassThrough{}, Pad, LeftPads, RightPads>{}, PassThrough{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - // LDS tensor view - // be careful of alignment - constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_N, - WeiBlockCopyDataPerAccess_K, - GemmDataPerReadA, - GemmDataPerReadB); - - constexpr auto in_c_h_w_n_block_desc_old = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - // hack - constexpr auto in_c_h_w_n_block_desc = make_native_tensor_descriptor( - in_c_h_w_n_block_desc_old.GetLengths(), in_c_h_w_n_block_desc_old.GetStrides()); - - // this check is ad-hoc - // TODO: need to properly implement tensor descriptor with alignment - static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0, - "GemmDataPerReadB alignment requirement is not meet"); - - constexpr auto wei_c_1_1_k_block_desc_old = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - constexpr auto wei_c_1_1_k_block_desc = make_native_tensor_descriptor( - wei_c_1_1_k_block_desc_old.GetLengths(), wei_c_1_1_k_block_desc_old.GetStrides()); - - // LDS: be careful of alignment - constexpr index_t in_block_space = in_c_h_w_n_block_desc_old.GetElementSpace(); - constexpr index_t wei_block_space = wei_c_1_1_k_block_desc_old.GetElementSpace(); - - __shared__ Float p_in_block[in_block_space]; - __shared__ Float p_wei_block[wei_block_space]; - - // tensor view of threadwise output in register - constexpr auto out_k_h_w_n_thread_desc_old = make_ConstantTensorDescriptor_packed( - Sequence{}); - - constexpr auto out_k_h_w_n_thread_desc = make_native_tensor_descriptor( - out_k_h_w_n_thread_desc_old.GetLengths(), out_k_h_w_n_thread_desc_old.GetStrides()); - - // blockwise input copy - // format is [C, Hi, Wi, N] - auto blockwise_in_copy = - BlockwiseGenericTensorSliceCopy_v4, - Sequence<0, 1, 2, 3>, - Sequence<0, 1, 2, 3>, - 3, - 3, - InBlockCopyDataPerAccess_N, - InBlockCopyDataPerAccess_N>( - {0, hp_block_data_begin, wp_block_data_begin, n_block_data_begin}, {0, 0, 0, 0}); - - // blockwise wei copy - // format is [CPerBlock, KPerBlock] - using WeiBlockCopySubLengths_CYXK = - Sequence; - using WeiBlockCopyClusterLengths_CYXK = Sequence; - - auto blockwise_wei_copy = - BlockwiseGenericTensorSliceCopy_v4, - Sequence<0, 1, 2, 3>, - Sequence<0, 1, 2, 3>, - 3, - 3, - WeiBlockCopyDataPerAccess_K, - WeiBlockCopyDataPerAccess_K>( - {0, 0, 0, k_block_data_begin}, {0, 0, 0, 0}); - - // a series of blockwise batched GEMM - // C_matrix += transpose(A_matrix) * B_matrix - // A_matrix and B_matrix saved in LDS, C_matrix saved in register - // A_matrix[C,K] is a sub-matrix of wei_block[C,K] - // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] - // C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N] - constexpr auto a_c_k_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - constexpr auto b_c_wn_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - constexpr auto c_k_wn_thread_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - const auto blockwise_batch_gemm = - BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2< - BlockSize, - decltype(a_c_k_block_mtx_desc), - decltype(b_c_wn_block_mtx_desc), - decltype(c_k_wn_thread_mtx_desc), - 0, - in_c_h_w_n_block_desc.GetStride(I1), - out_k_h_w_n_thread_desc.GetStride(I1), - HoPerBlock, - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - HoPerThread, - GemmDataPerReadA, - GemmDataPerReadB>{}; - - // register - // C++ lambda doesn't capture array, use pointer instead - Float p_out_thread_data[out_k_h_w_n_thread_desc_old.GetElementSpace()]; - Float* const p_out_thread = p_out_thread_data; - - // set threadwise output tensor to 0 - threadwise_matrix_set_zero(c_k_wn_thread_mtx_desc, p_out_thread); - - for(index_t y = 0; y < Y; ++y) - { - for(index_t x = 0; x < X; ++x) - { - for(index_t c_block_data_begin = 0; c_block_data_begin < C; - c_block_data_begin += CPerBlock) - { - blockwise_in_copy.Run(p_in_global, p_in_block); - blockwise_wei_copy.Run(p_wei_global, p_wei_block); - - __syncthreads(); - - blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread); - - __syncthreads(); - - // move along C - blockwise_in_copy.MoveSrcSliceWindow(make_multi_index(CPerBlock, 0, 0, 0), - True); - blockwise_wei_copy.MoveSrcSliceWindow(make_multi_index(CPerBlock, 0, 0, 0), - True); - } - - // reset C - blockwise_in_copy.MoveSrcSliceWindow(make_multi_index(C, 0, 0, 0), False); - blockwise_wei_copy.MoveSrcSliceWindow(make_multi_index(C, 0, 0, 0), False); - - // move aling X - blockwise_in_copy.MoveSrcSliceWindow(make_multi_index(0, 0, 1, 0), True); - blockwise_wei_copy.MoveSrcSliceWindow(make_multi_index(0, 0, 1, 0), True); - } - - // reset X - blockwise_in_copy.MoveSrcSliceWindow(make_multi_index(0, 0, X, 0), False); - blockwise_wei_copy.MoveSrcSliceWindow(make_multi_index(0, 0, X, 0), False); - - // move along Y - blockwise_in_copy.MoveSrcSliceWindow(make_multi_index(0, 1, 0, 0), True); - blockwise_wei_copy.MoveSrcSliceWindow(make_multi_index(0, 1, 0, 0), True); - } - - // output: register to global mem - const auto c_thread_mtx_begin = - blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t k_thread_data_begin = c_thread_mtx_begin.row; - const index_t ho_thread_data_begin = c_thread_mtx_begin.batch; - const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; - const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; - - static_if{}([&](auto fwd) { - // fwd do nothing but perfect forwarding. - // Using this trick to make this lambda a generic lambda, so it won't be compiled until - // being instantiated here - static_assert( - (fwd(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0), - "wrong!"); - - // output is a 10d tensor - constexpr index_t N2 = GemmNPerThreadSubC; - constexpr index_t N1 = NPerBlock / N2; - - constexpr index_t W2 = - (GemmNLevel0Cluster * GemmNLevel1Cluster) / fwd(NPerBlock / GemmNPerThreadSubC); - constexpr index_t W1 = WoPerBlock / W2; - - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = KPerBlock / KPerThread; - - constexpr auto out_10d_global_desc_old = fwd(out_k_h_w_n_global_desc_old) - .Fold(I3, Number{}, Number{}) - .Fold(I2, Number{}, Number{}) - .Fold(I0, Number{}, Number{}); - - constexpr auto out_10d_global_desc = make_native_tensor_descriptor( - out_10d_global_desc_old.GetLengths(), out_10d_global_desc_old.GetStrides()); - - constexpr auto out_10d_thread_desc_old = fwd(out_k_h_w_n_thread_desc_old) - .Fold(I3, Number<1>{}, Number{}) - .Fold(I2, Number{}, Number<1>{}) - .Fold(I0, Number<1>{}, Number{}); - - constexpr auto out_10d_thread_desc = make_native_tensor_descriptor( - out_10d_thread_desc_old.GetLengths(), out_10d_thread_desc_old.GetStrides()); - - Float* p_out_thread_on_global = - p_out_global + - out_k_h_w_n_global_desc.CalculateOffset({k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin, - n_block_data_begin + n_thread_data_begin}); - - ThreadwiseGenericTensorSliceCopy_v4r2::type, - 9, - OutThreadCopyDataPerAccess_N, - OutThreadCopyDataPerAccess_N>( - make_zero_array(), make_zero_array()) - .Run(p_out_thread, p_out_thread_on_global); - }).Else([&](auto fwd) { - static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && - GemmNPerThreadSubC % NPerThread == 0, - "wrong!"); - - // output is a 10d tensor - constexpr index_t N1 = NPerBlock; - - constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock; - constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster; - constexpr index_t W1 = WoPerBlock / fwd(W2 * W3); - - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = KPerBlock / KPerThread; - - constexpr auto out_10d_global_desc_old = - fwd(out_k_h_w_n_global_desc_old) - .Fold(I3, Number{}) - .Fold(I2, Number{}, Number{}, Number{}) - .Fold(I0, Number{}, Number{}); - - constexpr auto out_10d_global_desc = make_native_tensor_descriptor( - out_10d_global_desc_old.GetLengths(), out_10d_global_desc_old.GetStrides()); - - constexpr auto out_10d_thread_desc_old = - fwd(out_k_h_w_n_thread_desc_old) - .Fold(I3, Number{}) - .Fold(I2, Number{}, Number<1>{}, Number{}) - .Fold(I0, Number<1>{}, Number{}); - - constexpr auto out_10d_thread_desc = make_native_tensor_descriptor( - out_10d_thread_desc_old.GetLengths(0), out_10d_thread_desc_old.GetStrides()); - - Float* p_out_thread_on_global = - p_out_global + - out_k_h_w_n_global_desc.CalculateOffset({k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin, - n_block_data_begin + n_thread_data_begin}); - - ThreadwiseGenericTensorSliceCopy_v4r2::type, - 9, - OutThreadCopyDataPerAccess_N, - OutThreadCopyDataPerAccess_N>( - make_zero_array(), make_zero_array()) - .Run(p_out_thread, p_out_thread_on_global); - }); - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hpp deleted file mode 100644 index a5736272b8..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hpp +++ /dev/null @@ -1,451 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW -#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW - -#include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "blockwise_2d_tensor_op.hpp" -#include "blockwise_tensor_slice_copy.hpp" -#include "threadwise_tensor_slice_copy.hpp" -#include "threadwise_generic_tensor_op.hpp" -#include "blockwise_batched_gemm.hpp" - -namespace ck { - -template -struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw -{ - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) const - { - // be careful of this assertion - static_assert( - NPerBlock % NPerThread == 0 && - ((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) || - (GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock && - GemmNPerThreadSubC % NPerThread == 0)), - "wrong!"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{}; - constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{}; - constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{}; - - constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1); - - constexpr index_t N = out_n_k_h_w_global_desc.GetLength(I0); - constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1); - constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2); - constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3); - - constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1); - constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2); - - // divide block work: [N, K, Ho, Wo] - static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 && - Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, - "wrong! cannot evenly divide work for workgroup "); - - constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock); - constexpr index_t KBlockWork = math::integer_divide_ceil(K, KPerBlock); - constexpr index_t HBlockWork = math::integer_divide_ceil(Ho, HoPerBlock); - constexpr index_t WBlockWork = math::integer_divide_ceil(Wo, WoPerBlock); - - constexpr auto block_work_desc = make_ConstantTensorDescriptor_packed( - Sequence{}); - - const auto block_work_multi_id = - block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id()); - - const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock; - const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock; - const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock; - const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock; - - const index_t hi_block_data_begin = ho_block_data_begin; - const index_t wi_block_data_begin = wo_block_data_begin; - - // global tensor view - constexpr auto wei_c_k_global_desc = - make_ConstantTensorDescriptor(Sequence{}, Sequence{}); - - // LDS tensor view - // be careful of alignment - constexpr index_t max_align = math::lcm(InBlockReorderDataPerWrite_N, - WeiBlockCopyDataPerRead_K, - GemmDataPerReadA, - GemmDataPerReadB); - - constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, - Number{}); - - // this check is ad-hoc - // TODO: need to properly implement tensor descriptor with alignment - static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0, - "GemmDataPerReadB alignment requirement is not meet"); - - constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, - Number{}); - - // tensor view of threadwise output in register - constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed( - Sequence{}); - - // blockwise copy - // input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N] - constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{}; - - const auto blockwise_in_copy_reorder = BlockwiseTensorSliceReorderCopy_v3< - BlockSize, - Float, - decltype(in_n_c_h_w_global_desc), - decltype(in_c_h_w_n_block_desc), - Sequence, - InBlockReorderSrcSubLengths_NCHW, - InBlockReorderSrcClusterLengths_NCHW, - decltype(map_chwn2nchw), - InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW, - InBlockReorderDataPerRead_W, - InBlockReorderDataPerWrite_N>({0, 0, 0, 0}, {0, 0, 0, 0}); - - // blockwise wei copy - // format is [CPerBlock, KPerBlock] - const auto blockwise_wei_copy = - Blockwise2dTensorCopy3({0, 0}, {0, 0}); - - // a series of blockwise batched GEMM - // C_matrix += transpose(A_matrix) * B_matrix - // A_matrix and B_matrix saved in LDS, C_matrix saved in register - // A_matrix[C,K] is a sub-matrix of wei_block[C,K] - // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] - // C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N] - constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_c_wn_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - constexpr auto c_k_wn_thread_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - const auto blockwise_batch_gemm = - BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2< - BlockSize, - decltype(a_c_k_block_mtx_desc), - decltype(b_c_wn_block_mtx_desc), - decltype(c_k_wn_thread_mtx_desc), - 0, - in_c_h_w_n_block_desc.GetStride(I1), - out_k_h_w_n_thread_desc.GetStride(I1), - HoPerBlock, - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - HoPerThread, - GemmDataPerReadA, - GemmDataPerReadB>{}; - - // choose GEMM implementation here - const auto run_blockwise_batch_gemm = [&](auto... Xs) { -#if 1 - return blockwise_batch_gemm.Run(Xs...); -#elif 0 - return blockwise_batch_gemm.Run_amd_asm(Xs...); -#else - return blockwise_batch_gemm.Run_asm_v2(Xs...); -#endif - }; - - // LDS: be careful of alignment - constexpr index_t in_block_space = - in_c_h_w_n_block_desc.GetElementSpace(Number{}); - constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number{}); - - __shared__ Float p_in_block[in_block_space]; - __shared__ Float p_wei_block[wei_block_space]; - - // register - // C++ lambda doesn't capture array, use pointer instead - Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()]; - Float* const p_out_thread = p_out_thread_data; - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc"); - print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc"); - - print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc"); - print_ConstantTensorDescriptor(wei_c_k_block_desc, "wei_c_k_block_desc"); - - printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space); - } -#endif - - // set threadwise output tensor to 0 - threadwise_generic_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread); - -#if 0 - const Float* p_in_global_block_offset = - p_in_global + - in_n_c_h_w_global_desc.GetOffsetFromMultiIndex( - n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin); - - const Float* p_wei_global_block_offset = - p_wei_global + wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin); - - for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, - p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1), - p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0)) - { - for(index_t y = 0; y < Y; ++y) - { - for(index_t x = 0; x < X; ++x) - { - blockwise_in_copy_reorder.Run(p_in_global_block_offset + - in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x), - p_in_block); - - blockwise_wei_copy.Run(p_wei_global_block_offset + - wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0), - p_wei_block); - - __syncthreads(); - - run_blockwise_batch_gemm(p_wei_block, p_in_block, p_out_thread); - - __syncthreads(); - } - } - } -#else - for(index_t y = 0; y < Y; ++y) - { - for(index_t x = 0; x < X; ++x) - { - const Float* p_in_global_block_offset = - p_in_global + - in_n_c_h_w_global_desc.GetOffsetFromMultiIndex( - n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin + x); - - const Float* p_wei_global_block_offset = - p_wei_global + - wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, k_block_data_begin); - - for(index_t c_block_data_begin = 0; c_block_data_begin < C; - c_block_data_begin += CPerBlock, - p_in_global_block_offset += - CPerBlock * in_n_c_h_w_global_desc.GetStride(I1), - p_wei_global_block_offset += - CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0)) - { - blockwise_in_copy_reorder.Run(p_in_global_block_offset, p_in_block); - - blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block); - - __syncthreads(); - - run_blockwise_batch_gemm(p_wei_block, p_in_block, p_out_thread); - - __syncthreads(); - } - } - } -#endif - - // output: register to global mem, - const auto c_thread_mtx_begin = - blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t k_thread_data_begin = c_thread_mtx_begin.row; - const index_t ho_thread_data_begin = c_thread_mtx_begin.batch; - const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; - const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; - - static_if{}([&](auto fwd) { - // fwd do nothing but perfect forwarding. - // Using this trick to make this lambda a generic lambda, so it won't be compiled until - // begin instantiated here - static_assert( - (fwd(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0), - "wrong!"); - - // output is a 10d tensor - constexpr index_t N2 = GemmNPerThreadSubC; - constexpr index_t N1 = NPerBlock / N2; - - constexpr index_t W2 = - (GemmNLevel0Cluster * GemmNLevel1Cluster) / fwd(NPerBlock / GemmNPerThreadSubC); - constexpr index_t W1 = WoPerBlock / W2; - - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = KPerBlock / KPerThread; - - constexpr auto out_10d_global_desc = fwd(out_n_k_h_w_global_desc) - .Fold(I3, Number{}, Number{}) - .Fold(I1, Number{}, Number{}) - .Fold(I0, Number{}, Number{}); - - constexpr auto out_10d_thread_desc = fwd(out_k_h_w_n_thread_desc) - .Fold(I3, Number<1>{}, Number{}) - .Fold(I2, Number{}, Number<1>{}) - .Fold(I0, Number<1>{}, Number{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, - "a: out_k_h_w_n_thread_desc"); - print_ConstantTensorDescriptor(out_10d_thread_desc, "a: out_10d_thread_desc"); - - print_ConstantTensorDescriptor(out_n_k_h_w_global_desc, - "a: out_n_k_h_w_global_desc"); - print_ConstantTensorDescriptor(out_10d_global_desc, "a: out_10d_global_desc"); - } -#endif - - constexpr auto map_out_global2thread = Sequence<7, 8, 9, 0, 1, 2, 3, 4, 5, 6>{}; - - threadwise_tensor_slice_copy_reorder_given_dst2src_v2( - out_10d_thread_desc, - p_out_thread, - out_10d_global_desc, - p_out_global + - out_n_k_h_w_global_desc.GetOffsetFromMultiIndex( - n_block_data_begin + n_thread_data_begin, - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin), - out_10d_thread_desc.GetLengths(), - map_out_global2thread); - // Number{}); - }).Else([&](auto fwd) { - static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && - GemmNPerThreadSubC % NPerThread == 0, - "wrong!"); - - // output is a 10d tensor - constexpr index_t N1 = NPerBlock; - - constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock; - constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster; - constexpr index_t W1 = WoPerBlock / fwd(W2 * W3); - - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = KPerBlock / KPerThread; - - constexpr auto out_10d_global_desc = - fwd(out_n_k_h_w_global_desc) - .Fold(I3, Number{}, Number{}, Number{}) - .Fold(I1, Number{}, Number{}) - .Fold(I0, Number{}); - - constexpr auto out_10d_thread_desc = - fwd(out_k_h_w_n_thread_desc) - .Fold(I3, Number{}) - .Fold(I2, Number{}, Number<1>{}, Number{}) - .Fold(I0, Number<1>{}, Number{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, - "b: out_k_h_w_n_thread_desc"); - print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc"); - - print_ConstantTensorDescriptor(out_n_k_h_w_global_desc, - "b: out_n_k_h_w_global_desc"); - print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc"); - } -#endif - - constexpr auto map_out_global2thread = Sequence<8, 9, 0, 1, 2, 3, 4, 5, 6, 7>{}; - -#if 0 - threadwise_tensor_slice_copy_reorder_given_dst2src_v3( - out_10d_thread_desc, - p_out_thread, - out_10d_global_desc, - p_out_global + - out_n_k_h_w_global_desc.GetOffsetFromMultiIndex( - n_block_data_begin + n_thread_data_begin, - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin), - out_10d_thread_desc.GetLengths(), - map_out_global2thread, - Number{}); -#else - threadwise_generic_tensor_slice_copy_v1( - out_10d_thread_desc.ReorderGivenNew2Old(map_out_global2thread), - p_out_thread, - make_zero_array(), - out_10d_global_desc, - p_out_global + - out_n_k_h_w_global_desc.GetOffsetFromMultiIndex( - n_block_data_begin + n_thread_data_begin, - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin), - make_zero_array(), - out_10d_thread_desc.GetLengths().ReorderGivenNew2Old(map_out_global2thread), - arithmetic_sequence_gen<0, 10, 1>::type{}, - Number<1>{}); -#endif - }); - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp deleted file mode 100644 index 8d757056eb..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp +++ /dev/null @@ -1,499 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER -#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER - -#include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "blockwise_2d_tensor_op.hpp" -#include "blockwise_tensor_slice_copy.hpp" -#include "threadwise_tensor_slice_copy.hpp" -#include "threadwise_generic_tensor_op.hpp" -#include "blockwise_batched_gemm.hpp" - -namespace ck { - -template -struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer -{ - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) const - { - // be careful of this assertion - static_assert( - NPerBlock % NPerThread == 0 && - ((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) || - (GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock && - GemmNPerThreadSubC % NPerThread == 0)), - "wrong!"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{}; - constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{}; - constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{}; - - constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1); - - constexpr index_t N = out_n_k_h_w_global_desc.GetLength(I0); - constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1); - constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2); - constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3); - - constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1); - constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2); - - // assert for LDS double buffer - static_assert(C % (2 * CPerBlock) == 0, "C cannot be evenly divided"); - - // divide block work: [K, Ho, Wo, N] - static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 && - Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0, - "wrong! cannot evenly divide work for workgroup "); - - constexpr index_t NBlockWork = math::integer_divide_ceil(N, NPerBlock); - constexpr index_t KBlockWork = math::integer_divide_ceil(K, KPerBlock); - constexpr index_t HBlockWork = math::integer_divide_ceil(Ho, HoPerBlock); - constexpr index_t WBlockWork = math::integer_divide_ceil(Wo, WoPerBlock); - - constexpr auto block_work_desc = make_ConstantTensorDescriptor_packed( - Sequence{}); - - const auto block_work_multi_id = - block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id()); - - const index_t n_block_data_begin = block_work_multi_id[0] * NPerBlock; - const index_t k_block_data_begin = block_work_multi_id[1] * KPerBlock; - const index_t ho_block_data_begin = block_work_multi_id[2] * HoPerBlock; - const index_t wo_block_data_begin = block_work_multi_id[3] * WoPerBlock; - - const index_t hi_block_data_begin = ho_block_data_begin; - const index_t wi_block_data_begin = wo_block_data_begin; - - // global tensor view - constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3); - - // LDS tensor view - // be careful of alignment - constexpr index_t max_align = math::lcm(InBlockReorderDataPerWrite_N, - WeiBlockCopyDataPerRead_K, - GemmDataPerReadA, - GemmDataPerReadB); - - constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, - Number{}); - - // this check is ad-hoc - // TODO: need to properly implement tensor descriptor with multiple alignment requirements - static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0, - "GemmDataPerReadB alignment requirement is not meet"); - - constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, - Number{}); - - // tensor view of threadwise output in register - constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor_packed( - Sequence{}); - - // blockwise copy - // input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N] - constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{}; - - const auto blockwise_in_copy_reorder = BlockwiseTensorSliceReorderCopy_v3< - BlockSize, - Float, - decltype(in_n_c_h_w_global_desc), - decltype(in_c_h_w_n_block_desc), - Sequence, - InBlockReorderSrcSubLengths_NCHW, - InBlockReorderSrcClusterLengths_NCHW, - decltype(map_chwn2nchw), - InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW, - InBlockReorderDataPerRead_W, - InBlockReorderDataPerWrite_N>({0, 0, 0, 0}, {0, 0, 0, 0}); - - // blockwise wei copy - // format is [CPerBlock, KPerBlock] - const auto blockwise_wei_copy = - Blockwise2dTensorCopy3({0, 0}, {0, 0}); - - // a series of blockwise batched GEMM - // C_matrix += transpose(A_matrix) * B_matrix - // A_matrix and B_matrix saved in LDS, C_matrix saved in register - // A_matrix[C,K] is a sub-matrix of wei_block[C,K] - // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] - // C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N] - constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_c_wn_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - constexpr auto c_k_wn_thread_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - const auto blockwise_batch_gemm = - BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2< - BlockSize, - decltype(a_c_k_block_mtx_desc), - decltype(b_c_wn_block_mtx_desc), - decltype(c_k_wn_thread_mtx_desc), - 0, - in_c_h_w_n_block_desc.GetStride(I1), - out_k_h_w_n_thread_desc.GetStride(I1), - HoPerBlock, - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - HoPerThread, - GemmDataPerReadA, - GemmDataPerReadB>{}; - - // choose GEMM implementation here - const auto run_blockwise_batch_gemm = [&](auto... Xs) { -#if 1 - return blockwise_batch_gemm.Run(Xs...); -#elif 0 - return blockwise_batch_gemm.Run_amd_asm(Xs...); -#else - return blockwise_batch_gemm.Run_asm_v2(Xs...); -#endif - }; - - // LDS: be careful of alignment - constexpr index_t in_block_space = - in_c_h_w_n_block_desc.GetElementSpace(Number{}); - constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number{}); - - // LDS double buffer - __shared__ Float p_in_block_double[2 * in_block_space]; - __shared__ Float p_wei_block_double[2 * wei_block_space]; - - // register - // C++ lambda doesn't capture array, use pointer instead - Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()]; - Float* const p_out_thread = p_out_thread_data; - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc"); - print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc"); - - print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc"); - print_ConstantTensorDescriptor(wei_c_k_block_desc, "wei_c_k_block_desc"); - - printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space); - } -#endif - - // set threadwise output tensor to 0 - threadwise_generic_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread); - - for(index_t y = 0; y < Y; ++y) - { - for(index_t x = 0; x < X; ++x) - { - const Float* p_in_global_block_offset = - p_in_global + - in_n_c_h_w_global_desc.GetOffsetFromMultiIndex( - n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin + x); - - const Float* p_wei_global_block_offset = - p_wei_global + - wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, k_block_data_begin); - - // LDS double buffer: preload data into LDS - { - Float p_in_register_buffer[blockwise_in_copy_reorder.GetRegisterBufferSize()]; - Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; - - blockwise_in_copy_reorder.RunLoadRegisterBuffer(p_in_global_block_offset, - p_in_register_buffer); - blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset, - p_wei_register_buffer); - - blockwise_in_copy_reorder.RunStoreRegisterBuffer(p_in_register_buffer, - p_in_block_double); - blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, - p_wei_block_double); - } - - // LDS double buffer: main body - for(index_t c_block_data_begin = 0; c_block_data_begin + 2 * CPerBlock < C; - c_block_data_begin += 2 * CPerBlock) - { -#pragma unroll - for(index_t iloop = 0; iloop < 2; ++iloop) - { - const bool even_loop = (iloop % 2 == 0); - - Float* p_in_block_now = - even_loop ? p_in_block_double : p_in_block_double + in_block_space; - Float* p_wei_block_now = - even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space; - - Float* p_in_block_next = - even_loop ? p_in_block_double + in_block_space : p_in_block_double; - Float* p_wei_block_next = - even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; - - Float - p_in_register_buffer[blockwise_in_copy_reorder.GetRegisterBufferSize()]; - Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; - - p_in_global_block_offset += - CPerBlock * in_n_c_h_w_global_desc.GetStride(I1); - p_wei_global_block_offset += - CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - blockwise_in_copy_reorder.RunLoadRegisterBuffer(p_in_global_block_offset, - p_in_register_buffer); - blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset, - p_wei_register_buffer); - - // LDS double buffer: GEMM on current data - run_blockwise_batch_gemm(p_wei_block_now, p_in_block_now, p_out_thread); - - // LDS double buffer: store next data to LDS - blockwise_in_copy_reorder.RunStoreRegisterBuffer(p_in_register_buffer, - p_in_block_next); - blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, - p_wei_block_next); - } - } - - // LDS double buffer: tail - { - Float p_in_register_buffer[blockwise_in_copy_reorder.GetRegisterBufferSize()]; - Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; - - // even iteration - p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1); - p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - blockwise_in_copy_reorder.RunLoadRegisterBuffer(p_in_global_block_offset, - p_in_register_buffer); - blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset, - p_wei_register_buffer); - - // LDS double buffer: GEMM on current data - run_blockwise_batch_gemm(p_wei_block_double, p_in_block_double, p_out_thread); - - // LDS double buffer: store next data to LDS - blockwise_in_copy_reorder.RunStoreRegisterBuffer( - p_in_register_buffer, p_in_block_double + in_block_space); - blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, - p_wei_block_double + wei_block_space); - - // odd iteration - __syncthreads(); - - // LDS double buffer: GEMM on current data - run_blockwise_batch_gemm(p_wei_block_double + wei_block_space, - p_in_block_double + in_block_space, - p_out_thread); - } - } - } - - // output: register to global mem, - const auto c_thread_mtx_begin = - blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t k_thread_data_begin = c_thread_mtx_begin.row; - const index_t ho_thread_data_begin = c_thread_mtx_begin.batch; - const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock; - const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock; - - static_if{}([&](auto fwd) { - // fwd do nothing but perfect forwarding. - // Using this trick to make this lambda a generic lambda, so it won't be compiled until - // begin instantiated here - static_assert( - (fwd(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0), - "wrong!"); - - // output is a 10d tensor - constexpr index_t N2 = GemmNPerThreadSubC; - constexpr index_t N1 = NPerBlock / N2; - - constexpr index_t W2 = - (GemmNLevel0Cluster * GemmNLevel1Cluster) / fwd(NPerBlock / GemmNPerThreadSubC); - constexpr index_t W1 = WoPerBlock / W2; - - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = KPerBlock / KPerThread; - - constexpr auto out_10d_global_desc = fwd(out_n_k_h_w_global_desc) - .Fold(I3, Number{}, Number{}) - .Fold(I1, Number{}, Number{}) - .Fold(I0, Number{}, Number{}); - - constexpr auto out_10d_thread_desc = fwd(out_k_h_w_n_thread_desc) - .Fold(I3, Number<1>{}, Number{}) - .Fold(I2, Number{}, Number<1>{}) - .Fold(I0, Number<1>{}, Number{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, - "a: out_k_h_w_n_thread_desc"); - print_ConstantTensorDescriptor(out_10d_thread_desc, "a: out_10d_thread_desc"); - - print_ConstantTensorDescriptor(out_n_k_h_w_global_desc, - "a: out_n_k_h_w_global_desc"); - print_ConstantTensorDescriptor(out_10d_global_desc, "a: out_10d_global_desc"); - } -#endif - - constexpr auto map_out_global2thread = Sequence<7, 8, 9, 0, 1, 2, 3, 4, 5, 6>{}; - - threadwise_tensor_slice_copy_reorder_given_dst2src_v2( - out_10d_thread_desc, - p_out_thread, - out_10d_global_desc, - p_out_global + - out_n_k_h_w_global_desc.GetOffsetFromMultiIndex( - n_block_data_begin + n_thread_data_begin, - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin), - out_10d_thread_desc.GetLengths(), - map_out_global2thread); - // Number{}); - }).Else([&](auto fwd) { - static_assert(fwd(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock && - GemmNPerThreadSubC % NPerThread == 0, - "wrong!"); - - // output is a 10d tensor - constexpr index_t N1 = NPerBlock; - - constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock; - constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster; - constexpr index_t W1 = WoPerBlock / fwd(W2 * W3); - - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = KPerBlock / KPerThread; - - constexpr auto out_10d_global_desc = - fwd(out_n_k_h_w_global_desc) - .Fold(I3, Number{}, Number{}, Number{}) - .Fold(I1, Number{}, Number{}) - .Fold(I0, Number{}); - - constexpr auto out_10d_thread_desc = - fwd(out_k_h_w_n_thread_desc) - .Fold(I3, Number{}) - .Fold(I2, Number{}, Number<1>{}, Number{}) - .Fold(I0, Number<1>{}, Number{}); - -#if 0 - if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) - { - print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc, - "b: out_k_h_w_n_thread_desc"); - print_ConstantTensorDescriptor(out_10d_thread_desc, "b: out_10d_thread_desc"); - - print_ConstantTensorDescriptor(out_n_k_h_w_global_desc, - "b: out_n_k_h_w_global_desc"); - print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc"); - } -#endif - - constexpr auto map_out_global2thread = Sequence<8, 9, 0, 1, 2, 3, 4, 5, 6, 7>{}; - -#if 0 - threadwise_tensor_slice_copy_reorder_given_dst2src_v3( - out_10d_thread_desc, - p_out_thread, - out_10d_global_desc, - p_out_global + - out_n_k_h_w_global_desc.GetOffsetFromMultiIndex( - n_block_data_begin + n_thread_data_begin, - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin), - out_10d_thread_desc.GetLengths(), - map_out_global2thread, - Number{}); -#else - threadwise_generic_tensor_slice_copy_v1( - out_10d_thread_desc.ReorderGivenNew2Old(map_out_global2thread), - p_out_thread, - make_zero_array(), - out_10d_global_desc, - p_out_global + - out_n_k_h_w_global_desc.GetOffsetFromMultiIndex( - n_block_data_begin + n_thread_data_begin, - k_block_data_begin + k_thread_data_begin, - ho_block_data_begin + ho_thread_data_begin, - wo_block_data_begin + wo_thread_data_begin), - make_zero_array(), - out_10d_thread_desc.GetLengths().ReorderGivenNew2Old(map_out_global2thread), - arithmetic_sequence_gen<0, 10, 1>::type{}, - Number<1>{}); -#endif - }); - } -}; - -} // namespace -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp deleted file mode 100644 index dc02655f30..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp +++ /dev/null @@ -1,283 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN -#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN - -#include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "blockwise_4d_tensor_op.hpp" -#include "blockwise_2d_tensor_op.hpp" -#include "blockwise_gemm.hpp" - -namespace ck { - -// define B = flatten(N, Hi, Wi) -template -struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn -{ - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_chwn_global_desc = InGlobalDesc{}; - constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; - constexpr auto out_khwn_global_desc = OutGlobalDesc{}; - - constexpr index_t C = in_chwn_global_desc.GetLength(I0); - constexpr index_t Hi = in_chwn_global_desc.GetLength(I1); - constexpr index_t Wi = in_chwn_global_desc.GetLength(I2); - constexpr index_t N = in_chwn_global_desc.GetLength(I3); - - constexpr index_t K = out_khwn_global_desc.GetLength(I0); - constexpr index_t Ho = out_khwn_global_desc.GetLength(I1); - constexpr index_t Wo = out_khwn_global_desc.GetLength(I2); - - constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1); - constexpr index_t X = wei_cyxk_global_desc.GetLength(I2); - - constexpr index_t B = N * Hi * Wi; - constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1); - - // divide block work by 2d: [K, B] - constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr index_t BBlockWork = (B + BPerBlock - 1) / BPerBlock; - - const index_t k_block_work_id = get_block_1d_id() / BBlockWork; - const index_t b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork; - - const index_t k_block_data_begin = k_block_work_id * KPerBlock; - const index_t b_block_data_begin = b_block_work_id * BPerBlock; - - // flattend (2d) tensor view of gridwise input - constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence{}); - constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); - - // tensor view of blockwise input and weight - // be careful of alignment - constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - // tensor view of threadwise output in register - constexpr auto out_kb_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); - -// blockwise in copy -// formmat is [CPerBlock,BPerBlock + BGhostRead] -#if 0 - const auto blockwise_in_copy = - Blockwise2dTensorCopy1{}; -#elif 0 - const auto blockwise_in_copy = - Blockwise2dTensorCopy2{}; -#elif 1 - const auto blockwise_in_copy = - Blockwise2dTensorCopy3{}; -#endif - -// blockwise wei copy -// format is [CPerBlock*Y*X,KPerBlock] -#if 0 - const auto blockwise_wei_copy = - Blockwise2dTensorCopy1{}; -#elif 0 - const auto blockwise_wei_copy = - Blockwise2dTensorCopy2{}; -#elif 1 - const auto blockwise_wei_copy = - Blockwise2dTensorCopy3{}; -#endif - - // a series of blockwise GEMM - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx and b_mtx saved in LDS, c_mtx saved in register - // a_mtx[C,K] is a sub-matrix of wei_block[C,Y,X,K] - // b_mtx[C,B] is a subset of in_block[C,B + BGhostRead] - // c_mtx[K,B] is out_block[K,B] - constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto c_kxb_thread_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - const auto blockwise_gemm = - BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2{}; - - // LDS: be careful of alignment - constexpr index_t max_align = - math::lcm(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead); - - constexpr index_t in_block_space = in_cb_block_desc.GetElementSpace(Number{}); - - constexpr index_t wei_block_space = - wei_cyxk_block_desc.GetElementSpace(Number{}); - - __shared__ Float p_in_block[in_block_space]; - __shared__ Float p_wei_block[wei_block_space]; - - const Float* p_in_global_block_offset = - p_in_global + in_cb_global_desc.GetOffsetFromMultiIndex(0, b_block_data_begin); - - const Float* p_wei_global_block_offset = - p_wei_global + - wei_cyxk_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin); - - // register - Float p_out_thread[out_kb_thread_desc.GetElementSpace()]; - - // set threadwise output to 0 - threadwise_matrix_set_zero(c_kxb_thread_mtx_desc, p_out_thread); - - for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock, - p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0), - p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0), - __syncthreads()) - { - // load data - Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; - Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; - - blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, p_in_register_buffer); - blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset, - p_wei_register_buffer); - - blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block); - blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block); - - __syncthreads(); - - // compute on current data - // a series of GEMM - for(index_t y = 0; y < Y; ++y) - { - for(index_t x = 0; x < X; ++x) - { -#if 1 - blockwise_gemm.Run -#elif 0 - blockwise_gemm.Run_RegisterDoubleBuffer -#elif 1 - blockwise_gemm.Run_amd_asm -#endif - (p_wei_block + wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0), - p_in_block + y * Wi + x, - p_out_thread); - } - } - } - - // output: register to global mem, - const auto c_thread_mtx_begin = - blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row; - const index_t b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col; - - for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) - { - for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) - { - const auto c_thread_mtx_distance = - blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b); - - index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row; - index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col; - - index_t h_data = b_data / (Wi * N); - index_t itmp = b_data - h_data * (Wi * N); - index_t w_data = itmp / N; - index_t n_data = itmp - w_data * N; - - if(n_data < N && h_data < Ho && w_data < Wo) - { - p_out_global[out_khwn_global_desc.GetOffsetFromMultiIndex( - k_data, h_data, w_data, n_data)] = - p_out_thread[out_kb_thread_desc.GetOffsetFromMultiIndex(k, b)]; - } - } - } - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hpp deleted file mode 100644 index 4b3ab8f7cd..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hpp +++ /dev/null @@ -1,408 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER -#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER - -#include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "blockwise_4d_tensor_op.hpp" -#include "blockwise_2d_tensor_op.hpp" -#include "threadwise_tensor_slice_copy.hpp" -#include "blockwise_gemm.hpp" - -namespace ck { - -// define B = flatten(N, Hi, Wi) -template -struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer -{ - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_chwn_global_desc = InGlobalDesc{}; - constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{}; - constexpr auto out_khwn_global_desc = OutGlobalDesc{}; - - constexpr index_t C = in_chwn_global_desc.GetLength(I0); - constexpr index_t Hi = in_chwn_global_desc.GetLength(I1); - constexpr index_t Wi = in_chwn_global_desc.GetLength(I2); - constexpr index_t N = in_chwn_global_desc.GetLength(I3); - - constexpr index_t K = out_khwn_global_desc.GetLength(I0); - constexpr index_t Ho = out_khwn_global_desc.GetLength(I1); - constexpr index_t Wo = out_khwn_global_desc.GetLength(I2); - - constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1); - constexpr index_t X = wei_cyxk_global_desc.GetLength(I2); - - constexpr index_t B = N * Hi * Wi; - constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1); - - // assert for LDS double buffer - static_assert(C % (2 * CPerBlock) == 0, "C cannot be evenly divided"); - - // divide block work by 2d: [K, B] - constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock; - constexpr index_t BBlockWork = (B + BPerBlock - 1) / BPerBlock; - - const index_t k_block_work_id = get_block_1d_id() / BBlockWork; - const index_t b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork; - - const index_t k_block_data_begin = k_block_work_id * KPerBlock; - const index_t b_block_data_begin = b_block_work_id * BPerBlock; - - // flattend (2d) tensor view of gridwise input - constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence{}); - constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence{}); - - // tensor view of blockwise input and weight - // be careful of alignment - constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - // tensor view of threadwise output in register - constexpr auto out_kb_thread_desc = - make_ConstantTensorDescriptor(Sequence{}); - -// blockwise in copy -// formmat is [CPerBlock,BPerBlock + BGhostRead] -#if 0 - const auto blockwise_in_copy = - Blockwise2dTensorCopy1{}; -#elif 0 - const auto blockwise_in_copy = - Blockwise2dTensorCopy2{}; -#elif 1 - const auto blockwise_in_copy = - Blockwise2dTensorCopy3{}; -#endif - -// blockwise wei copy -// format is [CPerBlock*Y*X,KPerBlock] -#if 0 - const auto blockwise_wei_copy = - Blockwise2dTensorCopy1{}; -#elif 0 - const auto blockwise_wei_copy = - Blockwise2dTensorCopy2{}; -#elif 1 - const auto blockwise_wei_copy = - Blockwise2dTensorCopy3{}; -#endif - - // a series of blockwise GEMM - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx and b_mtx saved in LDS, c_mtx saved in register - // a_mtx[C,K] is a sub-matrix of wei_block[C,Y,X,K] - // b_mtx[C,B] is a subset of in_block[C,B + BGhostRead] - // c_mtx[K,B] is out_block[K,B] - constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto c_kxb_thread_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, Number{}); - - const auto blockwise_gemm = - BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2{}; - - // LDS: be careful of alignment - constexpr index_t max_align = - math::lcm(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead); - - constexpr index_t in_block_space = in_cb_block_desc.GetElementSpace(Number{}); - - constexpr index_t wei_block_space = - wei_cyxk_block_desc.GetElementSpace(Number{}); - - // LDS double buffer - __shared__ Float p_in_block_double[2 * in_block_space]; - __shared__ Float p_wei_block_double[2 * wei_block_space]; - - const Float* p_in_global_block_offset = - p_in_global + in_cb_global_desc.GetOffsetFromMultiIndex(0, b_block_data_begin); - - const Float* p_wei_global_block_offset = - p_wei_global + - wei_cyxk_global_desc.GetOffsetFromMultiIndex(0, 0, 0, k_block_data_begin); - - // preload data into LDS - { - Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; - Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; - - blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, p_in_register_buffer); - blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset, - p_wei_register_buffer); - - blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_double); - blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_double); - } - - // register - Float p_out_thread[out_kb_thread_desc.GetElementSpace()]; - - // set threadwise output to 0 - threadwise_matrix_set_zero(c_kxb_thread_mtx_desc, p_out_thread); - - for(index_t c_block_data_begin = 0; c_block_data_begin + 2 * CPerBlock < C; - c_block_data_begin += 2 * CPerBlock) - { -#pragma unroll - for(index_t iloop = 0; iloop < 2; ++iloop) - { - const bool even_loop = (iloop % 2 == 0); - - Float* p_in_block_now = - even_loop ? p_in_block_double : p_in_block_double + in_block_space; - Float* p_wei_block_now = - even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space; - - Float* p_in_block_next = - even_loop ? p_in_block_double + in_block_space : p_in_block_double; - Float* p_wei_block_next = - even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; - - // load next data - Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; - Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; - - p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0); - p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0); - - __syncthreads(); - - blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, - p_in_register_buffer); - blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset, - p_wei_register_buffer); - - // compute on current data - // a series of GEMM - for(index_t y = 0; y < Y; ++y) - { - for(index_t x = 0; x < X; ++x) - { -#if 1 - blockwise_gemm.Run -#elif 0 - blockwise_gemm.Run_RegisterDoubleBuffer -#elif 0 - blockwise_gemm.Run_amd_asm -#endif - (p_wei_block_now + - wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0), - p_in_block_now + y * Wi + x, - p_out_thread); - } - } - - blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next); - blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next); - } - } - - // tail - { - // even - p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0); - p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0); - - __syncthreads(); - - Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; - Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; - - blockwise_in_copy.RunLoadRegisterBuffer(p_in_global_block_offset, p_in_register_buffer); - - blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_global_block_offset, - p_wei_register_buffer); - - for(index_t y = 0; y < Y; ++y) - { - for(index_t x = 0; x < X; ++x) - { -#if 1 - blockwise_gemm.Run -#elif 0 - blockwise_gemm.Run_RegisterDoubleBuffer -#elif 0 - blockwise_gemm.Run_amd_asm -#endif - (p_wei_block_double + - wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0), - p_in_block_double + y * Wi + x, - p_out_thread); - } - } - - blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, - p_in_block_double + in_block_space); - blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, - p_wei_block_double + wei_block_space); - - // odd - __syncthreads(); - - for(index_t y = 0; y < Y; ++y) - { - for(index_t x = 0; x < X; ++x) - { -#if 1 - blockwise_gemm.Run -#elif 0 - blockwise_gemm.Run_RegisterDoubleBuffer -#elif 0 - blockwise_gemm.Run_amd_asm -#endif - (p_wei_block_double + wei_block_space + - wei_cyxk_block_desc.GetOffsetFromMultiIndex(0, y, x, 0), - p_in_block_double + in_block_space + y * Wi + x, - p_out_thread); - } - } - } - - // output: register to global mem, - const auto c_thread_mtx_begin = - blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t k_thread_data_begin = k_block_data_begin + c_thread_mtx_begin.row; - const index_t b_thread_data_begin = b_block_data_begin + c_thread_mtx_begin.col; - - if(Y == 1 && X == 1) - { // pure 1x1 conv (non padding, 1x1 stride) - constexpr index_t K2_ = GemmMPerThreadSubC; - constexpr index_t K1_ = KPerBlock / KPerThread; - constexpr index_t B2_ = GemmNPerThreadSubC; - constexpr index_t B1_ = BPerBlock / BPerThread; - - constexpr auto out_6d_global_desc = make_ConstantTensorDescriptor( - Sequence{}); - - constexpr auto out_6d_thread_desc = make_ConstantTensorDescriptor( - Sequence{}); - - constexpr auto out_kb_global_desc = make_ConstantTensorDescriptor(Sequence{}); - - threadwise_6d_tensor_copy(out_6d_thread_desc, - p_out_thread, - out_6d_global_desc, - p_out_global + - out_kb_global_desc.GetOffsetFromMultiIndex( - k_thread_data_begin, b_thread_data_begin), - out_6d_thread_desc.GetLengths(), - Number{}); - } - else - { - for(index_t k = 0; k < out_kb_thread_desc.GetLength(I0); ++k) - { - for(index_t b = 0; b < out_kb_thread_desc.GetLength(I1); ++b) - { - const auto c_thread_mtx_distance = - blockwise_gemm.GetDistanceFromBeginOfThreadMatrixC(k, b); - - index_t k_data = k_thread_data_begin + c_thread_mtx_distance.row; - index_t b_data = b_thread_data_begin + c_thread_mtx_distance.col; - - index_t h_data = b_data / (Wi * N); - index_t itmp = b_data - h_data * (Wi * N); - index_t w_data = itmp / N; - index_t n_data = itmp - w_data * N; - - if(n_data < N && h_data < Ho && w_data < Wo) - { - p_out_global[out_khwn_global_desc.GetOffsetFromMultiIndex( - k_data, h_data, w_data, n_data)] = - p_out_thread[out_kb_thread_desc.GetOffsetFromMultiIndex(k, b)]; - } - } - } - } - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp deleted file mode 100644 index 5ae7dc87d3..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp +++ /dev/null @@ -1,376 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW -#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW - -#include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" -#include "ConstantMergedTensorDescriptor_deprecated.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "blockwise_generic_tensor_slice_copy.hpp" -#include "blockwise_gemm.hpp" - -namespace ck { - -// define B = merge(N0, Ho, Wo) -template -struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw -{ - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) const - { - // this is a mess - // TODO: find more elegent way of specifying (or calculating) performance parameters - static_assert(N2 == GemmNPerThreadSubC, "wrong!"); - static_assert((N1 * N2 * BPerBlock) % - (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == - 0, - "wrong!"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - - constexpr auto True = integral_constant{}; - constexpr auto False = integral_constant{}; - - constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{}; - constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{}; - constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{}; - - constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0); - constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1); - constexpr index_t Hi = in_n_c_h_w_global_desc.GetLength(I2); - constexpr index_t Wi = in_n_c_h_w_global_desc.GetLength(I3); - - constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1); - constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2); - constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3); - - constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1); - constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2); - - static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread"); - - constexpr index_t N0 = N / (N1 * N2); - - constexpr index_t B = N0 * Ho * Wo; - - // divide block work by [K, B] - static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && C % CPerBlock == 0, - "wrong! cannot divide work evenly among block"); - - constexpr index_t KBlockWork = K / KPerBlock; - constexpr index_t BBlockWork = B / BPerBlock; - - constexpr auto block_work_desc = - make_ConstantTensorDescriptor_packed(Sequence{}); - - const auto block_work_multi_id = - block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id()); - - const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock; - const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock; - - // input tensor - // memory layout descriptor in device memory [N0, N1, N2, C, H, W] - constexpr auto in_n0_n1_n2_c_h_w_global_mem_desc = - in_n_c_h_w_global_desc.Fold(I0, Number{}, Number{}); - - // merged tensor descriptor in device memory [C, N1, B, N2], src of blockwise copy - constexpr auto in_c_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( - in_n0_n1_n2_c_h_w_global_mem_desc.Slice(I4, Number{}).Slice(I5, Number{}), - Sequence<3>{}, - Sequence<1>{}, - Sequence<0, 4, 5>{}, - Sequence<2>{}); - - // memory layout descriptor in LDS [C, N1, B, N2], dst of blockwise copy - // be careful of LDS alignment - constexpr auto in_c_n1_b_n2_block_mem_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - // this check is ad-hoc - // TODO: need to properly implement tensor descriptor with alignment - static_assert(in_c_n1_b_n2_block_mem_desc.GetStride(I1) % GemmDataPerReadB == 0, - "GemmDataPerReadB alignment requirement is not satisfied"); - - // input blockwise copy - // slice a merged tensor, reorder and copy to a normal tensor - // this copy operator already has blockwise offset built-in - auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated< - BlockSize, - Float, - decltype(in_c_n1_b_n2_global_merged_desc), - decltype(in_c_n1_b_n2_block_mem_desc), - decltype(in_c_n1_b_n2_block_mem_desc.GetLengths()), - InBlockCopySubLengths_C_N1_B_N2, - InBlockCopyClusterLengths_C_N1_B_N2, - Sequence<0, 1, 3, 2>, // thread_arrange_order [C, N1, N2, B] - Sequence<1, 3, 0, 2>, // src_access_order [N1, N2, C, B] - Sequence<0, 1, 2, 3>, // dst_access_order [C, N1, B, N2] - InBlockCopySrcDataPerRead_B, - InBlockCopyDstDataPerWrite_N2>({0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0}); - - // weight tensor - // tensor descriptor in device memory, src of blockwise copy - constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3); - - // tensor descriptor in LDS, dst of blockwise copy - // be careful of LDS alignment - constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, - Number{}); - - // operator for blockwise copy of weight into LDS - // slice a tensor, and copy it into another tensor - // this copy operator already have blockwise offset built-in - auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated< - BlockSize, - Float, - decltype(wei_c_k_global_desc), - decltype(wei_c_k_block_desc), - decltype(wei_c_k_block_desc.GetLengths()), - WeiBlockCopySubLengths_C_K, - WeiBlockCopyClusterLengths_C_K, - Sequence<0, 1>, // thread_arrange_order [C, K] - Sequence<0, 1>, // src_access_order [C, K] - Sequence<0, 1>, // dst_access_order [C, K] - WeiBlockCopyDataPerAccess_K, - WeiBlockCopyDataPerAccess_K>({0, k_block_data_on_global}, {0, 0}); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[CPerBlock, KPerBlock] is in LDS - // b_mtx[CPerBlocl, N1 * BPerBlock * N2] is in LDS - // c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in - // register - constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_c_n1bn2_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - // sanity check - static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == - 0, - "wrong!"); - - constexpr index_t GemmMRepeat = - KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster); - - // c_thread_mtx definition: this is a mess - // TODO:: more elegent way of defining c_thread_mtx - constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}); - - const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< - BlockSize, - decltype(a_c_k_block_mtx_desc), - decltype(b_c_n1bn2_block_mtx_desc), - decltype(c_k0k2_n1n2_thread_mtx_desc), - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - GemmDataPerReadA, - GemmDataPerReadB>{}; - - // choose GEMM implementation here - const auto run_blockwise_gemm = [&](auto... Xs) { -#if 1 - return blockwise_gemm.Run(Xs...); -#else - return blockwise_gemm.Run_amd_asm(Xs...); -#endif - }; - - // LDS allocation for input and weight: be careful of alignment - constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2, - WeiBlockCopyDataPerAccess_K, - GemmDataPerReadA, - GemmDataPerReadB); - - constexpr index_t in_block_space = - in_c_n1_b_n2_block_mem_desc.GetElementSpace(Number{}); - - constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number{}); - - __shared__ Float p_in_block[in_block_space]; - __shared__ Float p_wei_block[wei_block_space]; - - // register allocation for output - Float p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()]; - - // zero out threadwise output - threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread); - -#if 0 - // do work - for(index_t y = 0; y < Y; ++y) - { - for(index_t x = 0; x < X; ++x) - { - // calculate origin of block input and weight tensor on global memory - const Float* p_in_block_on_global = - p_in_global + in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x); - - const Float* p_wei_block_on_global = - p_wei_global + wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0); - - for(index_t - c_block_data_on_global = 0; - c_block_data_on_global < C; - c_block_data_on_global += CPerBlock, - p_in_block_on_global += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1), - p_wei_block_on_global += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0)) - { - blockwise_in_copy.Run(p_in_block_on_global, p_in_block); - blockwise_wei_copy.Run(p_wei_block_on_global, p_wei_block); - - __syncthreads(); - - run_blockwise_gemm(p_wei_block, p_in_block, p_out_thread); - - __syncthreads(); - } - } - } -#else - for(index_t y = 0; y < Y; ++y) - { - for(index_t x = 0; x < X; ++x) - { - // calculate origin of block input and weight tensor on global memory - const Float* p_in_block_on_global = - p_in_global + in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x); - - const Float* p_wei_block_on_global = - p_wei_global + wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0); - - for(index_t c_block_data_on_global = 0; c_block_data_on_global < C; - c_block_data_on_global += CPerBlock) - { - blockwise_in_copy.Run(p_in_block_on_global, p_in_block); - blockwise_wei_copy.Run(p_wei_block_on_global, p_wei_block); - - __syncthreads(); - - blockwise_gemm.Run(p_wei_block, p_in_block, p_out_thread); - - __syncthreads(); - - blockwise_in_copy.MoveSlicingWindowOnSourceTensor( - I0, Number{}, True); - - blockwise_wei_copy.MoveSlicingWindowOnSourceTensor( - I0, Number{}, True); - } - - // reset C - blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number{}, False); - - blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number{}, False); - } - } -#endif - - // copy output: register to global memory - { - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster; - constexpr index_t K0 = K / (K1 * K2); - - // define tensor descriptor for threadwise copy - // output memory layout descriptor in register - constexpr auto out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc = - make_ConstantTensorDescriptor_packed( - Sequence{}); - - // output tensor descriptor in register, src of threadwise copy - constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_thread_desc = - out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc.ReorderGivenNew2Old( - Sequence<4, 3, 7, 0, 1, 2, 5, 6>{}); - - // output memory layout descriptor in device memory, dst of threadwise copy - constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc = - out_n_k_h_w_global_desc.Fold(I1, Number{}, Number{}) - .Fold(I0, Number{}, Number{}); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t k_thread_data_on_global = - k_block_data_on_global + c_thread_mtx_on_block.row; - - const index_t b_thread_data_on_global = - b_block_data_on_global + c_thread_mtx_on_block.col / N2; - - // output merged global tensor descriptor, for calculating origin of thread tensor - // in global memory - constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( - out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5), - Sequence<3>{}, - Sequence<1>{}, - Sequence<0, 4, 5>{}, - Sequence<2>{}); - - // origin of dst in device memory - Float* p_out_thread_on_global = - p_out_global + - out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex( - k_thread_data_on_global, 0, b_thread_data_on_global, 0); - - threadwise_generic_tensor_slice_copy_v1( - out_n0_n1_n2_k0_k1_k2_h_w_thread_desc, - p_out_thread, - {0, 0, 0, 0, 0, 0, 0, 0}, - out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc, - p_out_thread_on_global, - {0, 0, 0, 0, 0, 0, 0, 0}, - out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(), - arithmetic_sequence_gen<0, 8, 1>::type{}, - Number<1>{}); - } - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp deleted file mode 100644 index 2a08be3249..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp +++ /dev/null @@ -1,394 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER -#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER - -#include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" -#include "ConstantMergedTensorDescriptor_deprecated.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "blockwise_generic_tensor_slice_copy.hpp" -#include "blockwise_gemm.hpp" - -namespace ck { - -// define B = merge(N0, Ho, Wo) -template -struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer -{ - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) const - { - // this is a mess - // TODO: find more elegent way of specifying (or calculating) performance parameters - static_assert(N2 == GemmNPerThreadSubC, "wrong!"); - static_assert((N1 * N2 * BPerBlock) % - (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == - 0, - "wrong!"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - - constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{}; - constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{}; - constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{}; - - constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0); - constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1); - constexpr index_t Hi = in_n_c_h_w_global_desc.GetLength(I2); - constexpr index_t Wi = in_n_c_h_w_global_desc.GetLength(I3); - - constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1); - constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2); - constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3); - - constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1); - constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2); - - static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread"); - - constexpr index_t N0 = N / (N1 * N2); - - constexpr index_t B = N0 * Ho * Wo; - - // divide block work by [K, B] - static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && C % (2 * CPerBlock) == 0, - "wrong! cannot divide work evenly among block"); - - constexpr index_t KBlockWork = K / KPerBlock; - constexpr index_t BBlockWork = B / BPerBlock; - - constexpr auto block_work_desc = - make_ConstantTensorDescriptor_packed(Sequence{}); - - const auto block_work_multi_id = - block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id()); - - const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock; - const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock; - - // input tensor - // memory layout descriptor in device memory [N0, N1, N2, C, H, W] - constexpr auto in_n0_n1_n2_c_h_w_global_mem_desc = - in_n_c_h_w_global_desc.Fold(I0, Number{}, Number{}); - - // merged tensor descriptor in device memory [C, N1, B, N2], src of blockwise copy - constexpr auto in_c_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( - in_n0_n1_n2_c_h_w_global_mem_desc.Slice(I4, Number{}).Slice(I5, Number{}), - Sequence<3>{}, - Sequence<1>{}, - Sequence<0, 4, 5>{}, - Sequence<2>{}); - - // memory layout descriptor in LDS [C, N1, B, N2], dst of blockwise copy - // be careful of LDS alignment - constexpr auto in_c_n1_b_n2_block_mem_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - // this check is ad-hoc - // TODO: need to properly implement tensor descriptor with alignment - static_assert(in_c_n1_b_n2_block_mem_desc.GetStride(I1) % GemmDataPerReadB == 0, - "GemmDataPerReadB alignment requirement is not satisfied"); - - // input blockwise copy - // slice a merged tensor, reorder and copy to a normal tensor - // this copy operator already has blockwise offset built-in - const auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated< - BlockSize, - Float, - decltype(in_c_n1_b_n2_global_merged_desc), - decltype(in_c_n1_b_n2_block_mem_desc), - decltype(in_c_n1_b_n2_block_mem_desc.GetLengths()), - InBlockCopySubLengths_C_N1_B_N2, - InBlockCopyClusterLengths_C_N1_B_N2, - Sequence<0, 1, 3, 2>, // thread_arrange_order [C, N1, N2, B] - Sequence<1, 3, 0, 2>, // src_access_order [N1, N2, C, B] - Sequence<0, 1, 2, 3>, // dst_access_order [C, N1, B, N2] - InBlockCopySrcDataPerRead_B, - InBlockCopyDstDataPerWrite_N2>({0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0}); - - // weight tensor - // tensor descriptor in device memory, src of blockwise copy - constexpr auto wei_c_k_global_desc = wei_c_y_x_k_global_desc.Extract(I0, I3); - - // tensor descriptor in LDS, dst of blockwise copy - // be careful of LDS alignment - constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, - Number{}); - - // operator for blockwise copy of weight into LDS - // slice a tensor, and copy it into another tensor - // this copy operator already have blockwise offset built-in - const auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated< - BlockSize, - Float, - decltype(wei_c_k_global_desc), - decltype(wei_c_k_block_desc), - decltype(wei_c_k_block_desc.GetLengths()), - WeiBlockCopySubLengths_C_K, - WeiBlockCopyClusterLengths_C_K, - Sequence<0, 1>, // thread_arrange_order [C, K] - Sequence<0, 1>, // src_access_order [C, K] - Sequence<0, 1>, // dst_access_order [C, K] - WeiBlockCopyDataPerAccess_K, - WeiBlockCopyDataPerAccess_K>({0, k_block_data_on_global}, {0, 0}); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[CPerBlock, KPerBlock] is in LDS - // b_mtx[CPerBlocl, N1 * BPerBlock * N2] is in LDS - // c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in - // register - constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_c_n1bn2_block_mtx_desc = - make_ConstantMatrixDescriptor(Number{}, - Number{}, - Number{}); - - // sanity check - static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == - 0, - "wrong!"); - - constexpr index_t GemmMRepeat = - KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster); - - // c_thread_mtx definition: this is a mess - // TODO:: more elegent way of defining c_thread_mtx - constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed( - Number{}, Number{}); - - const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< - BlockSize, - decltype(a_c_k_block_mtx_desc), - decltype(b_c_n1bn2_block_mtx_desc), - decltype(c_k0k2_n1n2_thread_mtx_desc), - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - GemmDataPerReadA, - GemmDataPerReadB>{}; - - // LDS allocation for input and weight: be careful of alignment - constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2, - WeiBlockCopyDataPerAccess_K, - GemmDataPerReadA, - GemmDataPerReadB); - - constexpr index_t in_block_space = - math::integer_least_multiple(in_c_n1_b_n2_block_mem_desc.GetElementSpace(), max_align); - - constexpr index_t wei_block_space = - math::integer_least_multiple(wei_c_k_block_desc.GetElementSpace(), max_align); - - __shared__ Float p_in_block_double[2 * in_block_space]; - __shared__ Float p_wei_block_double[2 * wei_block_space]; - - // register allocation for output - Float p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()]; - - // zero out threadwise output - threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread); - - // do work - for(index_t y = 0; y < Y; ++y) - { - for(index_t x = 0; x < X; ++x) - { - // calculate origin of block input and weight tensor on global memory - const Float* p_in_block_on_global = - p_in_global + in_n_c_h_w_global_desc.GetOffsetFromMultiIndex(0, 0, y, x); - - const Float* p_wei_block_on_global = - p_wei_global + wei_c_y_x_k_global_desc.GetOffsetFromMultiIndex(0, y, x, 0); - - // LDS double buffer: preload data into LDS - { - blockwise_in_copy.Run(p_in_block_on_global, p_in_block_double); - blockwise_wei_copy.Run(p_wei_block_on_global, p_wei_block_double); - } - - // LDS double buffer: main body - for(index_t c_block_data_begin = 0; c_block_data_begin + 2 * CPerBlock < C; - c_block_data_begin += 2 * CPerBlock) - { -#pragma unroll - for(index_t iloop = 0; iloop < 2; ++iloop) - { - const bool even_loop = (iloop % 2 == 0); - - Float* p_in_block_now = - even_loop ? p_in_block_double : p_in_block_double + in_block_space; - Float* p_wei_block_now = - even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space; - - Float* p_in_block_next = - even_loop ? p_in_block_double + in_block_space : p_in_block_double; - Float* p_wei_block_next = - even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; - - Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; - Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; - - p_in_block_on_global += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1); - p_wei_block_on_global += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - blockwise_in_copy.RunLoadRegisterBuffer(p_in_block_on_global, - p_in_register_buffer); - blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, - p_wei_register_buffer); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); - - // LDS double buffer: store next data to LDS - blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, - p_in_block_next); - blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, - p_wei_block_next); - } - } - - // LDS double buffer: tail - { - Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; - Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; - - // even iteration - p_in_block_on_global += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1); - p_wei_block_on_global += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - blockwise_in_copy.RunLoadRegisterBuffer(p_in_block_on_global, - p_in_register_buffer); - blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, - p_wei_register_buffer); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); - - // LDS double buffer: store next data to LDS - blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, - p_in_block_double + in_block_space); - blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, - p_wei_block_double + wei_block_space); - - // odd iteration - __syncthreads(); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_double + wei_block_space, - p_in_block_double + in_block_space, - p_out_thread); - } - } - } - - // copy output: register to global memory - { - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster; - constexpr index_t K0 = K / (K1 * K2); - - // define tensor descriptor for threadwise copy - // output memory layout descriptor in register - constexpr auto out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc = - make_ConstantTensorDescriptor_packed( - Sequence{}); - - // output tensor descriptor in register, src of threadwise copy - constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_thread_desc = - out_k0_k1_k2_n1_n0_h_w_n2_thread_mem_desc.ReorderGivenNew2Old( - Sequence<4, 3, 7, 0, 1, 2, 5, 6>{}); - - // output memory layout descriptor in device memory, dst of threadwise copy - constexpr auto out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc = - out_n_k_h_w_global_desc.Fold(I1, Number{}, Number{}) - .Fold(I0, Number{}, Number{}); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t k_thread_data_on_global = - k_block_data_on_global + c_thread_mtx_on_block.row; - - const index_t b_thread_data_on_global = - b_block_data_on_global + c_thread_mtx_on_block.col / N2; - - // output merged global tensor descriptor, for calculating origin of thread tensor - // in global memory - constexpr auto out_k_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( - out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc.Unfold(I3, I5), - Sequence<3>{}, - Sequence<1>{}, - Sequence<0, 4, 5>{}, - Sequence<2>{}); - - // origin of dst in device memory - Float* p_out_thread_on_global = - p_out_global + - out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex( - k_thread_data_on_global, 0, b_thread_data_on_global, 0); - - threadwise_generic_tensor_slice_copy_v1( - out_n0_n1_n2_k0_k1_k2_h_w_thread_desc, - p_out_thread, - {0, 0, 0, 0, 0, 0, 0, 0}, - out_n0_n1_n2_k0_k1_k2_h_w_global_mem_desc, - p_out_thread_on_global, - {0, 0, 0, 0, 0, 0, 0, 0}, - out_n0_n1_n2_k0_k1_k2_h_w_thread_desc.GetLengths(), - arithmetic_sequence_gen<0, 8, 1>::type{}, - Number<1>{}); - } - } -}; - -} // namesspace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp index a462c6b560..37be0c60c2 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp @@ -8,53 +8,9 @@ #include "blockwise_generic_tensor_slice_copy.hpp" #include "threadwise_generic_tensor_slice_copy.hpp" #include "blockwise_gemm.hpp" -#include "convolution_common.hpp" namespace ck { -template -struct make_wei_e_k_global_desc_v4r1; - -template <> -struct make_wei_e_k_global_desc_v4r1 -{ - template - __device__ constexpr auto operator()(WeiDesc) const - { - constexpr auto I1 = Number<1>{}; - constexpr auto I3 = Number<3>{}; - - return reorder_tensor_descriptor_given_upper2lower( - unfold_tensor_descriptor(WeiDesc{}, I1, I3), Sequence<1, 0>{}); - } -}; - -template <> -struct make_wei_e_k_global_desc_v4r1 -{ - template - __device__ constexpr auto operator()(WeiDesc) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto wei_k_c_y_x_global_desc = WeiDesc{}; - - constexpr index_t K = wei_k_c_y_x_global_desc.GetLength(I0); - constexpr index_t C = wei_k_c_y_x_global_desc.GetLength(I1); - constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2); - constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3); - - return transform_tensor_descriptor( - unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I2, I3), - make_tuple(Merge>{}, PassThrough{}), - make_tuple(Sequence<1, 2>{}, Sequence<0>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } -}; - template {}; - static_assert(ConvDirection == ConvolutionDirection::Forward || - ConvDirection == ConvolutionDirection::BackwardWeight, - "wrong! this kernel only support convolution forward and backward-weight"); - // this is a mess // TODO: find more elegent way of specifying (or calculating) performance parameters constexpr index_t N1 = GemmNRepeat; - constexpr index_t N2 = GemmNPerThreadSubC; + constexpr index_t N2 = GemmNPerThread; - static_assert((N1 * N2 * BPerBlock) % - (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == - 0, - "wrong!"); + static_assert( + (N1 * N2 * BPerBlock) % (GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0, + "wrong!"); constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{}; constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; @@ -240,7 +190,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer // It is constructed differently, depending on whether forward or backward weight // convolution constexpr auto wei_e_k_global_desc = - make_wei_e_k_global_desc_v4r1{}(wei_k_c_y_x_global_desc); + transform_tensor_descriptor(unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I2, I3), + make_tuple(Merge>{}, PassThrough{}), + make_tuple(Sequence<1, 2>{}, Sequence<0>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); // block tensor in LDS memory, dst of blockwise copy // be careful of LDS alignment @@ -290,30 +243,29 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer in_e_n1_b_n2_block_desc.GetStride(I0)); // sanity check - static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == - 0, + static_assert(KPerBlock % (GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster) == 0, "wrong!"); constexpr index_t GemmMRepeat = - KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster); + KPerBlock / (GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster); // c_thread_mtx definition: this is a mess // TODO:: more elegent way of defining c_thread_mtx constexpr auto c_k0k1_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed( - Number{}, Number{}); + Number{}, Number{}); const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< BlockSize, decltype(a_e_k_block_mtx_desc), decltype(b_e_n1bn2_block_mtx_desc), decltype(c_k0k1_n1n2_thread_mtx_desc), - GemmMPerThreadSubC, - GemmNPerThreadSubC, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, GemmMLevel0Cluster, GemmNLevel0Cluster, GemmMLevel1Cluster, GemmNLevel1Cluster, - GemmKPerThreadLoop, GemmDataPerReadA, GemmDataPerReadB>{}; @@ -432,13 +384,13 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer // copy output: register to global memory { - constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster; + constexpr index_t K1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster; constexpr index_t K0 = K / K1; // define output tensor descriptor for threadwise copy // thread output tensor, src of threadwise copy constexpr auto out_k0_k1_n1_b_n2_thread_desc = make_native_tensor_descriptor_packed( - Sequence{}); + Sequence{}); // global output tensor constexpr auto out_n0_n1_n2_k0_k1_ho_wo_global_desc = transform_tensor_descriptor( diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp deleted file mode 100644 index 133a4635f0..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp +++ /dev/null @@ -1,460 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_DEPRECATED_HPP -#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_DEPRECATED_HPP - -#include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" -#include "ConstantMergedTensorDescriptor_deprecated.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "blockwise_generic_tensor_slice_copy_deprecated.hpp" -#include "blockwise_gemm.hpp" -#include "threadwise_generic_tensor_slice_copy_deprecated.hpp" -#include "convolution_common.hpp" - -namespace ck { - -template -struct make_wei_e_k_global_desc_v4r1_deprecated; - -template <> -struct make_wei_e_k_global_desc_v4r1_deprecated -{ - template - __device__ constexpr auto operator()(WeiDesc) const - { - constexpr auto I1 = Number<1>{}; - constexpr auto I3 = Number<3>{}; - - return WeiDesc::Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{}); - } -}; - -template <> -struct make_wei_e_k_global_desc_v4r1_deprecated -{ - template - __device__ constexpr auto operator()(WeiDesc) const - { - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - return make_ConstantMergedTensorDescriptor( - WeiDesc::Unfold(I2, I3), Sequence<1, 2>{}, Sequence<0>{}); - } -}; - -// define B = merge(N0, Ho, Wo) -template -struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated -{ - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto True = integral_constant{}; - - constexpr auto generic_address_space = - integral_constant{}; - constexpr auto global_address_space = - integral_constant{}; - - static_assert(ConvDirection == ConvolutionDirection::Forward || - ConvDirection == ConvolutionDirection::BackwardWeight, - "wrong! this kernel only support convolution forward and backward-weight"); - - // this is a mess - // TODO: find more elegent way of specifying (or calculating) performance parameters - constexpr index_t N1 = GemmNRepeat; - constexpr index_t N2 = GemmNPerThreadSubC; - - static_assert((N1 * N2 * BPerBlock) % - (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == - 0, - "wrong!"); - - constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{}; - constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; - constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{}; - - constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0); - constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1); - - constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1); - constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2); - constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3); - - constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2); - constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3); - - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread"); - - constexpr index_t N0 = N / (N1 * N2); - - constexpr index_t B = N0 * Ho * Wo; - - constexpr index_t E = C * Y * X; - - // sanity-check for vectorized memory load - static_assert( - (Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) && - (X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0), - "wrong! alignment requirement for vectorized global load of input tensor will " - "be violated"); - - // divide block work by [K, B] - static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0, - "wrong! cannot divide work evenly among block"); - - constexpr index_t KBlockWork = K / KPerBlock; - constexpr index_t BBlockWork = B / BPerBlock; - - constexpr auto block_work_desc = - make_ConstantTensorDescriptor_packed(Sequence{}); - - const auto block_work_multi_id = - block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id()); - - const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock; - const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock; - - // input tensor - // tensor descriptor in device memory [N0, N1, N2, Ho, Wo] - constexpr auto in_n0_n1_n2_h_w_global_desc = - in_n_c_h_w_global_desc.StridedSlice(I2, Number{}, Number{}) - .StridedSlice(I3, Number{}, Number{}) - .Fold(I0, Number{}, Number{}) - .Extract(Sequence<0, 1, 2, 4, 5>{}); - - // batch descritpor for device memory - constexpr auto in_c_y_x_global_desc = - in_n_c_h_w_global_desc.StridedSlice(I2, Number{}, Number{}) - .StridedSlice(I3, Number{}, Number{}) - .Extract(Sequence<1, 2, 3>{}); - - // merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy - constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor( - in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_global_desc), - Sequence<0, 1, 2>{}, - Sequence<4>{}, - Sequence<3, 6, 7>{}, - Sequence<5>{}); - - // memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy - // be careful of LDS alignment - constexpr auto in_e_n1_b_n2_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, Number{}); - - // this check is ad-hoc - // TODO: need to properly implement tensor descriptor with multiple alignment - // requirements - static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0, - "GemmDataPerReadB alignment requirement is not satisfied"); - - // input blockwise copy - // slice a merged tensor, reorder and copy to a normal tensor - // this copy operator already has blockwise offset built-in - auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2_deprecated< - BlockSize, - decltype(in_e_n1_b_n2_global_merged_desc), - decltype(in_e_n1_b_n2_block_desc), - decltype(in_e_n1_b_n2_block_desc.GetLengths()), - InBlockCopySubLengths_E_N1_B_N2, - InBlockCopyClusterLengths_E_N1_B_N2, - InBlockCopyThreadClusterArrangeOrder, - InBlockCopySrcAccessOrder, - InBlockCopyDstAccessOrder, - 2, - 3, - InBlockCopySrcDataPerRead_B, - InBlockCopyDstDataPerWrite_N2>({0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0}); - - // weight tensor - // Iensor descriptor in device memory, src of blockwise copy - // It is constructed differently, depending on whether forward or backward weight - // convolution - constexpr auto wei_e_k_global_desc = - make_wei_e_k_global_desc_v4r1_deprecated{}(wei_k_c_y_x_global_desc); - - // tensor descriptor in LDS, dst of blockwise copy - // be careful of LDS alignment - constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, - Number{}); - - // operator for blockwise copy of weight into LDS - // slice a tensor, and copy it into another tensor - // this copy operator already have blockwise offset built-in - auto blockwise_wei_copy = - BlockwiseGenericTensorSliceCopy_v2_deprecated( - {0, k_block_data_on_global}, {0, 0}); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[EPerBlock, KPerBlock] is in LDS - // b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS - // c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in - // register - constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc); - - constexpr auto b_e_n1bn2_block_mtx_desc = - make_ConstantMatrixDescriptor(in_e_n1_b_n2_block_desc.Unfold(I1, I3)); - - // sanity check - static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == - 0, - "wrong!"); - - constexpr index_t GemmMRepeat = - KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster); - - // c_thread_mtx definition: this is a mess - // TODO:: more elegent way of defining c_thread_mtx - constexpr auto c_k0k1_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed( - Number{}, Number{}); - - const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< - BlockSize, - decltype(a_e_k_block_mtx_desc), - decltype(b_e_n1bn2_block_mtx_desc), - decltype(c_k0k1_n1n2_thread_mtx_desc), - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - GemmDataPerReadA, - GemmDataPerReadB>{}; - - // LDS allocation for input and weight: be careful of alignment - constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2, - WeiBlockCopyDstDataPerWrite_K, - GemmDataPerReadA, - GemmDataPerReadB); - - constexpr index_t in_block_space = - math::integer_least_multiple(in_e_n1_b_n2_block_desc.GetElementSpace(), max_align); - - constexpr index_t wei_block_space = - math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align); - - __shared__ Float p_in_block_double[2 * in_block_space]; - __shared__ Float p_wei_block_double[2 * wei_block_space]; - - // register allocation for output - AccDataType p_out_thread[c_k0k1_n1n2_thread_mtx_desc.GetElementSpace()]; - - // zero out threadwise output - threadwise_matrix_set_zero(c_k0k1_n1n2_thread_mtx_desc, p_out_thread); - - // LDS double buffer: preload data into LDS - { - blockwise_in_copy.Run( - p_in_global, p_in_block_double, global_address_space, generic_address_space); - blockwise_wei_copy.Run( - p_wei_global, p_wei_block_double, global_address_space, generic_address_space); - } - - // LDS double buffer: main body - for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E; - e_block_data_begin += 2 * EPerBlock) - { -#pragma unroll - for(index_t iloop = 0; iloop < 2; ++iloop) - { - const bool even_loop = (iloop % 2 == 0); - - Float* p_in_block_now = - even_loop ? p_in_block_double : p_in_block_double + in_block_space; - Float* p_wei_block_now = - even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space; - - Float* p_in_block_next = - even_loop ? p_in_block_double + in_block_space : p_in_block_double; - Float* p_wei_block_next = - even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; - - Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()]; - Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()]; - - blockwise_in_copy.MoveSrcSliceWindow(Sequence{}, True); - blockwise_wei_copy.MoveSrcSliceWindow(Sequence{}, True); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - blockwise_in_copy.RunLoadThreadBuffer( - p_in_global, p_in_thread_buffer, global_address_space, generic_address_space); - blockwise_wei_copy.RunLoadThreadBuffer( - p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); - - // LDS double buffer: store next data to LDS - blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next); - blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next); - } - } - - // LDS double buffer: tail - { - constexpr bool has_two_iteration_left = (E % (2 * EPerBlock) == 0); - - if(has_two_iteration_left) // if has 2 iteration left - { - // even iteration - Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()]; - Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()]; - - blockwise_in_copy.MoveSrcSliceWindow(Sequence{}, True); - blockwise_wei_copy.MoveSrcSliceWindow(Sequence{}, True); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - blockwise_in_copy.RunLoadThreadBuffer( - p_in_global, p_in_thread_buffer, global_address_space, generic_address_space); - blockwise_wei_copy.RunLoadThreadBuffer( - p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); - - // LDS double buffer: store next data to LDS - blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, - p_in_block_double + in_block_space); - blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, - p_wei_block_double + wei_block_space); - - // odd iteration - __syncthreads(); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_double + wei_block_space, - p_in_block_double + in_block_space, - p_out_thread); - } - else // if has 1 iteration left - { - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); - } - } - - // copy output: register to global memory - { - constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster; - - // define tensor descriptor for threadwise copy - // output memory layout descriptor in register, src of threadwise copy - constexpr auto out_k0_k1_n1_b_n2_thread_mem_desc = make_ConstantTensorDescriptor_packed( - Sequence{}); - - // output memory layout descriptor in device memory - constexpr auto out_n0_n1_n2_k0_k1_h_w_global_mem_desc = - out_n_k_h_w_global_desc.Fold(I1, Number{}).Fold(I0, Number{}, Number{}); - - // output merged global tensor descriptor, dst of threadwise copy - constexpr auto out_k0_k1_n1_b_n2_global_merged_desc = - make_ConstantMergedTensorDescriptor(out_n0_n1_n2_k0_k1_h_w_global_mem_desc, - Sequence<3>{}, - Sequence<4>{}, - Sequence<1>{}, - Sequence<0, 5, 6>{}, - Sequence<2>{}); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t k_thread_data_on_global = - k_block_data_on_global + c_thread_mtx_on_block.row; - - const index_t b_thread_data_on_global = - b_block_data_on_global + c_thread_mtx_on_block.col / N2; - - ThreadwiseGenericTensorSliceCopy_v2r1_deprecated< - decltype(out_k0_k1_n1_b_n2_thread_mem_desc), - decltype(out_k0_k1_n1_b_n2_global_merged_desc), - decltype(out_k0_k1_n1_b_n2_thread_mem_desc.GetLengths()), - arithmetic_sequence_gen<0, 5, 1>::type, - arithmetic_sequence_gen<0, 5, 1>::type, - 3, - 3, - 1, - 1>({0, 0, 0, 0, 0}, - {k_thread_data_on_global / K1, - k_thread_data_on_global % K1, - 0, - b_thread_data_on_global, - 0}) - .Run(p_out_thread, p_out_global, generic_address_space, global_address_space); - } - } -}; - -} // namespace ck -#endif // CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_DEPRECATED_HPP diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer.hpp deleted file mode 100644 index 3fe68ca3a9..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ /dev/null @@ -1,432 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER -#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER - -#include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" -#include "ConstantMergedTensorDescriptor_deprecated.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "blockwise_generic_tensor_slice_copy.hpp" -#include "blockwise_gemm.hpp" -#include "threadwise_generic_tensor_slice_copy.hpp" - -namespace ck { - -template -struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer -{ - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) const - { - // this is a mess - // TODO: find more elegent way of specifying (or calculating) performance parameters - static_assert(N2 * Ho2 * Wo2 == GemmNPerThreadSubC, "wrong!"); - static_assert((N1 * Ho1 * Wo1 * BPerBlock * N2 * Ho2 * Wo2) % - (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == - 0, - "wrong!"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I7 = Number<7>{}; - - constexpr auto True = integral_constant{}; - - constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{}; - constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; - constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{}; - - constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0]; - constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1]; - - constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1]; - constexpr index_t Ho = out_n_k_h_w_global_desc.GetLengths()[2]; - constexpr index_t Wo = out_n_k_h_w_global_desc.GetLengths()[3]; - - constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2]; - constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3]; - - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - constexpr index_t E = C * Y * X; - - constexpr index_t B = N1 * Ho1 * Wo1; - - static_assert(N % (N1 * N2) == 0 && Ho % (Ho1 * Ho2) == 0 && Wo % (Wo1 * Wo2) == 0, - "wrong!"); - - constexpr index_t N0 = N / (N1 * N2); - constexpr index_t Ho0 = Ho / (Ho1 * Ho2); - constexpr index_t Wo0 = Wo / (Wo1 * Wo2); - - static_assert((X == 1 || ConvDilationW % InBlockCopyDataPerAccess_W2 == 0), - "wrong! aligment requirement for vectorized global load of input tensor will " - "be violated"); - - // divide block work by [K, B] - static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0, - "wrong! cannot divide work evenly among block"); - - constexpr index_t KBlockWork = K / KPerBlock; - constexpr index_t BBlockWork = B / BPerBlock; - - constexpr auto block_work_desc = - make_ConstantTensorDescriptor_packed(Sequence{}); - - const auto block_work_multi_id = - block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id()); - - const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock; - const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock; - - // input tensor - // tensor descriptor in device memory [N0, N1, N2, Ho0, Ho1, Ho2, Wo0, Wo1, Wo2] - constexpr auto in_n0_n1_n2_ho0_ho1_ho2_wo0_wo1_wo2_global_desc = - in_n_c_h_w_global_desc.Extract(I0, I2, I3) - .StridedSlice(I1, Number{}, Number{}) - .StridedSlice(I2, Number{}, Number{}) - .Fold(I2, Number{}, Number{}) - .Fold(I1, Number{}, Number{}) - .Fold(I0, Number{}, Number{}); - - constexpr auto in_n0_ho0_wo0_n1_ho1_wo1_n2_ho2_wo2_global_desc = - in_n0_n1_n2_ho0_ho1_ho2_wo0_wo1_wo2_global_desc.ReorderGivenNew2Old( - Sequence<0, 3, 6, 1, 4, 7, 2, 5, 8>{}); - - // batch descritpor for device memory - constexpr auto in_c_y_x_global_desc = - in_n_c_h_w_global_desc.StridedSlice(I2, Number{}, Number{}) - .StridedSlice(I3, Number{}, Number{}) - .Extract(Sequence<1, 2, 3>{}); - - // merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy - constexpr auto in_e_n0_ho0_wo0_b_n2_ho2_wo2_global_merged_desc = - make_ConstantMergedTensorDescriptor( - in_c_y_x_global_desc.Embed(in_n0_ho0_wo0_n1_ho1_wo1_n2_ho2_wo2_global_desc), - Sequence<0, 1, 2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}, - Sequence<6, 7, 8>{}, - Sequence<9>{}, - Sequence<10>{}, - Sequence<11>{}); - - // memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy - // be careful of LDS alignment - constexpr auto in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc = - make_ConstantTensorDescriptor_packed( - Sequence{}); - - // input blockwise copy - // slice a merged tensor, reorder and copy to a normal tensor - // this copy operator already has blockwise offset built-in - auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated< - BlockSize, - Float, - decltype(in_e_n0_ho0_wo0_b_n2_ho2_wo2_global_merged_desc), - decltype(in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc), - decltype(in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc.GetLengths()), - InBlockCopySubLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2, - InBlockCopyClusterLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2, - InBlockCopyThreadClusterArrangeOrder, - InBlockCopySrcAccessOrder, - InBlockCopyDstAccessOrder, - InBlockCopyDataPerAccess_W2, - InBlockCopyDataPerAccess_W2>({0, 0, 0, 0, b_block_data_on_global, 0, 0, 0}, - {0, 0, 0, 0, 0, 0, 0, 0}); - - // weight tensor - // tensor descriptor in device memory, src of blockwise copy - constexpr auto wei_e_k_global_desc = - wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{}); - - // tensor descriptor in LDS, dst of blockwise copy - // be careful of LDS alignment - constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, - Number{}); - - // operator for blockwise copy of weight into LDS - // slice a tensor, and copy it into another tensor - // this copy operator already have blockwise offset built-in - auto blockwise_wei_copy = - BlockwiseGenericTensorSliceCopy_v1_deprecated( - {0, k_block_data_on_global}, {0, 0}); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[EPerBlock, KPerBlock] is in LDS - // b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS - // c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in - // register - constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc); - - // this check is ad-hoc - // TODO: need to properly implement tensor descriptor with multiple alignment - // requirements - static_assert(in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc.GetStrides()[3] % GemmDataPerReadB == - 0, - "GemmDataPerReadB alignment requirement is not satisfied"); - - constexpr auto b_e_n0ho0wo0bn2ho2wo2_block_mtx_desc = - make_ConstantMatrixDescriptor(in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc.Unfold(I1, I7)); - - // sanity check - static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == - 0, - "wrong!"); - - constexpr index_t GemmMRepeat = - KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster); - - // c_thread_mtx definition: this is a mess - // TODO:: more elegent way of defining c_thread_mtx - constexpr auto c_k0k2_n0ho0wo0n2ho2wo2_thread_mtx_desc = - make_ConstantMatrixDescriptor_packed(Number{}, - Number{}); - - const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< - BlockSize, - decltype(a_e_k_block_mtx_desc), - decltype(b_e_n0ho0wo0bn2ho2wo2_block_mtx_desc), - decltype(c_k0k2_n0ho0wo0n2ho2wo2_thread_mtx_desc), - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - GemmDataPerReadA, - GemmDataPerReadB>{}; - - // LDS allocation for input and weight: be careful of alignment - constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_W2, - WeiBlockCopyDstDataPerWrite_K, - GemmDataPerReadA, - GemmDataPerReadB); - - constexpr index_t in_block_space = math::integer_least_multiple( - in_e_n0_ho0_wo0_b_n2_ho2_wo2_block_desc.GetElementSpace(), max_align); - - constexpr index_t wei_block_space = - math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align); - - __shared__ Float p_in_block_double[2 * in_block_space]; - __shared__ Float p_wei_block_double[2 * wei_block_space]; - - // register allocation for output - Float p_out_thread[c_k0k2_n0ho0wo0n2ho2wo2_thread_mtx_desc.GetElementSpace()]; - - // zero out threadwise output - threadwise_matrix_set_zero(c_k0k2_n0ho0wo0n2ho2wo2_thread_mtx_desc, p_out_thread); - - const Float* p_wei_block_on_global = p_wei_global; - - // LDS double buffer: preload data into LDS - { - blockwise_in_copy.Run(p_in_global, p_in_block_double); - blockwise_wei_copy.Run(p_wei_global, p_wei_block_double); - } - - // LDS double buffer: main body - for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E; - e_block_data_begin += 2 * EPerBlock) - { -#pragma unroll - for(index_t iloop = 0; iloop < 2; ++iloop) - { - const bool even_loop = (iloop % 2 == 0); - - Float* p_in_block_now = - even_loop ? p_in_block_double : p_in_block_double + in_block_space; - Float* p_wei_block_now = - even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space; - - Float* p_in_block_next = - even_loop ? p_in_block_double + in_block_space : p_in_block_double; - Float* p_wei_block_next = - even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; - - Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; - Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; - - blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number{}, True); - p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); - blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, - p_wei_register_buffer); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); - - // LDS double buffer: store next data to LDS - blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next); - blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next); - } - } - - // LDS double buffer: tail - { - Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; - Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; - - // even iteration - blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number{}, True); - p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); - blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); - - // LDS double buffer: store next data to LDS - blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, - p_in_block_double + in_block_space); - blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, - p_wei_block_double + wei_block_space); - - // odd iteration - __syncthreads(); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_double + wei_block_space, - p_in_block_double + in_block_space, - p_out_thread); - } - - // copy output: register to global memory - { - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster; - - // define tensor descriptor for threadwise copy - // output memory layout descriptor in register - constexpr auto out_k0_k1_k2_n0_ho0_wo0_n1_ho1_wo1_n2_ho2_wo2_thread_mem_desc = - make_ConstantTensorDescriptor_packed( - Sequence{}); - - // output tensor descriptor in register, src of threadwise copy - constexpr auto out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc = - out_k0_k1_k2_n0_ho0_wo0_n1_ho1_wo1_n2_ho2_wo2_thread_mem_desc.ReorderGivenNew2Old( - Sequence<3, 6, 9, 0, 1, 2, 4, 7, 10, 5, 8, 11>{}); - - // output memory layout descriptor in device memory, dst of threadwise copy - constexpr auto out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc = - out_n_k_h_w_global_desc.Fold(I3, Sequence{}) - .Fold(I2, Sequence{}) - .Fold(I1, Sequence{}) - .Fold(I0, Sequence{}); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t k_thread_data_on_global = - k_block_data_on_global + c_thread_mtx_on_block.row; - - const index_t b_thread_data_on_global = - b_block_data_on_global + c_thread_mtx_on_block.col / (N2 * Ho2 * Wo2); - - // output merged global tensor descriptor, for calculating origin of thread tensor - // in global memory - constexpr auto out_k_n0_ho0_wo0_b_n2_ho2_wo2_global_merged_desc = - make_ConstantMergedTensorDescriptor( - out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc.Unfold(I3, I5), - Sequence<3>{}, - Sequence<0>{}, - Sequence<4>{}, - Sequence<7>{}, - Sequence<1, 5, 8>{}, - Sequence<2>{}, - Sequence<6>{}, - Sequence<9>{}); - - // origin of dst in device memory - Float* p_out_thread_on_global = - p_out_global + - out_k_n0_ho0_wo0_b_n2_ho2_wo2_global_merged_desc.GetOffsetFromMultiIndex( - k_thread_data_on_global, 0, 0, 0, b_thread_data_on_global, 0, 0, 0); - - threadwise_generic_tensor_slice_copy_v1( - out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc, - p_out_thread, - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc, - p_out_thread_on_global, - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc.GetLengths(), - arithmetic_sequence_gen<0, 12, 1>::type{}, - Number<1>{}); - } - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer.hpp deleted file mode 100644 index bc50bf19ca..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ /dev/null @@ -1,457 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R3_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER -#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R3_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER - -#include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" -#include "ConstantMergedTensorDescriptor_deprecated.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "blockwise_generic_tensor_slice_copy.hpp" -#include "blockwise_gemm.hpp" -#include "threadwise_generic_tensor_slice_copy.hpp" - -namespace ck { - -template -struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer -{ - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) const - { - // this is a mess - // TODO: find more elegent way of specifying (or calculating) performance parameters - static_assert(N2 * Ho2 * Wo2 == GemmNPerThreadSubC, "wrong!"); - static_assert((N1 * Ho1 * Wo1 * BPerBlock * N2 * Ho2 * Wo2) % - (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == - 0, - "wrong!"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I7 = Number<7>{}; - - constexpr auto True = integral_constant{}; - - constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{}; - constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; - constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{}; - - constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0]; - constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1]; - - constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1]; - constexpr index_t Ho = out_n_k_h_w_global_desc.GetLengths()[2]; - constexpr index_t Wo = out_n_k_h_w_global_desc.GetLengths()[3]; - - constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2]; - constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3]; - - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - constexpr index_t E = C * Y * X; - - constexpr index_t B = N0 * Ho0 * Wo0; - - static_assert(N == N0 * N1 * N2 && Ho == Ho0 * Ho1 * Ho2 && Wo == Wo0 * Wo1 * Wo2, - "wrong!"); - - static_assert((X == 1 || ConvDilationW % InBlockCopyDataPerAccess_W2 == 0), - "wrong! aligment requirement for vectorized global load of input tensor will " - "be violated"); - - // divide block work by [K, B] - static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0, - "wrong! cannot divide work evenly among block"); - - constexpr index_t KBlockWork = K / KPerBlock; - constexpr index_t BBlockWork = B / BPerBlock; - - constexpr auto block_work_desc = - make_ConstantTensorDescriptor_packed(Sequence{}); - - const auto block_work_multi_id = - block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id()); - - const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock; - const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock; - - // input tensor - // tensor descriptor in device memory [N0, N1, N2, Ho0, Ho1, Ho2, Wo0, Wo1, Wo2] - constexpr auto in_n0_n1_n2_ho0_ho1_ho2_wo0_wo1_wo2_global_desc = - in_n_c_h_w_global_desc.Extract(I0, I2, I3) - .StridedSlice(I1, Number{}, Number{}) - .StridedSlice(I2, Number{}, Number{}) - .Fold(I2, Number{}, Number{}) - .Fold(I1, Number{}, Number{}) - .Fold(I0, Number{}, Number{}); - - constexpr auto in_n1_ho1_wo1_n0_ho0_wo0_n2_ho2_wo2_global_desc = - in_n0_n1_n2_ho0_ho1_ho2_wo0_wo1_wo2_global_desc.ReorderGivenNew2Old( - Sequence<1, 4, 7, 0, 3, 6, 2, 5, 8>{}); - - // batch descritpor for device memory - constexpr auto in_c_y_x_global_desc = - in_n_c_h_w_global_desc.StridedSlice(I2, Number{}, Number{}) - .StridedSlice(I3, Number{}, Number{}) - .Extract(Sequence<1, 2, 3>{}); - - // merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy - constexpr auto in_e_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc = - make_ConstantMergedTensorDescriptor( - in_c_y_x_global_desc.Embed(in_n1_ho1_wo1_n0_ho0_wo0_n2_ho2_wo2_global_desc), - Sequence<0, 1, 2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}, - Sequence<6, 7, 8>{}, - Sequence<9>{}, - Sequence<10>{}, - Sequence<11>{}); - - // memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy - // be careful of LDS alignment - constexpr auto in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc = - make_ConstantTensorDescriptor_packed( - Sequence{}); - - // input blockwise copy - // slice a merged tensor, reorder and copy to a normal tensor - // this copy operator already has blockwise offset built-in - auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated< - BlockSize, - Float, - decltype(in_e_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc), - decltype(in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc), - decltype(in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc.GetLengths()), - InBlockCopySubLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2, - InBlockCopyClusterLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2, - InBlockCopyThreadClusterArrangeOrder, - InBlockCopySrcAccessOrder, - InBlockCopyDstAccessOrder, - InBlockCopyDataPerAccess_W2, - InBlockCopyDataPerAccess_W2>({0, 0, 0, 0, b_block_data_on_global, 0, 0, 0}, - {0, 0, 0, 0, 0, 0, 0, 0}); - - // weight tensor - // tensor descriptor in device memory, src of blockwise copy - constexpr auto wei_e_k_global_desc = - wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{}); - - // tensor descriptor in LDS, dst of blockwise copy - // be careful of LDS alignment - constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, - Number{}); - - // operator for blockwise copy of weight into LDS - // slice a tensor, and copy it into another tensor - // this copy operator already have blockwise offset built-in - auto blockwise_wei_copy = - BlockwiseGenericTensorSliceCopy_v1_deprecated( - {0, k_block_data_on_global}, {0, 0}); - -#if 0 - if(get_block_1d_id() == 0) - { - printf("id (%d %d), in offset: %d %d, wei offset %d %d\n", - get_block_1d_id(), - get_thread_local_1d_id(), - blockwise_in_copy.mThreadSrcOffset, - blockwise_in_copy.mThreadDstOffset, - blockwise_wei_copy.mThreadSrcOffset, - blockwise_wei_copy.mThreadDstOffset); - } -#endif - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[EPerBlock, KPerBlock] is in LDS - // b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS - // c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in - // register - constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc); - - // this check is ad-hoc - // TODO: need to properly implement tensor descriptor with multiple alignment - // requirements - static_assert(in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc.GetStrides()[3] % GemmDataPerReadB == - 0, - "GemmDataPerReadB alignment requirement is not satisfied"); - - constexpr auto b_e_n1ho1wo1bn2ho2wo2_block_mtx_desc = - make_ConstantMatrixDescriptor(in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc.Unfold(I1, I7)); - - // sanity check - static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == - 0, - "wrong!"); - - constexpr index_t GemmMRepeat = - KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster); - - // c_thread_mtx definition: this is a mess - // TODO:: more elegent way of defining c_thread_mtx - constexpr auto c_k0k2_n1ho1wo1n2ho2wo2_thread_mtx_desc = - make_ConstantMatrixDescriptor_packed(Number{}, - Number{}); - - const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< - BlockSize, - decltype(a_e_k_block_mtx_desc), - decltype(b_e_n1ho1wo1bn2ho2wo2_block_mtx_desc), - decltype(c_k0k2_n1ho1wo1n2ho2wo2_thread_mtx_desc), - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - GemmDataPerReadA, - GemmDataPerReadB>{}; - - // LDS allocation for input and weight: be careful of alignment - constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_W2, - WeiBlockCopyDstDataPerWrite_K, - GemmDataPerReadA, - GemmDataPerReadB); - - constexpr index_t in_block_space = math::integer_least_multiple( - in_e_n1_ho1_wo1_b_n2_ho2_wo2_block_desc.GetElementSpace(), max_align); - - constexpr index_t wei_block_space = - math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align); - - __shared__ Float p_in_block_double[2 * in_block_space]; - __shared__ Float p_wei_block_double[2 * wei_block_space]; - - // register allocation for output - Float p_out_thread[c_k0k2_n1ho1wo1n2ho2wo2_thread_mtx_desc.GetElementSpace()]; - - // zero out threadwise output - threadwise_matrix_set_zero(c_k0k2_n1ho1wo1n2ho2wo2_thread_mtx_desc, p_out_thread); - - const Float* p_wei_block_on_global = p_wei_global; - - // LDS double buffer: preload data into LDS - { - blockwise_in_copy.Run(p_in_global, p_in_block_double); - blockwise_wei_copy.Run(p_wei_global, p_wei_block_double); - } - - // LDS double buffer: main body - for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E; - e_block_data_begin += 2 * EPerBlock) - { -#pragma unroll - for(index_t iloop = 0; iloop < 2; ++iloop) - { - const bool even_loop = (iloop % 2 == 0); - - Float* p_in_block_now = - even_loop ? p_in_block_double : p_in_block_double + in_block_space; - Float* p_wei_block_now = - even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space; - - Float* p_in_block_next = - even_loop ? p_in_block_double + in_block_space : p_in_block_double; - Float* p_wei_block_next = - even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; - - Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; - Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; - - blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number{}, True); - p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); - blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, - p_wei_register_buffer); - -#if 0 - if(get_block_1d_id() == 0) - { - printf("tid (%d %d), %f %f %f %f\n", - get_block_1d_id(), - get_thread_local_1d_id(), - p_wei_register_buffer[0], - p_wei_register_buffer[1], - p_wei_register_buffer[2], - p_wei_register_buffer[3]); - } -#endif - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); - - // LDS double buffer: store next data to LDS - blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, p_in_block_next); - blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, p_wei_block_next); - } - } - - // LDS double buffer: tail - { - Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; - Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; - - // even iteration - blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number{}, True); - p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - blockwise_in_copy.RunLoadRegisterBuffer(p_in_global, p_in_register_buffer); - blockwise_wei_copy.RunLoadRegisterBuffer(p_wei_block_on_global, p_wei_register_buffer); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); - - // LDS double buffer: store next data to LDS - blockwise_in_copy.RunStoreRegisterBuffer(p_in_register_buffer, - p_in_block_double + in_block_space); - blockwise_wei_copy.RunStoreRegisterBuffer(p_wei_register_buffer, - p_wei_block_double + wei_block_space); - - // odd iteration - __syncthreads(); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_double + wei_block_space, - p_in_block_double + in_block_space, - p_out_thread); - } - - // copy output: register to global memory - { - constexpr index_t K2 = GemmMPerThreadSubC; - constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster; - - // define tensor descriptor for threadwise copy - // output memory layout descriptor in register - constexpr auto out_k0_k1_k2_n1_ho1_wo1_n0_ho0_wo0_n2_ho2_wo2_thread_mem_desc = - make_ConstantTensorDescriptor_packed( - Sequence{}); - - // output tensor descriptor in register, src of threadwise copy - constexpr auto out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc = - out_k0_k1_k2_n1_ho1_wo1_n0_ho0_wo0_n2_ho2_wo2_thread_mem_desc.ReorderGivenNew2Old( - Sequence<6, 3, 9, 0, 1, 2, 7, 4, 10, 8, 5, 11>{}); - - // output memory layout descriptor in device memory, dst of threadwise copy - constexpr auto out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc = - out_n_k_h_w_global_desc.Fold(I3, Sequence{}) - .Fold(I2, Sequence{}) - .Fold(I1, Sequence{}) - .Fold(I0, Sequence{}); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t k_thread_data_on_global = - k_block_data_on_global + c_thread_mtx_on_block.row; - - const index_t b_thread_data_on_global = - b_block_data_on_global + c_thread_mtx_on_block.col / (N2 * Ho2 * Wo2); - - // output merged global tensor descriptor, for calculating origin of thread tensor - // in global memory - constexpr auto out_k_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc = - make_ConstantMergedTensorDescriptor( - out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc.Unfold(I3, I5), - Sequence<3>{}, - Sequence<1>{}, - Sequence<5>{}, - Sequence<8>{}, - Sequence<0, 4, 7>{}, - Sequence<2>{}, - Sequence<6>{}, - Sequence<9>{}); - - // origin of dst in device memory - Float* p_out_thread_on_global = - p_out_global + - out_k_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc.GetOffsetFromMultiIndex( - k_thread_data_on_global, 0, 0, 0, b_thread_data_on_global, 0, 0, 0); - - threadwise_generic_tensor_slice_copy_v1( - out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc, - p_out_thread, - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_global_mem_desc, - p_out_thread_on_global, - {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - out_n0_n1_n2_k0_k1_k2_ho0_ho1_ho2_wo0_wo1_wo2_thread_desc.GetLengths(), - arithmetic_sequence_gen<0, 12, 1>::type{}, - Number<1>{}); - } - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp index ced7d1ea1c..ae68f4486e 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -75,6 +75,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw constexpr index_t ConvDilationH = ConvDilations{}[0]; constexpr index_t ConvDilationW = ConvDilations{}[1]; +#if 0 // sanity-check for vectorized memory load static_assert((Wo == 1 || (ConvStrideW == 1 || GemmBBlockCopySrcDataPerRead_GemmN == 1)) && (X == 1 || ConvDilationW % GemmBBlockCopySrcDataPerRead_GemmN == 0) && @@ -82,9 +83,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw InRightPads{}[1] % GemmBBlockCopySrcDataPerRead_GemmN == 0, "wrong! aligment requirement for vectorized global load of input tensor will " "be violated"); +#endif // weight tensor - constexpr auto wei_e_k_global_desc = reorder_tensor_descriptor_given_upper2lower( + constexpr auto wei_gemmk_gemmm_global_desc = reorder_tensor_descriptor_given_upper2lower( unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{}); // input tensor @@ -108,14 +110,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - constexpr auto in_e_b_global_desc = transform_tensor_descriptor( + constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor( in_n_c_y_ho_x_wo_global_desc, make_tuple(Merge>{}, Merge>{}), make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); // output tensor - constexpr auto out_k_b_global_desc = + constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3), make_tuple(PassThrough{}, Merge>{}), make_tuple(Sequence<1>{}, Sequence<0, 2>{}), @@ -127,9 +129,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw BlockSize, Float, AccFloat, - decltype(wei_e_k_global_desc), - decltype(in_e_b_global_desc), - decltype(out_k_b_global_desc), + decltype(wei_gemmk_gemmm_global_desc), + decltype(in_gemmm_gemmn_global_desc), + decltype(out_gemmk_gemmn_global_desc), InMemoryDataOperation::Set, GemmMPerBlock, GemmNPerBlock, @@ -157,7 +159,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw 1, GemmBBlockCopySrcDataPerRead_GemmN, GemmBBlockCopyDstDataPerWrite_GemmN, - Sequence<0, 1, 2, 3>, + Sequence<2, 3, 0, 1>, 3, GemmCThreadCopyDstDataPerWrite_GemmN1>{}; diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp deleted file mode 100644 index c6e36d5973..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp +++ /dev/null @@ -1,404 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_DEPRECATRD_HPP -#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_DEPRECATRD_HPP - -#include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" -#include "ConstantMergedTensorDescriptor_deprecated.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "blockwise_generic_tensor_slice_copy_deprecated.hpp" -#include "blockwise_gemm.hpp" -#include "threadwise_generic_tensor_slice_copy_deprecated.hpp" - -namespace ck { - -// B = merge(N, Ho, Wo) -template -struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_deprecated -{ - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I5 = Number<5>{}; - - constexpr auto True = integral_constant{}; - - constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{}; - constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; - constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{}; - - constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0]; - constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1]; - - constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1]; - constexpr index_t Ho = out_n_k_h_w_global_desc.GetLengths()[2]; - constexpr index_t Wo = out_n_k_h_w_global_desc.GetLengths()[3]; - - constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2]; - constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3]; - - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - constexpr index_t E = C * Y * X; - constexpr index_t B = N * Ho * Wo; - - // sanity-check for vectorized memory load - static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopyDataPerAccess_B == 1)) && - (X == 1 || ConvDilationW % InBlockCopyDataPerAccess_B == 0), - "wrong! aligment requirement for vectorized global load of input tensor will " - "be violated"); - - // divide block work by [K, B] - static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0, - "wrong! cannot divide work evenly among block"); - - constexpr index_t KBlockWork = K / KPerBlock; - constexpr index_t BBlockWork = B / BPerBlock; - - constexpr auto block_work_desc = - make_ConstantTensorDescriptor_packed(Sequence{}); - - const auto block_work_multi_id = - block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id()); - - const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock; - const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock; - - // input tensor - // tensor descriptor in device memory [N, Ho, Wo] - constexpr auto in_n_ho_wo_global_desc = - in_n_c_h_w_global_desc.Extract(I0, I2, I3) - .StridedSlice(I1, Number{}, Number{}) - .StridedSlice(I2, Number{}, Number{}); - - // batch descritpor for device memory - constexpr auto in_c_y_x_global_desc = - in_n_c_h_w_global_desc.StridedSlice(I2, Number{}, Number{}) - .StridedSlice(I3, Number{}, Number{}) - .Extract(Sequence<1, 2, 3>{}); - - // merged tensor descriptor in device memory [E, B], src of blockwise copy - constexpr auto in_e_b_global_desc = - make_ConstantMergedTensorDescriptor(in_c_y_x_global_desc.Embed(in_n_ho_wo_global_desc), - Sequence<0, 1, 2>{}, - Sequence<3, 4, 5>{}); - - // memory layout descriptor in LDS [E, B], dst of blockwise copy - // be careful of LDS alignment - constexpr auto in_e_b_block_desc = - make_ConstantTensorDescriptor_packed(Sequence{}); - - // input blockwise copy - // slice a merged tensor, reorder and copy to a normal tensor - // this copy operator already has blockwise offset built-in - auto blockwise_in_copy = - BlockwiseGenericTensorSliceCopy_v2_deprecated( - {0, b_block_data_on_global}, {0, 0}); - - // weight tensor - // tensor descriptor in device memory, src of blockwise copy - constexpr auto wei_e_k_global_desc = - wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{}); - - // tensor descriptor in LDS, dst of blockwise copy - // be careful of LDS alignment - constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned( - Sequence{}, - Number{}); - - // this check is ad-hoc - // TODO: need to properly implement tensor descriptor with multiple alignment - // requirements - static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0, - "GemmDataPerReadA alignment requirement is not satisfied"); - - // operator for blockwise copy of weight into LDS - // slice a tensor, and copy it into another tensor - // this copy operator already have blockwise offset built-in - auto blockwise_wei_copy = - BlockwiseGenericTensorSliceCopy_v2_deprecated( - {0, k_block_data_on_global}, {0, 0}); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[EPerBlock, KPerBlock] is in LDS - // b_mtx[EPerBlocl, BPerBlock] is in LDS - // c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in - // register - constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc); - - constexpr auto b_e_b_block_mtx_desc = make_ConstantMatrixDescriptor(in_e_b_block_desc); - - // sanity check - static_assert( - KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) == 0 && - BPerBlock % (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0, - "wrong!"); - - constexpr index_t GemmMRepeat = - KPerBlock / (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster); - - constexpr index_t GemmNRepeat = - BPerBlock / (GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster); - - // c_thread_mtx definition: this is a mess - // TODO:: more elegent way of defining c_thread_mtx - constexpr auto c_k0k1_b0b1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed( - Number{}, Number{}); - - const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< - BlockSize, - decltype(a_e_k_block_mtx_desc), - decltype(b_e_b_block_mtx_desc), - decltype(c_k0k1_b0b1_thread_mtx_desc), - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - GemmDataPerReadA, - GemmDataPerReadB>{}; - - // LDS allocation for input and weight: be careful of alignment - constexpr index_t max_align = math::lcm(InBlockCopyDataPerAccess_B, - WeiBlockCopyDstDataPerWrite_K, - GemmDataPerReadA, - GemmDataPerReadB); - - constexpr index_t in_block_space = - math::integer_least_multiple(in_e_b_block_desc.GetElementSpace(), max_align); - - constexpr index_t wei_block_space = - math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align); - - __shared__ Float p_in_block_double[2 * in_block_space]; - __shared__ Float p_wei_block_double[2 * wei_block_space]; - - // register allocation for output - Float p_out_thread[c_k0k1_b0b1_thread_mtx_desc.GetElementSpace()]; - - // zero out threadwise output - threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread); - - const Float* p_wei_block_on_global = p_wei_global; - - // LDS double buffer: preload data into LDS - { - blockwise_in_copy.template Run(p_in_global, - p_in_block_double); - blockwise_wei_copy.template Run(p_wei_global, - p_wei_block_double); - } - - // LDS double buffer: main body - for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E; - e_block_data_begin += 2 * EPerBlock) - { -#pragma unroll - for(index_t iloop = 0; iloop < 2; ++iloop) - { - const bool even_loop = (iloop % 2 == 0); - - Float* p_in_block_now = - even_loop ? p_in_block_double : p_in_block_double + in_block_space; - Float* p_wei_block_now = - even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space; - - Float* p_in_block_next = - even_loop ? p_in_block_double + in_block_space : p_in_block_double; - Float* p_wei_block_next = - even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; - - Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()]; - Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()]; - - blockwise_in_copy.MoveSrcSliceWindow(Sequence{}, True); - blockwise_wei_copy.MoveSrcSliceWindow(Sequence{}, True); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - blockwise_in_copy.template RunLoadThreadBuffer( - p_in_global, p_in_thread_buffer); - blockwise_wei_copy.template RunLoadThreadBuffer( - p_wei_global, p_wei_thread_buffer); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); - - // LDS double buffer: store next data to LDS - blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next); - blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next); - } - } - - // LDS double buffer: tail - { - Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()]; - Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()]; - - // even iteration - blockwise_in_copy.MoveSrcSliceWindow(Sequence{}, True); - blockwise_wei_copy.MoveSrcSliceWindow(Sequence{}, True); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - blockwise_in_copy.template RunLoadThreadBuffer( - p_in_global, p_in_thread_buffer); - blockwise_wei_copy.template RunLoadThreadBuffer( - p_wei_global, p_wei_thread_buffer); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); - - // LDS double buffer: store next data to LDS - blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, - p_in_block_double + in_block_space); - blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, - p_wei_block_double + wei_block_space); - - // odd iteration - __syncthreads(); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_double + wei_block_space, - p_in_block_double + in_block_space, - p_out_thread); - } - - // copy output: register to global memory - { - constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster; - constexpr index_t B1 = GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster; - - // define tensor descriptor for threadwise copy - // output global descriptor, for calculating origin of thread tensor - // in global memory - constexpr auto out_k_b_global_desc = make_ConstantMergedTensorDescriptor( - out_n_k_h_w_global_desc, Sequence<1>{}, Sequence<0, 2, 3>{}); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t k_thread_data_on_global = - k_block_data_on_global + c_thread_mtx_on_block.row; - - const index_t b_thread_data_on_global = - b_block_data_on_global + c_thread_mtx_on_block.col; - - // This is a hack, because slicing a merged dimension is not supported yet. - // This should be replaced with logic above, once slicing a merged dimension support - // become available - // dst descriptor - constexpr auto out_k0_k1_b_global_desc = - make_ConstantMergedTensorDescriptor(out_n_k_h_w_global_desc.Fold(I1, Number{}), - Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3, 4>{}); - - // src descriptor - constexpr auto out_k0_k1_b_thread_desc = make_ConstantTensorDescriptor_packed( - Sequence{}); - - using OutThreadCopySliceLengths = - Sequence; - - auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1_deprecated< - decltype(out_k0_k1_b_thread_desc), - decltype(out_k0_k1_b_global_desc), - OutThreadCopySliceLengths, - arithmetic_sequence_gen<0, 3, 1>::type, - arithmetic_sequence_gen<0, 3, 1>::type, - 2, - 2, - OutThreadCopyDataPerAccess_B, - OutThreadCopyDataPerAccess_B>({0, 0, 0}, - {k_thread_data_on_global / K1, - k_thread_data_on_global % K1, - b_thread_data_on_global}); - - for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat) - { - threadwise_out_copy - .template Run(p_out_thread, - p_out_global); - - threadwise_out_copy.MoveSrcSliceWindow(Sequence<0, 0, GemmNPerThreadSubC>{}, True); - threadwise_out_copy.MoveDstSliceWindow(Sequence<0, 0, B1>{}, True); - } - } - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp b/composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp index 0ebd9dc4a1..48f1b718f1 100644 --- a/composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp +++ b/composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp @@ -2,7 +2,6 @@ #define CK_CONSTANT_MATRIX_DESCRIPTOR_HPP #include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" #include "tensor_descriptor.hpp" namespace ck { @@ -58,18 +57,6 @@ __host__ __device__ constexpr auto return ConstantMatrixDescriptor{}; } -template -__host__ __device__ constexpr auto - make_ConstantMatrixDescriptor(ConstantTensorDescriptor_deprecated) -{ - using TDesc = ConstantTensorDescriptor_deprecated; - static_assert(TDesc::GetNumOfDimension() == 2, "wrong"); - static_assert(TDesc::GetStrides()[1] == 1, "wrong"); - return ConstantMatrixDescriptor{}; -} - template __host__ __device__ constexpr auto make_ConstantMatrixDescriptor(NativeTensorDescriptor) { diff --git a/composable_kernel/include/tensor_description/ConstantMergedTensorDescriptor_deprecated.hpp b/composable_kernel/include/tensor_description/ConstantMergedTensorDescriptor_deprecated.hpp deleted file mode 100644 index 814e47d1c1..0000000000 --- a/composable_kernel/include/tensor_description/ConstantMergedTensorDescriptor_deprecated.hpp +++ /dev/null @@ -1,210 +0,0 @@ -#ifndef CK_CONSTANT_MERGED_TENSOR_DESCRIPTOR_DEPRECATED_HPP -#define CK_CONSTANT_MERGED_TENSOR_DESCRIPTOR_DEPRECATED_HPP - -#include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" - -namespace ck { - -// OriginalTensorDesc : ConstantTensorDescriptor_deprecated<...> -// it's the tensor whose dimensions are to be merged -// OriginalDimMergeSeqs : Sequence<...>... -// each is a sequence of original dimensions (of OriginalTensorDesc) to be merged -template -struct ConstantMergedTensorDescriptor_deprecated -{ - using Type = ConstantMergedTensorDescriptor_deprecated; - - static constexpr auto mOriginalDimMergeSeqs = std::tuple{}; - - static constexpr index_t nDim = sizeof...(OriginalDimMergeSeqs); - static constexpr index_t nOriginalDim = OriginalTensorDesc::GetNumOfDimension(); - - __host__ __device__ constexpr ConstantMergedTensorDescriptor_deprecated() - { - static_assert(nDim <= nOriginalDim, "wrong!"); - - // TODO: check each of OriginalDimMergeSeqs contains at least 1, and at most - // OriginalTensorDesc::nDim number of dimensions - - // TODO: check OriginalDimMergeSeqs contains all original dimensions - - // TODO: check there is no duplication in OriginalDimMergeSeqs - } - - __host__ __device__ static constexpr auto GetOriginalTensorDescriptor() - { - return OriginalTensorDesc{}; - } - - __host__ __device__ static constexpr auto GetNumOfDimension() { return Number{}; } - - template - __host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number) - { - return std::get(mOriginalDimMergeSeqs); - } - - template - __host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(Number) - { - return (std::get(mOriginalDimMergeSeqs).GetSize() > 1); - } - - template - __host__ __device__ static constexpr auto GetLength(Number) - { - constexpr auto original_dims_partial = std::get(mOriginalDimMergeSeqs); - - return OriginalTensorDesc::Extract(original_dims_partial).GetElementSize(); - } - - template - __host__ __device__ static constexpr auto GetStride(Number) - { - static_assert(!ContainMultipleOriginalDimensions(Number{}), - "wrong! stride of a merged dimension is undefined"); - - constexpr auto idim_original = std::get(mOriginalDimMergeSeqs).Back(); - - return OriginalTensorDesc::GetStride(Number{}); - } - - // this is a hack to return the stride of the last original dimension of a merged dimension - // TODO: refactor this once the concept of "dimension" is used - template - __host__ __device__ static constexpr auto GetLastOriginalDimensionStride(Number) - { - constexpr auto idim_last_original = std::get(mOriginalDimMergeSeqs).Back(); - - return OriginalTensorDesc::GetStride(Number{}); - } - - __host__ __device__ static constexpr auto GetLengths() - { - return Sequence{}; - } - - __host__ __device__ static constexpr auto GetElementSize() - { - return OriginalTensorDesc::GetElementSize(); - } - - template - struct lambda_1_GetOriginalMultiIndexFromMultiIndex - { - const Array& original_multi_id_partial; - Array& original_multi_id; - - __host__ __device__ constexpr lambda_1_GetOriginalMultiIndexFromMultiIndex( - const Array& original_multi_id_partial_, - Array& original_multi_id_) - : original_multi_id_partial(original_multi_id_partial_), - original_multi_id(original_multi_id_) - { - } - - template - __host__ __device__ constexpr void operator()(Number) const - { - constexpr index_t idim_original = OriginalDimsPartial::Get(Number{}); - - index_t itmp = original_multi_id_partial[I]; - - original_multi_id(idim_original) = itmp; - } - }; - - struct lambda_0_GetOriginalMultiIndexFromMultiIndex - { - const Array& multi_id; - Array& original_multi_id; - - __host__ __device__ constexpr lambda_0_GetOriginalMultiIndexFromMultiIndex( - const Array& multi_id_, Array& original_multi_id_) - : multi_id(multi_id_), original_multi_id(original_multi_id_) - { - } - - template - __host__ __device__ constexpr void operator()(Number) const - { - constexpr auto original_dims_partial = std::get(Type::mOriginalDimMergeSeqs); - - // get partial original-multi-id corresponding to this merged dimension - const auto original_multi_id_partial = - OriginalTensorDesc::Extract(original_dims_partial) - .GetMultiIndexFrom1dIndex(multi_id[IDim]); - - static_for<0, original_dims_partial.GetSize(), 1>{}( - lambda_1_GetOriginalMultiIndexFromMultiIndex( - original_multi_id_partial, original_multi_id)); - } - }; - - // return type is Array<...> - __host__ __device__ static constexpr auto - GetOriginalMultiIndexFromMultiIndex(Array multi_id) - { - Array original_multi_id; - - static_for<0, nDim, 1>{}( - lambda_0_GetOriginalMultiIndexFromMultiIndex(multi_id, original_multi_id)); - - return original_multi_id; - } - - template - __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Sequence) - { - constexpr auto multi_id = sequence2array(Sequence{}); - - constexpr auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id); - - return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id); - } - - __host__ __device__ static constexpr index_t - GetOffsetFromMultiIndex(Array multi_id) - { - auto original_multi_id = GetOriginalMultiIndexFromMultiIndex(multi_id); - - return OriginalTensorDesc::GetOffsetFromMultiIndex(original_multi_id); - } - - template - __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is) - { - return GetOffsetFromMultiIndex(Array{is...}); - } - - __host__ __device__ static constexpr Array GetMultiIndexFrom1dIndex(index_t id) - { - constexpr auto packed_desc = make_ConstantTensorDescriptor_packed(GetLengths()); - - return packed_desc.GetMultiIndexFrom1dIndex(id); - } - - __host__ __device__ static constexpr auto Pack() - { - constexpr auto lengths = GetLengths(); - constexpr auto strides = calculate_tensor_strides_packed(lengths); - return ConstantTensorDescriptor_deprecated{}; - } -}; - -template -__host__ __device__ constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc, - OriginalDimMergeSeqs...) -{ - return ConstantMergedTensorDescriptor_deprecated{}; -} - -template -__host__ __device__ void print_ConstantMergedTensorDescriptor(const char* s, TDesc) -{ - print_ConstantTensorDescriptor(s, TDesc::GetOriginalTensorDescriptor()); -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_description/ConstantTensorDescriptor_deprecated.hpp b/composable_kernel/include/tensor_description/ConstantTensorDescriptor_deprecated.hpp deleted file mode 100644 index d745f69f80..0000000000 --- a/composable_kernel/include/tensor_description/ConstantTensorDescriptor_deprecated.hpp +++ /dev/null @@ -1,612 +0,0 @@ -#ifndef CK_CONSTANT_TENSOR_DESCRIPTOR_DEPRECATED_HPP -#define CK_CONSTANT_TENSOR_DESCRIPTOR_DEPRECATED_HPP - -#include "common_header.hpp" - -namespace ck { - -template -__host__ __device__ constexpr auto calculate_tensor_strides_packed_deprecated(Lengths) -{ - return reverse_inclusive_scan_sequence( - Lengths{}.PopFront(), math::multiplies{}, Number<1>{}) - .PushBack(Number<1>{}); -} - -template -__host__ __device__ constexpr auto calculate_tensor_strides_aligned_old(Lengths, Number) -{ - constexpr index_t L_back_align = - Align * math::integer_divide_ceiler{}(Lengths{}.Back(), Align); - - return calculate_tensor_strides_packed_deprecated( - Lengths{}.Modify(Number{}, Number{})); -} - -template -struct ConstantTensorDescriptor_deprecated -{ - using Type = ConstantTensorDescriptor_deprecated; - - static constexpr index_t nDim = Lengths::GetSize(); - - __host__ __device__ constexpr ConstantTensorDescriptor_deprecated() - { - static_assert(Lengths::GetSize() == Strides::GetSize(), "nDim not consistent"); - } - - __host__ __device__ static constexpr auto GetOriginalTensorDescriptor() { return Type{}; } - - template - __host__ __device__ static constexpr auto GetContainedOriginalDimensions(Number) - { - return Sequence{}; - } - - __host__ __device__ static constexpr auto GetNumOfDimension() { return Number{}; } - - __host__ __device__ static constexpr auto GetLengths() { return Lengths{}; } - - __host__ __device__ static constexpr auto GetStrides() { return Strides{}; } - - __host__ __device__ static constexpr auto GetLength(index_t IDim) { return Lengths{}[IDim]; } - - __host__ __device__ static constexpr auto GetStride(index_t IDim) { return Strides{}[IDim]; } - - struct lambda_AreDimensionsContinuous - { - bool& is_continuous; - - __host__ __device__ constexpr lambda_AreDimensionsContinuous(bool& is_continuous_) - : is_continuous(is_continuous_) - { - } - - template - __host__ __device__ constexpr void operator()(Number) const - { - constexpr auto IDim = Number{}; - constexpr auto IDim_p1 = Number{}; - - is_continuous = - is_continuous && (GetStride(IDim) >= GetStride(IDim_p1) && - GetStride(IDim) == GetStride(IDim_p1) * GetLength(IDim_p1)); - } - }; - - __host__ __device__ static constexpr bool AreDimensionsContinuous() - { - bool is_continuous = true; - - static_for<0, nDim - 1, 1>{}(lambda_AreDimensionsContinuous(is_continuous)); - - return is_continuous; - } - - __host__ __device__ static constexpr bool IsPackedTensor() - { - return AreDimensionsContinuous() && GetStride(Number{}) == 1; - } - - template - __host__ __device__ static constexpr bool ContainMultipleOriginalDimensions(T) - { - return false; - } - - __host__ __device__ static constexpr auto GetElementSize() - { - return Number{}, Number<1>{})>{}; - } - - __host__ __device__ static constexpr auto GetElementSpace() - { - constexpr index_t element_space_unaligned = reduce_on_sequence( - (GetLengths() - Number<1>{}) * GetStrides(), math::plus{}, Number<1>{}); - - return Number{}; - } - - // emulate constexpr lambda - template - struct lambda_GetOffsetFromMultiIndex - { - Array& multi_id; - index_t& offset; - - __host__ - __device__ constexpr lambda_GetOffsetFromMultiIndex(Array& multi_id_, - index_t& offset_) - : multi_id(multi_id_), offset(offset_) - { - } - - template - __host__ __device__ constexpr void operator()(X IDim) const - { - offset += multi_id[IDim] * Type::GetStride(IDim); - } - }; - - template - __host__ __device__ static constexpr index_t - GetOffsetFromMultiIndex(Array multi_id) - { - static_assert(NSize == nDim, "wrong! Dimension not consistent"); - - index_t offset = 0; - - static_for<0, nDim, 1>{}(lambda_GetOffsetFromMultiIndex(multi_id, offset)); - - return offset; - } - - template - __host__ __device__ static constexpr index_t GetOffsetFromMultiIndex(Is... is) - { - return GetOffsetFromMultiIndex(Array{is...}); - } - - template - __host__ __device__ static constexpr auto GetOffsetFromMultiIndex(Sequence) - { - static_assert(sizeof...(Is) == nDim, "wrong! Dimension not consistent"); - - constexpr auto multi_id = Sequence{}; - - return Number{}, Number<0>{})>{}; - } - - // emulate constexpr lambda - template - struct lambda_GetMultiIndexFrom1dIndex - { - index_t& id; - Array& multi_id; - - __host__ - __device__ constexpr lambda_GetMultiIndexFrom1dIndex(index_t& id_, - Array& multi_id_) - : id(id_), multi_id(multi_id_) - { - } - - template - __host__ __device__ constexpr void operator()(IDim_) const - { - constexpr auto IDim = IDim_{}; - constexpr index_t stride = PackedStrides::Get(IDim); - multi_id(IDim) = id / stride; - id -= multi_id[IDim] * stride; - } - }; - - __host__ __device__ static constexpr Array GetMultiIndexFrom1dIndex(index_t id) - { - Array multi_id; - - using PackedStrides = decltype(calculate_tensor_strides_packed_deprecated(GetLengths())); - - // calculate index in each of the dimensions in the order of their dimension - static_for<0, nDim - 1, 1>{}(lambda_GetMultiIndexFrom1dIndex(id, multi_id)); - - multi_id(Number{}) = id / PackedStrides::Get(Number{}); - - return multi_id; - } - - __host__ __device__ static constexpr auto - GetOriginalMultiIndexFromMultiIndex(Array multi_id) - { - return multi_id; - } - - // This function doesn't do carry check on the highest dimension for positive stepping (or - // borrow check on the highest dimension for negative stepping) , for performance reason. It is - // the user's responsibility to make sure the result "new_mutli_id" is not out-of-bound on the - // highest dimension for positive stepping (or on the lowest dimension for negative stepping) - template - __host__ __device__ static Array - UpdateMultiIndexGivenStepSizeOf1dIndex(Array old_multi_id, - index_t step_size_of_1d_index, - integral_constant) - { - Array new_multi_id; - - const auto step_sizes = GetMultiIndexFrom1dIndex(step_size_of_1d_index); - - static_if{}([&](auto) { - new_multi_id = old_multi_id + step_sizes; - - bool carry = false; - - // do carry check in reversed order, starting from lowest dimension - // don't check the highest dimension - static_for<0, nDim, 1>{}([&](auto IDimReverse) { - constexpr index_t idim = nDim - 1 - IDimReverse; - constexpr auto IDim = Number{}; - - if(carry) - { - ++new_multi_id(idim); - } - - carry = false; - - if(new_multi_id[idim] >= GetLength(IDim)) - { - new_multi_id(idim) -= GetLength(IDim); - carry = true; - } - }); - }).Else([&](auto) { - // shift up multi-id to avoid unsigned integer underflow during intermediate - // calculations. After the shift, should have new_multi_id[...] >= 1 - new_multi_id = old_multi_id + (GetLengths() - step_sizes); - - bool borrow = false; - - // do borrow check in reversed order, starting from lowest dimension - // don't check the highest dimension - static_for<0, nDim, 1>{}([&](auto IDimReverse) { - constexpr index_t idim = nDim - 1 - IDimReverse; - constexpr auto IDim = Number{}; - - if(borrow) - { - --new_multi_id(idim); - } - - borrow = false; - - if(new_multi_id[idim] < GetLength(IDim)) - { - new_multi_id(idim) += GetLength(IDim); - borrow = true; - } - }); - - // shift back down multi-id - // here, should have new_multi_id[...] >= GetLengths() - new_multi_id = new_multi_id - GetLengths(); - }); - - return new_multi_id; - } - - template - __host__ __device__ static constexpr auto Extract(Number... extract_dims) - { - static_assert(sizeof...(IDims) <= GetNumOfDimension(), - "wrong! too many number of dimensions to be extracted"); - - using extract_lengths = decltype(Lengths::Extract(extract_dims...)); - using extract_strides = decltype(Strides::Extract(extract_dims...)); - - return ConstantTensorDescriptor_deprecated{}; - } - - template - __host__ __device__ static constexpr auto Extract(Sequence) - { - return Extract(Number{}...); - } - - template - __host__ __device__ static constexpr auto Embed(ConstantTensorDescriptor_deprecated) - { - using leaf_tensor = ConstantTensorDescriptor_deprecated; - - return ConstantTensorDescriptor_deprecated< - decltype(GetLengths().PushBack(leaf_tensor::GetLengths())), - decltype(GetStrides().PushBack(leaf_tensor::GetStrides()))>{}; - } - - template - struct lambda_IsVectorizationAllowed - { - bool& is_allowed; - - __host__ __device__ constexpr lambda_IsVectorizationAllowed(bool& is_allowed_) - : is_allowed(is_allowed_) - { - } - - template - __host__ __device__ constexpr void operator()(Number) const - { - constexpr auto IDim = Number{}; - - if(IDimVector != IDim && Strides::Get(IDim) % DataPerVector != 0) - { - is_allowed = false; - } - } - }; - - template - __host__ __device__ static constexpr bool IsVectorizationAllowed(Number, - Number) - { - bool is_allowed = (Strides{}[IDimVector] == 1 || DataPerVector == 1) && - Lengths{}[IDimVector] % DataPerVector == 0; - - static_for<0, nDim, 1>{}( - lambda_IsVectorizationAllowed{is_allowed}); - - return is_allowed; - } - - template - __host__ __device__ static constexpr auto Vectorize(Number, Number) - { - constexpr auto idim = Number{}; - constexpr auto data_per_vector = Number{}; - - static_assert(IsVectorizationAllowed(idim, data_per_vector), "wrong!"); - - using vectorized_lengths = - decltype(Lengths::Modify(Number{}, Number{})); - using vectorized_strides = - decltype((Strides{} / Number{}).Modify(Number{}, Number<1>{})); - - return ConstantTensorDescriptor_deprecated{}; - } - - template - __host__ __device__ static constexpr auto Slice(Number, Number) - { - using slice_lengths = decltype(Lengths::Modify(Number{}, Number{})); - - return ConstantTensorDescriptor_deprecated{}; - } - - template - __host__ __device__ static constexpr auto Slice(Sequence slice_lengths) - { - static_assert(slice_lengths.GetSize() == nDim, "wrong!"); - - return ConstantTensorDescriptor_deprecated{}; - } - - template - __host__ __device__ static constexpr auto - StridedSlice(Number, Number, Number) - { - constexpr index_t new_stride = Strides::Get(Number{}) * SliceStride; - - using new_lengths = decltype(Lengths::Modify(Number{}, Number{})); - using new_strides = decltype(Strides::Modify(Number{}, Number{})); - - return ConstantTensorDescriptor_deprecated{}; - } - - template - __host__ __device__ static constexpr auto Fold(Number, Number...) - { - constexpr auto fold_intervals = Sequence{}; - - constexpr index_t fold_intervals_product = - reduce_on_sequence(fold_intervals, math::multiplies{}, Number<1>{}); - - constexpr auto unfold_length = GetLength(Number{}); - constexpr auto unfold_stride = GetStride(Number{}); - - // length of the dimension to be folded needs to be dividable by fold_interval_product, - // otherwise, folding is invalid - static_assert(unfold_length % fold_intervals_product == 0, - "wrong! length on the dimension to be folded cannot be evenly divided!"); - - // folded lengths - constexpr auto fold_lengths = - Sequence{}.PushBack(fold_intervals); - - // folded strides - constexpr auto fold_strides = - Number{} * - reverse_inclusive_scan_sequence( - fold_intervals.PushBack(Number<1>{}), math::multiplies{}, Number<1>{}); - - // left and right - constexpr auto left = typename arithmetic_sequence_gen<0, IDim, 1>::type{}; - constexpr auto right = - typename arithmetic_sequence_gen::type{}; - - constexpr auto new_lengths = - GetLengths().Extract(left).PushBack(fold_lengths).PushBack(GetLengths().Extract(right)); - constexpr auto new_strides = - GetStrides().Extract(left).PushBack(fold_strides).PushBack(GetStrides().Extract(right)); - - return ConstantTensorDescriptor_deprecated{}; - } - - template - __host__ __device__ static constexpr auto Fold(Number, Sequence) - { - return Fold(Number{}, Number{}...); - } - - // this function unfold dimension [FirstUnfoldDim, ..., LastUnfoldDim] into 1 dimension - template - __host__ __device__ static constexpr auto Unfold(Number, Number) - { - static_assert(FirstUnfoldDim >= 0 && LastUnfoldDim < nDim && - FirstUnfoldDim <= LastUnfoldDim, - "wrong! should have FirstUnfoldDim <= LastUnfoldDim!"); - - // left and right - constexpr auto left = typename arithmetic_sequence_gen<0, FirstUnfoldDim, 1>::type{}; - constexpr auto middle = - typename arithmetic_sequence_gen::type{}; - constexpr auto right = - typename arithmetic_sequence_gen::type{}; - - // dimensions to be unfolded need to be continuous - static_assert(Type::Extract(middle).AreDimensionsContinuous(), "wrong! not unfoldable"); - - // unfolded length, stride - constexpr index_t unfold_length = reduce_on_sequence( - GetLengths().Extract(middle), math::multiplies{}, Number<1>{}); - - constexpr index_t unfold_stride = GetStride(Number{}); - - // new lengths, strides - constexpr auto new_lengths = GetLengths() - .Extract(left) - .PushBack(Number{}) - .PushBack(GetLengths().Extract(right)); - - constexpr auto new_strides = GetStrides() - .Extract(left) - .PushBack(Number{}) - .PushBack(GetStrides().Extract(right)); - - return ConstantTensorDescriptor_deprecated{}; - } - - __host__ __device__ static constexpr auto Pack() - { - using packed_strides = decltype(calculate_tensor_strides_packed_deprecated(Lengths{})); - return ConstantTensorDescriptor_deprecated{}; - } - - template - __host__ __device__ static constexpr auto ReorderGivenNew2Old(MapNew2Old) - { - return ConstantTensorDescriptor_deprecated< - decltype(Lengths::ReorderGivenNew2Old(MapNew2Old{})), - decltype(Strides::ReorderGivenNew2Old(MapNew2Old{}))>{}; - } - - template - __host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New) - { - return ConstantTensorDescriptor_deprecated< - decltype(Lengths::ReorderGivenOld2New(MapOld2New{})), - decltype(Strides::ReorderGivenOld2New(MapOld2New{}))>{}; - } -}; - -template -__host__ __device__ constexpr auto make_ConstantTensorDescriptor_packed(Lengths) -{ - using Strides = decltype(calculate_tensor_strides_packed_deprecated(Lengths{})); - return ConstantTensorDescriptor_deprecated{}; -} - -template -__host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Strides) -{ - return ConstantTensorDescriptor_deprecated{}; -} - -template -__host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number) -{ - using Strides = decltype(calculate_tensor_strides_aligned_old(Lengths{}, Number{})); - return ConstantTensorDescriptor_deprecated{}; -} - -template -__host__ __device__ void print_ConstantTensorDescriptor( - const char* s, ConstantTensorDescriptor_deprecated, Sequence>) -{ - constexpr index_t ndim = sizeof...(Lengths); - - static_assert(ndim > 0 && ndim <= 12, "wrong!"); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u}, strides {%u}\n", s, ndim, Lengths..., Strides...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n", s, ndim, Lengths..., Strides...); - }); - - static_if{}([&](auto) { - printf( - "%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n", s, ndim, Lengths..., Strides...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n", - s, - ndim, - Lengths..., - Strides...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n", - s, - ndim, - Lengths..., - Strides...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n", - s, - ndim, - Lengths..., - Strides...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n", - s, - ndim, - Lengths..., - Strides...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n", - s, - ndim, - Lengths..., - Strides...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u " - "%u}\n", - s, - ndim, - Lengths..., - Strides...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u " - "%u %u %u}\n", - s, - ndim, - Lengths..., - Strides...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u " - "%u %u " - "%u %u %u}\n", - s, - ndim, - Lengths..., - Strides...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u " - "%u %u %u %u " - "%u %u %u}\n", - s, - ndim, - Lengths..., - Strides...); - }); -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_description/tensor_coordinate_deprecated.hpp b/composable_kernel/include/tensor_description/tensor_coordinate_deprecated.hpp deleted file mode 100644 index 69659445a0..0000000000 --- a/composable_kernel/include/tensor_description/tensor_coordinate_deprecated.hpp +++ /dev/null @@ -1,348 +0,0 @@ -#ifndef CK_TENSOR_COORDINATE_DEPRECATED_HPP -#define CK_TENSOR_COORDINATE_DEPRECATED_HPP - -#include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" -#include "ConstantMergedTensorDescriptor_deprecated.hpp" - -namespace ck { - -// TensorDesc is ConstantTensorDescriptor_deprecated -template -struct NormalTensorCoordinate_deprecated -{ - using type = NormalTensorCoordinate_deprecated; - using tensor_desc_type = TensorDesc; - - static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension(); - - __host__ - __device__ constexpr NormalTensorCoordinate_deprecated(Array tensor_index) - : mOffset{tensor_desc_type::GetOffsetFromMultiIndex(tensor_index)} - { - } - - template - __host__ __device__ constexpr NormalTensorCoordinate_deprecated(Xs... xs) - : NormalTensorCoordinate_deprecated(Array{xs...}) - { - } - - template - __host__ __device__ constexpr NormalTensorCoordinate_deprecated(Sequence) - : NormalTensorCoordinate_deprecated(Array{Xs...}) - { - } - - __host__ __device__ constexpr index_t GetOffset() const { return mOffset; } - - // T is Array or Sequence - template - __host__ __device__ type operator+=(T step_sizes) - { - static_assert(is_same{} && T::GetSize() == nDim, "wrong!"); - - mOffset += tensor_desc_type::GetOffsetFromMultiIndex(step_sizes); - - return *this; - } - - template - __host__ __device__ type operator-=(T step_sizes) - { - static_assert(is_same{} && T::GetSize() == nDim, "wrong!"); - - mOffset -= tensor_desc_type::GetOffsetFromMultiIndex(step_sizes); - - return *this; - } - - template - __host__ __device__ constexpr type operator+(T step_sizes) const - { - type coord = *this; - coord += step_sizes; - return coord; - } - - template - __host__ __device__ constexpr type operator-(T step_sizes) const - { - type coord = *this; - coord -= step_sizes; - return coord; - } - - // reposition point of origin, and return compensated offset. - // This is a hack to reduce index calculation during looping over - // a tensor whose origin is this TensorCoordinate. It does so, by spitting - // out the run-time offset to the pointer (to the tensor data) held by this - // TensorCoordiante, so the caller can add the offset into the run-time pointer of - // the data, so only 1 run-time variable (update pointer) is needed, instead - // of 2 run-time variables (old pointer and this offset) - // TODO: after introducing the concept of "run-time tensor view", which contains the - // run-time pointer to the data, always keep track of the pointer, instead of both - // offset and the pointer. This also bring additional benefit that we don't need to - // worry the offset might underflow (because offset is unsigned integer) when updating it. - __host__ __device__ constexpr index_t RepositionOrigin() - { - index_t offset_diff = mOffset; - mOffset = 0; - return offset_diff; - } - - private: - index_t mOffset; -}; - -// TensorDesc is ConstantMergedTensorDescriptor_deprecated -template -struct MergedTensorCoordinate_deprecated -{ - using type = MergedTensorCoordinate_deprecated; - using tensor_desc_type = TensorDesc; - - static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension(); - static constexpr index_t nOriginalDim = - tensor_desc_type::GetOriginalTensorDescriptor().GetNumOfDimension(); - - __host__ - __device__ constexpr MergedTensorCoordinate_deprecated(Array tensor_index) - : mOriginalIndex{tensor_desc_type::GetOriginalMultiIndexFromMultiIndex(tensor_index)} - { - // partial offset on each dimension - static_for<0, nDim, 1>{}([&](auto idim) { - constexpr auto partial_original_dims = - tensor_desc_type::GetContainedOriginalDimensions(idim); - - constexpr auto partial_original_desc = - tensor_desc_type::GetOriginalTensorDescriptor().Extract(partial_original_dims); - - mPartialOffsets(idim) = partial_original_desc.GetOffsetFromMultiIndex( - extract_array(mOriginalIndex, partial_original_dims)); - }); - - // complete offset - mOffset = - accumulate_on_array(mPartialOffsets, math::plus{}, static_cast(0)); - } - - template - __host__ __device__ constexpr MergedTensorCoordinate_deprecated(Xs... xs) - : MergedTensorCoordinate_deprecated(Array{xs...}) - { - } - - __host__ __device__ constexpr index_t GetOffset() const { return mOffset; } - - template - __host__ __device__ void - MoveOnDimension(IDim idim_, T step_size, integral_constant) - { - constexpr auto idim = idim_; - - // if step_size is known at compile time - static_if::value>{}( - [&](auto) { static_if{}([&](auto) { return; }); }); - - // update original index - static_if{}([&](auto) { - constexpr auto partial_original_dims = - tensor_desc_type::GetContainedOriginalDimensions(idim); - - constexpr index_t ndim_partial_original = partial_original_dims.GetSize(); - - constexpr auto partial_original_desc = - tensor_desc_type::GetOriginalTensorDescriptor().Extract(partial_original_dims); - - const auto partial_original_step_sizes = - partial_original_desc.GetMultiIndexFrom1dIndex(step_size); - - // update partial original multi-id - auto partial_original_id = extract_array(mOriginalIndex, partial_original_dims); - - static_if{}([&](auto) { - partial_original_id += partial_original_step_sizes; - - bool carry = false; - - // do carry check in reversed order, starting from lowest dimension - // don't check the highest dimension - static_for<0, ndim_partial_original - 1, 1>{}([&](auto IReverse) { - constexpr index_t i = ndim_partial_original - 1 - IReverse; - - if(carry) - { - ++partial_original_id(i); - } - - carry = false; - - if(partial_original_id[i] >= partial_original_desc.GetLength(i)) - { - partial_original_id(i) -= partial_original_desc.GetLength(i); - carry = true; - } - }); - - // highest dimension - if(carry) - { - ++partial_original_id(0); - } - }).Else([&](auto) { - // shift up multi-id to avoid unsigned integer underflow during intermediate - // calculations. After the shift, should have new_multi_id[...] >= 1 - partial_original_id += - partial_original_desc.GetLengths() - partial_original_step_sizes; - - bool borrow = false; - - // do borrow check in reversed order, starting from lowest dimension - // don't check the highest dimension - static_for<0, ndim_partial_original - 1, 1>{}([&](auto IReverse) { - constexpr index_t i = ndim_partial_original - 1 - IReverse; - - if(borrow) - { - --partial_original_id(i); - } - - borrow = false; - - if(partial_original_id[i] < partial_original_desc.GetLength(i)) - { - partial_original_id(i) += partial_original_desc.GetLength(i); - borrow = true; - } - }); - - // highest dimension - if(borrow) - { - --partial_original_id(0); - } - - // shift back down multi-id - // here, should have new_multi_id[...] >= GetLengths() - partial_original_id = partial_original_id - partial_original_desc.GetLengths(); - }); - - // update "mOriginalIndex" - static_for<0, ndim_partial_original, 1>{}([&](auto I) { - constexpr auto idim_original = partial_original_dims[I]; - - mOriginalIndex(idim_original) = partial_original_id[I]; - }); - - // calculate new partial offset on this merged dimension - const index_t old_partial_offset = mPartialOffsets[idim]; - - mPartialOffsets(idim) = - partial_original_desc.GetOffsetFromMultiIndex(partial_original_id); - - // update "mThreadSrcOffset", do "+" before "-" to avoid underflow - mOffset = (mOffset + mPartialOffsets[idim]) - old_partial_offset; - }).Else([&](auto fwd) { - static_if{}([&](auto) { - mOffset += step_size * fwd(tensor_desc_type{}).GetStride(idim); - }).Else([&](auto) { mOffset -= step_size * fwd(tensor_desc_type{}).GetStride(idim); }); - }); - } - - // T is Array or Sequence - template - __host__ __device__ type operator+=(T step_sizes) - { - static_assert(is_same{} && T::GetSize() == nDim, "wrong!"); - - static_for<0, nDim, 1>{}([&](auto idim) { - // compiler should remove dead code path, because step_sizes is known at - // compile time - if(step_sizes[idim] != 0) - { - this->MoveOnDimension(idim, step_sizes[idim], integral_constant{}); - } - }); - - return *this; - } - - template - __host__ __device__ type operator-=(T step_sizes) - { - static_assert(is_same{} && T::GetSize() == nDim, "wrong!"); - - static_for<0, nDim, 1>{}([&](auto idim) { - // compiler should remove dead code path, because step_sizes is known at - // compile time - if(step_sizes[idim] != 0) - { - this->MoveOnDimension(idim, step_sizes[idim], integral_constant{}); - } - }); - - return *this; - } - - template - __host__ __device__ constexpr type operator+(T step_sizes) const - { - type coord = *this; - coord += step_sizes; - return coord; - } - - template - __host__ __device__ constexpr type operator-(T step_sizes) const - { - type coord = *this; - coord -= step_sizes; - return coord; - } - - __host__ __device__ static constexpr index_t RepositionOrigin() { return 0; } - - private: - // Allocate register memory for all merged dimensions and normal dimensions. - // However, only those merged dimensions, whose index will be involved in arithmetic - // after the construction of this TensorCoordinate (e.g. when user move a slicing - // window on the merged dimension), will use these register memory. - // Let's hope compiler will optimize away those register memory allocated for normal - // dimensions, and those merged dimensions, that would never be involved in index - // arithmetic after construction of TensorCoordinate. - // TODO: refactor TensorCoordinate, after introducing the concept of "dimensions" - // and simplify implementation of ConstantMergedTensorDescriptor_deprecated, so we don't need to - // count on compiler to optimize away those register memory for us - Array mOriginalIndex; - Array mPartialOffsets; - - // complete offset - index_t mOffset; -}; - -template -struct TensorCoordinate_deprecated -{ - private: - template - __host__ __device__ static constexpr auto - MakeDummyTensorCoordinate(ConstantTensorDescriptor_deprecated) - { - return NormalTensorCoordinate_deprecated>(); - } - - template - __host__ __device__ static constexpr auto - MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor_deprecated) - { - return MergedTensorCoordinate_deprecated< - ConstantMergedTensorDescriptor_deprecated>(); - } - - public: - using type = decltype(MakeDummyTensorCoordinate(TensorDesc{})); -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_description/tensor_coordinate_helper.hpp b/composable_kernel/include/tensor_description/tensor_coordinate_helper.hpp deleted file mode 100644 index 2cacb329cb..0000000000 --- a/composable_kernel/include/tensor_description/tensor_coordinate_helper.hpp +++ /dev/null @@ -1,16 +0,0 @@ -#ifndef CK_TENSOR_COORDINATE_HELPER_HPP -#define CK_TENSOR_COORDINATE_HELPER_HPP - -#include "tensor_coordiante_hpp" - -namespace ck { - -template -__host__ __device__ constexpr auto -make_tensor_coordinate(TensorDesc, MultiIndex idx) -{ - return typename TensorCoordinate::type(idx); -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm.hpp index 6106581896..2e21f7141b 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm.hpp @@ -18,11 +18,11 @@ template struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 diff --git a/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp b/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp index 97e6acc4ee..39c1fb86fa 100644 --- a/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp @@ -15,6 +15,8 @@ namespace ck { // The dimension access order can be different for src and dst. // Will do valid mapping check on src data: Read 0 if src data has a invalid mapping // Will do valid mapping check on dst data: No write if dst data has a invalid mapping +// BlockSize can be equal or larger than ThreadCluster size, which means some threads may not do +// threadwise copy template + InMemoryDataOperation DstInMemOp = InMemoryDataOperation::Set, + index_t SrcDataStride = 1, + index_t DstDataStride = 1> struct BlockwiseGenericTensorSliceCopy_v4 { static constexpr index_t nDim = BlockSrcDesc::GetNumOfDimension(); @@ -52,23 +56,23 @@ struct BlockwiseGenericTensorSliceCopy_v4 is_same{}, "wrong! threads should be mapped to cover entire slicing window"); - // map threads to cluster - constexpr auto thread_cluster_desc = - make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + static_assert(BlockSize >= mThreadClusterDesc.GetElementSize(), + "wrong! BlockSize too small"); - static_assert(BlockSize == thread_cluster_desc.GetElementSize(), - "wrong! BlockSize not consistent with ThreadClusterLengths"); + if(BlockSize == mThreadClusterDesc.GetElementSize() or + get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize()) + { + const auto thread_cluster_id = + mThreadClusterDesc.CalculateClusterIndex(get_thread_local_1d_id()); - const auto thread_cluster_id = - thread_cluster_desc.CalculateClusterIndex(get_thread_local_1d_id()); + const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{}; - const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{}; + mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin); + mThreadwiseLoad.SetDstSliceOrigin(make_zero_array()); - mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin); - mThreadwiseLoad.SetDstSliceOrigin(make_zero_array()); - - mThreadwiseStore.SetSrcSliceOrigin(make_zero_array()); - mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin); + mThreadwiseStore.SetSrcSliceOrigin(make_zero_array()); + mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin); + } } __device__ static constexpr index_t GetThreadBufferSize() @@ -83,14 +87,18 @@ struct BlockwiseGenericTensorSliceCopy_v4 constexpr bool has_optimized_address_calculation = decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation(); - // TODO: threadwise copy is still being tweaked - if(has_optimized_address_calculation) + if(BlockSize == mThreadClusterDesc.GetElementSize() or + get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize()) { - mThreadwiseLoad.Run_optimized_src_address_calculation(p_block_src, p_thread_buffer); - } - else - { - mThreadwiseLoad.Run(p_block_src, p_thread_buffer); + // TODO: threadwise copy is still being tweaked + if(has_optimized_address_calculation) + { + mThreadwiseLoad.Run_optimized_src_address_calculation(p_block_src, p_thread_buffer); + } + else + { + mThreadwiseLoad.Run(p_block_src, p_thread_buffer); + } } } @@ -101,14 +109,19 @@ struct BlockwiseGenericTensorSliceCopy_v4 constexpr bool has_optimized_address_calculation = decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation(); - // TODO: threadwise copy is still being tweaked - if(has_optimized_address_calculation) + if(BlockSize == mThreadClusterDesc.GetElementSize() or + get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize()) { - mThreadwiseStore.Run_optimized_dst_address_calculation(p_thread_buffer, p_block_dst); - } - else - { - mThreadwiseStore.Run(p_thread_buffer, p_block_dst); + // TODO: threadwise copy is still being tweaked + if(has_optimized_address_calculation) + { + mThreadwiseStore.Run_optimized_dst_address_calculation(p_thread_buffer, + p_block_dst); + } + else + { + mThreadwiseStore.Run(p_thread_buffer, p_block_dst); + } } } @@ -123,10 +136,14 @@ struct BlockwiseGenericTensorSliceCopy_v4 BlockSrcData p_thread_buffer[GetThreadBufferSize()]; - RunLoadThreadBuffer(p_block_src, p_thread_buffer); + if(BlockSize == mThreadClusterDesc.GetElementSize() or + get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize()) + { + RunLoadThreadBuffer(p_block_src, p_thread_buffer); - // if there is type conversion, it's done during store - RunStoreThreadBuffer(p_thread_buffer, p_block_dst); + // if there is type conversion, it's done during store + RunStoreThreadBuffer(p_thread_buffer, p_block_dst); + } } template @@ -134,7 +151,11 @@ struct BlockwiseGenericTensorSliceCopy_v4 MoveSrcSliceWindow(const T& step_sizes, integral_constant positive_direction) { - mThreadwiseLoad.MoveSrcSliceWindow(step_sizes, positive_direction); + if(BlockSize == mThreadClusterDesc.GetElementSize() or + get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize()) + { + mThreadwiseLoad.MoveSrcSliceWindow(step_sizes, positive_direction); + } } template @@ -142,7 +163,11 @@ struct BlockwiseGenericTensorSliceCopy_v4 MoveDstSliceWindow(const T& step_sizes, integral_constant positive_direction) { - mThreadwiseStore.MoveDstSliceWindow(step_sizes, positive_direction); + if(BlockSize == mThreadClusterDesc.GetElementSize() or + get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize()) + { + mThreadwiseStore.MoveDstSliceWindow(step_sizes, positive_direction); + } } private: @@ -157,7 +182,9 @@ struct BlockwiseGenericTensorSliceCopy_v4 1, SrcAddressSpace, ThreadBufferAddressSpace, - InMemoryDataOperation::Set>; + InMemoryDataOperation::Set, + SrcDataStride, + 1>; using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v4r2; + DstInMemOp, + 1, + DstDataStride>; + + static constexpr auto mThreadClusterDesc = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); ThreadwiseLoad mThreadwiseLoad; ThreadwiseStore mThreadwiseStore; diff --git a/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy_deprecated.hpp b/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy_deprecated.hpp deleted file mode 100644 index 784b7548c5..0000000000 --- a/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy_deprecated.hpp +++ /dev/null @@ -1,613 +0,0 @@ -#ifndef CK_BLOCKWISE_GENERIC_TENSOR_SLICE_COPY_DEPRECATED_HPP -#define CK_BLOCKWISE_GENERIC_TENSOR_SLICE_COPY_DEPRECATED_HPP - -#include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" -#include "ConstantMergedTensorDescriptor_deprecated.hpp" -#include "tensor_coordinate_deprecated.hpp" -#include "threadwise_generic_tensor_slice_copy_deprecated.hpp" - -namespace ck { - -// Slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor -// memory layout (ordering of dimensions) can be different between src and dst. -// This functions assume each thread is reading and writing a normal (not merged) tensor, -// to simplify index calculations. To satisfy this assumption, the user need to make sure -// that, on a merged dimension that constains multiple original dimensions, the length of -// the last original dimension need to be evenly dividable by its sub-lengths. Also, the -// repeat-length on the merged dimension need to be 1. These sanity checks are performed -// in constructor of BlockwiseGenericTensorSliceCopy_v1_deprecated -template -struct BlockwiseGenericTensorSliceCopy_v1_deprecated -{ - static constexpr index_t nDim = SrcDesc::GetNumOfDimension(); - - static constexpr index_t nOriginalDimSrc = - SrcDesc::GetOriginalTensorDescriptor().GetNumOfDimension(); - static constexpr index_t nOriginalDimDst = - DstDesc::GetOriginalTensorDescriptor().GetNumOfDimension(); - - // per-thread offset - index_t mThreadSrcOffset; - index_t mThreadDstOffset; - - // "mThreadSrcOriginalMultiId", "mThreadSrcPartialOffsets, "mThreadDstOriginalMultiId", - // "mThreadDstPartialOffsets" are always calculated inside constructor, and would be - // updated if slicing-window is moved. However, they will not be used if you always move - // the slicing-window along a non-merged dimension. In that case, compiler should be - // able to remove these calculation. - // TODO: make sure compiler would actually remove them in that case - - // partial offset in each (merged) dimension - Array mThreadSrcPartialOffsets; - Array mThreadDstPartialOffsets; - - // multi-id of original tensor - Array mThreadSrcOriginalMultiId; - Array mThreadDstOriginalMultiId; - - __device__ - BlockwiseGenericTensorSliceCopy_v1_deprecated(Array src_block_data_id_begin, - Array dst_block_data_id_begin) - { - // check NDim consistency - static_assert( - nDim == SrcDesc::GetNumOfDimension() && nDim == DstDesc::GetNumOfDimension() && - nDim == SliceLengths::GetSize() && nDim == SubLengths::GetSize() && - nDim == ThreadClusterLengths::GetSize() && - nDim == ThreadClusterArrangeOrder::GetSize() && - nDim == SrcDimAccessOrder::GetSize() && nDim == DstDimAccessOrder::GetSize(), - "wrong"); - - // check thread arrange order and read/write access order are valid - static_assert(is_valid_sequence_map::value && - is_valid_sequence_map::value && - is_valid_sequence_map::value, - "wrong!"); - - // thread cluster - constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed( - ThreadClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{})); - - // BlockSize - static_assert(BlockSize == thread_cluster_desc.GetElementSize(), "wrong! BlockSize"); - - // divide work - constexpr auto data_per_cluster_per_dims = SubLengths{} * ThreadClusterLengths{}; - - static_for<0, nDim, 1>{}([&](auto IDim) { - static_assert(SliceLengths::Get(IDim) % data_per_cluster_per_dims.Get(IDim) == 0, - "wrong! cannot evenly divide sliced tensor into cluster"); - }); - - constexpr auto repeat_lengths = SliceLengths{} / data_per_cluster_per_dims; - - // additional check for merged dimension - static_for<0, nDim, 1>{}([&](auto IDim_) { - // src - static_if{}([&](auto) { - constexpr auto IDim = decltype(IDim_){}; - - // on a merged dimension that constains multiple original dimensions, - // the length of the last original dimension need to evenly dividable by its - // sub-length, - // so each thread is effectively reading a normal (not merged) tensor - constexpr auto idim_last_original_src = - SrcDesc::GetContainedOriginalDimensions(IDim).Back(); - static_assert( - SrcDesc::GetOriginalTensorDescriptor().GetLength(idim_last_original_src) % - SubLengths::Get(IDim) == - 0, - "wrong!"); - - // merged dimension should have repeat_lengths = 1 - static_assert(repeat_lengths[IDim] == 1, - "wrong! repeat_lengths shoud be 1 on merged dimension"); - }); - - // dst - static_if{}([&](auto) { - constexpr auto IDim = decltype(IDim_){}; - - // on a merged dimension that constains multiple original dimensions, - // the length of the last original dimension need to evenly dividable by its - // sub-length, - // so each thread is effectively reading a normal (not merged) tensor - constexpr auto idim_last_original_dst = - DstDesc::GetContainedOriginalDimensions(IDim).Back(); - static_assert( - DstDesc::GetOriginalTensorDescriptor().GetLength(idim_last_original_dst) % - SubLengths::Get(IDim) == - 0, - "wrong!"); - - // merged dimension should have repeat_lengths = 1 - static_assert(repeat_lengths[IDim] == 1, - "wrong! repeat_lengths shoud be 1 on merged dimension"); - }); - }); - - // calculate mThreadSrcOffset, mThreadDstOffset - const auto thread_cluster_id = - thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id()); - - const auto data_cluster_id = - reorder_array_given_old2new(thread_cluster_id, ThreadClusterArrangeOrder{}); - - const auto thread_data_id_begin = data_cluster_id * SubLengths{}; - - // original multi-id - mThreadSrcOriginalMultiId = SrcDesc::GetOriginalMultiIndexFromMultiIndex( - src_block_data_id_begin + thread_data_id_begin); - - mThreadDstOriginalMultiId = DstDesc::GetOriginalMultiIndexFromMultiIndex( - dst_block_data_id_begin + thread_data_id_begin); - - // partial offset on each dimension - static_for<0, nDim, 1>{}([&](auto IDim) { - constexpr auto src_partial_original_dims = - SrcDesc::GetContainedOriginalDimensions(IDim); - - constexpr auto src_partial_original_desc = - SrcDesc::GetOriginalTensorDescriptor().Extract(src_partial_original_dims); - - mThreadSrcPartialOffsets(IDim) = src_partial_original_desc.GetOffsetFromMultiIndex( - extract_array(mThreadSrcOriginalMultiId, src_partial_original_dims)); - }); - - static_for<0, nDim, 1>{}([&](auto IDim) { - constexpr auto dst_partial_original_dims = - DstDesc::GetContainedOriginalDimensions(IDim); - - constexpr auto dst_partial_original_desc = - DstDesc::GetOriginalTensorDescriptor().Extract(dst_partial_original_dims); - - mThreadDstPartialOffsets(IDim) = dst_partial_original_desc.GetOffsetFromMultiIndex( - extract_array(mThreadDstOriginalMultiId, dst_partial_original_dims)); - }); - - // complete offset - mThreadSrcOffset = accumulate_on_array( - mThreadSrcPartialOffsets, math::plus{}, static_cast(0)); - - mThreadDstOffset = accumulate_on_array( - mThreadDstPartialOffsets, math::plus{}, static_cast(0)); - } - - __device__ static constexpr auto GetRegisterBufferDescriptor() - { - constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{}); - - return make_ConstantTensorDescriptor_packed(SubLengths{} * repeat_lengths); - } - - __device__ static constexpr index_t GetThreadBufferSize() - { - return GetRegisterBufferDescriptor().GetElementSpace(); - } - - template - __device__ void RunLoadThreadBuffer(const TData* __restrict__ p_src, - TData* __restrict__ p_buffer) const - { - constexpr auto thread_sub_tensor_lengths = SubLengths{}; - - constexpr auto data_per_cluster_per_dims = - thread_sub_tensor_lengths * ThreadClusterLengths{}; - - constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{}); - - constexpr auto thread_buffer_desc = GetRegisterBufferDescriptor(); - -#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 - static_ford{}([&](auto repeat_id) { - constexpr auto src_thread_data_id_begin = repeat_id * data_per_cluster_per_dims; - - constexpr auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths; - - constexpr index_t src_offset = - SrcDesc::GetOffsetFromMultiIndex(src_thread_data_id_begin); - - constexpr index_t buffer_offset = - thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin); -#else - ford{}([&](auto repeat_id) { - const auto src_thread_data_id_begin = repeat_id * data_per_cluster_per_dims; - - const auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths; - - const index_t src_offset = SrcDesc::GetOffsetFromMultiIndex(src_thread_data_id_begin); - - const index_t buffer_offset = - thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin); -#endif - - // By position the origin of the per-thread window at the point, where multi-index - // of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy - // is assuming each thread is copy a noraml (not merged) tensor. - // To satisfy this assumption, the user need to make sure that, on a merged dimension - // that constains multiple original dimensions, the length of the last original - // dimension need to be evenly dividable by its sub-lengths. Also, the repeat-length on - // the merged dimension need to be 1. These sanity checks are performed in constructor - // of BlockwiseGenericTensorSliceCopy_v1_deprecated - ThreadwiseGenericTensorSliceCopy_v1r2_deprecated(make_zero_array(), - make_zero_array()) - .Run(p_src + src_offset + mThreadSrcOffset, p_buffer + buffer_offset); - }); - } - - template - __device__ void RunStoreThreadBuffer(const TData* __restrict__ p_buffer, - TData* __restrict__ p_dst) const - { - constexpr auto thread_sub_tensor_lengths = SubLengths{}; - - constexpr auto data_per_cluster_per_dims = - thread_sub_tensor_lengths * ThreadClusterLengths{}; - - constexpr auto repeat_lengths = SliceLengths{} / (SubLengths{} * ThreadClusterLengths{}); - - constexpr auto thread_buffer_desc = GetRegisterBufferDescriptor(); - -#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 - static_ford{}([&](auto repeat_id) { - constexpr auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths; - - constexpr auto dst_data_id_begin = repeat_id * data_per_cluster_per_dims; - - constexpr index_t buffer_offset = - thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin); - - constexpr index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(dst_data_id_begin); -#else - ford{}([&](auto repeat_id) { - const auto buffer_data_id_begin = repeat_id * thread_sub_tensor_lengths; - - const auto dst_data_id_begin = repeat_id * data_per_cluster_per_dims; - - const index_t buffer_offset = - thread_buffer_desc.GetOffsetFromMultiIndex(buffer_data_id_begin); - - const index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(dst_data_id_begin); -#endif - - // By position the origin of the per-thread window at the point, where multi-index - // of the SrcDesc (might be a merged tensor) is all-zero. This threadwise slice copy - // is assuming each thread is copy a noraml (not merged) tensor. - // To satisfy this assumption, the user need to make sure that, on a merged dimension - // that constains multiple original dimensions, the length of the last original - // dimension need to be evenly dividable by its sub-lengths. Also, the repeat-length on - // the merged dimension need to be 1. These sanity checks are performed in constructor - // of BlockwiseGenericTensorSliceCopy_v1_deprecated - ThreadwiseGenericTensorSliceCopy_v1r2_deprecated( - make_zero_array(), make_zero_array()) - .Run(p_buffer + buffer_offset, p_dst + dst_offset + mThreadDstOffset); - }); - } - - template - __device__ void Run(const TData* __restrict__ p_src, TData* __restrict__ p_dst) const - { - TData p_buffer[GetThreadBufferSize()]; - - RunLoadThreadBuffer(p_src, p_buffer); - RunStoreThreadBuffer(p_buffer, p_dst); - } - - // When moving the slicing windows along a merged dimension, if the strides of the - // contained (by the merged dimension) original dimensions are not in descending order, - // then there is no guarantee that the new offset will be larger than the old offset - // for movement in positive direction (vice versue for movement in negative direction). - // As a result, there is the possiblity that the offset calculation may result in - // unsigned integer underflow (due to "-" operation). However, this hazard should not - // happen, as long as the users make sure the slicing window would not be moved out of - // the boundary of the tensor being sliced. This functions doesn't do runtime sanity - // check on out-of-bound slicing window, for performance reason - template - __device__ void MoveSlicingWindowOnSourceTensor( - Number, Number, integral_constant direction) - { - constexpr auto IDim = Number{}; - - static_if{}([&](auto) { - // logic for a merged dimension, also works for non-merged dimension, but its logic may - // be unncessarily complicated for compiler to remove calculations that are useless for - // a non-merged dimension - - // extract partial original dimensions - constexpr auto src_partial_original_dims = - SrcDesc::GetContainedOriginalDimensions(IDim); - - constexpr auto src_partial_original_desc = - SrcDesc::GetOriginalTensorDescriptor().Extract(src_partial_original_dims); - - // calculate new partial original multi-id - auto old_src_partial_original_id = - extract_array(mThreadSrcOriginalMultiId, src_partial_original_dims); - - auto new_src_partial_original_id = - src_partial_original_desc.UpdateMultiIndexGivenStepSizeOf1dIndex( - old_src_partial_original_id, StepSize, direction); - - // update "mThreadSrcOriginalMultiId" - static_for<0, decltype(src_partial_original_dims)::GetSize(), 1>{}([&](auto I) { - constexpr auto IDimOriginal = src_partial_original_dims[I]; - - mThreadSrcOriginalMultiId(IDimOriginal) = new_src_partial_original_id[I]; - }); - - // calculate new partial offset on this merged dimension - const index_t old_src_partial_offset = mThreadSrcPartialOffsets[IDim]; - - const index_t new_src_partial_offset = - src_partial_original_desc.GetOffsetFromMultiIndex(new_src_partial_original_id); - - // update "mThreadSrcPartialOffsets" - mThreadSrcPartialOffsets(IDim) = new_src_partial_offset; - - // update "mThreadSrcOffset", do "+" before "-" to avoid underflow - mThreadSrcOffset = (mThreadSrcOffset + new_src_partial_offset) - old_src_partial_offset; - }).Else([&](auto) { - // Logic for non-merged dimension. If you are never going to move the slicing window on - // a merged dimension, then "mThreadSrcOriginalMultiId" and "mThreadSrcPartialOffsets", - // which are being calculated here, will never be used later. In this case, compiler - // should be able to remove these calculations. - // TODO: make sure compiler would actually remove them in this case. - - // It is the user's responsiblity to make sure the slicing window will not be moved out - // of the boundary of the tensor being sliced. Otherwise, there might be hazard like - // unsigned integer underflow. That is NO runtime sanity check to prevent the hazard - - constexpr auto IDimOriginal = SrcDesc::GetContainedOriginalDimensions(IDim).Front(); - - static_if{}([&](auto fwd) { - mThreadSrcOffset += StepSize * fwd(SrcDesc{}).GetStride(IDim); - - mThreadSrcOriginalMultiId(IDimOriginal) += StepSize; - - mThreadSrcPartialOffsets(IDim) += StepSize * fwd(SrcDesc{}).GetStride(IDim); - }).Else([&](auto fwd) { - mThreadSrcOffset -= StepSize * fwd(SrcDesc{}).GetStride(IDim); - - mThreadSrcOriginalMultiId(IDimOriginal) -= StepSize; - - mThreadSrcPartialOffsets(IDim) -= StepSize * fwd(SrcDesc{}).GetStride(IDim); - }); - }); - } - - template - __device__ void - MoveSrcSliceWindow(T step_sizes, integral_constant positive_direction) - { - static_for<0, nDim, 1>{}([&](auto idim) { - if(step_sizes[idim] != 0) - { - MoveSlicingWindowOnSourceTensor(idim, step_sizes[idim], positive_direction); - } - }); - } -}; - -// This version use TensorCoordiante -// Slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor -// memory layout (ordering of dimensions) can be different between src and dst. -template -struct BlockwiseGenericTensorSliceCopy_v2_deprecated -{ - static constexpr index_t nDim = SrcDesc::GetNumOfDimension(); - - using Index = MultiIndex; - - __device__ constexpr BlockwiseGenericTensorSliceCopy_v2_deprecated( - const Index& src_block_slice_origin, const Index& dst_block_slice_origin) - { - static_assert( - nDim == SrcDesc::GetNumOfDimension() && nDim == DstDesc::GetNumOfDimension() && - nDim == SliceLengths::GetSize() && nDim == SubLengths::GetSize() && - nDim == ThreadClusterLengths::GetSize() && - nDim == ThreadClusterArrangeOrder::GetSize() && - nDim == SrcDimAccessOrder::GetSize() && nDim == DstDimAccessOrder::GetSize(), - "wrong! nDim not consistent"); - - static_assert(is_same{}, - "wrong! threads should be mapped to cover entire slicing window"); - - constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor_packed( - ThreadClusterLengths::ReorderGivenNew2Old(ThreadClusterArrangeOrder{})); - - static_assert(BlockSize == thread_cluster_desc.GetElementSize(), - "wrong! BlockSize not consistent with ThreadClusterLengths"); - - const auto thread_cluster_id = - thread_cluster_desc.GetMultiIndexFrom1dIndex(get_thread_local_1d_id()); - - const auto data_cluster_id = - reorder_array_given_old2new(thread_cluster_id, ThreadClusterArrangeOrder{}); - - const auto thread_data_id_begin = data_cluster_id * SubLengths{}; - - mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin); - mThreadwiseLoad.SetDstSliceOrigin(make_zero_array()); - - mThreadwiseStore.SetSrcSliceOrigin(make_zero_array()); - mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin); - } - - __device__ static constexpr index_t GetThreadBufferSize() - { - return ThreadBufferDesc::GetElementSpace(); - } - - template - __device__ void - RunLoadThreadBuffer(const BlockSrcData* p_block_src, - ThreadBufferData* p_thread_buffer, - integral_constant, - integral_constant) const - { - constexpr auto block_src_address_space = - integral_constant{}; - constexpr auto thread_buffer_address_space = - integral_constant{}; - - mThreadwiseLoad.Run( - p_block_src, p_thread_buffer, block_src_address_space, thread_buffer_address_space); - } - - template - __device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src, - ThreadBufferData* p_thread_buffer) const - { - constexpr auto generic_address_space = - integral_constant{}; - - RunLoadThreadBuffer( - p_block_src, p_thread_buffer, generic_address_space, generic_address_space); - } - - template - __device__ void - RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer, - BlockDstData* p_block_dst, - integral_constant, - integral_constant) const - { - constexpr auto thread_buffer_address_space = - integral_constant{}; - constexpr auto block_dst_address_space = - integral_constant{}; - - mThreadwiseStore.Run( - p_thread_buffer, p_block_dst, thread_buffer_address_space, block_dst_address_space); - } - - template - __device__ void RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer, - BlockDstData* p_block_dst) const - { - constexpr auto generic_address_space = - integral_constant{}; - - RunStoreThreadBuffer( - p_thread_buffer, p_block_dst, generic_address_space, generic_address_space); - } - - template - __device__ void - Run(const BlockSrcData* p_block_src, - BlockDstData* p_block_dst, - integral_constant block_src_address_space, - integral_constant block_dst_address_space) const - { - BlockSrcData p_thread_buffer[GetThreadBufferSize()]; - - constexpr auto generic_address_space = - integral_constant{}; - - RunLoadThreadBuffer( - p_block_src, p_thread_buffer, block_src_address_space, generic_address_space); - - // if there is type conversion, it's done during store - RunStoreThreadBuffer( - p_thread_buffer, p_block_dst, generic_address_space, block_dst_address_space); - } - - template - __device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const - { - constexpr auto generic_address_space = - integral_constant{}; - - Run(p_block_src, p_block_dst, generic_address_space, generic_address_space); - } - - template - __device__ void - MoveSrcSliceWindow(T step_sizes, integral_constant positive_direction) - { - mThreadwiseLoad.MoveSrcSliceWindow(step_sizes, positive_direction); - } - - template - __device__ void - MoveDstSliceWindow(T step_sizes, integral_constant positive_direction) - { - mThreadwiseStore.MoveDstSliceWindow(step_sizes, positive_direction); - } - - private: - using ThreadBufferDesc = decltype(make_ConstantTensorDescriptor_packed(SubLengths{})); - - using ThreadwiseLoad = ThreadwiseGenericTensorSliceCopy_v2r1_deprecated; - - using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v2r1_deprecated; - - ThreadwiseLoad mThreadwiseLoad; - ThreadwiseStore mThreadwiseStore; -}; - -} // namespace ck - -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm.hpp index e5c8e37495..d4cbee1ced 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm.hpp @@ -186,7 +186,6 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 "wrong!"); constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster); - constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster); // c_thread_mtx definition: this is a mess @@ -201,11 +200,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 decltype(c_m0m1_n0n1_thread_mtx_desc), MPerThread, NPerThread, + KPerThread, MLevel0Cluster, NLevel0Cluster, MLevel1Cluster, NLevel1Cluster, - KPerThread, ThreadGemmAThreadCopySrcDataPerRead_M, ThreadGemmBThreadCopySrcDataPerRead_N>{}; diff --git a/composable_kernel/include/tensor_operation/gridwise_tensor_contraction.hpp b/composable_kernel/include/tensor_operation/gridwise_tensor_contraction.hpp new file mode 100644 index 0000000000..3a3960863f --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_tensor_contraction.hpp @@ -0,0 +1,330 @@ +#ifndef CK_GRIDWISE_TENSOR_CONTRACTION_HPP +#define CK_GRIDWISE_TENSOR_CONTRACTION_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "ConstantMatrixDescriptor.hpp" +#include "blockwise_generic_tensor_slice_copy.hpp" +#include "threadwise_generic_tensor_slice_copy.hpp" +#include "blockwise_gemm.hpp" + +namespace ck { + +template +struct GridwiseTensorContraction_v1 +{ + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() {} + + __device__ void Run(const Float* __restrict__ p_a_global, + const Float* __restrict__ p_b_global, + Float* __restrict__ p_c_global, + Float* __restrict__ p_shared_block) const + { + /// \todo sanity-check on AGlobalDesc, BGlboalDesc, CGlobalDesc length consisitency + /// \todo santiy-check on CBlockLengtsh + + constexpr auto True = integral_constant{}; + + constexpr auto a_global_desc = AGlobalDesc{}; + constexpr auto b_global_desc = BGlobalDesc{}; + constexpr auto c_global_desc = CGlobalDesc{}; + + constexpr auto K = a_global_desc.GetLengths()[0]; + + // don't do anything if K == 0 + if(K == 0) + { + return; + } + + // lds max alignment + constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M, + BBlockCopyDstDataPerWrite_N, + ThreadGemmAThreadCopySrcDataPerRead_M, + ThreadGemmBThreadCopySrcDataPerRead_N); + + // divide block work by [M, N] + static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0, + "wrong! cannot divide work evenly among block"); + + constexpr index_t MBlockWork = M / MPerBlock; + constexpr index_t NBlockWork = N / NPerBlock; + + constexpr auto block_work_desc = + make_cluster_descriptor(Sequence{}); + + const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id()); + + const index_t m_block_data_on_global = block_work_id[0] * MPerBlock; + const index_t n_block_data_on_global = block_work_id[1] * NPerBlock; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned( + Sequence{}, Number{}); + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseGenericTensorSliceCopy_v4, + ABlockCopySrcVectorReadDim, + 1, + ABlockCopySrcDataPerRead, + ABlockCopyDstDataPerWrite_M, + AddressSpace::Global, + AddressSpace::Vgpr, + AddressSpace::Lds, + InMemoryDataOperation::Set>( + {0, m_block_data_on_global}, {0, 0}); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned( + Sequence{}, Number{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + BlockwiseGenericTensorSliceCopy_v4, + BBlockCopySrcVectorReadDim, + 1, + BBlockCopySrcDataPerRead, + BBlockCopyDstDataPerWrite_N, + AddressSpace::Global, + AddressSpace::Vgpr, + AddressSpace::Lds, + InMemoryDataOperation::Set>( + {0, n_block_data_on_global}, {0, 0}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[KPerBlock, MPerBlock] is in LDS + // b_mtx[KPerBlocl, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor(a_k_m_block_desc); + constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(b_k_n_block_desc); + + // sanity check + static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 && + NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0, + "wrong!"); + + constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster); + + constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster); + + // c_thread_mtx definition: this is a mess + // TODO:: more elegent way of defining c_thread_mtx + constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed( + Number{}, Number{}); + + const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< + BlockSize, + decltype(a_k_m_block_mtx_desc), + decltype(b_k_n_block_mtx_desc), + decltype(c_m0m1_n0n1_thread_mtx_desc), + MPerThread, + NPerThread, + MLevel0Cluster, + NLevel0Cluster, + MLevel1Cluster, + NLevel1Cluster, + KPerThread, + ThreadGemmAThreadCopySrcDataPerRead_M, + ThreadGemmBThreadCopySrcDataPerRead_N>{}; + + // LDS allocation for A and B: be careful of alignment + constexpr index_t a_block_space = + math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align); + + constexpr index_t b_block_space = + math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align); + + Float* p_a_block_double = p_shared_block; + Float* p_b_block_double = p_shared_block + 2 * a_block_space; + + // register allocation for output + AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()]; + + // zero out threadwise output + threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread); + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.Run(p_a_global, p_a_block_double); + b_blockwise_copy.Run(p_b_global, p_b_block_double); + } + + constexpr auto a_block_slice_copy_steps = Sequence{}; + constexpr auto b_block_slice_copy_steps = Sequence{}; + + // LDS double buffer: main body + for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K; + k_block_data_begin += 2 * KPerBlock) + { +#pragma unroll + for(index_t iloop = 0; iloop < 2; ++iloop) + { + const bool even_loop = (iloop % 2 == 0); + + Float* p_a_block_now = + even_loop ? p_a_block_double : p_a_block_double + a_block_space; + Float* p_b_block_now = + even_loop ? p_b_block_double : p_b_block_double + b_block_space; + + Float* p_a_block_next = + even_loop ? p_a_block_double + a_block_space : p_a_block_double; + Float* p_b_block_next = + even_loop ? p_b_block_double + b_block_space : p_b_block_double; + + Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; + Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; + + a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True); + b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer); + b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next); + b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next); + } + } + + // LDS double buffer: tail + { + constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0); + + if(has_two_iteration_left) // if has 2 iteration left + { + Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; + Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; + + a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True); + b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True); + + __syncthreads(); + + // LDS double buffer: load last data from device mem + a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer); + b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); + + // LDS double buffer: store last data to LDS + a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, + p_a_block_double + a_block_space); + b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, + p_b_block_double + b_block_space); + + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + p_a_block_double + a_block_space, p_b_block_double + b_block_space, p_c_thread); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); + } + } + + // input: register to global memory + { + constexpr index_t M1 = MPerThread * MLevel0Cluster * MLevel1Cluster; + constexpr index_t M0 = M / M1; + + constexpr index_t N1 = NPerThread * NLevel0Cluster * NLevel1Cluster; + constexpr index_t N0 = N / N1; + + // define input tensor descriptor for threadwise copy + // thread input tensor, src of threadwise copy + constexpr auto c_m0_m1_n0_n1_thread_desc = make_native_tensor_descriptor_packed( + Sequence{}); + + constexpr auto c_m0_m1_n0_n1_global_desc = transform_tensor_descriptor( + c_m_n_global_desc, + make_tuple(UnMerge>{}, UnMerge>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + // calculate origin of thread input tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const index_t m_thread_data_on_global = + m_block_data_on_global + c_thread_mtx_on_block.row; + + const index_t n_thread_data_on_global = + n_block_data_on_global + c_thread_mtx_on_block.col; + + ThreadwiseGenericTensorSliceCopy_v4r2( + {0, 0, 0, 0}, + {m_thread_data_on_global / M1, + m_thread_data_on_global % M1, + n_thread_data_on_global / N1, + n_thread_data_on_global % N1}) + .Run(p_c_thread, p_c_global); + } + } + + __device__ void Run(const Float* __restrict__ p_a_global, + const Float* __restrict__ p_b_global, + Float* __restrict__ p_c_global) const + { + constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float); + + __shared__ Float p_shared_block[shared_block_size]; + + Run(p_a_global, p_b_global, p_c_global, p_shared_block); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp index 1538623e41..2dd4a79912 100644 --- a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp @@ -23,7 +23,9 @@ template + InMemoryDataOperation DstInMemOp = InMemoryDataOperation::Set, + index_t SrcDataStride = 1, + index_t DstDataStride = 1> struct ThreadwiseGenericTensorSliceCopy_v4r2 { static constexpr index_t nDim = SliceLengths::Size(); @@ -116,7 +118,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 SrcDataPerRead, SrcAddressSpace, AddressSpace::Vgpr, - InMemoryDataOperation::Set>( + InMemoryDataOperation::Set, + SrcDataStride, + 1>( p_src, src_coord.GetOffset(), p_src_long_vector, buffer_offset); } } @@ -148,7 +152,9 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 DstDataPerWrite, AddressSpace::Vgpr, DstAddressSpace, - DstInMemOp>( + DstInMemOp, + 1, + DstDataStride>( p_dst_long_vector, buffer_offset, p_dst, dst_coord.GetOffset()); } } diff --git a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy_deprecated.hpp b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy_deprecated.hpp deleted file mode 100644 index 71460f33d2..0000000000 --- a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy_deprecated.hpp +++ /dev/null @@ -1,495 +0,0 @@ -#ifndef CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_DEPRECATED_HPP -#define CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_DEPRECATED_HPP - -#include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" -#include "ConstantMergedTensorDescriptor_deprecated.hpp" -#include "tensor_coordinate_deprecated.hpp" - -namespace ck { - -// This threadwise copy allow vector access of src and dst. -// It allows the vector size to be different on src and dst. -// The dimensions of vector access should be the same on src and dst. -// The dimension access order should be the same on src and dst. -// It is designed for cases, where one of src and dst is register, and -// the other is device memory or LDS -template -struct ThreadwiseGenericTensorSliceCopy_v1r2_deprecated -{ - static constexpr index_t nDim = SliceLengths::GetSize(); - - __device__ constexpr ThreadwiseGenericTensorSliceCopy_v1r2_deprecated( - Array src_slice_origin, Array dst_slice_origin) - : mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin) - { - static_assert(nDim == SrcDesc::GetNumOfDimension() && - nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() && - nDim == DimAccessOrder::GetSize(), - "wrong! # of dimensions not the same"); - - static_assert(is_valid_sequence_map::value, "wrong! map is not valid"); - - static_assert( - SliceLengths{}[VectorAccessDim] % math::lcm(SrcDataPerAccess, DstDataPerAccess) == 0, - "wrong! cannot evenly divide"); - - // check vectorized memory access - constexpr auto vector_access_dim = Number{}; - - static_if{}([&](auto fwd) { - static_assert( - (fwd(SrcDesc{}).GetStride(vector_access_dim) == 1 || SrcDataPerAccess == 1), - "wrong! vectorized access is allowed only if stride == 1"); - }).Else([&](auto fwd) { - static_assert((fwd(SrcDesc{}).GetLastOriginalDimensionStride(vector_access_dim) == 1 || - SrcDataPerAccess == 1), - "wrong! vectorized access is allowed only if stride == 1"); - }); - - static_if{}([&](auto fwd) { - static_assert( - (fwd(DstDesc{}).GetStride(vector_access_dim) == 1 || DstDataPerAccess == 1), - "wrong! vectorized access is allowed only if stride == 1"); - }).Else([&](auto fwd) { - static_assert((fwd(DstDesc{}).GetLastOriginalDimensionStride(vector_access_dim) == 1 || - DstDataPerAccess == 1), - "wrong! vectorized access is allowed only if stride == 1"); - }); - } - - __device__ constexpr ThreadwiseGenericTensorSliceCopy_v1r2_deprecated() - : ThreadwiseGenericTensorSliceCopy_v1r2_deprecated(make_zero_array(), - make_zero_array()) - { - } - - __device__ void SetSrcSliceOrigin(Array src_slice_origin) - { - mSrcSliceOrigin = src_slice_origin; - } - - __device__ void SetDstSliceOrigin(Array dst_slice_origin) - { - mDstSliceOrigin = dst_slice_origin; - } - - template - __device__ void Run(const SrcData* p_src, DstData* p_dst) const - { - using src_vector_t = typename vector_type::MemoryType; - using dst_vector_t = typename vector_type::MemoryType; - - constexpr auto vector_access_dim = Number{}; - - constexpr auto src_data_per_access = Number{}; - constexpr auto dst_data_per_access = Number{}; - - constexpr auto long_vector_size = Number{}; - - constexpr auto long_vector_access_lengths = SliceLengths::Modify( - vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size); - - ford{}( - [&](auto long_vector_access_id) { - - // data id w.r.t slicing-window - auto long_vector_data_begin_id = long_vector_access_id; - long_vector_data_begin_id(vector_access_dim) = - long_vector_size * long_vector_access_id[vector_access_dim]; - - // buffer to hold a long-vector - SrcData p_src_long_vector[long_vector_size]; - DstData p_dst_long_vector[long_vector_size]; - - // load data from src to the long-vector buffer - for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i) - { - auto scalar_id = make_zero_array(); - scalar_id(vector_access_dim) = i * src_data_per_access; - - const index_t src_offset = SrcDesc::GetOffsetFromMultiIndex( - mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id)); - - const index_t buffer_offset = i * src_data_per_access; - - *reinterpret_cast(&p_src_long_vector[buffer_offset]) = - *reinterpret_cast(&p_src[src_offset]); - } - - // type conversion - for(index_t i = 0; i < long_vector_size; ++i) - { - p_dst_long_vector[i] = type_convert{}(p_src_long_vector[i]); - } - - // store data from the long-vector buffer to dst - for(index_t i = 0; i < long_vector_size / dst_data_per_access; ++i) - { - auto scalar_id = make_zero_array(); - scalar_id(vector_access_dim) = i * dst_data_per_access; - - const index_t buffer_offset = i * dst_data_per_access; - - const index_t dst_offset = DstDesc::GetOffsetFromMultiIndex( - mDstSliceOrigin + (long_vector_data_begin_id + scalar_id)); - - *reinterpret_cast(&p_dst[dst_offset]) = - *reinterpret_cast(&p_dst_long_vector[buffer_offset]); - } - }); - } - - private: - Array mSrcSliceOrigin; - Array mDstSliceOrigin; -}; - -// This version use TensorCoordinate_deprecated -// This threadwise copy allow vector access of src and dst. -// It allows the dimensions of vector access to be different on src and dst. -// It also allows the vector size to be different on src and dst. -// It also allows order of access to be different on src and dst. -// It use register as buffer to hold all data moving from src to dst. -// It is designed for copying small amount of data, and src and dst are -// device memory or LDS. -// When copying large amout of data, let's hope compiler will reduce register -// used for the buffer. -template -struct ThreadwiseGenericTensorSliceCopy_v2r1_deprecated -{ - static constexpr index_t nDim = SliceLengths::GetSize(); - - using Index = MultiIndex; - - using SrcCoordinate = typename TensorCoordinate_deprecated::type; - using DstCoordinate = typename TensorCoordinate_deprecated::type; - - __device__ constexpr ThreadwiseGenericTensorSliceCopy_v2r1_deprecated( - const Index& src_slice_origin, const Index& dst_slice_origin) - : mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin) - { - static_assert(nDim == SrcDesc::GetNumOfDimension() && - nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() && - nDim == SrcDimAccessOrder::GetSize() && - nDim == DstDimAccessOrder::GetSize(), - "wrong! # of dimensions not the same"); - - static_assert(is_valid_sequence_map::value && - is_valid_sequence_map::value, - "wrong! map is not valid"); - - static_assert(SliceLengths{}[SrcVectorAccessDim] % SrcDataPerAccess == 0 && - SliceLengths{}[DstVectorAccessDim] % DstDataPerAccess == 0, - "wrong! cannot evenly divide"); - - // check vectorized memory access - constexpr auto src_vector_access_dim = Number{}; - constexpr auto dst_vector_access_dim = Number{}; - - static_if{}( - [&](auto fwd) { - static_assert( - (fwd(SrcDesc{}).GetStride(src_vector_access_dim) == 1 || SrcDataPerAccess == 1), - "wrong! vectorized access is allowed only if stride == 1"); - }) - .Else([&](auto fwd) { - static_assert( - (fwd(SrcDesc{}).GetLastOriginalDimensionStride(src_vector_access_dim) == 1 || - SrcDataPerAccess == 1), - "wrong! vectorized access is allowed only if stride == 1"); - }); - - static_if{}( - [&](auto fwd) { - static_assert( - (fwd(DstDesc{}).GetStride(dst_vector_access_dim) == 1 || DstDataPerAccess == 1), - "wrong! vectorized access is allowed only if stride == 1"); - }) - .Else([&](auto fwd) { - static_assert( - (fwd(DstDesc{}).GetLastOriginalDimensionStride(dst_vector_access_dim) == 1 || - DstDataPerAccess == 1), - "wrong! vectorized access is allowed only if stride == 1"); - }); - } - - __device__ constexpr ThreadwiseGenericTensorSliceCopy_v2r1_deprecated() - : ThreadwiseGenericTensorSliceCopy_v2r1_deprecated(make_zero_array(), - make_zero_array()) - { - } - - __device__ void SetSrcSliceOrigin(SrcCoordinate src_slice_origin) - { - mSrcSliceOrigin = src_slice_origin; - } - - __device__ void SetDstSliceOrigin(DstCoordinate dst_slice_origin) - { - mDstSliceOrigin = dst_slice_origin; - } - - template - struct IsolateMergedDimLengths - { - template - __device__ constexpr index_t operator()(IDim idim) const - { - return TDesc::ContainMultipleOriginalDimensions(idim) ? Lengths{}[idim] : 1; - } - }; - - template - __device__ void Run(const SrcData* p_src, - DstData* p_dst, - integral_constant, - integral_constant) const - { - constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{}); - - SrcData p_src_buffer_[buffer_desc.GetElementSpace()]; - SrcData* p_src_buffer = p_src_buffer_; - - // copy data from src into buffer - { - using src_vector_t = typename vector_type::MemoryType; - - constexpr auto src_vector_access_dim = Number{}; - constexpr auto src_data_per_access = Number{}; - - constexpr auto src_access_lengths = SliceLengths::Modify( - src_vector_access_dim, - SliceLengths::Get(src_vector_access_dim) / src_data_per_access); - - // Offset w.r.t merged dimensions need to be calculated at run-time. Offset w.r.t - // normal dimensions is known at compile time. - // Below is a hack to isolate merged dimension id from normal dimension id, so the - // corresponding offset can be calculated seperately at run-time and compile-time. - // src_merged_dim_access_lengths has the same value as src_access_lengths on src's - // merged dimensions, and has value = 1 on normal dimensions; - // src_merged_dim_access_lengths has the same value as src_access_lengths on src's - // normal dimensions, and has value = 1 on merged dimensions; - constexpr auto src_merged_dim_access_lengths = typename sequence_gen< - nDim, - IsolateMergedDimLengths>::type{}; - - constexpr auto src_normal_dim_access_lengths = - src_access_lengths + Number<1>{} - src_merged_dim_access_lengths; - - ford{}( - [&](auto src_merged_dim_access_id) { - - auto src_merged_dim_data_id = src_merged_dim_access_id; - src_merged_dim_data_id(src_vector_access_dim) = - src_merged_dim_access_id[src_vector_access_dim] * src_data_per_access; - - // offset w.r.t. merged dimension need be computed at run-time, - const index_t src_merged_offset = - (mSrcSliceOrigin + src_merged_dim_data_id).GetOffset(); - - ford{}([&]( - auto src_normal_dim_access_id) { - - auto src_normal_dim_data_id = src_normal_dim_access_id; - src_normal_dim_data_id(src_vector_access_dim) = - src_normal_dim_access_id[src_vector_access_dim] * src_data_per_access; - - // offset w.r.t. normal dimension is known at compile-time - const index_t src_normal_offset = - SrcDesc::GetOffsetFromMultiIndex(src_normal_dim_data_id); - - src_vector_t vector_data; - - // Read vector from src. - // 1. Source code version can take src of all kinds of memory-space - // 2. Intrinsic version using buffer_load can only take - // src from global-memory - // - // Commemt for loading from global-memory: - // When: - // 1) using source code, in order for compiler to emit optimal - // load instruction, or - // 2) using buffer_load intrinsic, in order for ISA to be valid, - // following assumptions need to be satisfied: - // 1. p_src need to be block-invariant (assumption) - // 2. src_normal_offset must be calculatd at compile time (guaranteed by - // algorithm) - // 3. src_merged_offset can be runtime value (no assumption imposed) - static_if{}([&](auto fwd) { -#if CK_USE_AMD_BUFFER_ADDRESSING - vector_data = amd_intrinsic_buffer_load( - fwd(p_src), src_merged_offset, src_normal_offset); -#else - vector_data = *reinterpret_cast( - &p_src[src_normal_offset + src_merged_offset]); -#endif - }).Else([&](auto) { - // src can be all kinds of memory-space. - vector_data = *reinterpret_cast( - &p_src[src_normal_offset + src_merged_offset]); - }); - - // unpack vector into buffer - for(index_t i = 0; i < SrcDataPerAccess; ++i) - { - auto scalar_id = make_zero_array(); - scalar_id(src_vector_access_dim) = i; - - const index_t buffer_offset = buffer_desc.GetOffsetFromMultiIndex( - src_merged_dim_data_id + src_normal_dim_data_id + scalar_id); - - p_src_buffer[buffer_offset] = - reinterpret_cast(&vector_data)[i]; - } - }); - }); - } - - // type conversion - // TODO: would compiler do a good job reusing register for buffer? - DstData p_dst_buffer_[buffer_desc.GetElementSpace()]; - DstData* p_dst_buffer = p_dst_buffer_; - - ford{}([&](auto idx) { - p_dst_buffer[buffer_desc.GetOffsetFromMultiIndex(idx)] = - type_convert{}(p_src_buffer[buffer_desc.GetOffsetFromMultiIndex(idx)]); - }); - - // copy data from buffer into dst - { - using dst_vector_t = typename vector_type::MemoryType; - - constexpr auto dst_vector_access_dim = Number{}; - constexpr auto dst_data_per_access = Number{}; - - constexpr auto dst_access_lengths = SliceLengths::Modify( - dst_vector_access_dim, - SliceLengths::Get(dst_vector_access_dim) / dst_data_per_access); - - constexpr auto dst_merged_dim_access_lengths = typename sequence_gen< - nDim, - IsolateMergedDimLengths>::type{}; - - constexpr auto dst_normal_dim_access_lengths = - dst_access_lengths + Number<1>{} - dst_merged_dim_access_lengths; - - ford{}([&]( - auto dst_merged_dim_access_id) { - - auto dst_merged_dim_data_id = dst_merged_dim_access_id; - dst_merged_dim_data_id(dst_vector_access_dim) = - dst_merged_dim_access_id[dst_vector_access_dim] * dst_data_per_access; - - // offset w.r.t. merged dimension need be computed at run-time, - const index_t dst_merged_offset = - (mDstSliceOrigin + dst_merged_dim_data_id).GetOffset(); - - ford{}([&]( - auto dst_normal_dim_access_id) { - - auto dst_normal_dim_data_id = dst_normal_dim_access_id; - dst_normal_dim_data_id(dst_vector_access_dim) = - dst_normal_dim_access_id[dst_vector_access_dim] * dst_data_per_access; - - dst_vector_t vector_data; - - // pack vector from buffer - for(index_t i = 0; i < DstDataPerAccess; ++i) - { - auto scalar_id = make_zero_array(); - scalar_id(dst_vector_access_dim) = i; - - const index_t buffer_offset = buffer_desc.GetOffsetFromMultiIndex( - dst_merged_dim_data_id + dst_normal_dim_data_id + scalar_id); - - reinterpret_cast(&vector_data)[i] = p_dst_buffer[buffer_offset]; - } - - // offset w.r.t. normal dimension is known at compile-time - const index_t dst_normal_offset = - DstDesc::GetOffsetFromMultiIndex(dst_normal_dim_data_id); - - // Write vector into dst. - // 1. Source code version can take dst of all kinds of memory-space - // 2. Intrinsic version using buffer_store can only take - // dst from global-memory - // - // Commemt for storing into global-memory: - // When: - // 1) using source code, in order for compiler to emit optimal - // store instruction, or - // 2) using buffer_store, intrinsic in order ISA to be valid - // following assumptions need to be satisfied: - // 1. p_dst need to be block-invariant (assumption) - // 2. dst_normal_offset must be calculatd at compile time (guaranteed by - // algorithm) - // 3. dst_merged_offset can be runtime value (no assumption imposed) - static_if{}([&](auto fwd) { -#if CK_USE_AMD_BUFFER_ADDRESSING - amd_intrinsic_buffer_store( - vector_data, fwd(p_dst), dst_merged_offset, dst_normal_offset); -#else - *reinterpret_cast( - &p_dst[dst_normal_offset + dst_merged_offset]) = vector_data; -#endif - }).Else([&](auto) { - // dst can be all kinds of memory-space - *reinterpret_cast( - &p_dst[dst_normal_offset + dst_merged_offset]) = vector_data; - }); - }); - }); - } - } - - template - __device__ void Run(const SrcData* p_src, DstData* p_dst) const - { - constexpr auto generic_address_space = - integral_constant{}; - - Run(p_src, p_dst, generic_address_space, generic_address_space); - } - - // T can be Sequence or Array - template - __device__ void MoveSrcSliceWindow(T step_sizes, integral_constant) - { - static_if{}([&](auto) { - mSrcSliceOrigin += step_sizes; - }).Else([&](auto) { mSrcSliceOrigin -= step_sizes; }); - } - - template - __device__ void MoveDstSliceWindow(T step_sizes, integral_constant) - { - static_if{}([&](auto) { - mDstSliceOrigin += step_sizes; - }).Else([&](auto) { mDstSliceOrigin -= step_sizes; }); - } - - private: - SrcCoordinate mSrcSliceOrigin; - DstCoordinate mDstSliceOrigin; -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/amd_buffer_addressing.hpp b/composable_kernel/include/utility/amd_buffer_addressing.hpp index f3a2661849..a308e710f9 100644 --- a/composable_kernel/include/utility/amd_buffer_addressing.hpp +++ b/composable_kernel/include/utility/amd_buffer_addressing.hpp @@ -8,65 +8,149 @@ namespace ck { // For 128bit SGPRs in buffer_load and buffer_store instructions // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions template -union BufferLoadStoreDwordConfig +union BufferAddressConfig { int32x4_t data; T* address[2]; int32_t range[4]; }; -__device__ float __llvm_amdgcn_buffer_load(int32x4_t rsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.load.f32"); +__device__ float __llvm_amdgcn_buffer_load_f32(int32x4_t rsrc, + index_t vindex, + index_t offset, + bool glc, + bool slc) __asm("llvm.amdgcn.buffer.load.f32"); -__device__ float2_t __llvm_amdgcn_buffer_loadx2(int32x4_t rsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.load.v2f32"); - -__device__ float4_t __llvm_amdgcn_buffer_loadx4(int32x4_t rsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.load.v4f32"); - -__device__ void __llvm_amdgcn_buffer_store(float vdata, - int32x4_t rsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.store.f32"); - -__device__ void __llvm_amdgcn_buffer_storex2(float2_t vdata, - int32x4_t rsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.store.v2f32"); - -__device__ void __llvm_amdgcn_buffer_storex4(float4_t vdata, - int32x4_t rsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.store.v4f32"); - -__device__ void -__llvm_amdgcn_buffer_atomic_add(float vdata, - int32x4_t rsrc, +__device__ float2_t +__llvm_amdgcn_buffer_load_f32x2(int32x4_t rsrc, index_t vindex, index_t offset, - bool slc) __asm("llvm.amdgcn.buffer.atomic.fadd.f32"); + bool glc, + bool slc) __asm("llvm.amdgcn.buffer.load.v2f32"); + +__device__ float4_t +__llvm_amdgcn_buffer_load_f32x4(int32x4_t rsrc, + index_t vindex, + index_t offset, + bool glc, + bool slc) __asm("llvm.amdgcn.buffer.load.v4f32"); + +__device__ half_t __llvm_amdgcn_buffer_load_f16(int32x4_t rsrc, + index_t vindex, + index_t offset, + bool glc, + bool slc) __asm("llvm.amdgcn.buffer.load.f16"); + +__device__ half2_t __llvm_amdgcn_buffer_load_f16x2(int32x4_t rsrc, + index_t vindex, + index_t offset, + bool glc, + bool slc) __asm("llvm.amdgcn.buffer.load.v2f16"); + +__device__ half4_t __llvm_amdgcn_buffer_load_f16x4(int32x4_t rsrc, + index_t vindex, + index_t offset, + bool glc, + bool slc) __asm("llvm.amdgcn.buffer.load.v4f16"); + +__device__ ushort __llvm_amdgcn_buffer_load_bf16(int32x4_t rsrc, + index_t vindex, + index_t offset, + bool glc, + bool slc) __asm("llvm.amdgcn.buffer.load.bf16"); + +__device__ ushort2_t +__llvm_amdgcn_buffer_load_bf16x2(int32x4_t rsrc, + index_t vindex, + index_t offset, + bool glc, + bool slc) __asm("llvm.amdgcn.buffer.load.v2bf16"); + +__device__ ushort4_t +__llvm_amdgcn_buffer_load_bf16x4(int32x4_t rsrc, + index_t vindex, + index_t offset, + bool glc, + bool slc) __asm("llvm.amdgcn.buffer.load.v4bf16"); + +__device__ void __llvm_amdgcn_buffer_store_f32(float vdata, + int32x4_t rsrc, + index_t vindex, + index_t offset, + bool glc, + bool slc) __asm("llvm.amdgcn.buffer.store.f32"); + +__device__ void __llvm_amdgcn_buffer_store_f32x2(float2_t vdata, + int32x4_t rsrc, + index_t vindex, + index_t offset, + bool glc, + bool slc) __asm("llvm.amdgcn.buffer.store.v2f32"); + +__device__ void __llvm_amdgcn_buffer_store_f32x4(float4_t vdata, + int32x4_t rsrc, + index_t vindex, + index_t offset, + bool glc, + bool slc) __asm("llvm.amdgcn.buffer.store.v4f32"); + +__device__ void __llvm_amdgcn_buffer_store_f16(half_t vdata, + int32x4_t rsrc, + index_t vindex, + index_t offset, + bool glc, + bool slc) __asm("llvm.amdgcn.buffer.store.f16"); + +__device__ void __llvm_amdgcn_buffer_store_f16x2(half2_t vdata, + int32x4_t rsrc, + index_t vindex, + index_t offset, + bool glc, + bool slc) __asm("llvm.amdgcn.buffer.store.v2f16"); + +__device__ void __llvm_amdgcn_buffer_store_f16x4(half4_t vdata, + int32x4_t rsrc, + index_t vindex, + index_t offset, + bool glc, + bool slc) __asm("llvm.amdgcn.buffer.store.v4f16"); + +__device__ void __llvm_amdgcn_buffer_store_bf16(ushort vdata, + int32x4_t rsrc, + index_t vindex, + index_t offset, + bool glc, + bool slc) __asm("llvm.amdgcn.buffer.store.bf16"); + +__device__ void +__llvm_amdgcn_buffer_store_bf16x2(ushort2_t vdata, + int32x4_t rsrc, + index_t vindex, + index_t offset, + bool glc, + bool slc) __asm("llvm.amdgcn.buffer.store.v2bf16"); + +__device__ void +__llvm_amdgcn_buffer_store_bf16x4(ushort4_t vdata, + int32x4_t rsrc, + index_t vindex, + index_t offset, + bool glc, + bool slc) __asm("llvm.amdgcn.buffer.store.v4bf16"); + +__device__ void +__llvm_amdgcn_buffer_atomic_add_f32(float vdata, + int32x4_t rsrc, + index_t vindex, + index_t offset, + bool slc) __asm("llvm.amdgcn.buffer.atomic.fadd.f32"); // buffer_load requires: // 1) p_src must be in global memory space, d_dst must be vgpr // 2) p_src to be a block-invariant pointer. // It is user's responsibility to make sure that is true. template -__device__ typename vector_type::MemoryType amd_intrinsic_buffer_load( +__device__ typename vector_type::MemoryType amd_buffer_load( const T* p_src_block, index_t src_thread_data_offset, index_t src_const_data_offset); // buffer_store requires: @@ -74,30 +158,23 @@ __device__ typename vector_type::MemoryType amd_intrinsic_buffer_ // 2) p_dst to be a block-invariant pointer. // It is user's responsibility to make sure that is true. template -__device__ void -amd_intrinsic_buffer_store(const typename vector_type::MemoryType& src, - T* p_dst_block, - index_t dst_thread_data_offset, - index_t dst_const_data_offset); +__device__ void amd_buffer_store(const T* p_src, + T* p_dst_block, + index_t dst_thread_data_offset, + index_t dst_const_data_offset); template -__device__ void -amd_intrinsic_buffer_atomic_add(const typename vector_type::MemoryType& src, - T* p_dst_block, - index_t dst_thread_data_offset, - index_t dst_const_data_offset); +__device__ void amd_buffer_atomic_add(const T* p_src, + T* p_dst_block, + index_t dst_thread_data_offset, + index_t dst_const_data_offset); template <> -__device__ float amd_intrinsic_buffer_load(const float* p_src_block, - index_t src_thread_data_offset, - index_t src_const_data_offset) +__device__ float amd_buffer_load(const float* p_src_block, + index_t src_thread_data_offset, + index_t src_const_data_offset) { - float dst; - - index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float); - index_t src_const_addr_offset = src_const_data_offset * sizeof(float); - - BufferLoadStoreDwordConfig src_block_config; + BufferAddressConfig src_block_config; // fill in byte 0 - 1 src_block_config.address[0] = const_cast(p_src_block); @@ -106,33 +183,19 @@ __device__ float amd_intrinsic_buffer_load(const float* p_src_block, // fill in byte 3 src_block_config.range[3] = 0x00027000; -#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC - dst = __llvm_amdgcn_buffer_load( - src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); -#else - asm volatile( - "\n \ - buffer_load_dword %0, %1, %2, %3 offen offset:0 \n \ - s_waitcnt 0 \n \ - " - : "=v"(dst) - : "v"(src_thread_addr_offset), "s"(src_block_config.data), "s"(src_const_addr_offset)); -#endif - - return dst; -} - -template <> -__device__ float2_t amd_intrinsic_buffer_load(const float* p_src_block, - index_t src_thread_data_offset, - index_t src_const_data_offset) -{ - float2_t dst; - index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float); index_t src_const_addr_offset = src_const_data_offset * sizeof(float); - BufferLoadStoreDwordConfig src_block_config; + return __llvm_amdgcn_buffer_load_f32( + src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); +} + +template <> +__device__ float2_t amd_buffer_load(const float* p_src_block, + index_t src_thread_data_offset, + index_t src_const_data_offset) +{ + BufferAddressConfig src_block_config; // fill in byte 0 - 1 src_block_config.address[0] = const_cast(p_src_block); @@ -141,33 +204,19 @@ __device__ float2_t amd_intrinsic_buffer_load(const float* p_src_block // fill in byte 3 src_block_config.range[3] = 0x00027000; -#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC - dst = __llvm_amdgcn_buffer_loadx2( - src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); -#else - asm volatile( - "\n \ - buffer_load_dwordx2 %0, %1, %2, %3 offen offset:0 \n \ - s_waitcnt 0 \n \ - " - : "=v"(dst) - : "v"(src_thread_addr_offset), "s"(src_block_config.data), "s"(src_const_addr_offset)); -#endif - - return dst; -} - -template <> -__device__ float4_t amd_intrinsic_buffer_load(const float* p_src_block, - index_t src_thread_data_offset, - index_t src_const_data_offset) -{ - float4_t dst; - index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float); index_t src_const_addr_offset = src_const_data_offset * sizeof(float); - BufferLoadStoreDwordConfig src_block_config; + return __llvm_amdgcn_buffer_load_f32x2( + src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); +} + +template <> +__device__ float4_t amd_buffer_load(const float* p_src_block, + index_t src_thread_data_offset, + index_t src_const_data_offset) +{ + BufferAddressConfig src_block_config; // fill in byte 0 - 1 src_block_config.address[0] = const_cast(p_src_block); @@ -176,32 +225,236 @@ __device__ float4_t amd_intrinsic_buffer_load(const float* p_src_block // fill in byte 3 src_block_config.range[3] = 0x00027000; -#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC - dst = __llvm_amdgcn_buffer_loadx4( + index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float); + index_t src_const_addr_offset = src_const_data_offset * sizeof(float); + + return __llvm_amdgcn_buffer_load_f32x4( + src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); +} + +template <> +__device__ half_t amd_buffer_load(const half_t* p_src_block, + index_t src_thread_data_offset, + index_t src_const_data_offset) +{ + BufferAddressConfig src_block_config; + + // fill in byte 0 - 1 + src_block_config.address[0] = const_cast(p_src_block); + // fill in byte 2 + src_block_config.range[2] = -1; + // fill in byte 3 + src_block_config.range[3] = 0x00027000; + +#if !CK_WORKAROUND_SWDEV_231101 + index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t); + index_t src_const_addr_offset = src_const_data_offset * sizeof(half_t); + + return __llvm_amdgcn_buffer_load_f16( src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); #else - asm volatile( - "\n \ - buffer_load_dwordx4 %0, %1, %2, %3 offen offset:0 \n \ - s_waitcnt 0 \n \ - " - : "=v"(dst) - : "v"(src_thread_addr_offset), "s"(src_block_config.data), "s"(src_const_addr_offset)); + return p_src_block[src_thread_data_offset + src_const_data_offset]; #endif - - return dst; } template <> -__device__ void amd_intrinsic_buffer_store(const float& src, - float* p_dst_block, - index_t dst_thread_data_offset, - index_t dst_const_data_offset) +__device__ half2_t amd_buffer_load(const half_t* p_src_block, + index_t src_thread_data_offset, + index_t src_const_data_offset) { - index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); - index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float); + BufferAddressConfig src_block_config; - BufferLoadStoreDwordConfig dst_block_config; + // fill in byte 0 - 1 + src_block_config.address[0] = const_cast(p_src_block); + // fill in byte 2 + src_block_config.range[2] = -1; + // fill in byte 3 + src_block_config.range[3] = 0x00027000; + + index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t); + index_t src_const_addr_offset = src_const_data_offset * sizeof(half_t); + +#if !CK_WORKAROUND_SWDEV_231101 + return __llvm_amdgcn_buffer_load_f16x2( + src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); +#else + float dst_out_tmp = __llvm_amdgcn_buffer_load_f32( + src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); + + return *reinterpret_cast(&dst_out_tmp); +#endif +} + +template <> +__device__ half4_t amd_buffer_load(const half_t* p_src_block, + index_t src_thread_data_offset, + index_t src_const_data_offset) +{ + BufferAddressConfig src_block_config; + + // fill in byte 0 - 1 + src_block_config.address[0] = const_cast(p_src_block); + // fill in byte 2 + src_block_config.range[2] = -1; + // fill in byte 3 + src_block_config.range[3] = 0x00027000; + + index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t); + index_t src_const_addr_offset = src_const_data_offset * sizeof(half_t); + +#if !CK_WORKAROUND_SWDEV_231101 + return __llvm_amdgcn_buffer_load_f16x4( + src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); +#else + float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2( + src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); + + return *reinterpret_cast(&dst_out_tmp); +#endif +} + +template <> +__device__ half8_t amd_buffer_load(const half_t* p_src_block, + index_t src_thread_data_offset, + index_t src_const_data_offset) +{ + BufferAddressConfig src_block_config; + + // fill in byte 0 - 1 + src_block_config.address[0] = const_cast(p_src_block); + // fill in byte 2 + src_block_config.range[2] = -1; + // fill in byte 3 + src_block_config.range[3] = 0x00027000; + + index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t); + index_t src_const_addr_offset = src_const_data_offset * sizeof(half_t); + +#if !CK_WORKAROUND_SWDEV_231101 + static_assert(false, "wrong! not supported"); +#else + float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4( + src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); + + return *reinterpret_cast(&dst_out_tmp); +#endif +} + +template <> +__device__ ushort amd_buffer_load(const ushort* p_src_block, + index_t src_thread_data_offset, + index_t src_const_data_offset) +{ + BufferAddressConfig src_block_config; + + // fill in byte 0 - 1 + src_block_config.address[0] = const_cast(p_src_block); + // fill in byte 2 + src_block_config.range[2] = -1; + // fill in byte 3 + src_block_config.range[3] = 0x00027000; + +#if !CK_WORKAROUND_SWDEV_231101 + index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort); + index_t src_const_addr_offset = src_const_data_offset * sizeof(ushort); + + return __llvm_amdgcn_buffer_load_bf16( + src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); +#else + return p_src_block[src_thread_data_offset + src_const_data_offset]; +#endif +} + +template <> +__device__ ushort2_t amd_buffer_load(const ushort* p_src_block, + index_t src_thread_data_offset, + index_t src_const_data_offset) +{ + BufferAddressConfig src_block_config; + + // fill in byte 0 - 1 + src_block_config.address[0] = const_cast(p_src_block); + // fill in byte 2 + src_block_config.range[2] = -1; + // fill in byte 3 + src_block_config.range[3] = 0x00027000; + + index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort); + index_t src_const_addr_offset = src_const_data_offset * sizeof(ushort); + +#if !CK_WORKAROUND_SWDEV_231101 + return __llvm_amdgcn_buffer_load_bf16x2( + src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); +#else + float dst_out_tmp = __llvm_amdgcn_buffer_load_f32( + src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); + + return *reinterpret_cast(&dst_out_tmp); +#endif +} + +template <> +__device__ ushort4_t amd_buffer_load(const ushort* p_src_block, + index_t src_thread_data_offset, + index_t src_const_data_offset) +{ + BufferAddressConfig src_block_config; + + // fill in byte 0 - 1 + src_block_config.address[0] = const_cast(p_src_block); + // fill in byte 2 + src_block_config.range[2] = -1; + // fill in byte 3 + src_block_config.range[3] = 0x00027000; + + index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort); + index_t src_const_addr_offset = src_const_data_offset * sizeof(ushort); + +#if !CK_WORKAROUND_SWDEV_231101 + return __llvm_amdgcn_buffer_load_bf16x4( + src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); +#else + float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2( + src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); + + return *reinterpret_cast(&dst_out_tmp); +#endif +} + +template <> +__device__ ushort8_t amd_buffer_load(const ushort* p_src_block, + index_t src_thread_data_offset, + index_t src_const_data_offset) +{ + BufferAddressConfig src_block_config; + + // fill in byte 0 - 1 + src_block_config.address[0] = const_cast(p_src_block); + // fill in byte 2 + src_block_config.range[2] = -1; + // fill in byte 3 + src_block_config.range[3] = 0x00027000; + + index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort); + index_t src_const_addr_offset = src_const_data_offset * sizeof(ushort); + +#if !CK_WORKAROUND_SWDEV_231101 + static_assert(false, "wrong! not implemented"); +#else + float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4( + src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false); + + return *reinterpret_cast(&dst_out_tmp); +#endif +} + +template <> +__device__ void amd_buffer_store(const float* p_src, + float* p_dst_block, + index_t dst_thread_data_offset, + index_t dst_const_data_offset) +{ + BufferAddressConfig dst_block_config; // fill in byte 0 - 1 dst_block_config.address[0] = p_dst_block; @@ -210,35 +463,24 @@ __device__ void amd_intrinsic_buffer_store(const float& src, // fill in byte 3 dst_block_config.range[3] = 0x00027000; -#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC - __llvm_amdgcn_buffer_store(src, - dst_block_config.data, - 0, - dst_thread_addr_offset + dst_const_addr_offset, - false, - false); -#else - asm volatile("\n \ - buffer_store_dword %1, %2, %0, %3 offen offset:0 \n \ - " - : - : "s"(dst_block_config.data), - "v"(src), - "v"(dst_thread_addr_offset), - "s"(dst_const_addr_offset)); -#endif -} - -template <> -__device__ void amd_intrinsic_buffer_store(const float2_t& src, - float* p_dst_block, - index_t dst_thread_data_offset, - index_t dst_const_data_offset) -{ index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float); - BufferLoadStoreDwordConfig dst_block_config; + __llvm_amdgcn_buffer_store_f32(*p_src, + dst_block_config.data, + 0, + dst_thread_addr_offset + dst_const_addr_offset, + false, + false); +} + +template <> +__device__ void amd_buffer_store(const float* p_src, + float* p_dst_block, + index_t dst_thread_data_offset, + index_t dst_const_data_offset) +{ + BufferAddressConfig dst_block_config; // fill in byte 0 - 1 dst_block_config.address[0] = p_dst_block; @@ -247,35 +489,24 @@ __device__ void amd_intrinsic_buffer_store(const float2_t& src, // fill in byte 3 dst_block_config.range[3] = 0x00027000; -#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC - __llvm_amdgcn_buffer_storex2(src, - dst_block_config.data, - 0, - dst_thread_addr_offset + dst_const_addr_offset, - false, - false); -#else - asm volatile("\n \ - buffer_store_dwordx2 %1, %2, %0, %3 offen offset:0 \n \ - " - : - : "s"(dst_block_config.data), - "v"(src), - "v"(dst_thread_addr_offset), - "s"(dst_const_addr_offset)); -#endif -} - -template <> -__device__ void amd_intrinsic_buffer_store(const float4_t& src, - float* p_dst_block, - index_t dst_thread_data_offset, - index_t dst_const_data_offset) -{ index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float); - BufferLoadStoreDwordConfig dst_block_config; + __llvm_amdgcn_buffer_store_f32x2(*reinterpret_cast(p_src), + dst_block_config.data, + 0, + dst_thread_addr_offset + dst_const_addr_offset, + false, + false); +} + +template <> +__device__ void amd_buffer_store(const float* p_src, + float* p_dst_block, + index_t dst_thread_data_offset, + index_t dst_const_data_offset) +{ + BufferAddressConfig dst_block_config; // fill in byte 0 - 1 dst_block_config.address[0] = p_dst_block; @@ -284,35 +515,24 @@ __device__ void amd_intrinsic_buffer_store(const float4_t& src, // fill in byte 3 dst_block_config.range[3] = 0x00027000; -#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC - __llvm_amdgcn_buffer_storex4(src, - dst_block_config.data, - 0, - dst_thread_addr_offset + dst_const_addr_offset, - false, - false); -#else - asm volatile("\n \ - buffer_store_dwordx4 %1, %2, %0, %3 offen offset:0 \n \ - " - : - : "s"(dst_block_config.data), - "v"(src), - "v"(dst_thread_addr_offset), - "s"(dst_const_addr_offset)); -#endif -} - -template <> -__device__ void amd_intrinsic_buffer_atomic_add(const float& src, - float* p_dst_block, - index_t dst_thread_data_offset, - index_t dst_const_data_offset) -{ index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float); - BufferLoadStoreDwordConfig dst_block_config; + __llvm_amdgcn_buffer_store_f32x4(*reinterpret_cast(p_src), + dst_block_config.data, + 0, + dst_thread_addr_offset + dst_const_addr_offset, + false, + false); +} + +template <> +__device__ void amd_buffer_store(const half_t* p_src, + half_t* p_dst_block, + index_t dst_thread_data_offset, + index_t dst_const_data_offset) +{ + BufferAddressConfig dst_block_config; // fill in byte 0 - 1 dst_block_config.address[0] = p_dst_block; @@ -321,13 +541,246 @@ __device__ void amd_intrinsic_buffer_atomic_add(const float& src, // fill in byte 3 dst_block_config.range[3] = 0x00027000; -#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC - __llvm_amdgcn_buffer_atomic_add( - src, dst_block_config.data, 0, dst_thread_addr_offset + dst_const_addr_offset, false); +#if !CK_WORKAROUND_SWDEV_231101 + index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t); + index_t dst_const_addr_offset = dst_const_data_offset * sizeof(half_t); + + __llvm_amdgcn_buffer_store_f16(*p_src, + dst_block_config.data, + 0, + dst_thread_addr_offset + dst_const_addr_offset, + false, + false); #else - static_assert(false, " wrong! not implemented"); + p_dst_block[dst_thread_data_offset + dst_const_data_offset] = *p_src; #endif } +template <> +__device__ void amd_buffer_store(const half_t* p_src, + half_t* p_dst_block, + index_t dst_thread_data_offset, + index_t dst_const_data_offset) +{ + BufferAddressConfig dst_block_config; + + // fill in byte 0 - 1 + dst_block_config.address[0] = p_dst_block; + // fill in byte 2 + dst_block_config.range[2] = -1; + // fill in byte 3 + dst_block_config.range[3] = 0x00027000; + + index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t); + index_t dst_const_addr_offset = dst_const_data_offset * sizeof(half_t); + +#if !CK_WORKAROUND_SWDEV_231101 + __llvm_amdgcn_buffer_store_f16x2(*reinterpret_cast(p_src), + dst_block_config.data, + 0, + dst_thread_addr_offset + dst_const_addr_offset, + false, + false); +#else + const float* p_src_tmp = reinterpret_cast(p_src); + + __llvm_amdgcn_buffer_store_f32(*p_src_tmp, + dst_block_config.data, + 0, + dst_thread_addr_offset + dst_const_addr_offset, + false, + false); +#endif +} + +template <> +__device__ void amd_buffer_store(const half_t* p_src, + half_t* p_dst_block, + index_t dst_thread_data_offset, + index_t dst_const_data_offset) +{ + index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t); + index_t dst_const_addr_offset = dst_const_data_offset * sizeof(half_t); + + BufferAddressConfig dst_block_config; + + // fill in byte 0 - 1 + dst_block_config.address[0] = p_dst_block; + // fill in byte 2 + dst_block_config.range[2] = -1; + // fill in byte 3 + dst_block_config.range[3] = 0x00027000; + +#if !CK_WORKAROUND_SWDEV_231101 + __llvm_amdgcn_buffer_store_f16x4(*reinterpret_cast(p_src), + dst_block_config.data, + 0, + dst_thread_addr_offset + dst_const_addr_offset, + false, + false); +#else + const float2_t* p_src_tmp = reinterpret_cast(p_src); + + __llvm_amdgcn_buffer_store_f32x2(*p_src_tmp, + dst_block_config.data, + 0, + dst_thread_addr_offset + dst_const_addr_offset, + false, + false); +#endif +} + +template <> +__device__ void amd_buffer_store(const ushort* p_src, + ushort* p_dst_block, + index_t dst_thread_data_offset, + index_t dst_const_data_offset) +{ + BufferAddressConfig dst_block_config; + + // fill in byte 0 - 1 + dst_block_config.address[0] = p_dst_block; + // fill in byte 2 + dst_block_config.range[2] = -1; + // fill in byte 3 + dst_block_config.range[3] = 0x00027000; + +#if !CK_WORKAROUND_SWDEV_231101 + index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort); + index_t dst_const_addr_offset = dst_const_data_offset * sizeof(ushort); + + __llvm_amdgcn_buffer_store_bf16(*p_src, + dst_block_config.data, + 0, + dst_thread_addr_offset + dst_const_addr_offset, + false, + false); +#else + p_dst_block[dst_thread_data_offset + dst_const_data_offset] = *p_src; +#endif +} + +template <> +__device__ void amd_buffer_store(const ushort* p_src, + ushort* p_dst_block, + index_t dst_thread_data_offset, + index_t dst_const_data_offset) +{ + BufferAddressConfig dst_block_config; + + // fill in byte 0 - 1 + dst_block_config.address[0] = p_dst_block; + // fill in byte 2 + dst_block_config.range[2] = -1; + // fill in byte 3 + dst_block_config.range[3] = 0x00027000; + + index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort); + index_t dst_const_addr_offset = dst_const_data_offset * sizeof(ushort); + +#if !CK_WORKAROUND_SWDEV_231101 + __llvm_amdgcn_buffer_store_bf16x2(*p_src, + dst_block_config.data, + 0, + dst_thread_addr_offset + dst_const_addr_offset, + false, + false); +#else + const float* p_src_tmp = reinterpret_cast(p_src); + + __llvm_amdgcn_buffer_store_f32(*p_src_tmp, + dst_block_config.data, + 0, + dst_thread_addr_offset + dst_const_addr_offset, + false, + false); +#endif +} + +template <> +__device__ void amd_buffer_store(const ushort* p_src, + ushort* p_dst_block, + index_t dst_thread_data_offset, + index_t dst_const_data_offset) +{ + BufferAddressConfig dst_block_config; + + // fill in byte 0 - 1 + dst_block_config.address[0] = p_dst_block; + // fill in byte 2 + dst_block_config.range[2] = -1; + // fill in byte 3 + dst_block_config.range[3] = 0x00027000; + + index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort); + index_t dst_const_addr_offset = dst_const_data_offset * sizeof(ushort); + +#if !CK_WORKAROUND_SWDEV_231101 + __llvm_amdgcn_buffer_store_bf16x4(*p_src, + dst_block_config.data, + 0, + dst_thread_addr_offset + dst_const_addr_offset, + false, + false); +#else + const float2_t* p_src_tmp = reinterpret_cast(p_src); + + __llvm_amdgcn_buffer_store_f32x2(*p_src_tmp, + dst_block_config.data, + 0, + dst_thread_addr_offset + dst_const_addr_offset, + false, + false); +#endif +} + +template <> +__device__ void amd_buffer_atomic_add(const float* p_src, + float* p_dst_block, + index_t dst_thread_data_offset, + index_t dst_const_data_offset) +{ + BufferAddressConfig dst_block_config; + + // fill in byte 0 - 1 + dst_block_config.address[0] = p_dst_block; + // fill in byte 2 + dst_block_config.range[2] = -1; + // fill in byte 3 + dst_block_config.range[3] = 0x00027000; + + index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); + index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float); + + __llvm_amdgcn_buffer_atomic_add_f32( + *p_src, dst_block_config.data, 0, dst_thread_addr_offset + dst_const_addr_offset, false); +} + +template <> +__device__ void amd_buffer_atomic_add(const float* p_src, + float* p_dst_block, + index_t dst_thread_data_offset, + index_t dst_const_data_offset) +{ + for(index_t i = 0; i < 2; ++i) + { + amd_buffer_atomic_add( + &p_src[i], p_dst_block, dst_thread_data_offset, dst_const_data_offset + i); + } +} + +template <> +__device__ void amd_buffer_atomic_add(const float* p_src, + float* p_dst_block, + index_t dst_thread_data_offset, + index_t dst_const_data_offset) +{ + for(index_t i = 0; i < 4; ++i) + { + amd_buffer_atomic_add( + &p_src[i], p_dst_block, dst_thread_data_offset, dst_const_data_offset + i); + } +} + } // namespace ck #endif diff --git a/composable_kernel/include/utility/common_header.hpp b/composable_kernel/include/utility/common_header.hpp index 45750bbed0..27098cb3e8 100644 --- a/composable_kernel/include/utility/common_header.hpp +++ b/composable_kernel/include/utility/common_header.hpp @@ -16,15 +16,12 @@ #include "functional3.hpp" #include "functional4.hpp" #include "in_memory_operation.hpp" +#include "synchronization.hpp" #if CK_USE_AMD_INLINE_ASM #include "amd_inline_asm.hpp" #endif -#if CK_USE_AMD_BUFFER_ADDRESSING -#include "amd_buffer_addressing.hpp" -#endif - #if CK_USE_AMD_XDLOPS #include "amd_xdlops.hpp" #endif diff --git a/composable_kernel/include/utility/config.amd.hpp.in b/composable_kernel/include/utility/config.amd.hpp.in index 7ff99e0af6..89a8fd5f60 100644 --- a/composable_kernel/include/utility/config.amd.hpp.in +++ b/composable_kernel/include/utility/config.amd.hpp.in @@ -25,11 +25,7 @@ #define CK_USE_AMD_BUFFER_ADDRESSING 1 #endif -#ifndef CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC -#define CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC 1 -#endif - -// only support gfx908 +// only gfx908 support native floating point atomic add #ifndef CK_USE_AMD_BUFFER_ATOMIC_ADD #define CK_USE_AMD_BUFFER_ATOMIC_ADD 0 #endif @@ -47,6 +43,11 @@ #define CK_USE_AMD_XDLOPS_EMULATE 0 // For internal debug purposes #endif +// block synchronization only s_wait lgkmcnt(0), not vmcnt(0) +#ifndef CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM +#define CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 +#endif + // experimental implementation #define CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE 1 #define CK_EXPERIMENTAL_TENSOR_COORDINATE_USE_CALCULATE_OFFSET_DIFF 0 @@ -54,8 +55,24 @@ #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0 #define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0 + +#ifndef CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK #define CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK 0 +#endif + +#ifndef CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK #define CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK 0 +#endif + +// workaround: put all workaround here +// workaround for unnecessary VGPA <--> AGRP data movement when using mfma LLVM intrinsic +#ifndef CK_WORKAROUND_SWDEV_229564 +#define CK_WORKAROUND_SWDEV_229564 1 +#endif +// workaround for buffer load/store fp16/bfp16 intrinsic bug +#ifndef CK_WORKAROUND_SWDEV_231101 +#define CK_WORKAROUND_SWDEV_231101 1 +#endif namespace ck { diff --git a/composable_kernel/include/utility/config.nvidia.hpp.in b/composable_kernel/include/utility/config.nvidia.hpp.in index 08757e0a02..2c26d4d624 100644 --- a/composable_kernel/include/utility/config.nvidia.hpp.in +++ b/composable_kernel/include/utility/config.nvidia.hpp.in @@ -1,10 +1,9 @@ #ifndef CK_CONFIG_NVIDIA_HPP #define CK_CONFIG_NVIDIA_HPP -#include "cuda_runtime.h" -#include "cuda_fp16.h" -#include "nvToolsExt.h" -#include "helper_cuda.h" +#include +#include +#include // index type: unsigned or signed #define CK_UNSIGNED_INDEX_TYPE 0 @@ -19,6 +18,7 @@ #define CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC 0 #define CK_USE_AMD_XDLOPS 0 #define CK_USE_AMD_XDLOPS_INLINE_ASM 0 +#define CK_USE_AMD_XDLOPS_EMULATE 0 // experimental implementation #define CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE 0 @@ -32,16 +32,16 @@ namespace ck { enum AddressSpace { - generic, - global, - lds, - vgpr + Generic, + Global, + Lds, + Vgpr }; enum InMemoryDataOperation { - none, - atomic_add + Set, + AtomicAdd }; #if CK_UNSIGNED_INDEX_TYPE diff --git a/composable_kernel/include/utility/float_type.amd.hpp.in b/composable_kernel/include/utility/float_type.amd.hpp.in index fd9c0029bc..058bfcca02 100644 --- a/composable_kernel/include/utility/float_type.amd.hpp.in +++ b/composable_kernel/include/utility/float_type.amd.hpp.in @@ -11,12 +11,15 @@ typedef float float16_t __attribute__((ext_vector_type(16))); typedef float float32_t __attribute__((ext_vector_type(32))); // float16 +typedef _Float16 half_t; typedef _Float16 half2_t __attribute__((ext_vector_type(2))); typedef _Float16 half4_t __attribute__((ext_vector_type(4))); +typedef _Float16 half8_t __attribute__((ext_vector_type(8))); // bfloat16 typedef ushort ushort2_t __attribute__((ext_vector_type(2))); typedef ushort ushort4_t __attribute__((ext_vector_type(4))); +typedef ushort ushort8_t __attribute__((ext_vector_type(8))); template struct vector_type @@ -83,37 +86,37 @@ struct vector_type }; template <> -struct vector_type +struct vector_type { - using MemoryType = half; + using MemoryType = half_t; template - __host__ __device__ static void SetScalar(MemoryType& v, half s, Number) + __host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number) { static_assert(I < 1, "wrong"); - *(reinterpret_cast(&v) + I) = s; + *(reinterpret_cast(&v) + I) = s; } }; template <> -struct vector_type +struct vector_type { using MemoryType = half2_t; union DataType { MemoryType vector; - half scalar[2]; + half_t scalar[2]; }; template - __host__ __device__ static void SetScalar(MemoryType& v, half s, Number) + __host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number) { static_assert(I < 2, "wrong"); - *(reinterpret_cast(&v) + I) = s; + *(reinterpret_cast(&v) + I) = s; } - __host__ __device__ static MemoryType Pack(half s0, half s1) + __host__ __device__ static MemoryType Pack(half_t s0, half_t s1) { DataType data; data.scalar[0] = s0; @@ -123,24 +126,24 @@ struct vector_type }; template <> -struct vector_type +struct vector_type { using MemoryType = half4_t; union DataType { MemoryType vector; - half scalar[4]; + half_t scalar[4]; }; template - __host__ __device__ static void SetScalar(MemoryType& v, half s, Number) + __host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number) { static_assert(I < 4, "wrong"); - *(reinterpret_cast(&v) + I) = s; + *(reinterpret_cast(&v) + I) = s; } - __host__ __device__ static MemoryType Pack(half s0, half s1, half s2, half s3) + __host__ __device__ static MemoryType Pack(half_t s0, half_t s1, half_t s2, half_t s3) { DataType data; data.scalar[0] = s0; @@ -151,6 +154,25 @@ struct vector_type } }; +template <> +struct vector_type +{ + using MemoryType = half8_t; + + union DataType + { + MemoryType vector; + half_t scalar[8]; + }; + + template + __host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number) + { + static_assert(I < 8, "wrong"); + *(reinterpret_cast(&v) + I) = s; + } +}; + template <> struct vector_type { @@ -220,6 +242,25 @@ struct vector_type } }; +template <> +struct vector_type +{ + using MemoryType = ushort8_t; + + union DataType + { + MemoryType vector; + ushort scalar[8]; + }; + + template + __host__ __device__ static void SetScalar(MemoryType& v, ushort s, Number) + { + static_assert(I < 8, "wrong"); + *(reinterpret_cast(&v) + I) = s; + } +}; + // data type conversion template struct type_convert @@ -250,12 +291,40 @@ struct inner_product_with_conversion { static constexpr auto convert = type_convert(); + __device__ T operator()(float4_t a, float4_t b) const + { + const float* p_a_float = reinterpret_cast(&a); + const float* p_b_float = reinterpret_cast(&b); + + T acc = 0; + for(index_t v = 0; v < 4; ++v) + { + acc += convert(p_a_float[v]) * convert(p_b_float[v]); + } + + return acc; + } + + __device__ T operator()(float2_t a, float2_t b) const + { + const float* p_a_float = reinterpret_cast(&a); + const float* p_b_float = reinterpret_cast(&b); + + T acc = 0; + for(index_t v = 0; v < 2; ++v) + { + acc += convert(p_a_float[v]) * convert(p_b_float[v]); + } + + return acc; + } + __device__ T operator()(float a, float b) const { return convert(a) * convert(b); } __device__ T operator()(half2_t a, half2_t b) const { - const half* p_a_half = reinterpret_cast(&a); - const half* p_b_half = reinterpret_cast(&b); + const half_t* p_a_half = reinterpret_cast(&a); + const half_t* p_b_half = reinterpret_cast(&b); T acc = 0; for(index_t v = 0; v < 2; ++v) @@ -268,8 +337,8 @@ struct inner_product_with_conversion __device__ T operator()(half4_t a, half4_t b) const { - const half* p_a_half = reinterpret_cast(&a); - const half* p_b_half = reinterpret_cast(&b); + const half_t* p_a_half = reinterpret_cast(&a); + const half_t* p_b_half = reinterpret_cast(&b); T acc = 0; for(index_t v = 0; v < 4; ++v) @@ -279,6 +348,19 @@ struct inner_product_with_conversion return acc; } + __device__ T operator()(half8_t a, half8_t b) const + { + const half_t* p_a_half = reinterpret_cast(&a); + const half_t* p_b_half = reinterpret_cast(&b); + + T acc = 0; + for(index_t v = 0; v < 8; ++v) + { + acc += convert(p_a_half[v]) * convert(p_b_half[v]); + } + return acc; + } + __device__ T operator()(ushort2_t a, ushort2_t b) const { const ushort* p_a_bfloat16 = reinterpret_cast(&a); @@ -305,6 +387,19 @@ struct inner_product_with_conversion } return acc; } + + __device__ T operator()(ushort8_t a, ushort8_t b) const + { + const ushort* p_a_bfloat16 = reinterpret_cast(&a); + const ushort* p_b_bfloat16 = reinterpret_cast(&b); + + T acc = 0; + for(index_t v = 0; v < 8; ++v) + { + acc += convert(p_a_bfloat16[v]) * convert(p_b_bfloat16[v]); + } + return acc; + } }; } // namespace ck diff --git a/composable_kernel/include/utility/float_type.nvidia.hpp.in b/composable_kernel/include/utility/float_type.nvidia.hpp.in index 8be8c704a0..f4a0a47c67 100644 --- a/composable_kernel/include/utility/float_type.nvidia.hpp.in +++ b/composable_kernel/include/utility/float_type.nvidia.hpp.in @@ -13,8 +13,18 @@ namespace ck { using float2_t = float2; using float4_t = float4; -// float16 +// float +typedef float float32_t __attribute__((ext_vector_type(32))); + +// bfloat16 +typedef ushort ushort2_t __attribute__((ext_vector_type(2))); +typedef ushort ushort4_t __attribute__((ext_vector_type(4))); +typedef ushort ushort8_t __attribute__((ext_vector_type(8))); + +// fp16 +using half_t = half; using half2_t = half2; +using half4_t = float2; template struct vector_type @@ -81,37 +91,37 @@ struct vector_type }; template <> -struct vector_type +struct vector_type { - using MemoryType = half; + using MemoryType = half_t; template - __host__ __device__ static void SetScalar(MemoryType& v, half s, Number) + __host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number) { static_assert(I < 1, "wrong"); - *(reinterpret_cast(&v) + I) = s; + *(reinterpret_cast(&v) + I) = s; } }; template <> -struct vector_type +struct vector_type { using MemoryType = half2_t; union DataType { MemoryType vector; - half scalar[2]; + half_t scalar[2]; }; template - __host__ __device__ static void SetScalar(MemoryType& v, half s, Number) + __host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number) { static_assert(I < 2, "wrong"); - *(reinterpret_cast(&v) + I) = s; + *(reinterpret_cast(&v) + I) = s; } - __host__ __device__ static MemoryType Pack(half s0, half s1) + __host__ __device__ static MemoryType Pack(half_t s0, half_t s1) { DataType data; data.scalar[0] = s0; @@ -140,8 +150,8 @@ struct inner_product_with_conversion __device__ T operator()(half2_t a, half2_t b) const { - const half* p_a_half = reinterpret_cast(&a); - const half* p_b_half = reinterpret_cast(&b); + const half_t* p_a_half = reinterpret_cast(&a); + const half_t* p_b_half = reinterpret_cast(&b); T acc = 0; for(index_t v = 0; v < 2; ++v) @@ -151,6 +161,19 @@ struct inner_product_with_conversion return acc; } + + __device__ T operator()(half4_t a, half4_t b) const + { + const half_t* p_a_half = reinterpret_cast(&a); + const half_t* p_b_half = reinterpret_cast(&b); + + T acc = 0; + for(index_t v = 0; v < 4; ++v) + { + acc += convert(p_a_half[v]) * convert(p_b_half[v]); + } + return acc; + } }; } // namespace ck diff --git a/composable_kernel/include/utility/in_memory_operation.amd.hpp.in b/composable_kernel/include/utility/in_memory_operation.amd.hpp.in index 2ba30a183b..4f99531044 100644 --- a/composable_kernel/include/utility/in_memory_operation.amd.hpp.in +++ b/composable_kernel/include/utility/in_memory_operation.amd.hpp.in @@ -2,91 +2,159 @@ #define CK_IN_MEMORY_OPERATION_AMD_HPP #include "float_type.hpp" + +#if CK_USE_AMD_BUFFER_ADDRESSING #include "amd_buffer_addressing.hpp" +#endif namespace ck { -template -__device__ void set_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) +template +__device__ void atomic_add_impl(T* p_dst, T src) +{ + atomicAdd(p_dst, src); +} + +// atomicAdd for float does not support vector type +template <> +__device__ void atomic_add_impl(float2_t* p_dst, float2_t src) +{ + float* p_dst_float = reinterpret_cast(p_dst); + const float* p_src_float = reinterpret_cast(&src); + + for(index_t i = 0; i < 2; ++i) + { + atomicAdd(&(p_dst_float[i]), p_src_float[i]); + } +} + +template <> +__device__ void atomic_add_impl(float4_t* p_dst, float4_t src) +{ + float* p_dst_float = reinterpret_cast(p_dst); + const float* p_src_float = reinterpret_cast(&src); + + for(index_t i = 0; i < 4; ++i) + { + atomicAdd(&(p_dst_float[i]), p_src_float[i]); + } +} + +template +struct SetData { using vector_t = typename vector_type::MemoryType; + // This version is only for compatibility, don't use this version if possible + template + __device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const + { + *reinterpret_cast(&p_dst[dst_offset]) = + *reinterpret_cast(&p_src[src_offset]); + } + #if CK_USE_AMD_BUFFER_ADDRESSING - // TODO: use static_if::ElseIf, instead of nested static_if - static_if{}([&](auto) { - // buffer_load requires: - // 1) p_src must be in global memory space, d_dst must be vgpr - // 2) p_src to be a block-invariant pointer. - // It is user's responsibility to make sure that is true. + // buffer_load requires: + // 1) p_src must be in global memory space, d_dst must be vgpr + // 2) p_src to be a block-invariant pointer. + // It is user's responsibility to make sure that is true. + template <> + __device__ void Run(const T* p_src, + index_t src_offset, + T* p_dst, + index_t dst_offset) const + { *reinterpret_cast(&p_dst[dst_offset]) = - amd_intrinsic_buffer_load(p_src, src_offset, 0); - }).Else([&](auto) { - static_if{}([&](auto) { - // buffer_store requires: - // 1) p_src must be in vgpr space, d_dst must be global memory - // 2) p_dst to be a block-invariant pointer. - // It is user's responsibility to make sure that is true. - amd_intrinsic_buffer_store( - *reinterpret_cast(&p_src[src_offset]), p_dst, dst_offset, 0); - }).Else([&](auto) { - *reinterpret_cast(&p_dst[dst_offset]) = - *reinterpret_cast(&p_src[src_offset]); - }); - }); -#else - *reinterpret_cast(&p_dst[dst_offset]) = - *reinterpret_cast(&p_src[src_offset]); -#endif -} + amd_buffer_load(p_src, src_offset, 0); + } -template -__device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) + // buffer_store requires: + // 1) p_src must be in vgpr space, d_dst must be global memory + // 2) p_dst to be a block-invariant pointer. + // It is user's responsibility to make sure that is true. + template <> + __device__ void Run(const T* p_src, + index_t src_offset, + T* p_dst, + index_t dst_offset) const + { + amd_buffer_store(&(p_src[src_offset]), p_dst, dst_offset, 0); + } +#endif +}; + +template +struct AtomicAddData { using vector_t = typename vector_type::MemoryType; - static_if{}([&](auto) { -#if CK_USE_AMD_BUFFER_ATOMIC_ADD - amd_intrinsic_buffer_atomic_add( - *reinterpret_cast(&p_src[src_offset]), p_dst, dst_offset, 0); -#else - atomicAdd(reinterpret_cast(&p_dst[dst_offset]), - *reinterpret_cast(&p_src[src_offset])); + // This version is only for compatibility, don't use this version if possible + template + __device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const + { + atomic_add_impl(reinterpret_cast(&p_dst[dst_offset]), + *reinterpret_cast(&p_src[src_offset])); + } + +#if CK_USE_AMD_BUFFER_ADDRESSING && CK_USE_AMD_BUFFER_ATOMIC_ADD + // buffer_atomic_add requires: + // 1) p_src must be in vgpr space, d_dst must be global memory + // 2) p_dst to be a block-invariant pointer. + // It is user's responsibility to make sure that is true. + template <> + __device__ void Run(const T* p_src, + index_t src_offset, + T* p_dst, + index_t dst_offset) const + { + amd_buffer_atomic_add(&(p_src[src_offset]), p_dst, dst_offset, 0); + } #endif - }).Else([&](auto fwd) { - static_assert(fwd(false), "atomic_add doesn't support this memory space"); - }); -} +}; template + InMemoryDataOperation DstInMemOp, + index_t SrcDataStride = 1, + index_t DstDataStride = 1> __device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) { static_assert(DstInMemOp == InMemoryDataOperation::Set || DstInMemOp == InMemoryDataOperation::AtomicAdd, "wrong! InMemoryDataOperation not supported!"); - // TODO: use static_if::ElseIf - static_if{}([&](auto) { - set_data( - p_src, src_offset, p_dst, dst_offset); - }); + // keep it simple, don't use static_if here, otherwise compiler will do weird things + if(SrcDataStride == 1 && DstDataStride == 1) + { + // TODO: use static_if::ElseIf + static_if{}([&](auto) { + SetData{}.template Run( + p_src, src_offset, p_dst, dst_offset); + }); - static_if{}([&](auto) { - atomic_add_data( - p_src, src_offset, p_dst, dst_offset); - }); + static_if{}([&](auto) { + AtomicAddData{}.template Run( + p_src, src_offset, p_dst, dst_offset); + }); + } + else + { + for(index_t i = 0; i < DataPerAccess; i++) + { + // TODO: use static_if::ElseIf + static_if{}([&](auto) { + SetData{}.template Run( + p_src, src_offset + i * SrcDataStride, p_dst, dst_offset + i * DstDataStride); + }); + + static_if{}([&](auto) { + AtomicAddData{}.template Run( + p_src, src_offset + i * SrcDataStride, p_dst, dst_offset + i * DstDataStride); + }); + } + } } } // namespace ck diff --git a/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in b/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in index 4061aff125..0e2c7e9603 100644 --- a/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in +++ b/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in @@ -3,56 +3,106 @@ namespace ck { -template -__device__ void copy_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) +template +__device__ void atomic_add_impl(T* p_dst, T src) +{ + atomicAdd(p_dst, src); +} + +// atomicAdd for float does not support vector type +template <> +__device__ void atomic_add_impl(float2_t* p_dst, float2_t src) +{ + float* p_dst_float = reinterpret_cast(p_dst); + const float* p_src_float = reinterpret_cast(&src); + + for(index_t i = 0; i < 2; ++i) + { + atomicAdd(&(p_dst_float[i]), p_src_float[i]); + } +} + +template <> +__device__ void atomic_add_impl(float4_t* p_dst, float4_t src) +{ + float* p_dst_float = reinterpret_cast(p_dst); + const float* p_src_float = reinterpret_cast(&src); + + for(index_t i = 0; i < 4; ++i) + { + atomicAdd(&(p_dst_float[i]), p_src_float[i]); + } +} + +template +struct SetData { using vector_t = typename vector_type::MemoryType; - *reinterpret_cast(&p_dst[dst_offset]) = - *reinterpret_cast(&p_src[src_offset]); -} + template + __device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const + { + *reinterpret_cast(&p_dst[dst_offset]) = + *reinterpret_cast(&p_src[src_offset]); + } +}; -template -__device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) +template +struct AtomicAddData { using vector_t = typename vector_type::MemoryType; - static_if{}([&](auto) { - atomicAdd(reinterpret_cast(&p_dst[dst_offset]), - *reinterpret_cast(&p_src[src_offset])); - }).Else([&](auto fwd) { - static_assert(fwd(false), "atomic_add doesn't support this memory space"); - }); -} + template + __device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const + { + atomic_add_impl(reinterpret_cast(&p_dst[dst_offset]), + *reinterpret_cast(&p_src[src_offset])); + } +}; template + InMemoryDataOperation DstInMemOp, + index_t SrcDataStride = 1, + index_t DstDataStride = 1> __device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) { static_assert(DstInMemOp == InMemoryDataOperation::Set || DstInMemOp == InMemoryDataOperation::AtomicAdd, "wrong! InMemoryDataOperation not supported!"); - // TODO: use static_if::ElseIf - static_if{}([&](auto) { - copy_data( - p_src, src_offset, p_dst, dst_offset); - }); + // keep it simple, don't use static_if here, otherwise compiler will do weird things + if(SrcDataStride == 1 && DstDataStride == 1) + { + // TODO: use static_if::ElseIf + static_if{}([&](auto) { + SetData{}.template Run( + p_src, src_offset, p_dst, dst_offset); + }); - static_if{}([&](auto) { - atomic_add_data( - p_src, src_offset, p_dst, dst_offset); - }); + static_if{}([&](auto) { + AtomicAddData{}.template Run( + p_src, src_offset, p_dst, dst_offset); + }); + } + else + { + for(index_t i = 0; i < DataPerAccess; i++) + { + // TODO: use static_if::ElseIf + static_if{}([&](auto) { + SetData{}.template Run( + p_src, src_offset + i * SrcDataStride, p_dst, dst_offset + i * DstDataStride); + }); + + static_if{}([&](auto) { + AtomicAddData{}.template Run( + p_src, src_offset + i * SrcDataStride, p_dst, dst_offset + i * DstDataStride); + }); + } + } } } // namespace ck diff --git a/composable_kernel/include/utility/math.hpp b/composable_kernel/include/utility/math.hpp index 4c9cd85d5e..7ce60d74cf 100644 --- a/composable_kernel/include/utility/math.hpp +++ b/composable_kernel/include/utility/math.hpp @@ -3,6 +3,7 @@ #include "config.hpp" #include "integral_constant.hpp" +#include "number.hpp" #include "type.hpp" namespace ck { diff --git a/composable_kernel/include/utility/synchronization.amd.hpp.in b/composable_kernel/include/utility/synchronization.amd.hpp.in new file mode 100644 index 0000000000..4e899baa95 --- /dev/null +++ b/composable_kernel/include/utility/synchronization.amd.hpp.in @@ -0,0 +1,25 @@ +#ifndef CK_SYNCHRONIZATION_AMD_HPP +#define CK_SYNCHRONIZATION_AMD_HPP + +#include "config.hpp" + +namespace ck { + +__device__ void __llvm_amdgcn_s_barrier() __asm("llvm.amdgcn.s.barrier"); + +__device__ void block_sync_lds() +{ +#if CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM + asm volatile("\ + s_waitcnt lgkmcnt(0) \n \ + s_barrier \ + " ::); +#else + __llvm_amdgcn_s_barrier(); +#endif +} + +__device__ void block_sync_lds_vmem() { __llvm_amdgcn_s_barrier(); } + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/synchronization.nvidia.hpp.in b/composable_kernel/include/utility/synchronization.nvidia.hpp.in new file mode 100644 index 0000000000..030b86e12d --- /dev/null +++ b/composable_kernel/include/utility/synchronization.nvidia.hpp.in @@ -0,0 +1,13 @@ +#ifndef CK_SYNCHRONIZATION_NVIDIA_HPP +#define CK_SYNCHRONIZATION_NVIDIA_HPP + +#include "config.hpp" + +namespace ck { + +__device__ void block_sync_lds() { __syncthreads(); } + +__device__ void block_sync_lds_vmem() { __syncthreads(); } + +} // namespace ck +#endif diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index d64d0adbc9..a986b14e1d 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -1,5 +1,5 @@ set(TENSOR_SOURCE - src/tensor.cpp; + src/host_tensor.cpp; src/device.cpp; ) @@ -25,8 +25,6 @@ elseif(DEVICE_BACKEND STREQUAL "NVIDIA") endif() add_executable(conv_driver ${CONV_SOURCE}) -add_executable(col2im_driver ${COL2IM_SOURCE}) add_executable(conv_bwd_data_driver ${CONV_BWD_DATA_SOURCE}) target_link_libraries(conv_driver PRIVATE host) -target_link_libraries(col2im_driver PRIVATE host) target_link_libraries(conv_bwd_data_driver PRIVATE host) diff --git a/driver/include/conv_common.hpp b/driver/include/conv_common.hpp index 3213d7de9e..2c09622e5e 100644 --- a/driver/include/conv_common.hpp +++ b/driver/include/conv_common.hpp @@ -1,56 +1,8 @@ #ifndef CONV_COMMON_HPP #define CONV_COMMON_HPP -#include "ConstantTensorDescriptor_deprecated.hpp" #include "tensor_descriptor.hpp" -template -constexpr auto get_convolution_output_default_4d_tensor_descriptor_deprecated( - InDesc, WeiDesc, ConvStrides, ConvDilations, LowerPads, UpperPads) -{ - using namespace ck; - - constexpr auto in_desc = InDesc{}; - constexpr auto wei_desc = WeiDesc{}; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - static_assert(in_desc.GetNumOfDimension() == 4, "input nDim is not 4"); - static_assert(wei_desc.GetNumOfDimension() == 4, "weight nDim is not 4"); - static_assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1), - "input & weight dimension not consistent"); - - constexpr index_t N = in_desc.GetLength(I0); - constexpr index_t Hi = in_desc.GetLength(I2); - constexpr index_t Wi = in_desc.GetLength(I3); - - constexpr index_t K = wei_desc.GetLength(I0); - constexpr index_t Y = wei_desc.GetLength(I2); - constexpr index_t X = wei_desc.GetLength(I3); - - constexpr index_t HPadLow = LowerPads{}.Get(I0); - constexpr index_t WPadLow = LowerPads{}.Get(I1); - - constexpr index_t HPadUp = UpperPads{}.Get(I0); - constexpr index_t WPadUp = UpperPads{}.Get(I1); - - constexpr index_t YEff = (Y - 1) * ConvDilations{}[0] + 1; - constexpr index_t XEff = (X - 1) * ConvDilations{}[1] + 1; - - constexpr index_t Ho = (Hi + HPadLow + HPadUp - YEff) / ConvStrides{}[0] + 1; - constexpr index_t Wo = (Wi + WPadLow + WPadUp - XEff) / ConvStrides{}[1] + 1; - - return make_ConstantTensorDescriptor_packed(Sequence{}); -} - template -#include "device.hpp" -#include "tensor.hpp" -#include "gridwise_operation_wrapper.hpp" -#include "gridwise_col2im_eb_nchw.hpp" - -template -void device_col2im_eb_nchw(ColDesc, - const Tensor& col_eb, - ImgDesc, - Tensor& img_nchw, - FilterSizes, - OutputSizes, - ConvStrides, - ConvDilations, - LeftPads, - RightPads, - std::size_t nrepeat) -{ - using namespace ck; - - constexpr auto col_eb_desc = ColDesc{}; - constexpr auto img_nchw_desc = ImgDesc{}; - - constexpr index_t N = img_nchw_desc.GetLengths()[0]; - constexpr index_t C = img_nchw_desc.GetLengths()[1]; - constexpr index_t Hi = img_nchw_desc.GetLengths()[2]; - constexpr index_t Wi = img_nchw_desc.GetLengths()[3]; - - constexpr index_t E = col_eb_desc.GetLengths()[0]; - constexpr index_t B = col_eb_desc.GetLengths()[1]; - - std::size_t data_sz = sizeof(T); - DeviceMem col_eb_device_buf(data_sz * col_eb.mDesc.GetElementSpace()); - DeviceMem img_nchw_device_buf(data_sz * img_nchw.mDesc.GetElementSpace()); - - col_eb_device_buf.ToDevice(col_eb.mData.data()); - img_nchw_device_buf.ToDevice(img_nchw.mData.data()); - -#if 1 - constexpr index_t BlockSize = 256; - - constexpr index_t EPerBlock = 128; - constexpr index_t BPerBlock = 128; - - using BlockCopySubLengths_E_B = Sequence<8, 8>; - using BlockCopyClusterLengths_E_B = Sequence<16, 16>; - using BlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B] - using BlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B] - using BlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B] - - constexpr index_t BlockCopyDataPerAccess_B = 1; -#endif - - constexpr index_t GridSize = - ((E + EPerBlock - 1) / EPerBlock) * ((B + BPerBlock - 1) / BPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - constexpr auto gridwise_col2im = GridwiseCol2Im_eb_nchw{}; - - for(index_t i = 0; i < nrepeat; ++i) - { - float time = - launch_and_time_kernel(run_gridwise_operation, - dim3(GridSize), - dim3(BlockSize), - 0, - gridwise_col2im, - const_cast( - static_cast(col_eb_device_buf.GetDeviceBuffer())), - const_cast( - static_cast(img_nchw_device_buf.GetDeviceBuffer()))); - - printf("Elapsed time : %f ms\n", time); - usleep(std::min(time * 1000, float(10000))); - } - - img_nchw_device_buf.FromDevice(img_nchw.mData.data()); -} diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp index 00dcbdc832..7357563eb5 100644 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp @@ -1,7 +1,7 @@ #pragma once #include #include "device.hpp" -#include "tensor.hpp" +#include "host_tensor.hpp" #include "gridwise_operation_wrapper.hpp" #include "gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp" @@ -49,16 +49,16 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); -#if 0 +#if 1 // BlockSize = 256, each thread hold 64 data constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4; @@ -83,6 +83,36 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i // BlockSize = 256, each thread hold 64 data constexpr index_t BlockSize = 256; + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 8; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmThreadGemmDataPerReadM = 4; + constexpr index_t GemmThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; +#elif 1 + // BlockSize = 256, each thread hold 64 data + constexpr index_t BlockSize = 256; + constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 16; @@ -119,7 +149,7 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - constexpr auto gridwise_conv = GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw< + using gridwise_conv_bwd_data = GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw< GridSize, BlockSize, T, @@ -151,28 +181,38 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, GemmBBlockCopySrcDataPerRead_GemmN, GemmBBlockCopyDstDataPerWrite_GemmN, - GemmCThreadCopyDstDataPerWrite_GemmN1>{}; + GemmCThreadCopyDstDataPerWrite_GemmN1>; - for(index_t i = 0; i < nrepeat; ++i) + for(index_t i = 0; i < 5; ++i) { - float time = launch_and_time_kernel(run_gridwise_operation, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - gridwise_conv, - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); + std::cout << "Start running " << nrepeat << " times..." << std::endl; - printf("Elapsed time : %f ms, %f TFlop/s\n", - time, - (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / time); - usleep(std::min(time * 1000, float(10000))); + KernelTimer timer; + timer.Start(); + + for(index_t j = 0; j < nrepeat; ++j) + { + launch_kernel(run_gridwise_operation, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + static_cast(in_nchw_device_buf.GetDeviceBuffer()), + static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), + static_cast(out_nkhw_device_buf.GetDeviceBuffer())); + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; } in_nchw_device_buf.FromDevice(in_nchw.mData.data()); diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp index 89f19725bf..aeeef9ab87 100644 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp @@ -1,7 +1,7 @@ #pragma once #include #include "device.hpp" -#include "tensor.hpp" +#include "host_tensor.hpp" #include "gridwise_operation_wrapper.hpp" #include "gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp" @@ -55,25 +55,27 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i constexpr index_t BPerBlock = 32; constexpr index_t EPerBlock = 32; - constexpr index_t KPerBlock = 8; + constexpr index_t KPerBlock = 16; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - using OutBlockCopySubLengths_K_B_N0 = Sequence<1, 1, 4>; + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using OutBlockCopySubLengths_K_B_N0 = Sequence<2, 1, 4>; using OutBlockCopyClusterLengths_K_B_N0 = Sequence<8, 32, 1>; constexpr index_t OutBlockCopySrcDataPerRead_B = 1; constexpr index_t OutBlockCopyDstDataPerWrite_N0 = 4; - using WeiBlockCopySubLengths_K_E_C0 = Sequence<1, 4, 1>; + using WeiBlockCopySubLengths_K_E_C0 = Sequence<2, 4, 1>; using WeiBlockCopyClusterLengths_K_E_C0 = Sequence<8, 8, 4>; constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; @@ -82,8 +84,8 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i constexpr index_t InThreadCopyDstDataPerWrite_B = 1; #endif - constexpr index_t C0 = GemmMPerThreadSubC; - constexpr index_t N0 = GemmNPerThreadSubC; + constexpr index_t C0 = GemmMPerThread; + constexpr index_t N0 = GemmNPerThread; constexpr index_t C1 = C / C0; constexpr index_t N1 = N / N0; @@ -96,7 +98,7 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - constexpr auto gridwise_conv = + using gridwise_conv_bwd_data = GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer< GridSize, BlockSize, @@ -112,13 +114,13 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i EPerBlock, BPerBlock, KPerBlock, - GemmMPerThreadSubC, - GemmNPerThreadSubC, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, GemmMLevel0Cluster, GemmNLevel0Cluster, GemmMLevel1Cluster, GemmNLevel1Cluster, - GemmKPerThreadLoop, GemmDataPerReadA, GemmDataPerReadB, OutBlockCopySubLengths_K_B_N0, @@ -129,28 +131,38 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i WeiBlockCopyClusterLengths_K_E_C0, WeiBlockCopySrcDataPerRead_E, WeiBlockCopyDstDataPerWrite_C0, - InThreadCopyDstDataPerWrite_B>{}; + InThreadCopyDstDataPerWrite_B>; - for(index_t i = 0; i < nrepeat; ++i) + for(index_t i = 0; i < 5; ++i) { - float time = launch_and_time_kernel(run_gridwise_operation, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - gridwise_conv, - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); + std::cout << "Start running " << nrepeat << " times..." << std::endl; - printf("Elapsed time : %f ms, %f TFlop/s\n", - time, - (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / time); - usleep(std::min(time * 1000, float(10000))); + KernelTimer timer; + timer.Start(); + + for(index_t j = 0; j < nrepeat; ++j) + { + launch_kernel(run_gridwise_operation, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + static_cast(in_nchw_device_buf.GetDeviceBuffer()), + static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), + static_cast(out_nkhw_device_buf.GetDeviceBuffer())); + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; } in_nchw_device_buf.FromDevice(in_nchw.mData.data()); diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp index 622062018d..92ad30c568 100644 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp @@ -1,7 +1,7 @@ #pragma once #include #include "device.hpp" -#include "tensor.hpp" +#include "host_tensor.hpp" #include "gridwise_operation_wrapper.hpp" #include "gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp" @@ -185,7 +185,7 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - constexpr auto gridwise_conv = GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw< + using gridwise_conv_bwd_data = GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw< GridSize, BlockSize, T, @@ -217,28 +217,38 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, GemmBBlockCopySrcDataPerRead_GemmN, GemmBBlockCopyDstDataPerWrite_GemmN, - GemmCThreadCopyDstDataPerWrite_GemmN1>{}; + GemmCThreadCopyDstDataPerWrite_GemmN1>; - for(index_t i = 0; i < nrepeat; ++i) + for(index_t i = 0; i < 5; ++i) { - float time = launch_and_time_kernel(run_gridwise_operation, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - gridwise_conv, - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); + std::cout << "Start running " << nrepeat << " times..." << std::endl; - printf("Elapsed time : %f ms, %f TFlop/s\n", - time, - (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / time); - usleep(std::min(time * 1000, float(10000))); + KernelTimer timer; + timer.Start(); + + for(index_t j = 0; j < nrepeat; ++j) + { + launch_kernel(run_gridwise_operation, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + static_cast(in_nchw_device_buf.GetDeviceBuffer()), + static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), + static_cast(out_nkhw_device_buf.GetDeviceBuffer())); + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; } in_nchw_device_buf.FromDevice(in_nchw.mData.data()); diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp index 2fec94b08b..ba68390326 100644 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp @@ -1,7 +1,7 @@ #pragma once #include #include "device.hpp" -#include "tensor.hpp" +#include "host_tensor.hpp" #include "gridwise_operation_wrapper.hpp" #include "gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp" @@ -124,7 +124,7 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - constexpr auto gridwise_conv = GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw< + using gridwise_conv_bwd_data = GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw< GridSize, BlockSize, T, @@ -156,28 +156,38 @@ void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc i GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, GemmBBlockCopySrcDataPerRead_GemmN, GemmBBlockCopyDstDataPerWrite_GemmN, - GemmCThreadCopyDstDataPerWrite_GemmN1>{}; + GemmCThreadCopyDstDataPerWrite_GemmN1>; - for(index_t i = 0; i < nrepeat; ++i) + for(index_t i = 0; i < 5; ++i) { - float time = launch_and_time_kernel(run_gridwise_operation, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - gridwise_conv, - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); + std::cout << "Start running " << nrepeat << " times..." << std::endl; - printf("Elapsed time : %f ms, %f TFlop/s\n", - time, - (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / time); - usleep(std::min(time * 1000, float(10000))); + KernelTimer timer; + timer.Start(); + + for(index_t j = 0; j < nrepeat; ++j) + { + launch_kernel(run_gridwise_operation, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + static_cast(in_nchw_device_buf.GetDeviceBuffer()), + static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), + static_cast(out_nkhw_device_buf.GetDeviceBuffer())); + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; } in_nchw_device_buf.FromDevice(in_nchw.mData.data()); diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index 8ae1c72527..e870990a72 100644 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -1,19 +1,14 @@ #pragma once #include #include "device.hpp" -#include "tensor.hpp" +#include "host_tensor.hpp" +#include "gridwise_operation_wrapper.hpp" #include "gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" namespace launcher { using namespace ck; -template -__global__ void run_gridwise_convolution_backward_data_v4r1(Xs... xs) -{ - GridwiseOp::template Run(xs...); -} - template ; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; #endif @@ -157,78 +122,82 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - for(index_t i = 0; i < nrepeat; ++i) + for(index_t i = 0; i < 5; ++i) { - using GridwiseConvBwdData = GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw< - GridSize, - BlockSize, - T, - T, - decltype(in_nchw_desc), - decltype(wei_kcyx_desc), - decltype(out_nkhw_desc), - ConvStrides, - ConvDilations, - InLeftPads, - InRightPads, - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerThread, - GemmNPerThread, - GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmThreadGemmDataPerReadM, - GemmThreadGemmDataPerReadN, - GemmABlockCopyThreadSliceLengths_GemmK_GemmM, - GemmABlockCopyThreadClusterLengths_GemmK_GemmM, - GemmABlockCopySrcDataPerRead_GemmM, - GemmABlockCopyDstDataPerWrite_GemmM, - GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, - GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, - GemmBBlockCopySrcDataPerRead_GemmN, - GemmBBlockCopyDstDataPerWrite_GemmN, - GemmCThreadCopyDstDataPerWrite_GemmN1>; + std::cout << "Start running " << nrepeat << " times..." << std::endl; KernelTimer timer; timer.Start(); - static_for<0, GridwiseConvBwdData::GetNumberOfGemm(), 1>{}([&](auto gemm_id_) { - constexpr index_t gemm_id = decltype(gemm_id_){}; + for(index_t i = 0; i < nrepeat; ++i) + { + using GridwiseConvBwdData = + GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw< + GridSize, + BlockSize, + T, + T, + decltype(in_nchw_desc), + decltype(wei_kcyx_desc), + decltype(out_nkhw_desc), + ConvStrides, + ConvDilations, + InLeftPads, + InRightPads, + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmThreadGemmDataPerReadM, + GemmThreadGemmDataPerReadN, + GemmABlockCopyThreadSliceLengths_GemmK_GemmM, + GemmABlockCopyThreadClusterLengths_GemmK_GemmM, + GemmABlockCopySrcDataPerRead_GemmM, + GemmABlockCopyDstDataPerWrite_GemmM, + GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, + GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, + GemmBBlockCopySrcDataPerRead_GemmN, + GemmBBlockCopyDstDataPerWrite_GemmN, + GemmCThreadCopyDstDataPerWrite_GemmN1>; - constexpr auto gemm_sizes = GridwiseConvBwdData::GetGemmSize(gemm_id); - constexpr index_t gemm_k = gemm_sizes.At(2); - constexpr bool is_gemm_not_empty = gemm_k > 0; + static_for<0, GridwiseConvBwdData::GetNumberOfGemm(), 1>{}([&](auto gemm_id) { + constexpr auto gemm_sizes = GridwiseConvBwdData::GetGemmSize(gemm_id); + constexpr index_t gemm_k = gemm_sizes.At(2); + constexpr bool is_gemm_not_empty = gemm_k > 0; - // only compile and run if GEMM is no empty - static_if{}([&](auto fwd) { - launch_kernel( - run_gridwise_convolution_backward_data_v4r1, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); + // only compile and run if GEMM is no empty + static_if{}([&](auto fwd) { + launch_kernel(run_gridwise_operation, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + static_cast(in_nchw_device_buf.GetDeviceBuffer()), + static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), + static_cast(out_nkhw_device_buf.GetDeviceBuffer()), + fwd(gemm_id)); + }); }); - }); + } timer.End(); - float time = timer.GetElapsedTime(); - printf("Elapsed time : %f ms, %f TFlop/s\n", - time, - (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / time); - usleep(std::min(time * 1000, float(10000))); + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; } in_nchw_device_buf.FromDevice(in_nchw.mData.data()); diff --git a/driver/include/device_convolution_direct_v2_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_direct_v2_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 5840947a45..0000000000 --- a/driver/include/device_convolution_direct_v2_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,98 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "tensor.hpp" -#include "gridwise_convolution_kernel_wrapper.hpp" -#include "gridwise_convolution_direct_v2_nchw_kcyx_nkhw.hpp" - -using namespace ck; - -template -void device_convolution_direct_v2_nchw_kcyx_nkhw(InDesc, - const Tensor& in, - WeiDesc, - const Tensor& wei, - OutDesc, - Tensor& out, - index_t nrepeat) -{ - std::size_t data_sz = sizeof(T); - DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace()); - DeviceMem wei_device_buf(data_sz * wei.mDesc.GetElementSpace()); - DeviceMem out_device_buf(data_sz * out.mDesc.GetElementSpace()); - - int num_thread = std::thread::hardware_concurrency(); - - in_device_buf.ToDevice(in.mData.data()); - wei_device_buf.ToDevice(wei.mData.data()); - out_device_buf.ToDevice(out.mData.data()); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_desc = InDesc{}; - constexpr auto wei_desc = WeiDesc{}; - constexpr auto out_desc = OutDesc{}; - -#if 1 - // 3x3, 34x34, 128 thread - constexpr index_t NPerBlock = 2; - constexpr index_t KPerBlock = 32; - constexpr index_t CPerBlock = 4; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 32; - - constexpr index_t NPerThread = 2; - constexpr index_t KPerThread = 4; - constexpr index_t CPerThread = 2; - constexpr index_t HoPerThread = 2; - constexpr index_t WoPerThread = 2; - - constexpr index_t InBlockCopyDataPerRead = 1; - constexpr index_t WeiBlockCopyDataPerRead = 1; - - constexpr index_t BlockSize = 128; -#endif - - constexpr index_t GridSize = - (out_desc.GetLength(I0) / NPerBlock) * (out_desc.GetLength(I1) / KPerBlock) * - (out_desc.GetLength(I2) / HoPerBlock) * (out_desc.GetLength(I3) / WoPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - for(index_t i = 0; i < nrepeat; ++i) - { - using gridwise_conv = GridwiseConvolutionDirect_v2_nchw_kcyx_nkhw; - float time = launch_and_time_kernel(run_gridwise_convolution_kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer())); - - printf("Elapsed time : %f ms\n", time); - usleep(std::min(time * 1000, float(10000))); - } - - out_device_buf.FromDevice(out.mData.data()); -} diff --git a/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp b/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp deleted file mode 100644 index 39a05db992..0000000000 --- a/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp +++ /dev/null @@ -1,486 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "tensor.hpp" -#include "gridwise_convolution_kernel_wrapper.hpp" -#include "gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hpp" -#include "gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hpp" -#include "gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hpp" -#include "gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_lds_double_buffer.hpp" - -using namespace ck; - -template -void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcyx, - OutDesc, - Tensor& out_nkhw, - index_t nrepeat) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_desc = InDesc{}; - constexpr auto wei_kcyx_desc = WeiDesc{}; - constexpr auto out_nkhw_desc = OutDesc{}; - - constexpr index_t Hi = in_nchw_desc.GetLength(I2); - constexpr index_t Wi = in_nchw_desc.GetLength(I3); - - constexpr index_t N = out_nkhw_desc.GetLength(I0); - constexpr index_t Ho = out_nkhw_desc.GetLength(I2); - constexpr index_t Wo = out_nkhw_desc.GetLength(I3); - - constexpr index_t K = wei_kcyx_desc.GetLength(I0); - constexpr index_t C = wei_kcyx_desc.GetLength(I1); - constexpr index_t Y = wei_kcyx_desc.GetLength(I2); - constexpr index_t X = wei_kcyx_desc.GetLength(I3); - - // reorder weight - auto wei_cyxk_desc = make_ConstantTensorDescriptor_packed(Sequence{}); - ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: "); - - Tensor wei_cyxk(make_TensorDescriptor(wei_cyxk_desc)); - - auto f_reorder_kcyx2cyxk = [&](auto k, auto c, auto y, auto x) { - wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x); - }; - - make_ParallelTensorFunctor(f_reorder_kcyx2cyxk, K, C, Y, X)( - std::thread::hardware_concurrency()); - - // reorder input - auto in_chwn_desc = make_ConstantTensorDescriptor_packed(Sequence{}); - ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: "); - - Tensor in_chwn(make_TensorDescriptor(in_chwn_desc)); - - auto f_reorder_nchw2chwn = [&](auto n, auto c, auto hi, auto wi) { - in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi); - }; - - make_ParallelTensorFunctor(f_reorder_nchw2chwn, N, C, Hi, Wi)( - std::thread::hardware_concurrency()); - - // output - auto out_khwn_desc = make_ConstantTensorDescriptor_packed(Sequence{}); - ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: "); - - Tensor out_khwn(make_TensorDescriptor(out_khwn_desc)); - - std::size_t data_sz = sizeof(T); - DeviceMem in_chwn_device_buf(data_sz * in_chwn.mDesc.GetElementSpace()); - DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace()); - DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace()); - - in_chwn_device_buf.ToDevice(in_chwn.mData.data()); - wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data()); - out_khwn_device_buf.ToDevice(out_khwn.mData.data()); - -#if 0 - // for 3x3, 34x34, v1r1, Pascal - constexpr index_t BlockSize = 128; - - constexpr index_t NPerBlock = 16; - constexpr index_t KPerBlock = 64; - constexpr index_t CPerBlock = 4; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 4; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 2; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 4>; - constexpr index_t InBlockCopyDataPerAccess_N = 4; - - constexpr index_t WeiBlockCopyDataPerAccess_K = 4; - - constexpr index_t OutThreadCopyDataPerAccess_N = 2; -#elif 1 - // for 3x3, 34x34, v1r3, Pascal - // for 3x3, 28x28, v1r3, Pascal - // for 3x3, 14x14, v1r3, Pascal - constexpr index_t BlockSize = 128; - - constexpr index_t NPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 2; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 2; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 2; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_CHWN = Sequence<1, 1, 1, 4>; - using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 2, 4>; - constexpr index_t InBlockCopyDataPerAccess_N = 4; - - using WeiBlockCopySubLengths_CK = Sequence<2, 4>; - using WeiBlockCopyClusterLengths_CK = Sequence<4, 32>; - constexpr index_t WeiBlockCopyDataPerAccess_K = 4; - - constexpr index_t OutThreadCopyDataPerAccess_N = 2; -#elif 0 - // for 3x3, 34x34, v1r1, Vega 20 - constexpr index_t BlockSize = 256; - - constexpr index_t NPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 4; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 4; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 2; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 2; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 8>; - constexpr index_t InBlockCopyDataPerAccess_N = 2; - - constexpr index_t WeiBlockCopyDataPerAccess_K = 2; - - constexpr index_t OutThreadCopyDataPerAccess_N = 4; -#elif 1 - // for 3x3, 34x34, v1r3, Vega 20 - constexpr index_t BlockSize = 256; - - constexpr index_t NPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 4; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 2; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 2; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_CHWN = Sequence<1, 1, 1, 4>; - using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 4, 4>; - constexpr index_t InBlockCopyDataPerAccess_N = 4; - - using WeiBlockCopySubLengths_CK = Sequence<1, 4>; - using WeiBlockCopyClusterLengths_CK = Sequence<8, 32>; - constexpr index_t WeiBlockCopyDataPerAccess_K = 4; - - constexpr index_t OutThreadCopyDataPerAccess_N = 4; -#elif 0 - // for 3x3, 56x56, v1r1, Pascal - constexpr index_t NPerBlock = 32; - constexpr index_t KPerBlock = 64; - constexpr index_t CPerBlock = 4; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 2; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 2; - - constexpr index_t InBlockCopy_ThreadPerDimC = 1; - constexpr index_t InBlockCopy_ThreadPerDimH = 4; - constexpr index_t InBlockCopy_ThreadPerDimW = 4; - constexpr index_t InBlockCopy_ThreadPerDimN = 8; - constexpr index_t InBlockCopyDataPerAccess_N = 4; - - constexpr index_t WeiBlockCopyDataPerAccess_K = 4; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - - constexpr index_t OutThreadCopyDataPerAccess_N = 2; - - constexpr index_t BlockSize = 128; -#elif 0 - // for 3x3, 56x56, v1r2, Pascal - constexpr index_t NPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 2; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 2; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 2; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 1; - constexpr index_t GemmDataPerReadB = 1; - - constexpr index_t InBlockCopy_ThreadPerDimC = 1; - constexpr index_t InBlockCopy_ThreadPerDimH = 2; - constexpr index_t InBlockCopy_ThreadPerDimW = 4; - constexpr index_t InBlockCopy_ThreadPerDimN = 4; - constexpr index_t InBlockCopyDataPerAccess_N = 4; - - constexpr index_t WeiBlockCopyDataPerAccess_K = 4; - constexpr index_t OutThreadCopyDataPerAccess_N = 4; - - constexpr index_t BlockSize = 128; -#elif 0 - // for 3x3, 28x28, v1r1, Pacal - constexpr index_t NPerBlock = 32; - constexpr index_t KPerBlock = 64; - constexpr index_t CPerBlock = 4; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 2; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 2; - - constexpr index_t InBlockCopy_ThreadPerDimC = 1; - constexpr index_t InBlockCopy_ThreadPerDimH = 4; - constexpr index_t InBlockCopy_ThreadPerDimW = 4; - constexpr index_t InBlockCopy_ThreadPerDimN = 8; - constexpr index_t InBlockCopyDataPerAccess_N = 4; - - constexpr index_t WeiBlockCopyDataPerAccess_K = 4; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - constexpr index_t OutThreadCopyDataPerAccess_N = 2; - - constexpr index_t BlockSize = 128; -#elif 0 - // for 3x3, 28x28, v1r2, Pascal - constexpr index_t BlockSize = 128; - - constexpr index_t NPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 2; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 2; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 2; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopyClusterLengths_CHWN = Sequence<4, 2, 4, 4>; - constexpr index_t InBlockCopyDataPerAccess_N = 4; - - constexpr index_t WeiBlockCopyDataPerAccess_K = 4; - - constexpr index_t OutThreadCopyDataPerAccess_N = 2; -#elif 0 - // for 1x1, 28x28, v1r1, Pascal - constexpr index_t NPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 2; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 16; - constexpr index_t CPerThread = 1; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 1; - - constexpr index_t InBlockCopy_ThreadPerDimC = 8; - constexpr index_t InBlockCopy_ThreadPerDimH = 2; - constexpr index_t InBlockCopy_ThreadPerDimW = 2; - constexpr index_t InBlockCopy_ThreadPerDimN = 4; - constexpr index_t InBlockCopyDataPerAccess_N = 4; - - constexpr index_t WeiBlockCopyDataPerAccess_K = 4; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - - constexpr index_t OutThreadCopyDataPerAccess_N = 2; - - constexpr index_t BlockSize = 128; -#elif 0 - // for 1x1, 14x14, v1r1, Pascal - constexpr index_t NPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 2; - - constexpr index_t NPerThread = 8; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 1; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 2; - constexpr index_t GemmKPerThreadLoop = 1; - - constexpr index_t InBlockCopy_ThreadPerDimC = 8; - constexpr index_t InBlockCopy_ThreadPerDimH = 2; - constexpr index_t InBlockCopy_ThreadPerDimW = 2; - constexpr index_t InBlockCopy_ThreadPerDimN = 4; - constexpr index_t InBlockCopyDataPerAccess_N = 4; - - constexpr index_t WeiBlockCopyDataPerAccess_K = 4; - constexpr index_t OutThreadCopyDataPerAccess_N = 2; - - constexpr index_t BlockSize = 128; -#endif - - constexpr index_t GridSize = - (N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - constexpr auto gridwise_conv = -#if 0 - GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn -#elif 0 - GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn -#elif 0 - GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn -#elif 1 - GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer -#endif - {}; - - for(index_t i = 0; i < nrepeat; ++i) - { - float time = - launch_and_time_kernel(run_gridwise_convolution_kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - static_cast(in_chwn_device_buf.GetDeviceBuffer()), - static_cast(wei_cyxk_device_buf.GetDeviceBuffer()), - static_cast(out_khwn_device_buf.GetDeviceBuffer())); - - printf("Elapsed time : %f ms, %f TFlop/s\n", - time, - (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / time); - usleep(std::min(time * 1000, float(10000))); - } - - out_khwn_device_buf.FromDevice(out_khwn.mData.data()); - - // reorder output - auto f_reorder_khwn2nkhw = [&](auto k, auto ho, auto wo, auto n) { - out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n); - }; - - make_ParallelTensorFunctor(f_reorder_khwn2nkhw, K, Ho, Wo, N)( - std::thread::hardware_concurrency()); -} diff --git a/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp b/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp deleted file mode 100644 index 34a10e2d46..0000000000 --- a/driver/include/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp +++ /dev/null @@ -1,189 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "tensor.hpp" -#include "gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn_padded.hpp" - -using namespace ck; - -template -void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcyx, - OutDesc, - Tensor& out_nkhw, - LeftPads, - RightPads, - index_t nrepeat) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_desc = InDesc{}; - constexpr auto wei_kcyx_desc = WeiDesc{}; - constexpr auto out_nkhw_desc = OutDesc{}; - - constexpr index_t Hi = in_nchw_desc.GetLength(I2); - constexpr index_t Wi = in_nchw_desc.GetLength(I3); - - constexpr index_t N = out_nkhw_desc.GetLength(I0); - constexpr index_t Ho = out_nkhw_desc.GetLength(I2); - constexpr index_t Wo = out_nkhw_desc.GetLength(I3); - - constexpr index_t K = wei_kcyx_desc.GetLength(I0); - constexpr index_t C = wei_kcyx_desc.GetLength(I1); - constexpr index_t Y = wei_kcyx_desc.GetLength(I2); - constexpr index_t X = wei_kcyx_desc.GetLength(I3); - - // reorder weight - auto wei_cyxk_desc = make_ConstantTensorDescriptor_packed(Sequence{}); - ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: "); - - Tensor wei_cyxk(make_TensorDescriptor(wei_cyxk_desc)); - - auto f_reorder_kcyx2cyxk = [&](auto k, auto c, auto y, auto x) { - wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x); - }; - - make_ParallelTensorFunctor(f_reorder_kcyx2cyxk, K, C, Y, X)( - std::thread::hardware_concurrency()); - - // reorder input - auto in_chwn_desc = make_ConstantTensorDescriptor_packed(Sequence{}); - ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: "); - - Tensor in_chwn(make_TensorDescriptor(in_chwn_desc)); - - auto f_reorder_nchw2chwn = [&](auto n, auto c, auto hi, auto wi) { - in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi); - }; - - make_ParallelTensorFunctor(f_reorder_nchw2chwn, N, C, Hi, Wi)( - std::thread::hardware_concurrency()); - - // output - auto out_khwn_desc = make_ConstantTensorDescriptor_packed(Sequence{}); - ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: "); - - Tensor out_khwn(make_TensorDescriptor(out_khwn_desc)); - - std::size_t data_sz = sizeof(T); - DeviceMem in_chwn_device_buf(data_sz * in_chwn.mDesc.GetElementSpace()); - DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace()); - DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace()); - - in_chwn_device_buf.ToDevice(in_chwn.mData.data()); - wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data()); - out_khwn_device_buf.ToDevice(out_khwn.mData.data()); - -#if 1 - // v1r3, 3x3, 32x32, 1x1 pad - constexpr index_t BlockSize = 256; - - constexpr index_t NPerBlock = 32; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 2; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 2; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 2; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_CHWN = Sequence<1, 1, 1, 4>; - using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 2, 8>; - constexpr index_t InBlockCopyDataPerAccess_N = 4; - - using WeiBlockCopySubLengths_CK = Sequence<1, 4>; - using WeiBlockCopyClusterLengths_CK = Sequence<8, 32>; - constexpr index_t WeiBlockCopyDataPerAccess_K = 4; - - constexpr index_t OutThreadCopyDataPerAccess_N = 4; -#endif - -#if 1 // debug - constexpr index_t GridSize = - (N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock); -#else - constexpr index_t GridSize = 1; -#endif - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - constexpr auto gridwise_conv = - GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_padded{}; - - for(index_t i = 0; i < nrepeat; ++i) - { - float time = - launch_and_time_kernel(run_gridwise_convolution_kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - static_cast(in_chwn_device_buf.GetDeviceBuffer()), - static_cast(wei_cyxk_device_buf.GetDeviceBuffer()), - static_cast(out_khwn_device_buf.GetDeviceBuffer())); - - printf("Elapsed time : %f ms, %f TFlop/s\n", - time, - (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / time); - usleep(std::min(time * 1000, float(10000))); - } - - out_khwn_device_buf.FromDevice(out_khwn.mData.data()); - - // reorder output - auto f_reorder_khwn2nkhw = [&](auto k, auto ho, auto wo, auto n) { - out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n); - }; - - make_ParallelTensorFunctor(f_reorder_khwn2nkhw, K, Ho, Wo, N)( - std::thread::hardware_concurrency()); -} diff --git a/driver/include/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp deleted file mode 100644 index 3b192c9a86..0000000000 --- a/driver/include/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp +++ /dev/null @@ -1,374 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "tensor.hpp" -#include "gridwise_convolution_kernel_wrapper.hpp" -#include "gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw.hpp" -#include "gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer.hpp" - -using namespace ck; - -template -void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcyx, - OutDesc, - Tensor& out_nkhw, - index_t nrepeat) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_desc = InDesc{}; - constexpr auto wei_kcyx_desc = WeiDesc{}; - constexpr auto out_nkhw_desc = OutDesc{}; - - constexpr index_t Hi = in_nchw_desc.GetLength(I2); - constexpr index_t Wi = in_nchw_desc.GetLength(I3); - - constexpr index_t N = out_nkhw_desc.GetLength(I0); - constexpr index_t Ho = out_nkhw_desc.GetLength(I2); - constexpr index_t Wo = out_nkhw_desc.GetLength(I3); - - constexpr index_t K = wei_kcyx_desc.GetLength(I0); - constexpr index_t C = wei_kcyx_desc.GetLength(I1); - constexpr index_t Y = wei_kcyx_desc.GetLength(I2); - constexpr index_t X = wei_kcyx_desc.GetLength(I3); - - // reorder weight - auto wei_cyxk_desc = make_ConstantTensorDescriptor_packed(Sequence{}); - ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: "); - - Tensor wei_cyxk(make_TensorDescriptor(wei_cyxk_desc)); - - auto f_reorder_kcyx2cyxk = [&](auto k, auto c, auto y, auto x) { - wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x); - }; - - make_ParallelTensorFunctor(f_reorder_kcyx2cyxk, K, C, Y, X)( - std::thread::hardware_concurrency()); - - std::size_t data_sz = sizeof(T); - DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); - DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace()); - DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace()); - - in_nchw_device_buf.ToDevice(in_nchw.mData.data()); - wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data()); - out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); - -#if 0 - // for 3x3, 34x34, v1r3, Pascal - constexpr index_t BlockSize = 128; - - constexpr index_t NPerBlock = 2; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 16; - - constexpr index_t NPerThread = 2; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 4; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 2; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockReorderSrcSubLengths_NCHW = Sequence<2, 1, 2, 1>; - using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 1, 16>; - using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>; - constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW - constexpr index_t InBlockReorderDataPerWrite_N = 1; - - using WeiBlockCopyClusterLengths = void; - constexpr index_t WeiBlockCopyDataPerRead_K = 4; - - constexpr index_t OutThreadCopyDataPerWrite_W = 2; -#elif 0 - // for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 32 - constexpr index_t BlockSize = 256; - - constexpr index_t NPerBlock = 1; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - constexpr index_t HoPerBlock = 4; - constexpr index_t WoPerBlock = 32; - - constexpr index_t NPerThread = 1; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 8; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 2; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockReorderSrcSubLengths_NCHW = Sequence<1, 2, 2, 1>; - using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 4, 2, 32>; - using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>; - constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW - constexpr index_t InBlockReorderDataPerWrite_N = 1; - - using WeiBlockCopyClusterLengths = void; - constexpr index_t WeiBlockCopyDataPerRead_K = 4; - - constexpr index_t OutThreadCopyDataPerWrite_W = 4; -#elif 1 - // for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 16 - constexpr index_t BlockSize = 256; - - constexpr index_t NPerBlock = 2; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - constexpr index_t HoPerBlock = 4; - constexpr index_t WoPerBlock = 16; - - constexpr index_t NPerThread = 2; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 4; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 2; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockReorderSrcSubLengths_NCHW = Sequence<2, 1, 2, 1>; - using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 2, 16>; - using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>; - constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW - constexpr index_t InBlockReorderDataPerWrite_N = 2; - - using WeiBlockCopyClusterLengths = void; - constexpr index_t WeiBlockCopyDataPerRead_K = 4; - - constexpr index_t OutThreadCopyDataPerWrite_W = 2; -#elif 0 - // for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 8 - constexpr index_t BlockSize = 256; - - constexpr index_t NPerBlock = 4; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - constexpr index_t HoPerBlock = 4; - constexpr index_t WoPerBlock = 8; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 2; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 2; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>; - using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 4, 8>; - using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>; - constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW - constexpr index_t InBlockReorderDataPerWrite_N = 4; - - using WeiBlockCopyClusterLengths = void; - constexpr index_t WeiBlockCopyDataPerRead_K = 4; - - constexpr index_t OutThreadCopyDataPerWrite_W = 1; -#elif 0 - // for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 4 - constexpr index_t BlockSize = 256; - - constexpr index_t NPerBlock = 8; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - constexpr index_t HoPerBlock = 4; - constexpr index_t WoPerBlock = 4; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 2; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 2; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>; - using InBlockReorderSrcClusterLengths_NCHW = Sequence<2, 8, 4, 4>; - using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>; - constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW - constexpr index_t InBlockReorderDataPerWrite_N = 4; - - using WeiBlockCopyClusterLengths = void; - constexpr index_t WeiBlockCopyDataPerRead_K = 4; - - constexpr index_t OutThreadCopyDataPerWrite_W = 1; -#elif 0 - // for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 2 - constexpr index_t BlockSize = 256; - - constexpr index_t NPerBlock = 32; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 2; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 2; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 2; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>; - using InBlockReorderSrcClusterLengths_NCHW = Sequence<8, 8, 2, 2>; - using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>; - constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW - constexpr index_t InBlockReorderDataPerWrite_N = 4; - - using WeiBlockCopyClusterLengths = void; - constexpr index_t WeiBlockCopyDataPerRead_K = 4; - - constexpr index_t OutThreadCopyDataPerWrite_W = 1; -#elif 1 - // for 3x3, 28x28, v1r3, Pascal - constexpr index_t BlockSize = 128; - - constexpr index_t NPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 2; - - constexpr index_t NPerThread = 4; - constexpr index_t KPerThread = 8; - constexpr index_t HoPerThread = 1; - constexpr index_t WoPerThread = 2; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 2; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>; - using InBlockReorderSrcClusterLengths_NCHW = Sequence<4, 8, 2, 2>; - using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>; - constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load NCHW - constexpr index_t InBlockReorderDataPerWrite_N = 4; - - using WeiBlockCopyClusterLengths = void; - constexpr index_t WeiBlockCopyDataPerRead_K = 4; - - constexpr index_t OutThreadCopyDataPerWrite_W = 2; -#endif - - constexpr index_t GridSize = - ((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) * - ((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - for(index_t i = 0; i < nrepeat; ++i) - { - constexpr auto gridwise_conv = -#if 0 - GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw -#else - GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw_lds_double_buffer -#endif - {}; - - float time = - launch_and_time_kernel(run_gridwise_convolution_kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - static_cast(wei_cyxk_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); - - printf("Elapsed time : %f ms, %f TFlop/s\n", - time, - (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / time); - usleep(std::min(time * 1000, float(10000))); - } - - out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); -} diff --git a/driver/include/device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp b/driver/include/device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp deleted file mode 100644 index 50da0a7df5..0000000000 --- a/driver/include/device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp +++ /dev/null @@ -1,334 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "tensor.hpp" -#include "gridwise_convolution_kernel_wrapper.hpp" -#include "gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp" -#include "gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hpp" - -using namespace ck; - -template -void device_convolution_implicit_gemm_v2_chwn_cyxk_khwn(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcyx, - OutDesc, - Tensor& out_nkhw, - index_t nrepeat) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_desc = InDesc{}; - constexpr auto wei_kcyx_desc = WeiDesc{}; - constexpr auto out_nkhw_desc = OutDesc{}; - - constexpr index_t N = in_nchw_desc.GetLength(I0); - constexpr index_t Hi = in_nchw_desc.GetLength(I2); - constexpr index_t Wi = in_nchw_desc.GetLength(I3); - - constexpr index_t Ho = out_nkhw_desc.GetLength(I2); - constexpr index_t Wo = out_nkhw_desc.GetLength(I3); - - constexpr index_t K = wei_kcyx_desc.GetLength(I0); - constexpr index_t C = wei_kcyx_desc.GetLength(I1); - constexpr index_t Y = wei_kcyx_desc.GetLength(I2); - constexpr index_t X = wei_kcyx_desc.GetLength(I3); - - constexpr index_t BGhostRead = (Y - 1) * Wi + (X - 1); - - // convert in_nchw to in_cnhw - auto in_chwn_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(in_chwn_desc, std::cout << "in_chwn_desc: "); - - Tensor in_chwn(make_TensorDescriptor(in_chwn_desc)); - - make_ParallelTensorFunctor( - [&](auto n, auto c, auto hi, auto wi) { in_chwn(c, hi, wi, n) = in_nchw(n, c, hi, wi); }, - N, - C, - Hi, - Wi)(std::thread::hardware_concurrency()); - - // convert wei_kcyx to wei_cyxk - auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: "); - - Tensor wei_cyxk(make_TensorDescriptor(wei_cyxk_desc)); - - make_ParallelTensorFunctor( - [&](auto k, auto c, auto y, auto x) { wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x); }, - K, - C, - Y, - X)(std::thread::hardware_concurrency()); - - // conver out_nkhw to out_knhw - auto out_khwn_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: "); - - Tensor out_khwn(make_TensorDescriptor(out_khwn_desc)); - -#if 0 - // 3x3, 34x34 - // need to use register double buffer for GEMM - constexpr index_t BPerBlock = 128; - constexpr index_t KPerBlock = 64; - constexpr index_t CPerBlock = 4; - - constexpr index_t BPerThread = 8; - constexpr index_t KPerThread = 8; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 8; - constexpr index_t GemmKPerThreadLoop = 1; - - constexpr index_t InBlockCopyThreadPerDim0 = 4; - constexpr index_t InBlockCopyThreadPerDim1 = 16; - - constexpr index_t WeiBlockCopyThreadPerDim0 = 4; - constexpr index_t WeiBlockCopyThreadPerDim1 = 16; - - constexpr index_t InBlockCopyDataPerRead = 4; - constexpr index_t WeiBlockCopyDataPerRead = 4; - constexpr index_t OutThreadCopyDataPerWrite = 4; - - constexpr index_t BlockSize = 128; -#elif 0 - // 1x1, 28x28, 64 threads - constexpr index_t BPerBlock = 64; - constexpr index_t KPerBlock = 64; - constexpr index_t CPerBlock = 8; - - constexpr index_t BPerThread = 8; - constexpr index_t KPerThread = 8; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - - constexpr index_t GemmThreadPerColumnPerCluster = 8; - constexpr index_t GemmThreadPerRowPerCluster = 8; - - constexpr index_t InBlockCopyThreadPerDim0 = 4; - constexpr index_t InBlockCopyThreadPerDim1 = 16; - - constexpr index_t WeiBlockCopyThreadPerDim0 = 4; - constexpr index_t WeiBlockCopyThreadPerDim1 = 16; - - constexpr index_t InBlockCopyDataPerRead = 4; - constexpr index_t WeiBlockCopyDataPerRead = 4; - - constexpr index_t BlockSize = 64; -#elif 0 - // 1x1, 28x28, 128 threads, no lds-double-buffer - // 1x1, 28x28, 128 threads, with lds-double-buffer, max_register = 128 - constexpr index_t BPerBlock = 64; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - - constexpr index_t BPerThread = 8; - constexpr index_t KPerThread = 8; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - - constexpr index_t GemmThreadPerColumnPerCluster = 8; - constexpr index_t GemmThreadPerRowPerCluster = 8; - - constexpr index_t InBlockCopyThreadPerDim0 = 4; - constexpr index_t InBlockCopyThreadPerDim1 = 16; - - constexpr index_t WeiBlockCopyThreadPerDim0 = 4; - constexpr index_t WeiBlockCopyThreadPerDim1 = 16; - - constexpr index_t InBlockCopyDataPerRead = 4; - constexpr index_t WeiBlockCopyDataPerRead = 4; - - constexpr index_t BlockSize = 128; -#elif 0 - // 1x1, 28x28, 256 thread - constexpr index_t BPerBlock = 128; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - - constexpr index_t BPerThread = 8; - constexpr index_t KPerThread = 8; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - - constexpr index_t GemmThreadPerColumnPerCluster = 8; - constexpr index_t GemmThreadPerRowPerCluster = 8; - - constexpr index_t InBlockCopyThreadPerDim0 = 4; - constexpr index_t InBlockCopyThreadPerDim1 = 16; - - constexpr index_t WeiBlockCopyThreadPerDim0 = 4; - constexpr index_t WeiBlockCopyThreadPerDim1 = 16; - - constexpr index_t InBlockCopyDataPerRead = 4; - constexpr index_t WeiBlockCopyDataPerRead = 4; - - constexpr index_t BlockSize = 256; -#elif 0 - // 1x1, 14x14, Pascal, enable lds_double_buffer, disable register double buffer - constexpr index_t BPerBlock = 64; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - - constexpr index_t BPerThread = 8; - constexpr index_t KPerThread = 8; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - constexpr index_t InBlockCopyThreadPerDim0 = 4; - constexpr index_t InBlockCopyThreadPerDim1 = 16; - - constexpr index_t WeiBlockCopyThreadPerDim0 = 4; - constexpr index_t WeiBlockCopyThreadPerDim1 = 16; - - constexpr index_t InBlockCopyDataPerRead = 4; - constexpr index_t WeiBlockCopyDataPerRead = 4; - constexpr index_t OutThreadCopyDataPerWrite = 4; - - constexpr index_t BlockSize = 128; -#elif 1 - // 1x1, 14x14, Vega 20, enable lds_double_buffer, disable register_double_buffer - constexpr index_t BPerBlock = 128; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - - constexpr index_t BPerThread = 8; - constexpr index_t KPerThread = 8; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - constexpr index_t InBlockCopyThreadPerDim0 = 4; - constexpr index_t InBlockCopyThreadPerDim1 = 16; - - constexpr index_t WeiBlockCopyThreadPerDim0 = 4; - constexpr index_t WeiBlockCopyThreadPerDim1 = 16; - - constexpr index_t InBlockCopyDataPerRead = 4; - constexpr index_t WeiBlockCopyDataPerRead = 4; - constexpr index_t OutThreadCopyDataPerWrite = 4; - - constexpr index_t BlockSize = 256; -#endif - - constexpr index_t GridSize = - ((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - // mem - std::size_t data_sz = sizeof(T); - DeviceMem in_chwn_device_buf(data_sz * (in_chwn.mDesc.GetElementSpace() + BGhostRead + - BPerBlock)); // reserve extra space for BGhostRead - DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace()); - DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace()); - - in_chwn_device_buf.ToDevice(in_chwn.mData.data()); - wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data()); - out_khwn_device_buf.ToDevice(out_khwn.mData.data()); - - for(index_t i = 0; i < nrepeat; ++i) - { - constexpr auto gridwise_conv = -#if 0 - GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn -#else - GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer -#endif - {}; - - float time = - launch_and_time_kernel(run_gridwise_convolution_kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - static_cast(in_chwn_device_buf.GetDeviceBuffer()), - static_cast(wei_cyxk_device_buf.GetDeviceBuffer()), - static_cast(out_khwn_device_buf.GetDeviceBuffer())); - - printf("Elapsed time : %f ms, %f TFlop/s\n", - time, - (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / time); - usleep(std::min(time * 1000, float(10000))); - } - - out_khwn_device_buf.FromDevice(out_khwn.mData.data()); - - // convert out_khwn to out_nkhw - make_ParallelTensorFunctor( - [&](auto n, auto k, auto ho, auto wo) { out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n); }, - N, - K, - Ho, - Wo)(std::thread::hardware_concurrency()); -} diff --git a/driver/include/device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp deleted file mode 100644 index 23cef570fc..0000000000 --- a/driver/include/device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp +++ /dev/null @@ -1,155 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "tensor.hpp" -#include "gridwise_convolution_kernel_wrapper.hpp" -#include "gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp" -#include "gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp" - -template -void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcyx, - OutDesc, - Tensor& out_nkhw, - index_t nrepeat) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_desc = InDesc{}; - constexpr auto wei_kcyx_desc = WeiDesc{}; - constexpr auto out_nkhw_desc = OutDesc{}; - - constexpr index_t Hi = in_nchw_desc.GetLength(I2); - constexpr index_t Wi = in_nchw_desc.GetLength(I3); - - constexpr index_t N = out_nkhw_desc.GetLength(I0); - constexpr index_t Ho = out_nkhw_desc.GetLength(I2); - constexpr index_t Wo = out_nkhw_desc.GetLength(I3); - - constexpr index_t K = wei_kcyx_desc.GetLength(I0); - constexpr index_t C = wei_kcyx_desc.GetLength(I1); - constexpr index_t Y = wei_kcyx_desc.GetLength(I2); - constexpr index_t X = wei_kcyx_desc.GetLength(I3); - - // reorder weight - auto wei_cyxk_desc = make_ConstantTensorDescriptor_packed(Sequence{}); - ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: "); - - Tensor wei_cyxk(make_TensorDescriptor(wei_cyxk_desc)); - - auto f_reorder_kcyx2cyxk = [&](auto k, auto c, auto y, auto x) { - wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x); - }; - - make_ParallelTensorFunctor(f_reorder_kcyx2cyxk, K, C, Y, X)( - std::thread::hardware_concurrency()); - - std::size_t data_sz = sizeof(T); - DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); - DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace()); - DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace()); - - in_nchw_device_buf.ToDevice(in_nchw.mData.data()); - wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data()); - out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); - - constexpr index_t N1 = 2; - constexpr index_t N2 = 4; - - constexpr index_t B = (N * Ho * Wo) / (N1 * N2); - -#if 1 - constexpr index_t BlockSize = 256; - - constexpr index_t BPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t CPerBlock = 8; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_C_N1_B_N2 = Sequence<1, 1, 1, 4>; - using InBlockCopyClusterLengths_C_N1_B_N2 = Sequence<8, 2, 16, 1>; - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; - - using WeiBlockCopySubLengths_C_K = Sequence<1, 4>; - using WeiBlockCopyClusterLengths_C_K = Sequence<8, 32>; - - constexpr index_t WeiBlockCopyDataPerAccess_K = 4; -#endif - - constexpr index_t GridSize = - ((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - for(index_t i = 0; i < nrepeat; ++i) - { - constexpr auto gridwise_conv = -#if 0 - GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw -#else - GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer -#endif - {}; - - float time = - launch_and_time_kernel(run_gridwise_convolution_kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - static_cast(wei_cyxk_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); - - printf("Elapsed time : %f ms, %f TFlop/s\n", - time, - (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / time); - usleep(std::min(time * 1000, float(10000))); - } - - out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); -} diff --git a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index 07a3659856..32d01136ca 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -1,9 +1,8 @@ #pragma once #include #include "device.hpp" -#include "tensor.hpp" +#include "host_tensor.hpp" #include "gridwise_operation_wrapper.hpp" -#include "convolution_common.hpp" #include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp" template ::value, half_t, T>::type; + constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; @@ -55,25 +56,105 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); #if 0 - // BlockSize = 256, EperBlock = 8, each thread hold 64 data + // cdata = 64, BlockSize = 256, 64x256x8 constexpr index_t BlockSize = 256; - constexpr index_t BPerBlock = 16; - constexpr index_t KPerBlock = 128; + constexpr index_t KPerBlock = 64; + constexpr index_t BPerBlock = 32; constexpr index_t EPerBlock = 8; constexpr index_t GemmNRepeat = 2; - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmNLevel1Cluster = 16; + constexpr index_t GemmDataPerReadA = 4; constexpr index_t GemmDataPerReadB = 4; + using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>; + using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 32, 1>; + using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] + using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] + + constexpr index_t InBlockCopySrcDataPerRead_B = 1; + constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; + + using WeiBlockCopySubLengths_E_K = Sequence<2, 1>; + using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>; + using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + + constexpr index_t WeiBlockCopySrcDataPerRead_E = 2; + constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; +#elif 0 + // cdata = 64, BlockSize = 256, 128x128x4 + constexpr index_t BlockSize = 256; + + constexpr index_t KPerBlock = 128; + constexpr index_t BPerBlock = 16; + constexpr index_t EPerBlock = 4; + + constexpr index_t GemmNRepeat = 2; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 8; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 2>; + using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 16, 2>; + using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] + using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] + + constexpr index_t InBlockCopySrcDataPerRead_B = 1; + constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2; + + using WeiBlockCopySubLengths_E_K = Sequence<2, 1>; + using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; + using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + + constexpr index_t WeiBlockCopySrcDataPerRead_E = 2; + constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; +#elif 1 + // cdata = 64, BlockSize = 256, 128x128x8 + constexpr index_t BlockSize = 256; + + constexpr index_t KPerBlock = 128; + constexpr index_t BPerBlock = 16; + constexpr index_t EPerBlock = 8; + + constexpr index_t GemmNRepeat = 2; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 8; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>; using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] @@ -91,25 +172,27 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 1 - // BlockSize = 256, EPerBlock = 16, each thread hold 64 data +#elif 0 + // cdata = 64, BlockSize = 256, 128x128x16 constexpr index_t BlockSize = 256; - constexpr index_t BPerBlock = 16; constexpr index_t KPerBlock = 128; + constexpr index_t BPerBlock = 16; constexpr index_t EPerBlock = 16; constexpr index_t GemmNRepeat = 2; - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; + + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>; using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<16, 1, 16, 1>; @@ -128,26 +211,28 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2; -#elif 1 - // BlockSize = 256, EPerBlock = 16, each thread hold 64 data +#elif 0 + // cdata = 4, BlockSize = 256, 128x128x16 // for 1x1 constexpr index_t BlockSize = 256; - constexpr index_t BPerBlock = 16; constexpr index_t KPerBlock = 128; + constexpr index_t BPerBlock = 16; constexpr index_t EPerBlock = 16; constexpr index_t GemmNRepeat = 2; - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; + + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; using InBlockCopySubLengths_E_N1_B_N2 = Sequence<4, 1, 1, 2>; using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 16, 2>; @@ -166,25 +251,261 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2; -#elif 1 - // BlockSize = 64, each thread hold 64 data - constexpr index_t BlockSize = 64; +#elif 0 + // cdata = 64, BlockSize = 128, 64x128x4 + constexpr index_t BlockSize = 128; - constexpr index_t BPerBlock = 8; constexpr index_t KPerBlock = 64; + constexpr index_t BPerBlock = 16; + constexpr index_t EPerBlock = 4; + + constexpr index_t GemmNRepeat = 2; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; + using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 16, 1>; + using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] + using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] + + constexpr index_t InBlockCopySrcDataPerRead_B = 1; + constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; + + using WeiBlockCopySubLengths_E_K = Sequence<2, 1>; + using WeiBlockCopyClusterLengths_E_K = Sequence<2, 64>; + using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + + constexpr index_t WeiBlockCopySrcDataPerRead_E = 2; + constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; +#elif 0 + // cdata = 64, BlockSize = 128, 64x128x8 + constexpr index_t BlockSize = 128; + + constexpr index_t KPerBlock = 64; + constexpr index_t BPerBlock = 16; constexpr index_t EPerBlock = 8; constexpr index_t GemmNRepeat = 2; - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>; + using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 16, 1>; + using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] + using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] + + constexpr index_t InBlockCopySrcDataPerRead_B = 1; + constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; + + using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; + using WeiBlockCopyClusterLengths_E_K = Sequence<2, 64>; + using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + + constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; + constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; +#elif 0 + // cdata = 64, BlockSize = 128, 64x128x16 + constexpr index_t BlockSize = 128; + + constexpr index_t KPerBlock = 64; + constexpr index_t BPerBlock = 16; + constexpr index_t EPerBlock = 16; + + constexpr index_t GemmNRepeat = 2; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; + + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockCopySubLengths_E_N1_B_N2 = Sequence<2, 2, 1, 4>; + using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 16, 1>; + using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] + using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] + + constexpr index_t InBlockCopySrcDataPerRead_B = 1; + constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; + + using WeiBlockCopySubLengths_E_K = Sequence<4, 2>; + using WeiBlockCopyClusterLengths_E_K = Sequence<4, 32>; + using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + + constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; + constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2; +#elif 0 + // cdata = 64, BlockSize = 128, 128x64x4 + constexpr index_t BlockSize = 128; + + constexpr index_t KPerBlock = 128; + constexpr index_t BPerBlock = 8; + constexpr index_t EPerBlock = 4; + + constexpr index_t GemmNRepeat = 2; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 8; + constexpr index_t GemmNLevel1Cluster = 4; + + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 2>; + using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 8, 2>; + using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] + using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] + + constexpr index_t InBlockCopySrcDataPerRead_B = 1; + constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2; + + using WeiBlockCopySubLengths_E_K = Sequence<2, 2>; + using WeiBlockCopyClusterLengths_E_K = Sequence<2, 64>; + using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + + constexpr index_t WeiBlockCopySrcDataPerRead_E = 2; + constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2; +#elif 0 + // cdata = 64, BlockSize = 128, 128x64x8 + constexpr index_t BlockSize = 128; + + constexpr index_t KPerBlock = 128; + constexpr index_t BPerBlock = 8; + constexpr index_t EPerBlock = 8; + + constexpr index_t GemmNRepeat = 2; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 8; + constexpr index_t GemmNLevel1Cluster = 4; + + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; + using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 8, 1>; + using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] + using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] + + constexpr index_t InBlockCopySrcDataPerRead_B = 1; + constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; + + using WeiBlockCopySubLengths_E_K = Sequence<4, 2>; + using WeiBlockCopyClusterLengths_E_K = Sequence<2, 64>; + using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + + constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; + constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2; +#elif 0 + // cdata = 64, BlockSize = 128, 128x64x16 + constexpr index_t BlockSize = 128; + + constexpr index_t KPerBlock = 128; + constexpr index_t BPerBlock = 8; + constexpr index_t EPerBlock = 16; + + constexpr index_t GemmNRepeat = 2; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 8; + constexpr index_t GemmNLevel1Cluster = 4; + + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>; + using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<16, 1, 8, 1>; + using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] + using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] + + constexpr index_t InBlockCopySrcDataPerRead_B = 1; + constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; + + using WeiBlockCopySubLengths_E_K = Sequence<4, 4>; + using WeiBlockCopyClusterLengths_E_K = Sequence<4, 32>; + using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + + constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; + constexpr index_t WeiBlockCopyDstDataPerWrite_K = 4; +#elif 0 + // cdata = 64, BlockSize = 64, 64x64x8 + constexpr index_t BlockSize = 64; + + constexpr index_t KPerBlock = 64; + constexpr index_t BPerBlock = 8; + constexpr index_t EPerBlock = 8; + + constexpr index_t GemmNRepeat = 2; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmMLevel1Cluster = 2; constexpr index_t GemmNLevel1Cluster = 2; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; + + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>; using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 8, 1>; @@ -204,24 +525,221 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; #elif 0 - // BlockSize = 256, blockwise-GEMM 64x128, each thread hold 32 data - constexpr index_t BlockSize = 256; + // cdata = 64, BlockSize = 32, 32x64x3 + constexpr index_t BlockSize = 32; + constexpr index_t KPerBlock = 32; + constexpr index_t BPerBlock = 8; + constexpr index_t EPerBlock = 3; + + constexpr index_t GemmNRepeat = 2; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 1; + constexpr index_t GemmNLevel1Cluster = 2; + + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockCopySubLengths_E_N1_B_N2 = Sequence<3, 1, 1, 2>; + using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<1, 2, 8, 2>; + using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] + using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] + + constexpr index_t InBlockCopySrcDataPerRead_B = 1; + constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2; + + using WeiBlockCopySubLengths_E_K = Sequence<3, 1>; + using WeiBlockCopyClusterLengths_E_K = Sequence<1, 32>; + using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + + constexpr index_t WeiBlockCopySrcDataPerRead_E = 1; + constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; +#elif 0 + // cdata = 64, BlockSize = 64, 32x128x3 + constexpr index_t BlockSize = 64; + + constexpr index_t KPerBlock = 32; constexpr index_t BPerBlock = 16; + constexpr index_t EPerBlock = 3; + + constexpr index_t GemmNRepeat = 2; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 1; + constexpr index_t GemmNLevel1Cluster = 4; + + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockCopySubLengths_E_N1_B_N2 = Sequence<3, 1, 1, 2>; + using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<1, 2, 16, 2>; + using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] + using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] + + constexpr index_t InBlockCopySrcDataPerRead_B = 1; + constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2; + + using WeiBlockCopySubLengths_E_K = Sequence<3, 1>; + using WeiBlockCopyClusterLengths_E_K = Sequence<1, 32>; + using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + + constexpr index_t WeiBlockCopySrcDataPerRead_E = 1; + constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; +#elif 0 + // cdata = 64, BlockSize = 64, 64x64x3 + constexpr index_t BlockSize = 64; + constexpr index_t KPerBlock = 64; + constexpr index_t BPerBlock = 8; + constexpr index_t EPerBlock = 3; + + constexpr index_t GemmNRepeat = 2; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockCopySubLengths_E_N1_B_N2 = Sequence<3, 1, 1, 1>; + using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<1, 2, 8, 4>; + using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] + using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] + + constexpr index_t InBlockCopySrcDataPerRead_B = 1; + constexpr index_t InBlockCopyDstDataPerWrite_N2 = 1; + + using WeiBlockCopySubLengths_E_K = Sequence<3, 1>; + using WeiBlockCopyClusterLengths_E_K = Sequence<1, 64>; + using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + + constexpr index_t WeiBlockCopySrcDataPerRead_E = 1; + constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; +#elif 0 + // cdata = 64, BlockSize = 64, 32x128x4 + constexpr index_t BlockSize = 64; + + constexpr index_t KPerBlock = 32; + constexpr index_t BPerBlock = 16; + constexpr index_t EPerBlock = 4; + + constexpr index_t GemmNRepeat = 2; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>; + using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 1, 16, 1>; + using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] + using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] + + constexpr index_t InBlockCopySrcDataPerRead_B = 1; + constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; + + using WeiBlockCopySubLengths_E_K = Sequence<2, 1>; + using WeiBlockCopyClusterLengths_E_K = Sequence<2, 32>; + using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + + constexpr index_t WeiBlockCopySrcDataPerRead_E = 2; + constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; +#elif 0 + // cdata = 64, BlockSize = 64, 32x128x8 + constexpr index_t BlockSize = 64; + + constexpr index_t KPerBlock = 32; + constexpr index_t BPerBlock = 16; constexpr index_t EPerBlock = 8; constexpr index_t GemmNRepeat = 2; - constexpr index_t GemmMPerThreadSubC = 2; - constexpr index_t GemmNPerThreadSubC = 4; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 1; + constexpr index_t GemmNLevel1Cluster = 4; + + constexpr index_t GemmDataPerReadA = 4; + constexpr index_t GemmDataPerReadB = 4; + + using InBlockCopySubLengths_E_N1_B_N2 = Sequence<2, 2, 1, 4>; + using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 1, 16, 1>; + using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] + using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] + using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] + + constexpr index_t InBlockCopySrcDataPerRead_B = 1; + constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; + + using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; + using WeiBlockCopyClusterLengths_E_K = Sequence<2, 32>; + using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] + using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] + + constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; + constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; +#elif 0 + // cdata = 32, BlockSize = 256, 64x128x8 + constexpr index_t BlockSize = 256; + + constexpr index_t KPerBlock = 64; + constexpr index_t BPerBlock = 16; + constexpr index_t EPerBlock = 8; + + constexpr index_t GemmNRepeat = 2; + + constexpr index_t GemmMPerThread = 2; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + constexpr index_t GemmMLevel0Cluster = 4; constexpr index_t GemmNLevel0Cluster = 4; constexpr index_t GemmMLevel1Cluster = 4; constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 2; - constexpr index_t GemmDataPerReadB = 4; + + constexpr index_t GemmDataPerReadA = 2; + constexpr index_t GemmDataPerReadB = 4; using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>; @@ -243,7 +761,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, #endif constexpr index_t N1 = GemmNRepeat; - constexpr index_t N2 = GemmNPerThreadSubC; + constexpr index_t N2 = GemmNPerThread; constexpr index_t B = (N * Ho * Wo) / (N1 * N2); @@ -252,72 +770,76 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - constexpr auto gridwise_conv = - GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer< - GridSize, - BlockSize, - T, - T, - decltype(in_nchw_desc), - decltype(wei_kcyx_desc), - decltype(out_nkhw_desc), - ConvStrides, - ConvDilations, - LeftPads, - RightPads, - ConvolutionDirection::Forward, - BPerBlock, - KPerBlock, - EPerBlock, - GemmNRepeat, - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - GemmDataPerReadA, - GemmDataPerReadB, - InBlockCopySubLengths_E_N1_B_N2, - InBlockCopyClusterLengths_E_N1_B_N2, - InBlockCopyThreadClusterArrangeOrder, - InBlockCopySrcAccessOrder, - InBlockCopyDstAccessOrder, - InBlockCopySrcDataPerRead_B, - InBlockCopyDstDataPerWrite_N2, - WeiBlockCopySubLengths_E_K, - WeiBlockCopyClusterLengths_E_K, - WeiBlockCopyThreadClusterArrangeOrder, - WeiBlockCopySrcAccessOrder, - WeiBlockCopyDstAccessOrder, - WeiBlockCopySrcDataPerRead_E, - WeiBlockCopyDstDataPerWrite_K>{}; + using gridwise_conv = GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer< + GridSize, + BlockSize, + T, + T, + decltype(in_nchw_desc), + decltype(wei_kcyx_desc), + decltype(out_nkhw_desc), + ConvStrides, + ConvDilations, + LeftPads, + RightPads, + BPerBlock, + KPerBlock, + EPerBlock, + GemmNRepeat, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmDataPerReadA, + GemmDataPerReadB, + InBlockCopySubLengths_E_N1_B_N2, + InBlockCopyClusterLengths_E_N1_B_N2, + InBlockCopyThreadClusterArrangeOrder, + InBlockCopySrcAccessOrder, + InBlockCopyDstAccessOrder, + InBlockCopySrcDataPerRead_B, + InBlockCopyDstDataPerWrite_N2, + WeiBlockCopySubLengths_E_K, + WeiBlockCopyClusterLengths_E_K, + WeiBlockCopyThreadClusterArrangeOrder, + WeiBlockCopySrcAccessOrder, + WeiBlockCopyDstAccessOrder, + WeiBlockCopySrcDataPerRead_E, + WeiBlockCopyDstDataPerWrite_K>; - for(index_t i = 0; i < nrepeat; ++i) + for(index_t i = 0; i < 5; ++i) { - float time = - launch_and_time_kernel(run_gridwise_operation, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - gridwise_conv, - const_cast( - static_cast(in_nchw_device_buf.GetDeviceBuffer())), - const_cast( - static_cast(wei_kcyx_device_buf.GetDeviceBuffer())), - const_cast( - static_cast(out_nkhw_device_buf.GetDeviceBuffer()))); + std::cout << "Start running " << nrepeat << " times..." << std::endl; - printf("Elapsed time : %f ms, %f TFlop/s\n", - time, - (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / time); - usleep(std::min(time * 1000, float(10000))); + KernelTimer timer; + timer.Start(); + + for(index_t j = 0; j < nrepeat; ++j) + { + launch_kernel(run_gridwise_operation, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + static_cast(in_nchw_device_buf.GetDeviceBuffer()), + static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), + static_cast(out_nkhw_device_buf.GetDeviceBuffer())); + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; } out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); diff --git a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated.hpp b/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated.hpp deleted file mode 100644 index ab309670c6..0000000000 --- a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated.hpp +++ /dev/null @@ -1,305 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "tensor.hpp" -#include "gridwise_convolution_kernel_wrapper.hpp" -#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp" - -template -void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcyx, - OutDesc, - Tensor& out_nkhw, - ConvStrides, - ConvDilations, - ck::index_t nrepeat) -{ - using namespace ck; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_desc = InDesc{}; - constexpr auto wei_kcyx_desc = WeiDesc{}; - constexpr auto out_nkhw_desc = OutDesc{}; - - constexpr index_t N = out_nkhw_desc.GetLength(I0); - constexpr index_t K = out_nkhw_desc.GetLength(I1); - constexpr index_t Ho = out_nkhw_desc.GetLength(I2); - constexpr index_t Wo = out_nkhw_desc.GetLength(I3); - - std::size_t data_sz = sizeof(T); - DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); - DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); - DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace()); - - in_nchw_device_buf.ToDevice(in_nchw.mData.data()); - wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); - out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); - -#if 0 - // BlockSize = 256, blockwise-GEMM 128x128, each thread hold 64 data - constexpr index_t BlockSize = 256; - - constexpr index_t BPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t EPerBlock = 8; - - constexpr index_t GemmNRepeat = 2; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; - - using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 0 - // BlockSize = 256, EPerBlock = 16, each thread hold 64 data - constexpr index_t BlockSize = 256; - - constexpr index_t BPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t EPerBlock = 16; - - constexpr index_t GemmNRepeat = 2; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<16, 1, 16, 1>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; - - using WeiBlockCopySubLengths_E_K = Sequence<4, 2>; - using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 0 - // BlockSize = 64, blockwise-GEMM 64x64, each thread hold 64 data - constexpr index_t BlockSize = 64; - - constexpr index_t BPerBlock = 8; - constexpr index_t KPerBlock = 64; - constexpr index_t EPerBlock = 8; - - constexpr index_t GemmNRepeat = 2; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 2; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 8, 1>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; - - using WeiBlockCopySubLengths_E_K = Sequence<4, 2>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 32>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 0 - // BlockSize = 256, blockwise-GEMM 64x128, each thread hold 32 data - constexpr index_t BlockSize = 256; - - constexpr index_t BPerBlock = 16; - constexpr index_t KPerBlock = 64; - constexpr index_t EPerBlock = 8; - - constexpr index_t GemmNRepeat = 2; - - constexpr index_t GemmMPerThreadSubC = 2; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 2; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; - - using WeiBlockCopySubLengths_E_K = Sequence<2, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 2; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 1 - constexpr index_t BlockSize = 64; - - constexpr index_t BPerBlock = 16; - constexpr index_t KPerBlock = 32; - constexpr index_t EPerBlock = 4; - - constexpr index_t GemmNRepeat = 2; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 1; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 1, 16, 1>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; - - using WeiBlockCopySubLengths_E_K = Sequence<1, 2>; - using WeiBlockCopyClusterLengths_E_K = Sequence<4, 16>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 1; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2; -#endif - - constexpr index_t N1 = GemmNRepeat; - constexpr index_t N2 = GemmNPerThreadSubC; - - constexpr index_t B = (N * Ho * Wo) / (N1 * N2); - - constexpr index_t GridSize = - ((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - constexpr auto gridwise_conv = - GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated< - GridSize, - BlockSize, - T, - T, - decltype(in_nchw_desc), - decltype(wei_kcyx_desc), - decltype(out_nkhw_desc), - ConvStrides, - ConvDilations, - ConvolutionDirection::Forward, - BPerBlock, - KPerBlock, - EPerBlock, - GemmNRepeat, - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - GemmDataPerReadA, - GemmDataPerReadB, - InBlockCopySubLengths_E_N1_B_N2, - InBlockCopyClusterLengths_E_N1_B_N2, - InBlockCopyThreadClusterArrangeOrder, - InBlockCopySrcAccessOrder, - InBlockCopyDstAccessOrder, - InBlockCopySrcDataPerRead_B, - InBlockCopyDstDataPerWrite_N2, - WeiBlockCopySubLengths_E_K, - WeiBlockCopyClusterLengths_E_K, - WeiBlockCopyThreadClusterArrangeOrder, - WeiBlockCopySrcAccessOrder, - WeiBlockCopyDstAccessOrder, - WeiBlockCopySrcDataPerRead_E, - WeiBlockCopyDstDataPerWrite_K>{}; - - for(index_t i = 0; i < nrepeat; ++i) - { - float time = - launch_and_time_kernel(run_gridwise_convolution_kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); - - printf("Elapsed time : %f ms, %f TFlop/s\n", - time, - (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / time); - usleep(std::min(time * 1000, float(10000))); - } - - out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); -} diff --git a/driver/include/device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 1a67f48477..0000000000 --- a/driver/include/device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,220 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "tensor.hpp" -#include "gridwise_convolution_kernel_wrapper.hpp" -#include "gridwise_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer.hpp" - -using namespace ck; - -template -void device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcyx, - OutDesc, - Tensor& out_nkhw, - ConvStrides, - ConvDilations, - index_t nrepeat) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_desc = InDesc{}; - constexpr auto wei_kcyx_desc = WeiDesc{}; - constexpr auto out_nkhw_desc = OutDesc{}; - - constexpr index_t Hi = in_nchw_desc.GetLength(I2); - constexpr index_t Wi = in_nchw_desc.GetLength(I3); - - constexpr index_t N = out_nkhw_desc.GetLength(I0); - constexpr index_t Ho = out_nkhw_desc.GetLength(I2); - constexpr index_t Wo = out_nkhw_desc.GetLength(I3); - - constexpr index_t K = wei_kcyx_desc.GetLength(I0); - constexpr index_t C = wei_kcyx_desc.GetLength(I1); - constexpr index_t Y = wei_kcyx_desc.GetLength(I2); - constexpr index_t X = wei_kcyx_desc.GetLength(I3); - - std::size_t data_sz = sizeof(T); - DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); - DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); - DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace()); - - in_nchw_device_buf.ToDevice(in_nchw.mData.data()); - wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); - out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); - -#if 0 - // 1x1 filter, 8x8 image - constexpr index_t N0 = 1; - constexpr index_t Ho0 = 2; - constexpr index_t Wo0 = 1; - - constexpr index_t N2 = 4; - constexpr index_t Ho2 = 1; - constexpr index_t Wo2 = 1; - - constexpr index_t BlockSize = 256; - - constexpr index_t BPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t EPerBlock = 8; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2 = Sequence<1, 1, 1, 1, 1, 4, 1, 1>; - using InBlockCopyClusterLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2 = Sequence<8, 1, 2, 1, 16, 1, 1, 1>; - using InBlockCopyThreadClusterArrangeOrder = - Sequence<0, 1, 5, 2, 6, 3, 4, 7>; // [E, N0, N2, Ho0, Ho2, Wo0, B, Wo2] - using InBlockCopySrcAccessOrder = - Sequence<0, 1, 5, 2, 6, 3, 4, 7>; // [E, N0, N2, Ho0, Ho2, Wo0, B, Wo2] - using InBlockCopyDstAccessOrder = - Sequence<0, 1, 2, 3, 4, 5, 6, 7>; // [E, N0, Ho0, Wo0, B, N2, Ho2, Wo2] - - constexpr index_t InBlockCopyDataPerAccess_W2 = 1; - - using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 1 - // 1x1 filter, 8x8 image - constexpr index_t N0 = 1; - constexpr index_t Ho0 = 2; - constexpr index_t Wo0 = 1; - - constexpr index_t N2 = 2; - constexpr index_t Ho2 = 2; - constexpr index_t Wo2 = 1; - - constexpr index_t BlockSize = 256; - - constexpr index_t BPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t EPerBlock = 8; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2 = Sequence<1, 1, 2, 1, 1, 2, 1, 1>; - using InBlockCopyClusterLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2 = Sequence<8, 1, 1, 1, 16, 1, 2, 1>; - using InBlockCopyThreadClusterArrangeOrder = - Sequence<0, 1, 5, 2, 6, 3, 4, 7>; // [E, N0, N2, Ho0, Ho2, Wo0, B, Wo2] - using InBlockCopySrcAccessOrder = - Sequence<0, 1, 5, 2, 6, 3, 4, 7>; // [E, N0, N2, Ho0, Ho2, Wo0, B, Wo2] - using InBlockCopyDstAccessOrder = - Sequence<0, 1, 2, 3, 4, 5, 6, 7>; // [E, N0, Ho0, Wo0, B, N2, Ho2, Wo2] - - constexpr index_t InBlockCopyDataPerAccess_W2 = 1; - - using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#endif - - constexpr index_t N1 = N / (N0 * N2); - constexpr index_t Ho1 = Ho / (Ho0 * Ho2); - constexpr index_t Wo1 = Wo / (Wo0 * Wo2); - - constexpr index_t B = N1 * Ho1 * Wo1; - - constexpr index_t GridSize = - ((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - for(index_t i = 0; i < nrepeat; ++i) - { - constexpr auto gridwise_conv = - GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer< - GridSize, - BlockSize, - T, - decltype(in_nchw_desc), - decltype(wei_kcyx_desc), - decltype(out_nkhw_desc), - ConvStrides, - ConvDilations, - N1, - N2, - Ho1, - Ho2, - Wo1, - Wo2, - BPerBlock, - KPerBlock, - EPerBlock, - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - GemmDataPerReadA, - GemmDataPerReadB, - InBlockCopySubLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2, - InBlockCopyClusterLengths_E_N0_Ho0_Wo0_B_N2_Ho2_Wo2, - InBlockCopyThreadClusterArrangeOrder, - InBlockCopySrcAccessOrder, - InBlockCopyDstAccessOrder, - InBlockCopyDataPerAccess_W2, - WeiBlockCopySubLengths_E_K, - WeiBlockCopyClusterLengths_E_K, - WeiBlockCopyThreadClusterArrangeOrder, - WeiBlockCopySrcAccessOrder, - WeiBlockCopyDstAccessOrder, - WeiBlockCopySrcDataPerRead_E, - WeiBlockCopyDstDataPerWrite_K>{}; - - float time = - launch_and_time_kernel(run_gridwise_convolution_kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); - - printf("Elapsed time : %f ms, %f TFlop/s\n", - time, - (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / time); - usleep(std::min(time * 1000, float(10000))); - } - - out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); -} diff --git a/driver/include/device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp deleted file mode 100644 index f905eaec5a..0000000000 --- a/driver/include/device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,178 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "tensor.hpp" -#include "gridwise_convolution_kernel_wrapper.hpp" -#include "gridwise_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer.hpp" - -using namespace ck; - -template -void device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcyx, - OutDesc, - Tensor& out_nkhw, - ConvStrides, - ConvDilations, - index_t nrepeat) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_desc = InDesc{}; - constexpr auto wei_kcyx_desc = WeiDesc{}; - constexpr auto out_nkhw_desc = OutDesc{}; - - constexpr index_t Hi = in_nchw_desc.GetLength(I2); - constexpr index_t Wi = in_nchw_desc.GetLength(I3); - - constexpr index_t N = out_nkhw_desc.GetLength(I0); - constexpr index_t Ho = out_nkhw_desc.GetLength(I2); - constexpr index_t Wo = out_nkhw_desc.GetLength(I3); - - constexpr index_t K = wei_kcyx_desc.GetLength(I0); - constexpr index_t C = wei_kcyx_desc.GetLength(I1); - constexpr index_t Y = wei_kcyx_desc.GetLength(I2); - constexpr index_t X = wei_kcyx_desc.GetLength(I3); - - std::size_t data_sz = sizeof(T); - DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); - DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); - DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace()); - - in_nchw_device_buf.ToDevice(in_nchw.mData.data()); - wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); - out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); - -#if 1 - // 1x1 filter, 8x8 image - constexpr index_t N1 = 2; - constexpr index_t Ho1 = 1; - constexpr index_t Wo1 = 1; - - constexpr index_t N2 = 1; - constexpr index_t Ho2 = 1; - constexpr index_t Wo2 = 4; - - constexpr index_t BlockSize = 256; - - constexpr index_t BPerBlock = 16; - constexpr index_t KPerBlock = 128; - constexpr index_t EPerBlock = 8; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2 = Sequence<1, 1, 1, 1, 1, 1, 1, 4>; - using InBlockCopyClusterLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2 = Sequence<8, 2, 1, 1, 16, 1, 1, 1>; - using InBlockCopyThreadClusterArrangeOrder = - Sequence<0, 1, 5, 2, 6, 3, 4, 7>; // [E, N1, N2, Ho1, Ho2, Wo1, B, Wo2] - using InBlockCopySrcAccessOrder = - Sequence<0, 1, 5, 2, 6, 3, 4, 7>; // [E, N1, N2, Ho1, Ho2, Wo1, B, Wo2] - using InBlockCopyDstAccessOrder = - Sequence<0, 1, 2, 3, 4, 5, 6, 7>; // [E, N1, Ho1, Wo1, B, N2, Ho2, Wo2] - - constexpr index_t InBlockCopyDataPerAccess_W2 = 4; - - using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#endif - - constexpr index_t N0 = N / (N1 * N2); - constexpr index_t Ho0 = Ho / (Ho1 * Ho2); - constexpr index_t Wo0 = Wo / (Wo1 * Wo2); - - constexpr index_t B = N0 * Ho0 * Wo0; - - constexpr index_t GridSize = - ((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - for(index_t i = 0; i < nrepeat; ++i) - { - constexpr auto gridwise_conv = - GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer< - GridSize, - BlockSize, - T, - decltype(in_nchw_desc), - decltype(wei_kcyx_desc), - decltype(out_nkhw_desc), - ConvStrides, - ConvDilations, - N0, - N1, - N2, - Ho0, - Ho1, - Ho2, - Wo0, - Wo1, - Wo2, - BPerBlock, - KPerBlock, - EPerBlock, - GemmMPerThreadSubC, - GemmNPerThreadSubC, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmKPerThreadLoop, - GemmDataPerReadA, - GemmDataPerReadB, - InBlockCopySubLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2, - InBlockCopyClusterLengths_E_N1_Ho1_Wo1_B_N2_Ho2_Wo2, - InBlockCopyThreadClusterArrangeOrder, - InBlockCopySrcAccessOrder, - InBlockCopyDstAccessOrder, - InBlockCopyDataPerAccess_W2, - WeiBlockCopySubLengths_E_K, - WeiBlockCopyClusterLengths_E_K, - WeiBlockCopyThreadClusterArrangeOrder, - WeiBlockCopySrcAccessOrder, - WeiBlockCopyDstAccessOrder, - WeiBlockCopySrcDataPerRead_E, - WeiBlockCopyDstDataPerWrite_K>{}; - - float time = - launch_and_time_kernel(run_gridwise_convolution_kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); - - printf("Elapsed time : %f ms, %f TFlop/s\n", - time, - (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / time); - usleep(std::min(time * 1000, float(10000))); - } - - out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); -} diff --git a/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp index f775054b58..4c887e9322 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -1,8 +1,7 @@ -#pragma once #include #include "device.hpp" -#include "tensor.hpp" -#include "gridwise_convolution_kernel_wrapper.hpp" +#include "host_tensor.hpp" +#include "gridwise_operation_wrapper.hpp" #include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" template ::value, half_t, T>::type; + constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; @@ -54,20 +55,88 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); #if 0 - // BlockSize = 256, GemmKPerBlock = 8 + // cdata = 64, BlockSize = 256, 64x256x8 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 64; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 8; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 16; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 2; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 256>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 0 + // cdata = 64, BlockSize = 256, 128x128x4 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 2; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 1 + // cdata = 64, BlockSize = 256, 128x128x8 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 8; + constexpr index_t GemmNLevel1Cluster = 8; + constexpr index_t ThreadGemmDataPerReadM = 4; constexpr index_t ThreadGemmDataPerReadN = 4; @@ -85,20 +154,56 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; #elif 0 - // BlockSize = 256, GemmKPerBlock = 16 + // cdata = 64, BlockSize = 256, 128x128x8 + // vector 4 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 8; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 8; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; +#elif 0 + // cdata = 64, BlockSize = 256, 128x128x16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 16; - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t ThreadGemmDataPerReadM = 4; constexpr index_t ThreadGemmDataPerReadN = 4; @@ -116,21 +221,24 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; #elif 0 - // BlockSize = 256, GemmKPerBlock = 8 - // for 1x1 filter, vector-read-b = 4 + // cdata = 64, BlockSize = 256, 128x128x8 + // GemmBBlockCopySrcDataPerRead_GemmN = 4 + // GemmCThreadCopyDstDataPerWrite_GemmN1 = 4 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t ThreadGemmDataPerReadM = 4; constexpr index_t ThreadGemmDataPerReadN = 4; @@ -147,22 +255,25 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; -#elif 1 - // BlockSize = 256, GemmKPerBlock = 16 - // for 1x1 filter, vector-read-b = 4 +#elif 0 + // cdata = 64, BlockSize = 256, 128x128x16 + // GemmBBlockCopySrcDataPerRead_GemmN = 4 + // GemmCThreadCopyDstDataPerWrite_GemmN1 = 4 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 16; - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t ThreadGemmDataPerReadM = 4; constexpr index_t ThreadGemmDataPerReadN = 4; @@ -179,37 +290,674 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; -#elif 1 - // 1x1 filter, 14x14 image - constexpr index_t BlockSize = 256; +#elif 0 + // cdata = 64, BlockSize = 128, 128x64x4 + constexpr index_t BlockSize = 128; constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 8; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; constexpr index_t ThreadGemmDataPerReadM = 4; constexpr index_t ThreadGemmDataPerReadN = 4; using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<1, 128>; constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>; + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 64>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 0 + // cdata = 64, BlockSize = 128, 128x64x4 + // GemmBBlockCopySrcDataPerRead_GemmN = 2 + // GemmCThreadCopyDstDataPerWrite_GemmN1 = 2 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<1, 128>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 2>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 32>; constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 2; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 2; +#elif 0 + // cdata = 64, BlockSize = 128, 128x64x8 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 8; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 64>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 2; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 64>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 0 + // cdata = 64, BlockSize = 128, 128x64x8 + // GemmBBlockCopySrcDataPerRead_GemmN = 4 + // GemmCThreadCopyDstDataPerWrite_GemmN1 = 4 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 8; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 64>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 2; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 16>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; +#elif 0 + // cdata = 64, BlockSize = 128, 128x64x16 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 16; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 4>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 64>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 0 + // cdata = 64, BlockSize = 128, 64x128x4 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 64; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 64>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 2; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 0 + // BlockSize = 128, 64x128x4 + // GemmBBlockCopySrcDataPerRead_GemmN = 4 + // GemmCThreadCopyDstDataPerWrite_GemmN1 = 4 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 64; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 64>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 2; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 32>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; +#elif 0 + // cdata = 64, BlockSize = 128, 64x128x8 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 64; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 8; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 64>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 0 + // BlockSize = 128, 64x128x8 + // GemmBBlockCopySrcDataPerRead_GemmN = 4 + // GemmCThreadCopyDstDataPerWrite_GemmN1 = 4 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 64; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 8; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 64>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 32>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; +#elif 0 + // BlockSize = 128, 64x128x16 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 64; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 16; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<8, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 64>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<16, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 0 + // cdata = 64, BlockSize = 64, 64x64x4 + constexpr index_t BlockSize = 64; + + constexpr index_t GemmMPerBlock = 64; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 2; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<1, 64>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 0 + // cdata = 64, BlockSize = 64, 64x64x8 + constexpr index_t BlockSize = 64; + + constexpr index_t GemmMPerBlock = 64; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 8; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 2; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 2; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 0 + // cdata = 64, BlockSize = 64, 32x128x2 + constexpr index_t BlockSize = 64; + + constexpr index_t GemmMPerBlock = 32; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 2; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 0 + // cdata = 64, BlockSize = 64, 32x128x4 + constexpr index_t BlockSize = 64; + + constexpr index_t GemmMPerBlock = 32; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 2; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 0 + // cdata = 64, BlockSize = 32, 32x64x3 + constexpr index_t BlockSize = 32; + + constexpr index_t GemmMPerBlock = 32; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 3; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<3, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<1, 32>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<3, 2>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 32>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 0 + // cdata = 64, BlockSize = 64, 32x128x3 + constexpr index_t BlockSize = 64; + + constexpr index_t GemmMPerBlock = 32; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 3; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<3, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<1, 32>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<3, 2>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 1 + // cdata = 64, BlockSize = 64, 64x64x3 + constexpr index_t BlockSize = 64; + + constexpr index_t GemmMPerBlock = 64; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 3; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<3, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<1, 64>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<3, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 0 + // cdata = 64, BlockSize = 64, 32x128x8 + constexpr index_t BlockSize = 64; + + constexpr index_t GemmMPerBlock = 32; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 8; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 2>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 1 + // cdata = 32, BlockSize = 128, 32x128x8 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 32; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 8; + + constexpr index_t GemmMPerThread = 2; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; + + constexpr index_t ThreadGemmDataPerReadM = 2; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 2; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 0 + // cdata = 32, BlockSize = 128, 32x128x16 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 32; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 16; + + constexpr index_t GemmMPerThread = 2; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 4; + + constexpr index_t ThreadGemmDataPerReadM = 2; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<16, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; #endif constexpr index_t GemmM = K; @@ -220,11 +968,11 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - constexpr auto gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw< + using gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw< GridSize, BlockSize, - T, - T, + TDevice, + TDevice, decltype(in_nchw_desc), decltype(wei_kcyx_desc), decltype(out_nkhw_desc), @@ -235,13 +983,13 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, GemmMPerBlock, GemmNPerBlock, GemmKPerBlock, - GemmMPerThreadSubC, - GemmNPerThreadSubC, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, GemmMLevel0Cluster, GemmNLevel0Cluster, GemmMLevel1Cluster, GemmNLevel1Cluster, - GemmKPerThreadLoop, ThreadGemmDataPerReadM, ThreadGemmDataPerReadN, GemmABlockCopyThreadSliceLengths_GemmK_GemmM, @@ -252,25 +1000,38 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, GemmBBlockCopySrcDataPerRead_GemmN, GemmBBlockCopyDstDataPerWrite_GemmN, - GemmCThreadCopyDstDataPerWrite_GemmN1>{}; + GemmCThreadCopyDstDataPerWrite_GemmN1>; - for(index_t i = 0; i < nrepeat; ++i) + for(index_t i = 0; i < 5; ++i) { - float time = - launch_and_time_kernel(run_gridwise_convolution_kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); + std::cout << "Start running " << nrepeat << " times..." << std::endl; - printf("Elapsed time : %f ms, %f TFlop/s\n", - time, - (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / time); - usleep(std::min(time * 1000, float(10000))); + KernelTimer timer; + timer.Start(); + + for(index_t j = 0; j < nrepeat; ++j) + { + launch_kernel(run_gridwise_operation, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + static_cast(in_nchw_device_buf.GetDeviceBuffer()), + static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), + static_cast(out_nkhw_device_buf.GetDeviceBuffer())); + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; } out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); diff --git a/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated.hpp b/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated.hpp deleted file mode 100644 index 646d59dbf4..0000000000 --- a/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated.hpp +++ /dev/null @@ -1,225 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "tensor.hpp" -#include "gridwise_convolution_kernel_wrapper.hpp" -#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp" - -using namespace ck; - -template -void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcyx, - OutDesc, - Tensor& out_nkhw, - ConvStrides, - ConvDilations, - ck::index_t nrepeat) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_desc = InDesc{}; - constexpr auto wei_kcyx_desc = WeiDesc{}; - constexpr auto out_nkhw_desc = OutDesc{}; - - constexpr index_t N = out_nkhw_desc.GetLength(I0); - constexpr index_t K = out_nkhw_desc.GetLength(I1); - constexpr index_t Ho = out_nkhw_desc.GetLength(I2); - constexpr index_t Wo = out_nkhw_desc.GetLength(I3); - - std::size_t data_sz = sizeof(T); - DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); - DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); - DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace()); - - in_nchw_device_buf.ToDevice(in_nchw.mData.data()); - wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); - out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); - -#if 1 - constexpr index_t BlockSize = 256; - - constexpr index_t BPerBlock = 128; - constexpr index_t KPerBlock = 128; - constexpr index_t EPerBlock = 8; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_B = Sequence<4, 1>; - using InBlockCopyClusterLengths_E_B = Sequence<2, 128>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B] - using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B] - using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B] - - constexpr index_t InBlockCopyDataPerAccess_B = 1; - - using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; - - constexpr index_t OutThreadCopyDataPerAccess_B = 1; -#elif 1 - // 1x1 filter, 8x8 image - constexpr index_t BlockSize = 256; - - constexpr index_t BPerBlock = 128; - constexpr index_t KPerBlock = 128; - constexpr index_t EPerBlock = 8; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_B = Sequence<1, 4>; - using InBlockCopyClusterLengths_E_B = Sequence<8, 32>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B] - using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B] - using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B] - - constexpr index_t InBlockCopyDataPerAccess_B = 4; - - using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; - - constexpr index_t OutThreadCopyDataPerAccess_B = 4; -#elif 0 - // 1x1 filter, 14x14 image - constexpr index_t BlockSize = 256; - - constexpr index_t BPerBlock = 128; - constexpr index_t KPerBlock = 128; - constexpr index_t EPerBlock = 8; - - constexpr index_t GemmMPerThreadSubC = 4; - constexpr index_t GemmNPerThreadSubC = 4; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmKPerThreadLoop = 1; - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_B = Sequence<2, 2>; - using InBlockCopyClusterLengths_E_B = Sequence<4, 64>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1>; // [E, B] - using InBlockCopySrcAccessOrder = Sequence<0, 1>; // [E, B] - using InBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, B] - - constexpr index_t InBlockCopyDataPerAccess_B = 2; - - using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; - - constexpr index_t OutThreadCopyDataPerAccess_B = 2; -#endif - - constexpr index_t B = N * Ho * Wo; - - constexpr index_t GridSize = - ((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - constexpr auto gridwise_conv = -#if 0 - GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw -#else - GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_deprecated -#endif - {}; - - for(index_t i = 0; i < nrepeat; ++i) - { - float time = - launch_and_time_kernel(run_gridwise_convolution_kernel, - dim3(GridSize), - dim3(BlockSize), - 0, - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); - - printf("Elapsed time : %f ms, %f TFlop/s\n", - time, - (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / time); - usleep(std::min(time * 1000, float(10000))); - } - - out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); -} diff --git a/driver/include/device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp b/driver/include/device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 7158032e8e..0000000000 --- a/driver/include/device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,214 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "tensor.hpp" -#include "gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp" - -using namespace ck; - -template -void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcyx, - OutDesc, - Tensor& out_nkhw, - index_t nrepeat) -{ - // this suppose in / wei data type is int8x4 - constexpr index_t NVector = 4; - using accum_t = int32_t; - using vector_t = vector_type; - using vector_mem_t = typename vector_t::MemoryType; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_desc = InDesc{}; - constexpr auto wei_kcyx_desc = WeiDesc{}; - constexpr auto out_nkhw_desc = OutDesc{}; - - constexpr index_t Hi = in_nchw_desc.GetLength(I2); - constexpr index_t Wi = in_nchw_desc.GetLength(I3); - - constexpr index_t N = out_nkhw_desc.GetLength(I0); - constexpr index_t Ho = out_nkhw_desc.GetLength(I2); - constexpr index_t Wo = out_nkhw_desc.GetLength(I3); - - constexpr index_t K = wei_kcyx_desc.GetLength(I0); - constexpr index_t C = wei_kcyx_desc.GetLength(I1); - constexpr index_t Y = wei_kcyx_desc.GetLength(I2); - constexpr index_t X = wei_kcyx_desc.GetLength(I3); - - // vectorized input - auto in_nchw_vec_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(in_nchw_vec_desc, std::cout << "in_nchw_vec_desc: "); - - Tensor in_nchw_vec(make_TensorDescriptor(in_nchw_vec_desc)); - - auto f_vectorized_nchw = [&](auto n, auto c, auto h, auto w) { -#if 0 - in_nchw_vec(n, c, h, w) = in_nchw(n, c, h, w); -#elif 0 - in_nchw_vec(n, c, h, w) = - vector_t::Pack(in_nchw(n, 2 * c, h, w), in_nchw(n, 2 * c + 1, h, w)); -#elif 1 - in_nchw_vec(n, c, h, w) = vector_t::Pack(in_nchw(n, 4 * c, h, w), - in_nchw(n, 4 * c + 1, h, w), - in_nchw(n, 4 * c + 2, h, w), - in_nchw(n, 4 * c + 3, h, w)); -#endif - }; - - make_ParallelTensorFunctor(f_vectorized_nchw, N, C / NVector, Hi, Wi)( - std::thread::hardware_concurrency()); - - // vectorize weight - auto wei_kcyx_vec_desc = make_ConstantTensorDescriptor(Sequence{}); - ostream_ConstantTensorDescriptor(wei_kcyx_vec_desc, std::cout << "wei_kcyx_vec_desc: "); - - Tensor wei_kcyx_vec(make_TensorDescriptor(wei_kcyx_vec_desc)); - - auto f_vectorized_kcyx = [&](auto k, auto c, auto y, auto x) { -#if 0 - wei_kcyx_vec(k, c, y, x) = wei_kcyx(k, c, y, x); -#elif 0 - wei_kcyx_vec(k, c, y, x) = - vector_t::Pack(wei_kcyx(k, 2 * c, y, x), wei_kcyx(k, 2 * c + 1, y, x)); -#elif 1 - wei_kcyx_vec(k, c, y, x) = vector_t::Pack(wei_kcyx(k, 4 * c, y, x), - wei_kcyx(k, 4 * c + 1, y, x), - wei_kcyx(k, 4 * c + 2, y, x), - wei_kcyx(k, 4 * c + 3, y, x)); -#endif - }; - - make_ParallelTensorFunctor(f_vectorized_kcyx, K, C / NVector, Y, X)( - std::thread::hardware_concurrency()); - - // - DeviceMem in_nchw_vec_device_buf(sizeof(vector_mem_t) * in_nchw_vec.mDesc.GetElementSpace()); - DeviceMem wei_kcyx_vec_device_buf(sizeof(vector_mem_t) * wei_kcyx_vec.mDesc.GetElementSpace()); - DeviceMem out_nkhw_device_buf(sizeof(TOut) * out_nkhw.mDesc.GetElementSpace()); - - in_nchw_vec_device_buf.ToDevice(in_nchw_vec.mData.data()); - wei_kcyx_vec_device_buf.ToDevice(wei_kcyx_vec.mData.data()); - out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); - -#if 0 - // 3x3, 34x34, 128 thread, fp32, vector = 1 - constexpr index_t NPerBlock = 2; - constexpr index_t KPerBlock = 32; - constexpr index_t CPerBlock = 4; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 32; - - constexpr index_t NPerThread = 2; - constexpr index_t KPerThread = 4; - constexpr index_t CPerThread = 2; - constexpr index_t HoPerThread = 2; - constexpr index_t WoPerThread = 2; - - constexpr index_t InBlockCopyDataPerRead = 2; - constexpr index_t WeiBlockCopyDataPerRead = 2; - - constexpr index_t BlockSize = 128; -#elif 0 - // 3x3, 34x34, 128 thread, fp32, vector = 2 - constexpr index_t NPerBlock = 2; - constexpr index_t KPerBlock = 32; - constexpr index_t CPerBlock = 2; - constexpr index_t HoPerBlock = 2; - constexpr index_t WoPerBlock = 32; - - constexpr index_t NPerThread = 2; - constexpr index_t KPerThread = 4; - constexpr index_t CPerThread = 1; - constexpr index_t HoPerThread = 2; - constexpr index_t WoPerThread = 2; - - constexpr index_t InBlockCopyDataPerRead = 2; - constexpr index_t WeiBlockCopyDataPerRead = 2; - - constexpr index_t BlockSize = 128; -#elif 0 - // 3x3, 34x34, 128 thread, int8, vector = 4 - constexpr index_t NPerBlock = 2; - constexpr index_t KPerBlock = 32; - constexpr index_t CPerBlock = 8; - constexpr index_t HoPerBlock = 4; - constexpr index_t WoPerBlock = 32; - - constexpr index_t NPerThread = 1; - constexpr index_t KPerThread = 8; - constexpr index_t CPerThread = 2; - constexpr index_t HoPerThread = 4; - constexpr index_t WoPerThread = 2; - - constexpr index_t InBlockCopyDataPerRead = 2; - constexpr index_t WeiBlockCopyDataPerRead = 2; - - constexpr index_t BlockSize = 128; -#elif 1 - // 1x1, 32x32, 128 thread, int8, vector = 4 - constexpr index_t NPerBlock = 1; - constexpr index_t KPerBlock = 64; - constexpr index_t CPerBlock = 16; - constexpr index_t HoPerBlock = 4; - constexpr index_t WoPerBlock = 32; - - constexpr index_t NPerThread = 1; - constexpr index_t KPerThread = 8; - constexpr index_t CPerThread = 2; - constexpr index_t HoPerThread = 4; - constexpr index_t WoPerThread = 2; - - constexpr index_t InBlockCopyDataPerRead = 2; - constexpr index_t WeiBlockCopyDataPerRead = 2; - - constexpr index_t BlockSize = 128; -#endif - - constexpr index_t GridSize = - (N / NPerBlock) * (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - for(index_t i = 0; i < nrepeat; ++i) - { - float time = launch_and_time_kernel( - gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw, - dim3(GridSize), - dim3(BlockSize), - static_cast(in_nchw_vec_device_buf.GetDeviceBuffer()), - static_cast(wei_kcyx_vec_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); - - printf("Elapsed time : %f ms\n", time); - usleep(std::min(time * 1000, float(10000))); - } - - out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); -} diff --git a/driver/include/device_tensor.hpp b/driver/include/device_tensor.hpp index e3a1c25821..07d98f87a3 100644 --- a/driver/include/device_tensor.hpp +++ b/driver/include/device_tensor.hpp @@ -1,28 +1,26 @@ #pragma once -#include "tensor.hpp" +#include "host_tensor.hpp" #include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" #include "tensor_descriptor.hpp" -template -auto make_TensorDescriptor_impl(ConstTensorDesc, std::integer_sequence) +template +auto make_HostTensorDescriptor_impl(TensorDesc, std::integer_sequence) { - std::initializer_list lengths = {ConstTensorDesc::GetLengths()[Is]...}; - std::initializer_list strides = {ConstTensorDesc::GetStrides()[Is]...}; + std::initializer_list lengths = {TensorDesc::GetLengths()[Is]...}; + std::initializer_list strides = {TensorDesc::GetStrides()[Is]...}; - return TensorDescriptor(lengths, strides); + return HostTensorDescriptor(lengths, strides); } -template -auto make_TensorDescriptor(ConstTensorDesc) +template +auto make_HostTensorDescriptor(TensorDesc) { - return make_TensorDescriptor_impl( - ConstTensorDesc{}, - std::make_integer_sequence{}); + return make_HostTensorDescriptor_impl( + TensorDesc{}, std::make_integer_sequence{}); } -template -void ostream_ConstantTensorDescriptor(ConstTensorDesc, std::ostream& os = std::cout) +template +void ostream_tensor_descriptor(TensorDesc, std::ostream& os = std::cout) { - ostream_TensorDescriptor(make_TensorDescriptor(ConstTensorDesc{}), os); + ostream_HostTensorDescriptor(make_HostTensorDescriptor(TensorDesc{}), os); } diff --git a/driver/include/host_col2im.hpp b/driver/include/host_col2im.hpp deleted file mode 100644 index e23540d8e0..0000000000 --- a/driver/include/host_col2im.hpp +++ /dev/null @@ -1,71 +0,0 @@ -#pragma once -#include "tensor.hpp" - -template -void host_col2im(const Tensor& in_eb, - Tensor& in_nchw, - FilterSizes, - OutputSizes, - ConvStrides, - ConvDilations, - LeftPads, - RightPads) -{ - using namespace ck; - - int N = in_nchw.mDesc.GetLengths()[0]; - int C = in_nchw.mDesc.GetLengths()[1]; - int HI = in_nchw.mDesc.GetLengths()[2]; - int WI = in_nchw.mDesc.GetLengths()[3]; - - int Y = FilterSizes{}[0]; - int X = FilterSizes{}[1]; - - int HO = OutputSizes{}[0]; - int WO = OutputSizes{}[1]; - - auto f = [&](auto n, auto c, auto hi, auto wi) { - double v = 0; - - for(int y = 0; y < Y; ++y) - { - int h_tmp = hi + LeftPads{}[0] - y * ConvDilations{}[0]; - - if(h_tmp >= 0 && h_tmp < HI && h_tmp % ConvStrides{}[0] == 0) - { - int ho = h_tmp / ConvStrides{}[0]; - - for(int x = 0; x < X; ++x) - { - int w_tmp = wi + LeftPads{}[1] - x * ConvDilations{}[1]; - - if(w_tmp >= 0 && w_tmp < WI && w_tmp % ConvStrides{}[1] == 0) - { - int wo = w_tmp / ConvStrides{}[1]; - - int e = c * (Y * X) + y * X + x; - int b = n * (HO * WO) + ho * WO + wo; - - v += in_eb(e, b); - } - } - } - } - - in_nchw(n, c, hi, wi) = v; - }; - - auto f_par = make_ParallelTensorFunctor(f, - in_nchw.mDesc.GetLengths()[0], - in_nchw.mDesc.GetLengths()[1], - in_nchw.mDesc.GetLengths()[2], - in_nchw.mDesc.GetLengths()[3]); - - f_par(std::thread::hardware_concurrency()); -} diff --git a/driver/include/host_conv.hpp b/driver/include/host_conv.hpp index ab932bb2c6..5ce822e70a 100644 --- a/driver/include/host_conv.hpp +++ b/driver/include/host_conv.hpp @@ -1,5 +1,5 @@ #pragma once -#include "tensor.hpp" +#include "host_tensor.hpp" template & in_nchw, if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 && wi < in_nchw.mDesc.GetLengths()[3]) { - v += double(in_nchw(n, c, hi, wi)) * double(wei_kcyx(k, c, y, x)); + v += static_cast(in_nchw(n, c, hi, wi)) * + static_cast(wei_kcyx(k, c, y, x)); } } } diff --git a/driver/include/host_conv_bwd_data.hpp b/driver/include/host_conv_bwd_data.hpp index ce0fb789c9..fbcfcd004f 100644 --- a/driver/include/host_conv_bwd_data.hpp +++ b/driver/include/host_conv_bwd_data.hpp @@ -1,5 +1,5 @@ #pragma once -#include "tensor.hpp" +#include "host_tensor.hpp" template #include @@ -65,26 +65,26 @@ auto construct_f_unpack_args(F, T args) return construct_f_unpack_args_impl(args, std::make_index_sequence{}); } -struct TensorDescriptor +struct HostTensorDescriptor { - TensorDescriptor() = delete; + HostTensorDescriptor() = delete; template - TensorDescriptor(std::vector lens); + HostTensorDescriptor(std::vector lens); template - TensorDescriptor(std::vector lens, std::vector strides); + HostTensorDescriptor(std::vector lens, std::vector strides); void CalculateStrides(); template - TensorDescriptor(const Range& lens) : mLens(lens.begin(), lens.end()) + HostTensorDescriptor(const Range& lens) : mLens(lens.begin(), lens.end()) { this->CalculateStrides(); } template - TensorDescriptor(const Range1& lens, const Range2& strides) + HostTensorDescriptor(const Range1& lens, const Range2& strides) : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) { } @@ -205,7 +205,7 @@ struct Tensor { } - Tensor(const TensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {} + Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {} template void GenerateTensorValue(G g, std::size_t num_thread = 1) @@ -267,11 +267,11 @@ struct Tensor typename std::vector::const_iterator end() const { return mData.end(); } - TensorDescriptor mDesc; + HostTensorDescriptor mDesc; std::vector mData; }; -void ostream_TensorDescriptor(const TensorDescriptor& desc, std::ostream& os = std::cout) +void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout) { os << "dim " << desc.GetNumOfDimension() << ", "; diff --git a/driver/include/tensor_generator.hpp b/driver/include/host_tensor_generator.hpp similarity index 94% rename from driver/include/tensor_generator.hpp rename to driver/include/host_tensor_generator.hpp index 15469ba67a..84ff1bfff2 100644 --- a/driver/include/tensor_generator.hpp +++ b/driver/include/host_tensor_generator.hpp @@ -1,5 +1,5 @@ -#ifndef TENSOR_GENERATOR_HPP -#define TENSOR_GENERATOR_HPP +#ifndef HOST_TENSOR_GENERATOR_HPP +#define HOST_TENSOR_GENERATOR_HPP #include "config.hpp" diff --git a/driver/src/col2im_driver.cpp b/driver/src/col2im_driver.cpp deleted file mode 100644 index 2c460d6ce1..0000000000 --- a/driver/src/col2im_driver.cpp +++ /dev/null @@ -1,385 +0,0 @@ -#include -#include -#include -#include -#include -#include "config.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "print_array.hpp" -#include "print_sequence.hpp" -#include "device.hpp" -#include "tensor_generator.hpp" -#include "device_tensor.hpp" -#include "conv_common.hpp" -#include "host_col2im.hpp" -#include "device_col2im_eb_nchw.hpp" - -int main(int argc, char* argv[]) -{ - using namespace ck; - -#if 1 - constexpr index_t N = 2; - constexpr index_t C = 8; - constexpr index_t HI = 8; - constexpr index_t WI = 8; - constexpr index_t K = 128; - constexpr index_t Y = 4; - constexpr index_t X = 4; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<1, 1>; - using RightPads = Sequence<2, 2>; -#elif 0 - // 3x3, 34x34 - constexpr index_t N = 64; - constexpr index_t C = 256; - constexpr index_t HI = 34; - constexpr index_t WI = 34; - constexpr index_t K = 128; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 8x8 image - // cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42% - constexpr index_t N = 64; - constexpr index_t C = 1536; - constexpr index_t HI = 8; - constexpr index_t WI = 8; - constexpr index_t K = 256; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 8x8 image - // cudnn@V100 77%, ck@V100 76%, ck@P100 79%, ck@VII 51% - constexpr index_t N = 128; - constexpr index_t C = 2048; - constexpr index_t HI = 8; - constexpr index_t WI = 8; - constexpr index_t K = 384; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 7x7 image - // cudnn@V100 82%, ck@V100 76%, ck@P100 67%, ck@VII 64% - constexpr index_t N = 128; - constexpr index_t C = 832; - constexpr index_t HI = 7; - constexpr index_t WI = 7; - constexpr index_t K = 384; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 8x8 image - // cudnn@V100 83%, ck@V100 75%, ck@P100 78%, ck@VII 65% - constexpr index_t N = 128; - constexpr index_t C = 1280; - constexpr index_t HI = 8; - constexpr index_t WI = 8; - constexpr index_t K = 384; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 14x14 image - // cudnn@V100 62%, ck@V100 68%, ck@P100 70%, ck@VII 50% - constexpr index_t N = 128; - constexpr index_t C = 512; - constexpr index_t HI = 14; - constexpr index_t WI = 14; - constexpr index_t K = 128; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 8x8 image - // cudnn@V100 74%, ck@V100 57%, ck@P100 78%, ck@VII 61% - constexpr index_t N = 64; - constexpr index_t C = 1536; - constexpr index_t HI = 8; - constexpr index_t WI = 8; - constexpr index_t K = 384; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 28x28 image - // cudnn@V100 86%, ck@V100 84%, ck@P100 80%, ck@VII 69% - constexpr index_t N = 128; - constexpr index_t C = 256; - constexpr index_t HI = 28; - constexpr index_t WI = 28; - constexpr index_t K = 128; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 7x7 image - // cudnn@V100 71%, ck@V100 55%, ck@P100 70%, ck@VII 62% - constexpr index_t N = 128; - constexpr index_t C = 832; - constexpr index_t HI = 7; - constexpr index_t WI = 7; - constexpr index_t K = 256; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 17x17 input - // cudnn@V100 81%, ck@V100 76%, ck@P100 70%, ck@VII 76% - constexpr index_t N = 128; - constexpr index_t C = 768; - constexpr index_t HI = 17; - constexpr index_t WI = 17; - constexpr index_t K = 128; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 14x14 image - // cudnn@V100 73%, ck@V100 71%, ck@P100 70%, ck@VII 64% - constexpr index_t N = 128; - constexpr index_t C = 528; - constexpr index_t HI = 14; - constexpr index_t WI = 14; - constexpr index_t K = 128; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 14x14 image - // cudnn@V100 73%, ck@V100 72%, ck@P100 79%, ck@VII 75% - constexpr index_t N = 128; - constexpr index_t C = 528; - constexpr index_t HI = 14; - constexpr index_t WI = 14; - constexpr index_t K = 256; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 7x7 image - // cudnn@V100 49%, ck@V100 50%, ck@P100 61%, ck@VII 52% - constexpr index_t N = 128; - constexpr index_t C = 832; - constexpr index_t HI = 7; - constexpr index_t WI = 7; - constexpr index_t K = 128; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output - // cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81% - constexpr index_t N = 128; - constexpr index_t C = 288; - constexpr index_t HI = 35; - constexpr index_t WI = 35; - constexpr index_t K = 384; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<2, 2>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 5x5 filter, 2x2 pad, 7x7 input - constexpr index_t N = 128; - constexpr index_t C = 48; - constexpr index_t HI = 7; - constexpr index_t WI = 7; - constexpr index_t K = 128; - constexpr index_t Y = 5; - constexpr index_t X = 5; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<2, 2>; - using RightPads = Sequence<2, 2>; -#elif 0 - // 7x1 filter, 3x0 pad, 17x17 input - constexpr index_t N = 128; - constexpr index_t C = 128; - constexpr index_t HI = 17; - constexpr index_t WI = 17; - constexpr index_t K = 128; - constexpr index_t Y = 7; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<3, 0>; - using RightPads = Sequence<3, 0>; -#elif 1 - // 1x7 filter, 0x3 pad, 17x17 input - constexpr index_t N = 128; - constexpr index_t C = 128; - constexpr index_t HI = 17; - constexpr index_t WI = 17; - constexpr index_t K = 128; - constexpr index_t Y = 1; - constexpr index_t X = 7; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 3>; - using RightPads = Sequence<0, 3>; -#endif - - constexpr auto img_nchw_desc = make_native_tensor_descriptor_packed(Sequence{}); - constexpr auto wei_kcyx_desc = make_native_tensor_descriptor_packed(Sequence{}); - constexpr auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor( - img_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, LeftPads{}, RightPads{}); - - constexpr index_t HO = out_nkhw_desc.GetLengths()[2]; - constexpr index_t WO = out_nkhw_desc.GetLengths()[3]; - - constexpr auto col_eb_desc = - make_native_tensor_descriptor_packed(Sequence{}); - - using FilterSizes = Sequence; - using OutputSizes = Sequence; - - ostream_ConstantTensorDescriptor(col_eb_desc, std::cout << "col_eb_desc: "); - ostream_ConstantTensorDescriptor(img_nchw_desc, std::cout << "img_nchw_desc: "); - print_sequence("FilterSizes", FilterSizes{}); - print_sequence("OutputSizes", OutputSizes{}); - print_sequence("LeftPads", LeftPads{}); - print_sequence("LeftPads", LeftPads{}); - print_sequence("RightPads", RightPads{}); - print_sequence("ConvStrides", ConvStrides{}); - print_sequence("ConvDilations", ConvDilations{}); - - Tensor col_eb(make_TensorDescriptor(col_eb_desc)); - Tensor img_nchw_host(make_TensorDescriptor(img_nchw_desc)); - Tensor img_nchw_device(make_TensorDescriptor(img_nchw_desc)); - - std::size_t num_thread = std::thread::hardware_concurrency(); - - if(argc != 3) - { - printf("arg1: do_verification, arg2: nrepeat\n"); - exit(1); - } - - bool do_verification = atoi(argv[1]); - std::size_t nrepeat = atoi(argv[2]); - - if(do_verification) - { -#if 0 - col_eb.GenerateTensorValue(GeneratorTensor_1{}, num_thread); -#else - col_eb.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); -#endif - } - - device_col2im_eb_nchw(col_eb_desc, - col_eb, - img_nchw_desc, - img_nchw_device, - FilterSizes{}, - OutputSizes{}, - ConvStrides{}, - ConvDilations{}, - LeftPads{}, - RightPads{}, - nrepeat); - - if(do_verification) - { - host_col2im(col_eb, - img_nchw_host, - FilterSizes{}, - OutputSizes{}, - ConvStrides{}, - ConvDilations{}, - LeftPads{}, - RightPads{}); - - check_error(img_nchw_host, img_nchw_device); - -#if 0 - LogRange(std::cout << "col_eb : ", col_eb.mData, ",") << std::endl; - LogRange(std::cout << "img_nchw_host : ", img_nchw_host.mData, ",") << std::endl; - LogRange(std::cout << "img_nchw_device : ", img_nchw_device.mData, ",") << std::endl; -#endif - } -} diff --git a/driver/src/col2im_driver.cu b/driver/src/col2im_driver.cu deleted file mode 120000 index 8d388393e4..0000000000 --- a/driver/src/col2im_driver.cu +++ /dev/null @@ -1 +0,0 @@ -col2im_driver.cpp \ No newline at end of file diff --git a/driver/src/conv_bwd_data_driver.cpp b/driver/src/conv_bwd_data_driver.cpp index a94dcb55bf..5cc7d6621b 100644 --- a/driver/src/conv_bwd_data_driver.cpp +++ b/driver/src/conv_bwd_data_driver.cpp @@ -9,7 +9,7 @@ #include "print_array.hpp" #include "print_sequence.hpp" #include "device.hpp" -#include "tensor_generator.hpp" +#include "host_tensor_generator.hpp" #include "device_tensor.hpp" #include "conv_common.hpp" #include "host_conv_bwd_data.hpp" @@ -23,7 +23,7 @@ int main(int argc, char* argv[]) { using namespace launcher; -#if 1 +#if 0 constexpr index_t N = 64; constexpr index_t C = 256; constexpr index_t HI = 56; @@ -160,10 +160,10 @@ int main(int argc, char* argv[]) #elif 0 // 1x7 filter, 0x3 pad, 17x17 input constexpr index_t N = 128; - constexpr index_t C = 128; + constexpr index_t C = 1024; constexpr index_t HI = 17; constexpr index_t WI = 17; - constexpr index_t K = 128; + constexpr index_t K = 1024; constexpr index_t Y = 1; constexpr index_t X = 7; @@ -190,10 +190,10 @@ int main(int argc, char* argv[]) #elif 1 // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output constexpr index_t N = 128; - constexpr index_t C = 1024; + constexpr index_t C = 128; constexpr index_t HI = 35; constexpr index_t WI = 35; - constexpr index_t K = 128; + constexpr index_t K = 1024; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -209,19 +209,19 @@ int main(int argc, char* argv[]) constexpr auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor( in_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, LeftPads{}, RightPads{}); - ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); - ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); - ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); + ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); + ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); + ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); print_sequence("LeftPads", LeftPads{}); print_sequence("LeftPads", LeftPads{}); print_sequence("RightPads", RightPads{}); print_sequence("ConvStrides", ConvStrides{}); print_sequence("ConvDilations", ConvDilations{}); - Tensor in_nchw_device(make_TensorDescriptor(in_nchw_desc)); - Tensor in_nchw_host(make_TensorDescriptor(in_nchw_desc)); - Tensor wei_kcyx(make_TensorDescriptor(wei_kcyx_desc)); - Tensor out_nkhw(make_TensorDescriptor(out_nkhw_desc)); + Tensor in_nchw_device(make_HostTensorDescriptor(in_nchw_desc)); + Tensor in_nchw_host(make_HostTensorDescriptor(in_nchw_desc)); + Tensor wei_kcyx(make_HostTensorDescriptor(wei_kcyx_desc)); + Tensor out_nkhw(make_HostTensorDescriptor(out_nkhw_desc)); std::size_t num_thread = std::thread::hardware_concurrency(); @@ -245,9 +245,9 @@ int main(int argc, char* argv[]) #endif } -#if 1 +#if 0 device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw -#elif 0 +#elif 1 device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw #elif 0 device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw @@ -256,17 +256,17 @@ int main(int argc, char* argv[]) #elif 1 device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw #endif - (in_nchw_desc, - in_nchw_device, - wei_kcyx_desc, - wei_kcyx, - out_nkhw_desc, - out_nkhw, - ConvStrides{}, - ConvDilations{}, - LeftPads{}, - RightPads{}, - nrepeat); + (in_nchw_desc, + in_nchw_device, + wei_kcyx_desc, + wei_kcyx, + out_nkhw_desc, + out_nkhw, + ConvStrides{}, + ConvDilations{}, + LeftPads{}, + RightPads{}, + nrepeat); if(do_verification) { diff --git a/driver/src/conv_driver.cpp b/driver/src/conv_driver.cpp index ae0dda3d4c..7317bd6a1c 100644 --- a/driver/src/conv_driver.cpp +++ b/driver/src/conv_driver.cpp @@ -3,38 +3,28 @@ #include #include #include +#include #include "config.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" #include "print_array.hpp" #include "print_sequence.hpp" #include "device.hpp" -#include "tensor_generator.hpp" +#include "host_tensor_generator.hpp" #include "conv_common.hpp" #include "host_conv.hpp" #include "device_tensor.hpp" -//#include "device_convolution_direct_v2_nchw_kcyx_nkhw.hpp" -//#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp" -//#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp" -//#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp" -//#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp" -//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp" -#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated.hpp" #include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" -//#include "device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw.hpp" -//#include "device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw.hpp" -#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated.hpp" #include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" int main(int argc, char* argv[]) { using namespace ck; -#if 1 - // 1x1 - constexpr index_t N = 64; - constexpr index_t C = 64; - constexpr index_t HI = 56; - constexpr index_t WI = 56; +#if 0 + // 1x1, 17x17 + constexpr index_t N = 128; + constexpr index_t C = 1024; + constexpr index_t HI = 17; + constexpr index_t WI = 17; constexpr index_t K = 256; constexpr index_t Y = 1; constexpr index_t X = 1; @@ -45,12 +35,87 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; #elif 0 - // 1x7 + // 1x1, 8x8 constexpr index_t N = 128; - constexpr index_t C = 1024; + constexpr index_t C = 1536; + constexpr index_t HI = 8; + constexpr index_t WI = 8; + constexpr index_t K = 256; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; +#elif 0 + // 1x1, 73x73 + constexpr index_t N = 128; + constexpr index_t C = 160; + constexpr index_t HI = 73; + constexpr index_t WI = 73; + constexpr index_t K = 64; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; +#elif 0 + // 3x3, 35x35 + constexpr index_t N = 128; + constexpr index_t C = 96; + constexpr index_t HI = 35; + constexpr index_t WI = 35; + constexpr index_t K = 96; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<1, 1>; + using RightPads = Sequence<1, 1>; +#elif 0 + // 3x3, 71x71 + constexpr index_t N = 128; + constexpr index_t C = 192; + constexpr index_t HI = 71; + constexpr index_t WI = 71; + constexpr index_t K = 192; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + using ConvStrides = Sequence<2, 2>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<1, 1>; + using RightPads = Sequence<1, 1>; +#elif 0 + // 7x1, 17x17 + constexpr index_t N = 128; + constexpr index_t C = 128; constexpr index_t HI = 17; constexpr index_t WI = 17; - constexpr index_t K = 1024; + constexpr index_t K = 128; + constexpr index_t Y = 7; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<3, 0>; + using RightPads = Sequence<3, 0>; +#elif 1 + // 1x7, 17x17 + constexpr index_t N = 128; + constexpr index_t C = 128; + constexpr index_t HI = 17; + constexpr index_t WI = 17; + constexpr index_t K = 128; constexpr index_t Y = 1; constexpr index_t X = 7; @@ -60,27 +125,12 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<0, 3>; using RightPads = Sequence<0, 3>; #elif 0 - // 3x3, 34x34 - constexpr index_t N = 64; - constexpr index_t C = 256; - constexpr index_t HI = 34; - constexpr index_t WI = 34; - constexpr index_t K = 256; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output + // 3x3, 299x299 stride=2 constexpr index_t N = 128; - constexpr index_t C = 128; - constexpr index_t HI = 35; - constexpr index_t WI = 35; - constexpr index_t K = 128; + constexpr index_t C = 3; + constexpr index_t HI = 299; + constexpr index_t WI = 299; + constexpr index_t K = 32; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -90,47 +140,31 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; #elif 0 - // 1x1 filter, 8x8 image - // cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42% - constexpr index_t N = 64; - constexpr index_t C = 1536; - constexpr index_t HI = 8; - constexpr index_t WI = 8; - constexpr index_t K = 256; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 8x8 image - // cudnn@V100 77%, ck@V100 76%, ck@P100 79%, ck@VII 51% + // 3x3, 147x147 + // v4r4@v100 xx.xx%, cudnn@v100 xx.xx% constexpr index_t N = 128; - constexpr index_t C = 2048; - constexpr index_t HI = 8; - constexpr index_t WI = 8; - constexpr index_t K = 384; - constexpr index_t Y = 1; - constexpr index_t X = 1; + constexpr index_t C = 32; + constexpr index_t HI = 147; + constexpr index_t WI = 147; + constexpr index_t K = 64; + constexpr index_t Y = 3; + constexpr index_t X = 3; using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using LeftPads = Sequence<1, 1>; + using RightPads = Sequence<1, 1>; #elif 0 - // 1x1 filter, 7x7 image - // cudnn@V100 82%, ck@V100 76%, ck@P100 67%, ck@VII 64% + // 3x3, 149x149 + // v4r4@v100 xx.xx%, cudnn@v100 xx.xx% constexpr index_t N = 128; - constexpr index_t C = 832; - constexpr index_t HI = 7; - constexpr index_t WI = 7; - constexpr index_t K = 384; - constexpr index_t Y = 1; - constexpr index_t X = 1; + constexpr index_t C = 32; + constexpr index_t HI = 149; + constexpr index_t WI = 149; + constexpr index_t K = 32; + constexpr index_t Y = 3; + constexpr index_t X = 3; using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; @@ -138,109 +172,27 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; #elif 0 - // 1x1 filter, 8x8 image - // cudnn@V100 83%, ck@V100 75%, ck@P100 78%, ck@VII 65% + // 3x3, 17x17, stride 2 constexpr index_t N = 128; - constexpr index_t C = 1280; - constexpr index_t HI = 8; - constexpr index_t WI = 8; - constexpr index_t K = 384; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 14x14 image - // cudnn@V100 62%, ck@V100 68%, ck@P100 70%, ck@VII 50% - constexpr index_t N = 128; - constexpr index_t C = 512; - constexpr index_t HI = 14; - constexpr index_t WI = 14; - constexpr index_t K = 128; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 8x8 image - // cudnn@V100 74%, ck@V100 57%, ck@P100 78%, ck@VII 61% - constexpr index_t N = 64; - constexpr index_t C = 1536; - constexpr index_t HI = 8; - constexpr index_t WI = 8; - constexpr index_t K = 384; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 28x28 image - // cudnn@V100 86%, ck@V100 84%, ck@P100 80%, ck@VII 69% - constexpr index_t N = 128; - constexpr index_t C = 256; - constexpr index_t HI = 28; - constexpr index_t WI = 28; - constexpr index_t K = 128; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 7x7 image - // cudnn@V100 71%, ck@V100 55%, ck@P100 70%, ck@VII 62% - constexpr index_t N = 128; - constexpr index_t C = 832; - constexpr index_t HI = 7; - constexpr index_t WI = 7; - constexpr index_t K = 256; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 17x17 input - // cudnn@V100 81%, ck@V100 76%, ck@P100 70%, ck@VII 76% - constexpr index_t N = 128; - constexpr index_t C = 768; + constexpr index_t C = 192; constexpr index_t HI = 17; constexpr index_t WI = 17; - constexpr index_t K = 128; - constexpr index_t Y = 1; - constexpr index_t X = 1; + constexpr index_t K = 192; + constexpr index_t Y = 3; + constexpr index_t X = 3; - using ConvStrides = Sequence<1, 1>; + using ConvStrides = Sequence<2, 2>; using ConvDilations = Sequence<1, 1>; using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; #elif 0 - // 1x1 filter, 14x14 image - // cudnn@V100 73%, ck@V100 71%, ck@P100 70%, ck@VII 64% + // 1x1, 35x35 constexpr index_t N = 128; - constexpr index_t C = 528; - constexpr index_t HI = 14; - constexpr index_t WI = 14; - constexpr index_t K = 128; + constexpr index_t C = 384; + constexpr index_t HI = 35; + constexpr index_t WI = 35; + constexpr index_t K = 96; constexpr index_t Y = 1; constexpr index_t X = 1; @@ -249,41 +201,8 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 14x14 image - // cudnn@V100 73%, ck@V100 72%, ck@P100 79%, ck@VII 75% - constexpr index_t N = 128; - constexpr index_t C = 528; - constexpr index_t HI = 14; - constexpr index_t WI = 14; - constexpr index_t K = 256; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 7x7 image - // cudnn@V100 49%, ck@V100 50%, ck@P100 61%, ck@VII 52% - constexpr index_t N = 128; - constexpr index_t C = 832; - constexpr index_t HI = 7; - constexpr index_t WI = 7; - constexpr index_t K = 128; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output - // cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81% +#elif 1 + // 3x3, 35x35, stride 2 constexpr index_t N = 128; constexpr index_t C = 288; constexpr index_t HI = 35; @@ -298,42 +217,73 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; #elif 0 - // 5x5 filter, 2x2 pad, 7x7 input + // 1x3, 8x8 constexpr index_t N = 128; - constexpr index_t C = 48; - constexpr index_t HI = 7; - constexpr index_t WI = 7; - constexpr index_t K = 128; - constexpr index_t Y = 5; - constexpr index_t X = 5; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<2, 2>; - using RightPads = Sequence<2, 2>; -#elif 0 - // 1x7 filter, 0x3 pad, 17x17 input - constexpr index_t N = 128; - constexpr index_t C = 128; - constexpr index_t HI = 17; - constexpr index_t WI = 17; - constexpr index_t K = 128; + constexpr index_t C = 384; + constexpr index_t HI = 8; + constexpr index_t WI = 8; + constexpr index_t K = 448; constexpr index_t Y = 1; - constexpr index_t X = 7; + constexpr index_t X = 3; using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 3>; - using RightPads = Sequence<0, 3>; -#elif 1 - // 7x1 filter, 3x0 pad, 17x17 input + using LeftPads = Sequence<0, 1>; + using RightPads = Sequence<0, 1>; +#elif 0 + // 3x1, 8x8 constexpr index_t N = 128; - constexpr index_t C = 128; - constexpr index_t HI = 17; - constexpr index_t WI = 17; - constexpr index_t K = 128; + constexpr index_t C = 448; + constexpr index_t HI = 8; + constexpr index_t WI = 8; + constexpr index_t K = 512; + constexpr index_t Y = 3; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<1, 0>; + using RightPads = Sequence<1, 0>; +#elif 0 + // 3x1, 8x8 + constexpr index_t N = 128; + constexpr index_t C = 448; + constexpr index_t HI = 8; + constexpr index_t WI = 8; + constexpr index_t K = 512; + constexpr index_t Y = 3; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<1, 0>; + using RightPads = Sequence<1, 0>; +#elif 0 + // 3x3, 147x147 + constexpr index_t N = 128; + constexpr index_t C = 64; + constexpr index_t HI = 147; + constexpr index_t WI = 147; + constexpr index_t K = 96; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + using ConvStrides = Sequence<2, 2>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; +#elif 0 + // 7x1, 73x73 + // v44@v100 xx.xx%, cudnn@v100 xx.xx% + constexpr index_t N = 128; + constexpr index_t C = 64; + constexpr index_t HI = 73; + constexpr index_t WI = 73; + constexpr index_t K = 64; constexpr index_t Y = 7; constexpr index_t X = 1; @@ -342,27 +292,243 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<3, 0>; using RightPads = Sequence<3, 0>; +#elif 0 + // 3x3, 73x73 + constexpr index_t N = 128; + constexpr index_t C = 64; + constexpr index_t HI = 73; + constexpr index_t WI = 73; + constexpr index_t K = 96; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; +#elif 0 + // 1x1, 14x14, stride 2 + constexpr index_t N = 128; + constexpr index_t C = 1024; + constexpr index_t HI = 14; + constexpr index_t WI = 14; + constexpr index_t K = 2048; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<2, 2>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; +#elif 0 + // 1x1, 14x14 + constexpr index_t N = 128; + constexpr index_t C = 1024; + constexpr index_t HI = 14; + constexpr index_t WI = 14; + constexpr index_t K = 256; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; +#elif 0 + // 1x1, 14x14, stride 2 + constexpr index_t N = 128; + constexpr index_t C = 1024; + constexpr index_t HI = 14; + constexpr index_t WI = 14; + constexpr index_t K = 512; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<2, 2>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; +#elif 0 + // 3x3, 28x28 + constexpr index_t N = 128; + constexpr index_t C = 128; + constexpr index_t HI = 28; + constexpr index_t WI = 28; + constexpr index_t K = 128; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<1, 1>; + using RightPads = Sequence<1, 1>; +#elif 1 + // 3x3, 14x14 + constexpr index_t N = 128; + constexpr index_t C = 256; + constexpr index_t HI = 14; + constexpr index_t WI = 14; + constexpr index_t K = 256; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<1, 1>; + using RightPads = Sequence<1, 1>; +#elif 1 + // 1x1, 56x56, stride 2 + constexpr index_t N = 128; + constexpr index_t C = 256; + constexpr index_t HI = 56; + constexpr index_t WI = 56; + constexpr index_t K = 128; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<2, 2>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; +#elif 0 + // 7x7, 230x230 stride=2 + constexpr index_t N = 128; + constexpr index_t C = 3; + constexpr index_t HI = 230; + constexpr index_t WI = 230; + constexpr index_t K = 64; + constexpr index_t Y = 7; + constexpr index_t X = 7; + + using ConvStrides = Sequence<2, 2>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; +#elif 0 + // 1x1, 28x28, stride = 2 + constexpr index_t N = 128; + constexpr index_t C = 512; + constexpr index_t HI = 28; + constexpr index_t WI = 28; + constexpr index_t K = 1024; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<2, 2>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; +#elif 0 + // 1x1, 28x28, stride 2 + constexpr index_t N = 128; + constexpr index_t C = 512; + constexpr index_t HI = 28; + constexpr index_t WI = 28; + constexpr index_t K = 256; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<2, 2>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; +#elif 0 + // 1x1, 7x7 + constexpr index_t N = 128; + constexpr index_t C = 512; + constexpr index_t HI = 7; + constexpr index_t WI = 7; + constexpr index_t K = 2048; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; +#elif 0 + // 3x3, 7x7 + constexpr index_t N = 128; + constexpr index_t C = 512; + constexpr index_t HI = 7; + constexpr index_t WI = 7; + constexpr index_t K = 512; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<1, 1>; + using RightPads = Sequence<1, 1>; +#elif 1 + // 1x1, 56x56 + constexpr index_t N = 128; + constexpr index_t C = 64; + constexpr index_t HI = 56; + constexpr index_t WI = 56; + constexpr index_t K = 64; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; +#elif 1 + // 3x3, 56x56 + constexpr index_t N = 128; + constexpr index_t C = 64; + constexpr index_t HI = 56; + constexpr index_t WI = 56; + constexpr index_t K = 64; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<1, 1>; + using RightPads = Sequence<1, 1>; #endif - auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence{}); - auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence{}); - auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor_deprecated( + auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence{}); + auto wei_kcyx_desc = make_native_tensor_descriptor_packed(Sequence{}); + auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor( in_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, LeftPads{}, RightPads{}); - ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); - ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); - ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); + ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); + ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); + ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); print_sequence("LeftPads", LeftPads{}); print_sequence("RightPads", RightPads{}); print_sequence("ConvStrides", ConvStrides{}); print_sequence("ConvDilations", ConvDilations{}); +#if 1 using in_data_t = float; using out_data_t = float; - Tensor in_nchw(make_TensorDescriptor(in_nchw_desc)); - Tensor wei_kcyx(make_TensorDescriptor(wei_kcyx_desc)); - Tensor out_nkhw_host(make_TensorDescriptor(out_nkhw_desc)); - Tensor out_nkhw_device(make_TensorDescriptor(out_nkhw_desc)); +#else + using in_data_t = half_float::half; + using out_data_t = half_float::half; +#endif + + Tensor in_nchw(make_HostTensorDescriptor(in_nchw_desc)); + Tensor wei_kcyx(make_HostTensorDescriptor(wei_kcyx_desc)); + Tensor out_nkhw_host(make_HostTensorDescriptor(out_nkhw_desc)); + Tensor out_nkhw_device(make_HostTensorDescriptor(out_nkhw_desc)); std::size_t num_thread = std::thread::hardware_concurrency(); @@ -399,42 +565,7 @@ int main(int argc, char* argv[]) #endif } -#if 0 - device_convolution_direct_v2_nchw_kcyx_nkhw - (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); -#elif 0 - device_convolution_implicit_gemm_v1_chwn_cyxk_khwn( - in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); -#elif 0 - device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(in_nchw_desc, - in_nchw, - wei_kcyx_desc, - wei_kcyx, - out_nkhw_desc, - out_nkhw_device, - LeftPads{}, - RightPads{}, - nrepeat); -#elif 0 - device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw( - in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); -#elif 0 - device_convolution_implicit_gemm_v2_chwn_cyxk_khwn( - in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); -#elif 0 - device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw( - (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); -#elif 0 - device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(in_nchw_desc, - in_nchw, - wei_kcyx_desc, - wei_kcyx, - out_nkhw_desc, - out_nkhw_device, - ConvStrides{}, - ConvDilations{}, - nrepeat); -#elif 0 +#if 1 device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, in_nchw, wei_kcyx_desc, @@ -446,36 +577,6 @@ int main(int argc, char* argv[]) LeftPads{}, RightPads{}, nrepeat); -#elif 0 - device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw(in_nchw_desc, - in_nchw, - wei_kcyx_desc, - wei_kcyx, - out_nkhw_desc, - out_nkhw_device, - ConvStrides{}, - ConvDilations{}, - nrepeat); -#elif 0 - device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw(in_nchw_desc, - in_nchw, - wei_kcyx_desc, - wei_kcyx, - out_nkhw_desc, - out_nkhw_device, - ConvStrides{}, - ConvDilations{}, - nrepeat); -#elif 0 - device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated(in_nchw_desc, - in_nchw, - wei_kcyx_desc, - wei_kcyx, - out_nkhw_desc, - out_nkhw_device, - ConvStrides{}, - ConvDilations{}, - nrepeat); #elif 1 device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc, in_nchw, @@ -492,7 +593,7 @@ int main(int argc, char* argv[]) if(do_verification) { -#if 1 +#if 0 if(Y == 3 && X == 3 && ConvStrides{}[0] == 1 && ConvStrides{}[1] == 1 && ConvDilations{}[0] == 1 && ConvDilations{}[1] == 1) { diff --git a/driver/src/device.cpp b/driver/src/device.cpp index 76cb19f466..14f4792d26 100644 --- a/driver/src/device.cpp +++ b/driver/src/device.cpp @@ -6,7 +6,7 @@ DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size) #if CK_DEVICE_BACKEND_AMD hipGetErrorString(hipMalloc(static_cast(&mpDeviceBuf), mMemSize)); #elif CK_DEVICE_BACKEND_NVIDIA - checkCudaErrors(cudaMalloc(static_cast(&mpDeviceBuf), mMemSize)); + cudaMalloc(static_cast(&mpDeviceBuf), mMemSize); #endif } @@ -18,8 +18,7 @@ void DeviceMem::ToDevice(const void* p) hipGetErrorString( hipMemcpy(mpDeviceBuf, const_cast(p), mMemSize, hipMemcpyHostToDevice)); #elif CK_DEVICE_BACKEND_NVIDIA - checkCudaErrors( - cudaMemcpy(mpDeviceBuf, const_cast(p), mMemSize, cudaMemcpyHostToDevice)); + cudaMemcpy(mpDeviceBuf, const_cast(p), mMemSize, cudaMemcpyHostToDevice); #endif } @@ -28,7 +27,7 @@ void DeviceMem::FromDevice(void* p) #if CK_DEVICE_BACKEND_AMD hipGetErrorString(hipMemcpy(p, mpDeviceBuf, mMemSize, hipMemcpyDeviceToHost)); #elif CK_DEVICE_BACKEND_NVIDIA - checkCudaErrors(cudaMemcpy(p, mpDeviceBuf, mMemSize, cudaMemcpyDeviceToHost)); + cudaMemcpy(p, mpDeviceBuf, mMemSize, cudaMemcpyDeviceToHost); #endif } @@ -37,7 +36,7 @@ DeviceMem::~DeviceMem() #if CK_DEVICE_BACKEND_AMD hipGetErrorString(hipFree(mpDeviceBuf)); #elif CK_DEVICE_BACKEND_NVIDIA - checkCudaErrors(cudaFree(mpDeviceBuf)); + cudaFree(mpDeviceBuf); #endif } @@ -68,8 +67,10 @@ struct KernelTimerImpl void Start() { #if CK_DEVICE_BACKEND_AMD + hipDeviceSynchronize(); hipEventRecord(mStart, 0); #elif CK_DEVICE_BACKEND_NVIDIA + cudaDeviceSynchronize(); cudaEventRecord(mStart, 0); #endif } diff --git a/driver/src/tensor.cpp b/driver/src/host_tensor.cpp similarity index 56% rename from driver/src/tensor.cpp rename to driver/src/host_tensor.cpp index 24d2c77233..10f358d2f6 100644 --- a/driver/src/tensor.cpp +++ b/driver/src/host_tensor.cpp @@ -1,21 +1,21 @@ #include #include -#include "tensor.hpp" +#include "host_tensor.hpp" template -TensorDescriptor::TensorDescriptor(std::vector lens) : mLens(lens) +HostTensorDescriptor::HostTensorDescriptor(std::vector lens) : mLens(lens) { this->CalculateStrides(); } template -TensorDescriptor::TensorDescriptor(std::vector lens, std::vector strides) +HostTensorDescriptor::HostTensorDescriptor(std::vector lens, std::vector strides) : mLens(lens), mStrides(strides) { } -void TensorDescriptor::CalculateStrides() +void HostTensorDescriptor::CalculateStrides() { mStrides.clear(); mStrides.resize(mLens.size(), 0); @@ -27,21 +27,21 @@ void TensorDescriptor::CalculateStrides() mLens.rbegin(), mLens.rend() - 1, mStrides.rbegin() + 1, std::multiplies()); } -std::size_t TensorDescriptor::GetNumOfDimension() const { return mLens.size(); } +std::size_t HostTensorDescriptor::GetNumOfDimension() const { return mLens.size(); } -std::size_t TensorDescriptor::GetElementSize() const +std::size_t HostTensorDescriptor::GetElementSize() const { assert(mLens.size() == mStrides.size()); return std::accumulate( mLens.begin(), mLens.end(), std::size_t{1}, std::multiplies()); } -std::size_t TensorDescriptor::GetElementSpace() const +std::size_t HostTensorDescriptor::GetElementSpace() const { auto ls = mLens | boost::adaptors::transformed([](std::size_t v) { return v - 1; }); return std::inner_product(ls.begin(), ls.end(), mStrides.begin(), std::size_t{0}) + 1; } -const std::vector& TensorDescriptor::GetLengths() const { return mLens; } +const std::vector& HostTensorDescriptor::GetLengths() const { return mLens; } -const std::vector& TensorDescriptor::GetStrides() const { return mStrides; } +const std::vector& HostTensorDescriptor::GetStrides() const { return mStrides; } diff --git a/external/half/include/half.hpp b/external/half/include/half.hpp new file mode 100644 index 0000000000..1172a2c564 --- /dev/null +++ b/external/half/include/half.hpp @@ -0,0 +1,5670 @@ +// half - IEEE 754-based half-precision floating-point library. +// +// Copyright (c) 2012-2019 Christian Rau +// +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and +// associated documentation +// files (the "Software"), to deal in the Software without restriction, including without limitation +// the rights to use, copy, +// modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit +// persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all copies or +// substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT +// NOT LIMITED TO THE +// WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +// SHALL THE AUTHORS OR +// COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF +// CONTRACT, TORT OR OTHERWISE, +// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// Version 2.1.0 + +/// \file +/// Main header file for half-precision functionality. + +#ifndef HALF_HALF_HPP +#define HALF_HALF_HPP + +#define HALF_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) + +#if defined(__INTEL_COMPILER) +#define HALF_ICC_VERSION __INTEL_COMPILER +#elif defined(__ICC) +#define HALF_ICC_VERSION __ICC +#elif defined(__ICL) +#define HALF_ICC_VERSION __ICL +#else +#define HALF_ICC_VERSION 0 +#endif + +// check C++11 language features +#if defined(__clang__) // clang +#if __has_feature(cxx_static_assert) && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) +#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 +#endif +#if __has_feature(cxx_constexpr) && !defined(HALF_ENABLE_CPP11_CONSTEXPR) +#define HALF_ENABLE_CPP11_CONSTEXPR 1 +#endif +#if __has_feature(cxx_noexcept) && !defined(HALF_ENABLE_CPP11_NOEXCEPT) +#define HALF_ENABLE_CPP11_NOEXCEPT 1 +#endif +#if __has_feature(cxx_user_literals) && !defined(HALF_ENABLE_CPP11_USER_LITERALS) +#define HALF_ENABLE_CPP11_USER_LITERALS 1 +#endif +#if __has_feature(cxx_thread_local) && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) +#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 +#endif +#if(defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L) && \ + !defined(HALF_ENABLE_CPP11_LONG_LONG) +#define HALF_ENABLE_CPP11_LONG_LONG 1 +#endif +#elif HALF_ICC_VERSION && defined(__INTEL_CXX11_MODE__) // Intel C++ +#if HALF_ICC_VERSION >= 1500 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) +#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 +#endif +#if HALF_ICC_VERSION >= 1500 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) +#define HALF_ENABLE_CPP11_USER_LITERALS 1 +#endif +#if HALF_ICC_VERSION >= 1400 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) +#define HALF_ENABLE_CPP11_CONSTEXPR 1 +#endif +#if HALF_ICC_VERSION >= 1400 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) +#define HALF_ENABLE_CPP11_NOEXCEPT 1 +#endif +#if HALF_ICC_VERSION >= 1110 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) +#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 +#endif +#if HALF_ICC_VERSION >= 1110 && !defined(HALF_ENABLE_CPP11_LONG_LONG) +#define HALF_ENABLE_CPP11_LONG_LONG 1 +#endif +#elif defined(__GNUC__) // gcc +#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L +#if HALF_GCC_VERSION >= 408 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) +#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 +#endif +#if HALF_GCC_VERSION >= 407 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) +#define HALF_ENABLE_CPP11_USER_LITERALS 1 +#endif +#if HALF_GCC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) +#define HALF_ENABLE_CPP11_CONSTEXPR 1 +#endif +#if HALF_GCC_VERSION >= 406 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) +#define HALF_ENABLE_CPP11_NOEXCEPT 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) +#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 +#endif +#if !defined(HALF_ENABLE_CPP11_LONG_LONG) +#define HALF_ENABLE_CPP11_LONG_LONG 1 +#endif +#endif +#define HALF_TWOS_COMPLEMENT_INT 1 +#elif defined(_MSC_VER) // Visual C++ +#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_THREAD_LOCAL) +#define HALF_ENABLE_CPP11_THREAD_LOCAL 1 +#endif +#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_USER_LITERALS) +#define HALF_ENABLE_CPP11_USER_LITERALS 1 +#endif +#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_CONSTEXPR) +#define HALF_ENABLE_CPP11_CONSTEXPR 1 +#endif +#if _MSC_VER >= 1900 && !defined(HALF_ENABLE_CPP11_NOEXCEPT) +#define HALF_ENABLE_CPP11_NOEXCEPT 1 +#endif +#if _MSC_VER >= 1600 && !defined(HALF_ENABLE_CPP11_STATIC_ASSERT) +#define HALF_ENABLE_CPP11_STATIC_ASSERT 1 +#endif +#if _MSC_VER >= 1310 && !defined(HALF_ENABLE_CPP11_LONG_LONG) +#define HALF_ENABLE_CPP11_LONG_LONG 1 +#endif +#define HALF_TWOS_COMPLEMENT_INT 1 +#define HALF_POP_WARNINGS 1 +#pragma warning(push) +#pragma warning(disable : 4099 4127 4146) // struct vs class, constant in if, negative unsigned +#endif + +// check C++11 library features +#include +#if defined(_LIBCPP_VERSION) // libc++ +#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 +#ifndef HALF_ENABLE_CPP11_TYPE_TRAITS +#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 +#endif +#ifndef HALF_ENABLE_CPP11_CSTDINT +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#ifndef HALF_ENABLE_CPP11_CMATH +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#ifndef HALF_ENABLE_CPP11_HASH +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#ifndef HALF_ENABLE_CPP11_CFENV +#define HALF_ENABLE_CPP11_CFENV 1 +#endif +#endif +#elif defined(__GLIBCXX__) // libstdc++ +#if defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103 +#ifdef __clang__ +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) +#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 +#endif +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CSTDINT) +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CMATH) +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_HASH) +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#if __GLIBCXX__ >= 20080606 && !defined(HALF_ENABLE_CPP11_CFENV) +#define HALF_ENABLE_CPP11_CFENV 1 +#endif +#else +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) +#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CSTDINT) +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CMATH) +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_HASH) +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#if HALF_GCC_VERSION >= 403 && !defined(HALF_ENABLE_CPP11_CFENV) +#define HALF_ENABLE_CPP11_CFENV 1 +#endif +#endif +#endif +#elif defined(_CPPLIB_VER) // Dinkumware/Visual C++ +#if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_TYPE_TRAITS) +#define HALF_ENABLE_CPP11_TYPE_TRAITS 1 +#endif +#if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_CSTDINT) +#define HALF_ENABLE_CPP11_CSTDINT 1 +#endif +#if _CPPLIB_VER >= 520 && !defined(HALF_ENABLE_CPP11_HASH) +#define HALF_ENABLE_CPP11_HASH 1 +#endif +#if _CPPLIB_VER >= 610 && !defined(HALF_ENABLE_CPP11_CMATH) +#define HALF_ENABLE_CPP11_CMATH 1 +#endif +#if _CPPLIB_VER >= 610 && !defined(HALF_ENABLE_CPP11_CFENV) +#define HALF_ENABLE_CPP11_CFENV 1 +#endif +#endif +#undef HALF_GCC_VERSION +#undef HALF_ICC_VERSION + +// any error throwing C++ exceptions? +#if defined(HALF_ERRHANDLING_THROW_INVALID) || defined(HALF_ERRHANDLING_THROW_DIVBYZERO) || \ + defined(HALF_ERRHANDLING_THROW_OVERFLOW) || defined(HALF_ERRHANDLING_THROW_UNDERFLOW) || \ + defined(HALF_ERRHANDLING_THROW_INEXACT) +#define HALF_ERRHANDLING_THROWS 1 +#endif + +// any error handling enabled? +#define HALF_ERRHANDLING \ + (HALF_ERRHANDLING_FLAGS || HALF_ERRHANDLING_ERRNO || HALF_ERRHANDLING_FENV || \ + HALF_ERRHANDLING_THROWS) + +#if HALF_ERRHANDLING +#define HALF_UNUSED_NOERR(name) name +#else +#define HALF_UNUSED_NOERR(name) +#endif + +// support constexpr +#if HALF_ENABLE_CPP11_CONSTEXPR +#define HALF_CONSTEXPR constexpr +#define HALF_CONSTEXPR_CONST constexpr +#if HALF_ERRHANDLING +#define HALF_CONSTEXPR_NOERR +#else +#define HALF_CONSTEXPR_NOERR constexpr +#endif +#else +#define HALF_CONSTEXPR +#define HALF_CONSTEXPR_CONST const +#define HALF_CONSTEXPR_NOERR +#endif + +// support noexcept +#if HALF_ENABLE_CPP11_NOEXCEPT +#define HALF_NOEXCEPT noexcept +#define HALF_NOTHROW noexcept +#else +#define HALF_NOEXCEPT +#define HALF_NOTHROW throw() +#endif + +// support thread storage +#if HALF_ENABLE_CPP11_THREAD_LOCAL +#define HALF_THREAD_LOCAL thread_local +#else +#define HALF_THREAD_LOCAL static +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if HALF_ENABLE_CPP11_TYPE_TRAITS +#include +#endif +#if HALF_ENABLE_CPP11_CSTDINT +#include +#endif +#if HALF_ERRHANDLING_ERRNO +#include +#endif +#if HALF_ENABLE_CPP11_CFENV +#include +#endif +#if HALF_ENABLE_CPP11_HASH +#include +#endif +#if HALF_ENABLE_F16C_INTRINSICS +#include +#endif + +#ifndef HALF_ENABLE_F16C_INTRINSICS +/// Enable F16C intruction set intrinsics. +/// Defining this to 1 enables the use of [F16C compiler +/// intrinsics](https://en.wikipedia.org/wiki/F16C) for converting between +/// half-precision and single-precision values which may result in improved performance. This will +/// not perform additional checks +/// for support of the F16C instruction set, so an appropriate target platform is required when +/// enabling this feature. +/// +/// Unless predefined it will be enabled automatically when the `__F16C__` symbol is defined, which +/// some compilers do on supporting platforms. +#define HALF_ENABLE_F16C_INTRINSICS __F16C__ +#endif + +#ifdef HALF_DOXYGEN_ONLY +/// Type for internal floating-point computations. +/// This can be predefined to a built-in floating-point type (`float`, `double` or `long double`) to +/// override the internal +/// half-precision implementation to use this type for computing arithmetic operations and +/// mathematical function (if available). +/// This can result in improved performance for arithmetic operators and mathematical functions but +/// might cause results to +/// deviate from the specified half-precision rounding mode and inhibits proper detection of +/// half-precision exceptions. +#define HALF_ARITHMETIC_TYPE (undefined) + +/// Enable internal exception flags. +/// Defining this to 1 causes operations on half-precision values to raise internal floating-point +/// exception flags according to +/// the IEEE 754 standard. These can then be cleared and checked with clearexcept(), testexcept(). +#define HALF_ERRHANDLING_FLAGS 0 + +/// Enable exception propagation to `errno`. +/// Defining this to 1 causes operations on half-precision values to propagate floating-point +/// exceptions to +/// [errno](https://en.cppreference.com/w/cpp/error/errno) from ``. Specifically this will +/// propagate domain errors as +/// [EDOM](https://en.cppreference.com/w/cpp/error/errno_macros) and pole, overflow and underflow +/// errors as +/// [ERANGE](https://en.cppreference.com/w/cpp/error/errno_macros). Inexact errors won't be +/// propagated. +#define HALF_ERRHANDLING_ERRNO 0 + +/// Enable exception propagation to built-in floating-point platform. +/// Defining this to 1 causes operations on half-precision values to propagate floating-point +/// exceptions to the built-in +/// single- and double-precision implementation's exception flags using the +/// [C++11 floating-point environment control](https://en.cppreference.com/w/cpp/numeric/fenv) from +/// ``. However, this +/// does not work in reverse and single- or double-precision exceptions will not raise the +/// corresponding half-precision +/// exception flags, nor will explicitly clearing flags clear the corresponding built-in flags. +#define HALF_ERRHANDLING_FENV 0 + +/// Throw C++ exception on domain errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::domain_error](https://en.cppreference.com/w/cpp/error/domain_error) with the specified +/// message on domain errors. +#define HALF_ERRHANDLING_THROW_INVALID (undefined) + +/// Throw C++ exception on pole errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::domain_error](https://en.cppreference.com/w/cpp/error/domain_error) with the specified +/// message on pole errors. +#define HALF_ERRHANDLING_THROW_DIVBYZERO (undefined) + +/// Throw C++ exception on overflow errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::overflow_error](https://en.cppreference.com/w/cpp/error/overflow_error) with the specified +/// message on overflows. +#define HALF_ERRHANDLING_THROW_OVERFLOW (undefined) + +/// Throw C++ exception on underflow errors. +/// Defining this to a string literal causes operations on half-precision values to throw a +/// [std::underflow_error](https://en.cppreference.com/w/cpp/error/underflow_error) with the +/// specified message on underflows. +#define HALF_ERRHANDLING_THROW_UNDERFLOW (undefined) + +/// Throw C++ exception on rounding errors. +/// Defining this to 1 causes operations on half-precision values to throw a +/// [std::range_error](https://en.cppreference.com/w/cpp/error/range_error) with the specified +/// message on general rounding errors. +#define HALF_ERRHANDLING_THROW_INEXACT (undefined) +#endif + +#ifndef HALF_ERRHANDLING_OVERFLOW_TO_INEXACT +/// Raise INEXACT exception on overflow. +/// Defining this to 1 (default) causes overflow errors to automatically raise inexact exceptions in +/// addition. +/// These will be raised after any possible handling of the underflow exception. +#define HALF_ERRHANDLING_OVERFLOW_TO_INEXACT 1 +#endif + +#ifndef HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT +/// Raise INEXACT exception on underflow. +/// Defining this to 1 (default) causes underflow errors to automatically raise inexact exceptions +/// in addition. +/// These will be raised after any possible handling of the underflow exception. +/// +/// **Note:** This will actually cause underflow (and the accompanying inexact) exceptions to be +/// raised *only* when the result +/// is inexact, while if disabled bare underflow errors will be raised for *any* (possibly exact) +/// subnormal result. +#define HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT 1 +#endif + +/// Default rounding mode. +/// This specifies the rounding mode used for all conversions between [half](\ref half_float::half)s +/// and more precise types +/// (unless using half_cast() and specifying the rounding mode directly) as well as in arithmetic +/// operations and mathematical +/// functions. It can be redefined (before including half.hpp) to one of the standard rounding modes +/// using their respective +/// constants or the equivalent values of +/// [std::float_round_style](https://en.cppreference.com/w/cpp/types/numeric_limits/float_round_style): +/// +/// `std::float_round_style` | value | rounding +/// ---------------------------------|-------|------------------------- +/// `std::round_indeterminate` | -1 | fastest +/// `std::round_toward_zero` | 0 | toward zero +/// `std::round_to_nearest` | 1 | to nearest (default) +/// `std::round_toward_infinity` | 2 | toward positive infinity +/// `std::round_toward_neg_infinity` | 3 | toward negative infinity +/// +/// By default this is set to `1` (`std::round_to_nearest`), which rounds results to the nearest +/// representable value. It can even +/// be set to +/// [std::numeric_limits::round_style](https://en.cppreference.com/w/cpp/types/numeric_limits/round_style) +/// to synchronize +/// the rounding mode with that of the built-in single-precision implementation (which is likely +/// `std::round_to_nearest`, though). +#ifndef HALF_ROUND_STYLE +#define HALF_ROUND_STYLE 1 // = std::round_to_nearest +#endif + +/// Value signaling overflow. +/// In correspondence with `HUGE_VAL[F|L]` from `` this symbol expands to a positive value +/// signaling the overflow of an +/// operation, in particular it just evaluates to positive infinity. +/// +/// **See also:** Documentation for +/// [HUGE_VAL](https://en.cppreference.com/w/cpp/numeric/math/HUGE_VAL) +#define HUGE_VALH std::numeric_limits::infinity() + +/// Fast half-precision fma function. +/// This symbol is defined if the fma() function generally executes as fast as, or faster than, a +/// separate +/// half-precision multiplication followed by an addition, which is always the case. +/// +/// **See also:** Documentation for +/// [FP_FAST_FMA](https://en.cppreference.com/w/cpp/numeric/math/fma) +#define FP_FAST_FMAH 1 + +/// Half rounding mode. +/// In correspondence with `FLT_ROUNDS` from `` this symbol expands to the rounding mode +/// used for +/// half-precision operations. It is an alias for [HALF_ROUND_STYLE](\ref HALF_ROUND_STYLE). +/// +/// **See also:** Documentation for +/// [FLT_ROUNDS](https://en.cppreference.com/w/cpp/types/climits/FLT_ROUNDS) +#define HLF_ROUNDS HALF_ROUND_STYLE + +#ifndef FP_ILOGB0 +#define FP_ILOGB0 INT_MIN +#endif +#ifndef FP_ILOGBNAN +#define FP_ILOGBNAN INT_MAX +#endif +#ifndef FP_SUBNORMAL +#define FP_SUBNORMAL 0 +#endif +#ifndef FP_ZERO +#define FP_ZERO 1 +#endif +#ifndef FP_NAN +#define FP_NAN 2 +#endif +#ifndef FP_INFINITE +#define FP_INFINITE 3 +#endif +#ifndef FP_NORMAL +#define FP_NORMAL 4 +#endif + +#if !HALF_ENABLE_CPP11_CFENV && !defined(FE_ALL_EXCEPT) +#define FE_INVALID 0x10 +#define FE_DIVBYZERO 0x08 +#define FE_OVERFLOW 0x04 +#define FE_UNDERFLOW 0x02 +#define FE_INEXACT 0x01 +#define FE_ALL_EXCEPT (FE_INVALID | FE_DIVBYZERO | FE_OVERFLOW | FE_UNDERFLOW | FE_INEXACT) +#endif + +/// Main namespace for half-precision functionality. +/// This namespace contains all the functionality provided by the library. +namespace half_float { +class half; + +#if HALF_ENABLE_CPP11_USER_LITERALS +/// Library-defined half-precision literals. +/// Import this namespace to enable half-precision floating-point literals: +/// ~~~~{.cpp} +/// using namespace half_float::literal; +/// half_float::half = 4.2_h; +/// ~~~~ +namespace literal { +half operator"" _h(long double); +} +#endif + +/// \internal +/// \brief Implementation details. +namespace detail { +#if HALF_ENABLE_CPP11_TYPE_TRAITS +/// Conditional type. +template +struct conditional : std::conditional +{ +}; + +/// Helper for tag dispatching. +template +struct bool_type : std::integral_constant +{ +}; +using std::true_type; +using std::false_type; + +/// Type traits for floating-point types. +template +struct is_float : std::is_floating_point +{ +}; +#else +/// Conditional type. +template +struct conditional +{ + typedef T type; +}; +template +struct conditional +{ + typedef F type; +}; + +/// Helper for tag dispatching. +template +struct bool_type +{ +}; +typedef bool_type true_type; +typedef bool_type false_type; + +/// Type traits for floating-point types. +template +struct is_float : false_type +{ +}; +template +struct is_float : is_float +{ +}; +template +struct is_float : is_float +{ +}; +template +struct is_float : is_float +{ +}; +template <> +struct is_float : true_type +{ +}; +template <> +struct is_float : true_type +{ +}; +template <> +struct is_float : true_type +{ +}; +#endif + +/// Type traits for floating-point bits. +template +struct bits +{ + typedef unsigned char type; +}; +template +struct bits : bits +{ +}; +template +struct bits : bits +{ +}; +template +struct bits : bits +{ +}; + +#if HALF_ENABLE_CPP11_CSTDINT +/// Unsigned integer of (at least) 16 bits width. +typedef std::uint_least16_t uint16; + +/// Fastest unsigned integer of (at least) 32 bits width. +typedef std::uint_fast32_t uint32; + +/// Fastest signed integer of (at least) 32 bits width. +typedef std::int_fast32_t int32; + +/// Unsigned integer of (at least) 32 bits width. +template <> +struct bits +{ + typedef std::uint_least32_t type; +}; + +/// Unsigned integer of (at least) 64 bits width. +template <> +struct bits +{ + typedef std::uint_least64_t type; +}; +#else +/// Unsigned integer of (at least) 16 bits width. +typedef unsigned short uint16; + +/// Fastest unsigned integer of (at least) 32 bits width. +typedef unsigned long uint32; + +/// Fastest unsigned integer of (at least) 32 bits width. +typedef long int32; + +/// Unsigned integer of (at least) 32 bits width. +template <> +struct bits + : conditional::digits >= 32, unsigned int, unsigned long> +{ +}; + +#if HALF_ENABLE_CPP11_LONG_LONG +/// Unsigned integer of (at least) 64 bits width. +template <> +struct bits : conditional::digits >= 64, + unsigned long, + unsigned long long> +{ +}; +#else +/// Unsigned integer of (at least) 64 bits width. +template <> +struct bits +{ + typedef unsigned long type; +}; +#endif +#endif + +#ifdef HALF_ARITHMETIC_TYPE +/// Type to use for arithmetic computations and mathematic functions internally. +typedef HALF_ARITHMETIC_TYPE internal_t; +#endif + +/// Tag type for binary construction. +struct binary_t +{ +}; + +/// Tag for binary construction. +HALF_CONSTEXPR_CONST binary_t binary = binary_t(); + +/// \name Implementation defined classification and arithmetic +/// \{ + +/// Check for infinity. +/// \tparam T argument type (builtin floating-point type) +/// \param arg value to query +/// \retval true if infinity +/// \retval false else +template +bool builtin_isinf(T arg) +{ +#if HALF_ENABLE_CPP11_CMATH + return std::isinf(arg); +#elif defined(_MSC_VER) + return !::_finite(static_cast(arg)) && !::_isnan(static_cast(arg)); +#else + return arg == std::numeric_limits::infinity() || arg == -std::numeric_limits::infinity(); +#endif +} + +/// Check for NaN. +/// \tparam T argument type (builtin floating-point type) +/// \param arg value to query +/// \retval true if not a number +/// \retval false else +template +bool builtin_isnan(T arg) +{ +#if HALF_ENABLE_CPP11_CMATH + return std::isnan(arg); +#elif defined(_MSC_VER) + return ::_isnan(static_cast(arg)) != 0; +#else + return arg != arg; +#endif +} + +/// Check sign. +/// \tparam T argument type (builtin floating-point type) +/// \param arg value to query +/// \retval true if signbit set +/// \retval false else +template +bool builtin_signbit(T arg) +{ +#if HALF_ENABLE_CPP11_CMATH + return std::signbit(arg); +#else + return arg < T() || (arg == T() && T(1) / arg < T()); +#endif +} + +/// Platform-independent sign mask. +/// \param arg integer value in two's complement +/// \retval -1 if \a arg negative +/// \retval 0 if \a arg positive +inline uint32 sign_mask(uint32 arg) +{ + static const int N = std::numeric_limits::digits - 1; +#if HALF_TWOS_COMPLEMENT_INT + return static_cast(arg) >> N; +#else + return -((arg >> N) & 1); +#endif +} + +/// Platform-independent arithmetic right shift. +/// \param arg integer value in two's complement +/// \param i shift amount (at most 31) +/// \return \a arg right shifted for \a i bits with possible sign extension +inline uint32 arithmetic_shift(uint32 arg, int i) +{ +#if HALF_TWOS_COMPLEMENT_INT + return static_cast(arg) >> i; +#else + return static_cast(arg) / (static_cast(1) << i) - + ((arg >> (std::numeric_limits::digits - 1)) & 1); +#endif +} + +/// \} +/// \name Error handling +/// \{ + +/// Internal exception flags. +/// \return reference to global exception flags +inline int& errflags() +{ + HALF_THREAD_LOCAL int flags = 0; + return flags; +} + +/// Raise floating-point exception. +/// \param flags exceptions to raise +/// \param cond condition to raise exceptions for +inline void raise(int HALF_UNUSED_NOERR(flags), bool HALF_UNUSED_NOERR(cond) = true) +{ +#if HALF_ERRHANDLING + if(!cond) + return; +#if HALF_ERRHANDLING_FLAGS + errflags() |= flags; +#endif +#if HALF_ERRHANDLING_ERRNO + if(flags & FE_INVALID) + errno = EDOM; + else if(flags & (FE_DIVBYZERO | FE_OVERFLOW | FE_UNDERFLOW)) + errno = ERANGE; +#endif +#if HALF_ERRHANDLING_FENV && HALF_ENABLE_CPP11_CFENV + std::feraiseexcept(flags); +#endif +#ifdef HALF_ERRHANDLING_THROW_INVALID + if(flags & FE_INVALID) + throw std::domain_error(HALF_ERRHANDLING_THROW_INVALID); +#endif +#ifdef HALF_ERRHANDLING_THROW_DIVBYZERO + if(flags & FE_DIVBYZERO) + throw std::domain_error(HALF_ERRHANDLING_THROW_DIVBYZERO); +#endif +#ifdef HALF_ERRHANDLING_THROW_OVERFLOW + if(flags & FE_OVERFLOW) + throw std::overflow_error(HALF_ERRHANDLING_THROW_OVERFLOW); +#endif +#ifdef HALF_ERRHANDLING_THROW_UNDERFLOW + if(flags & FE_UNDERFLOW) + throw std::underflow_error(HALF_ERRHANDLING_THROW_UNDERFLOW); +#endif +#ifdef HALF_ERRHANDLING_THROW_INEXACT + if(flags & FE_INEXACT) + throw std::range_error(HALF_ERRHANDLING_THROW_INEXACT); +#endif +#if HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT + if((flags & FE_UNDERFLOW) && !(flags & FE_INEXACT)) + raise(FE_INEXACT); +#endif +#if HALF_ERRHANDLING_OVERFLOW_TO_INEXACT + if((flags & FE_OVERFLOW) && !(flags & FE_INEXACT)) + raise(FE_INEXACT); +#endif +#endif +} + +/// Check and signal for any NaN. +/// \param x first half-precision value to check +/// \param y second half-precision value to check +/// \retval true if either \a x or \a y is NaN +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool compsignal(unsigned int x, unsigned int y) +{ +#if HALF_ERRHANDLING + raise(FE_INVALID, (x & 0x7FFF) > 0x7C00 || (y & 0x7FFF) > 0x7C00); +#endif + return (x & 0x7FFF) > 0x7C00 || (y & 0x7FFF) > 0x7C00; +} + +/// Signal and silence signaling NaN. +/// \param nan half-precision NaN value +/// \return quiet NaN +/// \exception FE_INVALID if \a nan is signaling NaN +inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int nan) +{ +#if HALF_ERRHANDLING + raise(FE_INVALID, !(nan & 0x200)); +#endif + return nan | 0x200; +} + +/// Signal and silence signaling NaNs. +/// \param x first half-precision value to check +/// \param y second half-precision value to check +/// \return quiet NaN +/// \exception FE_INVALID if \a x or \a y is signaling NaN +inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int x, unsigned int y) +{ +#if HALF_ERRHANDLING + raise(FE_INVALID, + ((x & 0x7FFF) > 0x7C00 && !(x & 0x200)) || ((y & 0x7FFF) > 0x7C00 && !(y & 0x200))); +#endif + return ((x & 0x7FFF) > 0x7C00) ? (x | 0x200) : (y | 0x200); +} + +/// Signal and silence signaling NaNs. +/// \param x first half-precision value to check +/// \param y second half-precision value to check +/// \param z third half-precision value to check +/// \return quiet NaN +/// \exception FE_INVALID if \a x, \a y or \a z is signaling NaN +inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int x, unsigned int y, unsigned int z) +{ +#if HALF_ERRHANDLING + raise(FE_INVALID, + ((x & 0x7FFF) > 0x7C00 && !(x & 0x200)) || ((y & 0x7FFF) > 0x7C00 && !(y & 0x200)) || + ((z & 0x7FFF) > 0x7C00 && !(z & 0x200))); +#endif + return ((x & 0x7FFF) > 0x7C00) ? (x | 0x200) : ((y & 0x7FFF) > 0x7C00) ? (y | 0x200) + : (z | 0x200); +} + +/// Select value or signaling NaN. +/// \param x preferred half-precision value +/// \param y ignored half-precision value except for signaling NaN +/// \return \a y if signaling NaN, \a x otherwise +/// \exception FE_INVALID if \a y is signaling NaN +inline HALF_CONSTEXPR_NOERR unsigned int select(unsigned int x, unsigned int HALF_UNUSED_NOERR(y)) +{ +#if HALF_ERRHANDLING + return (((y & 0x7FFF) > 0x7C00) && !(y & 0x200)) ? signal(y) : x; +#else + return x; +#endif +} + +/// Raise domain error and return NaN. +/// return quiet NaN +/// \exception FE_INVALID +inline HALF_CONSTEXPR_NOERR unsigned int invalid() +{ +#if HALF_ERRHANDLING + raise(FE_INVALID); +#endif + return 0x7FFF; +} + +/// Raise pole error and return infinity. +/// \param sign half-precision value with sign bit only +/// \return half-precision infinity with sign of \a sign +/// \exception FE_DIVBYZERO +inline HALF_CONSTEXPR_NOERR unsigned int pole(unsigned int sign = 0) +{ +#if HALF_ERRHANDLING + raise(FE_DIVBYZERO); +#endif + return sign | 0x7C00; +} + +/// Check value for underflow. +/// \param arg non-zero half-precision value to check +/// \return \a arg +/// \exception FE_UNDERFLOW if arg is subnormal +inline HALF_CONSTEXPR_NOERR unsigned int check_underflow(unsigned int arg) +{ +#if HALF_ERRHANDLING && !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT + raise(FE_UNDERFLOW, !(arg & 0x7C00)); +#endif + return arg; +} + +/// \} +/// \name Conversion and rounding +/// \{ + +/// Half-precision overflow. +/// \tparam R rounding mode to use +/// \param sign half-precision value with sign bit only +/// \return rounded overflowing half-precision value +/// \exception FE_OVERFLOW +template +HALF_CONSTEXPR_NOERR unsigned int overflow(unsigned int sign = 0) +{ +#if HALF_ERRHANDLING + raise(FE_OVERFLOW); +#endif + return (R == std::round_toward_infinity) + ? (sign + 0x7C00 - (sign >> 15)) + : (R == std::round_toward_neg_infinity) + ? (sign + 0x7BFF + (sign >> 15)) + : (R == std::round_toward_zero) ? (sign | 0x7BFF) : (sign | 0x7C00); +} + +/// Half-precision underflow. +/// \tparam R rounding mode to use +/// \param sign half-precision value with sign bit only +/// \return rounded underflowing half-precision value +/// \exception FE_UNDERFLOW +template +HALF_CONSTEXPR_NOERR unsigned int underflow(unsigned int sign = 0) +{ +#if HALF_ERRHANDLING + raise(FE_UNDERFLOW); +#endif + return (R == std::round_toward_infinity) + ? (sign + 1 - (sign >> 15)) + : (R == std::round_toward_neg_infinity) ? (sign + (sign >> 15)) : sign; +} + +/// Round half-precision number. +/// \tparam R rounding mode to use +/// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results +/// \param value finite half-precision number to round +/// \param g guard bit (most significant discarded bit) +/// \param s sticky bit (or of all but the most significant discarded bits) +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded or \a I is `true` +template +HALF_CONSTEXPR_NOERR unsigned int rounded(unsigned int value, int g, int s) +{ +#if HALF_ERRHANDLING + value += (R == std::round_to_nearest) + ? (g & (s | value)) + : (R == std::round_toward_infinity) + ? (~(value >> 15) & (g | s)) + : (R == std::round_toward_neg_infinity) ? ((value >> 15) & (g | s)) : 0; + if((value & 0x7C00) == 0x7C00) + raise(FE_OVERFLOW); + else if(value & 0x7C00) + raise(FE_INEXACT, I || (g | s) != 0); + else + raise(FE_UNDERFLOW, !(HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT) || I || (g | s) != 0); + return value; +#else + return (R == std::round_to_nearest) + ? (value + (g & (s | value))) + : (R == std::round_toward_infinity) + ? (value + (~(value >> 15) & (g | s))) + : (R == std::round_toward_neg_infinity) ? (value + ((value >> 15) & (g | s))) + : value; +#endif +} + +/// Round half-precision number to nearest integer value. +/// \tparam R rounding mode to use +/// \tparam E `true` for round to even, `false` for round away from zero +/// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never raise it +/// \param value half-precision value to round +/// \return half-precision bits for nearest integral value +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded and \a I is `true` +template +unsigned int integral(unsigned int value) +{ + unsigned int abs = value & 0x7FFF; + if(abs < 0x3C00) + { + raise(FE_INEXACT, I); + return ((R == std::round_to_nearest) + ? (0x3C00 & -static_cast(abs >= (0x3800 + E))) + : (R == std::round_toward_infinity) + ? (0x3C00 & -(~(value >> 15) & (abs != 0))) + : (R == std::round_toward_neg_infinity) + ? (0x3C00 & -static_cast(value > 0x8000)) + : 0) | + (value & 0x8000); + } + if(abs >= 0x6400) + return (abs > 0x7C00) ? signal(value) : value; + unsigned int exp = 25 - (abs >> 10), mask = (1 << exp) - 1; + raise(FE_INEXACT, I && (value & mask)); + return (((R == std::round_to_nearest) + ? ((1 << (exp - 1)) - (~(value >> exp) & E)) + : (R == std::round_toward_infinity) + ? (mask & ((value >> 15) - 1)) + : (R == std::round_toward_neg_infinity) ? (mask & -(value >> 15)) : 0) + + value) & + ~mask; +} + +/// Convert fixed point to half-precision floating-point. +/// \tparam R rounding mode to use +/// \tparam F number of fractional bits (at least 11) +/// \tparam S `true` for signed, `false` for unsigned +/// \tparam N `true` for additional normalization step, `false` if already normalized to 1.F +/// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results +/// \param m mantissa in Q1.F fixed point format +/// \param exp exponent +/// \param sign half-precision value with sign bit only +/// \param s sticky bit (or of all but the most significant already discarded bits) +/// \return value converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded or \a I is `true` +template +unsigned int fixed2half(uint32 m, int exp = 14, unsigned int sign = 0, int s = 0) +{ + if(S) + { + uint32 msign = sign_mask(m); + m = (m ^ msign) - msign; + sign = msign & 0x8000; + } + if(N) + for(; m < (static_cast(1) << F) && exp; m <<= 1, --exp) + ; + else if(exp < 0) + return rounded(sign + (m >> (F - 10 - exp)), + (m >> (F - 11 - exp)) & 1, + s | ((m & ((static_cast(1) << (F - 11 - exp)) - 1)) != 0)); + return rounded(sign + (exp << 10) + (m >> (F - 10)), + (m >> (F - 11)) & 1, + s | ((m & ((static_cast(1) << (F - 11)) - 1)) != 0)); +} + +/// Convert IEEE single-precision to half-precision. +/// Credit for this goes to [Jeroen van der +/// Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). +/// \tparam R rounding mode to use +/// \param value single-precision value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int float2half_impl(float value, true_type) +{ +#if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtsi128_si32(_mm_cvtps_ph(_mm_set_ss(value), + (R == std::round_to_nearest) + ? _MM_FROUND_TO_NEAREST_INT + : (R == std::round_toward_zero) + ? _MM_FROUND_TO_ZERO + : (R == std::round_toward_infinity) + ? _MM_FROUND_TO_POS_INF + : (R == std::round_toward_neg_infinity) + ? _MM_FROUND_TO_NEG_INF + : _MM_FROUND_CUR_DIRECTION)); +#else + bits::type fbits; + std::memcpy(&fbits, &value, sizeof(float)); +#if 1 + unsigned int sign = (fbits >> 16) & 0x8000; + fbits &= 0x7FFFFFFF; + if(fbits >= 0x7F800000) + return sign | 0x7C00 | ((fbits > 0x7F800000) ? (0x200 | ((fbits >> 13) & 0x3FF)) : 0); + if(fbits >= 0x47800000) + return overflow(sign); + if(fbits >= 0x38800000) + return rounded(sign | (((fbits >> 23) - 112) << 10) | ((fbits >> 13) & 0x3FF), + (fbits >> 12) & 1, + (fbits & 0xFFF) != 0); + if(fbits >= 0x33000000) + { + int i = 125 - (fbits >> 23); + fbits = (fbits & 0x7FFFFF) | 0x800000; + return rounded(sign | (fbits >> (i + 1)), + (fbits >> i) & 1, + (fbits & ((static_cast(1) << i) - 1)) != 0); + } + if(fbits != 0) + return underflow(sign); + return sign; +#else + static const uint16 base_table[512] = { + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020, 0x0040, + 0x0080, 0x0100, 0x0200, 0x0400, 0x0800, 0x0C00, 0x1000, 0x1400, 0x1800, 0x1C00, 0x2000, + 0x2400, 0x2800, 0x2C00, 0x3000, 0x3400, 0x3800, 0x3C00, 0x4000, 0x4400, 0x4800, 0x4C00, + 0x5000, 0x5400, 0x5800, 0x5C00, 0x6000, 0x6400, 0x6800, 0x6C00, 0x7000, 0x7400, 0x7800, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, 0x7BFF, + 0x7BFF, 0x7BFF, 0x7C00, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, + 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8000, 0x8001, 0x8002, 0x8004, 0x8008, + 0x8010, 0x8020, 0x8040, 0x8080, 0x8100, 0x8200, 0x8400, 0x8800, 0x8C00, 0x9000, 0x9400, + 0x9800, 0x9C00, 0xA000, 0xA400, 0xA800, 0xAC00, 0xB000, 0xB400, 0xB800, 0xBC00, 0xC000, + 0xC400, 0xC800, 0xCC00, 0xD000, 0xD400, 0xD800, 0xDC00, 0xE000, 0xE400, 0xE800, 0xEC00, + 0xF000, 0xF400, 0xF800, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, + 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFBFF, 0xFC00}; + static const unsigned char shift_table[256] = { + 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 24, 23, 22, 21, 20, 19, 18, 17, + 16, 15, 14, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, + 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 13}; + int sexp = fbits >> 23, exp = sexp & 0xFF, i = shift_table[exp]; + fbits &= 0x7FFFFF; + uint32 m = (fbits | ((exp != 0) << 23)) & -static_cast(exp != 0xFF); + return rounded(base_table[sexp] + (fbits >> i), + (m >> (i - 1)) & 1, + (((static_cast(1) << (i - 1)) - 1) & m) != 0); +#endif +#endif +} + +/// Convert IEEE double-precision to half-precision. +/// \tparam R rounding mode to use +/// \param value double-precision value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int float2half_impl(double value, true_type) +{ +#if HALF_ENABLE_F16C_INTRINSICS + if(R == std::round_indeterminate) + return _mm_cvtsi128_si32( + _mm_cvtps_ph(_mm_cvtpd_ps(_mm_set_sd(value)), _MM_FROUND_CUR_DIRECTION)); +#endif + bits::type dbits; + std::memcpy(&dbits, &value, sizeof(double)); + uint32 hi = dbits >> 32, lo = dbits & 0xFFFFFFFF; + unsigned int sign = (hi >> 16) & 0x8000; + hi &= 0x7FFFFFFF; + if(hi >= 0x7FF00000) + return sign | 0x7C00 | ((dbits & 0xFFFFFFFFFFFFF) ? (0x200 | ((hi >> 10) & 0x3FF)) : 0); + if(hi >= 0x40F00000) + return overflow(sign); + if(hi >= 0x3F100000) + return rounded(sign | (((hi >> 20) - 1008) << 10) | ((hi >> 10) & 0x3FF), + (hi >> 9) & 1, + ((hi & 0x1FF) | lo) != 0); + if(hi >= 0x3E600000) + { + int i = 1018 - (hi >> 20); + hi = (hi & 0xFFFFF) | 0x100000; + return rounded(sign | (hi >> (i + 1)), + (hi >> i) & 1, + ((hi & ((static_cast(1) << i) - 1)) | lo) != 0); + } + if((hi | lo) != 0) + return underflow(sign); + return sign; +} + +/// Convert non-IEEE floating-point to half-precision. +/// \tparam R rounding mode to use +/// \tparam T source type (builtin floating-point type) +/// \param value floating-point value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int float2half_impl(T value, ...) +{ + unsigned int hbits = static_cast(builtin_signbit(value)) << 15; + if(value == T()) + return hbits; + if(builtin_isnan(value)) + return hbits | 0x7FFF; + if(builtin_isinf(value)) + return hbits | 0x7C00; + int exp; + std::frexp(value, &exp); + if(exp > 16) + return overflow(hbits); + if(exp < -13) + value = std::ldexp(value, 25); + else + { + value = std::ldexp(value, 12 - exp); + hbits |= ((exp + 13) << 10); + } + T ival, frac = std::modf(value, &ival); + int m = std::abs(static_cast(ival)); + return rounded(hbits + (m >> 1), m & 1, frac != T()); +} + +/// Convert floating-point to half-precision. +/// \tparam R rounding mode to use +/// \tparam T source type (builtin floating-point type) +/// \param value floating-point value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int float2half(T value) +{ + return float2half_impl(value, + bool_type < std::numeric_limits::is_iec559 && + sizeof(typename bits::type) == sizeof(T) > ()); +} + +/// Convert integer to half-precision floating-point. +/// \tparam R rounding mode to use +/// \tparam T type to convert (builtin integer type) +/// \param value integral value to convert +/// \return rounded half-precision value +/// \exception FE_OVERFLOW on overflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int int2half(T value) +{ + unsigned int bits = static_cast(value < 0) << 15; + if(!value) + return bits; + if(bits) + value = -value; + if(value > 0xFFFF) + return overflow(bits); + unsigned int m = static_cast(value), exp = 24; + for(; m < 0x400; m <<= 1, --exp) + ; + for(; m > 0x7FF; m >>= 1, ++exp) + ; + bits |= (exp << 10) + m; + return (exp > 24) ? rounded( + bits, (value >> (exp - 25)) & 1, (((1 << (exp - 25)) - 1) & value) != 0) + : bits; +} + +/// Convert half-precision to IEEE single-precision. +/// Credit for this goes to [Jeroen van der +/// Zijp](ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf). +/// \param value half-precision value to convert +/// \return single-precision value +inline float half2float_impl(unsigned int value, float, true_type) +{ +#if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(value))); +#else +#if 0 + bits::type fbits = static_cast::type>(value&0x8000) << 16; + int abs = value & 0x7FFF; + if(abs) + { + fbits |= 0x38000000 << static_cast(abs>=0x7C00); + for(; abs<0x400; abs<<=1,fbits-=0x800000) ; + fbits += static_cast::type>(abs) << 13; + } +#else + static const bits::type mantissa_table[2048] = { + 0x00000000, 0x33800000, 0x34000000, 0x34400000, 0x34800000, 0x34A00000, 0x34C00000, + 0x34E00000, 0x35000000, 0x35100000, 0x35200000, 0x35300000, 0x35400000, 0x35500000, + 0x35600000, 0x35700000, 0x35800000, 0x35880000, 0x35900000, 0x35980000, 0x35A00000, + 0x35A80000, 0x35B00000, 0x35B80000, 0x35C00000, 0x35C80000, 0x35D00000, 0x35D80000, + 0x35E00000, 0x35E80000, 0x35F00000, 0x35F80000, 0x36000000, 0x36040000, 0x36080000, + 0x360C0000, 0x36100000, 0x36140000, 0x36180000, 0x361C0000, 0x36200000, 0x36240000, + 0x36280000, 0x362C0000, 0x36300000, 0x36340000, 0x36380000, 0x363C0000, 0x36400000, + 0x36440000, 0x36480000, 0x364C0000, 0x36500000, 0x36540000, 0x36580000, 0x365C0000, + 0x36600000, 0x36640000, 0x36680000, 0x366C0000, 0x36700000, 0x36740000, 0x36780000, + 0x367C0000, 0x36800000, 0x36820000, 0x36840000, 0x36860000, 0x36880000, 0x368A0000, + 0x368C0000, 0x368E0000, 0x36900000, 0x36920000, 0x36940000, 0x36960000, 0x36980000, + 0x369A0000, 0x369C0000, 0x369E0000, 0x36A00000, 0x36A20000, 0x36A40000, 0x36A60000, + 0x36A80000, 0x36AA0000, 0x36AC0000, 0x36AE0000, 0x36B00000, 0x36B20000, 0x36B40000, + 0x36B60000, 0x36B80000, 0x36BA0000, 0x36BC0000, 0x36BE0000, 0x36C00000, 0x36C20000, + 0x36C40000, 0x36C60000, 0x36C80000, 0x36CA0000, 0x36CC0000, 0x36CE0000, 0x36D00000, + 0x36D20000, 0x36D40000, 0x36D60000, 0x36D80000, 0x36DA0000, 0x36DC0000, 0x36DE0000, + 0x36E00000, 0x36E20000, 0x36E40000, 0x36E60000, 0x36E80000, 0x36EA0000, 0x36EC0000, + 0x36EE0000, 0x36F00000, 0x36F20000, 0x36F40000, 0x36F60000, 0x36F80000, 0x36FA0000, + 0x36FC0000, 0x36FE0000, 0x37000000, 0x37010000, 0x37020000, 0x37030000, 0x37040000, + 0x37050000, 0x37060000, 0x37070000, 0x37080000, 0x37090000, 0x370A0000, 0x370B0000, + 0x370C0000, 0x370D0000, 0x370E0000, 0x370F0000, 0x37100000, 0x37110000, 0x37120000, + 0x37130000, 0x37140000, 0x37150000, 0x37160000, 0x37170000, 0x37180000, 0x37190000, + 0x371A0000, 0x371B0000, 0x371C0000, 0x371D0000, 0x371E0000, 0x371F0000, 0x37200000, + 0x37210000, 0x37220000, 0x37230000, 0x37240000, 0x37250000, 0x37260000, 0x37270000, + 0x37280000, 0x37290000, 0x372A0000, 0x372B0000, 0x372C0000, 0x372D0000, 0x372E0000, + 0x372F0000, 0x37300000, 0x37310000, 0x37320000, 0x37330000, 0x37340000, 0x37350000, + 0x37360000, 0x37370000, 0x37380000, 0x37390000, 0x373A0000, 0x373B0000, 0x373C0000, + 0x373D0000, 0x373E0000, 0x373F0000, 0x37400000, 0x37410000, 0x37420000, 0x37430000, + 0x37440000, 0x37450000, 0x37460000, 0x37470000, 0x37480000, 0x37490000, 0x374A0000, + 0x374B0000, 0x374C0000, 0x374D0000, 0x374E0000, 0x374F0000, 0x37500000, 0x37510000, + 0x37520000, 0x37530000, 0x37540000, 0x37550000, 0x37560000, 0x37570000, 0x37580000, + 0x37590000, 0x375A0000, 0x375B0000, 0x375C0000, 0x375D0000, 0x375E0000, 0x375F0000, + 0x37600000, 0x37610000, 0x37620000, 0x37630000, 0x37640000, 0x37650000, 0x37660000, + 0x37670000, 0x37680000, 0x37690000, 0x376A0000, 0x376B0000, 0x376C0000, 0x376D0000, + 0x376E0000, 0x376F0000, 0x37700000, 0x37710000, 0x37720000, 0x37730000, 0x37740000, + 0x37750000, 0x37760000, 0x37770000, 0x37780000, 0x37790000, 0x377A0000, 0x377B0000, + 0x377C0000, 0x377D0000, 0x377E0000, 0x377F0000, 0x37800000, 0x37808000, 0x37810000, + 0x37818000, 0x37820000, 0x37828000, 0x37830000, 0x37838000, 0x37840000, 0x37848000, + 0x37850000, 0x37858000, 0x37860000, 0x37868000, 0x37870000, 0x37878000, 0x37880000, + 0x37888000, 0x37890000, 0x37898000, 0x378A0000, 0x378A8000, 0x378B0000, 0x378B8000, + 0x378C0000, 0x378C8000, 0x378D0000, 0x378D8000, 0x378E0000, 0x378E8000, 0x378F0000, + 0x378F8000, 0x37900000, 0x37908000, 0x37910000, 0x37918000, 0x37920000, 0x37928000, + 0x37930000, 0x37938000, 0x37940000, 0x37948000, 0x37950000, 0x37958000, 0x37960000, + 0x37968000, 0x37970000, 0x37978000, 0x37980000, 0x37988000, 0x37990000, 0x37998000, + 0x379A0000, 0x379A8000, 0x379B0000, 0x379B8000, 0x379C0000, 0x379C8000, 0x379D0000, + 0x379D8000, 0x379E0000, 0x379E8000, 0x379F0000, 0x379F8000, 0x37A00000, 0x37A08000, + 0x37A10000, 0x37A18000, 0x37A20000, 0x37A28000, 0x37A30000, 0x37A38000, 0x37A40000, + 0x37A48000, 0x37A50000, 0x37A58000, 0x37A60000, 0x37A68000, 0x37A70000, 0x37A78000, + 0x37A80000, 0x37A88000, 0x37A90000, 0x37A98000, 0x37AA0000, 0x37AA8000, 0x37AB0000, + 0x37AB8000, 0x37AC0000, 0x37AC8000, 0x37AD0000, 0x37AD8000, 0x37AE0000, 0x37AE8000, + 0x37AF0000, 0x37AF8000, 0x37B00000, 0x37B08000, 0x37B10000, 0x37B18000, 0x37B20000, + 0x37B28000, 0x37B30000, 0x37B38000, 0x37B40000, 0x37B48000, 0x37B50000, 0x37B58000, + 0x37B60000, 0x37B68000, 0x37B70000, 0x37B78000, 0x37B80000, 0x37B88000, 0x37B90000, + 0x37B98000, 0x37BA0000, 0x37BA8000, 0x37BB0000, 0x37BB8000, 0x37BC0000, 0x37BC8000, + 0x37BD0000, 0x37BD8000, 0x37BE0000, 0x37BE8000, 0x37BF0000, 0x37BF8000, 0x37C00000, + 0x37C08000, 0x37C10000, 0x37C18000, 0x37C20000, 0x37C28000, 0x37C30000, 0x37C38000, + 0x37C40000, 0x37C48000, 0x37C50000, 0x37C58000, 0x37C60000, 0x37C68000, 0x37C70000, + 0x37C78000, 0x37C80000, 0x37C88000, 0x37C90000, 0x37C98000, 0x37CA0000, 0x37CA8000, + 0x37CB0000, 0x37CB8000, 0x37CC0000, 0x37CC8000, 0x37CD0000, 0x37CD8000, 0x37CE0000, + 0x37CE8000, 0x37CF0000, 0x37CF8000, 0x37D00000, 0x37D08000, 0x37D10000, 0x37D18000, + 0x37D20000, 0x37D28000, 0x37D30000, 0x37D38000, 0x37D40000, 0x37D48000, 0x37D50000, + 0x37D58000, 0x37D60000, 0x37D68000, 0x37D70000, 0x37D78000, 0x37D80000, 0x37D88000, + 0x37D90000, 0x37D98000, 0x37DA0000, 0x37DA8000, 0x37DB0000, 0x37DB8000, 0x37DC0000, + 0x37DC8000, 0x37DD0000, 0x37DD8000, 0x37DE0000, 0x37DE8000, 0x37DF0000, 0x37DF8000, + 0x37E00000, 0x37E08000, 0x37E10000, 0x37E18000, 0x37E20000, 0x37E28000, 0x37E30000, + 0x37E38000, 0x37E40000, 0x37E48000, 0x37E50000, 0x37E58000, 0x37E60000, 0x37E68000, + 0x37E70000, 0x37E78000, 0x37E80000, 0x37E88000, 0x37E90000, 0x37E98000, 0x37EA0000, + 0x37EA8000, 0x37EB0000, 0x37EB8000, 0x37EC0000, 0x37EC8000, 0x37ED0000, 0x37ED8000, + 0x37EE0000, 0x37EE8000, 0x37EF0000, 0x37EF8000, 0x37F00000, 0x37F08000, 0x37F10000, + 0x37F18000, 0x37F20000, 0x37F28000, 0x37F30000, 0x37F38000, 0x37F40000, 0x37F48000, + 0x37F50000, 0x37F58000, 0x37F60000, 0x37F68000, 0x37F70000, 0x37F78000, 0x37F80000, + 0x37F88000, 0x37F90000, 0x37F98000, 0x37FA0000, 0x37FA8000, 0x37FB0000, 0x37FB8000, + 0x37FC0000, 0x37FC8000, 0x37FD0000, 0x37FD8000, 0x37FE0000, 0x37FE8000, 0x37FF0000, + 0x37FF8000, 0x38000000, 0x38004000, 0x38008000, 0x3800C000, 0x38010000, 0x38014000, + 0x38018000, 0x3801C000, 0x38020000, 0x38024000, 0x38028000, 0x3802C000, 0x38030000, + 0x38034000, 0x38038000, 0x3803C000, 0x38040000, 0x38044000, 0x38048000, 0x3804C000, + 0x38050000, 0x38054000, 0x38058000, 0x3805C000, 0x38060000, 0x38064000, 0x38068000, + 0x3806C000, 0x38070000, 0x38074000, 0x38078000, 0x3807C000, 0x38080000, 0x38084000, + 0x38088000, 0x3808C000, 0x38090000, 0x38094000, 0x38098000, 0x3809C000, 0x380A0000, + 0x380A4000, 0x380A8000, 0x380AC000, 0x380B0000, 0x380B4000, 0x380B8000, 0x380BC000, + 0x380C0000, 0x380C4000, 0x380C8000, 0x380CC000, 0x380D0000, 0x380D4000, 0x380D8000, + 0x380DC000, 0x380E0000, 0x380E4000, 0x380E8000, 0x380EC000, 0x380F0000, 0x380F4000, + 0x380F8000, 0x380FC000, 0x38100000, 0x38104000, 0x38108000, 0x3810C000, 0x38110000, + 0x38114000, 0x38118000, 0x3811C000, 0x38120000, 0x38124000, 0x38128000, 0x3812C000, + 0x38130000, 0x38134000, 0x38138000, 0x3813C000, 0x38140000, 0x38144000, 0x38148000, + 0x3814C000, 0x38150000, 0x38154000, 0x38158000, 0x3815C000, 0x38160000, 0x38164000, + 0x38168000, 0x3816C000, 0x38170000, 0x38174000, 0x38178000, 0x3817C000, 0x38180000, + 0x38184000, 0x38188000, 0x3818C000, 0x38190000, 0x38194000, 0x38198000, 0x3819C000, + 0x381A0000, 0x381A4000, 0x381A8000, 0x381AC000, 0x381B0000, 0x381B4000, 0x381B8000, + 0x381BC000, 0x381C0000, 0x381C4000, 0x381C8000, 0x381CC000, 0x381D0000, 0x381D4000, + 0x381D8000, 0x381DC000, 0x381E0000, 0x381E4000, 0x381E8000, 0x381EC000, 0x381F0000, + 0x381F4000, 0x381F8000, 0x381FC000, 0x38200000, 0x38204000, 0x38208000, 0x3820C000, + 0x38210000, 0x38214000, 0x38218000, 0x3821C000, 0x38220000, 0x38224000, 0x38228000, + 0x3822C000, 0x38230000, 0x38234000, 0x38238000, 0x3823C000, 0x38240000, 0x38244000, + 0x38248000, 0x3824C000, 0x38250000, 0x38254000, 0x38258000, 0x3825C000, 0x38260000, + 0x38264000, 0x38268000, 0x3826C000, 0x38270000, 0x38274000, 0x38278000, 0x3827C000, + 0x38280000, 0x38284000, 0x38288000, 0x3828C000, 0x38290000, 0x38294000, 0x38298000, + 0x3829C000, 0x382A0000, 0x382A4000, 0x382A8000, 0x382AC000, 0x382B0000, 0x382B4000, + 0x382B8000, 0x382BC000, 0x382C0000, 0x382C4000, 0x382C8000, 0x382CC000, 0x382D0000, + 0x382D4000, 0x382D8000, 0x382DC000, 0x382E0000, 0x382E4000, 0x382E8000, 0x382EC000, + 0x382F0000, 0x382F4000, 0x382F8000, 0x382FC000, 0x38300000, 0x38304000, 0x38308000, + 0x3830C000, 0x38310000, 0x38314000, 0x38318000, 0x3831C000, 0x38320000, 0x38324000, + 0x38328000, 0x3832C000, 0x38330000, 0x38334000, 0x38338000, 0x3833C000, 0x38340000, + 0x38344000, 0x38348000, 0x3834C000, 0x38350000, 0x38354000, 0x38358000, 0x3835C000, + 0x38360000, 0x38364000, 0x38368000, 0x3836C000, 0x38370000, 0x38374000, 0x38378000, + 0x3837C000, 0x38380000, 0x38384000, 0x38388000, 0x3838C000, 0x38390000, 0x38394000, + 0x38398000, 0x3839C000, 0x383A0000, 0x383A4000, 0x383A8000, 0x383AC000, 0x383B0000, + 0x383B4000, 0x383B8000, 0x383BC000, 0x383C0000, 0x383C4000, 0x383C8000, 0x383CC000, + 0x383D0000, 0x383D4000, 0x383D8000, 0x383DC000, 0x383E0000, 0x383E4000, 0x383E8000, + 0x383EC000, 0x383F0000, 0x383F4000, 0x383F8000, 0x383FC000, 0x38400000, 0x38404000, + 0x38408000, 0x3840C000, 0x38410000, 0x38414000, 0x38418000, 0x3841C000, 0x38420000, + 0x38424000, 0x38428000, 0x3842C000, 0x38430000, 0x38434000, 0x38438000, 0x3843C000, + 0x38440000, 0x38444000, 0x38448000, 0x3844C000, 0x38450000, 0x38454000, 0x38458000, + 0x3845C000, 0x38460000, 0x38464000, 0x38468000, 0x3846C000, 0x38470000, 0x38474000, + 0x38478000, 0x3847C000, 0x38480000, 0x38484000, 0x38488000, 0x3848C000, 0x38490000, + 0x38494000, 0x38498000, 0x3849C000, 0x384A0000, 0x384A4000, 0x384A8000, 0x384AC000, + 0x384B0000, 0x384B4000, 0x384B8000, 0x384BC000, 0x384C0000, 0x384C4000, 0x384C8000, + 0x384CC000, 0x384D0000, 0x384D4000, 0x384D8000, 0x384DC000, 0x384E0000, 0x384E4000, + 0x384E8000, 0x384EC000, 0x384F0000, 0x384F4000, 0x384F8000, 0x384FC000, 0x38500000, + 0x38504000, 0x38508000, 0x3850C000, 0x38510000, 0x38514000, 0x38518000, 0x3851C000, + 0x38520000, 0x38524000, 0x38528000, 0x3852C000, 0x38530000, 0x38534000, 0x38538000, + 0x3853C000, 0x38540000, 0x38544000, 0x38548000, 0x3854C000, 0x38550000, 0x38554000, + 0x38558000, 0x3855C000, 0x38560000, 0x38564000, 0x38568000, 0x3856C000, 0x38570000, + 0x38574000, 0x38578000, 0x3857C000, 0x38580000, 0x38584000, 0x38588000, 0x3858C000, + 0x38590000, 0x38594000, 0x38598000, 0x3859C000, 0x385A0000, 0x385A4000, 0x385A8000, + 0x385AC000, 0x385B0000, 0x385B4000, 0x385B8000, 0x385BC000, 0x385C0000, 0x385C4000, + 0x385C8000, 0x385CC000, 0x385D0000, 0x385D4000, 0x385D8000, 0x385DC000, 0x385E0000, + 0x385E4000, 0x385E8000, 0x385EC000, 0x385F0000, 0x385F4000, 0x385F8000, 0x385FC000, + 0x38600000, 0x38604000, 0x38608000, 0x3860C000, 0x38610000, 0x38614000, 0x38618000, + 0x3861C000, 0x38620000, 0x38624000, 0x38628000, 0x3862C000, 0x38630000, 0x38634000, + 0x38638000, 0x3863C000, 0x38640000, 0x38644000, 0x38648000, 0x3864C000, 0x38650000, + 0x38654000, 0x38658000, 0x3865C000, 0x38660000, 0x38664000, 0x38668000, 0x3866C000, + 0x38670000, 0x38674000, 0x38678000, 0x3867C000, 0x38680000, 0x38684000, 0x38688000, + 0x3868C000, 0x38690000, 0x38694000, 0x38698000, 0x3869C000, 0x386A0000, 0x386A4000, + 0x386A8000, 0x386AC000, 0x386B0000, 0x386B4000, 0x386B8000, 0x386BC000, 0x386C0000, + 0x386C4000, 0x386C8000, 0x386CC000, 0x386D0000, 0x386D4000, 0x386D8000, 0x386DC000, + 0x386E0000, 0x386E4000, 0x386E8000, 0x386EC000, 0x386F0000, 0x386F4000, 0x386F8000, + 0x386FC000, 0x38700000, 0x38704000, 0x38708000, 0x3870C000, 0x38710000, 0x38714000, + 0x38718000, 0x3871C000, 0x38720000, 0x38724000, 0x38728000, 0x3872C000, 0x38730000, + 0x38734000, 0x38738000, 0x3873C000, 0x38740000, 0x38744000, 0x38748000, 0x3874C000, + 0x38750000, 0x38754000, 0x38758000, 0x3875C000, 0x38760000, 0x38764000, 0x38768000, + 0x3876C000, 0x38770000, 0x38774000, 0x38778000, 0x3877C000, 0x38780000, 0x38784000, + 0x38788000, 0x3878C000, 0x38790000, 0x38794000, 0x38798000, 0x3879C000, 0x387A0000, + 0x387A4000, 0x387A8000, 0x387AC000, 0x387B0000, 0x387B4000, 0x387B8000, 0x387BC000, + 0x387C0000, 0x387C4000, 0x387C8000, 0x387CC000, 0x387D0000, 0x387D4000, 0x387D8000, + 0x387DC000, 0x387E0000, 0x387E4000, 0x387E8000, 0x387EC000, 0x387F0000, 0x387F4000, + 0x387F8000, 0x387FC000, 0x38000000, 0x38002000, 0x38004000, 0x38006000, 0x38008000, + 0x3800A000, 0x3800C000, 0x3800E000, 0x38010000, 0x38012000, 0x38014000, 0x38016000, + 0x38018000, 0x3801A000, 0x3801C000, 0x3801E000, 0x38020000, 0x38022000, 0x38024000, + 0x38026000, 0x38028000, 0x3802A000, 0x3802C000, 0x3802E000, 0x38030000, 0x38032000, + 0x38034000, 0x38036000, 0x38038000, 0x3803A000, 0x3803C000, 0x3803E000, 0x38040000, + 0x38042000, 0x38044000, 0x38046000, 0x38048000, 0x3804A000, 0x3804C000, 0x3804E000, + 0x38050000, 0x38052000, 0x38054000, 0x38056000, 0x38058000, 0x3805A000, 0x3805C000, + 0x3805E000, 0x38060000, 0x38062000, 0x38064000, 0x38066000, 0x38068000, 0x3806A000, + 0x3806C000, 0x3806E000, 0x38070000, 0x38072000, 0x38074000, 0x38076000, 0x38078000, + 0x3807A000, 0x3807C000, 0x3807E000, 0x38080000, 0x38082000, 0x38084000, 0x38086000, + 0x38088000, 0x3808A000, 0x3808C000, 0x3808E000, 0x38090000, 0x38092000, 0x38094000, + 0x38096000, 0x38098000, 0x3809A000, 0x3809C000, 0x3809E000, 0x380A0000, 0x380A2000, + 0x380A4000, 0x380A6000, 0x380A8000, 0x380AA000, 0x380AC000, 0x380AE000, 0x380B0000, + 0x380B2000, 0x380B4000, 0x380B6000, 0x380B8000, 0x380BA000, 0x380BC000, 0x380BE000, + 0x380C0000, 0x380C2000, 0x380C4000, 0x380C6000, 0x380C8000, 0x380CA000, 0x380CC000, + 0x380CE000, 0x380D0000, 0x380D2000, 0x380D4000, 0x380D6000, 0x380D8000, 0x380DA000, + 0x380DC000, 0x380DE000, 0x380E0000, 0x380E2000, 0x380E4000, 0x380E6000, 0x380E8000, + 0x380EA000, 0x380EC000, 0x380EE000, 0x380F0000, 0x380F2000, 0x380F4000, 0x380F6000, + 0x380F8000, 0x380FA000, 0x380FC000, 0x380FE000, 0x38100000, 0x38102000, 0x38104000, + 0x38106000, 0x38108000, 0x3810A000, 0x3810C000, 0x3810E000, 0x38110000, 0x38112000, + 0x38114000, 0x38116000, 0x38118000, 0x3811A000, 0x3811C000, 0x3811E000, 0x38120000, + 0x38122000, 0x38124000, 0x38126000, 0x38128000, 0x3812A000, 0x3812C000, 0x3812E000, + 0x38130000, 0x38132000, 0x38134000, 0x38136000, 0x38138000, 0x3813A000, 0x3813C000, + 0x3813E000, 0x38140000, 0x38142000, 0x38144000, 0x38146000, 0x38148000, 0x3814A000, + 0x3814C000, 0x3814E000, 0x38150000, 0x38152000, 0x38154000, 0x38156000, 0x38158000, + 0x3815A000, 0x3815C000, 0x3815E000, 0x38160000, 0x38162000, 0x38164000, 0x38166000, + 0x38168000, 0x3816A000, 0x3816C000, 0x3816E000, 0x38170000, 0x38172000, 0x38174000, + 0x38176000, 0x38178000, 0x3817A000, 0x3817C000, 0x3817E000, 0x38180000, 0x38182000, + 0x38184000, 0x38186000, 0x38188000, 0x3818A000, 0x3818C000, 0x3818E000, 0x38190000, + 0x38192000, 0x38194000, 0x38196000, 0x38198000, 0x3819A000, 0x3819C000, 0x3819E000, + 0x381A0000, 0x381A2000, 0x381A4000, 0x381A6000, 0x381A8000, 0x381AA000, 0x381AC000, + 0x381AE000, 0x381B0000, 0x381B2000, 0x381B4000, 0x381B6000, 0x381B8000, 0x381BA000, + 0x381BC000, 0x381BE000, 0x381C0000, 0x381C2000, 0x381C4000, 0x381C6000, 0x381C8000, + 0x381CA000, 0x381CC000, 0x381CE000, 0x381D0000, 0x381D2000, 0x381D4000, 0x381D6000, + 0x381D8000, 0x381DA000, 0x381DC000, 0x381DE000, 0x381E0000, 0x381E2000, 0x381E4000, + 0x381E6000, 0x381E8000, 0x381EA000, 0x381EC000, 0x381EE000, 0x381F0000, 0x381F2000, + 0x381F4000, 0x381F6000, 0x381F8000, 0x381FA000, 0x381FC000, 0x381FE000, 0x38200000, + 0x38202000, 0x38204000, 0x38206000, 0x38208000, 0x3820A000, 0x3820C000, 0x3820E000, + 0x38210000, 0x38212000, 0x38214000, 0x38216000, 0x38218000, 0x3821A000, 0x3821C000, + 0x3821E000, 0x38220000, 0x38222000, 0x38224000, 0x38226000, 0x38228000, 0x3822A000, + 0x3822C000, 0x3822E000, 0x38230000, 0x38232000, 0x38234000, 0x38236000, 0x38238000, + 0x3823A000, 0x3823C000, 0x3823E000, 0x38240000, 0x38242000, 0x38244000, 0x38246000, + 0x38248000, 0x3824A000, 0x3824C000, 0x3824E000, 0x38250000, 0x38252000, 0x38254000, + 0x38256000, 0x38258000, 0x3825A000, 0x3825C000, 0x3825E000, 0x38260000, 0x38262000, + 0x38264000, 0x38266000, 0x38268000, 0x3826A000, 0x3826C000, 0x3826E000, 0x38270000, + 0x38272000, 0x38274000, 0x38276000, 0x38278000, 0x3827A000, 0x3827C000, 0x3827E000, + 0x38280000, 0x38282000, 0x38284000, 0x38286000, 0x38288000, 0x3828A000, 0x3828C000, + 0x3828E000, 0x38290000, 0x38292000, 0x38294000, 0x38296000, 0x38298000, 0x3829A000, + 0x3829C000, 0x3829E000, 0x382A0000, 0x382A2000, 0x382A4000, 0x382A6000, 0x382A8000, + 0x382AA000, 0x382AC000, 0x382AE000, 0x382B0000, 0x382B2000, 0x382B4000, 0x382B6000, + 0x382B8000, 0x382BA000, 0x382BC000, 0x382BE000, 0x382C0000, 0x382C2000, 0x382C4000, + 0x382C6000, 0x382C8000, 0x382CA000, 0x382CC000, 0x382CE000, 0x382D0000, 0x382D2000, + 0x382D4000, 0x382D6000, 0x382D8000, 0x382DA000, 0x382DC000, 0x382DE000, 0x382E0000, + 0x382E2000, 0x382E4000, 0x382E6000, 0x382E8000, 0x382EA000, 0x382EC000, 0x382EE000, + 0x382F0000, 0x382F2000, 0x382F4000, 0x382F6000, 0x382F8000, 0x382FA000, 0x382FC000, + 0x382FE000, 0x38300000, 0x38302000, 0x38304000, 0x38306000, 0x38308000, 0x3830A000, + 0x3830C000, 0x3830E000, 0x38310000, 0x38312000, 0x38314000, 0x38316000, 0x38318000, + 0x3831A000, 0x3831C000, 0x3831E000, 0x38320000, 0x38322000, 0x38324000, 0x38326000, + 0x38328000, 0x3832A000, 0x3832C000, 0x3832E000, 0x38330000, 0x38332000, 0x38334000, + 0x38336000, 0x38338000, 0x3833A000, 0x3833C000, 0x3833E000, 0x38340000, 0x38342000, + 0x38344000, 0x38346000, 0x38348000, 0x3834A000, 0x3834C000, 0x3834E000, 0x38350000, + 0x38352000, 0x38354000, 0x38356000, 0x38358000, 0x3835A000, 0x3835C000, 0x3835E000, + 0x38360000, 0x38362000, 0x38364000, 0x38366000, 0x38368000, 0x3836A000, 0x3836C000, + 0x3836E000, 0x38370000, 0x38372000, 0x38374000, 0x38376000, 0x38378000, 0x3837A000, + 0x3837C000, 0x3837E000, 0x38380000, 0x38382000, 0x38384000, 0x38386000, 0x38388000, + 0x3838A000, 0x3838C000, 0x3838E000, 0x38390000, 0x38392000, 0x38394000, 0x38396000, + 0x38398000, 0x3839A000, 0x3839C000, 0x3839E000, 0x383A0000, 0x383A2000, 0x383A4000, + 0x383A6000, 0x383A8000, 0x383AA000, 0x383AC000, 0x383AE000, 0x383B0000, 0x383B2000, + 0x383B4000, 0x383B6000, 0x383B8000, 0x383BA000, 0x383BC000, 0x383BE000, 0x383C0000, + 0x383C2000, 0x383C4000, 0x383C6000, 0x383C8000, 0x383CA000, 0x383CC000, 0x383CE000, + 0x383D0000, 0x383D2000, 0x383D4000, 0x383D6000, 0x383D8000, 0x383DA000, 0x383DC000, + 0x383DE000, 0x383E0000, 0x383E2000, 0x383E4000, 0x383E6000, 0x383E8000, 0x383EA000, + 0x383EC000, 0x383EE000, 0x383F0000, 0x383F2000, 0x383F4000, 0x383F6000, 0x383F8000, + 0x383FA000, 0x383FC000, 0x383FE000, 0x38400000, 0x38402000, 0x38404000, 0x38406000, + 0x38408000, 0x3840A000, 0x3840C000, 0x3840E000, 0x38410000, 0x38412000, 0x38414000, + 0x38416000, 0x38418000, 0x3841A000, 0x3841C000, 0x3841E000, 0x38420000, 0x38422000, + 0x38424000, 0x38426000, 0x38428000, 0x3842A000, 0x3842C000, 0x3842E000, 0x38430000, + 0x38432000, 0x38434000, 0x38436000, 0x38438000, 0x3843A000, 0x3843C000, 0x3843E000, + 0x38440000, 0x38442000, 0x38444000, 0x38446000, 0x38448000, 0x3844A000, 0x3844C000, + 0x3844E000, 0x38450000, 0x38452000, 0x38454000, 0x38456000, 0x38458000, 0x3845A000, + 0x3845C000, 0x3845E000, 0x38460000, 0x38462000, 0x38464000, 0x38466000, 0x38468000, + 0x3846A000, 0x3846C000, 0x3846E000, 0x38470000, 0x38472000, 0x38474000, 0x38476000, + 0x38478000, 0x3847A000, 0x3847C000, 0x3847E000, 0x38480000, 0x38482000, 0x38484000, + 0x38486000, 0x38488000, 0x3848A000, 0x3848C000, 0x3848E000, 0x38490000, 0x38492000, + 0x38494000, 0x38496000, 0x38498000, 0x3849A000, 0x3849C000, 0x3849E000, 0x384A0000, + 0x384A2000, 0x384A4000, 0x384A6000, 0x384A8000, 0x384AA000, 0x384AC000, 0x384AE000, + 0x384B0000, 0x384B2000, 0x384B4000, 0x384B6000, 0x384B8000, 0x384BA000, 0x384BC000, + 0x384BE000, 0x384C0000, 0x384C2000, 0x384C4000, 0x384C6000, 0x384C8000, 0x384CA000, + 0x384CC000, 0x384CE000, 0x384D0000, 0x384D2000, 0x384D4000, 0x384D6000, 0x384D8000, + 0x384DA000, 0x384DC000, 0x384DE000, 0x384E0000, 0x384E2000, 0x384E4000, 0x384E6000, + 0x384E8000, 0x384EA000, 0x384EC000, 0x384EE000, 0x384F0000, 0x384F2000, 0x384F4000, + 0x384F6000, 0x384F8000, 0x384FA000, 0x384FC000, 0x384FE000, 0x38500000, 0x38502000, + 0x38504000, 0x38506000, 0x38508000, 0x3850A000, 0x3850C000, 0x3850E000, 0x38510000, + 0x38512000, 0x38514000, 0x38516000, 0x38518000, 0x3851A000, 0x3851C000, 0x3851E000, + 0x38520000, 0x38522000, 0x38524000, 0x38526000, 0x38528000, 0x3852A000, 0x3852C000, + 0x3852E000, 0x38530000, 0x38532000, 0x38534000, 0x38536000, 0x38538000, 0x3853A000, + 0x3853C000, 0x3853E000, 0x38540000, 0x38542000, 0x38544000, 0x38546000, 0x38548000, + 0x3854A000, 0x3854C000, 0x3854E000, 0x38550000, 0x38552000, 0x38554000, 0x38556000, + 0x38558000, 0x3855A000, 0x3855C000, 0x3855E000, 0x38560000, 0x38562000, 0x38564000, + 0x38566000, 0x38568000, 0x3856A000, 0x3856C000, 0x3856E000, 0x38570000, 0x38572000, + 0x38574000, 0x38576000, 0x38578000, 0x3857A000, 0x3857C000, 0x3857E000, 0x38580000, + 0x38582000, 0x38584000, 0x38586000, 0x38588000, 0x3858A000, 0x3858C000, 0x3858E000, + 0x38590000, 0x38592000, 0x38594000, 0x38596000, 0x38598000, 0x3859A000, 0x3859C000, + 0x3859E000, 0x385A0000, 0x385A2000, 0x385A4000, 0x385A6000, 0x385A8000, 0x385AA000, + 0x385AC000, 0x385AE000, 0x385B0000, 0x385B2000, 0x385B4000, 0x385B6000, 0x385B8000, + 0x385BA000, 0x385BC000, 0x385BE000, 0x385C0000, 0x385C2000, 0x385C4000, 0x385C6000, + 0x385C8000, 0x385CA000, 0x385CC000, 0x385CE000, 0x385D0000, 0x385D2000, 0x385D4000, + 0x385D6000, 0x385D8000, 0x385DA000, 0x385DC000, 0x385DE000, 0x385E0000, 0x385E2000, + 0x385E4000, 0x385E6000, 0x385E8000, 0x385EA000, 0x385EC000, 0x385EE000, 0x385F0000, + 0x385F2000, 0x385F4000, 0x385F6000, 0x385F8000, 0x385FA000, 0x385FC000, 0x385FE000, + 0x38600000, 0x38602000, 0x38604000, 0x38606000, 0x38608000, 0x3860A000, 0x3860C000, + 0x3860E000, 0x38610000, 0x38612000, 0x38614000, 0x38616000, 0x38618000, 0x3861A000, + 0x3861C000, 0x3861E000, 0x38620000, 0x38622000, 0x38624000, 0x38626000, 0x38628000, + 0x3862A000, 0x3862C000, 0x3862E000, 0x38630000, 0x38632000, 0x38634000, 0x38636000, + 0x38638000, 0x3863A000, 0x3863C000, 0x3863E000, 0x38640000, 0x38642000, 0x38644000, + 0x38646000, 0x38648000, 0x3864A000, 0x3864C000, 0x3864E000, 0x38650000, 0x38652000, + 0x38654000, 0x38656000, 0x38658000, 0x3865A000, 0x3865C000, 0x3865E000, 0x38660000, + 0x38662000, 0x38664000, 0x38666000, 0x38668000, 0x3866A000, 0x3866C000, 0x3866E000, + 0x38670000, 0x38672000, 0x38674000, 0x38676000, 0x38678000, 0x3867A000, 0x3867C000, + 0x3867E000, 0x38680000, 0x38682000, 0x38684000, 0x38686000, 0x38688000, 0x3868A000, + 0x3868C000, 0x3868E000, 0x38690000, 0x38692000, 0x38694000, 0x38696000, 0x38698000, + 0x3869A000, 0x3869C000, 0x3869E000, 0x386A0000, 0x386A2000, 0x386A4000, 0x386A6000, + 0x386A8000, 0x386AA000, 0x386AC000, 0x386AE000, 0x386B0000, 0x386B2000, 0x386B4000, + 0x386B6000, 0x386B8000, 0x386BA000, 0x386BC000, 0x386BE000, 0x386C0000, 0x386C2000, + 0x386C4000, 0x386C6000, 0x386C8000, 0x386CA000, 0x386CC000, 0x386CE000, 0x386D0000, + 0x386D2000, 0x386D4000, 0x386D6000, 0x386D8000, 0x386DA000, 0x386DC000, 0x386DE000, + 0x386E0000, 0x386E2000, 0x386E4000, 0x386E6000, 0x386E8000, 0x386EA000, 0x386EC000, + 0x386EE000, 0x386F0000, 0x386F2000, 0x386F4000, 0x386F6000, 0x386F8000, 0x386FA000, + 0x386FC000, 0x386FE000, 0x38700000, 0x38702000, 0x38704000, 0x38706000, 0x38708000, + 0x3870A000, 0x3870C000, 0x3870E000, 0x38710000, 0x38712000, 0x38714000, 0x38716000, + 0x38718000, 0x3871A000, 0x3871C000, 0x3871E000, 0x38720000, 0x38722000, 0x38724000, + 0x38726000, 0x38728000, 0x3872A000, 0x3872C000, 0x3872E000, 0x38730000, 0x38732000, + 0x38734000, 0x38736000, 0x38738000, 0x3873A000, 0x3873C000, 0x3873E000, 0x38740000, + 0x38742000, 0x38744000, 0x38746000, 0x38748000, 0x3874A000, 0x3874C000, 0x3874E000, + 0x38750000, 0x38752000, 0x38754000, 0x38756000, 0x38758000, 0x3875A000, 0x3875C000, + 0x3875E000, 0x38760000, 0x38762000, 0x38764000, 0x38766000, 0x38768000, 0x3876A000, + 0x3876C000, 0x3876E000, 0x38770000, 0x38772000, 0x38774000, 0x38776000, 0x38778000, + 0x3877A000, 0x3877C000, 0x3877E000, 0x38780000, 0x38782000, 0x38784000, 0x38786000, + 0x38788000, 0x3878A000, 0x3878C000, 0x3878E000, 0x38790000, 0x38792000, 0x38794000, + 0x38796000, 0x38798000, 0x3879A000, 0x3879C000, 0x3879E000, 0x387A0000, 0x387A2000, + 0x387A4000, 0x387A6000, 0x387A8000, 0x387AA000, 0x387AC000, 0x387AE000, 0x387B0000, + 0x387B2000, 0x387B4000, 0x387B6000, 0x387B8000, 0x387BA000, 0x387BC000, 0x387BE000, + 0x387C0000, 0x387C2000, 0x387C4000, 0x387C6000, 0x387C8000, 0x387CA000, 0x387CC000, + 0x387CE000, 0x387D0000, 0x387D2000, 0x387D4000, 0x387D6000, 0x387D8000, 0x387DA000, + 0x387DC000, 0x387DE000, 0x387E0000, 0x387E2000, 0x387E4000, 0x387E6000, 0x387E8000, + 0x387EA000, 0x387EC000, 0x387EE000, 0x387F0000, 0x387F2000, 0x387F4000, 0x387F6000, + 0x387F8000, 0x387FA000, 0x387FC000, 0x387FE000}; + static const bits::type exponent_table[64] = { + 0x00000000, 0x00800000, 0x01000000, 0x01800000, 0x02000000, 0x02800000, 0x03000000, + 0x03800000, 0x04000000, 0x04800000, 0x05000000, 0x05800000, 0x06000000, 0x06800000, + 0x07000000, 0x07800000, 0x08000000, 0x08800000, 0x09000000, 0x09800000, 0x0A000000, + 0x0A800000, 0x0B000000, 0x0B800000, 0x0C000000, 0x0C800000, 0x0D000000, 0x0D800000, + 0x0E000000, 0x0E800000, 0x0F000000, 0x47800000, 0x80000000, 0x80800000, 0x81000000, + 0x81800000, 0x82000000, 0x82800000, 0x83000000, 0x83800000, 0x84000000, 0x84800000, + 0x85000000, 0x85800000, 0x86000000, 0x86800000, 0x87000000, 0x87800000, 0x88000000, + 0x88800000, 0x89000000, 0x89800000, 0x8A000000, 0x8A800000, 0x8B000000, 0x8B800000, + 0x8C000000, 0x8C800000, 0x8D000000, 0x8D800000, 0x8E000000, 0x8E800000, 0x8F000000, + 0xC7800000}; + static const unsigned short offset_table[64] = { + 0, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, + 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, + 1024, 1024, 1024, 1024, 1024, 1024, 0, 1024, 1024, 1024, 1024, 1024, 1024, + 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, + 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024}; + bits::type fbits = + mantissa_table[offset_table[value >> 10] + (value & 0x3FF)] + exponent_table[value >> 10]; +#endif + float out; + std::memcpy(&out, &fbits, sizeof(float)); + return out; +#endif +} + +/// Convert half-precision to IEEE double-precision. +/// \param value half-precision value to convert +/// \return double-precision value +inline double half2float_impl(unsigned int value, double, true_type) +{ +#if HALF_ENABLE_F16C_INTRINSICS + return _mm_cvtsd_f64(_mm_cvtps_pd(_mm_cvtph_ps(_mm_cvtsi32_si128(value)))); +#else + uint32 hi = static_cast(value & 0x8000) << 16; + unsigned int abs = value & 0x7FFF; + if(abs) + { + hi |= 0x3F000000 << static_cast(abs >= 0x7C00); + for(; abs < 0x400; abs <<= 1, hi -= 0x100000) + ; + hi += static_cast(abs) << 10; + } + bits::type dbits = static_cast::type>(hi) << 32; + double out; + std::memcpy(&out, &dbits, sizeof(double)); + return out; +#endif +} + +/// Convert half-precision to non-IEEE floating-point. +/// \tparam T type to convert to (builtin integer type) +/// \param value half-precision value to convert +/// \return floating-point value +template +T half2float_impl(unsigned int value, T, ...) +{ + T out; + unsigned int abs = value & 0x7FFF; + if(abs > 0x7C00) + out = + (std::numeric_limits::has_signaling_NaN && !(abs & 0x200)) + ? std::numeric_limits::signaling_NaN() + : std::numeric_limits::has_quiet_NaN ? std::numeric_limits::quiet_NaN() : T(); + else if(abs == 0x7C00) + out = std::numeric_limits::has_infinity ? std::numeric_limits::infinity() + : std::numeric_limits::max(); + else if(abs > 0x3FF) + out = std::ldexp(static_cast((abs & 0x3FF) | 0x400), (abs >> 10) - 25); + else + out = std::ldexp(static_cast(abs), -24); + return (value & 0x8000) ? -out : out; +} + +/// Convert half-precision to floating-point. +/// \tparam T type to convert to (builtin integer type) +/// \param value half-precision value to convert +/// \return floating-point value +template +T half2float(unsigned int value) +{ + return half2float_impl(value, + T(), + bool_type < std::numeric_limits::is_iec559 && + sizeof(typename bits::type) == sizeof(T) > ()); +} + +/// Convert half-precision floating-point to integer. +/// \tparam R rounding mode to use +/// \tparam E `true` for round to even, `false` for round away from zero +/// \tparam I `true` to raise INEXACT exception (if inexact), `false` to never raise it +/// \tparam T type to convert to (buitlin integer type with at least 16 bits precision, excluding +/// any implicit sign bits) +/// \param value half-precision value to convert +/// \return rounded integer value +/// \exception FE_INVALID if value is not representable in type \a T +/// \exception FE_INEXACT if value had to be rounded and \a I is `true` +template +T half2int(unsigned int value) +{ + unsigned int abs = value & 0x7FFF; + if(abs >= 0x7C00) + { + raise(FE_INVALID); + return (value & 0x8000) ? std::numeric_limits::min() : std::numeric_limits::max(); + } + if(abs < 0x3800) + { + raise(FE_INEXACT, I); + return (R == std::round_toward_infinity) + ? T(~(value >> 15) & (abs != 0)) + : (R == std::round_toward_neg_infinity) ? -T(value > 0x8000) : T(); + } + int exp = 25 - (abs >> 10); + unsigned int m = (value & 0x3FF) | 0x400; + int32 i = static_cast( + (exp <= 0) + ? (m << -exp) + : ((m + ((R == std::round_to_nearest) ? ((1 << (exp - 1)) - (~(m >> exp) & E)) + : (R == std::round_toward_infinity) + ? (((1 << exp) - 1) & ((value >> 15) - 1)) + : (R == std::round_toward_neg_infinity) + ? (((1 << exp) - 1) & -(value >> 15)) + : 0)) >> + exp)); + if((!std::numeric_limits::is_signed && (value & 0x8000)) || + (std::numeric_limits::digits < 16 && + ((value & 0x8000) ? (-i < std::numeric_limits::min()) + : (i > std::numeric_limits::max())))) + raise(FE_INVALID); + else if(I && exp > 0 && (m & ((1 << exp) - 1))) + raise(FE_INEXACT); + return static_cast((value & 0x8000) ? -i : i); +} + +/// \} +/// \name Mathematics +/// \{ + +/// upper part of 64-bit multiplication. +/// \tparam R rounding mode to use +/// \param x first factor +/// \param y second factor +/// \return upper 32 bit of \a x * \a y +template +uint32 mulhi(uint32 x, uint32 y) +{ + uint32 xy = (x >> 16) * (y & 0xFFFF), yx = (x & 0xFFFF) * (y >> 16), + c = (xy & 0xFFFF) + (yx & 0xFFFF) + (((x & 0xFFFF) * (y & 0xFFFF)) >> 16); + return (x >> 16) * (y >> 16) + (xy >> 16) + (yx >> 16) + (c >> 16) + + ((R == std::round_to_nearest) ? ((c >> 15) & 1) : (R == std::round_toward_infinity) + ? ((c & 0xFFFF) != 0) + : 0); +} + +/// 64-bit multiplication. +/// \param x first factor +/// \param y second factor +/// \return upper 32 bit of \a x * \a y rounded to nearest +inline uint32 multiply64(uint32 x, uint32 y) +{ +#if HALF_ENABLE_CPP11_LONG_LONG + return static_cast( + (static_cast(x) * static_cast(y) + 0x80000000) >> + 32); +#else + return mulhi(x, y); +#endif +} + +/// 64-bit division. +/// \param x upper 32 bit of dividend +/// \param y divisor +/// \param s variable to store sticky bit for rounding +/// \return (\a x << 32) / \a y +inline uint32 divide64(uint32 x, uint32 y, int& s) +{ +#if HALF_ENABLE_CPP11_LONG_LONG + unsigned long long xx = static_cast(x) << 32; + return s = (xx % y != 0), static_cast(xx / y); +#else + y >>= 1; + uint32 rem = x, div = 0; + for(unsigned int i = 0; i < 32; ++i) + { + div <<= 1; + if(rem >= y) + { + rem -= y; + div |= 1; + } + rem <<= 1; + } + return s = rem > 1, div; +#endif +} + +/// Half precision positive modulus. +/// \tparam Q `true` to compute full quotient, `false` else +/// \tparam R `true` to compute signed remainder, `false` for positive remainder +/// \param x first operand as positive finite half-precision value +/// \param y second operand as positive finite half-precision value +/// \param quo adress to store quotient at, `nullptr` if \a Q `false` +/// \return modulus of \a x / \a y +template +unsigned int mod(unsigned int x, unsigned int y, int* quo = NULL) +{ + unsigned int q = 0; + if(x > y) + { + int absx = x, absy = y, expx = 0, expy = 0; + for(; absx < 0x400; absx <<= 1, --expx) + ; + for(; absy < 0x400; absy <<= 1, --expy) + ; + expx += absx >> 10; + expy += absy >> 10; + int mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400; + for(int d = expx - expy; d; --d) + { + if(!Q && mx == my) + return 0; + if(mx >= my) + { + mx -= my; + q += Q; + } + mx <<= 1; + q <<= static_cast(Q); + } + if(!Q && mx == my) + return 0; + if(mx >= my) + { + mx -= my; + ++q; + } + if(Q) + { + q &= (1 << (std::numeric_limits::digits - 1)) - 1; + if(!mx) + return *quo = q, 0; + } + for(; mx < 0x400; mx <<= 1, --expy) + ; + x = (expy > 0) ? ((expy << 10) | (mx & 0x3FF)) : (mx >> (1 - expy)); + } + if(R) + { + unsigned int a, b; + if(y < 0x800) + { + a = (x < 0x400) ? (x << 1) : (x + 0x400); + b = y; + } + else + { + a = x; + b = y - 0x400; + } + if(a > b || (a == b && (q & 1))) + { + int exp = (y >> 10) + (y <= 0x3FF), d = exp - (x >> 10) - (x <= 0x3FF); + int m = (((y & 0x3FF) | ((y > 0x3FF) << 10)) << 1) - + (((x & 0x3FF) | ((x > 0x3FF) << 10)) << (1 - d)); + for(; m < 0x800 && exp > 1; m <<= 1, --exp) + ; + x = 0x8000 + ((exp - 1) << 10) + (m >> 1); + q += Q; + } + } + if(Q) + *quo = q; + return x; +} + +/// Fixed point square root. +/// \tparam F number of fractional bits +/// \param r radicand in Q1.F fixed point format +/// \param exp exponent +/// \return square root as Q1.F/2 +template +uint32 sqrt(uint32& r, int& exp) +{ + int i = exp & 1; + r <<= i; + exp = (exp - i) / 2; + uint32 m = 0; + for(uint32 bit = static_cast(1) << F; bit; bit >>= 2) + { + if(r < m + bit) + m >>= 1; + else + { + r -= m + bit; + m = (m >> 1) + bit; + } + } + return m; +} + +/// Fixed point binary exponential. +/// This uses the BKM algorithm in E-mode. +/// \param m exponent in [0,1) as Q0.31 +/// \param n number of iterations (at most 32) +/// \return 2 ^ \a m as Q1.31 +inline uint32 exp2(uint32 m, unsigned int n = 32) +{ + static const uint32 logs[] = { + 0x80000000, 0x4AE00D1D, 0x2934F098, 0x15C01A3A, 0x0B31FB7D, 0x05AEB4DD, 0x02DCF2D1, + 0x016FE50B, 0x00B84E23, 0x005C3E10, 0x002E24CA, 0x001713D6, 0x000B8A47, 0x0005C53B, + 0x0002E2A3, 0x00017153, 0x0000B8AA, 0x00005C55, 0x00002E2B, 0x00001715, 0x00000B8B, + 0x000005C5, 0x000002E3, 0x00000171, 0x000000B9, 0x0000005C, 0x0000002E, 0x00000017, + 0x0000000C, 0x00000006, 0x00000003, 0x00000001}; + if(!m) + return 0x80000000; + uint32 mx = 0x80000000, my = 0; + for(unsigned int i = 1; i < n; ++i) + { + uint32 mz = my + logs[i]; + if(mz <= m) + { + my = mz; + mx += mx >> i; + } + } + return mx; +} + +/// Fixed point binary logarithm. +/// This uses the BKM algorithm in L-mode. +/// \param m mantissa in [1,2) as Q1.30 +/// \param n number of iterations (at most 32) +/// \return log2(\a m) as Q0.31 +inline uint32 log2(uint32 m, unsigned int n = 32) +{ + static const uint32 logs[] = { + 0x80000000, 0x4AE00D1D, 0x2934F098, 0x15C01A3A, 0x0B31FB7D, 0x05AEB4DD, 0x02DCF2D1, + 0x016FE50B, 0x00B84E23, 0x005C3E10, 0x002E24CA, 0x001713D6, 0x000B8A47, 0x0005C53B, + 0x0002E2A3, 0x00017153, 0x0000B8AA, 0x00005C55, 0x00002E2B, 0x00001715, 0x00000B8B, + 0x000005C5, 0x000002E3, 0x00000171, 0x000000B9, 0x0000005C, 0x0000002E, 0x00000017, + 0x0000000C, 0x00000006, 0x00000003, 0x00000001}; + if(m == 0x40000000) + return 0; + uint32 mx = 0x40000000, my = 0; + for(unsigned int i = 1; i < n; ++i) + { + uint32 mz = mx + (mx >> i); + if(mz <= m) + { + mx = mz; + my += logs[i]; + } + } + return my; +} + +/// Fixed point sine and cosine. +/// This uses the CORDIC algorithm in rotation mode. +/// \param mz angle in [-pi/2,pi/2] as Q1.30 +/// \param n number of iterations (at most 31) +/// \return sine and cosine of \a mz as Q1.30 +inline std::pair sincos(uint32 mz, unsigned int n = 31) +{ + static const uint32 angles[] = { + 0x3243F6A9, 0x1DAC6705, 0x0FADBAFD, 0x07F56EA7, 0x03FEAB77, 0x01FFD55C, 0x00FFFAAB, + 0x007FFF55, 0x003FFFEB, 0x001FFFFD, 0x00100000, 0x00080000, 0x00040000, 0x00020000, + 0x00010000, 0x00008000, 0x00004000, 0x00002000, 0x00001000, 0x00000800, 0x00000400, + 0x00000200, 0x00000100, 0x00000080, 0x00000040, 0x00000020, 0x00000010, 0x00000008, + 0x00000004, 0x00000002, 0x00000001}; + uint32 mx = 0x26DD3B6A, my = 0; + for(unsigned int i = 0; i < n; ++i) + { + uint32 sign = sign_mask(mz); + uint32 tx = mx - (arithmetic_shift(my, i) ^ sign) + sign; + uint32 ty = my + (arithmetic_shift(mx, i) ^ sign) - sign; + mx = tx; + my = ty; + mz -= (angles[i] ^ sign) - sign; + } + return std::make_pair(my, mx); +} + +/// Fixed point arc tangent. +/// This uses the CORDIC algorithm in vectoring mode. +/// \param my y coordinate as Q0.30 +/// \param mx x coordinate as Q0.30 +/// \param n number of iterations (at most 31) +/// \return arc tangent of \a my / \a mx as Q1.30 +inline uint32 atan2(uint32 my, uint32 mx, unsigned int n = 31) +{ + static const uint32 angles[] = { + 0x3243F6A9, 0x1DAC6705, 0x0FADBAFD, 0x07F56EA7, 0x03FEAB77, 0x01FFD55C, 0x00FFFAAB, + 0x007FFF55, 0x003FFFEB, 0x001FFFFD, 0x00100000, 0x00080000, 0x00040000, 0x00020000, + 0x00010000, 0x00008000, 0x00004000, 0x00002000, 0x00001000, 0x00000800, 0x00000400, + 0x00000200, 0x00000100, 0x00000080, 0x00000040, 0x00000020, 0x00000010, 0x00000008, + 0x00000004, 0x00000002, 0x00000001}; + uint32 mz = 0; + for(unsigned int i = 0; i < n; ++i) + { + uint32 sign = sign_mask(my); + uint32 tx = mx + (arithmetic_shift(my, i) ^ sign) - sign; + uint32 ty = my - (arithmetic_shift(mx, i) ^ sign) + sign; + mx = tx; + my = ty; + mz += (angles[i] ^ sign) - sign; + } + return mz; +} + +/// Reduce argument for trigonometric functions. +/// \param abs half-precision floating-point value +/// \param k value to take quarter period +/// \return \a abs reduced to [-pi/4,pi/4] as Q0.30 +inline uint32 angle_arg(unsigned int abs, int& k) +{ + uint32 m = (abs & 0x3FF) | ((abs > 0x3FF) << 10); + int exp = (abs >> 10) + (abs <= 0x3FF) - 15; + if(abs < 0x3A48) + return k = 0, m << (exp + 20); +#if HALF_ENABLE_CPP11_LONG_LONG + unsigned long long y = m * 0xA2F9836E4E442, mask = (1ULL << (62 - exp)) - 1, + yi = (y + (mask >> 1)) & ~mask, f = y - yi; + uint32 sign = -static_cast(f >> 63); + k = static_cast(yi >> (62 - exp)); + return (multiply64(static_cast((sign ? -f : f) >> (31 - exp)), 0xC90FDAA2) ^ sign) - + sign; +#else + uint32 yh = m * 0xA2F98 + mulhi(m, 0x36E4E442), + yl = (m * 0x36E4E442) & 0xFFFFFFFF; + uint32 mask = (static_cast(1) << (30 - exp)) - 1, yi = (yh + (mask >> 1)) & ~mask, + sign = -static_cast(yi > yh); + k = static_cast(yi >> (30 - exp)); + uint32 fh = (yh ^ sign) + (yi ^ ~sign) - ~sign, fl = (yl ^ sign) - sign; + return (multiply64((exp > -1) + ? (((fh << (1 + exp)) & 0xFFFFFFFF) | ((fl & 0xFFFFFFFF) >> (31 - exp))) + : fh, + 0xC90FDAA2) ^ + sign) - + sign; +#endif +} + +/// Get arguments for atan2 function. +/// \param abs half-precision floating-point value +/// \return \a abs and sqrt(1 - \a abs^2) as Q0.30 +inline std::pair atan2_args(unsigned int abs) +{ + int exp = -15; + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + uint32 my = ((abs & 0x3FF) | 0x400) << 5, r = my * my; + int rexp = 2 * exp; + r = 0x40000000 - + ((rexp > -31) ? ((r >> -rexp) | ((r & ((static_cast(1) << -rexp) - 1)) != 0)) : 1); + for(rexp = 0; r < 0x40000000; r <<= 1, --rexp) + ; + uint32 mx = sqrt<30>(r, rexp); + int d = exp - rexp; + if(d < 0) + return std::make_pair((d < -14) ? ((my >> (-d - 14)) + ((my >> (-d - 15)) & 1)) + : (my << (14 + d)), + (mx << 14) + (r << 13) / mx); + if(d > 0) + return std::make_pair(my << 14, + (d > 14) + ? ((mx >> (d - 14)) + ((mx >> (d - 15)) & 1)) + : ((d == 14) ? mx : ((mx << (14 - d)) + (r << (13 - d)) / mx))); + return std::make_pair(my << 13, (mx << 13) + (r << 12) / mx); +} + +/// Get exponentials for hyperbolic computation +/// \param abs half-precision floating-point value +/// \param exp variable to take unbiased exponent of larger result +/// \param n number of BKM iterations (at most 32) +/// \return exp(abs) and exp(-\a abs) as Q1.31 with same exponent +inline std::pair hyperbolic_args(unsigned int abs, int& exp, unsigned int n = 32) +{ + uint32 mx = detail::multiply64(static_cast((abs & 0x3FF) + ((abs > 0x3FF) << 10)) << 21, + 0xB8AA3B29), + my; + int e = (abs >> 10) + (abs <= 0x3FF); + if(e < 14) + { + exp = 0; + mx >>= 14 - e; + } + else + { + exp = mx >> (45 - e); + mx = (mx << (e - 14)) & 0x7FFFFFFF; + } + mx = exp2(mx, n); + int d = exp << 1, s; + if(mx > 0x80000000) + { + my = divide64(0x80000000, mx, s); + my |= s; + ++d; + } + else + my = mx; + return std::make_pair( + mx, (d < 31) ? ((my >> d) | ((my & ((static_cast(1) << d) - 1)) != 0)) : 1); +} + +/// Postprocessing for binary exponential. +/// \tparam R rounding mode to use +/// \tparam I `true` to always raise INEXACT exception, `false` to raise only for rounded results +/// \param m mantissa as Q1.31 +/// \param exp absolute value of unbiased exponent +/// \param esign sign of actual exponent +/// \param sign sign bit of result +/// \return value converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded or \a I is `true` +template +unsigned int exp2_post(uint32 m, int exp, bool esign, unsigned int sign = 0) +{ + int s = 0; + if(esign) + { + if(m > 0x80000000) + { + m = divide64(0x80000000, m, s); + ++exp; + } + if(exp > 25) + return underflow(sign); + else if(exp == 25) + return rounded(sign, 1, (m & 0x7FFFFFFF) != 0); + exp = -exp; + } + else if(exp > 15) + return overflow(sign); + return fixed2half(m, exp + 14, sign, s); +} + +/// Postprocessing for binary logarithm. +/// \tparam R rounding mode to use +/// \tparam L logarithm for base transformation as Q1.31 +/// \param m fractional part of logarithm as Q0.31 +/// \param ilog signed integer part of logarithm +/// \param exp biased exponent of result +/// \param sign sign bit of result +/// \return value base-transformed and converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if no other exception occurred +template +unsigned int log2_post(uint32 m, int ilog, int exp, unsigned int sign = 0) +{ + uint32 msign = sign_mask(ilog); + m = (((static_cast(ilog) << 27) + (m >> 4)) ^ msign) - msign; + if(!m) + return 0; + for(; m < 0x80000000; m <<= 1, --exp) + ; + int i = m >= L, s; + exp += i; + m >>= 1 + i; + sign ^= msign & 0x8000; + if(exp < -11) + return underflow(sign); + m = divide64(m, L, s); + return fixed2half(m, exp, sign, 1); +} + +/// Hypotenuse square root and postprocessing. +/// \tparam R rounding mode to use +/// \param r mantissa as Q2.30 +/// \param exp unbiased exponent +/// \return square root converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if value had to be rounded +template +unsigned int hypot_post(uint32 r, int exp) +{ + int i = r >> 31; + if((exp += i) > 46) + return overflow(); + if(exp < -34) + return underflow(); + r = (r >> i) | (r & i); + uint32 m = sqrt<30>(r, exp += 15); + return fixed2half(m, exp - 1, 0, r != 0); +} + +/// Division and postprocessing for tangents. +/// \tparam R rounding mode to use +/// \param my dividend as Q1.31 +/// \param mx divisor as Q1.31 +/// \param exp biased exponent of result +/// \param sign sign bit of result +/// \return quotient converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if no other exception occurred +template +unsigned int tangent_post(uint32 my, uint32 mx, int exp, unsigned int sign = 0) +{ + int i = my >= mx, s; + exp += i; + if(exp > 29) + return overflow(sign); + if(exp < -11) + return underflow(sign); + uint32 m = divide64(my >> (i + 1), mx, s); + return fixed2half(m, exp, sign, s); +} + +/// Area function and postprocessing. +/// This computes the value directly in Q2.30 using the representation `asinh|acosh(x) = +/// log(x+sqrt(x^2+|-1))`. +/// \tparam R rounding mode to use +/// \tparam S `true` for asinh, `false` for acosh +/// \param arg half-precision argument +/// \return asinh|acosh(\a arg) converted to half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if no other exception occurred +template +unsigned int area(unsigned int arg) +{ + int abs = arg & 0x7FFF, expx = (abs >> 10) + (abs <= 0x3FF) - 15, expy = -15, ilog, i; + uint32 mx = static_cast((abs & 0x3FF) | ((abs > 0x3FF) << 10)) << 20, my, r; + for(; abs < 0x400; abs <<= 1, --expy) + ; + expy += abs >> 10; + r = ((abs & 0x3FF) | 0x400) << 5; + r *= r; + i = r >> 31; + expy = 2 * expy + i; + r >>= i; + if(S) + { + if(expy < 0) + { + r = 0x40000000 + ((expy > -30) ? ((r >> -expy) | + ((r & ((static_cast(1) << -expy) - 1)) != 0)) + : 1); + expy = 0; + } + else + { + r += 0x40000000 >> expy; + i = r >> 31; + r = (r >> i) | (r & i); + expy += i; + } + } + else + { + r -= 0x40000000 >> expy; + for(; r < 0x40000000; r <<= 1, --expy) + ; + } + my = sqrt<30>(r, expy); + my = (my << 15) + (r << 14) / my; + if(S) + { + mx >>= expy - expx; + ilog = expy; + } + else + { + my >>= expx - expy; + ilog = expx; + } + my += mx; + i = my >> 31; + static const int G = S && (R == std::round_to_nearest); + return log2_post( + log2(my >> i, 26 + S + G) + (G << 3), ilog + i, 17, arg & (static_cast(S) << 15)); +} + +/// Class for 1.31 unsigned floating-point computation +struct f31 +{ + /// Constructor. + /// \param mant mantissa as 1.31 + /// \param e exponent + HALF_CONSTEXPR f31(uint32 mant, int e) : m(mant), exp(e) {} + + /// Constructor. + /// \param abs unsigned half-precision value + f31(unsigned int abs) : exp(-15) + { + for(; abs < 0x400; abs <<= 1, --exp) + ; + m = static_cast((abs & 0x3FF) | 0x400) << 21; + exp += (abs >> 10); + } + + /// Addition operator. + /// \param a first operand + /// \param b second operand + /// \return \a a + \a b + friend f31 operator+(f31 a, f31 b) + { + if(b.exp > a.exp) + std::swap(a, b); + int d = a.exp - b.exp; + uint32 m = a.m + ((d < 32) ? (b.m >> d) : 0); + int i = (m & 0xFFFFFFFF) < a.m; + return f31(((m + i) >> i) | 0x80000000, a.exp + i); + } + + /// Subtraction operator. + /// \param a first operand + /// \param b second operand + /// \return \a a - \a b + friend f31 operator-(f31 a, f31 b) + { + int d = a.exp - b.exp, exp = a.exp; + uint32 m = a.m - ((d < 32) ? (b.m >> d) : 0); + if(!m) + return f31(0, -32); + for(; m < 0x80000000; m <<= 1, --exp) + ; + return f31(m, exp); + } + + /// Multiplication operator. + /// \param a first operand + /// \param b second operand + /// \return \a a * \a b + friend f31 operator*(f31 a, f31 b) + { + uint32 m = multiply64(a.m, b.m); + int i = m >> 31; + return f31(m << (1 - i), a.exp + b.exp + i); + } + + /// Division operator. + /// \param a first operand + /// \param b second operand + /// \return \a a / \a b + friend f31 operator/(f31 a, f31 b) + { + int i = a.m >= b.m, s; + uint32 m = divide64((a.m + i) >> i, b.m, s); + return f31(m, a.exp - b.exp + i - 1); + } + + uint32 m; ///< mantissa as 1.31. + int exp; ///< exponent. +}; + +/// Error function and postprocessing. +/// This computes the value directly in Q1.31 using the approximations given +/// [here](https://en.wikipedia.org/wiki/Error_function#Approximation_with_elementary_functions). +/// \tparam R rounding mode to use +/// \tparam C `true` for comlementary error function, `false` else +/// \param arg half-precision function argument +/// \return approximated value of error function in half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if no other exception occurred +template +unsigned int erf(unsigned int arg) +{ + unsigned int abs = arg & 0x7FFF, sign = arg & 0x8000; + f31 x(abs), x2 = x * x * f31(0xB8AA3B29, 0), + t = f31(0x80000000, 0) / (f31(0x80000000, 0) + f31(0xA7BA054A, -2) * x), t2 = t * t; + f31 e = ((f31(0x87DC2213, 0) * t2 + f31(0xB5F0E2AE, 0)) * t2 + f31(0x82790637, -2) - + (f31(0xBA00E2B8, 0) * t2 + f31(0x91A98E62, -2)) * t) * + t / + ((x2.exp < 0) ? f31(exp2((x2.exp > -32) ? (x2.m >> -x2.exp) : 0, 30), 0) + : f31(exp2((x2.m << x2.exp) & 0x7FFFFFFF, 22), x2.m >> (31 - x2.exp))); + return (!C || sign) ? fixed2half( + 0x80000000 - (e.m >> (C - e.exp)), 14 + C, sign & (C - 1U)) + : (e.exp < -25) ? underflow() : fixed2half( + e.m >> 1, e.exp + 14, 0, e.m & 1); +} + +/// Gamma function and postprocessing. +/// This approximates the value of either the gamma function or its logarithm directly in Q1.31. +/// \tparam R rounding mode to use +/// \tparam L `true` for lograithm of gamma function, `false` for gamma function +/// \param arg half-precision floating-point value +/// \return lgamma/tgamma(\a arg) in half-precision +/// \exception FE_OVERFLOW on overflows +/// \exception FE_UNDERFLOW on underflows +/// \exception FE_INEXACT if \a arg is not a positive integer +template +unsigned int gamma(unsigned int arg) +{ + /* static const double p[] ={ 2.50662827563479526904, 225.525584619175212544, -268.295973841304927459, 80.9030806934622512966, -5.00757863970517583837, 0.0114684895434781459556 }; + double t = arg + 4.65, s = p[0]; + for(unsigned int i=0; i<5; ++i) + s += p[i+1] / (arg+i); + return std::log(s) + (arg-0.5)*std::log(t) - t; +*/ static const f31 + pi(0xC90FDAA2, 1), + lbe(0xB8AA3B29, 0); + unsigned int abs = arg & 0x7FFF, sign = arg & 0x8000; + bool bsign = sign != 0; + f31 z(abs), x = sign ? (z + f31(0x80000000, 0)) : z, t = x + f31(0x94CCCCCD, 2), + s = f31(0xA06C9901, 1) + f31(0xBBE654E2, -7) / (x + f31(0x80000000, 2)) + + f31(0xA1CE6098, 6) / (x + f31(0x80000000, 1)) + f31(0xE1868CB7, 7) / x - + f31(0x8625E279, 8) / (x + f31(0x80000000, 0)) - + f31(0xA03E158F, 2) / (x + f31(0xC0000000, 1)); + int i = (s.exp >= 2) + (s.exp >= 4) + (s.exp >= 8) + (s.exp >= 16); + s = f31((static_cast(s.exp) << (31 - i)) + (log2(s.m >> 1, 28) >> i), i) / lbe; + if(x.exp != -1 || x.m != 0x80000000) + { + i = (t.exp >= 2) + (t.exp >= 4) + (t.exp >= 8); + f31 l = f31((static_cast(t.exp) << (31 - i)) + (log2(t.m >> 1, 30) >> i), i) / lbe; + s = (x.exp < -1) ? (s - (f31(0x80000000, -1) - x) * l) + : (s + (x - f31(0x80000000, -1)) * l); + } + s = x.exp ? (s - t) : (t - s); + if(bsign) + { + if(z.exp >= 0) + { + sign &= (L | ((z.m >> (31 - z.exp)) & 1)) - 1; + for(z = f31((z.m << (1 + z.exp)) & 0xFFFFFFFF, -1); z.m < 0x80000000; + z.m <<= 1, --z.exp) + ; + } + if(z.exp == -1) + z = f31(0x80000000, 0) - z; + if(z.exp < -1) + { + z = z * pi; + z.m = sincos(z.m >> (1 - z.exp), 30).first; + for(z.exp = 1; z.m < 0x80000000; z.m <<= 1, --z.exp) + ; + } + else + z = f31(0x80000000, 0); + } + if(L) + { + if(bsign) + { + f31 l(0x92868247, 0); + if(z.exp < 0) + { + uint32 m = log2((z.m + 1) >> 1, 27); + z = f31(-((static_cast(z.exp) << 26) + (m >> 5)), 5); + for(; z.m < 0x80000000; z.m <<= 1, --z.exp) + ; + l = l + z / lbe; + } + sign = static_cast(x.exp && (l.exp < s.exp || (l.exp == s.exp && l.m < s.m))) + << 15; + s = sign ? (s - l) : x.exp ? (l - s) : (l + s); + } + else + { + sign = static_cast(x.exp == 0) << 15; + if(s.exp < -24) + return underflow(sign); + if(s.exp > 15) + return overflow(sign); + } + } + else + { + s = s * lbe; + uint32 m; + if(s.exp < 0) + { + m = s.m >> -s.exp; + s.exp = 0; + } + else + { + m = (s.m << s.exp) & 0x7FFFFFFF; + s.exp = (s.m >> (31 - s.exp)); + } + s.m = exp2(m, 27); + if(!x.exp) + s = f31(0x80000000, 0) / s; + if(bsign) + { + if(z.exp < 0) + s = s * z; + s = pi / s; + if(s.exp < -24) + return underflow(sign); + } + else if(z.exp > 0 && !(z.m & ((1 << (31 - z.exp)) - 1))) + return ((s.exp + 14) << 10) + (s.m >> 21); + if(s.exp > 15) + return overflow(sign); + } + return fixed2half(s.m, s.exp + 14, sign); +} +/// \} + +template +struct half_caster; +} + +/// Half-precision floating-point type. +/// This class implements an IEEE-conformant half-precision floating-point type with the usual +/// arithmetic +/// operators and conversions. It is implicitly convertible to single-precision floating-point, +/// which makes artihmetic +/// expressions and functions with mixed-type operands to be of the most precise operand type. +/// +/// According to the C++98/03 definition, the half type is not a POD type. But according to C++11's +/// less strict and +/// extended definitions it is both a standard layout type and a trivially copyable type (even if +/// not a POD type), which +/// means it can be standard-conformantly copied using raw binary copies. But in this context some +/// more words about the +/// actual size of the type. Although the half is representing an IEEE 16-bit type, it does not +/// neccessarily have to be of +/// exactly 16-bits size. But on any reasonable implementation the actual binary representation of +/// this type will most +/// probably not ivolve any additional "magic" or padding beyond the simple binary representation of +/// the underlying 16-bit +/// IEEE number, even if not strictly guaranteed by the standard. But even then it only has an +/// actual size of 16 bits if +/// your C++ implementation supports an unsigned integer type of exactly 16 bits width. But this +/// should be the case on +/// nearly any reasonable platform. +/// +/// So if your C++ implementation is not totally exotic or imposes special alignment requirements, +/// it is a reasonable +/// assumption that the data of a half is just comprised of the 2 bytes of the underlying IEEE +/// representation. +class half +{ + public: + /// \name Construction and assignment + /// \{ + + /// Default constructor. + /// This initializes the half to 0. Although this does not match the builtin types' + /// default-initialization semantics + /// and may be less efficient than no initialization, it is needed to provide proper + /// value-initialization semantics. + HALF_CONSTEXPR half() HALF_NOEXCEPT : data_() {} + + /// Conversion constructor. + /// \param rhs float to convert + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + explicit half(float rhs) + : data_(static_cast(detail::float2half(rhs))) + { + } + + /// Conversion to single-precision. + /// \return single precision value representing expression value + operator float() const { return detail::half2float(data_); } + + /// Assignment operator. + /// \param rhs single-precision value to copy from + /// \return reference to this half + /// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding + half& operator=(float rhs) + { + data_ = static_cast(detail::float2half(rhs)); + return *this; + } + + /// \} + /// \name Arithmetic updates + /// \{ + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to add + /// \return reference to this half + /// \exception FE_... according to operator+(half,half) + half& operator+=(half rhs) { return *this = *this + rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to subtract + /// \return reference to this half + /// \exception FE_... according to operator-(half,half) + half& operator-=(half rhs) { return *this = *this - rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to multiply with + /// \return reference to this half + /// \exception FE_... according to operator*(half,half) + half& operator*=(half rhs) { return *this = *this * rhs; } + + /// Arithmetic assignment. + /// \tparam T type of concrete half expression + /// \param rhs half expression to divide by + /// \return reference to this half + /// \exception FE_... according to operator/(half,half) + half& operator/=(half rhs) { return *this = *this / rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to add + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator+=(float rhs) { return *this = *this + rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to subtract + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator-=(float rhs) { return *this = *this - rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to multiply with + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator*=(float rhs) { return *this = *this * rhs; } + + /// Arithmetic assignment. + /// \param rhs single-precision value to divide by + /// \return reference to this half + /// \exception FE_... according to operator=() + half& operator/=(float rhs) { return *this = *this / rhs; } + + /// \} + /// \name Increment and decrement + /// \{ + + /// Prefix increment. + /// \return incremented half value + /// \exception FE_... according to operator+(half,half) + half& operator++() { return *this = *this + half(detail::binary, 0x3C00); } + + /// Prefix decrement. + /// \return decremented half value + /// \exception FE_... according to operator-(half,half) + half& operator--() { return *this = *this + half(detail::binary, 0xBC00); } + + /// Postfix increment. + /// \return non-incremented half value + /// \exception FE_... according to operator+(half,half) + half operator++(int) + { + half out(*this); + ++*this; + return out; + } + + /// Postfix decrement. + /// \return non-decremented half value + /// \exception FE_... according to operator-(half,half) + half operator--(int) + { + half out(*this); + --*this; + return out; + } + /// \} + + private: + /// Rounding mode to use + static const std::float_round_style round_style = (std::float_round_style)(HALF_ROUND_STYLE); + + /// Constructor. + /// \param bits binary representation to set half to + HALF_CONSTEXPR half(detail::binary_t, unsigned int bits) HALF_NOEXCEPT + : data_(static_cast(bits)) + { + } + + /// Internal binary representation + detail::uint16 data_; + +#ifndef HALF_DOXYGEN_ONLY + friend HALF_CONSTEXPR_NOERR bool operator==(half, half); + friend HALF_CONSTEXPR_NOERR bool operator!=(half, half); + friend HALF_CONSTEXPR_NOERR bool operator<(half, half); + friend HALF_CONSTEXPR_NOERR bool operator>(half, half); + friend HALF_CONSTEXPR_NOERR bool operator<=(half, half); + friend HALF_CONSTEXPR_NOERR bool operator>=(half, half); + friend HALF_CONSTEXPR half operator-(half); + friend half operator+(half, half); + friend half operator-(half, half); + friend half operator*(half, half); + friend half operator/(half, half); + template + friend std::basic_ostream& operator<<(std::basic_ostream&, half); + template + friend std::basic_istream& operator>>(std::basic_istream&, half&); + friend HALF_CONSTEXPR half fabs(half); + friend half fmod(half, half); + friend half remainder(half, half); + friend half remquo(half, half, int*); + friend half fma(half, half, half); + friend HALF_CONSTEXPR_NOERR half fmax(half, half); + friend HALF_CONSTEXPR_NOERR half fmin(half, half); + friend half fdim(half, half); + friend half nanh(const char*); + friend half exp(half); + friend half exp2(half); + friend half expm1(half); + friend half log(half); + friend half log10(half); + friend half log2(half); + friend half log1p(half); + friend half sqrt(half); + friend half cbrt(half); + friend half hypot(half, half); + friend half hypot(half, half, half); + friend half pow(half, half); + friend void sincos(half, half*, half*); + friend half sin(half); + friend half cos(half); + friend half tan(half); + friend half asin(half); + friend half acos(half); + friend half atan(half); + friend half atan2(half, half); + friend half sinh(half); + friend half cosh(half); + friend half tanh(half); + friend half asinh(half); + friend half acosh(half); + friend half atanh(half); + friend half erf(half); + friend half erfc(half); + friend half lgamma(half); + friend half tgamma(half); + friend half ceil(half); + friend half floor(half); + friend half trunc(half); + friend half round(half); + friend long lround(half); + friend half rint(half); + friend long lrint(half); + friend half nearbyint(half); +#ifdef HALF_ENABLE_CPP11_LONG_LONG + friend long long llround(half); + friend long long llrint(half); +#endif + friend half frexp(half, int*); + friend half scalbln(half, long); + friend half modf(half, half*); + friend int ilogb(half); + friend half logb(half); + friend half nextafter(half, half); + friend half nexttoward(half, long double); + friend HALF_CONSTEXPR half copysign(half, half); + friend HALF_CONSTEXPR int fpclassify(half); + friend HALF_CONSTEXPR bool isfinite(half); + friend HALF_CONSTEXPR bool isinf(half); + friend HALF_CONSTEXPR bool isnan(half); + friend HALF_CONSTEXPR bool isnormal(half); + friend HALF_CONSTEXPR bool signbit(half); + friend HALF_CONSTEXPR bool isgreater(half, half); + friend HALF_CONSTEXPR bool isgreaterequal(half, half); + friend HALF_CONSTEXPR bool isless(half, half); + friend HALF_CONSTEXPR bool islessequal(half, half); + friend HALF_CONSTEXPR bool islessgreater(half, half); + template + friend struct detail::half_caster; + friend class std::numeric_limits; +#if HALF_ENABLE_CPP11_HASH + friend struct std::hash; +#endif +#if HALF_ENABLE_CPP11_USER_LITERALS + friend half literal::operator"" _h(long double); +#endif +#endif +}; + +#if HALF_ENABLE_CPP11_USER_LITERALS +namespace literal { +/// Half literal. +/// While this returns a properly rounded half-precision value, half literals can unfortunately not +/// be constant +/// expressions due to rather involved conversions. So don't expect this to be a literal literal +/// without involving +/// conversion operations at runtime. It is a convenience feature, not a performance optimization. +/// \param value literal value +/// \return half with of given value (possibly rounded) +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half operator"" _h(long double value) +{ + return half(detail::binary, detail::float2half(value)); +} +} +#endif + +namespace detail { +/// Helper class for half casts. +/// This class template has to be specialized for all valid cast arguments to define an appropriate +/// static +/// `cast` member function and a corresponding `type` member denoting its return type. +/// \tparam T destination type +/// \tparam U source type +/// \tparam R rounding mode to use +template +struct half_caster +{ +}; +template +struct half_caster +{ +#if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_arithmetic::value, "half_cast from non-arithmetic type unsupported"); +#endif + + static half cast(U arg) { return cast_impl(arg, is_float()); }; + + private: + static half cast_impl(U arg, true_type) { return half(binary, float2half(arg)); } + static half cast_impl(U arg, false_type) { return half(binary, int2half(arg)); } +}; +template +struct half_caster +{ +#if HALF_ENABLE_CPP11_STATIC_ASSERT && HALF_ENABLE_CPP11_TYPE_TRAITS + static_assert(std::is_arithmetic::value, "half_cast to non-arithmetic type unsupported"); +#endif + + static T cast(half arg) { return cast_impl(arg, is_float()); } + + private: + static T cast_impl(half arg, true_type) { return half2float(arg.data_); } + static T cast_impl(half arg, false_type) { return half2int(arg.data_); } +}; +template +struct half_caster +{ + static half cast(half arg) { return arg; } +}; +} +} + +/// Extensions to the C++ standard library. +namespace std { +/// Numeric limits for half-precision floats. +/// **See also:** Documentation for +/// [std::numeric_limits](https://en.cppreference.com/w/cpp/types/numeric_limits) +template <> +class numeric_limits +{ + public: + /// Is template specialization. + static HALF_CONSTEXPR_CONST bool is_specialized = true; + + /// Supports signed values. + static HALF_CONSTEXPR_CONST bool is_signed = true; + + /// Is not an integer type. + static HALF_CONSTEXPR_CONST bool is_integer = false; + + /// Is not exact. + static HALF_CONSTEXPR_CONST bool is_exact = false; + + /// Doesn't provide modulo arithmetic. + static HALF_CONSTEXPR_CONST bool is_modulo = false; + + /// Has a finite set of values. + static HALF_CONSTEXPR_CONST bool is_bounded = true; + + /// IEEE conformant. + static HALF_CONSTEXPR_CONST bool is_iec559 = true; + + /// Supports infinity. + static HALF_CONSTEXPR_CONST bool has_infinity = true; + + /// Supports quiet NaNs. + static HALF_CONSTEXPR_CONST bool has_quiet_NaN = true; + + /// Supports signaling NaNs. + static HALF_CONSTEXPR_CONST bool has_signaling_NaN = true; + + /// Supports subnormal values. + static HALF_CONSTEXPR_CONST float_denorm_style has_denorm = denorm_present; + + /// Supports no denormalization detection. + static HALF_CONSTEXPR_CONST bool has_denorm_loss = false; + +#if HALF_ERRHANDLING_THROWS + static HALF_CONSTEXPR_CONST bool traps = true; +#else + /// Traps only if [HALF_ERRHANDLING_THROW_...](\ref HALF_ERRHANDLING_THROW_INVALID) is + /// acitvated. + static HALF_CONSTEXPR_CONST bool traps = false; +#endif + + /// Does not support no pre-rounding underflow detection. + static HALF_CONSTEXPR_CONST bool tinyness_before = false; + + /// Rounding mode. + static HALF_CONSTEXPR_CONST float_round_style round_style = half_float::half::round_style; + + /// Significant digits. + static HALF_CONSTEXPR_CONST int digits = 11; + + /// Significant decimal digits. + static HALF_CONSTEXPR_CONST int digits10 = 3; + + /// Required decimal digits to represent all possible values. + static HALF_CONSTEXPR_CONST int max_digits10 = 5; + + /// Number base. + static HALF_CONSTEXPR_CONST int radix = 2; + + /// One more than smallest exponent. + static HALF_CONSTEXPR_CONST int min_exponent = -13; + + /// Smallest normalized representable power of 10. + static HALF_CONSTEXPR_CONST int min_exponent10 = -4; + + /// One more than largest exponent + static HALF_CONSTEXPR_CONST int max_exponent = 16; + + /// Largest finitely representable power of 10. + static HALF_CONSTEXPR_CONST int max_exponent10 = 4; + + /// Smallest positive normal value. + static HALF_CONSTEXPR half_float::half min() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x0400); + } + + /// Smallest finite value. + static HALF_CONSTEXPR half_float::half lowest() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0xFBFF); + } + + /// Largest finite value. + static HALF_CONSTEXPR half_float::half max() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x7BFF); + } + + /// Difference between 1 and next representable value. + static HALF_CONSTEXPR half_float::half epsilon() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x1400); + } + + /// Maximum rounding error in ULP (units in the last place). + static HALF_CONSTEXPR half_float::half round_error() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, + (round_style == std::round_to_nearest) ? 0x3800 : 0x3C00); + } + + /// Positive infinity. + static HALF_CONSTEXPR half_float::half infinity() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x7C00); + } + + /// Quiet NaN. + static HALF_CONSTEXPR half_float::half quiet_NaN() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x7FFF); + } + + /// Signaling NaN. + static HALF_CONSTEXPR half_float::half signaling_NaN() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x7DFF); + } + + /// Smallest positive subnormal value. + static HALF_CONSTEXPR half_float::half denorm_min() HALF_NOTHROW + { + return half_float::half(half_float::detail::binary, 0x0001); + } +}; + +#if HALF_ENABLE_CPP11_HASH +/// Hash function for half-precision floats. +/// This is only defined if C++11 `std::hash` is supported and enabled. +/// +/// **See also:** Documentation for [std::hash](https://en.cppreference.com/w/cpp/utility/hash) +template <> +struct hash +{ + /// Type of function argument. + typedef half_float::half argument_type; + + /// Function return type. + typedef size_t result_type; + + /// Compute hash function. + /// \param arg half to hash + /// \return hash value + result_type operator()(argument_type arg) const + { + return hash()(arg.data_ & + -static_cast(arg.data_ != 0x8000)); + } +}; +#endif +} + +namespace half_float { +/// \anchor compop +/// \name Comparison operators +/// \{ + +/// Comparison for equality. +/// \param x first operand +/// \param y second operand +/// \retval true if operands equal +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator==(half x, half y) +{ + return !detail::compsignal(x.data_, y.data_) && + (x.data_ == y.data_ || !((x.data_ | y.data_) & 0x7FFF)); +} + +/// Comparison for inequality. +/// \param x first operand +/// \param y second operand +/// \retval true if operands not equal +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator!=(half x, half y) +{ + return detail::compsignal(x.data_, y.data_) || + (x.data_ != y.data_ && ((x.data_ | y.data_) & 0x7FFF)); +} + +/// Comparison for less than. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x less than \a y +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator<(half x, half y) +{ + return !detail::compsignal(x.data_, y.data_) && + ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) < + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); +} + +/// Comparison for greater than. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater than \a y +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator>(half x, half y) +{ + return !detail::compsignal(x.data_, y.data_) && + ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) > + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); +} + +/// Comparison for less equal. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x less equal \a y +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator<=(half x, half y) +{ + return !detail::compsignal(x.data_, y.data_) && + ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) <= + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); +} + +/// Comparison for greater equal. +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater equal \a y +/// \retval false else +/// \exception FE_INVALID if \a x or \a y is NaN +inline HALF_CONSTEXPR_NOERR bool operator>=(half x, half y) +{ + return !detail::compsignal(x.data_, y.data_) && + ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) >= + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)); +} + +/// \} +/// \anchor arithmetics +/// \name Arithmetic operators +/// \{ + +/// Identity. +/// \param arg operand +/// \return unchanged operand +inline HALF_CONSTEXPR half operator+(half arg) { return arg; } + +/// Negation. +/// \param arg operand +/// \return negated operand +inline HALF_CONSTEXPR half operator-(half arg) { return half(detail::binary, arg.data_ ^ 0x8000); } + +/// Addition. +/// This operation is exact to rounding for all rounding modes. +/// \param x left operand +/// \param y right operand +/// \return sum of half expressions +/// \exception FE_INVALID if \a x and \a y are infinities with different signs or signaling NaNs +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half operator+(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half( + detail::binary, + detail::float2half(detail::half2float(x.data_) + + detail::half2float(y.data_))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF; + bool sub = ((x.data_ ^ y.data_) & 0x8000) != 0; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absy != 0x7C00) ? x.data_ : (sub && absx == 0x7C00) ? detail::invalid() + : y.data_); + if(!absx) + return absy ? y : half(detail::binary, + (half::round_style == std::round_toward_neg_infinity) + ? (x.data_ | y.data_) + : (x.data_ & y.data_)); + if(!absy) + return x; + unsigned int sign = ((sub && absy > absx) ? y.data_ : x.data_) & 0x8000; + if(absy > absx) + std::swap(absx, absy); + int exp = (absx >> 10) + (absx <= 0x3FF), d = exp - (absy >> 10) - (absy <= 0x3FF), + mx = ((absx & 0x3FF) | ((absx > 0x3FF) << 10)) << 3, my; + if(d < 13) + { + my = ((absy & 0x3FF) | ((absy > 0x3FF) << 10)) << 3; + my = (my >> d) | ((my & ((1 << d) - 1)) != 0); + } + else + my = 1; + if(sub) + { + if(!(mx -= my)) + return half(detail::binary, + static_cast(half::round_style == std::round_toward_neg_infinity) + << 15); + for(; mx < 0x2000 && exp > 1; mx <<= 1, --exp) + ; + } + else + { + mx += my; + int i = mx >> 14; + if((exp += i) > 30) + return half(detail::binary, detail::overflow(sign)); + mx = (mx >> i) | (mx & i); + } + return half(detail::binary, + detail::rounded( + sign + ((exp - 1) << 10) + (mx >> 3), (mx >> 2) & 1, (mx & 0x3) != 0)); +#endif +} + +/// Subtraction. +/// This operation is exact to rounding for all rounding modes. +/// \param x left operand +/// \param y right operand +/// \return difference of half expressions +/// \exception FE_INVALID if \a x and \a y are infinities with equal signs or signaling NaNs +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half operator-(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half( + detail::binary, + detail::float2half(detail::half2float(x.data_) - + detail::half2float(y.data_))); +#else + return x + -y; +#endif +} + +/// Multiplication. +/// This operation is exact to rounding for all rounding modes. +/// \param x left operand +/// \param y right operand +/// \return product of half expressions +/// \exception FE_INVALID if multiplying 0 with infinity or if \a x or \a y is signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half operator*(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half( + detail::binary, + detail::float2half(detail::half2float(x.data_) * + detail::half2float(y.data_))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = -16; + unsigned int sign = (x.data_ ^ y.data_) & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : ((absx == 0x7C00 && !absy) || (absy == 0x7C00 && !absx)) + ? detail::invalid() + : (sign | 0x7C00)); + if(!absx || !absy) + return half(detail::binary, sign); + for(; absx < 0x400; absx <<= 1, --exp) + ; + for(; absy < 0x400; absy <<= 1, --exp) + ; + detail::uint32 m = static_cast((absx & 0x3FF) | 0x400) * + static_cast((absy & 0x3FF) | 0x400); + int i = m >> 21, s = m & i; + exp += (absx >> 10) + (absy >> 10) + i; + if(exp > 29) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -11) + return half(detail::binary, detail::underflow(sign)); + return half( + detail::binary, + detail::fixed2half(m >> i, exp, sign, s)); +#endif +} + +/// Division. +/// This operation is exact to rounding for all rounding modes. +/// \param x left operand +/// \param y right operand +/// \return quotient of half expressions +/// \exception FE_INVALID if dividing 0s or infinities with each other or if \a x or \a y is +/// signaling NaN +/// \exception FE_DIVBYZERO if dividing finite value by 0 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half operator/(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half( + detail::binary, + detail::float2half(detail::half2float(x.data_) / + detail::half2float(y.data_))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = 14; + unsigned int sign = (x.data_ ^ y.data_) & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absx == absy) ? detail::invalid() + : (sign | ((absx == 0x7C00) ? 0x7C00 : 0))); + if(!absx) + return half(detail::binary, absy ? sign : detail::invalid()); + if(!absy) + return half(detail::binary, detail::pole(sign)); + for(; absx < 0x400; absx <<= 1, --exp) + ; + for(; absy < 0x400; absy <<= 1, ++exp) + ; + detail::uint32 mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400; + int i = mx < my; + exp += (absx >> 10) - (absy >> 10) - i; + if(exp > 29) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -11) + return half(detail::binary, detail::underflow(sign)); + mx <<= 12 + i; + my <<= 1; + return half(detail::binary, + detail::fixed2half( + mx / my, exp, sign, mx % my != 0)); +#endif +} + +/// \} +/// \anchor streaming +/// \name Input and output +/// \{ + +/// Output operator. +/// This uses the built-in functionality for streaming out floating-point numbers. +/// \param out output stream to write into +/// \param arg half expression to write +/// \return reference to output stream +template +std::basic_ostream& operator<<(std::basic_ostream& out, half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return out << detail::half2float(arg.data_); +#else + return out << detail::half2float(arg.data_); +#endif +} + +/// Input operator. +/// This uses the built-in functionality for streaming in floating-point numbers, specifically +/// double precision floating +/// point numbers (unless overridden with [HALF_ARITHMETIC_TYPE](\ref HALF_ARITHMETIC_TYPE)). So the +/// input string is first +/// rounded to double precision using the underlying platform's current floating-point rounding mode +/// before being rounded +/// to half-precision using the library's half-precision rounding mode. +/// \param in input stream to read from +/// \param arg half to read into +/// \return reference to input stream +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +template +std::basic_istream& operator>>(std::basic_istream& in, half& arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t f; +#else + double f; +#endif + if(in >> f) + arg.data_ = detail::float2half(f); + return in; +} + +/// \} +/// \anchor basic +/// \name Basic mathematical operations +/// \{ + +/// Absolute value. +/// **See also:** Documentation for +/// [std::fabs](https://en.cppreference.com/w/cpp/numeric/math/fabs). +/// \param arg operand +/// \return absolute value of \a arg +inline HALF_CONSTEXPR half fabs(half arg) { return half(detail::binary, arg.data_ & 0x7FFF); } + +/// Absolute value. +/// **See also:** Documentation for [std::abs](https://en.cppreference.com/w/cpp/numeric/math/fabs). +/// \param arg operand +/// \return absolute value of \a arg +inline HALF_CONSTEXPR half abs(half arg) { return fabs(arg); } + +/// Remainder of division. +/// **See also:** Documentation for +/// [std::fmod](https://en.cppreference.com/w/cpp/numeric/math/fmod). +/// \param x first operand +/// \param y second operand +/// \return remainder of floating-point division. +/// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN +inline half fmod(half x, half y) +{ + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, sign = x.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absx == 0x7C00) ? detail::invalid() : x.data_); + if(!absy) + return half(detail::binary, detail::invalid()); + if(!absx) + return x; + if(absx == absy) + return half(detail::binary, sign); + return half(detail::binary, sign | detail::mod(absx, absy)); +} + +/// Remainder of division. +/// **See also:** Documentation for +/// [std::remainder](https://en.cppreference.com/w/cpp/numeric/math/remainder). +/// \param x first operand +/// \param y second operand +/// \return remainder of floating-point division. +/// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN +inline half remainder(half x, half y) +{ + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, sign = x.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absx == 0x7C00) ? detail::invalid() : x.data_); + if(!absy) + return half(detail::binary, detail::invalid()); + if(absx == absy) + return half(detail::binary, sign); + return half(detail::binary, sign ^ detail::mod(absx, absy)); +} + +/// Remainder of division. +/// **See also:** Documentation for +/// [std::remquo](https://en.cppreference.com/w/cpp/numeric/math/remquo). +/// \param x first operand +/// \param y second operand +/// \param quo address to store some bits of quotient at +/// \return remainder of floating-point division. +/// \exception FE_INVALID if \a x is infinite or \a y is 0 or if \a x or \a y is signaling NaN +inline half remquo(half x, half y, int* quo) +{ + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, value = x.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absx == 0x7C00) ? detail::invalid() : (*quo = 0, x.data_)); + if(!absy) + return half(detail::binary, detail::invalid()); + bool qsign = ((value ^ y.data_) & 0x8000) != 0; + int q = 1; + if(absx != absy) + value ^= detail::mod(absx, absy, &q); + return *quo = qsign ? -q : q, half(detail::binary, value); +} + +/// Fused multiply add. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::fma](https://en.cppreference.com/w/cpp/numeric/math/fma). +/// \param x first operand +/// \param y second operand +/// \param z third operand +/// \return ( \a x * \a y ) + \a z rounded as one operation. +/// \exception FE_INVALID according to operator*() and operator+() unless any argument is a quiet +/// NaN and no argument is a signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding the final addition +inline half fma(half x, half y, half z) +{ +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), + fy = detail::half2float(y.data_), + fz = detail::half2float(z.data_); +#if HALF_ENABLE_CPP11_CMATH && FP_FAST_FMA + return half(detail::binary, detail::float2half(std::fma(fx, fy, fz))); +#else + return half(detail::binary, detail::float2half(fx * fy + fz)); +#endif +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, absz = z.data_ & 0x7FFF, exp = -15; + unsigned int sign = (x.data_ ^ y.data_) & 0x8000; + bool sub = ((sign ^ z.data_) & 0x8000) != 0; + if(absx >= 0x7C00 || absy >= 0x7C00 || absz >= 0x7C00) + return (absx > 0x7C00 || absy > 0x7C00 || absz > 0x7C00) + ? half(detail::binary, detail::signal(x.data_, y.data_, z.data_)) + : (absx == 0x7C00) ? half(detail::binary, + (!absy || (sub && absz == 0x7C00)) ? detail::invalid() + : (sign | 0x7C00)) + : (absy == 0x7C00) ? half(detail::binary, + (!absx || (sub && absz == 0x7C00)) + ? detail::invalid() + : (sign | 0x7C00)) + : z; + if(!absx || !absy) + return absz ? z : half(detail::binary, + (half::round_style == std::round_toward_neg_infinity) + ? (z.data_ | sign) + : (z.data_ & sign)); + for(; absx < 0x400; absx <<= 1, --exp) + ; + for(; absy < 0x400; absy <<= 1, --exp) + ; + detail::uint32 m = static_cast((absx & 0x3FF) | 0x400) * + static_cast((absy & 0x3FF) | 0x400); + int i = m >> 21; + exp += (absx >> 10) + (absy >> 10) + i; + m <<= 3 - i; + if(absz) + { + int expz = 0; + for(; absz < 0x400; absz <<= 1, --expz) + ; + expz += absz >> 10; + detail::uint32 mz = static_cast((absz & 0x3FF) | 0x400) << 13; + if(expz > exp || (expz == exp && mz > m)) + { + std::swap(m, mz); + std::swap(exp, expz); + if(sub) + sign = z.data_ & 0x8000; + } + int d = exp - expz; + mz = (d < 23) ? ((mz >> d) | ((mz & ((static_cast(1) << d) - 1)) != 0)) : 1; + if(sub) + { + m = m - mz; + if(!m) + return half( + detail::binary, + static_cast(half::round_style == std::round_toward_neg_infinity) + << 15); + for(; m < 0x800000; m <<= 1, --exp) + ; + } + else + { + m += mz; + i = m >> 24; + m = (m >> i) | (m & i); + exp += i; + } + } + if(exp > 30) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -10) + return half(detail::binary, detail::underflow(sign)); + return half(detail::binary, + detail::fixed2half(m, exp - 1, sign)); +#endif +} + +/// Maximum of half expressions. +/// **See also:** Documentation for +/// [std::fmax](https://en.cppreference.com/w/cpp/numeric/math/fmax). +/// \param x first operand +/// \param y second operand +/// \return maximum of operands, ignoring quiet NaNs +/// \exception FE_INVALID if \a x or \a y is signaling NaN +inline HALF_CONSTEXPR_NOERR half fmax(half x, half y) +{ + return half(detail::binary, + (!isnan(y) && (isnan(x) || + (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) < + (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))))) + ? detail::select(y.data_, x.data_) + : detail::select(x.data_, y.data_)); +} + +/// Minimum of half expressions. +/// **See also:** Documentation for +/// [std::fmin](https://en.cppreference.com/w/cpp/numeric/math/fmin). +/// \param x first operand +/// \param y second operand +/// \return minimum of operands, ignoring quiet NaNs +/// \exception FE_INVALID if \a x or \a y is signaling NaN +inline HALF_CONSTEXPR_NOERR half fmin(half x, half y) +{ + return half(detail::binary, + (!isnan(y) && (isnan(x) || + (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) > + (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))))) + ? detail::select(y.data_, x.data_) + : detail::select(x.data_, y.data_)); +} + +/// Positive difference. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::fdim](https://en.cppreference.com/w/cpp/numeric/math/fdim). +/// \param x first operand +/// \param y second operand +/// \return \a x - \a y or 0 if difference negative +/// \exception FE_... according to operator-(half,half) +inline half fdim(half x, half y) +{ + if(isnan(x) || isnan(y)) + return half(detail::binary, detail::signal(x.data_, y.data_)); + return (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) <= + (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + ? half(detail::binary, 0) + : (x - y); +} + +/// Get NaN value. +/// **See also:** Documentation for [std::nan](https://en.cppreference.com/w/cpp/numeric/math/nan). +/// \param arg string code +/// \return quiet NaN +inline half nanh(const char* arg) +{ + unsigned int value = 0x7FFF; + while(*arg) + value ^= static_cast(*arg++) & 0xFF; + return half(detail::binary, value); +} + +/// \} +/// \anchor exponential +/// \name Exponential functions +/// \{ + +/// Exponential function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::exp](https://en.cppreference.com/w/cpp/numeric/math/exp). +/// \param arg function argument +/// \return e raised to \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half exp(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::exp(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? (0x7C00 & ((arg.data_ >> 15) - 1U)) + : detail::signal(arg.data_)); + if(abs >= 0x4C80) + return half(detail::binary, + (arg.data_ & 0x8000) ? detail::underflow() + : detail::overflow()); + detail::uint32 m = detail::multiply64( + static_cast((abs & 0x3FF) + ((abs > 0x3FF) << 10)) << 21, 0xB8AA3B29); + int e = (abs >> 10) + (abs <= 0x3FF), exp; + if(e < 14) + { + exp = 0; + m >>= 14 - e; + } + else + { + exp = m >> (45 - e); + m = (m << (e - 14)) & 0x7FFFFFFF; + } + return half(detail::binary, + detail::exp2_post( + detail::exp2(m, 26), exp, (arg.data_ & 0x8000) != 0)); +#endif +} + +/// Binary exponential. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::exp2](https://en.cppreference.com/w/cpp/numeric/math/exp2). +/// \param arg function argument +/// \return 2 raised to \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half exp2(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::exp2(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? (0x7C00 & ((arg.data_ >> 15) - 1U)) + : detail::signal(arg.data_)); + if(abs >= 0x4E40) + return half(detail::binary, + (arg.data_ & 0x8000) ? detail::underflow() + : detail::overflow()); + int e = (abs >> 10) + (abs <= 0x3FF), exp = (abs & 0x3FF) + ((abs > 0x3FF) << 10); + detail::uint32 m = detail::exp2((static_cast(exp) << (6 + e)) & 0x7FFFFFFF, 28); + exp >>= 25 - e; + if(m == 0x80000000) + { + if(arg.data_ & 0x8000) + exp = -exp; + else if(exp > 15) + return half(detail::binary, detail::overflow()); + return half(detail::binary, + detail::fixed2half(m, exp + 14)); + } + return half(detail::binary, + detail::exp2_post(m, exp, (arg.data_ & 0x8000) != 0)); +#endif +} + +/// Exponential minus one. +/// This function may be 1 ULP off the correctly rounded exact result in <0.05% of inputs for +/// `std::round_to_nearest` +/// and in <1% of inputs for any other rounding mode. +/// +/// **See also:** Documentation for +/// [std::expm1](https://en.cppreference.com/w/cpp/numeric/math/expm1). +/// \param arg function argument +/// \return e raised to \a arg and subtracted by 1 +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half expm1(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::expm1(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? (0x7C00 + (sign >> 1)) : detail::signal(arg.data_)); + if(abs >= 0x4A00) + return half(detail::binary, + (arg.data_ & 0x8000) ? detail::rounded(0xBBFF, 1, 1) + : detail::overflow()); + detail::uint32 m = detail::multiply64( + static_cast((abs & 0x3FF) + ((abs > 0x3FF) << 10)) << 21, 0xB8AA3B29); + int e = (abs >> 10) + (abs <= 0x3FF), exp; + if(e < 14) + { + exp = 0; + m >>= 14 - e; + } + else + { + exp = m >> (45 - e); + m = (m << (e - 14)) & 0x7FFFFFFF; + } + m = detail::exp2(m); + if(sign) + { + int s = 0; + if(m > 0x80000000) + { + ++exp; + m = detail::divide64(0x80000000, m, s); + } + m = 0x80000000 - + ((m >> exp) | ((m & ((static_cast(1) << exp) - 1)) != 0) | s); + exp = 0; + } + else + m -= (exp < 31) ? (0x80000000 >> exp) : 1; + for(exp += 14; m < 0x80000000 && exp; m <<= 1, --exp) + ; + if(exp > 29) + return half(detail::binary, detail::overflow()); + return half(detail::binary, + detail::rounded( + sign + (exp << 10) + (m >> 21), (m >> 20) & 1, (m & 0xFFFFF) != 0)); +#endif +} + +/// Natural logarithm. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::log](https://en.cppreference.com/w/cpp/numeric/math/log). +/// \param arg function argument +/// \return logarithm of \a arg to base e +/// \exception FE_INVALID for signaling NaN or negative argument +/// \exception FE_DIVBYZERO for 0 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half log(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::log(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(arg.data_ & 0x8000) + return half(detail::binary, + (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs >= 0x7C00) + return (abs == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + return half(detail::binary, + detail::log2_post( + detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, 27) + 8, + exp, + 17)); +#endif +} + +/// Common logarithm. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::log10](https://en.cppreference.com/w/cpp/numeric/math/log10). +/// \param arg function argument +/// \return logarithm of \a arg to base 10 +/// \exception FE_INVALID for signaling NaN or negative argument +/// \exception FE_DIVBYZERO for 0 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half log10(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::log10(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(arg.data_ & 0x8000) + return half(detail::binary, + (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs >= 0x7C00) + return (abs == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + switch(abs) + { + case 0x4900: return half(detail::binary, 0x3C00); + case 0x5640: return half(detail::binary, 0x4000); + case 0x63D0: return half(detail::binary, 0x4200); + case 0x70E2: return half(detail::binary, 0x4400); + } + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + return half(detail::binary, + detail::log2_post( + detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, 27) + 8, + exp, + 16)); +#endif +} + +/// Binary logarithm. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::log2](https://en.cppreference.com/w/cpp/numeric/math/log2). +/// \param arg function argument +/// \return logarithm of \a arg to base 2 +/// \exception FE_INVALID for signaling NaN or negative argument +/// \exception FE_DIVBYZERO for 0 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half log2(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::log2(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = -15, s = 0; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(arg.data_ & 0x8000) + return half(detail::binary, + (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs >= 0x7C00) + return (abs == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + if(abs == 0x3C00) + return half(detail::binary, 0); + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += (abs >> 10); + if(!(abs & 0x3FF)) + { + unsigned int value = static_cast(exp < 0) << 15, m = std::abs(exp) << 6; + for(exp = 18; m < 0x400; m <<= 1, --exp) + ; + return half(detail::binary, value + (exp << 10) + m); + } + detail::uint32 ilog = exp, sign = detail::sign_mask(ilog), + m = (((ilog << 27) + + (detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, + 28) >> + 4)) ^ + sign) - + sign; + if(!m) + return half(detail::binary, 0); + for(exp = 14; m < 0x8000000 && exp; m <<= 1, --exp) + ; + for(; m > 0xFFFFFFF; m >>= 1, ++exp) + s |= m & 1; + return half( + detail::binary, + detail::fixed2half(m, exp, sign & 0x8000, s)); +#endif +} + +/// Natural logarithm plus one. +/// This function may be 1 ULP off the correctly rounded exact result in <0.05% of inputs for +/// `std::round_to_nearest` +/// and in ~1% of inputs for any other rounding mode. +/// +/// **See also:** Documentation for +/// [std::log1p](https://en.cppreference.com/w/cpp/numeric/math/log1p). +/// \param arg function argument +/// \return logarithm of \a arg plus 1 to base e +/// \exception FE_INVALID for signaling NaN or argument <-1 +/// \exception FE_DIVBYZERO for -1 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half log1p(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::log1p(detail::half2float(arg.data_)))); +#else + if(arg.data_ >= 0xBC00) + return half(detail::binary, + (arg.data_ == 0xBC00) ? detail::pole(0x8000) : (arg.data_ <= 0xFC00) + ? detail::invalid() + : detail::signal(arg.data_)); + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs || abs >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + detail::uint32 m = static_cast((abs & 0x3FF) | 0x400) << 20; + if(arg.data_ & 0x8000) + { + m = 0x40000000 - (m >> -exp); + for(exp = 0; m < 0x40000000; m <<= 1, --exp) + ; + } + else + { + if(exp < 0) + { + m = 0x40000000 + (m >> -exp); + exp = 0; + } + else + { + m += 0x40000000 >> exp; + int i = m >> 31; + m >>= i; + exp += i; + } + } + return half(detail::binary, + detail::log2_post(detail::log2(m), exp, 17)); +#endif +} + +/// \} +/// \anchor power +/// \name Power functions +/// \{ + +/// Square root. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::sqrt](https://en.cppreference.com/w/cpp/numeric/math/sqrt). +/// \param arg function argument +/// \return square root of \a arg +/// \exception FE_INVALID for signaling NaN and negative arguments +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half sqrt(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::sqrt(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = 15; + if(!abs || arg.data_ >= 0x7C00) + return half(detail::binary, + (abs > 0x7C00) ? detail::signal(arg.data_) + : (arg.data_ > 0x8000) ? detail::invalid() : arg.data_); + for(; abs < 0x400; abs <<= 1, --exp) + ; + detail::uint32 r = static_cast((abs & 0x3FF) | 0x400) << 10, + m = detail::sqrt<20>(r, exp += abs >> 10); + return half( + detail::binary, + detail::rounded((exp << 10) + (m & 0x3FF), r > m, r != 0)); +#endif +} + +/// Cubic root. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::cbrt](https://en.cppreference.com/w/cpp/numeric/math/cbrt). +/// \param arg function argument +/// \return cubic root of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half cbrt(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::cbrt(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = -15; + if(!abs || abs == 0x3C00 || abs >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs < 0x400; abs <<= 1, --exp) + ; + detail::uint32 ilog = exp + (abs >> 10), sign = detail::sign_mask(ilog), f, + m = (((ilog << 27) + + (detail::log2(static_cast((abs & 0x3FF) | 0x400) << 20, + 24) >> + 4)) ^ + sign) - + sign; + for(exp = 2; m < 0x80000000; m <<= 1, --exp) + ; + m = detail::multiply64(m, 0xAAAAAAAB); + int i = m >> 31, s; + exp += i; + m <<= 1 - i; + if(exp < 0) + { + f = m >> -exp; + exp = 0; + } + else + { + f = (m << exp) & 0x7FFFFFFF; + exp = m >> (31 - exp); + } + m = detail::exp2(f, (half::round_style == std::round_to_nearest) ? 29 : 26); + if(sign) + { + if(m > 0x80000000) + { + m = detail::divide64(0x80000000, m, s); + ++exp; + } + exp = -exp; + } + return half(detail::binary, + (half::round_style == std::round_to_nearest) + ? detail::fixed2half( + m, exp + 14, arg.data_ & 0x8000) + : detail::fixed2half( + (m + 0x80) >> 8, exp + 14, arg.data_ & 0x8000)); +#endif +} + +/// Hypotenuse function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::hypot](https://en.cppreference.com/w/cpp/numeric/math/hypot). +/// \param x first argument +/// \param y second argument +/// \return square root of sum of squares without internal over- or underflows +/// \exception FE_INVALID if \a x or \a y is signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding of the final square root +inline half hypot(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), + fy = detail::half2float(y.data_); +#if HALF_ENABLE_CPP11_CMATH + return half(detail::binary, detail::float2half(std::hypot(fx, fy))); +#else + return half(detail::binary, + detail::float2half(std::sqrt(fx * fx + fy * fy))); +#endif +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, expx = 0, expy = 0; + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx == 0x7C00) ? detail::select(0x7C00, y.data_) + : (absy == 0x7C00) ? detail::select(0x7C00, x.data_) + : detail::signal(x.data_, y.data_)); + if(!absx) + return half(detail::binary, absy ? detail::check_underflow(absy) : 0); + if(!absy) + return half(detail::binary, detail::check_underflow(absx)); + if(absy > absx) + std::swap(absx, absy); + for(; absx < 0x400; absx <<= 1, --expx) + ; + for(; absy < 0x400; absy <<= 1, --expy) + ; + detail::uint32 mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400; + mx *= mx; + my *= my; + int ix = mx >> 21, iy = my >> 21; + expx = 2 * (expx + (absx >> 10)) - 15 + ix; + expy = 2 * (expy + (absy >> 10)) - 15 + iy; + mx <<= 10 - ix; + my <<= 10 - iy; + int d = expx - expy; + my = (d < 30) ? ((my >> d) | ((my & ((static_cast(1) << d) - 1)) != 0)) : 1; + return half(detail::binary, detail::hypot_post(mx + my, expx)); +#endif +} + +/// Hypotenuse function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::hypot](https://en.cppreference.com/w/cpp/numeric/math/hypot). +/// \param x first argument +/// \param y second argument +/// \param z third argument +/// \return square root of sum of squares without internal over- or underflows +/// \exception FE_INVALID if \a x, \a y or \a z is signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding of the final square root +inline half hypot(half x, half y, half z) +{ +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t fx = detail::half2float(x.data_), + fy = detail::half2float(y.data_), + fz = detail::half2float(z.data_); + return half(detail::binary, + detail::float2half(std::sqrt(fx * fx + fy * fy + fz * fz))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, absz = z.data_ & 0x7FFF, expx = 0, + expy = 0, expz = 0; + if(!absx) + return hypot(y, z); + if(!absy) + return hypot(x, z); + if(!absz) + return hypot(x, y); + if(absx >= 0x7C00 || absy >= 0x7C00 || absz >= 0x7C00) + return half(detail::binary, + (absx == 0x7C00) + ? detail::select(0x7C00, detail::select(y.data_, z.data_)) + : (absy == 0x7C00) + ? detail::select(0x7C00, detail::select(x.data_, z.data_)) + : (absz == 0x7C00) + ? detail::select(0x7C00, detail::select(x.data_, y.data_)) + : detail::signal(x.data_, y.data_, z.data_)); + if(absz > absy) + std::swap(absy, absz); + if(absy > absx) + std::swap(absx, absy); + if(absz > absy) + std::swap(absy, absz); + for(; absx < 0x400; absx <<= 1, --expx) + ; + for(; absy < 0x400; absy <<= 1, --expy) + ; + for(; absz < 0x400; absz <<= 1, --expz) + ; + detail::uint32 mx = (absx & 0x3FF) | 0x400, my = (absy & 0x3FF) | 0x400, + mz = (absz & 0x3FF) | 0x400; + mx *= mx; + my *= my; + mz *= mz; + int ix = mx >> 21, iy = my >> 21, iz = mz >> 21; + expx = 2 * (expx + (absx >> 10)) - 15 + ix; + expy = 2 * (expy + (absy >> 10)) - 15 + iy; + expz = 2 * (expz + (absz >> 10)) - 15 + iz; + mx <<= 10 - ix; + my <<= 10 - iy; + mz <<= 10 - iz; + int d = expy - expz; + mz = (d < 30) ? ((mz >> d) | ((mz & ((static_cast(1) << d) - 1)) != 0)) : 1; + my += mz; + if(my & 0x80000000) + { + my = (my >> 1) | (my & 1); + if(++expy > expx) + { + std::swap(mx, my); + std::swap(expx, expy); + } + } + d = expx - expy; + my = (d < 30) ? ((my >> d) | ((my & ((static_cast(1) << d) - 1)) != 0)) : 1; + return half(detail::binary, detail::hypot_post(mx + my, expx)); +#endif +} + +/// Power function. +/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in +/// ~0.00025% of inputs. +/// +/// **See also:** Documentation for [std::pow](https://en.cppreference.com/w/cpp/numeric/math/pow). +/// \param x base +/// \param y exponent +/// \return \a x raised to \a y +/// \exception FE_INVALID if \a x or \a y is signaling NaN or if \a x is finite an negative and \a y +/// is finite and not integral +/// \exception FE_DIVBYZERO if \a x is 0 and \a y is negative +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half pow(half x, half y) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::pow(detail::half2float(x.data_), + detail::half2float(y.data_)))); +#else + int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, exp = -15; + if(!absy || x.data_ == 0x3C00) + return half(detail::binary, + detail::select(0x3C00, (x.data_ == 0x3C00) ? y.data_ : x.data_)); + bool is_int = absy >= 0x6400 || (absy >= 0x3C00 && !(absy & ((1 << (25 - (absy >> 10))) - 1))); + unsigned int sign = + x.data_ & + (static_cast((absy < 0x6800) && is_int && ((absy >> (25 - (absy >> 10))) & 1)) + << 15); + if(absx >= 0x7C00 || absy >= 0x7C00) + return half(detail::binary, + (absx > 0x7C00 || absy > 0x7C00) + ? detail::signal(x.data_, y.data_) + : (absy == 0x7C00) + ? ((absx == 0x3C00) + ? 0x3C00 + : (!absx && y.data_ == 0xFC00) + ? detail::pole() + : (0x7C00 & -((y.data_ >> 15) ^ (absx > 0x3C00)))) + : (sign | (0x7C00 & ((y.data_ >> 15) - 1U)))); + if(!absx) + return half(detail::binary, (y.data_ & 0x8000) ? detail::pole(sign) : sign); + if((x.data_ & 0x8000) && !is_int) + return half(detail::binary, detail::invalid()); + if(x.data_ == 0xBC00) + return half(detail::binary, sign | 0x3C00); + if(y.data_ == 0x3800) + return sqrt(x); + if(y.data_ == 0x3C00) + return half(detail::binary, detail::check_underflow(x.data_)); + if(y.data_ == 0x4000) + return x * x; + for(; absx < 0x400; absx <<= 1, --exp) + ; + detail::uint32 ilog = exp + (absx >> 10), msign = detail::sign_mask(ilog), f, + m = (((ilog << 27) + + ((detail::log2(static_cast((absx & 0x3FF) | 0x400) << 20) + + 8) >> + 4)) ^ + msign) - + msign; + for(exp = -11; m < 0x80000000; m <<= 1, --exp) + ; + for(; absy < 0x400; absy <<= 1, --exp) + ; + m = detail::multiply64(m, static_cast((absy & 0x3FF) | 0x400) << 21); + int i = m >> 31; + exp += (absy >> 10) + i; + m <<= 1 - i; + if(exp < 0) + { + f = m >> -exp; + exp = 0; + } + else + { + f = (m << exp) & 0x7FFFFFFF; + exp = m >> (31 - exp); + } + return half(detail::binary, + detail::exp2_post( + detail::exp2(f), exp, ((msign & 1) ^ (y.data_ >> 15)) != 0, sign)); +#endif +} + +/// \} +/// \anchor trigonometric +/// \name Trigonometric functions +/// \{ + +/// Compute sine and cosine simultaneously. +/// This returns the same results as sin() and cos() but is faster than calling each function +/// individually. +/// +/// This function is exact to rounding for all rounding modes. +/// \param arg function argument +/// \param sin variable to take sine of \a arg +/// \param cos variable to take cosine of \a arg +/// \exception FE_INVALID for signaling NaN or infinity +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline void sincos(half arg, half* sin, half* cos) +{ +#ifdef HALF_ARITHMETIC_TYPE + detail::internal_t f = detail::half2float(arg.data_); + *sin = half(detail::binary, detail::float2half(std::sin(f))); + *cos = half(detail::binary, detail::float2half(std::cos(f))); +#else + int abs = arg.data_ & 0x7FFF, sign = arg.data_ >> 15, k; + if(abs >= 0x7C00) + *sin = *cos = + half(detail::binary, (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + else if(!abs) + { + *sin = arg; + *cos = half(detail::binary, 0x3C00); + } + else if(abs < 0x2500) + { + *sin = half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x3BFF, 1, 1)); + } + else + { + if(half::round_style != std::round_to_nearest) + { + switch(abs) + { + case 0x48B7: + *sin = half( + detail::binary, + detail::rounded((~arg.data_ & 0x8000) | 0x1D07, 1, 1)); + *cos = half(detail::binary, detail::rounded(0xBBFF, 1, 1)); + return; + case 0x598C: + *sin = half( + detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x3BFF, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x80FC, 1, 1)); + return; + case 0x6A64: + *sin = half( + detail::binary, + detail::rounded((~arg.data_ & 0x8000) | 0x3BFE, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x27FF, 1, 1)); + return; + case 0x6D8C: + *sin = half( + detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x0FE6, 1, 1)); + *cos = half(detail::binary, detail::rounded(0x3BFF, 1, 1)); + return; + } + } + std::pair sc = + detail::sincos(detail::angle_arg(abs, k), 28); + switch(k & 3) + { + case 1: sc = std::make_pair(sc.second, -sc.first); break; + case 2: sc = std::make_pair(-sc.first, -sc.second); break; + case 3: sc = std::make_pair(-sc.second, sc.first); break; + } + *sin = half(detail::binary, + detail::fixed2half( + (sc.first ^ -static_cast(sign)) + sign)); + *cos = half(detail::binary, + detail::fixed2half(sc.second)); + } +#endif +} + +/// Sine function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::sin](https://en.cppreference.com/w/cpp/numeric/math/sin). +/// \param arg function argument +/// \return sine value of \a arg +/// \exception FE_INVALID for signaling NaN or infinity +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half sin(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::sin(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, k; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2900) + return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); + if(half::round_style != std::round_to_nearest) + switch(abs) + { + case 0x48B7: + return half( + detail::binary, + detail::rounded((~arg.data_ & 0x8000) | 0x1D07, 1, 1)); + case 0x6A64: + return half( + detail::binary, + detail::rounded((~arg.data_ & 0x8000) | 0x3BFE, 1, 1)); + case 0x6D8C: + return half( + detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x0FE6, 1, 1)); + } + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 28); + detail::uint32 sign = -static_cast(((k >> 1) & 1) ^ (arg.data_ >> 15)); + return half(detail::binary, + detail::fixed2half( + (((k & 1) ? sc.second : sc.first) ^ sign) - sign)); +#endif +} + +/// Cosine function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::cos](https://en.cppreference.com/w/cpp/numeric/math/cos). +/// \param arg function argument +/// \return cosine value of \a arg +/// \exception FE_INVALID for signaling NaN or infinity +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half cos(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::cos(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, k; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2500) + return half(detail::binary, detail::rounded(0x3BFF, 1, 1)); + if(half::round_style != std::round_to_nearest && abs == 0x598C) + return half(detail::binary, detail::rounded(0x80FC, 1, 1)); + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 28); + detail::uint32 sign = -static_cast(((k >> 1) ^ k) & 1); + return half(detail::binary, + detail::fixed2half( + (((k & 1) ? sc.first : sc.second) ^ sign) - sign)); +#endif +} + +/// Tangent function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for [std::tan](https://en.cppreference.com/w/cpp/numeric/math/tan). +/// \param arg function argument +/// \return tangent value of \a arg +/// \exception FE_INVALID for signaling NaN or infinity +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half tan(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::tan(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = 13, k; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2700) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + if(half::round_style != std::round_to_nearest) + switch(abs) + { + case 0x658C: + return half( + detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x07E6, 1, 1)); + case 0x7330: + return half( + detail::binary, + detail::rounded((~arg.data_ & 0x8000) | 0x4B62, 1, 1)); + } + std::pair sc = detail::sincos(detail::angle_arg(abs, k), 30); + if(k & 1) + sc = std::make_pair(-sc.second, sc.first); + detail::uint32 signy = detail::sign_mask(sc.first), signx = detail::sign_mask(sc.second); + detail::uint32 my = (sc.first ^ signy) - signy, mx = (sc.second ^ signx) - signx; + for(; my < 0x80000000; my <<= 1, --exp) + ; + for(; mx < 0x80000000; mx <<= 1, ++exp) + ; + return half( + detail::binary, + detail::tangent_post(my, mx, exp, (signy ^ signx ^ arg.data_) & 0x8000)); +#endif +} + +/// Arc sine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::asin](https://en.cppreference.com/w/cpp/numeric/math/asin). +/// \param arg function argument +/// \return arc sine value of \a arg +/// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half asin(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::asin(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(!abs) + return arg; + if(abs >= 0x3C00) + return half(detail::binary, + (abs > 0x7C00) + ? detail::signal(arg.data_) + : (abs > 0x3C00) + ? detail::invalid() + : detail::rounded(sign | 0x3E48, 0, 1)); + if(abs < 0x2900) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + if(half::round_style != std::round_to_nearest && (abs == 0x2B44 || abs == 0x2DC3)) + return half(detail::binary, detail::rounded(arg.data_ + 1, 1, 1)); + std::pair sc = detail::atan2_args(abs); + detail::uint32 m = + detail::atan2(sc.first, sc.second, (half::round_style == std::round_to_nearest) ? 27 : 26); + return half(detail::binary, + detail::fixed2half(m, 14, sign)); +#endif +} + +/// Arc cosine function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::acos](https://en.cppreference.com/w/cpp/numeric/math/acos). +/// \param arg function argument +/// \return arc cosine value of \a arg +/// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half acos(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::acos(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ >> 15; + if(!abs) + return half(detail::binary, detail::rounded(0x3E48, 0, 1)); + if(abs >= 0x3C00) + return half(detail::binary, + (abs > 0x7C00) + ? detail::signal(arg.data_) + : (abs > 0x3C00) + ? detail::invalid() + : sign ? detail::rounded(0x4248, 0, 1) : 0); + std::pair cs = detail::atan2_args(abs); + detail::uint32 m = detail::atan2(cs.second, cs.first, 28); + return half(detail::binary, + detail::fixed2half( + sign ? (0xC90FDAA2 - m) : m, 15, 0, sign)); +#endif +} + +/// Arc tangent function. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::atan](https://en.cppreference.com/w/cpp/numeric/math/atan). +/// \param arg function argument +/// \return arc tangent value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half atan(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::atan(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, + (abs == 0x7C00) ? detail::rounded(sign | 0x3E48, 0, 1) + : detail::signal(arg.data_)); + if(abs <= 0x2700) + return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); + int exp = (abs >> 10) + (abs <= 0x3FF); + detail::uint32 my = (abs & 0x3FF) | ((abs > 0x3FF) << 10); + detail::uint32 m = (exp > 15) + ? detail::atan2(my << 19, + 0x20000000 >> (exp - 15), + (half::round_style == std::round_to_nearest) ? 26 : 24) + : detail::atan2(my << (exp + 4), + 0x20000000, + (half::round_style == std::round_to_nearest) ? 30 : 28); + return half(detail::binary, + detail::fixed2half(m, 14, sign)); +#endif +} + +/// Arc tangent function. +/// This function may be 1 ULP off the correctly rounded exact result in ~0.005% of inputs for +/// `std::round_to_nearest`, +/// in ~0.1% of inputs for `std::round_toward_zero` and in ~0.02% of inputs for any other rounding +/// mode. +/// +/// **See also:** Documentation for +/// [std::atan2](https://en.cppreference.com/w/cpp/numeric/math/atan2). +/// \param y numerator +/// \param x denominator +/// \return arc tangent value +/// \exception FE_INVALID if \a x or \a y is signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half atan2(half y, half x) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::atan2(detail::half2float(y.data_), + detail::half2float(x.data_)))); +#else + unsigned int absx = x.data_ & 0x7FFF, absy = y.data_ & 0x7FFF, signx = x.data_ >> 15, + signy = y.data_ & 0x8000; + if(absx >= 0x7C00 || absy >= 0x7C00) + { + if(absx > 0x7C00 || absy > 0x7C00) + return half(detail::binary, detail::signal(x.data_, y.data_)); + if(absy == 0x7C00) + return half(detail::binary, + (absx < 0x7C00) + ? detail::rounded(signy | 0x3E48, 0, 1) + : signx + ? detail::rounded(signy | 0x40B6, 0, 1) + : detail::rounded(signy | 0x3A48, 0, 1)); + return (x.data_ == 0x7C00) + ? half(detail::binary, signy) + : half(detail::binary, + detail::rounded(signy | 0x4248, 0, 1)); + } + if(!absy) + return signx ? half(detail::binary, + detail::rounded(signy | 0x4248, 0, 1)) + : y; + if(!absx) + return half(detail::binary, detail::rounded(signy | 0x3E48, 0, 1)); + int d = (absy >> 10) + (absy <= 0x3FF) - (absx >> 10) - (absx <= 0x3FF); + if(d > (signx ? 18 : 12)) + return half(detail::binary, detail::rounded(signy | 0x3E48, 0, 1)); + if(signx && d < -11) + return half(detail::binary, detail::rounded(signy | 0x4248, 0, 1)); + if(!signx && d < ((half::round_style == std::round_toward_zero) ? -15 : -9)) + { + for(; absy < 0x400; absy <<= 1, --d) + ; + detail::uint32 mx = ((absx << 1) & 0x7FF) | 0x800, my = ((absy << 1) & 0x7FF) | 0x800; + int i = my < mx; + d -= i; + if(d < -25) + return half(detail::binary, detail::underflow(signy)); + my <<= 11 + i; + return half(detail::binary, + detail::fixed2half( + my / mx, d + 14, signy, my % mx != 0)); + } + detail::uint32 m = detail::atan2( + ((absy & 0x3FF) | ((absy > 0x3FF) << 10)) << (19 + ((d < 0) ? d : (d > 0) ? 0 : -1)), + ((absx & 0x3FF) | ((absx > 0x3FF) << 10)) << (19 - ((d > 0) ? d : (d < 0) ? 0 : 1))); + return half(detail::binary, + detail::fixed2half( + signx ? (0xC90FDAA2 - m) : m, 15, signy, signx)); +#endif +} + +/// \} +/// \anchor hyperbolic +/// \name Hyperbolic functions +/// \{ + +/// Hyperbolic sine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::sinh](https://en.cppreference.com/w/cpp/numeric/math/sinh). +/// \param arg function argument +/// \return hyperbolic sine value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half sinh(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::sinh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp; + if(!abs || abs >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + if(abs <= 0x2900) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + std::pair mm = + detail::hyperbolic_args(abs, exp, (half::round_style == std::round_to_nearest) ? 29 : 27); + detail::uint32 m = mm.first - mm.second; + for(exp += 13; m < 0x80000000 && exp; m <<= 1, --exp) + ; + unsigned int sign = arg.data_ & 0x8000; + if(exp > 29) + return half(detail::binary, detail::overflow(sign)); + return half(detail::binary, + detail::fixed2half(m, exp, sign)); +#endif +} + +/// Hyperbolic cosine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::cosh](https://en.cppreference.com/w/cpp/numeric/math/cosh). +/// \param arg function argument +/// \return hyperbolic cosine value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half cosh(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::cosh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x7C00) + return half(detail::binary, (abs > 0x7C00) ? detail::signal(arg.data_) : 0x7C00); + std::pair mm = + detail::hyperbolic_args(abs, exp, (half::round_style == std::round_to_nearest) ? 23 : 26); + detail::uint32 m = mm.first + mm.second, i = (~m & 0xFFFFFFFF) >> 31; + m = (m >> i) | (m & i) | 0x80000000; + if((exp += 13 + i) > 29) + return half(detail::binary, detail::overflow()); + return half(detail::binary, + detail::fixed2half(m, exp)); +#endif +} + +/// Hyperbolic tangent. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::tanh](https://en.cppreference.com/w/cpp/numeric/math/tanh). +/// \param arg function argument +/// \return hyperbolic tangent value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half tanh(half arg) +{ +#ifdef HALF_ARITHMETIC_TYPE + return half(detail::binary, + detail::float2half( + std::tanh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp; + if(!abs) + return arg; + if(abs >= 0x7C00) + return half(detail::binary, + (abs > 0x7C00) ? detail::signal(arg.data_) : (arg.data_ - 0x4000)); + if(abs >= 0x4500) + return half(detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x3BFF, 1, 1)); + if(abs < 0x2700) + return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); + if(half::round_style != std::round_to_nearest && abs == 0x2D3F) + return half(detail::binary, detail::rounded(arg.data_ - 3, 0, 1)); + std::pair mm = detail::hyperbolic_args(abs, exp, 27); + detail::uint32 my = mm.first - mm.second - (half::round_style != std::round_to_nearest), + mx = mm.first + mm.second, i = (~mx & 0xFFFFFFFF) >> 31; + for(exp = 13; my < 0x80000000; my <<= 1, --exp) + ; + mx = (mx >> i) | 0x80000000; + return half(detail::binary, + detail::tangent_post(my, mx, exp - i, arg.data_ & 0x8000)); +#endif +} + +/// Hyperbolic area sine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::asinh](https://en.cppreference.com/w/cpp/numeric/math/asinh). +/// \param arg function argument +/// \return area sine value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half asinh(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::asinh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF; + if(!abs || abs >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + if(abs <= 0x2900) + return half(detail::binary, detail::rounded(arg.data_ - 1, 1, 1)); + if(half::round_style != std::round_to_nearest) + switch(abs) + { + case 0x32D4: + return half(detail::binary, + detail::rounded(arg.data_ - 13, 1, 1)); + case 0x3B5B: + return half(detail::binary, + detail::rounded(arg.data_ - 197, 1, 1)); + } + return half(detail::binary, detail::area(arg.data_)); +#endif +} + +/// Hyperbolic area cosine. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::acosh](https://en.cppreference.com/w/cpp/numeric/math/acosh). +/// \param arg function argument +/// \return area cosine value of \a arg +/// \exception FE_INVALID for signaling NaN or arguments <1 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half acosh(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::acosh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF; + if((arg.data_ & 0x8000) || abs < 0x3C00) + return half(detail::binary, + (abs <= 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs == 0x3C00) + return half(detail::binary, 0); + if(arg.data_ >= 0x7C00) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + return half(detail::binary, detail::area(arg.data_)); +#endif +} + +/// Hyperbolic area tangent. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::atanh](https://en.cppreference.com/w/cpp/numeric/math/atanh). +/// \param arg function argument +/// \return area tangent value of \a arg +/// \exception FE_INVALID for signaling NaN or if abs(\a arg) > 1 +/// \exception FE_DIVBYZERO for +/-1 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half atanh(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::atanh(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF, exp = 0; + if(!abs) + return arg; + if(abs >= 0x3C00) + return half(detail::binary, + (abs == 0x3C00) + ? detail::pole(arg.data_ & 0x8000) + : (abs <= 0x7C00) ? detail::invalid() : detail::signal(arg.data_)); + if(abs < 0x2700) + return half(detail::binary, detail::rounded(arg.data_, 0, 1)); + detail::uint32 m = static_cast((abs & 0x3FF) | ((abs > 0x3FF) << 10)) + << ((abs >> 10) + (abs <= 0x3FF) + 6), + my = 0x80000000 + m, mx = 0x80000000 - m; + for(; mx < 0x80000000; mx <<= 1, ++exp) + ; + int i = my >= mx, s; + return half(detail::binary, + detail::log2_post( + detail::log2((detail::divide64(my >> i, mx, s) + 1) >> 1, 27) + 0x10, + exp + i - 1, + 16, + arg.data_ & 0x8000)); +#endif +} + +/// \} +/// \anchor special +/// \name Error and gamma functions +/// \{ + +/// Error function. +/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in <0.5% +/// of inputs. +/// +/// **See also:** Documentation for [std::erf](https://en.cppreference.com/w/cpp/numeric/math/erf). +/// \param arg function argument +/// \return error function value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half erf(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::erf(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF; + if(!abs || abs >= 0x7C00) + return (abs >= 0x7C00) + ? half(detail::binary, + (abs == 0x7C00) ? (arg.data_ - 0x4000) : detail::signal(arg.data_)) + : arg; + if(abs >= 0x4200) + return half(detail::binary, + detail::rounded((arg.data_ & 0x8000) | 0x3BFF, 1, 1)); + return half(detail::binary, detail::erf(arg.data_)); +#endif +} + +/// Complementary error function. +/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in <0.5% +/// of inputs. +/// +/// **See also:** Documentation for +/// [std::erfc](https://en.cppreference.com/w/cpp/numeric/math/erfc). +/// \param arg function argument +/// \return 1 minus error function value of \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half erfc(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::erfc(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(abs >= 0x7C00) + return (abs >= 0x7C00) + ? half(detail::binary, (abs == 0x7C00) ? (sign >> 1) : detail::signal(arg.data_)) + : arg; + if(!abs) + return half(detail::binary, 0x3C00); + if(abs >= 0x4400) + return half( + detail::binary, + detail::rounded((sign >> 1) - (sign >> 15), sign >> 15, 1)); + return half(detail::binary, detail::erf(arg.data_)); +#endif +} + +/// Natural logarithm of gamma function. +/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in +/// ~0.025% of inputs. +/// +/// **See also:** Documentation for +/// [std::lgamma](https://en.cppreference.com/w/cpp/numeric/math/lgamma). +/// \param arg function argument +/// \return natural logarith of gamma function for \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_DIVBYZERO for 0 or negative integer arguments +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half lgamma(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::lgamma(detail::half2float(arg.data_)))); +#else + int abs = arg.data_ & 0x7FFF; + if(abs >= 0x7C00) + return half(detail::binary, (abs == 0x7C00) ? 0x7C00 : detail::signal(arg.data_)); + if(!abs || arg.data_ >= 0xE400 || + (arg.data_ >= 0xBC00 && !(abs & ((1 << (25 - (abs >> 10))) - 1)))) + return half(detail::binary, detail::pole()); + if(arg.data_ == 0x3C00 || arg.data_ == 0x4000) + return half(detail::binary, 0); + return half(detail::binary, detail::gamma(arg.data_)); +#endif +} + +/// Gamma function. +/// This function may be 1 ULP off the correctly rounded exact result for any rounding mode in +/// <0.25% of inputs. +/// +/// **See also:** Documentation for +/// [std::tgamma](https://en.cppreference.com/w/cpp/numeric/math/tgamma). +/// \param arg function argument +/// \return gamma function value of \a arg +/// \exception FE_INVALID for signaling NaN, negative infinity or negative integer arguments +/// \exception FE_DIVBYZERO for 0 +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half tgamma(half arg) +{ +#if defined(HALF_ARITHMETIC_TYPE) && HALF_ENABLE_CPP11_CMATH + return half(detail::binary, + detail::float2half( + std::tgamma(detail::half2float(arg.data_)))); +#else + unsigned int abs = arg.data_ & 0x7FFF; + if(!abs) + return half(detail::binary, detail::pole(arg.data_)); + if(abs >= 0x7C00) + return (arg.data_ == 0x7C00) ? arg : half(detail::binary, detail::signal(arg.data_)); + if(arg.data_ >= 0xE400 || (arg.data_ >= 0xBC00 && !(abs & ((1 << (25 - (abs >> 10))) - 1)))) + return half(detail::binary, detail::invalid()); + if(arg.data_ >= 0xCA80) + return half( + detail::binary, + detail::underflow((1 - ((abs >> (25 - (abs >> 10))) & 1)) << 15)); + if(arg.data_ <= 0x100 || (arg.data_ >= 0x4900 && arg.data_ < 0x8000)) + return half(detail::binary, detail::overflow()); + if(arg.data_ == 0x3C00) + return arg; + return half(detail::binary, detail::gamma(arg.data_)); +#endif +} + +/// \} +/// \anchor rounding +/// \name Rounding +/// \{ + +/// Nearest integer not less than half value. +/// **See also:** Documentation for +/// [std::ceil](https://en.cppreference.com/w/cpp/numeric/math/ceil). +/// \param arg half to round +/// \return nearest integer not less than \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded +inline half ceil(half arg) +{ + return half(detail::binary, + detail::integral(arg.data_)); +} + +/// Nearest integer not greater than half value. +/// **See also:** Documentation for +/// [std::floor](https://en.cppreference.com/w/cpp/numeric/math/floor). +/// \param arg half to round +/// \return nearest integer not greater than \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded +inline half floor(half arg) +{ + return half(detail::binary, + detail::integral(arg.data_)); +} + +/// Nearest integer not greater in magnitude than half value. +/// **See also:** Documentation for +/// [std::trunc](https://en.cppreference.com/w/cpp/numeric/math/trunc). +/// \param arg half to round +/// \return nearest integer not greater in magnitude than \a arg +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded +inline half trunc(half arg) +{ + return half(detail::binary, detail::integral(arg.data_)); +} + +/// Nearest integer. +/// **See also:** Documentation for +/// [std::round](https://en.cppreference.com/w/cpp/numeric/math/round). +/// \param arg half to round +/// \return nearest integer, rounded away from zero in half-way cases +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded +inline half round(half arg) +{ + return half(detail::binary, detail::integral(arg.data_)); +} + +/// Nearest integer. +/// **See also:** Documentation for +/// [std::lround](https://en.cppreference.com/w/cpp/numeric/math/round). +/// \param arg half to round +/// \return nearest integer, rounded away from zero in half-way cases +/// \exception FE_INVALID if value is not representable as `long` +inline long lround(half arg) +{ + return detail::half2int(arg.data_); +} + +/// Nearest integer using half's internal rounding mode. +/// **See also:** Documentation for +/// [std::rint](https://en.cppreference.com/w/cpp/numeric/math/rint). +/// \param arg half expression to round +/// \return nearest integer using default rounding mode +/// \exception FE_INVALID for signaling NaN +/// \exception FE_INEXACT if value had to be rounded +inline half rint(half arg) +{ + return half(detail::binary, detail::integral(arg.data_)); +} + +/// Nearest integer using half's internal rounding mode. +/// **See also:** Documentation for +/// [std::lrint](https://en.cppreference.com/w/cpp/numeric/math/rint). +/// \param arg half expression to round +/// \return nearest integer using default rounding mode +/// \exception FE_INVALID if value is not representable as `long` +/// \exception FE_INEXACT if value had to be rounded +inline long lrint(half arg) +{ + return detail::half2int(arg.data_); +} + +/// Nearest integer using half's internal rounding mode. +/// **See also:** Documentation for +/// [std::nearbyint](https://en.cppreference.com/w/cpp/numeric/math/nearbyint). +/// \param arg half expression to round +/// \return nearest integer using default rounding mode +/// \exception FE_INVALID for signaling NaN +inline half nearbyint(half arg) +{ + return half(detail::binary, detail::integral(arg.data_)); +} +#if HALF_ENABLE_CPP11_LONG_LONG +/// Nearest integer. +/// **See also:** Documentation for +/// [std::llround](https://en.cppreference.com/w/cpp/numeric/math/round). +/// \param arg half to round +/// \return nearest integer, rounded away from zero in half-way cases +/// \exception FE_INVALID if value is not representable as `long long` +inline long long llround(half arg) +{ + return detail::half2int(arg.data_); +} + +/// Nearest integer using half's internal rounding mode. +/// **See also:** Documentation for +/// [std::llrint](https://en.cppreference.com/w/cpp/numeric/math/rint). +/// \param arg half expression to round +/// \return nearest integer using default rounding mode +/// \exception FE_INVALID if value is not representable as `long long` +/// \exception FE_INEXACT if value had to be rounded +inline long long llrint(half arg) +{ + return detail::half2int(arg.data_); +} +#endif + +/// \} +/// \anchor float +/// \name Floating point manipulation +/// \{ + +/// Decompress floating-point number. +/// **See also:** Documentation for +/// [std::frexp](https://en.cppreference.com/w/cpp/numeric/math/frexp). +/// \param arg number to decompress +/// \param exp address to store exponent at +/// \return significant in range [0.5, 1) +/// \exception FE_INVALID for signaling NaN +inline half frexp(half arg, int* exp) +{ + *exp = 0; + unsigned int abs = arg.data_ & 0x7FFF; + if(abs >= 0x7C00 || !abs) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs < 0x400; abs <<= 1, --*exp) + ; + *exp += (abs >> 10) - 14; + return half(detail::binary, (arg.data_ & 0x8000) | 0x3800 | (abs & 0x3FF)); +} + +/// Multiply by power of two. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::scalbln](https://en.cppreference.com/w/cpp/numeric/math/scalbn). +/// \param arg number to modify +/// \param exp power of two to multiply with +/// \return \a arg multplied by 2 raised to \a exp +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half scalbln(half arg, long exp) +{ + unsigned int abs = arg.data_ & 0x7FFF, sign = arg.data_ & 0x8000; + if(abs >= 0x7C00 || !abs) + return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; + for(; abs < 0x400; abs <<= 1, --exp) + ; + exp += abs >> 10; + if(exp > 30) + return half(detail::binary, detail::overflow(sign)); + else if(exp < -10) + return half(detail::binary, detail::underflow(sign)); + else if(exp > 0) + return half(detail::binary, sign | (exp << 10) | (abs & 0x3FF)); + unsigned int m = (abs & 0x3FF) | 0x400; + return half(detail::binary, + detail::rounded( + sign | (m >> (1 - exp)), (m >> -exp) & 1, (m & ((1 << -exp) - 1)) != 0)); +} + +/// Multiply by power of two. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::scalbn](https://en.cppreference.com/w/cpp/numeric/math/scalbn). +/// \param arg number to modify +/// \param exp power of two to multiply with +/// \return \a arg multplied by 2 raised to \a exp +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half scalbn(half arg, int exp) { return scalbln(arg, exp); } + +/// Multiply by power of two. +/// This function is exact to rounding for all rounding modes. +/// +/// **See also:** Documentation for +/// [std::ldexp](https://en.cppreference.com/w/cpp/numeric/math/ldexp). +/// \param arg number to modify +/// \param exp power of two to multiply with +/// \return \a arg multplied by 2 raised to \a exp +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +inline half ldexp(half arg, int exp) { return scalbln(arg, exp); } + +/// Extract integer and fractional parts. +/// **See also:** Documentation for +/// [std::modf](https://en.cppreference.com/w/cpp/numeric/math/modf). +/// \param arg number to decompress +/// \param iptr address to store integer part at +/// \return fractional part +/// \exception FE_INVALID for signaling NaN +inline half modf(half arg, half* iptr) +{ + unsigned int abs = arg.data_ & 0x7FFF; + if(abs > 0x7C00) + { + arg = half(detail::binary, detail::signal(arg.data_)); + return *iptr = arg, arg; + } + if(abs >= 0x6400) + return *iptr = arg, half(detail::binary, arg.data_ & 0x8000); + if(abs < 0x3C00) + return iptr->data_ = arg.data_ & 0x8000, arg; + unsigned int exp = abs >> 10, mask = (1 << (25 - exp)) - 1, m = arg.data_ & mask; + iptr->data_ = arg.data_ & ~mask; + if(!m) + return half(detail::binary, arg.data_ & 0x8000); + for(; m < 0x400; m <<= 1, --exp) + ; + return half(detail::binary, (arg.data_ & 0x8000) | (exp << 10) | (m & 0x3FF)); +} + +/// Extract exponent. +/// **See also:** Documentation for +/// [std::ilogb](https://en.cppreference.com/w/cpp/numeric/math/ilogb). +/// \param arg number to query +/// \return floating-point exponent +/// \retval FP_ILOGB0 for zero +/// \retval FP_ILOGBNAN for NaN +/// \retval INT_MAX for infinity +/// \exception FE_INVALID for 0 or infinite values +inline int ilogb(half arg) +{ + int abs = arg.data_ & 0x7FFF, exp; + if(!abs || abs >= 0x7C00) + { + detail::raise(FE_INVALID); + return !abs ? FP_ILOGB0 : (abs == 0x7C00) ? INT_MAX : FP_ILOGBNAN; + } + for(exp = (abs >> 10) - 15; abs < 0x200; abs <<= 1, --exp) + ; + return exp; +} + +/// Extract exponent. +/// **See also:** Documentation for +/// [std::logb](https://en.cppreference.com/w/cpp/numeric/math/logb). +/// \param arg number to query +/// \return floating-point exponent +/// \exception FE_INVALID for signaling NaN +/// \exception FE_DIVBYZERO for 0 +inline half logb(half arg) +{ + int abs = arg.data_ & 0x7FFF, exp; + if(!abs) + return half(detail::binary, detail::pole(0x8000)); + if(abs >= 0x7C00) + return half(detail::binary, (abs == 0x7C00) ? 0x7C00 : detail::signal(arg.data_)); + for(exp = (abs >> 10) - 15; abs < 0x200; abs <<= 1, --exp) + ; + unsigned int value = static_cast(exp < 0) << 15; + if(exp) + { + unsigned int m = std::abs(exp) << 6; + for(exp = 18; m < 0x400; m <<= 1, --exp) + ; + value |= (exp << 10) + m; + } + return half(detail::binary, value); +} + +/// Next representable value. +/// **See also:** Documentation for +/// [std::nextafter](https://en.cppreference.com/w/cpp/numeric/math/nextafter). +/// \param from value to compute next representable value for +/// \param to direction towards which to compute next value +/// \return next representable value after \a from in direction towards \a to +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW for infinite result from finite argument +/// \exception FE_UNDERFLOW for subnormal result +inline half nextafter(half from, half to) +{ + int fabs = from.data_ & 0x7FFF, tabs = to.data_ & 0x7FFF; + if(fabs > 0x7C00 || tabs > 0x7C00) + return half(detail::binary, detail::signal(from.data_, to.data_)); + if(from.data_ == to.data_ || !(fabs | tabs)) + return to; + if(!fabs) + { + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT); + return half(detail::binary, (to.data_ & 0x8000) + 1); + } + unsigned int out = + from.data_ + + (((from.data_ >> 15) ^ + static_cast((from.data_ ^ (0x8000 | (0x8000 - (from.data_ >> 15)))) < + (to.data_ ^ (0x8000 | (0x8000 - (to.data_ >> 15)))))) + << 1) - + 1; + detail::raise(FE_OVERFLOW, fabs < 0x7C00 && (out & 0x7C00) == 0x7C00); + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT && (out & 0x7C00) < 0x400); + return half(detail::binary, out); +} + +/// Next representable value. +/// **See also:** Documentation for +/// [std::nexttoward](https://en.cppreference.com/w/cpp/numeric/math/nexttoward). +/// \param from value to compute next representable value for +/// \param to direction towards which to compute next value +/// \return next representable value after \a from in direction towards \a to +/// \exception FE_INVALID for signaling NaN +/// \exception FE_OVERFLOW for infinite result from finite argument +/// \exception FE_UNDERFLOW for subnormal result +inline half nexttoward(half from, long double to) +{ + int fabs = from.data_ & 0x7FFF; + if(fabs > 0x7C00) + return half(detail::binary, detail::signal(from.data_)); + long double lfrom = static_cast(from); + if(detail::builtin_isnan(to) || lfrom == to) + return half(static_cast(to)); + if(!fabs) + { + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT); + return half(detail::binary, (static_cast(detail::builtin_signbit(to)) << 15) + 1); + } + unsigned int out = + from.data_ + (((from.data_ >> 15) ^ static_cast(lfrom < to)) << 1) - 1; + detail::raise(FE_OVERFLOW, (out & 0x7FFF) == 0x7C00); + detail::raise(FE_UNDERFLOW, !HALF_ERRHANDLING_UNDERFLOW_TO_INEXACT && (out & 0x7FFF) < 0x400); + return half(detail::binary, out); +} + +/// Take sign. +/// **See also:** Documentation for +/// [std::copysign](https://en.cppreference.com/w/cpp/numeric/math/copysign). +/// \param x value to change sign for +/// \param y value to take sign from +/// \return value equal to \a x in magnitude and to \a y in sign +inline HALF_CONSTEXPR half copysign(half x, half y) +{ + return half(detail::binary, x.data_ ^ ((x.data_ ^ y.data_) & 0x8000)); +} + +/// \} +/// \anchor classification +/// \name Floating point classification +/// \{ + +/// Classify floating-point value. +/// **See also:** Documentation for +/// [std::fpclassify](https://en.cppreference.com/w/cpp/numeric/math/fpclassify). +/// \param arg number to classify +/// \retval FP_ZERO for positive and negative zero +/// \retval FP_SUBNORMAL for subnormal numbers +/// \retval FP_INFINITY for positive and negative infinity +/// \retval FP_NAN for NaNs +/// \retval FP_NORMAL for all other (normal) values +inline HALF_CONSTEXPR int fpclassify(half arg) +{ + return !(arg.data_ & 0x7FFF) ? FP_ZERO : ((arg.data_ & 0x7FFF) < 0x400) + ? FP_SUBNORMAL + : ((arg.data_ & 0x7FFF) < 0x7C00) + ? FP_NORMAL + : ((arg.data_ & 0x7FFF) == 0x7C00) + ? FP_INFINITE + : FP_NAN; +} + +/// Check if finite number. +/// **See also:** Documentation for +/// [std::isfinite](https://en.cppreference.com/w/cpp/numeric/math/isfinite). +/// \param arg number to check +/// \retval true if neither infinity nor NaN +/// \retval false else +inline HALF_CONSTEXPR bool isfinite(half arg) { return (arg.data_ & 0x7C00) != 0x7C00; } + +/// Check for infinity. +/// **See also:** Documentation for +/// [std::isinf](https://en.cppreference.com/w/cpp/numeric/math/isinf). +/// \param arg number to check +/// \retval true for positive or negative infinity +/// \retval false else +inline HALF_CONSTEXPR bool isinf(half arg) { return (arg.data_ & 0x7FFF) == 0x7C00; } + +/// Check for NaN. +/// **See also:** Documentation for +/// [std::isnan](https://en.cppreference.com/w/cpp/numeric/math/isnan). +/// \param arg number to check +/// \retval true for NaNs +/// \retval false else +inline HALF_CONSTEXPR bool isnan(half arg) { return (arg.data_ & 0x7FFF) > 0x7C00; } + +/// Check if normal number. +/// **See also:** Documentation for +/// [std::isnormal](https://en.cppreference.com/w/cpp/numeric/math/isnormal). +/// \param arg number to check +/// \retval true if normal number +/// \retval false if either subnormal, zero, infinity or NaN +inline HALF_CONSTEXPR bool isnormal(half arg) +{ + return ((arg.data_ & 0x7C00) != 0) & ((arg.data_ & 0x7C00) != 0x7C00); +} + +/// Check sign. +/// **See also:** Documentation for +/// [std::signbit](https://en.cppreference.com/w/cpp/numeric/math/signbit). +/// \param arg number to check +/// \retval true for negative number +/// \retval false for positive number +inline HALF_CONSTEXPR bool signbit(half arg) { return (arg.data_ & 0x8000) != 0; } + +/// \} +/// \anchor compfunc +/// \name Comparison +/// \{ + +/// Quiet comparison for greater than. +/// **See also:** Documentation for +/// [std::isgreater](https://en.cppreference.com/w/cpp/numeric/math/isgreater). +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater than \a y +/// \retval false else +inline HALF_CONSTEXPR bool isgreater(half x, half y) +{ + return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) > + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && + !isnan(x) && !isnan(y); +} + +/// Quiet comparison for greater equal. +/// **See also:** Documentation for +/// [std::isgreaterequal](https://en.cppreference.com/w/cpp/numeric/math/isgreaterequal). +/// \param x first operand +/// \param y second operand +/// \retval true if \a x greater equal \a y +/// \retval false else +inline HALF_CONSTEXPR bool isgreaterequal(half x, half y) +{ + return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) >= + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && + !isnan(x) && !isnan(y); +} + +/// Quiet comparison for less than. +/// **See also:** Documentation for +/// [std::isless](https://en.cppreference.com/w/cpp/numeric/math/isless). +/// \param x first operand +/// \param y second operand +/// \retval true if \a x less than \a y +/// \retval false else +inline HALF_CONSTEXPR bool isless(half x, half y) +{ + return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) < + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && + !isnan(x) && !isnan(y); +} + +/// Quiet comparison for less equal. +/// **See also:** Documentation for +/// [std::islessequal](https://en.cppreference.com/w/cpp/numeric/math/islessequal). +/// \param x first operand +/// \param y second operand +/// \retval true if \a x less equal \a y +/// \retval false else +inline HALF_CONSTEXPR bool islessequal(half x, half y) +{ + return ((x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) + (x.data_ >> 15)) <= + ((y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))) + (y.data_ >> 15)) && + !isnan(x) && !isnan(y); +} + +/// Quiet comarison for less or greater. +/// **See also:** Documentation for +/// [std::islessgreater](https://en.cppreference.com/w/cpp/numeric/math/islessgreater). +/// \param x first operand +/// \param y second operand +/// \retval true if either less or greater +/// \retval false else +inline HALF_CONSTEXPR bool islessgreater(half x, half y) +{ + return x.data_ != y.data_ && ((x.data_ | y.data_) & 0x7FFF) && !isnan(x) && !isnan(y); +} + +/// Quiet check if unordered. +/// **See also:** Documentation for +/// [std::isunordered](https://en.cppreference.com/w/cpp/numeric/math/isunordered). +/// \param x first operand +/// \param y second operand +/// \retval true if unordered (one or two NaN operands) +/// \retval false else +inline HALF_CONSTEXPR bool isunordered(half x, half y) { return isnan(x) || isnan(y); } + +/// \} +/// \anchor casting +/// \name Casting +/// \{ + +/// Cast to or from half-precision floating-point number. +/// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values +/// are converted +/// directly using the default rounding mode, without any roundtrip over `float` that a +/// `static_cast` would otherwise do. +/// +/// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any +/// of the two types +/// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) +/// results in a compiler +/// error and casting between [half](\ref half_float::half)s returns the argument unmodified. +/// \tparam T destination type (half or built-in arithmetic type) +/// \tparam U source type (half or built-in arithmetic type) +/// \param arg value to cast +/// \return \a arg converted to destination type +/// \exception FE_INVALID if \a T is integer type and result is not representable as \a T +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +template +T half_cast(U arg) +{ + return detail::half_caster::cast(arg); +} + +/// Cast to or from half-precision floating-point number. +/// This casts between [half](\ref half_float::half) and any built-in arithmetic type. The values +/// are converted +/// directly using the specified rounding mode, without any roundtrip over `float` that a +/// `static_cast` would otherwise do. +/// +/// Using this cast with neither of the two types being a [half](\ref half_float::half) or with any +/// of the two types +/// not being a built-in arithmetic type (apart from [half](\ref half_float::half), of course) +/// results in a compiler +/// error and casting between [half](\ref half_float::half)s returns the argument unmodified. +/// \tparam T destination type (half or built-in arithmetic type) +/// \tparam R rounding mode to use. +/// \tparam U source type (half or built-in arithmetic type) +/// \param arg value to cast +/// \return \a arg converted to destination type +/// \exception FE_INVALID if \a T is integer type and result is not representable as \a T +/// \exception FE_OVERFLOW, ...UNDERFLOW, ...INEXACT according to rounding +template +T half_cast(U arg) +{ + return detail::half_caster::cast(arg); +} +/// \} + +/// \} +/// \anchor errors +/// \name Error handling +/// \{ + +/// Clear exception flags. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::feclearexcept](https://en.cppreference.com/w/cpp/numeric/fenv/feclearexcept). +/// \param excepts OR of exceptions to clear +/// \retval 0 all selected flags cleared successfully +inline int feclearexcept(int excepts) +{ + detail::errflags() &= ~excepts; + return 0; +} + +/// Test exception flags. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::fetestexcept](https://en.cppreference.com/w/cpp/numeric/fenv/fetestexcept). +/// \param excepts OR of exceptions to test +/// \return OR of selected exceptions if raised +inline int fetestexcept(int excepts) { return detail::errflags() & excepts; } + +/// Raise exception flags. +/// This raises the specified floating point exceptions and also invokes any additional automatic +/// exception handling as +/// configured with the [HALF_ERRHANDLIG_...](\ref HALF_ERRHANDLING_ERRNO) preprocessor symbols. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::feraiseexcept](https://en.cppreference.com/w/cpp/numeric/fenv/feraiseexcept). +/// \param excepts OR of exceptions to raise +/// \retval 0 all selected exceptions raised successfully +inline int feraiseexcept(int excepts) +{ + detail::errflags() |= excepts; + detail::raise(excepts); + return 0; +} + +/// Save exception flags. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::fegetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag). +/// \param flagp adress to store flag state at +/// \param excepts OR of flags to save +/// \retval 0 for success +inline int fegetexceptflag(int* flagp, int excepts) +{ + *flagp = detail::errflags() & excepts; + return 0; +} + +/// Restore exception flags. +/// This only copies the specified exception state (including unset flags) without incurring any +/// additional exception handling. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// +/// **See also:** Documentation for +/// [std::fesetexceptflag](https://en.cppreference.com/w/cpp/numeric/fenv/feexceptflag). +/// \param flagp adress to take flag state from +/// \param excepts OR of flags to restore +/// \retval 0 for success +inline int fesetexceptflag(const int* flagp, int excepts) +{ + detail::errflags() = (detail::errflags() | (*flagp & excepts)) & (*flagp | ~excepts); + return 0; +} + +/// Throw C++ exceptions based on set exception flags. +/// This function manually throws a corresponding C++ exception if one of the specified flags is +/// set, +/// no matter if automatic throwing (via [HALF_ERRHANDLING_THROW_...](\ref +/// HALF_ERRHANDLING_THROW_INVALID)) is enabled or not. +/// This function works even if [automatic exception flag handling](\ref HALF_ERRHANDLING_FLAGS) is +/// disabled, +/// but in that case manual flag management is the only way to raise flags. +/// \param excepts OR of exceptions to test +/// \param msg error message to use for exception description +/// \throw std::domain_error if `FE_INVALID` or `FE_DIVBYZERO` is selected and set +/// \throw std::overflow_error if `FE_OVERFLOW` is selected and set +/// \throw std::underflow_error if `FE_UNDERFLOW` is selected and set +/// \throw std::range_error if `FE_INEXACT` is selected and set +inline void fethrowexcept(int excepts, const char* msg = "") +{ + excepts &= detail::errflags(); + if(excepts & (FE_INVALID | FE_DIVBYZERO)) + throw std::domain_error(msg); + if(excepts & FE_OVERFLOW) + throw std::overflow_error(msg); + if(excepts & FE_UNDERFLOW) + throw std::underflow_error(msg); + if(excepts & FE_INEXACT) + throw std::range_error(msg); +} +/// \} +} + +#undef HALF_UNUSED_NOERR +#undef HALF_CONSTEXPR +#undef HALF_CONSTEXPR_CONST +#undef HALF_CONSTEXPR_NOERR +#undef HALF_NOEXCEPT +#undef HALF_NOTHROW +#undef HALF_THREAD_LOCAL +#undef HALF_TWOS_COMPLEMENT_INT +#ifdef HALF_POP_WARNINGS +#pragma warning(pop) +#undef HALF_POP_WARNINGS +#endif + +#endif diff --git a/external/include/bfloat16_dev.hpp b/external/rocm/include/bfloat16_dev.hpp similarity index 100% rename from external/include/bfloat16_dev.hpp rename to external/rocm/include/bfloat16_dev.hpp diff --git a/script/cmake-cuda.sh b/script/cmake-cuda.sh index 759564b8ee..106035d70a 100755 --- a/script/cmake-cuda.sh +++ b/script/cmake-cuda.sh @@ -1,24 +1,20 @@ #!/bin/bash -rm -f CMakeCache.txt -rm -f *.cmake -rm -rf CMakeFiles - -MY_PROJECT_SOURCE=/home/chao/code/modular_convolution -MY_PROJECT_INSTALL=../install.dir +MY_PROJECT_SOURCE=../../../ + + export CUDA_ROOT=/usr/local/cuda + export CPATH=$CPATH:$CUDA_ROOT/include + export LIBRARY_PATH=$LIBRARY_PATH:$CUDA_ROOT/lib64 + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CUDA_ROOT/lib64 cmake \ --D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \ -D CMAKE_CXX_COMPILER=clang++ \ -D CMAKE_BUILD_TYPE=Release \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D DEVICE_BACKEND=NVIDIA \ --D CUDA_COMMON_INCLUDE_DIR="/package/install/cuda/10.1/NVIDIA_CUDA-10.1_Samples/common/inc" \ -D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -maxrregcount=128" \ ${MY_PROJECT_SOURCE} -#-D BOOST_ROOT="/package/install/boost_1.67.0" \ -#-D CMAKE_CUDA_COMPILER="/package/install/cuda_10.0/bin/nvcc" \ -#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61" \ -#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -maxrregcount=128" \ +#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_70,code=sm_70" \ +#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_70,code=sm_70 -Xptxas -v -maxrregcount=128" \ diff --git a/script/cmake-cuda_docker.sh b/script/cmake-cuda_docker.sh deleted file mode 100755 index d414bd873d..0000000000 --- a/script/cmake-cuda_docker.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash - -MY_PROJECT_SOURCE=../../../ -MY_PROJECT_INSTALL=../install.dir - -export CUDA_ROOT=/usr/local/cuda -export CPATH=$CPATH:$CUDA_ROOT/include -export LIBRARY_PATH=$LIBRARY_PATH:$CUDA_ROOT/lib64 -export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CUDA_ROOT/lib64 - -cmake \ --D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \ --D CMAKE_CXX_COMPILER=clang++-6.0 \ --D CMAKE_BUILD_TYPE=Release \ --D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ --D DEVICE_BACKEND=NVIDIA \ --D CUDA_COMMON_INCLUDE_DIR="/root/NVIDIA_CUDA-10.1_Samples/common/inc" \ --D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -maxrregcount=128" \ -${MY_PROJECT_SOURCE} - - -#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61" \ -#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -maxrregcount=128" \ diff --git a/script/cmake-hip.sh b/script/cmake-rocm3.1.sh similarity index 84% rename from script/cmake-hip.sh rename to script/cmake-rocm3.1.sh index 959582ffcf..c7bdb4f1c6 100755 --- a/script/cmake-hip.sh +++ b/script/cmake-rocm3.1.sh @@ -1,5 +1,4 @@ #!/bin/bash - rm -f CMakeCache.txt rm -f *.cmake rm -rf CMakeFiles @@ -11,9 +10,10 @@ cmake -D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \ -D CMAKE_BUILD_TYPE=Release \ -D DEVICE_BACKEND="AMD" \ --D HIP_HIPCC_FLAGS="${HIP_HIPCC_FLAGS} -gline-tables-only -v" \ --D CMAKE_CXX_FLAGS="-gline-tables-only --amdgpu-target=gfx906" \ +-D CMAKE_CXX_FLAGS="--amdgpu-target=gfx906" \ -D CMAKE_CXX_COMPILER=/opt/rocm/hip/bin/hipcc \ -D CMAKE_PREFIX_PATH="/opt/rocm" \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ ${MY_PROJECT_SOURCE} + +#-D CMAKE_CXX_FLAGS="-gline-tables-only -v --amdgpu-target=gfx906" \ diff --git a/script/cmake-rocm3.5.sh b/script/cmake-rocm3.5.sh new file mode 100755 index 0000000000..d3a9b575ee --- /dev/null +++ b/script/cmake-rocm3.5.sh @@ -0,0 +1,21 @@ +#!/bin/bash +rm -f CMakeCache.txt +rm -f *.cmake +rm -rf CMakeFiles + +MY_PROJECT_SOURCE=../../../ +MY_PROJECT_INSTALL=../install.dir + +cmake \ +-D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \ +-D CMAKE_BUILD_TYPE=Release \ +-D DEVICE_BACKEND="AMD" \ +-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-enable-global-sgpr-addr -mllvm --amdgpu-spill-vgpr-to-agpr=0" \ +-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ +-D CMAKE_PREFIX_PATH="/opt/rocm" \ +-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ +${MY_PROJECT_SOURCE} + +#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-enable-global-sgpr-addr -mllvm --amdgpu-spill-vgpr-to-agpr=0" \ +#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-enable-global-sgpr-addr -mllvm --amdgpu-spill-vgpr-to-agpr=0 -save-temps" \ +#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-enable-global-sgpr-addr -mllvm --amdgpu-spill-vgpr-to-agpr=0 -v -gline-tables-only -save-temps" \ diff --git a/script/compile-hip.sh b/script/compile-rocm3.1.sh similarity index 100% rename from script/compile-hip.sh rename to script/compile-rocm3.1.sh diff --git a/script/docker-cuda.sh b/script/docker-cuda.sh deleted file mode 100755 index 508774b657..0000000000 --- a/script/docker-cuda.sh +++ /dev/null @@ -1,3 +0,0 @@ -WORKSPACE=$1 -echo "workspace: " $WORKSPACE -sudo docker run -it -v $WORKSPACE:/root/workspace --group-add sudo --runtime=nvidia asroy/cuda:10.1-cudnn7-devel-ubuntu18.04-latest /bin/bash diff --git a/script/hack_isa.sh b/script/hack_isa.sh deleted file mode 100755 index 78793689db..0000000000 --- a/script/hack_isa.sh +++ /dev/null @@ -1,9 +0,0 @@ -# step 1: GET ISA DUMP -#cd /root/workspace/mlopen/modular_convolution/build/hipcc/build.dir/driver && KMDUMPISA=1 /opt/rocm/hip/bin/hipcc -I/root/workspace/mlopen/modular_convolution/build/hipcc/build.dir/composable_kernel/include/utility -I/root/workspace/mlopen/modular_convolution/driver/include -I/root/workspace/mlopen/modular_convolution/composable_kernel/include/kernel_algorithm -I/root/workspace/mlopen/modular_convolution/composable_kernel/include/tensor_operation -I/root/workspace/mlopen/modular_convolution/composable_kernel/include/tensor_description -I/root/workspace/mlopen/modular_convolution/composable_kernel/include/utility -I/root/workspace/mlopen/modular_convolution/composable_kernel/include -gline-tables-only --amdgpu-target=gfx906 -fopenmp=libomp -O3 -DNDEBUG -std=c++14 -o CMakeFiles/driver.dir/src/driver.cpp.o -c /root/workspace/mlopen/modular_convolution/driver/src/driver.cpp -fno-gpu-rdc - -# step 2: HACK ISA -#cd /root/workspace/mlopen/modular_convolution/build/hipcc/build.dir/driver && KMHACKISA=1 /opt/rocm/hip/bin/hipcc -I/root/workspace/mlopen/modular_convolution/build/hipcc/build.dir/composable_kernel/include/utility -I/root/workspace/mlopen/modular_convolution/driver/include -I/root/workspace/mlopen/modular_convolution/composable_kernel/include/kernel_algorithm -I/root/workspace/mlopen/modular_convolution/composable_kernel/include/tensor_operation -I/root/workspace/mlopen/modular_convolution/composable_kernel/include/tensor_description -I/root/workspace/mlopen/modular_convolution/composable_kernel/include/utility -I/root/workspace/mlopen/modular_convolution/composable_kernel/include -gline-tables-only --amdgpu-target=gfx906 -fopenmp=libomp -O3 -DNDEBUG -std=c++14 -o CMakeFiles/driver.dir/src/driver.cpp.o -c /root/workspace/mlopen/modular_convolution/driver/src/driver.cpp -fno-gpu-rdc - -# step 3: LINK -#/opt/rocm/hip/bin/hipcc -gline-tables-only --amdgpu-target=gfx906 -fopenmp=libomp -O3 -DNDEBUG CMakeFiles/driver.dir/src/driver.cpp.o -o driver -rdynamic libhost.so -Wl,-rpath,/root/workspace/mlopen/modular_convolution/build/hipcc/build.dir/driver - diff --git a/script/trace.sh b/script/trace.sh deleted file mode 100755 index 231a69de08..0000000000 --- a/script/trace.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash - -/root/workspace/rocprofiler_pkg/bin/rpl_run.sh --timestamp on -i /root/workspace/rocprofiler_pkg/input.xml -d ./trace ./driver/driver 0 10 diff --git a/script/tracer-hip.sh b/script/tracer-hip.sh deleted file mode 100755 index 231a69de08..0000000000 --- a/script/tracer-hip.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash - -/root/workspace/rocprofiler_pkg/bin/rpl_run.sh --timestamp on -i /root/workspace/rocprofiler_pkg/input.xml -d ./trace ./driver/driver 0 10