diff --git a/CMakeLists.txt b/CMakeLists.txt index 962ae7f00d..51d57d016f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,8 +38,6 @@ link_libraries(${OpenMP_pthread_LIBRARY}) #GPU backend if(DEVICE_BACKEND STREQUAL "AMD") find_package(HIP REQUIRED) -elseif(DEVICE_BACKEND STREQUAL "NVIDIA") - enable_language(CUDA) endif() # @@ -64,13 +62,7 @@ endif() if(DEVICE_BACKEND STREQUAL "AMD") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp") - configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/in_memory_operation.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/in_memory_operation.hpp") configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/synchronization.amd.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/synchronization.hpp") -elseif(DEVICE_BACKEND STREQUAL "NVIDIA") - configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/config.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/config.hpp") - configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/float_type.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/float_type.hpp") - configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/in_memory_operation.hpp") - configure_file("${PROJECT_SOURCE_DIR}/composable_kernel/include/utility/synchronization.nvidia.hpp.in" "${PROJECT_BINARY_DIR}/composable_kernel/include/utility/synchronization.hpp") endif() add_subdirectory(driver) @@ -80,26 +72,17 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") message("Compiling options for drivers: ${CMAKE_CXX_FLAGS}") if(DEVICE_BACKEND STREQUAL "AMD") - set(CONV_SOURCE driver/conv_driver.cpp) - set(CONV_BWD_DATA_SOURCE driver/conv_bwd_data_driver.cpp) set(CONV_V2_SOURCE driver/conv_driver_v2.cpp) set(CONV_BWD_DATA_V2_SOURCE driver/conv_bwd_data_driver_v2.cpp) set(CONV_V2_OLC_SOURCE driver/conv_driver_v2_olc.cpp) -elseif(DEVICE_BACKEND STREQUAL "NVIDIA") - set(CONV_SOURCE driver/conv_driver.cu) - set(CONV_BWD_DATA_SOURCE driver/conv_bwd_data_driver.cu) endif() -add_executable(conv_driver ${CONV_SOURCE}) -add_executable(conv_bwd_data_driver ${CONV_BWD_DATA_SOURCE}) add_executable(conv_driver_v2 ${CONV_V2_SOURCE}) add_executable(conv_bwd_data_driver_v2 ${CONV_BWD_DATA_V2_SOURCE}) add_executable(conv_driver_v2_olc ${CONV_V2_OLC_SOURCE}) target_include_directories(conv_driver_v2_olc PRIVATE driver/olCompiling/include/) -target_link_libraries(conv_driver PRIVATE modConv) -target_link_libraries(conv_bwd_data_driver PRIVATE modConv) target_link_libraries(conv_driver_v2 PRIVATE modConv) target_link_libraries(conv_bwd_data_driver_v2 PRIVATE modConv) target_link_libraries(conv_driver_v2_olc PRIVATE modConv) diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 71a9bb6dc0..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,172 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP -#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R1_NCHW_KCYX_NKHW_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm.hpp" - -namespace ck { - -// GemmM = C * Y * X -// GemmN = N * Ho * Wo -// GemmK = K -template -struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw -{ - __device__ void Run(Float* __restrict__ p_in_global, - const Float* __restrict__ p_wei_global, - const Float* __restrict__ p_out_global) const - { - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{}; - constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; - constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{}; - - constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0]; - constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1]; - constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2]; - constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3]; - - constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1]; - constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2]; - constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3]; - - constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2]; - constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3]; - - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - //\todo static_assert for global vector load/store - // statc_assert(); - - // weight tensor - constexpr auto wei_gemmk_gemmm_global_desc = - unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3); - - // input tensor - constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( - in_n_c_hi_wi_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - Pad, InLeftPads, InRightPads>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); - - constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2]; - constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3]; - - constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( - in_n_c_hip_wip_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - Embed, Sequence>{}, - Embed, Sequence>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor( - in_n_c_y_ho_x_wo_global_desc, - make_tuple(Merge>{}, Merge>{}), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // output tensor - constexpr auto out_gemmk_gemmn_global_desc = - transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3), - make_tuple(PassThrough{}, Merge>{}), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // GEMM - // \todo there are more combinations of Y, ConvDilationH and ConvStrideH that don't need - // atomic, find out all of them - constexpr bool not_need_atomic = (ConvStrideH >= ConvDilationH * (Y - 1) + 1) and - (ConvStrideW >= ConvDilationW * (X - 1) + 1); - - constexpr auto in_memory_op = - not_need_atomic ? InMemoryDataOperation::Set : InMemoryDataOperation::AtomicAdd; - - constexpr auto gridwise_gemm = - GridwiseGemmTransposedANormalBNormalC_v1, - Sequence<0, 1>, - 1, - GemmABlockCopySrcDataPerRead_GemmM, - GemmABlockCopyDstDataPerWrite_GemmM, - GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, - GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, - Sequence<0, 1>, - Sequence<0, 1>, - 1, - GemmBBlockCopySrcDataPerRead_GemmN, - GemmBBlockCopyDstDataPerWrite_GemmN, - Sequence<0, 1, 2, 3>, - 3, - GemmCThreadCopyDstDataPerWrite_GemmN1>{}; - - gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global); - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp deleted file mode 100644 index 05e4c54a61..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ /dev/null @@ -1,450 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP -#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "blockwise_generic_tensor_slice_copy.hpp" -#include "threadwise_generic_tensor_slice_copy.hpp" -#include "blockwise_gemm.hpp" - -namespace ck { - -template -struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer -{ - __device__ void Run(Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - const Float* const __restrict__ p_out_global) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto True = integral_constant{}; - - constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{}; - constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; - constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{}; - - constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0]; - constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1]; - constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2]; - constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3]; - - constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1]; - constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2]; - constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3]; - - constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2]; - constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3]; - - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - constexpr index_t C0 = GemmMPerThread; - constexpr index_t N0 = GemmNPerThread; - - static_assert(C % C0 == 0 && N % N0 == 0, "wrong!"); - - constexpr index_t C1 = C / C0; - constexpr index_t N1 = N / N0; - - constexpr index_t E = C1 * Y * X; - constexpr index_t B = N1 * Ho * Wo; - - // sanity-check for vectorized memory load - static_assert((Wo == 1 || (ConvStrideW == 1 || InThreadCopyDstDataPerWrite_B == 1)) && - (X == 1 || ConvDilationW % InThreadCopyDstDataPerWrite_B == 0), - "wrong! aligment requirement for vectorized global load of input tensor will " - "be violated"); - - // divide block work by [K, B] - static_assert(E % EPerBlock == 0 && B % BPerBlock == 0 && K % KPerBlock == 0, - "wrong! cannot divide work evenly among block"); - - constexpr index_t EBlockWork = E / EPerBlock; - constexpr index_t BBlockWork = B / BPerBlock; - - constexpr auto block_work_desc = - make_cluster_descriptor(Sequence{}); - - const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id()); - - const index_t e_block_data_on_global = block_work_id[Number<0>{}] * EPerBlock; - const index_t b_block_data_on_global = block_work_id[Number<1>{}] * BPerBlock; - - // output tensor - // global tensor in global memory, src of blockwise copy - constexpr auto out_n_k_howo_global_desc = - unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3); - - constexpr auto out_n0_n1_k_howo_global_desc = transform_tensor_descriptor( - out_n_k_howo_global_desc, - make_tuple(UnMerge>{}, PassThrough{}, PassThrough{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); - - constexpr auto out_k_b_n0_global_desc = transform_tensor_descriptor( - out_n0_n1_k_howo_global_desc, - make_tuple(PassThrough{}, Merge>{}, PassThrough{}), - make_tuple(Sequence<2>{}, Sequence<1, 3>{}, Sequence<0>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - // block tensor in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto out_k_b_n0_block_desc = make_native_tensor_descriptor_aligned( - Sequence{}, Number{}); - - // output tensor blockwise copy - auto blockwise_out_copy = - BlockwiseGenericTensorSliceCopy_v4, - Sequence<0, 1, 2>, - Sequence<0, 1, 2>, - 1, - 2, - OutBlockCopySrcDataPerRead_B, - OutBlockCopyDstDataPerWrite_N0, - AddressSpace::Global, - AddressSpace::Vgpr, - AddressSpace::Lds, - InMemoryDataOperation::Set>( - make_multi_index(0, b_block_data_on_global, 0), make_multi_index(0, 0, 0)); - - // weight tensor - // global tensor in global memory, src of blockwise copy - constexpr auto wei_k_cyx_global_desc = - unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3); - - constexpr auto wei_k_c0_e_global_desc = - transform_tensor_descriptor(wei_k_cyx_global_desc, - make_tuple(PassThrough{}, UnMerge>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{})); - - constexpr auto wei_k_e_c0_global_desc = reorder_tensor_descriptor_given_lower2upper( - wei_k_c0_e_global_desc, Sequence<0, 2, 1>{}); - - // block tensor in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto wei_k_e_c0_block_desc = make_native_tensor_descriptor_aligned( - Sequence{}, Number{}); - - // weight tensor blockwise copy - auto blockwise_wei_copy = - BlockwiseGenericTensorSliceCopy_v4, - Sequence<0, 1, 2>, - Sequence<0, 1, 2>, - 1, - 2, - WeiBlockCopySrcDataPerRead_E, - WeiBlockCopyDstDataPerWrite_C0, - AddressSpace::Global, - AddressSpace::Vgpr, - AddressSpace::Lds, - InMemoryDataOperation::Set>( - make_multi_index(0, e_block_data_on_global, 0), make_multi_index(0, 0, 0)); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[KPerBlock, EPerBlock*C0] is in LDS - // b_mtx[KPerBlocl, BPerBlock*N0] is in LDS - // c_mtx[EPerBlock*C0, BPerBlock*N0] is distributed among threads, and saved in - // register - constexpr auto a_k_ec0_block_mtx_desc = make_ConstantMatrixDescriptor( - wei_k_e_c0_block_desc.GetLength(I0), - wei_k_e_c0_block_desc.GetLength(I1) * wei_k_e_c0_block_desc.GetLength(I2), - wei_k_e_c0_block_desc.GetStride(I0)); - constexpr auto b_k_bn0_block_mtx_desc = make_ConstantMatrixDescriptor( - out_k_b_n0_block_desc.GetLength(I0), - out_k_b_n0_block_desc.GetLength(I1) * out_k_b_n0_block_desc.GetLength(I2), - out_k_b_n0_block_desc.GetStride(I0)); - - // sanity check alignment - // TODO: this check is ad-hoc, should enforce it by enforcing alignment of - // wei_k_e_c0_block_desc and out_k_b_n0_block_desc - static_assert(a_k_ec0_block_mtx_desc.RowStride() % GemmDataPerReadB == 0, "wrong!"); - static_assert(b_k_bn0_block_mtx_desc.RowStride() % GemmDataPerReadA == 0, "wrong!"); - - // sanity check - static_assert(EPerBlock % (GemmMLevel0Cluster * GemmMLevel1Cluster) == 0 && - BPerBlock % (GemmNLevel0Cluster * GemmNLevel1Cluster) == 0, - "wrong!"); - - constexpr index_t GemmMRepeat = EPerBlock / (GemmMLevel0Cluster * GemmMLevel1Cluster); - constexpr index_t GemmNRepeat = BPerBlock / (GemmNLevel0Cluster * GemmNLevel1Cluster); - - // c_thread_mtx definition: this is a mess - // TODO:: more elegent way of defining c_thread_mtx - constexpr auto c_e0e1c0_b0b1n0_thread_mtx_desc = make_ConstantMatrixDescriptor_packed( - Number{}, Number{}); - - const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< - BlockSize, - decltype(a_k_ec0_block_mtx_desc), - decltype(b_k_bn0_block_mtx_desc), - decltype(c_e0e1c0_b0b1n0_thread_mtx_desc), - GemmMPerThread, - GemmNPerThread, - GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmDataPerReadA, - GemmDataPerReadB>{}; - - // LDS allocation for input and weight: be careful of alignment - constexpr index_t max_lds_align = math::lcm(WeiBlockCopyDstDataPerWrite_C0, - OutBlockCopyDstDataPerWrite_N0, - GemmDataPerReadA, - GemmDataPerReadB); - - constexpr index_t out_block_space = - math::integer_least_multiple(out_k_b_n0_block_desc.GetElementSpace(), max_lds_align); - - constexpr index_t wei_block_space = - math::integer_least_multiple(wei_k_e_c0_block_desc.GetElementSpace(), max_lds_align); - - __shared__ Float p_out_block_double[2 * out_block_space]; - __shared__ Float p_wei_block_double[2 * wei_block_space]; - - // register allocation for output - AccFloat p_in_thread[c_e0e1c0_b0b1n0_thread_mtx_desc.GetElementSpace()]; - - // zero out threadwise output - threadwise_matrix_set_zero(c_e0e1c0_b0b1n0_thread_mtx_desc, p_in_thread); - - // LDS double buffer: preload data into LDS - { - blockwise_out_copy.Run(p_out_global, p_out_block_double); - blockwise_wei_copy.Run(p_wei_global, p_wei_block_double); - } - - // LDS double buffer: main body - for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K; - k_block_data_begin += 2 * KPerBlock) - { -#pragma unroll - for(index_t iloop = 0; iloop < 2; ++iloop) - { - const bool even_loop = (iloop % 2 == 0); - - Float* p_out_block_now = - even_loop ? p_out_block_double : p_out_block_double + out_block_space; - Float* p_wei_block_now = - even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space; - - Float* p_out_block_next = - even_loop ? p_out_block_double + out_block_space : p_out_block_double; - Float* p_wei_block_next = - even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; - - Float p_out_thread_buffer[blockwise_out_copy.GetThreadBufferSize()]; - Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()]; - - blockwise_out_copy.MoveSrcSliceWindow(Sequence{}, True); - blockwise_wei_copy.MoveSrcSliceWindow(Sequence{}, True); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - blockwise_out_copy.RunLoadThreadBuffer(p_out_global, p_out_thread_buffer); - blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_now, p_out_block_now, p_in_thread); - - // LDS double buffer: store next data to LDS - blockwise_out_copy.RunStoreThreadBuffer(p_out_thread_buffer, p_out_block_next); - blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next); - } - } - - // LDS double buffer: tail - { - constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0); - - if(has_two_iteration_left) // if has 2 iteration left - { - Float p_out_thread_buffer[blockwise_out_copy.GetThreadBufferSize()]; - Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()]; - - blockwise_out_copy.MoveSrcSliceWindow(Sequence{}, True); - blockwise_wei_copy.MoveSrcSliceWindow(Sequence{}, True); - - __syncthreads(); - - // LDS double buffer: load last data from device mem - blockwise_out_copy.RunLoadThreadBuffer(p_out_global, p_out_thread_buffer); - blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer); - - // LDS double buffer: GEMM on 2nd-last data - blockwise_gemm.Run(p_wei_block_double, p_out_block_double, p_in_thread); - - // LDS double buffer: store last data to LDS - blockwise_out_copy.RunStoreThreadBuffer(p_out_thread_buffer, - p_out_block_double + out_block_space); - blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, - p_wei_block_double + wei_block_space); - - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(p_wei_block_double + wei_block_space, - p_out_block_double + out_block_space, - p_in_thread); - } - else // if has 1 iteration left - { - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(p_wei_block_double, p_out_block_double, p_in_thread); - } - } - - { -#if 1 // debug - // input: register to global memory, atomic add - constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW) - ? InMemoryDataOperation::Set - : InMemoryDataOperation::AtomicAdd; -#else - constexpr auto in_memory_op = InMemoryDataOperation::AtomicAdd; -#endif - - constexpr index_t E1 = GemmMLevel0Cluster * GemmMLevel1Cluster; - constexpr index_t E0 = E / E1; - - constexpr index_t B1 = GemmNLevel0Cluster * GemmNLevel1Cluster; - constexpr index_t B0 = B / B1; - - // define input tensor descriptor for threadwise copy - // thread input tensor, src of threadwise copy - constexpr auto in_e0_e1_c0_b0_b1_n0_thread_desc = make_native_tensor_descriptor_packed( - Sequence{}); - - // global input tensor, dst of threadwise copy - constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( - in_n_c_hi_wi_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - Pad, LeftPads, RightPads>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); - - constexpr auto in_n0_n1_c0_c1_y_ho_x_wo_global_desc = transform_tensor_descriptor( - in_n_c_hip_wip_global_desc, - make_tuple(UnMerge>{}, - UnMerge>{}, - Embed, - Sequence>{}, - Embed, - Sequence>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{})); - - constexpr auto in_e_c0_b_n0_global_desc = transform_tensor_descriptor( - in_n0_n1_c0_c1_y_ho_x_wo_global_desc, - make_tuple(Merge>{}, - PassThrough{}, - Merge>{}, - PassThrough{}), - make_tuple(Sequence<3, 4, 6>{}, Sequence<2>{}, Sequence<1, 5, 7>{}, Sequence<0>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - constexpr auto in_e0_e1_c0_b0_b1_n0_global_desc = transform_tensor_descriptor( - in_e_c0_b_n0_global_desc, - make_tuple(UnMerge>{}, - PassThrough{}, - UnMerge>{}, - PassThrough{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - // calculate origin of thread input tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t e_thread_data_on_global = - e_block_data_on_global + c_thread_mtx_on_block.row / GemmMPerThread; - - const index_t b_thread_data_on_global = - b_block_data_on_global + c_thread_mtx_on_block.col / GemmNPerThread; - - ThreadwiseGenericTensorSliceCopy_v4r2< - decltype(in_e0_e1_c0_b0_b1_n0_thread_desc), - decltype(in_e0_e1_c0_b0_b1_n0_global_desc), - decltype(in_e0_e1_c0_b0_b1_n0_thread_desc.GetLengths()), - Sequence<0, 1, 2, 3, 4, 5>, - 4, - 1, - InThreadCopyDstDataPerWrite_B, - AddressSpace::Vgpr, - AddressSpace::Global, - in_memory_op>(make_multi_index(0, 0, 0, 0, 0, 0), - make_multi_index(e_thread_data_on_global / E1, - e_thread_data_on_global % E1, - 0, - b_thread_data_on_global / B1, - b_thread_data_on_global % B1, - 0)) - .Run(p_in_thread, p_in_global); - } - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp deleted file mode 100644 index e9266ca220..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,418 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP -#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm.hpp" - -namespace ck { - -// Number of GEMMs: YTilda * XTilda -// GemmM = C -// GemmN = N * HTildaSlice * WTildaSlice -// GemmK = K * YDotSlice * XDotSlice -template -struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw -{ - __host__ __device__ static constexpr index_t GetNumberOfGemm() - { - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; - constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; - - return YTilda * XTilda; - } - - __host__ __device__ static constexpr auto GetGemmSizeImpl(index_t iYTilda, index_t iXTilda) - { - constexpr index_t N = InGlobalDesc::GetLengths()[0]; - constexpr index_t C = InGlobalDesc::GetLengths()[1]; - constexpr index_t Hi = InGlobalDesc::GetLengths()[2]; - constexpr index_t Wi = InGlobalDesc::GetLengths()[3]; - - constexpr index_t K = OutGlobalDesc::GetLengths()[1]; - constexpr index_t Ho = OutGlobalDesc::GetLengths()[2]; - constexpr index_t Wo = OutGlobalDesc::GetLengths()[3]; - - constexpr index_t Y = WeiGlobalDesc::GetLengths()[2]; - constexpr index_t X = WeiGlobalDesc::GetLengths()[3]; - - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; - constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; - - constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); - constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); - - constexpr index_t HTilda = - Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); - constexpr index_t WTilda = - Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); - - // only work on HTilda and WTilda that contribute to non-padding area of input tensor - constexpr index_t iHTildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]); - constexpr index_t iWTildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]); - - constexpr index_t iHTildaRight = math::min( - HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); - constexpr index_t iWTildaRight = math::min( - WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); - - constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft; - constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft; - - // GemmM and GemmN - constexpr index_t GemmM = C; - constexpr index_t GemmN = N * HTildaSlice * WTildaSlice; - - // GemmK is different for each GEMM - index_t YDotSlice = math::integer_divide_ceil(Y - iYTilda, YTilda); - index_t XDotSlice = math::integer_divide_ceil(X - iXTilda, XTilda); - - index_t GemmK = K * YDotSlice * XDotSlice; - - return Array{GemmM, GemmN, GemmK}; - } - - __host__ __device__ static constexpr auto GetGemmSize(index_t gemm_id) - { - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; - - index_t iYTilda = gemm_id / XTilda; - index_t iXTilda = gemm_id % XTilda; - - return GetGemmSizeImpl(iYTilda, iXTilda); - } - - template - __device__ static void RunImpl(Float* __restrict__ p_in_global, - const Float* __restrict__ p_wei_global, - const Float* __restrict__ p_out_global) - { - constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{}; - constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; - constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{}; - - constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0]; - constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1]; - constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2]; - constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3]; - - constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1]; - constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2]; - constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3]; - - constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2]; - constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3]; - - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; - constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; - - constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); - constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); - - constexpr index_t YDotSlice = math::integer_divide_ceil(Y - iYTilda, YTilda); - constexpr index_t XDotSlice = math::integer_divide_ceil(X - iXTilda, XTilda); - - constexpr index_t HTilda = - Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); - constexpr index_t WTilda = - Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); - - // only work on HTilda and WTilda that contribute to non-padding area of input tensor - constexpr index_t iHTildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]); - constexpr index_t iWTildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]); - - constexpr index_t iHTildaRight = math::min( - HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); - constexpr index_t iWTildaRight = math::min( - WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); - - constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft; - constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft; - - // A matrix: weight - // weight out-of-bound check can be skipped - constexpr bool wei_skip_out_of_bound_check = true; - - constexpr auto wei_k_c_ydot_ytilda_xdot_xtilda_global_desc = transform_tensor_descriptor( - wei_k_c_y_x_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - Embed, - Sequence, - wei_skip_out_of_bound_check>{}, - Embed, - Sequence, - wei_skip_out_of_bound_check>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - constexpr auto wei_k_c_ydotslice_xdotslice_global_desc = transform_tensor_descriptor( - wei_k_c_ydot_ytilda_xdot_xtilda_global_desc, - make_tuple( - PassThrough{}, - PassThrough{}, - Slice, Sequence<0, 0>, Sequence>{}, - Freeze, Sequence>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<>{})); - - constexpr auto wei_gemmk_gemmm_global_desc = transform_tensor_descriptor( - wei_k_c_ydotslice_xdotslice_global_desc, - make_tuple(Merge>{}, PassThrough{}), - make_tuple(Sequence<0, 2, 3>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - -// B matrix: output tensor -// TODO sometimes output tensor out-of-bound check can be skipped, find out all such -// situations -#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK - constexpr bool out_skip_out_of_bound_check = false; -#else - constexpr bool out_skip_out_of_bound_check = true; -#endif - - constexpr auto out_n_k_ydot_htilda_xdot_wtilda_global_desc = transform_tensor_descriptor( - out_n_k_ho_wo_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - Embed, - Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>, - out_skip_out_of_bound_check>{}, - Embed, - Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>, - out_skip_out_of_bound_check>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - constexpr auto out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc = - transform_tensor_descriptor( - out_n_k_ydot_htilda_xdot_wtilda_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - PassThrough{}, - PassThrough{}, - Slice, - Sequence, - Sequence>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}, Sequence<3, 5>{})); - - constexpr auto out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc = - transform_tensor_descriptor( - out_n_k_ydot_htildaslice_xdot_wtildaslice_global_desc, - make_tuple( - PassThrough{}, - PassThrough{}, - PassThrough{}, - PassThrough{}, - Slice, Sequence<0, 0>, Sequence>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<2, 4>{})); - - constexpr auto out_gemmk_gemmn_global_desc = transform_tensor_descriptor( - out_n_k_ydotslice_htildaslice_xdotslice_wtildaslice_global_desc, - make_tuple(Merge>{}, - Merge>{}), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - -// C matrix: input tensor -// TODO sometimes input out-of-bound check can be skipped, find out all such situations -#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK - constexpr bool in_skip_out_of_bound_check = false; -#else - constexpr bool in_skip_out_of_bound_check = true; -#endif - - constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( - in_n_c_hi_wi_global_desc, - make_tuple( - PassThrough{}, - PassThrough{}, - Pad, InLeftPads, InRightPads, in_skip_out_of_bound_check>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); - - constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2]; - constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3]; - - constexpr auto in_n_c_ytilda_htilda_xtilda_wtilda_global_desc = transform_tensor_descriptor( - in_n_c_hip_wip_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - Embed, - Sequence, - in_skip_out_of_bound_check>{}, - Embed, - Sequence, - in_skip_out_of_bound_check>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - constexpr auto in_n_c_htildaslice_wtildaslice_global_desc = transform_tensor_descriptor( - in_n_c_ytilda_htilda_xtilda_wtilda_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - Freeze, Sequence>{}, - Slice, - Sequence, - Sequence>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 4>{}, Sequence<3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<>{}, Sequence<2, 3>{})); - - constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor( - in_n_c_htildaslice_wtildaslice_global_desc, - make_tuple(PassThrough{}, Merge>{}), - make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - constexpr auto gridwise_gemm = - GridwiseGemmTransposedANormalBNormalC_v1, - Sequence<0, 1>, - 1, - GemmABlockCopySrcDataPerRead_GemmM, - GemmABlockCopyDstDataPerWrite_GemmM, - GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, - GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, - Sequence<0, 1>, - Sequence<0, 1>, - 1, - GemmBBlockCopySrcDataPerRead_GemmN, - GemmBBlockCopyDstDataPerWrite_GemmN, - Sequence<0, 1, 2, 3>, - 3, - GemmCThreadCopyDstDataPerWrite_GemmN1>{}; - - gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global); - } - - template - __device__ static void Run(Float* __restrict__ p_in_global, - const Float* __restrict__ p_wei_global, - const Float* __restrict__ p_out_global, - Number) - { - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; - constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; - - constexpr index_t iYTilda = GemmId / XTilda; - constexpr index_t iXTilda = GemmId % XTilda; - - static_assert(iYTilda < YTilda && iXTilda < XTilda, "wrong! iYtilda, iXtilda"); - - RunImpl(p_in_global, p_wei_global, p_out_global); - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index e47f2fce01..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,406 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V5R1_NHWC_KYXC_NHWK_HPP -#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V5R1_NHWC_KYXC_NHWK_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm.hpp" - -namespace ck { - -// Number of GEMMs = YTilda * XTilda -// GemmM = C -// GemmN = N * HTildaSlice * WTildaSlice -// GemmK0 = YDotSlice -// GemmK1 = XDotSlice -// GemmK2 = K -template -struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk -{ - __host__ __device__ static constexpr index_t GetNumberOfGemm() - { - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; - constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; - - return YTilda * XTilda; - } - - __host__ __device__ static constexpr auto GetGemmSizeImpl(index_t iYTilda, index_t iXTilda) - { - constexpr index_t N = InGlobalDesc::GetLengths()[0]; - constexpr index_t Hi = InGlobalDesc::GetLengths()[1]; - constexpr index_t Wi = InGlobalDesc::GetLengths()[2]; - constexpr index_t C = InGlobalDesc::GetLengths()[3]; - - constexpr index_t Ho = OutGlobalDesc::GetLengths()[1]; - constexpr index_t Wo = OutGlobalDesc::GetLengths()[2]; - constexpr index_t K = OutGlobalDesc::GetLengths()[3]; - - constexpr index_t Y = WeiGlobalDesc::GetLengths()[1]; - constexpr index_t X = WeiGlobalDesc::GetLengths()[2]; - - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; - constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; - - constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); - constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); - - constexpr index_t HTilda = - Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); - constexpr index_t WTilda = - Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); - - // only work on HTilda and WTilda that contribute to non-padding area of input tensor - constexpr index_t iHTildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]); - constexpr index_t iWTildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]); - - constexpr index_t iHTildaRight = math::min( - HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); - constexpr index_t iWTildaRight = math::min( - WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); - - constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft; - constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft; - - // GemmM and GemmN - constexpr index_t GemmM = C; - constexpr index_t GemmN = N * HTildaSlice * WTildaSlice; - - // GemmK is different for each GEMM - index_t YDotSlice = math::integer_divide_ceil(Y - iYTilda, YTilda); - index_t XDotSlice = math::integer_divide_ceil(X - iXTilda, XTilda); - - index_t GemmK0 = YDotSlice; - index_t GemmK1 = XDotSlice; - index_t GemmK2 = K; - - return make_multi_index(GemmM, GemmN, GemmK0, GemmK1, GemmK2); - } - - __host__ __device__ static constexpr auto GetGemmSize(index_t gemm_id) - { - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; - - index_t iYTilda = gemm_id / XTilda; - index_t iXTilda = gemm_id % XTilda; - - return GetGemmSizeImpl(iYTilda, iXTilda); - } - - template - __device__ static void RunImpl(Float* __restrict__ p_in_global, - const Float* __restrict__ p_wei_global, - const Float* __restrict__ p_out_global) - { - constexpr auto in_n_hi_wi_c_global_desc = InGlobalDesc{}; - constexpr auto wei_k_y_x_c_global_desc = WeiGlobalDesc{}; - constexpr auto out_n_ho_wo_k_global_desc = OutGlobalDesc{}; - - constexpr index_t N = in_n_hi_wi_c_global_desc.GetLengths()[0]; - constexpr index_t Hi = in_n_hi_wi_c_global_desc.GetLengths()[1]; - constexpr index_t Wi = in_n_hi_wi_c_global_desc.GetLengths()[2]; - constexpr index_t C = in_n_hi_wi_c_global_desc.GetLengths()[3]; - - constexpr index_t Ho = out_n_ho_wo_k_global_desc.GetLengths()[1]; - constexpr index_t Wo = out_n_ho_wo_k_global_desc.GetLengths()[2]; - constexpr index_t K = out_n_ho_wo_k_global_desc.GetLengths()[3]; - - constexpr index_t Y = wei_k_y_x_c_global_desc.GetLengths()[1]; - constexpr index_t X = wei_k_y_x_c_global_desc.GetLengths()[2]; - - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; - constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; - - constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); - constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); - - constexpr index_t YDotSlice = math::integer_divide_ceil(Y - iYTilda, YTilda); - constexpr index_t XDotSlice = math::integer_divide_ceil(X - iXTilda, XTilda); - - constexpr index_t HTilda = - Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); - constexpr index_t WTilda = - Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); - - // only work on HTilda and WTilda that contribute to non-padding area of input tensor - constexpr index_t iHTildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]); - constexpr index_t iWTildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]); - - constexpr index_t iHTildaRight = math::min( - HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); - constexpr index_t iWTildaRight = math::min( - WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); - - constexpr index_t HTildaSlice = iHTildaRight - iHTildaLeft; - constexpr index_t WTildaSlice = iWTildaRight - iWTildaLeft; - - // A matrix: weight - // weight out-of-bound check can be skipped - constexpr bool wei_skip_out_of_bound_check = true; - - constexpr auto wei_k_ydot_ytilda_xdot_xtilda_c_global_desc = transform_tensor_descriptor( - wei_k_y_x_c_global_desc, - make_tuple(PassThrough{}, - Embed, - Sequence, - wei_skip_out_of_bound_check>{}, - Embed, - Sequence, - wei_skip_out_of_bound_check>{}, - PassThrough{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - constexpr auto wei_k_ydotslice_xdotslice_c_global_desc = transform_tensor_descriptor( - wei_k_ydot_ytilda_xdot_xtilda_c_global_desc, - make_tuple( - PassThrough{}, - Slice, Sequence<0, 0>, Sequence>{}, - Freeze, Sequence>{}, - PassThrough{}), - make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<>{}, Sequence<3>{})); - - constexpr auto wei_gemmk0_gemmk1_gemmk2_gemmm_global_desc = - reorder_tensor_descriptor_given_lower2upper(wei_k_ydotslice_xdotslice_c_global_desc, - Sequence<2, 0, 1, 3>{}); - -// B matrix: output tensor -// TODO sometimes output tensor out-of-bound check can be skipped, find out all such -// situations -#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK - constexpr bool out_skip_out_of_bound_check = false; -#else - constexpr bool out_skip_out_of_bound_check = true; -#endif - - constexpr auto out_n_ydot_htilda_xdot_wtilda_k_global_desc = transform_tensor_descriptor( - out_n_ho_wo_k_global_desc, - make_tuple(PassThrough{}, - Embed, - Sequence<-ConvDilationH / GcdStrideDilationH, 1, 0>, - out_skip_out_of_bound_check>{}, - Embed, - Sequence<-ConvDilationW / GcdStrideDilationW, 1, 0>, - out_skip_out_of_bound_check>{}, - PassThrough{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - constexpr auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k_global_desc = - transform_tensor_descriptor( - out_n_ydot_htilda_xdot_wtilda_k_global_desc, - make_tuple( - PassThrough{}, - Slice, Sequence<0, 0>, Sequence>{}, - Slice, - Sequence, - Sequence>{}, - PassThrough{}), - make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}), - make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{})); - - constexpr auto out_gemmk0_gemmk1_gemmk2_gemmn_global_desc = transform_tensor_descriptor( - out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - PassThrough{}, - Merge>{}), - make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - -// C matrix: input tensor -// TODO sometimes input out-of-bound check can be skipped, find out all such situations -#if !CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK - constexpr bool in_skip_out_of_bound_check = false; -#else - constexpr bool in_skip_out_of_bound_check = true; -#endif - - constexpr auto in_n_hip_wip_c_global_desc = transform_tensor_descriptor( - in_n_hi_wi_c_global_desc, - make_tuple(PassThrough{}, - Pad, InLeftPads, InRightPads, in_skip_out_of_bound_check>{}, - PassThrough{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - constexpr index_t Hip = in_n_hip_wip_c_global_desc.GetLengths()[1]; - constexpr index_t Wip = in_n_hip_wip_c_global_desc.GetLengths()[2]; - - constexpr auto in_n_ytilda_htilda_xtilda_wtilda_c_global_desc = transform_tensor_descriptor( - in_n_hip_wip_c_global_desc, - make_tuple(PassThrough{}, - Embed, - Sequence, - in_skip_out_of_bound_check>{}, - Embed, - Sequence, - in_skip_out_of_bound_check>{}, - PassThrough{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - constexpr auto in_n_htildaslice_wtildaslice_c_global_desc = transform_tensor_descriptor( - in_n_ytilda_htilda_xtilda_wtilda_c_global_desc, - make_tuple(PassThrough{}, - Freeze, Sequence>{}, - Slice, - Sequence, - Sequence>{}, - PassThrough{}), - make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{}, Sequence<5>{}), - make_tuple(Sequence<0>{}, Sequence<>{}, Sequence<1, 2>{}, Sequence<3>{})); - - constexpr auto in_gemmm_gemmn_global_desc = transform_tensor_descriptor( - in_n_htildaslice_wtildaslice_c_global_desc, - make_tuple(PassThrough{}, Merge>{}), - make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // call GEMM - constexpr auto gridwise_gemm = GridwiseGemmTransposedANormalBNormalC_v2< - GridSize, - BlockSize, - Float, - AccFloat, - decltype(wei_gemmk0_gemmk1_gemmk2_gemmm_global_desc), - decltype(out_gemmk0_gemmk1_gemmk2_gemmn_global_desc), - decltype(in_gemmm_gemmn_global_desc), - InMemoryDataOperation::Set, - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerThread, - GemmNPerThread, - GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - ThreadGemmDataPerRead_GemmM, - ThreadGemmDataPerRead_GemmN, - GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM, - GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM, - Sequence<0, 1, 2, 3>, - Sequence<0, 1, 2, 3>, - 3, - GemmABlockCopySrcDataPerRead_GemmM, - GemmABlockCopyDstDataPerWrite_GemmM, - GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN, - GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN, - Sequence<0, 1, 3, 2>, - Sequence<0, 1, 3, 2>, - 2, - GemmBBlockCopySrcDataPerRead_GemmK2, - GemmBBlockCopyDstDataPerWrite_GemmN, - Sequence<2, 3, 0, 1>, - 3, - GemmCThreadCopyDstDataPerWrite_GemmN1>{}; - - gridwise_gemm.Run(p_wei_global, p_out_global, p_in_global); - } - - template - __device__ static void Run(Float* __restrict__ p_in_global, - const Float* __restrict__ p_wei_global, - const Float* __restrict__ p_out_global, - Number) - { - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; - constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; - - constexpr index_t iYTilda = GemmId / XTilda; - constexpr index_t iXTilda = GemmId % XTilda; - - static_assert(iYTilda < YTilda && iXTilda < XTilda, "wrong! iYtilda, iXtilda"); - - RunImpl(p_in_global, p_wei_global, p_out_global); - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp deleted file mode 100644 index d270a24467..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ /dev/null @@ -1,454 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP -#define CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "blockwise_generic_tensor_slice_copy.hpp" -#include "threadwise_generic_tensor_slice_copy.hpp" -#include "blockwise_gemm.hpp" - -namespace ck { - -template -struct GridwiseConvolutionForwardImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer -{ - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto True = integral_constant{}; - - // this is a mess - // TODO: find more elegent way of specifying (or calculating) performance parameters - constexpr index_t N1 = GemmNRepeat; - constexpr index_t N2 = GemmNPerThread; - - static_assert( - (N1 * N2 * BPerBlock) % (GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster) == 0, - "wrong!"); - - constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{}; - constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; - constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{}; - - constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0]; - constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1]; - constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2]; - constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3]; - - constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1]; - constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2]; - constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3]; - - constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2]; - constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3]; - - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - static_assert(N % (N1 * N2) == 0, "wrong! cannot divice N evenly among thread"); - - constexpr index_t N0 = N / (N1 * N2); - - constexpr index_t B = N0 * Ho * Wo; - - constexpr index_t E = C * Y * X; - - // sanity-check for vectorized memory load - static_assert((Wo == 1 || (ConvStrideW == 1 || InBlockCopySrcDataPerRead_B == 1)) && - (X == 1 || ConvDilationW % InBlockCopySrcDataPerRead_B == 0), - "wrong! aligment requirement for vectorized global load of input tensor will " - "be violated"); - - // divide block work by [K, B] - static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0, - "wrong! cannot divide work evenly among block"); - - constexpr index_t KBlockWork = K / KPerBlock; - constexpr index_t BBlockWork = B / BPerBlock; - - constexpr auto block_work_desc = - make_cluster_descriptor(Sequence{}); - - const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id()); - - const index_t k_block_data_on_global = block_work_id[I0] * KPerBlock; - const index_t b_block_data_on_global = block_work_id[I1] * BPerBlock; - - // input tensor - // global tensor in global memory - constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( - in_n_c_hi_wi_global_desc, - make_tuple( - PassThrough{}, PassThrough{}, Pad, LeftPads, RightPads>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); - - constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2]; - constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3]; - - constexpr auto in_n0_n1_n2_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( - in_n_c_hip_wip_global_desc, - make_tuple(UnMerge>{}, - PassThrough{}, - Embed, Sequence>{}, - Embed, Sequence>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}, Sequence<6, 7>{})); - - // global tensor in global memory, src of blockwise copy - constexpr auto in_e_n1_b_n2_global_desc = transform_tensor_descriptor( - in_n0_n1_n2_c_y_ho_x_wo_global_desc, - make_tuple(Merge>{}, - PassThrough{}, - Merge>{}, - PassThrough{}), - make_tuple(Sequence<3, 4, 6>{}, Sequence<1>{}, Sequence<0, 5, 7>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - // block tensor in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto in_e_n1_b_n2_block_desc = make_native_tensor_descriptor_aligned( - Sequence{}, Number{}); - - // this check is ad-hoc - // TODO: need to properly implement tensor descriptor with multiple alignment - // requirements - static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0, - "GemmDataPerReadB alignment requirement is not satisfied"); - - // input tensor blockwise copy - auto blockwise_in_copy = - BlockwiseGenericTensorSliceCopy_v4( - make_multi_index(0, 0, b_block_data_on_global, 0), make_multi_index(0, 0, 0, 0)); - - // weight tensor - // global tensor in global memory, src of blockwise copy - // It is constructed differently, depending on whether forward or backward weight - // convolution - constexpr auto wei_e_k_global_desc = - transform_tensor_descriptor(unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I2, I3), - make_tuple(Merge>{}, PassThrough{}), - make_tuple(Sequence<1, 2>{}, Sequence<0>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // block tensor in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned( - Sequence{}, - Number{}); - - // this check is ad-hoc - // TODO: need to properly implement tensor descriptor with multiple alignment - // requirements - static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0, - "GemmDataPerReadA alignment requirement is not satisfied"); - - // weight tensor blockwise copy - auto blockwise_wei_copy = - BlockwiseGenericTensorSliceCopy_v4( - make_multi_index(0, k_block_data_on_global), make_multi_index(0, 0)); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[EPerBlock, KPerBlock] is in LDS - // b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS - // c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in - // register - constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc); - - constexpr auto b_e_n1bn2_block_mtx_desc = make_ConstantMatrixDescriptor( - in_e_n1_b_n2_block_desc.GetLength(I0), - in_e_n1_b_n2_block_desc.GetLength(I1) * in_e_n1_b_n2_block_desc.GetLength(I2) * - in_e_n1_b_n2_block_desc.GetLength(I3), - in_e_n1_b_n2_block_desc.GetStride(I0)); - - // sanity check - static_assert(KPerBlock % (GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster) == 0, - "wrong!"); - - constexpr index_t GemmMRepeat = - KPerBlock / (GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster); - - // c_thread_mtx definition: this is a mess - // TODO:: more elegent way of defining c_thread_mtx - constexpr auto c_k0k1_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed( - Number{}, Number{}); - - const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< - BlockSize, - decltype(a_e_k_block_mtx_desc), - decltype(b_e_n1bn2_block_mtx_desc), - decltype(c_k0k1_n1n2_thread_mtx_desc), - GemmMPerThread, - GemmNPerThread, - GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmDataPerReadA, - GemmDataPerReadB>{}; - - // LDS allocation for input and weight: be careful of alignment - constexpr index_t max_align = math::lcm(InBlockCopyDstDataPerWrite_N2, - WeiBlockCopyDstDataPerWrite_K, - GemmDataPerReadA, - GemmDataPerReadB); - - constexpr index_t in_block_space = - math::integer_least_multiple(in_e_n1_b_n2_block_desc.GetElementSpace(), max_align); - - constexpr index_t wei_block_space = - math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align); - - __shared__ Float p_in_block_double[2 * in_block_space]; - __shared__ Float p_wei_block_double[2 * wei_block_space]; - - // register allocation for output - AccDataType p_out_thread[c_k0k1_n1n2_thread_mtx_desc.GetElementSpace()]; - - // zero out threadwise output - threadwise_matrix_set_zero(c_k0k1_n1n2_thread_mtx_desc, p_out_thread); - - // LDS double buffer: preload data into LDS - { - blockwise_in_copy.Run(p_in_global, p_in_block_double); - blockwise_wei_copy.Run(p_wei_global, p_wei_block_double); - } - - constexpr auto in_block_slice_copy_steps = Sequence{}; - constexpr auto wei_block_slice_copy_steps = Sequence{}; - - // LDS double buffer: main body - for(index_t e_block_data_begin = 0; e_block_data_begin + 2 * EPerBlock < E; - e_block_data_begin += 2 * EPerBlock) - { -#pragma unroll - for(index_t iloop = 0; iloop < 2; ++iloop) - { - const bool even_loop = (iloop % 2 == 0); - - Float* p_in_block_now = - even_loop ? p_in_block_double : p_in_block_double + in_block_space; - Float* p_wei_block_now = - even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space; - - Float* p_in_block_next = - even_loop ? p_in_block_double + in_block_space : p_in_block_double; - Float* p_wei_block_next = - even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double; - - Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()]; - Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()]; - - blockwise_in_copy.MoveSrcSliceWindow(in_block_slice_copy_steps, True); - blockwise_wei_copy.MoveSrcSliceWindow(wei_block_slice_copy_steps, True); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer); - blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread); - - // LDS double buffer: store next data to LDS - blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, p_in_block_next); - blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, p_wei_block_next); - } - } - - // LDS double buffer: tail - { - constexpr bool has_two_iteration_left = (E % (2 * EPerBlock) == 0); - - if(has_two_iteration_left) // if has 2 iteration left - { - Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()]; - Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()]; - - blockwise_in_copy.MoveSrcSliceWindow(in_block_slice_copy_steps, True); - blockwise_wei_copy.MoveSrcSliceWindow(wei_block_slice_copy_steps, True); - - __syncthreads(); - - // LDS double buffer: load last data from device mem - blockwise_in_copy.RunLoadThreadBuffer(p_in_global, p_in_thread_buffer); - blockwise_wei_copy.RunLoadThreadBuffer(p_wei_global, p_wei_thread_buffer); - - // LDS double buffer: GEMM on 2nd-last data - blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); - - // LDS double buffer: store last data to LDS - blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer, - p_in_block_double + in_block_space); - blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer, - p_wei_block_double + wei_block_space); - - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(p_wei_block_double + wei_block_space, - p_in_block_double + in_block_space, - p_out_thread); - } - else // if has 1 iteration left - { - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread); - } - } - - // copy output: register to global memory - { - constexpr index_t K1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster; - constexpr index_t K0 = K / K1; - - // define output tensor descriptor for threadwise copy - // thread output tensor, src of threadwise copy - constexpr auto out_k0_k1_n1_b_n2_thread_desc = make_native_tensor_descriptor_packed( - Sequence{}); - - // global output tensor - constexpr auto out_n0_n1_n2_k0_k1_ho_wo_global_desc = transform_tensor_descriptor( - out_n_k_ho_wo_global_desc, - make_tuple(UnMerge>{}, - UnMerge>{}, - PassThrough{}, - PassThrough{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}, Sequence<6>{})); - - // global output tensor, dst of threadwise copy - constexpr auto out_k0_k1_n1_b_n2_global_desc = transform_tensor_descriptor( - out_n0_n1_n2_k0_k1_ho_wo_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - PassThrough{}, - Merge>{}, - PassThrough{}), - make_tuple(Sequence<3>{}, - Sequence<4>{}, - Sequence<1>{}, - Sequence<0, 5, 6>{}, - Sequence<2>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t k_thread_data_on_global = - k_block_data_on_global + c_thread_mtx_on_block.row; - - const index_t b_thread_data_on_global = - b_block_data_on_global + c_thread_mtx_on_block.col / N2; - - ThreadwiseGenericTensorSliceCopy_v4r2< - decltype(out_k0_k1_n1_b_n2_thread_desc), - decltype(out_k0_k1_n1_b_n2_global_desc), - decltype(out_k0_k1_n1_b_n2_thread_desc.GetLengths()), - arithmetic_sequence_gen<0, 5, 1>::type, - 3, - 1, - 1, - AddressSpace::Vgpr, - AddressSpace::Global, - InMemoryDataOperation::Set>(make_multi_index(0, 0, 0, 0, 0), - make_multi_index(k_thread_data_on_global / K1, - k_thread_data_on_global % K1, - 0, - b_thread_data_on_global, - 0)) - .Run(p_out_thread, p_out_global); - } - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp deleted file mode 100644 index b8090321a9..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,171 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP -#define CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm.hpp" - -namespace ck { - -// GemmM = K -// GemmN = N * Ho * Wo -// GemmK = C * Y * X -template -struct GridwiseConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw -{ - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) const - { - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{}; - constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{}; - constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{}; - - constexpr index_t N = in_n_c_hi_wi_global_desc.GetLengths()[0]; - constexpr index_t C = in_n_c_hi_wi_global_desc.GetLengths()[1]; - constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLengths()[2]; - constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLengths()[3]; - - constexpr index_t K = out_n_k_ho_wo_global_desc.GetLengths()[1]; - constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLengths()[2]; - constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLengths()[3]; - - constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2]; - constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3]; - - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - -#if 0 - // sanity-check for vectorized memory load - static_assert((Wo == 1 || (ConvStrideW == 1 || GemmBBlockCopySrcDataPerRead_GemmN == 1)) && - (X == 1 || ConvDilationW % GemmBBlockCopySrcDataPerRead_GemmN == 0) && - InLeftPads{}[1] % GemmBBlockCopySrcDataPerRead_GemmN == 0 && - InRightPads{}[1] % GemmBBlockCopySrcDataPerRead_GemmN == 0, - "wrong! aligment requirement for vectorized global load of input tensor will " - "be violated"); -#endif - - // weight tensor - constexpr auto wei_gemmk_gemmm_global_desc = reorder_tensor_descriptor_given_upper2lower( - unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{}); - - // input tensor - constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor( - in_n_c_hi_wi_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - Pad, InLeftPads, InRightPads>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); - - constexpr index_t Hip = in_n_c_hip_wip_global_desc.GetLengths()[2]; - constexpr index_t Wip = in_n_c_hip_wip_global_desc.GetLengths()[3]; - - constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor( - in_n_c_hip_wip_global_desc, - make_tuple(PassThrough{}, - PassThrough{}, - Embed, Sequence>{}, - Embed, Sequence>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); - - constexpr auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor( - in_n_c_y_ho_x_wo_global_desc, - make_tuple(Merge>{}, Merge>{}), - make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // output tensor - constexpr auto out_gemmm_gemmn_global_desc = - transform_tensor_descriptor(unfold_tensor_descriptor(out_n_k_ho_wo_global_desc, I2, I3), - make_tuple(PassThrough{}, Merge>{}), - make_tuple(Sequence<1>{}, Sequence<0, 2>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // GEMM - constexpr auto gridwise_gemm = - GridwiseGemmTransposedANormalBNormalC_v1, - Sequence<1, 0>, - 0, - GemmABlockCopySrcDataPerRead_GemmK, - GemmABlockCopyDstDataPerWrite_GemmM, - GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, - GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, - Sequence<0, 1>, - Sequence<0, 1>, - 1, - GemmBBlockCopySrcDataPerRead_GemmN, - GemmBBlockCopyDstDataPerWrite_GemmN, - Sequence<2, 3, 0, 1>, - 3, - GemmCThreadCopyDstDataPerWrite_GemmN1>{}; - - gridwise_gemm.Run(p_wei_global, p_in_global, p_out_global); - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index ac3e35f2db..0000000000 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,162 +0,0 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NHWC_KYXC_NHWK_HPP -#define CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NHWC_KYXC_NHWK_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm.hpp" - -namespace ck { - -// GemmM = K -// GemmN = N * Ho * Wo -// GemmK = C * Y * X -template -struct GridwiseConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk -{ - __device__ void Run(const Float* const __restrict__ p_in_global, - const Float* const __restrict__ p_wei_global, - Float* const __restrict__ p_out_global) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_n_hi_wi_c_global_desc = InGlobalDesc{}; - constexpr auto wei_k_y_x_c_global_desc = WeiGlobalDesc{}; - constexpr auto out_n_ho_wo_k_global_desc = OutGlobalDesc{}; - - constexpr index_t N = in_n_hi_wi_c_global_desc.GetLengths()[I0]; - constexpr index_t Hi = in_n_hi_wi_c_global_desc.GetLengths()[I1]; - constexpr index_t Wi = in_n_hi_wi_c_global_desc.GetLengths()[I2]; - constexpr index_t C = in_n_hi_wi_c_global_desc.GetLengths()[I3]; - - constexpr index_t K = out_n_ho_wo_k_global_desc.GetLengths()[I3]; - constexpr index_t Ho = out_n_ho_wo_k_global_desc.GetLengths()[I1]; - constexpr index_t Wo = out_n_ho_wo_k_global_desc.GetLengths()[I2]; - - constexpr index_t Y = wei_k_y_x_c_global_desc.GetLengths()[I1]; - constexpr index_t X = wei_k_y_x_c_global_desc.GetLengths()[I2]; - - constexpr index_t ConvStrideH = ConvStrides{}[I0]; - constexpr index_t ConvStrideW = ConvStrides{}[I1]; - - constexpr index_t ConvDilationH = ConvDilations{}[I0]; - constexpr index_t ConvDilationW = ConvDilations{}[I1]; - - // weight tensor - constexpr auto wei_gemmk_gemmm_global_desc = reorder_tensor_descriptor_given_upper2lower( - unfold_tensor_descriptor(wei_k_y_x_c_global_desc, I1, I3), Sequence<1, 0>{}); - - // input tensor - constexpr auto in_n_hip_wip_c_global_desc = - transform_tensor_descriptor(in_n_hi_wi_c_global_desc, - make_tuple(PassThrough{}, - Pad, InLeftPads, InRightPads>{}, - PassThrough{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - constexpr index_t Hip = in_n_hip_wip_c_global_desc.GetLengths()[I1]; - constexpr index_t Wip = in_n_hip_wip_c_global_desc.GetLengths()[I2]; - - constexpr auto in_n_y_ho_x_wo_c_global_desc = transform_tensor_descriptor( - in_n_hip_wip_c_global_desc, - make_tuple(PassThrough{}, - Embed, Sequence>{}, - Embed, Sequence>{}, - PassThrough{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - constexpr auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor( - in_n_y_ho_x_wo_c_global_desc, - make_tuple(Merge>{}, Merge>{}), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // output tensor - constexpr auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor( - unfold_tensor_descriptor(out_n_ho_wo_k_global_desc, I0, I2), - make_tuple(PassThrough{}, Merge>{}), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - // GEMM - constexpr auto gridwise_gemm = - GridwiseGemmTransposedANormalBNormalC_v1, - Sequence<1, 0>, - 0, - GemmABlockCopySrcDataPerRead_GemmK, - GemmABlockCopyDstDataPerWrite_GemmM, - GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, - GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, - Sequence<1, 0>, - Sequence<1, 0>, - 0, - GemmBBlockCopySrcDataPerRead_GemmK, - GemmBBlockCopyDstDataPerWrite_GemmN, - Sequence<2, 3, 0, 1>, - 1, - GemmCThreadCopyDstDataPerWrite_GemmM1>{}; - - gridwise_gemm.Run(p_wei_global, p_in_global, p_out_global); - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp b/composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp deleted file mode 100644 index 48f1b718f1..0000000000 --- a/composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp +++ /dev/null @@ -1,80 +0,0 @@ -#ifndef CK_CONSTANT_MATRIX_DESCRIPTOR_HPP -#define CK_CONSTANT_MATRIX_DESCRIPTOR_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" - -namespace ck { - -template -struct ConstantMatrixDescriptor -{ - __host__ __device__ constexpr ConstantMatrixDescriptor() - { - static_assert(NCol_ <= RowStride_, "wrong! NCol > RowStride!"); - } - - __host__ __device__ static constexpr index_t NRow() { return NRow_; } - - __host__ __device__ static constexpr index_t NCol() { return NCol_; } - - __host__ __device__ static constexpr index_t RowStride() { return RowStride_; } - - __host__ __device__ static constexpr auto GetLengths() { return Sequence{}; } - - __host__ __device__ static constexpr index_t GetElementSize() { return NRow_ * NCol_; } - - __host__ __device__ static constexpr index_t GetElementSpace() { return NRow_ * RowStride_; } - - __host__ __device__ static index_t GetOffsetFromMultiIndex(index_t irow, index_t icol) - { - return irow * RowStride_ + icol; - } - - __host__ __device__ static index_t CalculateOffset(index_t irow, index_t icol) - { - return GetOffsetFromMultiIndex(irow, icol); - } - - template - __host__ __device__ static constexpr auto MakeSubMatrixDescriptor(Number, - Number) - { - return ConstantMatrixDescriptor{}; - } -}; - -template -__host__ __device__ constexpr auto make_ConstantMatrixDescriptor_packed(Number, Number) -{ - return ConstantMatrixDescriptor{}; -} - -template -__host__ __device__ constexpr auto - make_ConstantMatrixDescriptor(Number, Number, Number) -{ - return ConstantMatrixDescriptor{}; -} - -template -__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(NativeTensorDescriptor) -{ - using TDesc = NativeTensorDescriptor; - static_assert(TDesc::GetNumOfDimension() == 2, "wrong"); - static_assert(TDesc::GetStrides()[1] == 1, "wrong"); - return ConstantMatrixDescriptor{}; -} - -template -__host__ __device__ void print_ConstantMatrixDescriptor(TDesc, const char* s) -{ - printf( - "%s NRow %u NCol %u RowStride %u\n", s, TDesc::NRow(), TDesc::NCol(), TDesc::RowStride()); -} - -} // namespace ck - -#endif diff --git a/composable_kernel/include/tensor_description/cluster_descriptor.hpp b/composable_kernel/include/tensor_description/cluster_descriptor.hpp index 7793dc242a..c3523623d9 100644 --- a/composable_kernel/include/tensor_description/cluster_descriptor.hpp +++ b/composable_kernel/include/tensor_description/cluster_descriptor.hpp @@ -2,50 +2,10 @@ #define CK_CLUSTER_DESCRIPTOR_HPP #include "common_header.hpp" - -// TODO remove dependency on deprecated tensor descriptor -#include "tensor_descriptor.hpp" #include "tensor_adaptor.hpp" namespace ck { -// a cluster map 1d index to N-d index -template -struct ClusterDescriptor -{ - static constexpr index_t nDim = Lengths::Size(); - - static constexpr auto mDesc = transform_tensor_descriptor( - make_native_tensor_descriptor_packed(Lengths{}), - make_tuple(Merge{}), - make_tuple(ArrangeOrder{}), - make_tuple(Sequence<0>{})); - - __host__ __device__ constexpr ClusterDescriptor() - { - static_assert(Lengths::Size() == nDim && ArrangeOrder::Size() == nDim, - "wrong! size not the same"); - - static_assert(is_valid_sequence_map{}, "wrong! ArrangeOrder is wrong"); - } - - __host__ __device__ static constexpr index_t GetElementSize() { return mDesc.GetElementSize(); } - - __host__ __device__ static constexpr auto CalculateClusterIndex(index_t idx_1d) - { - return mDesc.CalculateLowerIndex(MultiIndex<1>{idx_1d}); - } -}; - -template ::type> -__host__ __device__ constexpr auto make_cluster_descriptor( - Lengths, ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{}) -{ - return ClusterDescriptor{}; -} - -#if 1 template ::type> __host__ __device__ constexpr auto make_cluster_descriptor_v2( @@ -68,7 +28,6 @@ __host__ __device__ constexpr auto make_cluster_descriptor_v2( return make_single_stage_tensor_adaptor( make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids)); } -#endif } // namespace ck #endif diff --git a/composable_kernel/include/tensor_description/dimension.hpp b/composable_kernel/include/tensor_description/dimension.hpp deleted file mode 100644 index 566895b9a4..0000000000 --- a/composable_kernel/include/tensor_description/dimension.hpp +++ /dev/null @@ -1,17 +0,0 @@ -#ifndef CK_DIMENSION_HPP -#define CK_DIMENSION_HPP - -#include "common_header.hpp" - -namespace ck { - -template -struct NativeDimension -{ - __host__ __device__ static constexpr auto GetLength() { return Number{}; } - - __host__ __device__ static constexpr auto GetStride() { return Number{}; } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_description/multi_index_transform.hpp b/composable_kernel/include/tensor_description/multi_index_transform.hpp deleted file mode 100644 index d4f23b8459..0000000000 --- a/composable_kernel/include/tensor_description/multi_index_transform.hpp +++ /dev/null @@ -1,523 +0,0 @@ -#ifndef CK_MULTI_INDEX_TRANSFORM_HPP -#define CK_MULTI_INDEX_TRANSFORM_HPP - -#include "common_header.hpp" -#include "multi_index.hpp" - -namespace ck { - -template -struct PassThrough -{ - using LowerIndex = MultiIndex<1>; - using UpperIndex = MultiIndex<1>; - - __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<1>{}; } - - __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<1>{}; } - - __host__ __device__ static constexpr auto GetUpperLengths() { return Sequence{}; } - - __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up) - { - return idx_up; - } - - __host__ __device__ static constexpr auto - CalculateLowerIndexDiff(const UpperIndex& idx_up_diff, - const UpperIndex& /* idx_up_old */, - const LowerIndex& /* idx_low_old */) - { - return idx_up_diff; - } - - __host__ __device__ static constexpr bool IsLinearTransform() { return true; } - - __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() - { - return true; - } -}; - -// By default, will automatically judge if is-valid check for upper-to-lower-index-mapping is -// necessary -// However, the check will be skipped if SkipIsValidCheck is set to true by user -// LowerLengths: Sequence<...> -template -struct Pad -{ - static constexpr index_t nDim = LowerLengths::Size(); - - using LowerIndex = MultiIndex; - using UpperIndex = MultiIndex; - - __host__ __device__ constexpr Pad() - { - static_assert(LowerLengths::GetSize() == nDim && LeftPads::GetSize() == nDim && - RightPads::GetSize() == nDim, - "wrong! # of dimensions not consistent"); - } - - __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number{}; } - - __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number{}; } - - __host__ __device__ static constexpr auto GetUpperLengths() - { - return LowerLengths{} + LeftPads{} + RightPads{}; - } - - __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up) - { - return idx_up - LeftPads{}; - } - - __host__ __device__ static constexpr auto - CalculateLowerIndexDiff(const UpperIndex& idx_up_diff, - const UpperIndex& /* idx_up_old */, - const LowerIndex& /* idx_low_old */) - { - return idx_up_diff; - } - - __host__ __device__ static constexpr bool IsLinearTransform() { return true; } - - __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() - { - // skip valid check if user request it - if(SkipIsValidCheck) - { - return true; - } - - bool flag = true; - - for(index_t i = 0; i < nDim; ++i) - { - flag = flag && LeftPads::At(i) == 0 && RightPads::At(i) == 0; - } - - return flag; - } -}; - -// LowerLengths: Sequence<...> -// SliceBegins: Sequence<...> -// SliceEnds: Sequence<...> -template -struct Slice -{ - static constexpr index_t nDim = LowerLengths::Size(); - - using LowerIndex = MultiIndex; - using UpperIndex = MultiIndex; - - __host__ __device__ constexpr Slice() - { - static_assert(LowerLengths::GetSize() == nDim && SliceBegins::GetSize() == nDim && - SliceEnds::GetSize() == nDim, - "wrong! # of dimensions not consistent"); - -#if 0 - // TODO: would not compile, error on constexpr - static_for<0, nDim, 1>{}([&](auto idim) { - static_assert(SliceBegins::At(idim) <= SliceEnds::At(idim) && - SliceBegins::At(idim) >= 0 && - SliceEnds::At(idim) <= LowerLengths::At(idim), - "wrong! Slice config is wrong"); - }); -#endif - } - - __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number{}; } - - __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number{}; } - - __host__ __device__ static constexpr auto GetUpperLengths() - { - return SliceEnds{} - SliceBegins{}; - } - - __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up) - { - return idx_up + SliceBegins{}; - } - - __host__ __device__ static constexpr auto - CalculateLowerIndexDiff(const UpperIndex& idx_up_diff, - const UpperIndex& /* idx_up_old */, - const LowerIndex& /* idx_low_old */) - { - return idx_up_diff; - } - - __host__ __device__ static constexpr bool IsLinearTransform() { return true; } - - __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() - { - return true; - } -}; - -// LowerLengths: Sequence<...> -template -struct Merge -{ - static constexpr index_t nDimLow = LowerLengths::Size(); - static constexpr index_t nDimUp = 1; - - using LowerIndex = MultiIndex; - using UpperIndex = MultiIndex; - - __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number{}; } - - __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number{}; } - - __host__ __device__ static constexpr auto GetUpperLengths() - { - return Sequence{}, Number<1>{})>{}; - } - - // emulate constexpr lambda - template - struct lambda_CalculateLowerIndex - { - index_t& itmp; - LowerIndex& idx_low; - - __host__ __device__ constexpr lambda_CalculateLowerIndex(index_t& itmp_, - LowerIndex& idx_low_) - : itmp(itmp_), idx_low(idx_low_) - { - } - - template - __host__ __device__ constexpr void operator()(IDim idim) const - { - constexpr index_t stride = PseudoLowStrides::At(idim); - idx_low(idim) = itmp / stride; - itmp -= idx_low[idim] * stride; - } - }; - - __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up) - { - LowerIndex idx_low; - - index_t itmp = idx_up[Number<0>{}]; - - constexpr auto pseudo_low_strides = - reverse_inclusive_scan_sequence( - LowerLengths::PopFront(), math::multiplies{}, Number<1>{}) - .PushBack(Number<1>{}); - - static_for<0, nDimLow - 1, 1>{}( - lambda_CalculateLowerIndex(itmp, idx_low)); - - idx_low(Number{}) = itmp / pseudo_low_strides[Number{}]; - - return idx_low; - } - - // idx_low_diff depends on idx_low_old, so idx_low need to be up-to-date - // If idx_up_diff is known at compile-time, many calculations can be optimized - // away by compiler - // This function assume idx_low_old is not out-of-bound - __host__ __device__ static constexpr auto - CalculateLowerIndexDiff(const UpperIndex& idx_up_diff, - const UpperIndex& /* idx_up_old */, - const LowerIndex& idx_low_old) - { - if(idx_up_diff[Number<0>{}] == 0) - { - return make_zero_multi_index(); - } - else - { - // CalculateLowerIndex(idx_up_diff) has multiple integer divisions. - // If idx_up_diff is known at compile-time, the calculation can - // be done at compile-time. However, if idx_up_diff is only known - // at run-time, then the calculation will also be computed at - // run-time, and can be very expensive. - LowerIndex idx_low_diff_tmp = CalculateLowerIndex(idx_up_diff); - - // find out the last low dimension that changed - index_t last_changed_low_dim = 0; - - static_for<0, nDimLow, 1>{}([&](auto i) { - if(idx_low_diff_tmp[i] != 0) - { - last_changed_low_dim = i; - } - }); - - LowerIndex idx_low_new = idx_low_old + idx_low_diff_tmp; - - if(idx_up_diff[Number<0>{}] > 0) - { - // do carry check on each low dimension in reversed order - // starting from the first digit that changed - // don't check the highest dimension - bool carry = false; - - static_for{}([&](auto i) { - if(i <= last_changed_low_dim) - { - if(carry) - { - ++idx_low_new(i); - } - - carry = false; - - if(idx_low_new[i] >= LowerLengths::At(i)) - { - idx_low_new(i) -= LowerLengths::At(i); - carry = true; - } - } - }); - - // highest dimension, no out-of-bound check - if(carry) - { - ++idx_low_new(Number<0>{}); - } - } - else - { - // do borrow check on each low dimension in reversed order - // starting from the first digit that changed - // don't check the highest dimension - bool borrow = false; - - static_for{}([&](auto i) { - if(i <= last_changed_low_dim) - { - if(borrow) - { - --idx_low_new(i); - } - - borrow = false; - - if(idx_low_new[i] < 0) - { - idx_low_new(i) += LowerLengths::At(i); - borrow = true; - } - } - }); - - // highest dimension, no out-of-bound check - if(borrow) - { - --idx_low_new(Number<0>{}); - } - } - - return idx_low_new - idx_low_old; - } - } - - __host__ __device__ static constexpr bool IsLinearTransform() { return false; } - - __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() - { - return true; - } -}; - -// UpperLengths: Sequence<...> -template -struct UnMerge -{ - static constexpr index_t nDimLow = 1; - static constexpr index_t nDimUp = UpperLengths::Size(); - - using LowerIndex = MultiIndex; - using UpperIndex = MultiIndex; - - __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number{}; } - - __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number{}; } - - __host__ __device__ static constexpr auto GetUpperLengths() { return UpperLengths{}; } - - __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up) - { - LowerIndex idx_low = make_multi_index(0); - - constexpr auto pseudo_up_strides = - reverse_inclusive_scan_sequence( - UpperLengths::PopFront(), math::multiplies{}, Number<1>{}) - .PushBack(Number<1>{}); - - static_for<0, nDimUp, 1>{}( - [&](auto idim) { idx_low(Number<0>{}) += idx_up[idim] * pseudo_up_strides[idim]; }); - - return idx_low; - } - - __host__ __device__ static constexpr auto - CalculateLowerIndexDiff(const UpperIndex& idx_up_diff, - const UpperIndex& /* idx_up_old */, - const LowerIndex& /* idx_low_old */) - { - return CalculateLowerIndex(idx_up_diff); - } - - __host__ __device__ static constexpr bool IsLinearTransform() { return true; } - - __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() - { - return true; - } -}; - -// By default, will automatically judge if is-valid check for upper-to-lower-index-mapping is -// necessary -// However, the check will be skipped if SkipIsValidCheck is set to true by user -// UpperLengths: Sequence<...> -// Coefficients: Sequence<...> -// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] + coefficients[nDimUp] -template -struct Embed -{ - static constexpr index_t nDimLow = 1; - static constexpr index_t nDimUp = UpperLengths::Size(); - - using LowerIndex = MultiIndex; - using UpperIndex = MultiIndex; - - __host__ __device__ constexpr Embed() - { - static_assert(UpperLengths::GetSize() == nDimUp && Coefficients::GetSize() == nDimUp + 1, - "wrong! # of dimensions not consistent"); - } - - __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number{}; } - - __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number{}; } - - __host__ __device__ static constexpr auto GetUpperLengths() { return UpperLengths{}; } - - __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up) - { - LowerIndex idx_low = make_multi_index(Coefficients{}[Number{}]); - - static_for<0, nDimUp, 1>{}( - [&](auto i) { idx_low(Number<0>{}) += idx_up[i] * Coefficients{}[i]; }); - - return idx_low; - } - - __host__ __device__ static constexpr auto - CalculateLowerIndexDiff(const UpperIndex& idx_up_diff, - const UpperIndex& /* idx_up_old */, - const LowerIndex& /* idx_low_old */) - { - LowerIndex idx_low_diff = make_multi_index(0); - - static_for<0, nDimUp, 1>{}( - [&](auto i) { idx_low_diff(Number<0>{}) += idx_up_diff[i] * Coefficients{}[i]; }); - - return idx_low_diff; - } - - __host__ __device__ static constexpr bool IsLinearTransform() { return true; } - - __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() - { - // skip valid check if user request it - if(SkipIsValidCheck) - { - return true; - } - - bool flag = true; - - index_t ncorner = 1; - - for(index_t idim = 0; idim < nDimUp; ++idim) - { - ncorner *= 2; - } - - // loop over each corner of the upper tensor - for(index_t icorner = 0; icorner < ncorner; ++icorner) - { - // generate upper index for each corner - auto idx_up = make_zero_multi_index(); - - index_t itmp = icorner; - - static_for{}([&](auto idim) { - auto idim_m1 = idim - Number<1>{}; - idx_up(idim_m1) = itmp % 2 == 0 ? 0 : UpperLengths::At(idim_m1) - 1; - itmp /= 2; - }); - - // calculate lower index - auto idx_low = CalculateLowerIndex(idx_up); - - // judge if lower index is valid - flag = flag && idx_low[Number<0>{}] >= 0 && idx_low[Number<0>{}] < LowerLength; - } - - return flag; - } -}; - -// LowerLengths: Sequence<...> -// LowerFreezePoint: Sequence<...> -template -struct Freeze -{ - static constexpr index_t nDimLow = LowerLengths::Size(); - static constexpr index_t nDimUp = 0; - - using LowerIndex = MultiIndex; - using UpperIndex = MultiIndex; - - __host__ __device__ constexpr Freeze() - { - // TODO: sanity check: LowerFreezePoint should be within range of LowerLengths - } - - __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number{}; } - - __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<0>{}; } - - __host__ __device__ static constexpr auto GetUpperLengths() { return Sequence<>{}; } - - __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& /*idx_up*/) - { - return to_multi_index(LowerFreezePoint{}); - } - - __host__ __device__ static constexpr auto - CalculateLowerIndexDiff(const UpperIndex& /* idx_up_diff */, - const UpperIndex& /* idx_up_old */, - const LowerIndex& /* idx_low_old */) - { - return make_zero_multi_index(); - } - - __host__ __device__ static constexpr bool IsLinearTransform() { return true; } - - __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() - { - return true; - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_description/print_tensor_descriptor.hpp b/composable_kernel/include/tensor_description/print_tensor_descriptor.hpp deleted file mode 100644 index 89174e27b3..0000000000 --- a/composable_kernel/include/tensor_description/print_tensor_descriptor.hpp +++ /dev/null @@ -1,173 +0,0 @@ -#ifndef CK_PRINT_TENSOR_DESCRIPTOR_HPP -#define CK_PRINT_TENSOR_DESCRIPTOR_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" - -namespace ck { - -template -__host__ __device__ void -print_tensor_descriptor(const char* s, const NativeTensorDescriptor& desc) -{ - print_tensor_descriptor_impl(s, desc.GetLengths(), desc.GetStrides()); -} - -template -__host__ __device__ void print_tensor_descriptor(const char* s, - const TransformedTensorDescriptor& desc) -{ - print_tensor_descriptor_impl(s, desc.GetLengths()); -} - -template -__host__ __device__ void -print_tensor_descriptor_impl(const char* s, Sequence, Sequence) -{ - constexpr index_t nDim = sizeof...(Lengths); - - static_assert(nDim > 0 && nDim <= 12, "wrong!"); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u}, strides {%u}\n", s, nDim, Lengths..., Strides...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n", s, nDim, Lengths..., Strides...); - }); - - static_if{}([&](auto) { - printf( - "%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n", s, nDim, Lengths..., Strides...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n", - s, - nDim, - Lengths..., - Strides...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n", - s, - nDim, - Lengths..., - Strides...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n", - s, - nDim, - Lengths..., - Strides...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n", - s, - nDim, - Lengths..., - Strides...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n", - s, - nDim, - Lengths..., - Strides...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u " - "%u}\n", - s, - nDim, - Lengths..., - Strides...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u " - "%u %u %u}\n", - s, - nDim, - Lengths..., - Strides...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u " - "%u %u " - "%u %u %u}\n", - s, - nDim, - Lengths..., - Strides...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u " - "%u %u %u %u " - "%u %u %u}\n", - s, - nDim, - Lengths..., - Strides...); - }); -} - -template -__host__ __device__ void print_tensor_descriptor_impl(const char* s, Sequence) -{ - constexpr index_t nDim = sizeof...(Lengths); - - static_assert(nDim > 0 && nDim <= 12, "wrong!"); - - static_if{}([&](auto) { printf("%s dim %u, lengths {%u}\n", s, nDim, Lengths...); }); - - static_if{}( - [&](auto) { printf("%s dim %u, lengths {%u %u}\n", s, nDim, Lengths...); }); - - static_if{}( - [&](auto) { printf("%s dim %u, lengths {%u %u %u}\n", s, nDim, Lengths...); }); - - static_if{}( - [&](auto) { printf("%s dim %u, lengths {%u %u %u %u}\n", s, nDim, Lengths...); }); - - static_if{}( - [&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u}\n", s, nDim, Lengths...); }); - - static_if{}( - [&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u %u}, \n", s, nDim, Lengths...); }); - - static_if{}( - [&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u %u %u}\n", s, nDim, Lengths...); }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...); - }); - - static_if{}([&](auto) { - printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...); - }); -} - -} // namespace ck - -#endif diff --git a/composable_kernel/include/tensor_description/tensor_coordinate.hpp b/composable_kernel/include/tensor_description/tensor_coordinate.hpp deleted file mode 100644 index efd80beaf8..0000000000 --- a/composable_kernel/include/tensor_description/tensor_coordinate.hpp +++ /dev/null @@ -1,289 +0,0 @@ -#ifndef CK_TENSOR_COORDINATE_HPP -#define CK_TENSOR_COORDINATE_HPP - -#include "common_header.hpp" -#include "dimension.hpp" -#include "multi_index_transform.hpp" -#include "tensor_descriptor.hpp" - -namespace ck { - -// A "tensor cooridnate" is an opaque object that represents a "point of location" inside a tensor -// At the bare minimun, user should be able to query the following information from a tensor -// coordinate: -// 1. Tensor descriptor -// 2. Location, represented in the form of multi-index -// 3. Location, represented in the form of the offset to the origin of the tensor -// 4. If the location is inside invalid area or not, i.e. the padding area of an implicitly padded -// tensor is considered invalid, because the padding area doesn't have any physical memory -// allocation -// A tensor cooridnate also provides following functionality: -// 1. Given step size in each dimension, update itself, or return a new tensor cooridnate, so user -// can freely move the "point of location" inside the tensor - -// wrapper class for NativeTensorCoordinate and TransformedTensorCoordinate -template -struct TensorCoordinate; - -// tensor coordinate for native tensor -template -struct NativeTensorCoordinate -{ - using type = NativeTensorCoordinate; - using tensor_desc_type = NativeTensorDesc; - static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension(); - using Index = MultiIndex; - - __host__ __device__ constexpr NativeTensorCoordinate(Index idx) - : mIndex(idx), mOffset(tensor_desc_type::CalculateOffset(idx)) - { - } - - template - __host__ __device__ constexpr NativeTensorCoordinate(Xs... xs) - : NativeTensorCoordinate(make_multi_index(xs...)) - { - } - - template - __host__ __device__ constexpr NativeTensorCoordinate(Sequence) - : NativeTensorCoordinate(make_mutli_index(Xs...)) - { - } - - __host__ __device__ static constexpr auto GetTensorDescriptor() { return tensor_desc_type{}; } - - __host__ __device__ constexpr const Index& GetUpperIndex() const { return mIndex; } - - __host__ __device__ constexpr const Index& GetIndex() const { return mIndex; } - - __host__ __device__ constexpr const index_t& GetOffset() const { return mOffset; } - - __host__ __device__ constexpr type operator+=(const Index& idx_diff) - { - // mIndex is updated here, but some (or all) of its entries may never be used - // compiler should remove those entries as dead code - mIndex += idx_diff; - - mOffset += tensor_desc_type::CalculateOffsetDiff(idx_diff); - - return *this; - } - - __host__ __device__ constexpr type operator-=(const Index& idx_diff) - { - // mIndex is updated here, but some (or all) of its entries may never be used - // compiler should remove those entries as dead code - mIndex -= idx_diff; - - mOffset -= tensor_desc_type::CalculateOffsetDiff(idx_diff); - - return *this; - } - - __host__ __device__ constexpr type operator+(const Index& idx_diff) const - { - type coord = *this; - coord += idx_diff; - return coord; - } - - __host__ __device__ constexpr type operator-(const Index& idx_diff) const - { - type coord = *this; - coord -= idx_diff; - return coord; - } - - __host__ __device__ static constexpr index_t CalculateOffsetDiff(const Index& idx_diff) - { - return tensor_desc_type::CalculateOffsetDiff(idx_diff); - } - - // evaluated at run-time - __host__ __device__ constexpr bool IsUpperIndexValid() const - { - return tensor_desc_type::IsUpperIndexValid(GetUpperIndex()); - } - - // evaluated at run-time - __host__ __device__ constexpr bool IsOffsetValid() const - { - // For native tensor, offset is valid if upper-index is valid - return IsUpperIndexValid(); - } - - // evaluated at compile-time - __host__ __device__ static constexpr bool IsOffsetValidAssumingUpperIndexIsValid() - { - return true; - } - - private: - // mIndex may be saved and updated, however, the value of some (or all) of its entries may - // never be used. Compiler should be able to remove these entries as well as its calculation - // as dead code. - // TODO: make sure compiler indeed remove these dead code - Index mIndex; - index_t mOffset; -}; - -// tensor coordinate for transformed tensor -template -struct TransformedTensorCoordinate -{ - using tensor_desc_type = TransformedTensorDesc; - using LowerCoord = - typename TensorCoordinate::type; - using UpperCoord = TransformedTensorCoordinate; - static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension(); - using UpperIndex = MultiIndex; - - __host__ __device__ constexpr TransformedTensorCoordinate(UpperIndex idx) - : mIndexUp{idx}, mCoordLow{tensor_desc_type::CalculateLowerIndex(idx)} - { - } - - template - __host__ __device__ constexpr TransformedTensorCoordinate(Xs... xs) - : TransformedTensorCoordinate(UpperIndex{xs...}) - { - } - - template - __host__ __device__ constexpr TransformedTensorCoordinate(Sequence) - : TransformedTensorCoordinate(UpperIndex{Xs...}) - { - } - - __host__ __device__ static constexpr auto GetTensorDescriptor() { return tensor_desc_type{}; } - - __host__ __device__ constexpr const LowerCoord& GetLowerCoordinate() const { return mCoordLow; } - - __host__ __device__ constexpr const UpperIndex& GetUpperIndex() const { return mIndexUp; } - - __host__ __device__ constexpr const UpperIndex& GetIndex() const { return GetUpperIndex(); } - - __host__ __device__ constexpr const index_t& GetOffset() const - { - return GetLowerCoordinate().GetOffset(); - } - - __host__ __device__ constexpr UpperCoord operator+=(const UpperIndex& idx_up_diff) - { - // For transformation of multi-index difference, not all transformation functions need to - // know the old lower-index or the old upper-index. We pass both of them to the - // transformation function. The transformation function itself decides to use them or not. - mCoordLow += tensor_desc_type::CalculateLowerIndexDiff( - idx_up_diff, GetIndex(), GetLowerCoordinate().GetIndex()); - - // mIndexUp is updated here, but some (or all) of its entries may never be used - // compiler should remove those entries as dead code - mIndexUp += idx_up_diff; - - return *this; - } - - __host__ __device__ constexpr UpperCoord operator-=(const UpperIndex& idx_up_diff) - { - mCoordLow -= tensor_desc_type::CalculateLowerIndexDiff( - idx_up_diff, GetIndex(), GetLowerCoordinate().GetIndex()); - - // mIndex is updated here, but some (or all) of its entries may never be used - // compiler should remove those entries as dead code - mIndexUp -= idx_up_diff; - - return *this; - } - - __host__ __device__ constexpr UpperCoord operator+(const UpperIndex& idx_up_diff) const - { - UpperCoord coord_up = *this; - coord_up += idx_up_diff; - return coord_up; - } - - __host__ __device__ constexpr UpperCoord operator-(const UpperIndex& idx_up_diff) const - { - UpperCoord coord_up = *this; - coord_up -= idx_up_diff; - return coord_up; - } - - // Calculate offset diff without updating tensor-coordinate - // If idx_up_diff is know at compile time, and has only non-zero entries on linear dimensions, - // then all calculation can be done at compile-time. - // TODO: this function is not compiled to expected ISA - __host__ __device__ constexpr index_t CalculateOffsetDiff(const UpperIndex& idx_up_diff) const - { - // For transformation of multi-index difference, not all transformation functions need to - // know the old lower-index or the old upper-index. We pass both of them to the - // transformation function. The transformation function itself decides to use them or not. - const auto idx_low_diff = tensor_desc_type::CalculateLowerIndexDiff( - idx_up_diff, GetIndex(), GetLowerCoordinate().GetIndex()); - - return GetLowerCoordinate().CalculateOffsetDiff(idx_low_diff); - } - - // evaluated at run-time - __host__ __device__ constexpr bool IsUpperIndexValid() const - { - return tensor_desc_type::IsUpperIndexValid(GetUpperIndex()); - } - - // evaluted at run-time - __host__ __device__ constexpr bool IsOffsetValid() const - { - return IsUpperIndexValid() && GetLowerCoordinate().IsOffsetValid(); - } - - // most evaluatation is done at comile-time - __host__ __device__ constexpr bool IsLowerIndexValidAssumingUpperIndexIsValid() const - { - return tensor_desc_type::IsLowerIndexValidAssumingUpperIndexIsValid( - GetLowerCoordinate().GetIndex()); - } - - // most evaluatation is done at comile-time - __host__ __device__ constexpr bool IsOffsetValidAssumingUpperIndexIsValid() const - { - return IsLowerIndexValidAssumingUpperIndexIsValid() && - GetLowerCoordinate().IsOffsetValidAssumingUpperIndexIsValid(); - } - - private: - // mIndexUp may be calculated and updated, however, the value of some (or all) of its entries - // may - // never be used. Compiler should be able to remove these entries as well as its calculation - // as dead code. - // TODO: make sure compiler indeed remove these dead code - UpperIndex mIndexUp; - LowerCoord mCoordLow; -}; - -template -struct TensorCoordinate -{ - private: - template - __host__ __device__ static constexpr auto - MakeDummyTensorCoordinate(NativeTensorDescriptor) - { - return NativeTensorCoordinate>( - make_zero_multi_index()); - } - - template - __host__ __device__ static constexpr auto - MakeDummyTensorCoordinate(TransformedTensorDescriptor) - { - return TransformedTensorCoordinate>( - make_zero_multi_index()); - } - - public: - using type = decltype(MakeDummyTensorCoordinate(TensorDesc{})); -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_description/tensor_descriptor.hpp b/composable_kernel/include/tensor_description/tensor_descriptor.hpp deleted file mode 100644 index 7b57723341..0000000000 --- a/composable_kernel/include/tensor_description/tensor_descriptor.hpp +++ /dev/null @@ -1,526 +0,0 @@ -#ifndef CK_TENSOR_DESCRIPTOR_HPP -#define CK_TENSOR_DESCRIPTOR_HPP - -#include "common_header.hpp" -#include "dimension.hpp" -#include "multi_index_transform.hpp" - -namespace ck { - -// tensor descriptor for "native tensor" -// A "native tensor" is a "true" tensor that can be represented by Lengths and Strides -template -struct NativeTensorDescriptor -{ - using type = NativeTensorDescriptor; - static constexpr index_t nDim = sizeof...(NativeDimensions); - static constexpr auto mDimensions = make_tuple(NativeDimensions{}...); - - using Index = MultiIndex; - - __host__ __device__ static constexpr auto GetNumOfDimension() { return Number{}; } - - template - __host__ __device__ static constexpr auto GetLength(Number) - { - return mDimensions.At(Number{}).GetLength(); - } - - template - __host__ __device__ static constexpr auto GetStride(Number) - { - return mDimensions.At(Number{}).GetStride(); - } - - template - __host__ __device__ static constexpr auto GetLengths(Sequence) - { - return Sequence{})...>{}; - } - - template - __host__ __device__ static constexpr auto GetStrides(Sequence) - { - return Sequence{})...>{}; - } - - template - __host__ __device__ static constexpr auto GetLengths(Number, Number...) - { - return GetLengths(Sequence{}); - } - - template - __host__ __device__ static constexpr auto GetStrides(Number, Number...) - { - return GetStrides(Sequence{}); - } - - __host__ __device__ static constexpr auto GetLengths() - { - return GetLengths(typename arithmetic_sequence_gen<0, nDim, 1>::type{}); - } - - __host__ __device__ static constexpr auto GetStrides() - { - return GetStrides(typename arithmetic_sequence_gen<0, nDim, 1>::type{}); - } - - __host__ __device__ static constexpr index_t GetElementSize() - { - return reduce_on_sequence(GetLengths(), math::multiplies{}, Number<1>{}); - } - - __host__ __device__ static constexpr index_t GetElementSpace() - { - return reduce_on_sequence( - (GetLengths() - Number<1>{}) * GetStrides(), math::plus{}, Number<1>{}); - } - - // TODO: this cannot return constepxr because of use of lambda - __host__ __device__ static constexpr index_t CalculateOffset(const Index& idx) - { - index_t offset = 0; - - static_for<0, nDim, 1>{}([&](auto idim) { offset += idx[idim] * GetStride(idim); }); - - return offset; - } - - __host__ __device__ static constexpr index_t CalculateOffsetDiff(const Index& idx_diff) - { - index_t offset_diff = 0; - - static_for<0, nDim, 1>{}( - [&](auto idim) { offset_diff += idx_diff[idim] * GetStride(idim); }); - - return offset_diff; - } - - template - __host__ __device__ static constexpr bool IsLinearDimension(Number) - { - return true; - } - - __host__ __device__ static constexpr auto GetLinearDimensionMask() - { - return typename uniform_sequence_gen::type{}; - } - - __host__ __device__ static constexpr auto GetNonLinearDimensionMask() - { - return typename uniform_sequence_gen::type{}; - } - - __host__ __device__ static constexpr auto GetNonLinearDimensions() { return Sequence<>{}; } - - __host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups() - { - return Tuple<>{}; - } - - // a multi-index is valid if there is a corresponding point for it in the tensor - __host__ __device__ static constexpr bool IsUpperIndexValid(const Index& idx) - { - bool flag = true; - - for(index_t i = 0; i < nDim; ++i) - { - flag = flag && idx[i] >= 0 && idx[i] < GetLengths()[i]; - } - - return flag; - } -}; - -// Tensor descriptor for "transformed tensor" -template - typename LowDimensionIds, // Tuple> - typename UpDimensionIds> // Tuple> -struct TransformedTensorDescriptor -{ - using type = TransformedTensorDescriptor; - static constexpr index_t nTransform = Transforms::Size(); - - struct lambda_merge_sequences - { - template - __host__ __device__ constexpr auto operator()(Seqs... seqs) const - { - return merge_sequences(seqs...); - } - }; - - __host__ __device__ static constexpr auto GetNumOfLowerDimension() - { - // Here, we assume all lower-dimensions are active - // TODO: sanity-check all lower-dimension are indeed active - - using duplicated_low_active_dims = - decltype(unpack(lambda_merge_sequences{}, LowDimensionIds{})); - - using low_active_dims = typename sequence_unique_sort, - math::equal>::type; - - return low_active_dims::Size(); - } - - __host__ __device__ static constexpr auto GetNumOfUpperDimension() - { - using duplicated_up_active_dims = - decltype(unpack(lambda_merge_sequences{}, UpDimensionIds{})); - - using up_active_dims = typename sequence_unique_sort, - math::equal>::type; - - return up_active_dims::Size(); - } - - static constexpr index_t nDimUp = GetNumOfUpperDimension(); - static constexpr index_t nDimLow = GetNumOfLowerDimension(); - - using UpperIndex = MultiIndex; - using LowerIndex = MultiIndex; - - __host__ __device__ constexpr TransformedTensorDescriptor() - { - static_assert(nTransform == Transforms::Size() && nTransform == LowDimensionIds::Size() && - nTransform == UpDimensionIds::Size(), - "wrong! # of transformations not the same"); - - // sanity check: - // LowDimensionIds should include all low-dimensions, - // UpDimensionIds should include all up-dimensions - using mingled_up_dimension_ids = - decltype(unpack(lambda_merge_sequences{}, UpDimensionIds{})); - - using sorted_up_dimension_ids = - typename sequence_sort>::type; - - static_assert(sorted_up_dimension_ids::Size() == nDimUp && - is_valid_sequence_map{}, - "wrong! UpDimensionIds is not configured correctly"); - - using mingled_low_dimension_ids = - decltype(unpack(lambda_merge_sequences{}, LowDimensionIds{})); - - using sorted_low_dimension_ids = - typename sequence_sort>::type; - - static_assert(sorted_low_dimension_ids::Size() == nDimLow && - is_valid_sequence_map{}, - "wrong! LowDimensionIds is not configured correctly"); - - // TODO: sanity check: while a up-dimension could be associated with multille - // transformation, a low-dimension should be associated with only one transformation - - // TODO: sanity-check: GetLowerLengths of each transform should be consistent with lengths - // of lower-tensor-descriptor - } - - __host__ __device__ static constexpr auto GetNumOfDimension() - { - return GetNumOfUpperDimension(); - } - - __host__ __device__ static constexpr auto GetLowerTensorDescriptor() - { - return LowTensorDescriptor{}; - } - - struct lambda_GetUpperLengths - { - template - __host__ __device__ constexpr auto operator()(const Transform& tran) const - { - return tran.GetUpperLengths(); - } - }; - - __host__ __device__ static constexpr auto GetUpperLengths() - { - constexpr auto tuple_of_up_lengths = - transform_tuples(lambda_GetUpperLengths{}, Transforms{}); - - constexpr auto mingled_up_lengths = unpack(lambda_merge_sequences{}, tuple_of_up_lengths); - - constexpr auto mingled_up_dimension_ids = - unpack(lambda_merge_sequences{}, UpDimensionIds{}); - - // TODO: sanity-check mingled_up_dimension_ids contain all upper-dimensions - // TODO: sanity-check mingled_up_lengths have no conflicting upper-length - - // sort by upper-dimension-ids - using sort_up_dimension_ids = sequence_unique_sort, - math::equal>; - - // sanity-check sorted-upper-dimension-ids should be Sequence<0, 1, ... nDimUp-1> - static_assert(is_same::type>{}, - "wrong! UpDimensionIds is not configured correctly"); - - constexpr auto sorted2unsorted_map = typename sort_up_dimension_ids::sorted2unsorted_map{}; - - constexpr auto sorted_up_lengths = - pick_sequence_elements_by_ids(mingled_up_lengths, sorted2unsorted_map); - - return sorted_up_lengths; - } - - __host__ __device__ static constexpr auto GetLengths() { return GetUpperLengths(); } - - template - __host__ __device__ static constexpr auto GetLength(Number) - { - return GetLengths()[IDim]; - } - - template - __host__ __device__ static constexpr auto GetLengths(Sequence) - { - return Sequence{})...>{}; - } - - template - __host__ __device__ static constexpr auto GetLengths(Number, Number...) - { - return GetLengths(Sequence{}); - } - - __host__ __device__ static constexpr index_t GetElementSize() - { - return reduce_on_sequence(GetLengths(), math::multiplies{}, Number<1>{}); - } - - __host__ __device__ static constexpr index_t GetElementSpace() - { - // TODO: Is this the correct definition for transformed tensor? - return GetLowerTensorDescriptor().GetElementSpace(); - } - - // TODO: right now return value is not constexpr because use of non-constexpr lambda - __host__ __device__ static constexpr LowerIndex CalculateLowerIndex(const UpperIndex& idx_up) - { - LowerIndex idx_low; - - static_for<0, nTransform, 1>{}([&](auto itran) { - constexpr auto tran = Transforms{}.At(itran); - - const auto idx_up_part = pick_container_element(idx_up, UpDimensionIds{}.At(itran)); - auto idx_low_part = pick_container_element(idx_low, LowDimensionIds{}.At(itran)); - - // this assume each lower (single) index is only assocaited with one transformation, - // which is required for index transformation, and has been checked during constructor - // of TransformedTensorDescriptor - idx_low_part = tran.CalculateLowerIndex(to_multi_index(idx_up_part)); - }); - - return idx_low; - } - - // TODO: right now return value is not constexpr because use of non-constepxr lambda - __host__ __device__ static constexpr LowerIndex CalculateLowerIndexDiff( - const UpperIndex& idx_up_diff, const UpperIndex& idx_up_old, const LowerIndex& idx_low_old) - { - LowerIndex idx_low_diff; - - static_for<0, nTransform, 1>{}([&](auto itran) { - constexpr auto tran = Transforms{}.At(itran); - - const auto idx_up_diff_part = - pick_container_element(idx_up_diff, UpDimensionIds{}.At(itran)); - - const auto idx_up_old_part = - pick_container_element(idx_up_old, UpDimensionIds{}.At(itran)); - - const auto idx_low_old_part = - pick_container_element(idx_low_old, LowDimensionIds{}.At(itran)); - - auto idx_low_diff_part = - pick_container_element(idx_low_diff, LowDimensionIds{}.At(itran)); - - // this assume each lower (single) index is associated with only one transformation, - // which is required for index transformation, and has been checked during constructor - // of TransformedTensorDescriptor - idx_low_diff_part = tran.CalculateLowerIndexDiff(to_multi_index(idx_up_diff_part), - to_multi_index(idx_up_old_part), - to_multi_index(idx_low_old_part)); - }); - - return idx_low_diff; - } - - __host__ __device__ static constexpr index_t CalculateOffset(const UpperIndex& idx_up) - { - return GetLowerTensorDescriptor().CalculateOffset(CalculateLowerIndex(idx_up)); - } - - struct lambda_sequence_logical_and - { - template - __host__ __device__ constexpr auto operator()(Seqs...) const - { - return typename sequence_reduce, Seqs...>::type{}; - } - }; - - template - struct lambda_is_true - { - __host__ __device__ constexpr auto operator()(const T& x) const - { - // TODO: remove static_cast once Sequence can take bool as entries - return static_cast(x) == true; - } - }; - - struct lambda_get_linear_dimension_mask_of_single_tranform - { - // check only one transform at a time - template - __host__ __device__ constexpr auto - operator()(Transform, LowDimensionId, UpDimensionId) const - { - // judge if transformation is linear - constexpr bool is_linear_transform = Transform::IsLinearTransform(); - - // judge if all lower dimension are linear - constexpr bool are_all_low_dim_linear = sequence_all_of( - pick_sequence_elements_by_ids(GetLowerTensorDescriptor().GetLinearDimensionMask(), - LowDimensionId{}), - lambda_is_true{}); - - // create linear mask for upper dimensions - constexpr bool are_up_dim_linear = is_linear_transform && are_all_low_dim_linear; - - constexpr auto mask_of_up_linear_dims = modify_sequence_elements_by_ids( - typename uniform_sequence_gen::type{}, - typename uniform_sequence_gen::type{}, - UpDimensionId{}); - - return mask_of_up_linear_dims; - } - }; - - // TODO: this is a hack, transform_tuples() doesn't compile, would complain about constexpr - template - __host__ __device__ static constexpr auto - dummy_transform_tuples_impl(F f, X x, Y y, Z z, Sequence) - { - return make_tuple(f(x.At(Number{}), y.At(Number{}), z.At(Number{}))...); - } - - __host__ __device__ static constexpr auto GetLinearDimensionMask() - { -#if 0 - // create tuple of linear dimension masks, for all transformations - // TODO: this doesn't compile, because transform_tuples() complain about constexpr - constexpr auto tuple_of_linear_dimension_mask = - transform_tuples(lambda_get_linear_dimension_mask_of_single_tranform{}, - Transforms{}, - LowDimensionIds{}, - UpDimensionIds{}); -#else - // create tuple of linear dimension masks, for all transformations - // TODO: this is a hack - constexpr auto tuple_of_linear_dimension_mask = dummy_transform_tuples_impl( - lambda_get_linear_dimension_mask_of_single_tranform{}, - Transforms{}, - LowDimensionIds{}, - UpDimensionIds{}, - typename arithmetic_sequence_gen<0, Transforms::Size(), 1>::type{}); -#endif - - // reduce tuple of masks into one mask - constexpr auto linear_dimension_mask = - unpack(lambda_sequence_logical_and{}, tuple_of_linear_dimension_mask); - - return linear_dimension_mask; - } - - __host__ __device__ static constexpr auto GetNonLinearDimensionMask() - { - return GetLinearDimensionMask().Transform(logical_not{}); - } - - template - __host__ __device__ static constexpr bool IsLinearDimension(Number) - { - return GetLinearDimensionMask().At(Number{}); - } - - __host__ __device__ static constexpr auto GetLinearDimensions() - { - constexpr auto linear_dimension_mask = GetLinearDimensionMask(); - - return pick_sequence_elements_by_mask( - typename arithmetic_sequence_gen<0, nDimUp, 1>::type{}, linear_dimension_mask); - } - - __host__ __device__ static constexpr auto GetNonLinearDimensions() - { - constexpr auto nonlinear_dimension_mask = GetNonLinearDimensionMask(); - - return pick_sequence_elements_by_mask( - typename arithmetic_sequence_gen<0, nDimUp, 1>::type{}, nonlinear_dimension_mask); - } - -#if 0 - __host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups() - { - // TODO: not implemented - } -#endif - - // a multi-index is valid if there is a corresponding point for it in the tensor - __host__ __device__ constexpr bool IsUpperIndexValid(const UpperIndex& idx_up) const - { - bool flag = true; - - for(index_t i = 0; i < nDimUp; ++i) - { - flag = flag && idx_up[i] >= 0 && idx_up[i] < GetLengths()[i]; - } - - return flag; - } - - // this function is for optimization purpose, it's called by tensor coordinate - // this function tells you: If a lower-index is valid or not, assuming upper index is valid - __host__ __device__ static constexpr bool - IsLowerIndexValidAssumingUpperIndexIsValid(const LowerIndex& idx_low) - { - bool flag = true; - - static_for<0, nTransform, 1>{}([&](auto itran) { - constexpr auto tran = Transforms{}.At(itran); - - // check a indtransformation if it does not always has a valid mapping - constexpr bool is_valid_up_always_mapped_to_valid_low = - decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex(); - - if(!is_valid_up_always_mapped_to_valid_low) - { - constexpr auto low_dims_part = LowDimensionIds{}.At(itran); - constexpr auto low_lengths_part = - GetLowerTensorDescriptor().GetLengths(low_dims_part); - const auto idx_low_part = - to_multi_index(pick_container_element(idx_low, low_dims_part)); - - static_for<0, decltype(low_dims_part)::Size(), 1>{}([&](auto i) { - flag = flag && idx_low_part[i] >= 0 && idx_low_part[i] < low_lengths_part[i]; - }); - } - }); - - return flag; - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp b/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp deleted file mode 100644 index bed6de6d1e..0000000000 --- a/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp +++ /dev/null @@ -1,176 +0,0 @@ -#ifndef CK_TENSOR_DESCRIPTOR_HELPER_HPP -#define CK_TENSOR_DESCRIPTOR_HELPER_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" - -namespace ck { - -template -__host__ __device__ constexpr auto calculate_tensor_strides_packed(Lengths) -{ - return reverse_inclusive_scan_sequence( - Lengths{}.PopFront(), math::multiplies{}, Number<1>{}) - .PushBack(Number<1>{}); -} - -template -__host__ __device__ constexpr auto calculate_tensor_strides_aligned(Lengths, Number) -{ - constexpr index_t L_back_align = - Align * math::integer_divide_ceiler{}(Lengths{}.Back(), Align); - - return calculate_tensor_strides_packed( - Lengths{}.Modify(Number{}, Number{})); -} - -template -__host__ __device__ constexpr auto make_native_tensor_descriptor(Sequence, - Sequence) -{ - return NativeTensorDescriptor...>{}; -} - -template -__host__ __device__ constexpr auto make_native_tensor_descriptor_packed(Lengths) -{ - constexpr auto strides = calculate_tensor_strides_packed(Lengths{}); - - return make_native_tensor_descriptor(Lengths{}, strides); -} - -template -__host__ __device__ constexpr auto make_native_tensor_descriptor_aligned(Lengths, Number) -{ - constexpr auto strides = calculate_tensor_strides_aligned(Lengths{}, Number{}); - return make_native_tensor_descriptor(Lengths{}, strides); -} - -template -__host__ __device__ constexpr auto - transform_tensor_descriptor(LowTensorDescriptor, Transforms, LowDimensionIds, UpDimensionIds) -{ - return TransformedTensorDescriptor{}; -} - -template -__host__ __device__ constexpr auto -reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor, - Sequence, - Sequence, - Sequence) -{ - return TransformedTensorDescriptor...>, - Tuple...>, - Tuple...>>{}; -} - -// reorder a NativeTensorDescriptor -template -__host__ __device__ constexpr auto -reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor, MapLower2Upper) -{ - static_assert(is_valid_sequence_map{}, - "wrong! MapLower2Upper is not a valid map"); - - constexpr auto old_desc = NativeTensorDescriptor{}; - - static_assert(old_desc.GetNumOfDimension() == MapLower2Upper::Size(), "wrong!"); - - constexpr auto new_lengths = old_desc.GetLengths().ReorderGivenOld2New(MapLower2Upper{}); - constexpr auto new_strides = old_desc.GetStrides().ReorderGivenOld2New(MapLower2Upper{}); - - return make_native_tensor_descriptor(new_lengths, new_strides); -} - -// reorder a TransformedTensorDescriptor -template -__host__ __device__ constexpr auto -reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor, MapLower2Upper) -{ - static_assert(is_valid_sequence_map{}, - "wrong! MapLower2Upper is not a valid map"); - - constexpr auto low_desc = TransformedTensorDescriptor{}; - - static_assert(low_desc.GetNumOfDimension() == MapLower2Upper::Size(), "wrong!"); - - return reorder_transformed_tensor_descriptor_impl( - low_desc, - low_desc.GetLengths(), - typename arithmetic_sequence_gen<0, low_desc.GetNumOfDimension(), 1>::type{}, - MapLower2Upper{}); -} - -template -__host__ __device__ constexpr auto - reorder_tensor_descriptor_given_upper2lower(LowerTensorDescriptor, MapUpper2Lower) -{ - return reorder_tensor_descriptor_given_lower2upper( - LowerTensorDescriptor{}, typename sequence_map_inverse::type{}); -} - -template -__host__ __device__ constexpr bool are_dimensions_unfoldable(Lengths, Strides) -{ - static_assert(Lengths::Size() == Strides::Size(), "wrong!"); - - bool flag = true; - - for(index_t i = 0; i < Lengths::Size() - 1; ++i) - { - flag = flag && Strides::At(i) == Strides::At(i + 1) * Lengths::At(i + 1); - } - - return flag; -} - -// unfold only support NativeTennsorDescriptor, for now -template -__host__ __device__ constexpr auto unfold_tensor_descriptor(NativeTensorDescriptor desc, - Number, - Number) -{ - constexpr index_t nDim = desc.GetNumOfDimension(); - - static_assert(FirstUnfoldDim >= 0 && LastUnfoldDim < nDim && FirstUnfoldDim <= LastUnfoldDim, - "wrong! should have FirstUnfoldDim <= LastUnfoldDim!"); - - // left and right - constexpr auto left = typename arithmetic_sequence_gen<0, FirstUnfoldDim, 1>::type{}; - constexpr auto middle = - typename arithmetic_sequence_gen::type{}; - constexpr auto right = typename arithmetic_sequence_gen::type{}; - - // sanity-check if unfold-able - static_assert(are_dimensions_unfoldable(desc.GetLengths(middle), desc.GetStrides(middle)), - "wrong! not unfold-able"); - - // unfolded length, stride - constexpr index_t unfold_length = - reduce_on_sequence(desc.GetLengths(middle), math::multiplies{}, Number<1>{}); - - constexpr index_t unfold_stride = desc.GetStride(Number{}); - - // new lengths, strides - constexpr auto new_lengths = - desc.GetLengths(left).PushBack(Number{}).PushBack(desc.GetLengths(right)); - - constexpr auto new_strides = - desc.GetStrides(left).PushBack(Number{}).PushBack(desc.GetStrides(right)); - - return make_native_tensor_descriptor(new_lengths, new_strides); -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_batched_gemm.hpp b/composable_kernel/include/tensor_operation/blockwise_batched_gemm.hpp deleted file mode 100644 index f5c0df4d7d..0000000000 --- a/composable_kernel/include/tensor_operation/blockwise_batched_gemm.hpp +++ /dev/null @@ -1,406 +0,0 @@ -#ifndef CK_BLOCKWISE_BATCHED_GEMM_HPP -#define CK_BLOCKWISE_BATCHED_GEMM_HPP - -#include "common_header.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "threadwise_gemm.hpp" - -#ifndef CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM -#define CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM 1 -#endif - -namespace ck { - -template -struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 -{ - index_t mMyThreadOffsetA = 0; - index_t mMyThreadOffsetB = 0; - - struct MatrixIndex - { - index_t batch; - index_t row; - index_t col; - }; - - __device__ BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2() - { - static_assert(BatchSize % BatchPerThread == 0, - "wrong! BatchSize is not dividable by BatchPerThread"); - - constexpr index_t BatchThreadWork = BatchSize / BatchPerThread; - - constexpr index_t ThreadPerLevel1Cluster = - MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster; - - static_assert(BlockSize == BatchThreadWork * ThreadPerLevel1Cluster, - "wrong! wrong blocksize\n"); - - constexpr auto a_block_mtx = BlockMatrixA{}; - constexpr auto b_block_mtx = BlockMatrixB{}; - constexpr auto c_thread_mtx = ThreadMatrixC{}; - - static_assert(a_block_mtx.NRow() == b_block_mtx.NRow(), - "wrong! K dimension not consistent\n"); - - constexpr index_t M = a_block_mtx.NCol(); // A is transposed - constexpr index_t N = b_block_mtx.NCol(); - - constexpr index_t MPerThread = c_thread_mtx.NRow(); - constexpr index_t NPerThread = c_thread_mtx.NCol(); - - static_assert((MPerThread % MPerThreadSubC == 0) && (NPerThread % NPerThreadSubC == 0), - "wrong! Cannot evenly divide thread work among repeat \n"); - - constexpr index_t MRepeat = MPerThread / MPerThreadSubC; - constexpr index_t NRepeat = NPerThread / NPerThreadSubC; - - static_assert((M % MRepeat == 0) && (N % NRepeat == 0), - "wrong! Cannot evenly divide work among repeat\n"); - - constexpr index_t MPerLevel1Cluster = M / MRepeat; - constexpr index_t NPerLevel1Cluster = N / NRepeat; - - static_assert((MPerLevel1Cluster % MLevel1Cluster == 0) && - (NPerLevel1Cluster % NLevel1Cluster == 0), - "wrong! Cannot evenly divide work among Level1Cluster\n"); - - constexpr index_t MPerLevel0Cluster = MPerLevel1Cluster / MLevel1Cluster; - constexpr index_t NPerLevel0Cluster = NPerLevel1Cluster / NLevel1Cluster; - - static_assert((MPerLevel0Cluster % MLevel0Cluster == 0) && - (NPerLevel0Cluster % NLevel0Cluster == 0), - "wrong! Cannot evenly divide work among Level0Cluster\n"); - - static_assert((MPerThreadSubC == MPerLevel0Cluster / MLevel0Cluster) && - (NPerThreadSubC == NPerLevel0Cluster / NLevel0Cluster), - "wrong! thread work size is wrong\n"); - - const auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - mMyThreadOffsetA = c_thread_mtx_index.batch * BlockMatrixStrideA + - a_block_mtx.GetOffsetFromMultiIndex(0, c_thread_mtx_index.row); - - mMyThreadOffsetB = c_thread_mtx_index.batch * BlockMatrixStrideB + - b_block_mtx.GetOffsetFromMultiIndex(0, c_thread_mtx_index.col); - } - - __device__ MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) const - { - constexpr index_t ThreadPerLevel1Cluster = - MLevel0Cluster * NLevel0Cluster * MLevel1Cluster * NLevel1Cluster; - - constexpr index_t ThreadPerLevel0Cluster = MLevel0Cluster * NLevel0Cluster; - - index_t batch_work_id = thread_id / ThreadPerLevel1Cluster; - index_t cluster_id = thread_id - batch_work_id * ThreadPerLevel1Cluster; - - index_t level1_id = cluster_id / ThreadPerLevel0Cluster; - index_t level1_m_id = level1_id / NLevel1Cluster; - index_t level1_n_id = level1_id % NLevel1Cluster; - - index_t level0_id = cluster_id % ThreadPerLevel0Cluster; - index_t level0_m_id = level0_id / NLevel0Cluster; - index_t level0_n_id = level0_id % NLevel0Cluster; - - constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0Cluster; - constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0Cluster; - - return MatrixIndex{batch_work_id * BatchPerThread, - level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC, - level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC}; - } - - // this should be optimized away because input will be known at compile time - __device__ static MatrixIndex - GetDistanceFromBeginOfThreadMatrixC(index_t batch_in_c, index_t m_in_c, index_t n_in_c) - { - constexpr auto c_thread_mtx = ThreadMatrixC{}; - - constexpr index_t MPerThread = c_thread_mtx.NRow(); - constexpr index_t NPerThread = c_thread_mtx.NCol(); - - constexpr index_t MRepeat = MPerThread / MPerThreadSubC; - constexpr index_t NRepeat = NPerThread / NPerThreadSubC; - - constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; - - index_t m_repeat = m_in_c / MPerThreadSubC; - index_t n_repeat = n_in_c / NPerThreadSubC; - - index_t m_in_sub_c = m_in_c % MPerThreadSubC; - index_t n_in_sub_c = n_in_c % NPerThreadSubC; - - return MatrixIndex{batch_in_c, - m_repeat * MPerLevel1Cluster + m_in_sub_c, - n_repeat * NPerLevel1Cluster + n_in_sub_c}; - } - - template - __device__ void Run_source(const FloatA* __restrict__ p_a_block, - const FloatB* __restrict__ p_b_block, - FloatC* __restrict__ p_c_thread) const - { - constexpr auto True = integral_constant{}; - constexpr auto False = integral_constant{}; - - constexpr auto a_block_mtx = BlockMatrixA{}; - constexpr auto b_block_mtx = BlockMatrixB{}; - constexpr auto c_thread_mtx = ThreadMatrixC{}; - - constexpr index_t KPerBlock = a_block_mtx.NRow(); // A is transposed - - constexpr index_t MPerThread = c_thread_mtx.NRow(); - constexpr index_t NPerThread = c_thread_mtx.NCol(); - - // thread A, B for GEMM - // A is transposed, b is not - constexpr auto a_thread_mtx = - make_ConstantMatrixDescriptor_packed(Number{}, Number{}); - - constexpr auto b_thread_mtx = - make_ConstantMatrixDescriptor_packed(Number{}, Number{}); - - // thread A-sub, B-sub for copy - constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; - FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; - - constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; - - constexpr index_t MRepeat = MPerThread / MPerThreadSubC; - constexpr index_t NRepeat = NPerThread / NPerThreadSubC; - -// loop over k -#pragma unroll - for(index_t k_begin = 0; k_begin < KPerBlock; k_begin += KPerThreadLoop) - { -// loop over batch -#pragma unroll - for(index_t ib = 0; ib < BatchPerThread; ++ib) - { - // read next batch of a, b - if(BlockMatrixStrideA != 0 or ib == 0) - { -#pragma unroll - for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) - { - threadwise_matrix_copy(a_block_mtx, - p_a_block + - a_block_mtx.GetOffsetFromMultiIndex( - k_begin, m_repeat * MPerLevel1Cluster) + - ib * BlockMatrixStrideA + mMyThreadOffsetA, - a_thread_mtx, - p_a_thread + a_thread_mtx.GetOffsetFromMultiIndex( - 0, m_repeat * MPerThreadSubC), - a_thread_sub_mtx.GetLengths(), - Number{}); - } - } - - if(BlockMatrixStrideB != 0 or ib == 0) - { -#pragma unroll - for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) - { - threadwise_matrix_copy(b_block_mtx, - p_b_block + - b_block_mtx.GetOffsetFromMultiIndex( - k_begin, n_repeat * NPerLevel1Cluster) + - ib * BlockMatrixStrideB + mMyThreadOffsetB, - b_thread_mtx, - p_b_thread + b_thread_mtx.GetOffsetFromMultiIndex( - 0, n_repeat * NPerThreadSubC), - b_thread_sub_mtx.GetLengths(), - Number{}); - } - } - - threadwise_gemm(a_thread_mtx, - True, - p_a_thread, - b_thread_mtx, - False, - p_b_thread, - c_thread_mtx, - False, - p_c_thread + ib * ThreadMatrixStrideC); - } - } - } - -#if CK_USE_AMD_INLINE_ASM - template - __device__ void Run_amd_asm(const FloatA* __restrict__ p_a_block, - const FloatB* __restrict__ p_b_block, - FloatC* __restrict__ p_c_thread) const - { - constexpr auto a_block_mtx = BlockMatrixA{}; - constexpr auto b_block_mtx = BlockMatrixB{}; - constexpr auto c_thread_mtx = ThreadMatrixC{}; - - constexpr index_t K = a_block_mtx.NRow(); // A is transposed - - constexpr index_t MPerThread = c_thread_mtx.NRow(); - constexpr index_t NPerThread = c_thread_mtx.NCol(); - - // thread A, B for GEMM - // A is transposed, b is not - constexpr auto a_thread_mtx = - make_ConstantMatrixDescriptor_packed(Number{}, Number{}); - - constexpr auto b_thread_mtx = - make_ConstantMatrixDescriptor_packed(Number{}, Number{}); - - // thread A-sub, B-sub for copy - constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; - FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; - - constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; - - // assertion for inline asm - static_assert(is_same{} && is_same{} && - is_same{}, - "Run_amd_asm only deal with float\n"); - - static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 && - MPerThread == 8 && NPerThread == 8, - "Run_amd_asm cannot deal with this GEMM shape yet\n"); - - static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_amd_asm only do float4 read\n"); - - static_assert(BlockMatrixStrideA == 0 && BatchPerThread == 1, - "Run_amd_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == " - "1 for now\n"); - - using Float4 = vector_type::type; - - Float4* reg_a = (Float4*)(p_a_thread); - Float4* reg_b = (Float4*)(p_b_thread); - Float4* reg_c = (Float4*)(p_c_thread); - - reg_a[0] = *reinterpret_cast(&p_a_block[mMyThreadOffsetA]); - reg_b[0] = *reinterpret_cast(&p_b_block[mMyThreadOffsetB]); - reg_b[1] = *reinterpret_cast( - &p_b_block[b_block_mtx.GetOffsetFromMultiIndex(0, NPerLevel1Cluster) + - mMyThreadOffsetB]); - reg_a[1] = *reinterpret_cast( - &p_a_block[a_block_mtx.GetOffsetFromMultiIndex(0, MPerLevel1Cluster) + - mMyThreadOffsetA]); - outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); - outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); - -#pragma unroll - for(index_t k = 1; k < K; ++k) - { - reg_a[0] = *reinterpret_cast( - &p_a_block[a_block_mtx.GetOffsetFromMultiIndex(k, 0) + mMyThreadOffsetA]); - outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); - reg_b[0] = *reinterpret_cast( - &p_b_block[b_block_mtx.GetOffsetFromMultiIndex(k, 0) + mMyThreadOffsetB]); - outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); - reg_b[1] = *reinterpret_cast( - &p_b_block[b_block_mtx.GetOffsetFromMultiIndex(k, NPerLevel1Cluster) + - mMyThreadOffsetB]); - reg_a[1] = *reinterpret_cast( - &p_a_block[a_block_mtx.GetOffsetFromMultiIndex(k, MPerLevel1Cluster) + - mMyThreadOffsetA]); - outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]); - outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]); - } - outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]); - outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]); - } -#endif - - template - __device__ void Run(const FloatA* __restrict__ p_a_block, - const FloatB* __restrict__ p_b_block, - FloatC* __restrict__ p_c_thread) const - - { -#if CK_USE_AMD_INLINE_ASM && CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM - Run_amd_asm(p_a_block, p_b_block, p_c_thread); -#else - Run_source(p_a_block, p_b_block, p_c_thread); -#endif - } - - template - __device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread, - FloatC* __restrict__ p_c_block) const - { - constexpr auto c_block_mtx = BlockMatrixC{}; - constexpr auto c_thread_mtx = ThreadMatrixC{}; - - constexpr index_t MPerThread = c_thread_mtx.NRow(); - constexpr index_t NPerThread = c_thread_mtx.NCol(); - - constexpr auto c_thread_sub_mtx = make_ConstantMatrixDescriptor( - Number{}, Number{}, Number{}); - - constexpr index_t MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster; - constexpr index_t NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster; - - constexpr index_t MRepeat = MPerThread / MPerThreadSubC; - constexpr index_t NRepeat = NPerThread / NPerThreadSubC; - - const auto c_thread_mtx_begin = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t c_thread_offset = - c_thread_mtx_begin.batch * BlockMatrixStrideC + - c_block_mtx.GetOffsetFromMultiIndex(c_thread_mtx_begin.row, c_thread_mtx_begin.col); - - for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) - { - for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) - { - threadwise_matrix_copy( - c_thread_sub_mtx, - p_c_thread + c_thread_sub_mtx.GetOffsetFromMultiIndex( - m_repeat * MPerLevel1Cluster, n_repeat * NPerLevel1Cluster), - c_block_mtx, - p_c_block + - c_block_mtx.GetOffsetFromMultiIndex(m_repeat * MPerLevel1Cluster, - n_repeat * NPerLevel1Cluster) + - c_thread_offset, - c_thread_sub_mtx.GetLengths()); - } - } - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm.hpp deleted file mode 100644 index 3ffeb3f16f..0000000000 --- a/composable_kernel/include/tensor_operation/blockwise_gemm.hpp +++ /dev/null @@ -1,334 +0,0 @@ -#ifndef CK_BLOCKWISE_GEMM_HPP -#define CK_BLOCKWISE_GEMM_HPP - -#include "common_header.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "threadwise_gemm.hpp" - -namespace ck { - -// blockwise GEMM: C += transpose(A) * B -// A and B are visable to the whole block, C is distributed among each thread -// If following number are power of 2, index calculation shall be greatly reduced: -// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster, -// MLevel1ThreadCluster, NLevel1ThreadCluster -template -struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 -{ - struct MatrixIndex - { - index_t row; - index_t col; - }; - - index_t mMyThreadOffsetA; - index_t mMyThreadOffsetB; - - __device__ BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2() - { - constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster * - MLevel1ThreadCluster * NLevel1ThreadCluster; - - static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n"); - - static_assert(BlockMatrixA::NRow() == BlockMatrixB::NRow(), - "wrong! K dimension not consistent\n"); - - constexpr index_t M = BlockMatrixA::NCol(); // A is transposed - constexpr index_t N = BlockMatrixB::NCol(); - - static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 && - N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0, - "wrong! Cannot evenly divide work among\n"); - - static_assert( - is_same{}, - "wrong! ThreadMatrixC lengths is wrong"); - - auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - mMyThreadOffsetA = BlockMatrixA::GetOffsetFromMultiIndex(0, c_thread_mtx_index.row); - mMyThreadOffsetB = BlockMatrixB::GetOffsetFromMultiIndex(0, c_thread_mtx_index.col); - } - - __device__ static constexpr auto GetThreadMatrixCLengths() - { - constexpr index_t M = BlockMatrixA::NCol(); // A is transposed - constexpr index_t N = BlockMatrixB::NCol(); - - constexpr index_t MRepeat = - M / (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster); - constexpr index_t NRepeat = - N / (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster); - - return Sequence{}; - } - - __device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) - { - constexpr index_t ThreadPerLevel0Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster; - - index_t level1_id = thread_id / ThreadPerLevel0Cluster; - index_t level1_m_id = level1_id / NLevel1ThreadCluster; - index_t level1_n_id = level1_id % NLevel1ThreadCluster; - - index_t level0_id = thread_id % ThreadPerLevel0Cluster; - index_t level0_m_id = level0_id / NLevel0ThreadCluster; - index_t level0_n_id = level0_id % NLevel0ThreadCluster; - - constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0ThreadCluster; - constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0ThreadCluster; - - return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC, - level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC}; - } - - template - __device__ void - Run_naive(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const - { - constexpr auto a_block_mtx = BlockMatrixA{}; - constexpr auto b_block_mtx = BlockMatrixB{}; - constexpr auto c_thread_mtx = ThreadMatrixC{}; - - constexpr index_t K = a_block_mtx.NRow(); - - constexpr index_t MPerThread = c_thread_mtx.NRow(); - constexpr index_t NPerThread = c_thread_mtx.NCol(); - - constexpr index_t MPerLevel1Cluster = - MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster; - constexpr index_t NPerLevel1Cluster = - NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster; - - constexpr index_t MRepeat = MPerThread / MPerThreadSubC; - constexpr index_t NRepeat = NPerThread / NPerThreadSubC; - - // thread A, B for GEMM - constexpr auto a_thread_mtx = - make_ConstantMatrixDescriptor_packed(Number{}, Number{}); - - constexpr auto b_thread_mtx = - make_ConstantMatrixDescriptor_packed(Number{}, Number{}); - - FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; - FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; - - constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy{}; - - constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy{}; - - constexpr auto threadwise_gemm = - ThreadwiseGemmTransANormalBNormalC{}; -#pragma unroll - // loop over k - for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) - { -#pragma unroll - // read A - for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) - { - a_thread_copy.Run( - p_a_block + a_block_mtx.CalculateOffset(k_begin, m_repeat * MPerLevel1Cluster) + - mMyThreadOffsetA, - p_a_thread + a_thread_mtx.CalculateOffset(0, m_repeat * MPerThreadSubC)); - } - -#pragma unroll - // read B - for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) - { - b_thread_copy.Run( - p_b_block + b_block_mtx.CalculateOffset(k_begin, n_repeat * NPerLevel1Cluster) + - mMyThreadOffsetB, - p_b_thread + b_thread_mtx.CalculateOffset(0, n_repeat * NPerThreadSubC)); - } - - // C += A * B - threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread); - } - } - - template - __device__ void - Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const - { - constexpr auto a_block_mtx = BlockMatrixA{}; - constexpr auto b_block_mtx = BlockMatrixB{}; - constexpr auto c_thread_mtx = ThreadMatrixC{}; - - constexpr index_t K = a_block_mtx.NRow(); - - constexpr index_t MPerThread = c_thread_mtx.NRow(); - constexpr index_t NPerThread = c_thread_mtx.NCol(); - - constexpr index_t MPerLevel1Cluster = - MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster; - constexpr index_t NPerLevel1Cluster = - NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster; - - constexpr index_t MRepeat = MPerThread / MPerThreadSubC; - constexpr index_t NRepeat = NPerThread / NPerThreadSubC; - - static_assert(MRepeat == 2 && NRepeat == 2, - "wrong! inline asm cannot deal with this GEMM config yet"); - - // thread A, B - constexpr auto a_thread_mtx = - make_ConstantMatrixDescriptor_packed(Number{}, Number{}); - constexpr auto b_thread_mtx = - make_ConstantMatrixDescriptor_packed(Number{}, Number{}); - - // thread A-sub, B-sub - constexpr auto a_thread_sub_mtx = a_thread_mtx.MakeSubMatrixDescriptor( - Number{}, Number{}); - constexpr auto b_thread_sub_mtx = b_thread_mtx.MakeSubMatrixDescriptor( - Number{}, Number{}); - - // thread C-sub - constexpr auto c_thread_sub_mtx = ThreadMatrixC::MakeSubMatrixDescriptor( - Number{}, Number{}); - - FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; - FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; - - constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy{}; - - constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy{}; - - constexpr auto threadwise_gemm = - ThreadwiseGemmTransANormalBNormalC{}; - - const FloatA* p_a_block_off = p_a_block + mMyThreadOffsetA; - const FloatB* p_b_block_off = p_b_block + mMyThreadOffsetB; - - // read A_sub_0 - a_thread_copy.Run(p_a_block_off, p_a_thread); - - // read B_sub_0 - b_thread_copy.Run(p_b_block_off, p_b_thread); - - // read B_sub_1 - b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(0, NPerLevel1Cluster), - p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC)); - - // read A_sub_1 - a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(0, MPerLevel1Cluster), - p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC)); - - // C_sub_00 += transpose(A_sub_0) * B_sub_0 - threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread); - - // C_sub_01 += transpose(A_sub_0) * B_sub_1 - threadwise_gemm.Run(p_a_thread, - p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC), - p_c_thread + ThreadMatrixC::CalculateOffset(0, NPerThreadSubC)); - -#pragma unroll - // loop over rest of k - for(index_t k = KPerThreadLoop; k < K; k += KPerThreadLoop) - { - // read A_sub_0 - a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(k, 0), p_a_thread); - - // C_sub_10 += transpose(A_sub_1) * B_sub_0 - threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC), - p_b_thread, - p_c_thread + ThreadMatrixC::CalculateOffset(MPerThreadSubC, 0)); - - // read B_sub_0 - b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(k, 0), p_b_thread); - - // C_sub_11 += transpose(A_sub_1) * B_sub_1 - threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC), - p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC), - p_c_thread + - ThreadMatrixC::CalculateOffset(MPerThreadSubC, NPerThreadSubC)); - - // read B_sub_1 - b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(k, NPerLevel1Cluster), - p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC)); - - // read A_sub_1 - a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(k, MPerLevel1Cluster), - p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC)); - - // C_sub_00 += transpose(A_sub_0) * B_sub_0 - threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread); - - // C_sub_01 += transpose(A_sub_0) * B_sub_1 - threadwise_gemm.Run(p_a_thread, - p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC), - p_c_thread + ThreadMatrixC::CalculateOffset(0, NPerThreadSubC)); - } - - // C_sub_10 += transpose(A_sub_1) * B_sub_0 - threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC), - p_b_thread, - p_c_thread + ThreadMatrixC::CalculateOffset(MPerThreadSubC, 0)); - - // C_sub_11 += transpose(A_sub_1) * B_sub_1 - threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC), - p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC), - p_c_thread + - ThreadMatrixC::CalculateOffset(MPerThreadSubC, NPerThreadSubC)); - } - - template - __device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const - { -#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE - constexpr index_t MPerThread = ThreadMatrixC::NRow(); - constexpr index_t NPerThread = ThreadMatrixC::NCol(); - - constexpr index_t MRepeat = MPerThread / MPerThreadSubC; - constexpr index_t NRepeat = NPerThread / NPerThreadSubC; - - if constexpr(MRepeat == 2 && NRepeat == 2) - { - Run_pipelined_2x2(p_a_block, p_b_block, p_c_thread); - } - else - { - Run_naive(p_a_block, p_b_block, p_c_thread); - } -#else - Run_naive(p_a_block, p_b_block, p_c_thread); -#endif - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp b/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp deleted file mode 100644 index d67101a935..0000000000 --- a/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp +++ /dev/null @@ -1,189 +0,0 @@ -#ifndef CK_BLOCKWISE_GENERIC_TENSOR_SLICE_COPY_HPP -#define CK_BLOCKWISE_GENERIC_TENSOR_SLICE_COPY_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "tensor_coordinate.hpp" -#include "cluster_descriptor.hpp" -#include "threadwise_generic_tensor_slice_copy.hpp" - -namespace ck { - -// This blockwise copy allow vector access of src and dst. -// It allows the vector size to be different on src and dst. -// The dimension of vector access can be different for src and dst. -// The dimension access order can be different for src and dst. -// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping -// Will do valid mapping check on dst data: No write if dst data has a invalid mapping -// BlockSize can be equal or larger than ThreadCluster size, which means some threads may not do -// threadwise copy -template -struct BlockwiseGenericTensorSliceCopy_v4 -{ - static constexpr index_t nDim = BlockSrcDesc::GetNumOfDimension(); - using Index = MultiIndex; - - __device__ constexpr BlockwiseGenericTensorSliceCopy_v4(const Index& src_block_slice_origin, - const Index& dst_block_slice_origin) - { - static_assert(nDim == BlockSrcDesc::GetNumOfDimension() && - nDim == BlockDstDesc::GetNumOfDimension() && - nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() && - nDim == ThreadClusterLengths::Size() && - nDim == ThreadClusterArrangeOrder::Size() && - nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), - "wrong! nDim not consistent"); - - static_assert( - is_same{}, - "wrong! threads should be mapped to cover entire slicing window"); - - static_assert(BlockSize >= mThreadClusterDesc.GetElementSize(), - "wrong! BlockSize too small"); - - if(BlockSize == mThreadClusterDesc.GetElementSize() or - get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize()) - { - const auto thread_cluster_id = - mThreadClusterDesc.CalculateClusterIndex(get_thread_local_1d_id()); - - const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{}; - - mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin); - mThreadwiseLoad.SetDstSliceOrigin(make_zero_multi_index()); - - mThreadwiseStore.SetSrcSliceOrigin(make_zero_multi_index()); - mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin); - } - } - - __device__ static constexpr index_t GetThreadBufferSize() - { - return ThreadBufferDesc::GetElementSpace(); - } - - template - __device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src, - ThreadBufferData* p_thread_buffer) const - { - if(BlockSize == mThreadClusterDesc.GetElementSize() or - get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize()) - { - mThreadwiseLoad.Run(p_block_src, p_thread_buffer); - } - } - - template - __device__ void RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer, - BlockDstData* p_block_dst) const - { - if(BlockSize == mThreadClusterDesc.GetElementSize() or - get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize()) - { - mThreadwiseStore.Run(p_thread_buffer, p_block_dst); - } - } - - template - __device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const - { - static_assert(ThreadBufferAddressSpace == AddressSpace::Vgpr, - "wrong! This function use vgpr as its thread " - "buffer. However, you have set RunLoadThreadBuffer and RunStoreThreadBuffer " - "to use ThreadBufferAddressSpace as their thread buffer, which is not vgpr. " - "Behavior may be different"); - - BlockSrcData p_thread_buffer[GetThreadBufferSize()]; - - if(BlockSize == mThreadClusterDesc.GetElementSize() or - get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize()) - { - RunLoadThreadBuffer(p_block_src, p_thread_buffer); - - // if there is type conversion, it's done during store - RunStoreThreadBuffer(p_thread_buffer, p_block_dst); - } - } - - template - __device__ void - MoveSrcSliceWindow(const T& step_sizes, - integral_constant positive_direction) - { - if(BlockSize == mThreadClusterDesc.GetElementSize() or - get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize()) - { - mThreadwiseLoad.MoveSrcSliceWindow(step_sizes, positive_direction); - } - } - - template - __device__ void - MoveDstSliceWindow(const T& step_sizes, - integral_constant positive_direction) - { - if(BlockSize == mThreadClusterDesc.GetElementSize() or - get_thread_local_1d_id() < mThreadClusterDesc.GetElementSize()) - { - mThreadwiseStore.MoveDstSliceWindow(step_sizes, positive_direction); - } - } - - private: - using ThreadBufferDesc = decltype(make_native_tensor_descriptor_packed(ThreadSliceLengths{})); - - using ThreadwiseLoad = ThreadwiseGenericTensorSliceCopy_v4r2; - - using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v4r2; - - static constexpr auto mThreadClusterDesc = - make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); - - ThreadwiseLoad mThreadwiseLoad; - ThreadwiseStore mThreadwiseStore; -}; - -} // namespace ck - -#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm.hpp deleted file mode 100644 index 7f9936bcd9..0000000000 --- a/composable_kernel/include/tensor_operation/gridwise_gemm.hpp +++ /dev/null @@ -1,785 +0,0 @@ -#ifndef CK_GRIDWISE_GEMM_HPP -#define CK_GRIDWISE_GEMM_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "blockwise_generic_tensor_slice_copy.hpp" -#include "threadwise_generic_tensor_slice_copy.hpp" -#include "blockwise_gemm.hpp" - -namespace ck { - -template -struct GridwiseGemmTransposedANormalBNormalC_v1 -{ - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M, - BBlockCopyDstDataPerWrite_N, - ThreadGemmAThreadCopySrcDataPerRead_M, - ThreadGemmBThreadCopySrcDataPerRead_N); - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned( - Sequence{}, Number{}); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned( - Sequence{}, Number{}); - - // LDS allocation for A and B: be careful of alignment - constexpr index_t a_block_space_size = - math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align); - - constexpr index_t b_block_space_size = - math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align); - - return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float); - } - - __device__ void Run(const Float* __restrict__ p_a_global, - const Float* __restrict__ p_b_global, - Float* __restrict__ p_c_global, - Float* __restrict__ p_shared_block) const - { - constexpr auto True = integral_constant{}; - - constexpr auto a_k_m_global_desc = AGlobalDesc{}; - constexpr auto b_k_n_global_desc = BGlobalDesc{}; - constexpr auto c_m_n_global_desc = CGlobalDesc{}; - - constexpr auto K = a_k_m_global_desc.GetLengths()[0]; - constexpr auto M = a_k_m_global_desc.GetLengths()[1]; - constexpr auto N = b_k_n_global_desc.GetLengths()[1]; - - // don't do anything if K == 0 - if(K == 0) - { - return; - } - - // lds max alignment - constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M, - BBlockCopyDstDataPerWrite_N, - ThreadGemmAThreadCopySrcDataPerRead_M, - ThreadGemmBThreadCopySrcDataPerRead_N); - - // divide block work by [M, N] - static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0, - "wrong! cannot divide work evenly among block"); - - constexpr index_t MBlockWork = M / MPerBlock; - constexpr index_t NBlockWork = N / NPerBlock; - - constexpr auto block_work_desc = - make_cluster_descriptor(Sequence{}); - - const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id()); - - const index_t m_block_data_on_global = block_work_id[Number<0>{}] * MPerBlock; - const index_t n_block_data_on_global = block_work_id[Number<1>{}] * NPerBlock; - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned( - Sequence{}, Number{}); - - // A matrix blockwise copy - auto a_blockwise_copy = - BlockwiseGenericTensorSliceCopy_v4, - ABlockCopySrcVectorReadDim, - 1, - ABlockCopySrcDataPerRead, - ABlockCopyDstDataPerWrite_M, - AddressSpace::Global, - AddressSpace::Vgpr, - AddressSpace::Lds, - InMemoryDataOperation::Set>( - make_multi_index(0, m_block_data_on_global), make_multi_index(0, 0)); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned( - Sequence{}, Number{}); - - // B matrix blockwise copy - auto b_blockwise_copy = - BlockwiseGenericTensorSliceCopy_v4, - BBlockCopySrcVectorReadDim, - 1, - BBlockCopySrcDataPerRead, - BBlockCopyDstDataPerWrite_N, - AddressSpace::Global, - AddressSpace::Vgpr, - AddressSpace::Lds, - InMemoryDataOperation::Set>( - make_multi_index(0, n_block_data_on_global), make_multi_index(0, 0)); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[KPerBlock, MPerBlock] is in LDS - // b_mtx[KPerBlocl, NPerBlock] is in LDS - // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in - // register - constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor(a_k_m_block_desc); - constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(b_k_n_block_desc); - - // sanity check - static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 && - NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0, - "wrong!"); - - constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster); - constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster); - - // c_thread_mtx definition: this is a mess - // TODO:: more elegent way of defining c_thread_mtx - constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed( - Number{}, Number{}); - - const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< - BlockSize, - decltype(a_k_m_block_mtx_desc), - decltype(b_k_n_block_mtx_desc), - decltype(c_m0m1_n0n1_thread_mtx_desc), - MPerThread, - NPerThread, - KPerThread, - MLevel0Cluster, - NLevel0Cluster, - MLevel1Cluster, - NLevel1Cluster, - ThreadGemmAThreadCopySrcDataPerRead_M, - ThreadGemmBThreadCopySrcDataPerRead_N>{}; - - // LDS allocation for A and B: be careful of alignment - constexpr index_t a_block_space_size = - math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align); - - constexpr index_t b_block_space_size = - math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align); - - Float* p_a_block_double = p_shared_block; - Float* p_b_block_double = p_shared_block + 2 * a_block_space_size; - - // register allocation for output - AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()]; - - // zero out threadwise output - threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread); - - // LDS double buffer: preload data into LDS - { - a_blockwise_copy.Run(p_a_global, p_a_block_double); - b_blockwise_copy.Run(p_b_global, p_b_block_double); - } - - constexpr auto a_block_slice_copy_step = Sequence{}; - constexpr auto b_block_slice_copy_step = Sequence{}; - - Float* p_a_block_even = p_a_block_double; - Float* p_b_block_even = p_b_block_double; - - Float* p_a_block_odd = p_a_block_double + a_block_space_size; - Float* p_b_block_odd = p_b_block_double + b_block_space_size; - - // LDS double buffer: main body - for(index_t k_block_data_begin = 0; k_block_data_begin < K - 2 * KPerBlock; - k_block_data_begin += 2 * KPerBlock) - { - Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; - Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; - - // even iteration - a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True); - b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer); - b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread); - - // LDS double buffer: store next data to LDS - a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_odd); - b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_odd); - - // odd iteration - a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True); - b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer); - b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread); - - // LDS double buffer: store next data to LDS - a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_even); - b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_even); - } - - // LDS double buffer: tail - { - constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0); - - if(has_two_iteration_left) // if has 2 iteration left - { - Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; - Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; - - a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True); - b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True); - - __syncthreads(); - - // LDS double buffer: load last data from device mem - a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer); - b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer); - - // LDS double buffer: GEMM on 2nd-last data - blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); - - // LDS double buffer: store last data to LDS - a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, - p_a_block_double + a_block_space_size); - b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, - p_b_block_double + b_block_space_size); - - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(p_a_block_double + a_block_space_size, - p_b_block_double + b_block_space_size, - p_c_thread); - } - else // if has 1 iteration left - { - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); - } - } - - // input: register to global memory - { - constexpr index_t M1 = MPerThread * MLevel0Cluster * MLevel1Cluster; - constexpr index_t M0 = M / M1; - - constexpr index_t N1 = NPerThread * NLevel0Cluster * NLevel1Cluster; - constexpr index_t N0 = N / N1; - - // define input tensor descriptor for threadwise copy - // thread input tensor, src of threadwise copy - constexpr auto c_m0_m1_n0_n1_thread_desc = make_native_tensor_descriptor_packed( - Sequence{}); - - constexpr auto c_m0_m1_n0_n1_global_desc = transform_tensor_descriptor( - c_m_n_global_desc, - make_tuple(UnMerge>{}, UnMerge>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); - - // calculate origin of thread input tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t m_thread_data_on_global = - m_block_data_on_global + c_thread_mtx_on_block.row; - - const index_t n_thread_data_on_global = - n_block_data_on_global + c_thread_mtx_on_block.col; - - ThreadwiseGenericTensorSliceCopy_v4r2( - make_multi_index(0, 0, 0, 0), - make_multi_index(m_thread_data_on_global / M1, - m_thread_data_on_global % M1, - n_thread_data_on_global / N1, - n_thread_data_on_global % N1)) - .Run(p_c_thread, p_c_global); - } - } - - __device__ void Run(const Float* __restrict__ p_a_global, - const Float* __restrict__ p_b_global, - Float* __restrict__ p_c_global) const - { - constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float); - - __shared__ Float p_shared_block[shared_block_size]; - - Run(p_a_global, p_b_global, p_c_global, p_shared_block); - } -}; - -template -struct GridwiseGemmTransposedANormalBNormalC_v2 -{ - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M, - BBlockCopyDstDataPerWrite_N, - ThreadGemmAThreadCopySrcDataPerRead_M, - ThreadGemmBThreadCopySrcDataPerRead_N); - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned( - Sequence{}, Number{}); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned( - Sequence{}, Number{}); - - // LDS allocation for A and B: be careful of alignment - constexpr index_t a_block_space_size = - math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align); - - constexpr index_t b_block_space_size = - math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align); - - return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float); - } - - __device__ void Run(const Float* __restrict__ p_a_global, - const Float* __restrict__ p_b_global, - Float* __restrict__ p_c_global, - Float* __restrict__ p_shared_block) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto True = integral_constant{}; - constexpr auto False = integral_constant{}; - - constexpr auto a_k0_k1_k2_m_global_desc = AGlobalDesc{}; - constexpr auto b_k0_k1_k2_n_global_desc = BGlobalDesc{}; - constexpr auto c_m_n_global_desc = CGlobalDesc{}; - - constexpr auto K0 = a_k0_k1_k2_m_global_desc.GetLengths()[I0]; - constexpr auto K1 = a_k0_k1_k2_m_global_desc.GetLengths()[I1]; - constexpr auto K = a_k0_k1_k2_m_global_desc.GetLengths()[I2]; - constexpr auto M = c_m_n_global_desc.GetLengths()[I0]; - constexpr auto N = c_m_n_global_desc.GetLengths()[I1]; - - // don't do anything if K == 0 - if(K == 0) - { - return; - } - - // lds max alignment - constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M, - BBlockCopyDstDataPerWrite_N, - ThreadGemmAThreadCopySrcDataPerRead_M, - ThreadGemmBThreadCopySrcDataPerRead_N); - - // divide block work by [M, N] - static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0, - "wrong! cannot divide work evenly among block"); - - constexpr index_t MBlockWork = M / MPerBlock; - constexpr index_t NBlockWork = N / NPerBlock; - - constexpr auto block_work_desc = - make_cluster_descriptor(Sequence{}); - - const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id()); - - const index_t m_block_data_on_global = block_work_id[I0] * MPerBlock; - const index_t n_block_data_on_global = block_work_id[I1] * NPerBlock; - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k0_k1_k2_m_block_desc = make_native_tensor_descriptor_aligned( - Sequence<1, 1, KPerBlock, MPerBlock>{}, Number{}); - - // A matrix blockwise copy - auto a_blockwise_copy = - BlockwiseGenericTensorSliceCopy_v4, - ABlockCopySrcVectorReadDim, - 3, - ABlockCopySrcDataPerRead, - ABlockCopyDstDataPerWrite_M, - AddressSpace::Global, - AddressSpace::Vgpr, - AddressSpace::Lds, - InMemoryDataOperation::Set>( - make_multi_index(0, 0, 0, m_block_data_on_global), make_multi_index(0, 0, 0, 0)); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k0_k1_k2_n_block_desc = make_native_tensor_descriptor_aligned( - Sequence<1, 1, KPerBlock, NPerBlock>{}, Number{}); - - // B matrix blockwise copy - auto b_blockwise_copy = - BlockwiseGenericTensorSliceCopy_v4, - BBlockCopySrcVectorReadDim, - 3, - BBlockCopySrcDataPerRead, - BBlockCopyDstDataPerWrite_N, - AddressSpace::Global, - AddressSpace::Vgpr, - AddressSpace::Lds, - InMemoryDataOperation::Set>( - make_multi_index(0, 0, 0, n_block_data_on_global), make_multi_index(0, 0, 0, 0)); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[KPerBlock, MPerBlock] is in LDS - // b_mtx[KPerBlocl, NPerBlock] is in LDS - // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in - // register - constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor( - unfold_tensor_descriptor(a_k0_k1_k2_m_block_desc, I0, I2)); - constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor( - unfold_tensor_descriptor(b_k0_k1_k2_n_block_desc, I0, I2)); - - // sanity check - static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 && - NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0, - "wrong!"); - - constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster); - constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster); - - // c_thread_mtx definition: this is a mess - // TODO:: more elegent way of defining c_thread_mtx - constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed( - Number{}, Number{}); - - const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< - BlockSize, - decltype(a_k_m_block_mtx_desc), - decltype(b_k_n_block_mtx_desc), - decltype(c_m0m1_n0n1_thread_mtx_desc), - MPerThread, - NPerThread, - KPerThread, - MLevel0Cluster, - NLevel0Cluster, - MLevel1Cluster, - NLevel1Cluster, - ThreadGemmAThreadCopySrcDataPerRead_M, - ThreadGemmBThreadCopySrcDataPerRead_N>{}; - - // LDS allocation for A and B: be careful of alignment - constexpr index_t a_block_space_size = - math::integer_least_multiple(a_k0_k1_k2_m_block_desc.GetElementSpace(), max_lds_align); - - constexpr index_t b_block_space_size = - math::integer_least_multiple(b_k0_k1_k2_n_block_desc.GetElementSpace(), max_lds_align); - - Float* p_a_block_double = p_shared_block; - Float* p_b_block_double = p_shared_block + 2 * a_block_space_size; - - // register allocation for output - AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()]; - - // zero out threadwise output - threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread); - - for(index_t k0 = 0; k0 < K0; ++k0) - { - for(index_t k1 = 0; k1 < K1; ++k1) - { - // LDS double buffer: preload data into LDS - { - a_blockwise_copy.Run(p_a_global, p_a_block_double); - b_blockwise_copy.Run(p_b_global, p_b_block_double); - } - - constexpr auto a_block_slice_copy_step = Sequence<0, 0, KPerBlock, 0>{}; - constexpr auto b_block_slice_copy_step = Sequence<0, 0, KPerBlock, 0>{}; - - // LDS double buffer: main body - for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K; - k_block_data_begin += 2 * KPerBlock) - { -#pragma unroll - for(index_t iloop = 0; iloop < 2; ++iloop) - { - const bool even_loop = (iloop % 2 == 0); - - Float* p_a_block_now = - even_loop ? p_a_block_double : p_a_block_double + a_block_space_size; - Float* p_b_block_now = - even_loop ? p_b_block_double : p_b_block_double + b_block_space_size; - - Float* p_a_block_next = - even_loop ? p_a_block_double + a_block_space_size : p_a_block_double; - Float* p_b_block_next = - even_loop ? p_b_block_double + b_block_space_size : p_b_block_double; - - Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; - Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; - - a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True); - b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer); - b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread); - - // LDS double buffer: store next data to LDS - a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next); - b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next); - } - } - - // LDS double buffer: tail - { - constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0); - - if(has_two_iteration_left) // if has 2 iteration left - { - Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; - Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; - - a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True); - b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True); - - __syncthreads(); - - // LDS double buffer: load last data from device mem - a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer); - b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer); - - // LDS double buffer: GEMM on 2nd-last data - blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); - - // LDS double buffer: store last data to LDS - a_blockwise_copy.RunStoreThreadBuffer( - p_a_thread_buffer, p_a_block_double + a_block_space_size); - b_blockwise_copy.RunStoreThreadBuffer( - p_b_thread_buffer, p_b_block_double + b_block_space_size); - - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(p_a_block_double + a_block_space_size, - p_b_block_double + b_block_space_size, - p_c_thread); - } - else // if has 1 iteration left - { - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); - } - } - - // reset slice windoww on K2 dimension, then move forward on K1 dimension - a_blockwise_copy.MoveSrcSliceWindow(Sequence<0, 0, K - KPerBlock, 0>{}, False); - b_blockwise_copy.MoveSrcSliceWindow(Sequence<0, 0, K - KPerBlock, 0>{}, False); - - a_blockwise_copy.MoveSrcSliceWindow(Sequence<0, 1, 0, 0>{}, True); - b_blockwise_copy.MoveSrcSliceWindow(Sequence<0, 1, 0, 0>{}, True); - } - - // reset slice windoww on K1 dimension, then move forward on K0 dimension - a_blockwise_copy.MoveSrcSliceWindow(Sequence<0, K1, 0, 0>{}, False); - b_blockwise_copy.MoveSrcSliceWindow(Sequence<0, K1, 0, 0>{}, False); - - a_blockwise_copy.MoveSrcSliceWindow(Sequence<1, 0, 0, 0>{}, True); - b_blockwise_copy.MoveSrcSliceWindow(Sequence<1, 0, 0, 0>{}, True); - } - - // input: register to global memory - { - constexpr index_t M1 = MPerThread * MLevel0Cluster * MLevel1Cluster; - constexpr index_t M0 = M / M1; - - constexpr index_t N1 = NPerThread * NLevel0Cluster * NLevel1Cluster; - constexpr index_t N0 = N / N1; - - // define input tensor descriptor for threadwise copy - // thread input tensor, src of threadwise copy - constexpr auto c_m0_m1_n0_n1_thread_desc = make_native_tensor_descriptor_packed( - Sequence{}); - - constexpr auto c_m0_m1_n0_n1_global_desc = transform_tensor_descriptor( - c_m_n_global_desc, - make_tuple(UnMerge>{}, UnMerge>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); - - // calculate origin of thread input tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t m_thread_data_on_global = - m_block_data_on_global + c_thread_mtx_on_block.row; - - const index_t n_thread_data_on_global = - n_block_data_on_global + c_thread_mtx_on_block.col; - - ThreadwiseGenericTensorSliceCopy_v4r2( - make_multi_index(0, 0, 0, 0), - make_multi_index(m_thread_data_on_global / M1, - m_thread_data_on_global % M1, - n_thread_data_on_global / N1, - n_thread_data_on_global % N1)) - .Run(p_c_thread, p_c_global); - } - } - - __device__ void Run(const Float* __restrict__ p_a_global, - const Float* __restrict__ p_b_global, - Float* __restrict__ p_c_global) const - { - constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float); - - __shared__ Float p_shared_block[shared_block_size]; - - Run(p_a_global, p_b_global, p_c_global, p_shared_block); - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_direct_convolution.hpp b/composable_kernel/include/tensor_operation/threadwise_direct_convolution.hpp deleted file mode 100644 index bae080b04c..0000000000 --- a/composable_kernel/include/tensor_operation/threadwise_direct_convolution.hpp +++ /dev/null @@ -1,228 +0,0 @@ -#ifndef CK_THREADWISE_DIRECT_CONVOLUTION_HPP -#define CK_THREADWISE_DIRECT_CONVOLUTION_HPP - -#include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" -#include "threadwise_tensor_slice_copy.hpp" - -namespace ck { - -// optimized for scenario if p_in, p_wei, p_out are in register -template -__device__ void threadwise_direct_convolution_1(InDesc, - TInWei* const __restrict__ p_in, - WeiDesc, - TInWei* const __restrict__ p_wei, - OutDesc, - TOut* __restrict__ p_out) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_desc = InDesc{}; - constexpr auto wei_desc = WeiDesc{}; - constexpr auto out_desc = OutDesc{}; - -#if 0 - if(blockIdx.x == 0 && get_thread_local_1d_id() == 0) - { - print_ConstantTensorDescriptor(in_desc, "threadwise_direct_convolution: in_desc: "); - print_ConstantTensorDescriptor(wei_desc, "threadwise_direct_convolution: wei_desc: "); - print_ConstantTensorDescriptor(out_desc, "threadwise_direct_convolution: out_desc: "); - } -#endif - - for(index_t n = 0; n < out_desc.GetLength(I0); ++n) - { - for(index_t k = 0; k < out_desc.GetLength(I1); ++k) - { - for(index_t ho = 0; ho < out_desc.GetLength(I2); ++ho) - { - for(index_t wo = 0; wo < out_desc.GetLength(I3); ++wo) - { - for(index_t c = 0; c < wei_desc.GetLength(I1); ++c) - { - for(index_t y = 0; y < wei_desc.GetLength(I2); ++y) - { - for(index_t x = 0; x < wei_desc.GetLength(I3); ++x) - { - const index_t hi = ho + y; - const index_t wi = wo + x; - - const index_t in_index = - in_desc.GetOffsetFromMultiIndex(n, c, hi, wi); - - const index_t wei_index = - wei_desc.GetOffsetFromMultiIndex(k, c, y, x); - - const index_t out_index = - out_desc.GetOffsetFromMultiIndex(n, k, ho, wo); - - fused_multiply_accumulate( - p_out[out_index], p_wei[wei_index], p_in[in_index]); - } - } - } - } - } - } - } -} - -// Optimized for scenario if p_in and p_wei are in LDS, p_out are in register -// Copy in and wei into register before doing convolution -template -__device__ void threadwise_direct_convolution_2(InDesc, - TInWei* const __restrict__ p_in, - WeiDesc, - TInWei* const __restrict__ p_wei, - OutDesc, - TOut* __restrict__ p_out) -{ - constexpr auto in_desc = InDesc{}; - constexpr auto wei_desc = WeiDesc{}; - constexpr auto out_desc = OutDesc{}; - - constexpr auto in_reg_desc = make_ConstantTensorDescriptor_packed(in_desc.GetLengths()); - constexpr auto wei_reg_desc = make_ConstantTensorDescriptor_packed(wei_desc.GetLengths()); - - // register - TInWei p_in_reg[in_reg_desc.GetElementSpace()]; - TInWei p_wei_reg[wei_reg_desc.GetElementSpace()]; - - // copy input tensor into register - threadwise_tensor_slice_copy( - in_desc, p_in, in_reg_desc, p_in_reg, in_reg_desc.GetLengths(), Number<1>{}); - - // copy input tensor into register - threadwise_tensor_slice_copy( - wei_desc, p_wei, wei_reg_desc, p_wei_reg, wei_reg_desc.GetLengths(), Number<1>{}); - - // do convolution - threadwise_direct_convolution_1( - in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out); -} - -// optimized for scenario where p_in and p_wei are in LDS, p_out is in register -// break down a non-1x1 convolution into a sequence of 1x1 convolutions, -// load 1x1 weight into register, and do 1x1 convolution in register. -template -__device__ void threadwise_direct_convolution_3(InDesc, - Data* const __restrict__ p_in, - WeiDesc, - Data* const __restrict__ p_wei, - OutDesc, - Data* __restrict__ p_out) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_desc = InDesc{}; - constexpr auto wei_desc = WeiDesc{}; - constexpr auto out_desc = OutDesc{}; - - constexpr auto in_reg_desc = make_ConstantTensorDescriptor(Sequence{}); - - constexpr auto wei_reg_desc = make_ConstantTensorDescriptor( - Sequence{}); - - Data p_in_reg[in_reg_desc.GetElementSpace()]; - Data p_wei_reg[wei_reg_desc.GetElementSpace()]; - - constexpr index_t in_w_new_read = 1; - - constexpr auto in_desc_reg_new_read = - make_ConstantTensorDescriptor(Sequence{}); - -#if 0 - // this verison reused old input data in register, and read new data from LDS - // loop over vertical direction - for(index_t y = 0; y < wei_desc.GetLength(I2); ++y) - { - // read first input - threadwise_4d_tensor_copy(in_desc, - p_in + in_desc.GetOffsetFromMultiIndex(0, 0, y, 0), - in_reg_desc, - p_in_reg, - in_reg_desc.GetLengths()); - - // read first 1x1 weight - threadwise_4d_tensor_copy(wei_desc, - p_wei + wei_desc.GetOffsetFromMultiIndex(0, 0, y, 0), - wei_reg_desc, - p_wei_reg, - wei_reg_desc.GetLengths()); - - // do first 1x1 conv - threadwise_direct_convolution_1( - in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out); - - // loop over horizontal direction - for(index_t x = 1; x < wei_desc.GetLength(I3); ++x) - { - // read new weight - threadwise_4d_tensor_copy(wei_desc, - p_wei + wei_desc.GetOffsetFromMultiIndex(0, 0, y, x), - wei_reg_desc, - p_wei_reg, - wei_reg_desc.GetLengths()); - - // shift old input to the left - threadwise_4d_tensor_shift_down(in_reg_desc, p_in_reg, I3, Number{}); - - // read new input - threadwise_4d_tensor_copy( - in_desc, - p_in + in_desc.GetOffsetFromMultiIndex(0, 0, y, x + in_reg_desc.GetLength(I3) - 1), - in_reg_desc, - p_in_reg + - in_reg_desc.GetOffsetFromMultiIndex(0, 0, 0, in_reg_desc.GetLength(I3) - in_w_new_read), - in_desc_reg_new_read.GetLengths()); - - // do 1x1 conv - threadwise_direct_convolution_1( - in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out); - } - } -#elif 1 - // this version read all input from LDS when filter moves - // loop over vertical direction - for(index_t y = 0; y < wei_desc.GetLength(I2); ++y) - { - // loop over horizontal direction - for(index_t x = 0; x < wei_desc.GetLength(I3); ++x) - { - // read new weight - threadwise_4d_tensor_copy(wei_desc, - p_wei + wei_desc.GetOffsetFromMultiIndex(0, 0, y, x), - wei_reg_desc, - p_wei_reg, - wei_reg_desc.GetLengths()); - - // read new input - threadwise_4d_tensor_copy(in_desc, - p_in + in_desc.GetOffsetFromMultiIndex(0, 0, y, x), - in_reg_desc, - p_in_reg, - in_reg_desc.GetLengths()); - - // do 1x1 conv - threadwise_direct_convolution_1( - in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out); - } - } -#endif -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_gemm.hpp b/composable_kernel/include/tensor_operation/threadwise_gemm.hpp deleted file mode 100644 index 56440bc2b7..0000000000 --- a/composable_kernel/include/tensor_operation/threadwise_gemm.hpp +++ /dev/null @@ -1,165 +0,0 @@ -#ifndef CK_THREADWISE_GEMM_HPP -#define CK_THREADWISE_GEMM_HPP - -#include "common_header.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "math.hpp" - -namespace ck { - -template -__device__ void threadwise_matrix_set_zero(Matrix, Float* __restrict__ p_thread) -{ - for(index_t i = 0; i < Matrix::NRow(); ++i) - { - for(index_t j = 0; j < Matrix::NCol(); ++j) - { - const index_t id = Matrix::CalculateOffset(i, j); - p_thread[id] = Float(0); - } - } -} - -template -struct ThreadwiseMatrixSliceCopy -{ - __device__ constexpr ThreadwiseMatrixSliceCopy() - { - static_assert(SrcMatrix::RowStride() % DataPerAccess == 0 && - DstMatrix::RowStride() % DataPerAccess == 0, - "wrong! wrong alignment"); - static_assert(NSliceCol % DataPerAccess == 0, - "wrong! should be NSliceCol % DataPerAccess == 0"); - } - - template - __device__ static void Run(const Data* p_src, Data* p_dst) - { - using vector_t = typename vector_type::type; - - for(index_t i = 0; i < NSliceRow; ++i) - { - for(index_t j = 0; j < NSliceCol; j += DataPerAccess) - { - const index_t src_index = SrcMatrix::CalculateOffset(i, j); - const index_t dst_index = DstMatrix::CalculateOffset(i, j); - - *reinterpret_cast(&p_dst[dst_index]) = - *reinterpret_cast(&p_src[src_index]); - } - } - } -}; - -// C += transpose(A) * B -// Element of matrix can be vectorized data -template -struct ThreadwiseGemmTransANormalBNormalC -{ - __device__ constexpr ThreadwiseGemmTransANormalBNormalC() - { - static_assert(MatrixA::NRow() == MatrixB::NRow() && MatrixA::NCol() == MatrixC::NRow() && - MatrixB::NCol() == MatrixC::NCol(), - "wrong!"); - } - - template - __device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) - { - constexpr index_t M = MatrixC::NRow(); - constexpr index_t N = MatrixC::NCol(); - constexpr index_t K = MatrixA::NRow(); // A is transposed - - for(index_t k = 0; k < K; ++k) - { - for(index_t m = 0; m < M; ++m) - { - for(index_t n = 0; n < N; ++n) - { - const index_t aindex = MatrixA::CalculateOffset(k, m); // A is transposed - const index_t bindex = MatrixB::CalculateOffset(k, n); - const index_t cindex = MatrixC::CalculateOffset(m, n); - - p_c[cindex] += - inner_product_with_conversion{}(p_a[aindex], p_b[bindex]); - } - } - } - } - -#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM - template - __device__ static void Run_amd_asm(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) - { - constexpr index_t M = MatrixC::NRow(); - constexpr index_t N = MatrixC::NCol(); - constexpr index_t K = MatrixA::NRow(); // A is transposed - - static_assert(N == 4 || N == 2, "wrong! this config not supported by asm yet"); - - for(index_t k = 0; k < K; ++k) - { - for(index_t m = 0; m < M; ++m) - { - const index_t aindex = MatrixA::CalculateOffset(k, m); // A is transposed - - static_if{}([&](auto) { - const index_t bindex_0 = MatrixB::CalculateOffset(k, 0); - const index_t bindex_1 = MatrixB::CalculateOffset(k, 1); - - const index_t cindex_0 = MatrixC::CalculateOffset(m, 0); - const index_t cindex_1 = MatrixC::CalculateOffset(m, 1); - - amd_assembly_outer_product_1x2( - p_a[aindex], p_b[bindex_0], p_b[bindex_1], p_c[cindex_0], p_c[cindex_1]); - }); - - static_if{}([&](auto) { - const index_t bindex_0 = MatrixB::CalculateOffset(k, 0); - const index_t bindex_1 = MatrixB::CalculateOffset(k, 1); - const index_t bindex_2 = MatrixB::CalculateOffset(k, 2); - const index_t bindex_3 = MatrixB::CalculateOffset(k, 3); - - const index_t cindex_0 = MatrixC::CalculateOffset(m, 0); - const index_t cindex_1 = MatrixC::CalculateOffset(m, 1); - const index_t cindex_2 = MatrixC::CalculateOffset(m, 2); - const index_t cindex_3 = MatrixC::CalculateOffset(m, 3); - - amd_assembly_outer_product_1x4(p_a[aindex], - p_b[bindex_0], - p_b[bindex_1], - p_b[bindex_2], - p_b[bindex_3], - p_c[cindex_0], - p_c[cindex_1], - p_c[cindex_2], - p_c[cindex_3]); - }); - } - } - } -#endif - - template - __device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) - { -#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM - constexpr bool has_amd_asm = is_same{} && - ((is_same{} && is_same{}) || - (is_same{} && is_same{}) || - (is_same{} && is_same{})); - - static_if{}([&](auto fwd) { Run_amd_asm(p_a, p_b, fwd(p_c)); }) - .Else([&](auto) { Run_source(p_a, p_b, p_c); }); -#else - Run_source(p_a, p_b, p_c); -#endif - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_op.hpp b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_op.hpp deleted file mode 100644 index 8b83b68c76..0000000000 --- a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_op.hpp +++ /dev/null @@ -1,20 +0,0 @@ -#ifndef CK_THREADWISE_GENERIC_TENSOR_OP_HPP -#define CK_THREADWISE_GENERIC_TENSOR_OP_HPP - -#include "common_header.hpp" -#include "ConstantTensorDescriptor_deprecated.hpp" -#include "ConstantMergedTensorDescriptor_deprecated.hpp" - -namespace ck { -template -__device__ void threadwise_generic_tensor_set_zero(TDesc, Float* __restrict__ p) -{ - static_ford{}([&](auto multi_id) { - constexpr index_t offset = TDesc::GetOffsetFromMultiIndex(multi_id); - - p[offset] = static_cast(0); - }); -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp deleted file mode 100644 index f9f48a18b7..0000000000 --- a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp +++ /dev/null @@ -1,191 +0,0 @@ -#ifndef CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_HPP -#define CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "tensor_coordinate.hpp" - -namespace ck { - -// This threadwise copy allow vector access of src and dst. -// It allows the vector size to be different on src and dst. -// The dimensions of vector access should be the same on src and dst. -// The dimension access order should be the same on src and dst. -// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping -// Will do valid mapping check on dst data: No write if dst data has a invalid mapping -template -struct ThreadwiseGenericTensorSliceCopy_v4r2 -{ - static constexpr index_t nDim = SliceLengths::Size(); - using Index = MultiIndex; - - using SrcCoord = typename TensorCoordinate::type; - using DstCoord = typename TensorCoordinate::type; - - __device__ constexpr ThreadwiseGenericTensorSliceCopy_v4r2(const Index& src_slice_origin, - const Index& dst_slice_origin) - : mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin) - { - static_assert(nDim == SrcDesc::GetNumOfDimension() && - nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::Size() && - nDim == SrcDstDimAccessOrder::Size(), - "wrong! # of dimensions not the same"); - - static_assert(is_valid_sequence_map{}, "wrong! map is not valid"); - - static_assert(SliceLengths{}[SrcDstVectorReadWriteDim] % - math::lcm(SrcDataPerRead, DstDataPerWrite) == - 0, - "wrong! cannot evenly divide"); - - // TODO:: sanity-check if vectorized memory read/write is allowed on src and dst - } - - __device__ constexpr ThreadwiseGenericTensorSliceCopy_v4r2() - : ThreadwiseGenericTensorSliceCopy_v4r2(make_zero_multi_index(), - make_zero_multi_index()) - { - } - - __device__ void SetSrcSliceOrigin(SrcCoord src_slice_origin) - { - mSrcSliceOrigin = src_slice_origin; - } - - __device__ void SetDstSliceOrigin(DstCoord dst_slice_origin) - { - mDstSliceOrigin = dst_slice_origin; - } - - template - __device__ void Run(const SrcData* p_src, DstData* p_dst) const - { - constexpr auto vector_access_dim = Number{}; - - constexpr auto src_data_per_access = Number{}; - constexpr auto dst_data_per_access = Number{}; - - constexpr auto long_vector_size = Number{}; - - constexpr auto long_vector_access_lengths = SliceLengths::Modify( - vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size); - - ford{}( - [&](auto long_vector_access_id) { - - // data id w.r.t slicing-window - auto long_vector_data_begin_id = long_vector_access_id; - long_vector_data_begin_id(vector_access_dim) = - long_vector_size * long_vector_access_id[vector_access_dim]; - - // buffer to hold a src long-vector - SrcData p_src_long_vector[long_vector_size]; - - // load data from src to the long-vector buffer - static_for<0, long_vector_size / src_data_per_access, 1>{}([&](auto i) { - auto scalar_id = make_zero_multi_index(); - scalar_id(vector_access_dim) = i * src_data_per_access; - - const index_t buffer_offset = i * src_data_per_access; - - const auto src_coord = - mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id); - - // Check src data's valid mapping situation, only check the first data in this - // src - // vector. It's user's responsiblity to make sure all data in the src vector - // has the valid/invalid mapping situation - transfer_data(p_src, - src_coord.GetOffset(), - src_coord.IsOffsetValidAssumingUpperIndexIsValid(), - SrcDesc::GetElementSpace(), - p_src_long_vector, - buffer_offset, - true, - long_vector_size); - }); - - // SrcData to DstData conversion - DstData p_dst_long_vector[long_vector_size]; - - static_for<0, long_vector_size, 1>{}([&](auto i) { - p_dst_long_vector[i] = type_convert{}(p_src_long_vector[i]); - }); - - // store data from the long-vector buffer to dst - static_for<0, long_vector_size / dst_data_per_access, 1>{}([&](auto i) { - auto scalar_id = make_zero_multi_index(); - scalar_id(vector_access_dim) = i * dst_data_per_access; - - const index_t buffer_offset = i * dst_data_per_access; - - const auto dst_coord = - mDstSliceOrigin + (long_vector_data_begin_id + scalar_id); - - // Check dst data's valid mapping situation, only check the first data in this - // dst - // vector. It's user's responsiblity to make sure all data in the dst vector - // has the valid/invalid mapping situation - transfer_data(p_dst_long_vector, - buffer_offset, - true, - long_vector_size, - p_dst, - dst_coord.GetOffset(), - dst_coord.IsOffsetValidAssumingUpperIndexIsValid(), - DstDesc::GetElementSpace()); - }); - }); - } - - template - __device__ void MoveSrcSliceWindow(const T& step_sizes_, - integral_constant) - { - const auto step_sizes = to_multi_index(step_sizes_); - - static_if{}([&](auto) { mSrcSliceOrigin += to_multi_index(step_sizes); }) - .Else([&](auto) { mSrcSliceOrigin -= step_sizes; }); - } - - template - __device__ void MoveDstSliceWindow(const T& step_sizes_, - integral_constant) - { - const auto step_sizes = to_multi_index(step_sizes_); - - static_if{}([&](auto) { mDstSliceOrigin += step_sizes; }) - .Else([&](auto) { mDstSliceOrigin -= step_sizes; }); - } - - private: - SrcCoord mSrcSliceOrigin; - DstCoord mDstSliceOrigin; -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp index 5fbc22c807..876a1174e7 100644 --- a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp +++ b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp @@ -2,7 +2,6 @@ #define CK_XDLOPS_GEMM_HPP #include "common_header.hpp" -#include "ConstantMatrixDescriptor.hpp" #include "math.hpp" #include "amd_xdlops.hpp" diff --git a/composable_kernel/include/utility/amd_buffer_addressing.hpp b/composable_kernel/include/utility/amd_buffer_addressing.hpp deleted file mode 100644 index 380a14003d..0000000000 --- a/composable_kernel/include/utility/amd_buffer_addressing.hpp +++ /dev/null @@ -1,1042 +0,0 @@ -#ifndef CK_AMD_BUFFER_ADDRESSING_HPP -#define CK_AMD_BUFFER_ADDRESSING_HPP - -#include "float_type.hpp" -#include "amd_buffer_addressing_v2.hpp" - -namespace ck { - -template -union BufferResource -{ - // 128 bit SGPRs to supply buffer resource in buffer instructions - // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions - int32x4_t data; - T* address[2]; - int32_t range[4]; - int32_t config[4]; -}; - -__device__ float __llvm_amdgcn_buffer_load_f32(int32x4_t srsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.load.f32"); - -__device__ float2_t -__llvm_amdgcn_buffer_load_f32x2(int32x4_t srsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.load.v2f32"); - -__device__ float4_t -__llvm_amdgcn_buffer_load_f32x4(int32x4_t srsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.load.v4f32"); -__device__ half_t -__llvm_amdgcn_raw_buffer_load_f16(int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); - -__device__ ushort -__llvm_amdgcn_raw_buffer_load_bf16(int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.bf16"); - -__device__ void __llvm_amdgcn_buffer_store_f32(float vdata, - int32x4_t srsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.store.f32"); - -__device__ void __llvm_amdgcn_buffer_store_f32x2(float2_t vdata, - int32x4_t srsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.store.v2f32"); - -__device__ void __llvm_amdgcn_buffer_store_f32x4(float4_t vdata, - int32x4_t srsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.store.v4f32"); - -__device__ void -__llvm_amdgcn_raw_buffer_store_f16(half_t vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16"); - -__device__ void -__llvm_amdgcn_raw_buffer_store_bf16(ushort vdata, - int32x4_t rsrc, - index_t voffset, - index_t soffset, - index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.bf16"); - -#if CK_USE_AMD_BUFFER_ATOMIC_FADD -#if CK_HIP_VERSION_FLAT >= 3010020405 -// starting ROCm-3.10, the return type becomes float -__device__ float -#else -__device__ void -#endif -__llvm_amdgcn_buffer_atomic_add_f32(float vdata, - int32x4_t rsrc, - index_t vindex, - index_t offset, - bool slc) __asm("llvm.amdgcn.buffer.atomic.fadd.f32"); -#endif - -// buffer_load requires: -// 1) p_src_wave must be in global memory space -// 2) p_src_wave to be a wavewise pointer. -// It is user's responsibility to make sure that is true. -template -__device__ typename vector_type::type amd_buffer_load(const T* p_src_wave, - index_t src_thread_data_offset, - bool src_thread_data_valid, - index_t src_elemenst_space); - -// buffer_store requires: -// 1) p_src_thread must be in vgpr space, p_dst_thread must be global memory -// 2) p_dst_thread to be a wavewise pointer. -// It is user's responsibility to make sure that is true. -template -__device__ void amd_buffer_store(const T* p_src_thread, - T* p_dst_wave, - index_t dst_thread_data_offset, - bool dst_thread_data_valid, - index_t dst_data_range); - -// buffer_atomic requires: -// 1) p_src_thread must be in vgpr space, p_dst_thread must be global memory -// 2) p_dst_thread to be a wavewise pointer. -// It is user's responsibility to make sure that is true. -template -__device__ void amd_buffer_atomic_add(const T* p_src_thread, - T* p_dst_wave, - index_t dst_thread_data_offset, - bool dst_thread_data_valid, - index_t dst_data_range); - -template <> -__device__ float amd_buffer_load(const float* p_src_wave, - index_t src_thread_data_offset, - bool src_thread_data_valid, - index_t src_data_range) -{ - BufferResource src_wave_buffer_resource; - - // wavewise base address (64 bit) - src_wave_buffer_resource.address[0] = const_cast(p_src_wave); - // wavewise range (32 bit) - src_wave_buffer_resource.range[2] = src_data_range * sizeof(float); - // wavewise setting (32 bit) - src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float); - -#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK - uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; - - return __llvm_amdgcn_buffer_load_f32( - src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false); -#else - float tmp = __llvm_amdgcn_buffer_load_f32( - src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false); - - return src_thread_data_valid ? tmp : float(0); -#endif -} - -template <> -__device__ float2_t amd_buffer_load(const float* p_src_wave, - index_t src_thread_data_offset, - bool src_thread_data_valid, - index_t src_data_range) -{ - BufferResource src_wave_buffer_resource; - - // wavewise base address (64 bit) - src_wave_buffer_resource.address[0] = const_cast(p_src_wave); - // wavewise range (32 bit) - src_wave_buffer_resource.range[2] = src_data_range * sizeof(float); - // wavewise setting (32 bit) - src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float); - -#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK - uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; - - return __llvm_amdgcn_buffer_load_f32x2( - src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false); -#else - float2_t tmp = __llvm_amdgcn_buffer_load_f32x2( - src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false); - - return src_thread_data_valid ? tmp : float2_t(0); -#endif -} - -template <> -__device__ float4_t amd_buffer_load(const float* p_src_wave, - index_t src_thread_data_offset, - bool src_thread_data_valid, - index_t src_data_range) -{ - BufferResource src_wave_buffer_resource; - - // wavewise base address (64 bit) - src_wave_buffer_resource.address[0] = const_cast(p_src_wave); - // wavewise range (32 bit) - src_wave_buffer_resource.range[2] = src_data_range * sizeof(float); - // wavewise setting (32 bit) - src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float); - -#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK - uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; - - return __llvm_amdgcn_buffer_load_f32x4( - src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false); -#else - float4_t tmp = __llvm_amdgcn_buffer_load_f32x4( - src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false); - - return src_thread_data_valid ? tmp : float4_t(0); -#endif -} - -template <> -__device__ half_t amd_buffer_load(const half_t* p_src_wave, - index_t src_thread_data_offset, - bool src_thread_data_valid, - index_t src_data_range) -{ - BufferResource src_wave_buffer_resource; - - // wavewise base address (64 bit) - src_wave_buffer_resource.address[0] = const_cast(p_src_wave); - // wavewise range (32 bit) - src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t); - // wavewise setting (32 bit) - src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t); - -#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK - uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; - - // current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and - // everything is passed to Voffset - return __llvm_amdgcn_raw_buffer_load_f16( - src_wave_buffer_resource.data, src_addr_shift + src_thread_addr_offset, 0, 0); -#else - half_t zero(0); - - // current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and - // everything is passed to Voffset - return src_thread_data_valid ? __llvm_amdgcn_raw_buffer_load_f16( - src_wave_buffer_resource.data, src_thread_addr_offset, 0, 0) - : zero; -#endif -} - -template <> -__device__ half2_t amd_buffer_load(const half_t* p_src_wave, - index_t src_thread_data_offset, - bool src_thread_data_valid, - index_t src_data_range) -{ - BufferResource src_wave_buffer_resource; - - // wavewise base address (64 bit) - src_wave_buffer_resource.address[0] = const_cast(p_src_wave); - // wavewise range (32 bit) - src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t); - // wavewise setting (32 bit) - src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t); - -#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK - uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; - - float dst_out_tmp = __llvm_amdgcn_buffer_load_f32( - src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false); - - return *reinterpret_cast(&dst_out_tmp); -#else - half2_t zeros(0); - - float dst_out_tmp = __llvm_amdgcn_buffer_load_f32( - src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false); - - return src_thread_data_valid ? *reinterpret_cast(&dst_out_tmp) : zeros; -#endif -} - -template <> -__device__ half4_t amd_buffer_load(const half_t* p_src_wave, - index_t src_thread_data_offset, - bool src_thread_data_valid, - index_t src_data_range) -{ - BufferResource src_wave_buffer_resource; - - // wavewise base address (64 bit) - src_wave_buffer_resource.address[0] = const_cast(p_src_wave); - // wavewise range (32 bit) - src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t); - // wavewise setting (32 bit) - src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t); - -#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK - uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; - - float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2( - src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false); - - return *reinterpret_cast(&dst_out_tmp); -#else - half4_t zeros(0); - - float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2( - src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false); - - return src_thread_data_valid ? *reinterpret_cast(&dst_out_tmp) : zeros; -#endif -} - -template <> -__device__ half8_t amd_buffer_load(const half_t* p_src_wave, - index_t src_thread_data_offset, - bool src_thread_data_valid, - index_t src_data_range) -{ - BufferResource src_wave_buffer_resource; - - // wavewise base address (64 bit) - src_wave_buffer_resource.address[0] = const_cast(p_src_wave); - // wavewise range (32 bit) - src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t); - // wavewise setting (32 bit) - src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t); - -#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK - uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; - - float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4( - src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false); - - return *reinterpret_cast(&dst_out_tmp); -#else - half8_t zeros(0); - - float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4( - src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false); - - return src_thread_data_valid ? *reinterpret_cast(&dst_out_tmp) : zeros; -#endif -} - -template <> -__device__ ushort amd_buffer_load(const ushort* p_src_wave, - index_t src_thread_data_offset, - bool src_thread_data_valid, - index_t src_data_range) -{ - BufferResource src_wave_buffer_resource; - - // wavewise base address (64 bit) - src_wave_buffer_resource.address[0] = const_cast(p_src_wave); - // wavewise range (32 bit) - src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort); - // wavewise setting (32 bit) - src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort); - -#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK - uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; - - // current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and - // everything is passed to Voffset - return __llvm_amdgcn_raw_buffer_load_bf16( - src_wave_buffer_resource.data, src_addr_shift + src_thread_addr_offset, 0, 0); -#else - ushort zero(0); - - // current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and - // everything is passed to Voffset - return src_thread_data_valid ? __llvm_amdgcn_raw_buffer_load_bf16( - src_wave_buffer_resource.data, src_thread_addr_offset, 0, 0) - : zero; -#endif -} - -template <> -__device__ ushort2_t amd_buffer_load(const ushort* p_src_wave, - index_t src_thread_data_offset, - bool src_thread_data_valid, - index_t src_data_range) -{ - BufferResource src_wave_buffer_resource; - - // wavewise base address (64 bit) - src_wave_buffer_resource.address[0] = const_cast(p_src_wave); - // wavewise range (32 bit) - src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort); - // wavewise setting (32 bit) - src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort); - -#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK - uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; - - float dst_out_tmp = __llvm_amdgcn_buffer_load_f32( - src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false); - - return *reinterpret_cast(&dst_out_tmp); -#else - ushort2_t zeros(0); - - float dst_out_tmp = __llvm_amdgcn_buffer_load_f32( - src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false); - - return src_thread_data_valid ? *reinterpret_cast(&dst_out_tmp) : zeros; -#endif -} - -template <> -__device__ ushort4_t amd_buffer_load(const ushort* p_src_wave, - index_t src_thread_data_offset, - bool src_thread_data_valid, - index_t src_data_range) -{ - BufferResource src_wave_buffer_resource; - - // wavewise base address (64 bit) - src_wave_buffer_resource.address[0] = const_cast(p_src_wave); - // wavewise range (32 bit) - src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort); - // wavewise setting (32 bit) - src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort); - -#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK - uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; - - float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2( - src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false); - - return *reinterpret_cast(&dst_out_tmp); -#else - ushort4_t zeros(0); - - float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2( - src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false); - - return src_thread_data_valid ? *reinterpret_cast(&dst_out_tmp) : zeros; -#endif -} - -template <> -__device__ ushort8_t amd_buffer_load(const ushort* p_src_wave, - index_t src_thread_data_offset, - bool src_thread_data_valid, - index_t src_data_range) -{ - BufferResource src_wave_buffer_resource; - - // wavewise base address (64 bit) - src_wave_buffer_resource.address[0] = const_cast(p_src_wave); - // wavewise range (32 bit) - src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort); - // wavewise setting (32 bit) - src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort); - -#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK - uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; - - float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4( - src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false); - - return *reinterpret_cast(&dst_out_tmp); -#else - ushort8_t zeros(0); - - float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4( - src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false); - - return src_thread_data_valid ? *reinterpret_cast(&dst_out_tmp) : zeros; -#endif -} - -template <> -__device__ void amd_buffer_store(const float* p_src_thread, - float* p_dst_wave, - index_t dst_thread_data_offset, - bool dst_thread_data_valid, - index_t dst_data_range) -{ - BufferResource dst_wave_buffer_resource; - - // wavewise base address (64 bit) - dst_wave_buffer_resource.address[0] = p_dst_wave; - // wavewise range (32 bit) - dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float); - // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); - -#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; - - __llvm_amdgcn_buffer_store_f32(*p_src_thread, - dst_wave_buffer_resource.data, - 0, - dst_addr_shift + dst_thread_addr_offset, - false, - false); -#else - if(dst_thread_data_valid) - { - __llvm_amdgcn_buffer_store_f32( - *p_src_thread, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false); - } -#endif -} - -template <> -__device__ void amd_buffer_store(const float* p_src_thread, - float* p_dst_wave, - index_t dst_thread_data_offset, - bool dst_thread_data_valid, - index_t dst_data_range) -{ - BufferResource dst_wave_buffer_resource; - - // wavewise base address (64 bit) - dst_wave_buffer_resource.address[0] = p_dst_wave; - // wavewise range (32 bit) - dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float); - // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); - -#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; - - __llvm_amdgcn_buffer_store_f32x2(*reinterpret_cast(p_src_thread), - dst_wave_buffer_resource.data, - 0, - dst_addr_shift + dst_thread_addr_offset, - false, - false); -#else - if(dst_thread_data_valid) - { - __llvm_amdgcn_buffer_store_f32x2(*reinterpret_cast(p_src_thread), - dst_wave_buffer_resource.data, - 0, - dst_thread_addr_offset, - false, - false); - } -#endif -} - -template <> -__device__ void amd_buffer_store(const float* p_src_thread, - float* p_dst_wave, - index_t dst_thread_data_offset, - bool dst_thread_data_valid, - index_t dst_data_range) -{ - BufferResource dst_wave_buffer_resource; - - // wavewise base address (64 bit) - dst_wave_buffer_resource.address[0] = p_dst_wave; - // wavewise range (32 bit) - dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float); - // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); - -#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; - - __llvm_amdgcn_buffer_store_f32x4(*reinterpret_cast(p_src_thread), - dst_wave_buffer_resource.data, - 0, - dst_addr_shift + dst_thread_addr_offset, - false, - false); -#else - if(dst_thread_data_valid) - { - __llvm_amdgcn_buffer_store_f32x4(*reinterpret_cast(p_src_thread), - dst_wave_buffer_resource.data, - 0, - dst_thread_addr_offset, - false, - false); - } -#endif -} - -template <> -__device__ void amd_buffer_store(const half_t* p_src_thread, - half_t* p_dst_wave, - index_t dst_thread_data_offset, - bool dst_thread_data_valid, - index_t dst_data_range) -{ - BufferResource dst_wave_buffer_resource; - - // wavewise base address (64 bit) - dst_wave_buffer_resource.address[0] = p_dst_wave; - // wavewise range (32 bit) - dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t); - // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t); - -#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; - - // current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and - // everything is passed to Voffset - __llvm_amdgcn_raw_buffer_store_f16(*p_src_thread, - dst_wave_buffer_resource.data, - dst_addr_shift + dst_thread_addr_offset, - 0, - 0); -#else - if(dst_thread_data_valid) - { - // current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and - // everything is passed to Voffset - __llvm_amdgcn_raw_buffer_store_f16( - *p_src_thread, dst_wave_buffer_resource.data, dst_thread_addr_offset, 0, 0); - } -#endif -} - -template <> -__device__ void amd_buffer_store(const half_t* p_src_thread, - half_t* p_dst_wave, - index_t dst_thread_data_offset, - bool dst_thread_data_valid, - index_t dst_data_range) -{ - BufferResource dst_wave_buffer_resource; - - // wavewise base address (64 bit) - dst_wave_buffer_resource.address[0] = p_dst_wave; - // wavewise range (32 bit) - dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t); - // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t); - - const float* p_src_tmp = reinterpret_cast(p_src_thread); - -#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; - - __llvm_amdgcn_buffer_store_f32(*p_src_tmp, - dst_wave_buffer_resource.data, - 0, - dst_addr_shift + dst_thread_addr_offset, - false, - false); -#else - if(dst_thread_data_valid) - { - __llvm_amdgcn_buffer_store_f32( - *p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false); - } -#endif -} - -template <> -__device__ void amd_buffer_store(const half_t* p_src_thread, - half_t* p_dst_wave, - index_t dst_thread_data_offset, - bool dst_thread_data_valid, - index_t dst_data_range) -{ - BufferResource dst_wave_buffer_resource; - - // wavewise base address (64 bit) - dst_wave_buffer_resource.address[0] = p_dst_wave; - // wavewise range (32 bit) - dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t); - // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t); - - const float2_t* p_src_tmp = reinterpret_cast(p_src_thread); - -#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; - - __llvm_amdgcn_buffer_store_f32x2(*p_src_tmp, - dst_wave_buffer_resource.data, - 0, - dst_addr_shift + dst_thread_addr_offset, - false, - false); -#else - if(dst_thread_data_valid) - { - __llvm_amdgcn_buffer_store_f32x2( - *p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false); - } -#endif -} - -template <> -__device__ void amd_buffer_store(const half_t* p_src_thread, - half_t* p_dst_wave, - index_t dst_thread_data_offset, - bool dst_thread_data_valid, - index_t dst_data_range) -{ - BufferResource dst_wave_buffer_resource; - - // wavewise base address (64 bit) - dst_wave_buffer_resource.address[0] = p_dst_wave; - // wavewise range (32 bit) - dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t); - // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t); - - const float4_t* p_src_tmp = reinterpret_cast(p_src_thread); - -#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; - - __llvm_amdgcn_buffer_store_f32x4(*p_src_tmp, - dst_wave_buffer_resource.data, - 0, - dst_addr_shift + dst_thread_addr_offset, - false, - false); -#else - if(dst_thread_data_valid) - { - __llvm_amdgcn_buffer_store_f32x4( - *p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false); - } -#endif -} - -template <> -__device__ void amd_buffer_store(const ushort* p_src_thread, - ushort* p_dst_wave, - index_t dst_thread_data_offset, - bool dst_thread_data_valid, - index_t dst_data_range) -{ - BufferResource dst_wave_buffer_resource; - - // wavewise base address (64 bit) - dst_wave_buffer_resource.address[0] = p_dst_wave; - // wavewise range (32 bit) - dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort); - // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort); - -#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; - - __llvm_amdgcn_raw_buffer_store_bf16(*p_src_thread, - dst_wave_buffer_resource.data, - dst_addr_shift + dst_thread_addr_offset, - 0, - 0); -#else - if(dst_thread_data_valid) - { - __llvm_amdgcn_raw_buffer_store_bf16( - *p_src_thread, dst_wave_buffer_resource.data, dst_thread_addr_offset, 0, 0); - } -#endif -} - -template <> -__device__ void amd_buffer_store(const ushort* p_src_thread, - ushort* p_dst_wave, - index_t dst_thread_data_offset, - bool dst_thread_data_valid, - index_t dst_data_range) -{ - BufferResource dst_wave_buffer_resource; - - // wavewise base address (64 bit) - dst_wave_buffer_resource.address[0] = p_dst_wave; - // wavewise range (32 bit) - dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort); - // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort); - - const float* p_src_tmp = reinterpret_cast(p_src_thread); - -#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; - - __llvm_amdgcn_buffer_store_f32(*p_src_tmp, - dst_wave_buffer_resource.data, - 0, - dst_addr_shift + dst_thread_addr_offset, - false, - false); -#else - if(dst_thread_data_valid) - { - __llvm_amdgcn_buffer_store_f32( - *p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false); - } -#endif -} - -template <> -__device__ void amd_buffer_store(const ushort* p_src_thread, - ushort* p_dst_wave, - index_t dst_thread_data_offset, - bool dst_thread_data_valid, - index_t dst_data_range) -{ - BufferResource dst_wave_buffer_resource; - - // wavewise base address (64 bit) - dst_wave_buffer_resource.address[0] = p_dst_wave; - // wavewise range (32 bit) - dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort); - // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort); - - const float2_t* p_src_tmp = reinterpret_cast(p_src_thread); - -#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; - - __llvm_amdgcn_buffer_store_f32x2(*p_src_tmp, - dst_wave_buffer_resource.data, - 0, - dst_addr_shift + dst_thread_addr_offset, - false, - false); -#else - if(dst_thread_data_valid) - { - __llvm_amdgcn_buffer_store_f32x2( - *p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false); - } -#endif -} - -template <> -__device__ void amd_buffer_store(const ushort* p_src_thread, - ushort* p_dst_wave, - index_t dst_thread_data_offset, - bool dst_thread_data_valid, - index_t dst_data_range) -{ - BufferResource dst_wave_buffer_resource; - - // wavewise base address (64 bit) - dst_wave_buffer_resource.address[0] = p_dst_wave; - // wavewise range (32 bit) - dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort); - // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort); - - const float4_t* p_src_tmp = reinterpret_cast(p_src_thread); - -#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; - - __llvm_amdgcn_buffer_store_f32x4(*p_src_tmp, - dst_wave_buffer_resource.data, - 0, - dst_addr_shift + dst_thread_addr_offset, - false, - false); -#else - if(dst_thread_data_valid) - { - __llvm_amdgcn_buffer_store_f32x4( - *p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false); - } -#endif -} - -#if CK_USE_AMD_BUFFER_ATOMIC_FADD -template <> -__device__ void amd_buffer_atomic_add(const float* p_src_thread, - float* p_dst_wave, - index_t dst_thread_data_offset, - bool dst_thread_data_valid, - index_t dst_data_range) -{ - BufferResource dst_wave_buffer_resource; - - // wavewise base address (64 bit) - dst_wave_buffer_resource.address[0] = p_dst_wave; - // wavewise range (32 bit) - dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float); - // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); - -#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; - - __llvm_amdgcn_buffer_atomic_add_f32(*p_src_thread, - dst_wave_buffer_resource.data, - 0, - dst_addr_shift + dst_thread_addr_offset, - false); -#else - if(dst_thread_data_valid) - { - __llvm_amdgcn_buffer_atomic_add_f32( - *p_src_thread, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false); - } -#endif -} - -template <> -__device__ void amd_buffer_atomic_add(const float* p_src_thread, - float* p_dst_wave, - index_t dst_thread_data_offset, - bool dst_thread_data_valid, - index_t dst_data_range) -{ - BufferResource dst_wave_buffer_resource; - - // wavewise base address (64 bit) - dst_wave_buffer_resource.address[0] = p_dst_wave; - // wavewise range (32 bit) - dst_wave_buffer_resource.range[2] = dst_data_range; - // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); - -#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; - - for(index_t i = 0; i < 2; ++i) - { - __llvm_amdgcn_buffer_atomic_add_f32(p_src_thread[i], - dst_wave_buffer_resource.data, - 0, - dst_addr_shift + dst_thread_addr_offset + - i * sizeof(float), - false); - } -#else - if(dst_thread_data_valid) - { - for(index_t i = 0; i < 2; ++i) - { - __llvm_amdgcn_buffer_atomic_add_f32(p_src_thread[i], - dst_wave_buffer_resource.data, - 0, - dst_thread_addr_offset + i * sizeof(float), - false); - } - } -#endif -} - -template <> -__device__ void amd_buffer_atomic_add(const float* p_src_thread, - float* p_dst_wave, - index_t dst_thread_data_offset, - bool dst_thread_data_valid, - index_t dst_data_range) -{ - BufferResource dst_wave_buffer_resource; - - // wavewise base address (64 bit) - dst_wave_buffer_resource.address[0] = p_dst_wave; - // wavewise range (32 bit) - dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float); - // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; - - index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); - -#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK - uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; - - for(index_t i = 0; i < 4; ++i) - { - __llvm_amdgcn_buffer_atomic_add_f32(p_src_thread[i], - dst_wave_buffer_resource.data, - 0, - dst_addr_shift + dst_thread_addr_offset + - i * sizeof(float), - false); - } -#else - if(dst_thread_data_valid) - { - for(index_t i = 0; i < 4; ++i) - { - __llvm_amdgcn_buffer_atomic_add_f32(p_src_thread[i], - dst_wave_buffer_resource.data, - 0, - dst_thread_addr_offset + i * sizeof(float), - false); - } - } -#endif -} -#endif // CK_USE_AMD_BUFFER_ATOMIC_FADD - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/amd_dlop.hpp b/composable_kernel/include/utility/amd_dlop.hpp index c67cfc7118..e5b9d901ba 100644 --- a/composable_kernel/include/utility/amd_dlop.hpp +++ b/composable_kernel/include/utility/amd_dlop.hpp @@ -23,6 +23,48 @@ amd_inner_product_dlop(const float& a, const float& b, floa #endif } +template <> +__device__ void +amd_inner_product_dlop(const float2_t& a, const float2_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + amd_inner_product_dlop(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); +} + +template <> +__device__ void +amd_inner_product_dlop(const float4_t& a, const float4_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + amd_inner_product_dlop(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I2], + vector_type{b}.AsType()[I2], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I3], + vector_type{b}.AsType()[I3], + c); +} + #if CK_USE_AMD_DLOP template <> __device__ void diff --git a/composable_kernel/include/utility/common_header.hpp b/composable_kernel/include/utility/common_header.hpp index 9517c8e8bd..32e2abd99f 100644 --- a/composable_kernel/include/utility/common_header.hpp +++ b/composable_kernel/include/utility/common_header.hpp @@ -13,7 +13,6 @@ #include "functional2.hpp" #include "functional3.hpp" #include "functional4.hpp" -#include "in_memory_operation.hpp" #include "integral_constant.hpp" #include "math.hpp" #include "number.hpp" @@ -25,6 +24,7 @@ #include "type.hpp" #include "utility.hpp" #include "magic_division.hpp" +#include "amd_buffer_addressing_v2.hpp" #include "static_buffer.hpp" #include "dynamic_buffer.hpp" diff --git a/composable_kernel/include/utility/config.nvidia.hpp.in b/composable_kernel/include/utility/config.nvidia.hpp.in deleted file mode 100644 index 2c26d4d624..0000000000 --- a/composable_kernel/include/utility/config.nvidia.hpp.in +++ /dev/null @@ -1,54 +0,0 @@ -#ifndef CK_CONFIG_NVIDIA_HPP -#define CK_CONFIG_NVIDIA_HPP - -#include -#include -#include - -// index type: unsigned or signed -#define CK_UNSIGNED_INDEX_TYPE 0 - -// device backend -#define CK_DEVICE_BACKEND_NVIDIA 1 - -// disable AMD inline asm and intrinsic -#define CK_USE_AMD_INLINE_ASM 0 -#define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 0 -#define CK_USE_AMD_BUFFER_ADDRESSING 0 -#define CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC 0 -#define CK_USE_AMD_XDLOPS 0 -#define CK_USE_AMD_XDLOPS_INLINE_ASM 0 -#define CK_USE_AMD_XDLOPS_EMULATE 0 - -// experimental implementation -#define CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE 0 -#define CK_EXPERIMENTAL_TENSOR_COORDINATE_USE_CALCULATE_OFFSET_DIFF 0 -#define CK_EXPERIMENTAL_THREADWISE_COPY_V4R2_USE_OPTIMIZED_ADDRESS_CACLULATION 0 -#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0 -#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0 -#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0 - -namespace ck { - -enum AddressSpace -{ - Generic, - Global, - Lds, - Vgpr -}; - -enum InMemoryDataOperation -{ - Set, - AtomicAdd -}; - -#if CK_UNSIGNED_INDEX_TYPE -using index_t = uint32_t; -#else -using index_t = int32_t; -#endif - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/float_type.nvidia.hpp.in b/composable_kernel/include/utility/float_type.nvidia.hpp.in deleted file mode 100644 index 82b147483a..0000000000 --- a/composable_kernel/include/utility/float_type.nvidia.hpp.in +++ /dev/null @@ -1,180 +0,0 @@ -#ifndef CK_FLOAT_TYPE_NVIDIA_HPP -#define CK_FLOAT_TYPE_NVIDIA_HPP - -#include "number.hpp" - -namespace ck { - -// For some reason, CUDA need this definition, otherwise -// compiler won't generate optimal load and store instruction, and -// kernel would produce wrong result, indicating the compiler fail to generate correct -// instruction, -// float -using float2_t = float2; -using float4_t = float4; - -// float -typedef float float32_t __attribute__((ext_vector_type(32))); - -// bfloat16 -typedef ushort ushort2_t __attribute__((ext_vector_type(2))); -typedef ushort ushort4_t __attribute__((ext_vector_type(4))); -typedef ushort ushort8_t __attribute__((ext_vector_type(8))); - -// fp16 -using half_t = half; -using half2_t = half2; -using half4_t = float2; - -template -struct vector_type -{ - typedef struct - { - T scalar[N]; - } type; -}; - -template <> -struct vector_type -{ - using type = float; - - template - __host__ __device__ static void SetScalar(type& v, float s, Number) - { - static_assert(I < 1, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } -}; - -template <> -struct vector_type -{ - using type = float2_t; - - union DataType - { - type vector; - float scalar[2]; - }; - - template - __host__ __device__ static void SetScalar(type& v, float s, Number) - { - static_assert(I < 2, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } - - __host__ __device__ static type Pack(float s0, float s1) - { - DataType data; - data.scalar[0] = s0; - data.scalar[1] = s1; - return data.vector; - } -}; - -template <> -struct vector_type -{ - using type = float4_t; - - __host__ __device__ static constexpr index_t GetSize() { return 4; } - - template - __host__ __device__ static void SetScalar(type& v, float s, Number) - { - static_assert(I < 4, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } -}; - -template <> -struct vector_type -{ - using type = half_t; - - template - __host__ __device__ static void SetScalar(type& v, half_t s, Number) - { - static_assert(I < 1, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } -}; - -template <> -struct vector_type -{ - using type = half2_t; - - union DataType - { - type vector; - half_t scalar[2]; - }; - - template - __host__ __device__ static void SetScalar(type& v, half_t s, Number) - { - static_assert(I < 2, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } - - __host__ __device__ static type Pack(half_t s0, half_t s1) - { - DataType data; - data.scalar[0] = s0; - data.scalar[1] = s1; - return data.vector; - } -}; - -// data type conversion -template -struct type_convert -{ - template - __device__ T operator()(const X& x) const - { - return static_cast(x); - } -}; - -template -struct inner_product_with_conversion -{ - static constexpr auto convert = type_convert(); - - __device__ T operator()(float a, float b) const { return convert(a) * convert(b); } - - __device__ T operator()(half2_t a, half2_t b) const - { - const half_t* p_a_half = reinterpret_cast(&a); - const half_t* p_b_half = reinterpret_cast(&b); - - T acc = 0; - for(index_t v = 0; v < 2; ++v) - { - acc += convert(p_a_half[v]) * convert(p_b_half[v]); - } - - return acc; - } - - __device__ T operator()(half4_t a, half4_t b) const - { - const half_t* p_a_half = reinterpret_cast(&a); - const half_t* p_b_half = reinterpret_cast(&b); - - T acc = 0; - for(index_t v = 0; v < 4; ++v) - { - acc += convert(p_a_half[v]) * convert(p_b_half[v]); - } - return acc; - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/in_memory_operation.amd.hpp.in b/composable_kernel/include/utility/in_memory_operation.amd.hpp.in deleted file mode 100644 index 97ea488a63..0000000000 --- a/composable_kernel/include/utility/in_memory_operation.amd.hpp.in +++ /dev/null @@ -1,241 +0,0 @@ -#ifndef CK_IN_MEMORY_OPERATION_AMD_HPP -#define CK_IN_MEMORY_OPERATION_AMD_HPP - -#include "float_type.hpp" - -#if CK_USE_AMD_BUFFER_ADDRESSING -#include "amd_buffer_addressing.hpp" -#include "amd_buffer_addressing_v2.hpp" -#endif - -namespace ck { - -template -__device__ void atomic_add_impl(T* p_dst, T src) -{ - atomicAdd(p_dst, src); -} - -// atomicAdd for float does not support vector type -template <> -__device__ void atomic_add_impl(float2_t* p_dst, float2_t src) -{ - float* p_dst_float = reinterpret_cast(p_dst); - const float* p_src_float = reinterpret_cast(&src); - - for(index_t i = 0; i < 2; ++i) - { - atomicAdd(&(p_dst_float[i]), p_src_float[i]); - } -} - -template <> -__device__ void atomic_add_impl(float4_t* p_dst, float4_t src) -{ - float* p_dst_float = reinterpret_cast(p_dst); - const float* p_src_float = reinterpret_cast(&src); - - for(index_t i = 0; i < 4; ++i) - { - atomicAdd(&(p_dst_float[i]), p_src_float[i]); - } -} - -template -struct SetData -{ - using vector_t = typename vector_type::type; - - // This version is only for compatibility, don't use this version if possible - template - __device__ void Run(const T* p_src, - index_t src_offset, - bool src_valid, - index_t /* src_range */, - T* p_dst, - index_t dst_offset, - bool dst_valid, - index_t /* dst_range */) const - { - if(dst_valid) - { - if(src_valid) - { -#if 0 - *reinterpret_cast(&p_dst[dst_offset]) = - *reinterpret_cast(&p_src[src_offset]); -#else - *reinterpret_cast(&p_dst[dst_offset]) = - *reinterpret_cast(&p_src[0x3fffffff & src_offset]); -#endif - } - else - { - *reinterpret_cast(&p_dst[dst_offset]) = 0; - } - } - } - -#if CK_USE_AMD_BUFFER_ADDRESSING - // buffer_load requires: - // 1) p_src_thread must be in global memory space, p_dst_thread must be vgpr - // 2) p_src_thread to be a wavewise pointer. - // It is user's responsibility to make sure that is true. - template <> - __device__ void Run(const T* p_src, - index_t src_offset, - bool src_valid, - index_t src_range, - T* p_dst, - index_t dst_offset, - bool dst_valid, - index_t /* dst_range */) const - { - if(dst_valid) - { - *reinterpret_cast(&p_dst[dst_offset]) = - amd_buffer_load_v2(p_src, src_offset, src_valid, src_range); - } - } - - // buffer_store requires: - // 1) p_src_thread must be in vgpr space, p_dst_thread must be global memory - // 2) p_dst_thread to be a wavewise pointer. - // It is user's responsibility to make sure that is true. - template <> - __device__ void Run(const T* p_src, - index_t src_offset, - bool src_valid, - index_t /* src_range */, - T* p_dst, - index_t dst_offset, - bool dst_valid, - index_t dst_range) const - { - const auto zeros = vector_t(0); - - amd_buffer_store_v2( - src_valid ? *reinterpret_cast(&(p_src[src_offset])) : zeros, - p_dst, - dst_offset, - dst_valid, - dst_range); - } -#endif -}; - -template -struct AtomicAddData -{ - using vector_t = typename vector_type::type; - - // This version is only for compatibility, don't use this version if possible - template - __device__ void Run(const T* p_src, - index_t src_offset, - bool src_valid, - index_t /* src_range */, - T* p_dst, - index_t dst_offset, - bool dst_valid, - index_t /* dst_range */) const - { - if(src_valid && dst_valid) - { - atomic_add_impl(reinterpret_cast(&p_dst[dst_offset]), - *reinterpret_cast(&p_src[src_offset])); - } - } - -#if CK_USE_AMD_BUFFER_ADDRESSING && CK_USE_AMD_BUFFER_ATOMIC_FADD - // buffer_atomic requires: - // 1) p_src_thread must be in vgpr space, p_dst_thread must be global memory - // 2) p_dst_thread to be a wavewise pointer. - // It is user's responsibility to make sure that is true. - template <> - __device__ void Run(const T* p_src, - index_t src_offset, - bool src_valid, - index_t /* src_range */, - T* p_dst, - index_t dst_offset, - bool dst_valid, - index_t dst_range) const - { - const auto zeros = vector_t(0); - - amd_buffer_atomic_add( - src_valid ? &(p_src[src_offset]) : &zeros, p_dst, dst_offset, dst_valid, dst_range); - } -#endif -}; - -template -__device__ void transfer_data(const T* p_src, - index_t src_offset, - bool src_valid, - index_t src_range, - T* p_dst, - index_t dst_offset, - bool dst_valid, - index_t dst_range) -{ - static_assert(DstInMemOp == InMemoryDataOperation::Set || - DstInMemOp == InMemoryDataOperation::AtomicAdd, - "wrong! InMemoryDataOperation not supported!"); - - // keep it simple, don't use static_if here, otherwise compiler will do weird things - if constexpr(SrcDataStride == 1 && DstDataStride == 1) - { - if constexpr(DstInMemOp == InMemoryDataOperation::Set) - { - SetData{}.template Run( - p_src, src_offset, src_valid, src_range, p_dst, dst_offset, dst_valid, dst_range); - } - else if constexpr(DstInMemOp == InMemoryDataOperation::AtomicAdd) - { - AtomicAddData{}.template Run( - p_src, src_offset, src_valid, src_range, p_dst, dst_offset, dst_valid, dst_range); - } - } - else - { -#pragma unroll - for(index_t i = 0; i < DataPerAccess; ++i) - { - if constexpr(DstInMemOp == InMemoryDataOperation::Set) - { - SetData{}.template Run( - p_src, - src_offset + i * SrcDataStride, - src_valid, - src_range, - p_dst, - dst_offset + i * DstDataStride, - dst_valid, - dst_range); - } - else if constexpr(DstInMemOp == InMemoryDataOperation::AtomicAdd) - { - AtomicAddData{}.template Run( - p_src, - src_offset + i * SrcDataStride, - src_valid, - src_range, - p_dst, - dst_offset + i * DstDataStride, - dst_valid, - dst_range); - } - } - } -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in b/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in deleted file mode 100644 index 2778321035..0000000000 --- a/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in +++ /dev/null @@ -1,109 +0,0 @@ -#ifndef CK_IN_MEMORY_OPERATION_NVIDIA_HPP -#define CK_IN_MEMORY_OPERATION_NVIDIA_HPP - -namespace ck { - -template -__device__ void atomic_add_impl(T* p_dst, T src) -{ - atomicAdd(p_dst, src); -} - -// atomicAdd for float does not support vector type -template <> -__device__ void atomic_add_impl(float2_t* p_dst, float2_t src) -{ - float* p_dst_float = reinterpret_cast(p_dst); - const float* p_src_float = reinterpret_cast(&src); - - for(index_t i = 0; i < 2; ++i) - { - atomicAdd(&(p_dst_float[i]), p_src_float[i]); - } -} - -template <> -__device__ void atomic_add_impl(float4_t* p_dst, float4_t src) -{ - float* p_dst_float = reinterpret_cast(p_dst); - const float* p_src_float = reinterpret_cast(&src); - - for(index_t i = 0; i < 4; ++i) - { - atomicAdd(&(p_dst_float[i]), p_src_float[i]); - } -} - -template -struct SetData -{ - using vector_t = typename vector_type::type; - - template - __device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const - { - *reinterpret_cast(&p_dst[dst_offset]) = - *reinterpret_cast(&p_src[src_offset]); - } -}; - -template -struct AtomicAddData -{ - using vector_t = typename vector_type::type; - - template - __device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const - { - atomic_add_impl(reinterpret_cast(&p_dst[dst_offset]), - *reinterpret_cast(&p_src[src_offset])); - } -}; - -template -__device__ void transfer_data(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) -{ - static_assert(DstInMemOp == InMemoryDataOperation::Set || - DstInMemOp == InMemoryDataOperation::AtomicAdd, - "wrong! InMemoryDataOperation not supported!"); - - // keep it simple, don't use static_if here, otherwise compiler will do weird things - if(SrcDataStride == 1 && DstDataStride == 1) - { - // TODO: use static_if::ElseIf - static_if{}([&](auto) { - SetData{}.template Run( - p_src, src_offset, p_dst, dst_offset); - }); - - static_if{}([&](auto) { - AtomicAddData{}.template Run( - p_src, src_offset, p_dst, dst_offset); - }); - } - else - { - for(index_t i = 0; i < DataPerAccess; i++) - { - // TODO: use static_if::ElseIf - static_if{}([&](auto) { - SetData{}.template Run( - p_src, src_offset + i * SrcDataStride, p_dst, dst_offset + i * DstDataStride); - }); - - static_if{}([&](auto) { - AtomicAddData{}.template Run( - p_src, src_offset + i * SrcDataStride, p_dst, dst_offset + i * DstDataStride); - }); - } - } -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/synchronization.nvidia.hpp.in b/composable_kernel/include/utility/synchronization.nvidia.hpp.in deleted file mode 100644 index 030b86e12d..0000000000 --- a/composable_kernel/include/utility/synchronization.nvidia.hpp.in +++ /dev/null @@ -1,13 +0,0 @@ -#ifndef CK_SYNCHRONIZATION_NVIDIA_HPP -#define CK_SYNCHRONIZATION_NVIDIA_HPP - -#include "config.hpp" - -namespace ck { - -__device__ void block_sync_lds() { __syncthreads(); } - -__device__ void block_sync_lds_vmem() { __syncthreads(); } - -} // namespace ck -#endif diff --git a/composable_kernel/src/kernel_wrapper/gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.cpp b/composable_kernel/src/kernel_wrapper/gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.cpp deleted file mode 100644 index ecd3af822f..0000000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.cpp +++ /dev/null @@ -1,8 +0,0 @@ - -extern "C" __global__ void -gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer( - const void* const __restrict__ p_in_global, - const void* const __restrict__ p_wei_global, - void* const __restrict__ p_out_global){ - -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/gridwise_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.cpp deleted file mode 100644 index 820a0515ee..0000000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.cpp +++ /dev/null @@ -1,7 +0,0 @@ - -extern "C" __global__ void gridwise_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( - const void* const __restrict__ p_in_global, - const void* const __restrict__ p_wei_global, - void* const __restrict__ p_out_global){ - -}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.cpp b/composable_kernel/src/kernel_wrapper/gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.cpp deleted file mode 100644 index 4f646adbb7..0000000000 --- a/composable_kernel/src/kernel_wrapper/gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.cpp +++ /dev/null @@ -1,8 +0,0 @@ - - -extern "C" __global__ void gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( - const void* const __restrict__ p_in_global, - const void* const __restrict__ p_wei_global, - void* const __restrict__ p_out_global){ - -}; diff --git a/driver/conv_bwd_data_driver.cpp b/driver/conv_bwd_data_driver.cpp deleted file mode 100644 index 63723f5f4f..0000000000 --- a/driver/conv_bwd_data_driver.cpp +++ /dev/null @@ -1,299 +0,0 @@ -#include -#include -#include -#include -#include -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor_generator.hpp" -#include "device_tensor.hpp" -#include "conv_common.hpp" -#include "host_conv_bwd_data.hpp" -#include "device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp" -#include "device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp" -#include "device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" -#include "device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp" - -int main(int argc, char* argv[]) -{ - using namespace launcher; - -#if 1 - // 1x1 filter, 14x14 image - constexpr index_t N = 1; - constexpr index_t C = 256; - constexpr index_t HI = 1; - constexpr index_t WI = 128; - constexpr index_t K = 16; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - constexpr index_t N = 64; - constexpr index_t C = 256; - constexpr index_t HI = 56; - constexpr index_t WI = 56; - constexpr index_t K = 256; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 3x3, 34x34 - constexpr index_t N = 64; - constexpr index_t C = 256; - constexpr index_t HI = 34; - constexpr index_t WI = 34; - constexpr index_t K = 256; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 3x3, 28x28 - constexpr index_t N = 128; - constexpr index_t C = 128; - constexpr index_t HI = 28; - constexpr index_t WI = 28; - constexpr index_t K = 128; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<1, 1>; - using RightPads = Sequence<1, 1>; -#elif 0 - // 1x1 filter, 8x8 image - constexpr index_t N = 256; - constexpr index_t C = 1024; - constexpr index_t HI = 8; - constexpr index_t WI = 8; - constexpr index_t K = 1024; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 7x7 image - constexpr index_t N = 128; - constexpr index_t C = 1024; - constexpr index_t HI = 7; - constexpr index_t WI = 7; - constexpr index_t K = 1024; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 1 - // 1x1 filter, 14x14 image - constexpr index_t N = 128; - constexpr index_t C = 512; - constexpr index_t HI = 14; - constexpr index_t WI = 14; - constexpr index_t K = 128; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 28x28 image - constexpr index_t N = 128; - constexpr index_t C = 128; - constexpr index_t HI = 28; - constexpr index_t WI = 28; - constexpr index_t K = 128; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 1x1 filter, 17x17 input - constexpr index_t N = 128; - constexpr index_t C = 1024; - constexpr index_t HI = 17; - constexpr index_t WI = 17; - constexpr index_t K = 1024; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - // 5x5 filter, 2x2 pad, 7x7 input - constexpr index_t N = 128; - constexpr index_t C = 1024; - constexpr index_t HI = 7; - constexpr index_t WI = 7; - constexpr index_t K = 1024; - constexpr index_t Y = 5; - constexpr index_t X = 5; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<2, 2>; - using RightPads = Sequence<2, 2>; -#elif 0 - // 1x7 filter, 0x3 pad, 17x17 input - constexpr index_t N = 128; - constexpr index_t C = 128; - constexpr index_t HI = 17; - constexpr index_t WI = 17; - constexpr index_t K = 128; - constexpr index_t Y = 1; - constexpr index_t X = 7; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 3>; - using RightPads = Sequence<0, 3>; -#elif 0 - // 7x1 filter, 3x0 pad, 17x17 input - constexpr index_t N = 128; - constexpr index_t C = 256; - constexpr index_t HI = 17; - constexpr index_t WI = 17; - constexpr index_t K = 1024; - constexpr index_t Y = 7; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<3, 0>; - using RightPads = Sequence<3, 0>; -#elif 1 - // 3x3 filter, 2x2 stride, 35x35 input, 17x17 output - constexpr index_t N = 128; - constexpr index_t C = 256; - constexpr index_t HI = 35; - constexpr index_t WI = 35; - constexpr index_t K = 1280; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<2, 2>; - using ConvDilations = Sequence<2, 2>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#endif - - constexpr auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence{}); - constexpr auto wei_kcyx_desc = make_native_tensor_descriptor_packed(Sequence{}); - constexpr auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor( - in_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, LeftPads{}, RightPads{}); - - ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); - ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); - ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); - print_array("LeftPads", LeftPads{}); - print_array("LeftPads", LeftPads{}); - print_array("RightPads", RightPads{}); - print_array("ConvStrides", ConvStrides{}); - print_array("ConvDilations", ConvDilations{}); - - Tensor in_nchw_device(make_HostTensorDescriptor(in_nchw_desc)); - Tensor in_nchw_host(make_HostTensorDescriptor(in_nchw_desc)); - Tensor wei_kcyx(make_HostTensorDescriptor(wei_kcyx_desc)); - Tensor out_nkhw(make_HostTensorDescriptor(out_nkhw_desc)); - - std::size_t num_thread = std::thread::hardware_concurrency(); - - if(argc != 3) - { - printf("arg1: do_verification, arg2: nrepeat\n"); - exit(1); - } - - bool do_verification = atoi(argv[1]); - std::size_t nrepeat = atoi(argv[2]); - - if(do_verification) - { -#if 0 - wei_kcyx.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); - out_nkhw.GenerateTensorValue(GeneratorTensor_1{1}, num_thread); -#else - wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - out_nkhw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); -#endif - } - -#if 0 - device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw -#elif 0 - device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw -#elif 0 - device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw -#elif 1 - device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk -#endif - (in_nchw_desc, - in_nchw_device, - wei_kcyx_desc, - wei_kcyx, - out_nkhw_desc, - out_nkhw, - ConvStrides{}, - ConvDilations{}, - LeftPads{}, - RightPads{}, - nrepeat); - - if(do_verification) - { - host_direct_convolution_backward_data(in_nchw_host, - wei_kcyx, - out_nkhw, - ConvStrides{}, - ConvDilations{}, - LeftPads{}, - RightPads{}); - - check_error(in_nchw_host, in_nchw_device); - -#if 0 - LogRange(std::cout << "out_nkhw : ", out_nkhw.mData, ",") << std::endl; - LogRange(std::cout << "wei_kcyx : ", wei_kcyx.mData, ",") << std::endl; - LogRange(std::cout << "in_nchw_host : ", in_nchw_host.mData, ",") << std::endl; - LogRange(std::cout << "in_nchw_device : ", in_nchw_device.mData, ",") << std::endl; -#endif - } -} diff --git a/driver/conv_driver.cpp b/driver/conv_driver.cpp deleted file mode 100644 index b116b21046..0000000000 --- a/driver/conv_driver.cpp +++ /dev/null @@ -1,780 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include "config.hpp" -#include "print.hpp" -#include "device.hpp" -#include "host_tensor_generator.hpp" -#include "conv_common.hpp" -#include "host_conv.hpp" -#include "device_tensor.hpp" -#include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" -#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" -#include "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" - -int main(int argc, char* argv[]) -{ - using namespace ck; - - if(argc != 5) - { - printf("arg1: do_verification, arg2: do_log, arg3: init_method, arg4: nrepeat\n"); - exit(1); - } - - const bool do_verification = atoi(argv[1]); - const bool do_log = atoi(argv[2]); - const int init_method = atoi(argv[3]); - const int nrepeat = atoi(argv[4]); - -#if 0 - constexpr index_t N = 256; - constexpr index_t C = 256; - constexpr index_t HI = 16; - constexpr index_t WI = 16; - constexpr index_t K = 256; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 0>; - using InRightPads = Sequence<0, 0>; -#elif 0 - constexpr index_t N = 1; - constexpr index_t C = 16; - constexpr index_t HI = 1080; - constexpr index_t WI = 1920; - constexpr index_t K = 16; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<1, 1>; - using InRightPads = Sequence<1, 1>; -#elif 0 - constexpr index_t N = 1; - constexpr index_t C = 16; - constexpr index_t Hi = 540; - constexpr index_t Wi = 960; - constexpr index_t K = 16; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 0>; - using InRightPads = Sequence<0, 0>; -#elif 0 - constexpr index_t N = 1; - constexpr index_t C = 16; - constexpr index_t Hi = 270; - constexpr index_t Wi = 480; - constexpr index_t K = 16; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 0>; - using InRightPads = Sequence<0, 0>; -#elif 0 - constexpr index_t N = 1; - constexpr index_t C = 16; - constexpr index_t Hi = 1080; - constexpr index_t Wi = 1920; - constexpr index_t K = 16; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<1, 1>; - using InRightPads = Sequence<1, 1>; -#elif 0 - constexpr index_t N = 1; - constexpr index_t C = 1; - constexpr index_t Hi = 1024; - constexpr index_t Wi = 2048; - constexpr index_t K = 4; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<1, 1>; - using InRightPads = Sequence<1, 1>; -#elif 0 - constexpr index_t N = 1; - constexpr index_t C = 16; - constexpr index_t Hi = 540; - constexpr index_t Wi = 960; - constexpr index_t K = 16; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<1, 1>; - using InRightPads = Sequence<1, 1>; -#elif 0 - constexpr index_t N = 1; - constexpr index_t C = 16; - constexpr index_t Hi = 270; - constexpr index_t Wi = 480; - constexpr index_t K = 16; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<1, 1>; - using InRightPads = Sequence<1, 1>; -#elif 0 - // 3x3, 36x36, stride 2 - constexpr index_t N = 128; - constexpr index_t C = 192; - constexpr index_t Hi = 37; - constexpr index_t Wi = 37; - constexpr index_t K = 384; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<2, 2>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 0>; - using InRightPads = Sequence<0, 0>; -#elif 0 - // 3x3, 35x35, stride 2 - constexpr index_t N = 128; - constexpr index_t C = 192; - constexpr index_t Hi = 35; - constexpr index_t Wi = 35; - constexpr index_t K = 384; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<2, 2>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 0>; - using InRightPads = Sequence<0, 0>; -#elif 0 - // 3x3, 71x71 - constexpr index_t N = 128; - constexpr index_t C = 192; - constexpr index_t HI = 71; - constexpr index_t WI = 71; - constexpr index_t K = 256; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<2, 2>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<1, 1>; - using InRightPads = Sequence<1, 1>; -#elif 0 - // 1x1, 8x8 - constexpr index_t N = 128; - constexpr index_t C = 1536; - constexpr index_t Hi = 8; - constexpr index_t Wi = 8; - constexpr index_t K = 256; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 0>; - using InRightPads = Sequence<0, 0>; -#elif 0 - // 1x1, 73x73 - constexpr index_t N = 128; - constexpr index_t C = 160; - constexpr index_t Hi = 73; - constexpr index_t Wi = 73; - constexpr index_t K = 64; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 0>; - using InRightPads = Sequence<0, 0>; -#elif 0 - // 3x3, 35x35 - constexpr index_t N = 128; - constexpr index_t C = 96; - constexpr index_t Hi = 35; - constexpr index_t Wi = 35; - constexpr index_t K = 128; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<1, 1>; - using InRightPads = Sequence<1, 1>; -#elif 0 - // 3x3, 71x71 - constexpr index_t N = 128; - constexpr index_t C = 192; - constexpr index_t Hi = 71; - constexpr index_t Wi = 71; - constexpr index_t K = 192; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<2, 2>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<1, 1>; - using InRightPads = Sequence<1, 1>; -#elif 0 - // 7x1, 17x17 - constexpr index_t N = 128; - constexpr index_t C = 128; - constexpr index_t Hi = 17; - constexpr index_t Wi = 17; - constexpr index_t K = 128; - constexpr index_t Y = 7; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<3, 0>; - using InRightPads = Sequence<3, 0>; -#elif 1 - // 1x7, 17x17 - constexpr index_t N = 128; - constexpr index_t C = 128; - constexpr index_t Hi = 17; - constexpr index_t Wi = 17; - constexpr index_t K = 128; - constexpr index_t Y = 1; - constexpr index_t X = 7; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 3>; - using InRightPads = Sequence<0, 3>; -#elif 0 - // 3x3, 299x299 stride=2 - constexpr index_t N = 128; - constexpr index_t C = 3; - constexpr index_t Hi = 299; - constexpr index_t Wi = 299; - constexpr index_t K = 32; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<2, 2>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 0>; - using InRightPads = Sequence<0, 0>; -#elif 0 - // 3x3, 147x147 - constexpr index_t N = 128; - constexpr index_t C = 128; - constexpr index_t Hi = 147; - constexpr index_t Wi = 147; - constexpr index_t K = 128; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<1, 1>; - using InRightPads = Sequence<1, 1>; -#elif 0 - // 3x3, 149x149 - constexpr index_t N = 128; - constexpr index_t C = 32; - constexpr index_t Hi = 149; - constexpr index_t Wi = 149; - constexpr index_t K = 32; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 0>; - using InRightPads = Sequence<0, 0>; -#elif 0 - // 3x3, 17x17, stride 2 - constexpr index_t N = 128; - constexpr index_t C = 192; - constexpr index_t Hi = 17; - constexpr index_t Wi = 17; - constexpr index_t K = 192; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<2, 2>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 0>; - using InRightPads = Sequence<0, 0>; -#elif 0 - // 1x1, 35x35 - constexpr index_t N = 128; - constexpr index_t C = 384; - constexpr index_t Hi = 35; - constexpr index_t Wi = 35; - constexpr index_t K = 96; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 0>; - using InRightPads = Sequence<0, 0>; -#elif 0 - // 3x3, 35x35, stride 2 - constexpr index_t N = 128; - constexpr index_t C = 288; - constexpr index_t Hi = 35; - constexpr index_t Wi = 35; - constexpr index_t K = 384; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<2, 2>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 0>; - using InRightPads = Sequence<0, 0>; -#elif 0 - // 1x3, 8x8 - constexpr index_t N = 128; - constexpr index_t C = 384; - constexpr index_t Hi = 8; - constexpr index_t Wi = 8; - constexpr index_t K = 448; - constexpr index_t Y = 1; - constexpr index_t X = 3; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 1>; - using InRightPads = Sequence<0, 1>; -#elif 0 - // 3x1, 8x8 - constexpr index_t N = 128; - constexpr index_t C = 448; - constexpr index_t Hi = 8; - constexpr index_t Wi = 8; - constexpr index_t K = 512; - constexpr index_t Y = 3; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<1, 0>; - using InRightPads = Sequence<1, 0>; -#elif 0 - // 3x3, 147x147 - constexpr index_t N = 128; - constexpr index_t C = 64; - constexpr index_t Hi = 147; - constexpr index_t Wi = 147; - constexpr index_t K = 96; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<2, 2>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 0>; - using InRightPads = Sequence<0, 0>; -#elif 0 - // 7x1, 73x73 - constexpr index_t N = 128; - constexpr index_t C = 64; - constexpr index_t Hi = 73; - constexpr index_t Wi = 73; - constexpr index_t K = 64; - constexpr index_t Y = 7; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<3, 0>; - using InRightPads = Sequence<3, 0>; -#elif 0 - // 3x3, 73x73 - constexpr index_t N = 128; - constexpr index_t C = 64; - constexpr index_t Hi = 73; - constexpr index_t Wi = 73; - constexpr index_t K = 96; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 0>; - using InRightPads = Sequence<0, 0>; -#elif 0 - // 1x1, 14x14, stride 2 - constexpr index_t N = 256; - constexpr index_t C = 1024; - constexpr index_t Hi = 14; - constexpr index_t Wi = 14; - constexpr index_t K = 2048; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<2, 2>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 0>; - using InRightPads = Sequence<0, 0>; -#elif 0 - // 1x1, 14x14 - constexpr index_t N = 256; - constexpr index_t C = 1024; - constexpr index_t Hi = 14; - constexpr index_t Wi = 14; - constexpr index_t K = 256; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 0>; - using InRightPads = Sequence<0, 0>; -#elif 0 - // 1x1, 14x14, stride 2 - constexpr index_t N = 128; - constexpr index_t C = 1024; - constexpr index_t Hi = 14; - constexpr index_t Wi = 14; - constexpr index_t K = 512; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<2, 2>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 0>; - using InRightPads = Sequence<0, 0>; -#elif 1 - // 3x3, 28x28 - constexpr index_t N = 128; - constexpr index_t C = 128; - constexpr index_t Hi = 28; - constexpr index_t Wi = 28; - constexpr index_t K = 128; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<1, 1>; - using InRightPads = Sequence<1, 1>; -#elif 1 - // 3x3, 14x14 - constexpr index_t N = 128; - constexpr index_t C = 256; - constexpr index_t Hi = 14; - constexpr index_t Wi = 14; - constexpr index_t K = 256; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<1, 1>; - using InRightPads = Sequence<1, 1>; -#elif 0 - // 1x1, 56x56, stride 2 - constexpr index_t N = 128; - constexpr index_t C = 256; - constexpr index_t Hi = 56; - constexpr index_t Wi = 56; - constexpr index_t K = 128; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<2, 2>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 0>; - using InRightPads = Sequence<0, 0>; -#elif 0 - // 7x7, 230x230 stride=2 - constexpr index_t N = 128; - constexpr index_t C = 3; - constexpr index_t Hi = 230; - constexpr index_t Wi = 230; - constexpr index_t K = 64; - constexpr index_t Y = 7; - constexpr index_t X = 7; - - using ConvStrides = Sequence<2, 2>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 0>; - using InRightPads = Sequence<0, 0>; -#elif 0 - // 1x1, 28x28, stride = 2 - constexpr index_t N = 128; - constexpr index_t C = 512; - constexpr index_t Hi = 28; - constexpr index_t Wi = 28; - constexpr index_t K = 1024; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<2, 2>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 0>; - using InRightPads = Sequence<0, 0>; -#elif 0 - // 1x1, 28x28, stride 2 - constexpr index_t N = 128; - constexpr index_t C = 512; - constexpr index_t Hi = 28; - constexpr index_t Wi = 28; - constexpr index_t K = 256; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<2, 2>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 0>; - using InRightPads = Sequence<0, 0>; -#elif 1 - // 1x1, 7x7 - constexpr index_t N = 128; - constexpr index_t C = 512; - constexpr index_t Hi = 7; - constexpr index_t Wi = 7; - constexpr index_t K = 2048; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 0>; - using InRightPads = Sequence<0, 0>; -#elif 0 - // 3x3, 7x7 - constexpr index_t N = 128; - constexpr index_t C = 512; - constexpr index_t Hi = 7; - constexpr index_t Wi = 7; - constexpr index_t K = 512; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<1, 1>; - using InRightPads = Sequence<1, 1>; -#elif 0 - // 1x1, 56x56 - constexpr index_t N = 128; - constexpr index_t C = 64; - constexpr index_t Hi = 56; - constexpr index_t Wi = 56; - constexpr index_t K = 64; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<0, 0>; - using InRightPads = Sequence<0, 0>; -#elif 0 - // 3x3, 56x56 - constexpr index_t N = 128; - constexpr index_t C = 64; - constexpr index_t Hi = 56; - constexpr index_t Wi = 56; - constexpr index_t K = 64; - constexpr index_t Y = 3; - constexpr index_t X = 3; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using InLeftPads = Sequence<1, 1>; - using InRightPads = Sequence<1, 1>; -#endif - - constexpr index_t YEff = (Y - 1) * ConvDilations{}[0] + 1; - constexpr index_t XEff = (X - 1) * ConvDilations{}[1] + 1; - - constexpr index_t Ho = (Hi + InLeftPads{}[0] + InRightPads{}[0] - YEff) / ConvStrides{}[0] + 1; - constexpr index_t Wo = (Wi + InLeftPads{}[1] + InRightPads{}[1] - XEff) / ConvStrides{}[1] + 1; - -#if 1 - constexpr index_t in_vector_size = 1; - using in_data_t = typename vector_type::type; - using acc_data_t = float; - using out_data_t = float; -#elif 1 - using in_data_t = half_t; - constexpr index_t in_vector_size = 1; - using acc_data_t = float; - using out_data_t = half_t; -#elif 0 - constexpr index_t in_vector_size = 1; - using in_data_t = typename vector_type::type; - using acc_data_t = float; - using out_data_t = int8_t; -#elif 1 - constexpr index_t in_vector_size = 16; - using in_data_t = typename vector_type::type; - using acc_data_t = int32_t; - using out_data_t = int8_t; -#endif - - Tensor in_nchw(HostTensorDescriptor(std::initializer_list{N, C, Hi, Wi})); - Tensor wei_kcyx(HostTensorDescriptor(std::initializer_list{K, C, Y, X})); - Tensor out_nkhw_host( - HostTensorDescriptor(std::initializer_list{N, K, Ho, Wo})); - Tensor out_nkhw_device( - HostTensorDescriptor(std::initializer_list{N, K, Ho, Wo})); - - ostream_HostTensorDescriptor(in_nchw.mDesc, std::cout << "in_nchw_desc: "); - ostream_HostTensorDescriptor(wei_kcyx.mDesc, std::cout << "wei_kcyx_desc: "); - ostream_HostTensorDescriptor(out_nkhw_host.mDesc, std::cout << "out_nkhw_desc: "); - - print_array("InLeftPads", InLeftPads{}); - print_array("InRightPads", InRightPads{}); - print_array("ConvStrides", ConvStrides{}); - print_array("ConvDilations", ConvDilations{}); - - std::size_t num_thread = std::thread::hardware_concurrency(); - - if(do_verification) - { - switch(init_method) - { - case 0: - in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 1: - in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - case 2: - in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - break; - case 3: - in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - break; - default: - in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); - - auto gen_wei = [](auto... is) { - return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); - }; - wei_kcyx.GenerateTensorValue(gen_wei, num_thread); - } - } - - constexpr auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence{}); - constexpr auto wei_kcyx_desc = make_native_tensor_descriptor_packed(Sequence{}); - constexpr auto out_nkhw_desc = make_native_tensor_descriptor_packed(Sequence{}); - -#if 1 - device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, - in_nchw, - wei_kcyx_desc, - wei_kcyx, - out_nkhw_desc, - out_nkhw_device, - ConvStrides{}, - ConvDilations{}, - InLeftPads{}, - InRightPads{}, - nrepeat); -#elif 0 - device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc, - in_nchw, - wei_kcyx_desc, - wei_kcyx, - out_nkhw_desc, - out_nkhw_device, - ConvStrides{}, - ConvDilations{}, - InLeftPads{}, - InRightPads{}, - nrepeat); -#elif 0 - device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(in_nchw_desc, - in_nchw, - wei_kcyx_desc, - wei_kcyx, - out_nkhw_desc, - out_nkhw_device, - ConvStrides{}, - ConvDilations{}, - InLeftPads{}, - InRightPads{}, - nrepeat); -#endif - - if(do_verification) - { - host_direct_convolution(in_nchw, - wei_kcyx, - out_nkhw_host, - ConvStrides{}, - ConvDilations{}, - InLeftPads{}, - InRightPads{}); - - check_error(out_nkhw_host, out_nkhw_device); - - if(do_log) - { - LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl; - LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl; - LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl; - LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl; - } - } -} diff --git a/driver/conv_driver_v2.cpp b/driver/conv_driver_v2.cpp index 38b93395f9..93b13caaa4 100644 --- a/driver/conv_driver_v2.cpp +++ b/driver/conv_driver_v2.cpp @@ -24,16 +24,16 @@ #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" #define USE_DYNAMIC_MODE 1 -#define USE_CONV_FWD_V4R4_NCHW 0 -#define USE_CONV_FWD_V4R4_NHWC 0 -#define USE_CONV_FWD_V4R4R2_NHWC 0 -#define USE_CONV_FWD_V4R5_NCHW 0 +#define USE_CONV_FWD_V4R4_NCHW 1 +#define USE_CONV_FWD_V4R4_NHWC 1 +#define USE_CONV_FWD_V4R4R2_NHWC 1 +#define USE_CONV_FWD_V4R5_NCHW 1 #define USE_CONV_FWD_V4R5R2_NCHW 1 #define USE_CONV_FWD_V5R1_NCHW 0 -#define USE_CONV_FWD_V4R4_XDL_NCHW 0 -#define USE_CONV_FWD_V4R4R2_XDL_NHWC 0 -#define USE_CONV_FWD_V4R4R3_XDL_NHWC 0 -#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0 +#define USE_CONV_FWD_V4R4_XDL_NCHW 1 +#define USE_CONV_FWD_V4R4R2_XDL_NHWC 1 +#define USE_CONV_FWD_V4R4R3_XDL_NHWC 1 +#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1 enum ConvForwardAlgo { diff --git a/driver/include/conv_common.hpp b/driver/include/conv_common.hpp index 2b89bc876e..73126b3c79 100644 --- a/driver/include/conv_common.hpp +++ b/driver/include/conv_common.hpp @@ -1,7 +1,6 @@ #ifndef CONV_COMMON_HPP #define CONV_COMMON_HPP -#include "tensor_descriptor.hpp" #include "dynamic_tensor_descriptor.hpp" enum ConvTensorLayout @@ -13,53 +12,6 @@ enum ConvTensorLayout NHWCc }; -template -constexpr auto get_convolution_output_default_4d_tensor_descriptor( - InDesc, WeiDesc, ConvStrides, ConvDilations, LeftPads, RightPads) -{ - using namespace ck; - - constexpr auto in_desc = InDesc{}; - constexpr auto wei_desc = WeiDesc{}; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - static_assert(in_desc.GetNumOfDimension() == 4, "input nDim is not 4"); - static_assert(wei_desc.GetNumOfDimension() == 4, "weight nDim is not 4"); - static_assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1), - "input & weight dimension not consistent"); - - constexpr index_t N = in_desc.GetLength(I0); - constexpr index_t Hi = in_desc.GetLength(I2); - constexpr index_t Wi = in_desc.GetLength(I3); - - constexpr index_t K = wei_desc.GetLength(I0); - constexpr index_t Y = wei_desc.GetLength(I2); - constexpr index_t X = wei_desc.GetLength(I3); - - constexpr index_t LeftPadH = LeftPads{}.Get(I0); - constexpr index_t LeftPadW = LeftPads{}.Get(I1); - - constexpr index_t RightPadH = RightPads{}.Get(I0); - constexpr index_t RightPadW = RightPads{}.Get(I1); - - constexpr index_t YEff = (Y - 1) * ConvDilations{}[0] + 1; - constexpr index_t XEff = (X - 1) * ConvDilations{}[1] + 1; - - constexpr index_t Ho = (Hi + LeftPadH + RightPadH - YEff) / ConvStrides{}[0] + 1; - constexpr index_t Wo = (Wi + LeftPadW + RightPadW - XEff) / ConvStrides{}[1] + 1; - - return make_native_tensor_descriptor_packed(Sequence{}); -} - template -constexpr std::size_t calculate_convolution_memory_size(Float, InDesc, WeiDesc, OutDesc) -{ - using namespace ck; - - constexpr auto wei_desc = WeiDesc{}; - constexpr auto out_desc = OutDesc{}; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr index_t N = out_desc.GetLength(I0); - constexpr index_t K = out_desc.GetLength(I1); - constexpr index_t Ho = out_desc.GetLength(I2); - constexpr index_t Wo = out_desc.GetLength(I3); - - constexpr index_t C = wei_desc.GetLength(I1); - constexpr index_t Y = wei_desc.GetLength(I2); - constexpr index_t X = wei_desc.GetLength(I3); - - return sizeof(Float) * - (InDesc::GetElementSpace() + WeiDesc::GetElementSpace() + OutDesc::GetElementSpace()); -} - #endif diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 1b8e70878a..0000000000 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,221 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "gridwise_operation_wrapper.hpp" -#include "gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp" - -namespace launcher { - -using namespace ck; - -template -void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc in_nchw_desc, - Tensor& in_nchw, - WeiDesc wei_kcyx_desc, - const Tensor& wei_kcyx, - OutDesc out_nkhw_desc, - const Tensor& out_nkhw, - ConvStrides, - ConvDilations, - InLeftPads, - InRightPads, - std::size_t nrepeat) -{ - using namespace ck; - - constexpr index_t N = out_nkhw_desc.GetLengths()[0]; - constexpr index_t K = out_nkhw_desc.GetLengths()[1]; - constexpr index_t Ho = out_nkhw_desc.GetLengths()[2]; - constexpr index_t Wo = out_nkhw_desc.GetLengths()[3]; - - constexpr index_t C = wei_kcyx_desc.GetLengths()[1]; - constexpr index_t Y = wei_kcyx_desc.GetLengths()[2]; - constexpr index_t X = wei_kcyx_desc.GetLengths()[3]; - - std::size_t data_sz = sizeof(T); - DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); - DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); - DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace()); - - in_nchw_device_buf.ToDevice(in_nchw.mData.data()); - wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); - out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); - -#if 1 - // BlockSize = 256, each thread hold 64 data - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmThreadGemmDataPerReadM = 4; - constexpr index_t GemmThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 1 - // BlockSize = 256, each thread hold 64 data - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmThreadGemmDataPerReadM = 4; - constexpr index_t GemmThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; -#elif 1 - // BlockSize = 256, each thread hold 64 data - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 16; - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmThreadGemmDataPerReadM = 4; - constexpr index_t GemmThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 4>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; -#endif - - constexpr index_t GemmM = C * Y * X; - constexpr index_t GemmN = N * Ho * Wo; - - constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) * - math::integer_divide_ceil(GemmN, GemmNPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - using gridwise_conv_bwd_data = GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw< - GridSize, - BlockSize, - T, - T, - decltype(in_nchw_desc), - decltype(wei_kcyx_desc), - decltype(out_nkhw_desc), - ConvStrides, - ConvDilations, - InLeftPads, - InRightPads, - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerThread, - GemmNPerThread, - GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmThreadGemmDataPerReadM, - GemmThreadGemmDataPerReadN, - GemmABlockCopyThreadSliceLengths_GemmK_GemmM, - GemmABlockCopyThreadClusterLengths_GemmK_GemmM, - GemmABlockCopySrcDataPerRead_GemmM, - GemmABlockCopyDstDataPerWrite_GemmM, - GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, - GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, - GemmBBlockCopySrcDataPerRead_GemmN, - GemmBBlockCopyDstDataPerWrite_GemmN, - GemmCThreadCopyDstDataPerWrite_GemmN1>; - - for(index_t i = 0; i < 1; ++i) - { - std::cout << "Start running " << nrepeat << " times..." << std::endl; - - KernelTimer timer; - timer.Start(); - - for(index_t j = 0; j < nrepeat; ++j) - { - launch_kernel(run_gridwise_operation, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); - } - - timer.End(); - - float ave_time = timer.GetElapsedTime() / nrepeat; - - float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - in_nchw_device_buf.FromDevice(in_nchw.mData.data()); -} - -} // namespace launcher diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp deleted file mode 100644 index aeeef9ab87..0000000000 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,171 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "gridwise_operation_wrapper.hpp" -#include "gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp" - -namespace launcher { - -using namespace ck; - -template -void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc in_nchw_desc, - Tensor& in_nchw, - WeiDesc wei_kcyx_desc, - const Tensor& wei_kcyx, - OutDesc out_nkhw_desc, - const Tensor& out_nkhw, - ConvStrides, - ConvDilations, - LeftPads, - RightPads, - std::size_t nrepeat) -{ - using namespace ck; - - constexpr index_t N = out_nkhw_desc.GetLengths()[0]; - constexpr index_t K = out_nkhw_desc.GetLengths()[1]; - constexpr index_t Ho = out_nkhw_desc.GetLengths()[2]; - constexpr index_t Wo = out_nkhw_desc.GetLengths()[3]; - - constexpr index_t C = wei_kcyx_desc.GetLengths()[1]; - constexpr index_t Y = wei_kcyx_desc.GetLengths()[2]; - constexpr index_t X = wei_kcyx_desc.GetLengths()[3]; - - std::size_t data_sz = sizeof(T); - DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); - DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); - DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace()); - - in_nchw_device_buf.ToDevice(in_nchw.mData.data()); - wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); - out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); - -#if 1 - // BlockSize = 256, each thread hold 64 data - constexpr index_t BlockSize = 256; - - constexpr index_t BPerBlock = 32; - constexpr index_t EPerBlock = 32; - constexpr index_t KPerBlock = 16; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using OutBlockCopySubLengths_K_B_N0 = Sequence<2, 1, 4>; - using OutBlockCopyClusterLengths_K_B_N0 = Sequence<8, 32, 1>; - - constexpr index_t OutBlockCopySrcDataPerRead_B = 1; - constexpr index_t OutBlockCopyDstDataPerWrite_N0 = 4; - - using WeiBlockCopySubLengths_K_E_C0 = Sequence<2, 4, 1>; - using WeiBlockCopyClusterLengths_K_E_C0 = Sequence<8, 8, 4>; - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_C0 = 1; - - constexpr index_t InThreadCopyDstDataPerWrite_B = 1; -#endif - - constexpr index_t C0 = GemmMPerThread; - constexpr index_t N0 = GemmNPerThread; - - constexpr index_t C1 = C / C0; - constexpr index_t N1 = N / N0; - - constexpr index_t E = C1 * Y * X; - constexpr index_t B = (N1 * Ho * Wo); - - constexpr index_t GridSize = - ((E + EPerBlock - 1) / EPerBlock) * ((B + BPerBlock - 1) / BPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - using gridwise_conv_bwd_data = - GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer< - GridSize, - BlockSize, - T, - T, - decltype(in_nchw_desc), - decltype(wei_kcyx_desc), - decltype(out_nkhw_desc), - ConvStrides, - ConvDilations, - LeftPads, - RightPads, - EPerBlock, - BPerBlock, - KPerBlock, - GemmMPerThread, - GemmNPerThread, - GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmDataPerReadA, - GemmDataPerReadB, - OutBlockCopySubLengths_K_B_N0, - OutBlockCopyClusterLengths_K_B_N0, - OutBlockCopySrcDataPerRead_B, - OutBlockCopyDstDataPerWrite_N0, - WeiBlockCopySubLengths_K_E_C0, - WeiBlockCopyClusterLengths_K_E_C0, - WeiBlockCopySrcDataPerRead_E, - WeiBlockCopyDstDataPerWrite_C0, - InThreadCopyDstDataPerWrite_B>; - - for(index_t i = 0; i < 5; ++i) - { - std::cout << "Start running " << nrepeat << " times..." << std::endl; - - KernelTimer timer; - timer.Start(); - - for(index_t j = 0; j < nrepeat; ++j) - { - launch_kernel(run_gridwise_operation, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); - } - - timer.End(); - - float ave_time = timer.GetElapsedTime() / nrepeat; - - float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - in_nchw_device_buf.FromDevice(in_nchw.mData.data()); -} - -} // namespace launcher diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp deleted file mode 100644 index b4f421131c..0000000000 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,267 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "gridwise_operation_wrapper.hpp" -#include "gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" - -namespace launcher { - -using namespace ck; - -template -void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc in_nchw_desc, - Tensor& in_nchw, - WeiDesc wei_kcyx_desc, - const Tensor& wei_kcyx, - OutDesc out_nkhw_desc, - const Tensor& out_nkhw, - ConvStrides, - ConvDilations, - InLeftPads, - InRightPads, - std::size_t nrepeat) -{ - constexpr index_t N = out_nkhw_desc.GetLengths()[0]; - constexpr index_t K = out_nkhw_desc.GetLengths()[1]; - constexpr index_t C = wei_kcyx_desc.GetLengths()[1]; - - constexpr index_t Hi = in_nchw_desc.GetLengths()[2]; - constexpr index_t Wi = in_nchw_desc.GetLengths()[3]; - - constexpr index_t Ho = out_nkhw_desc.GetLengths()[2]; - constexpr index_t Wo = out_nkhw_desc.GetLengths()[3]; - - constexpr index_t Y = wei_kcyx_desc.GetLengths()[2]; - constexpr index_t X = wei_kcyx_desc.GetLengths()[3]; - - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - std::size_t data_sz = sizeof(T); - DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); - DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); - DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace()); - - in_nchw_device_buf.ToDevice(in_nchw.mData.data()); - wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); - out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); - -#if 1 - // cdata = 64, BlockSize = 256, 128x128x8 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 8; - constexpr index_t GemmNLevel1Cluster = 8; - constexpr index_t GemmThreadGemmDataPerReadM = 4; - constexpr index_t GemmThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 1 - // cdata = 64, BlockSize = 256, 128x128x8 - // GemmABlockCopySrcDataPerRead_GemmM = 4 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmThreadGemmDataPerReadM = 4; - constexpr index_t GemmThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 1 - // cdata = 64, BlockSize = 256, 128x128x16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 16; - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmThreadGemmDataPerReadM = 4; - constexpr index_t GemmThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 8>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<16, 16>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#endif - - constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; - constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; - - constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); - constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); - - constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); - constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); - - constexpr index_t HTildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]); - constexpr index_t WTildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]); - - constexpr index_t HTildaRight = math::min( - HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); - constexpr index_t WTildaRight = math::min( - WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); - - constexpr index_t HTildaSlice = HTildaRight - HTildaLeft; - constexpr index_t WTildaSlice = WTildaRight - WTildaLeft; - - constexpr index_t GemmM = C; - constexpr index_t GemmN = N * HTildaSlice * WTildaSlice; - - constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) * - math::integer_divide_ceil(GemmN, GemmNPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - for(index_t i = 0; i < 5; ++i) - { - std::cout << "Start running " << nrepeat << " times..." << std::endl; - - KernelTimer timer; - timer.Start(); - - for(index_t i = 0; i < nrepeat; ++i) - { - using GridwiseConvBwdData = - GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw< - GridSize, - BlockSize, - T, - T, - decltype(in_nchw_desc), - decltype(wei_kcyx_desc), - decltype(out_nkhw_desc), - ConvStrides, - ConvDilations, - InLeftPads, - InRightPads, - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerThread, - GemmNPerThread, - GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmThreadGemmDataPerReadM, - GemmThreadGemmDataPerReadN, - GemmABlockCopyThreadSliceLengths_GemmK_GemmM, - GemmABlockCopyThreadClusterLengths_GemmK_GemmM, - GemmABlockCopySrcDataPerRead_GemmM, - GemmABlockCopyDstDataPerWrite_GemmM, - GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, - GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, - GemmBBlockCopySrcDataPerRead_GemmN, - GemmBBlockCopyDstDataPerWrite_GemmN, - GemmCThreadCopyDstDataPerWrite_GemmN1>; - - static_for<0, GridwiseConvBwdData::GetNumberOfGemm(), 1>{}([&](auto gemm_id) { - constexpr auto gemm_sizes = GridwiseConvBwdData::GetGemmSize(gemm_id); - constexpr index_t gemm_k = gemm_sizes.At(2); - constexpr bool is_gemm_not_empty = gemm_k > 0; - - // only compile and run if GEMM is no empty - static_if{}([&](auto fwd) { - launch_kernel(run_gridwise_operation, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer()), - fwd(gemm_id)); - }); - }); - } - - timer.End(); - - float ave_time = timer.GetElapsedTime() / nrepeat; - - float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - in_nchw_device_buf.FromDevice(in_nchw.mData.data()); -} - -} // namespace launcher diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index b534215637..0000000000 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,266 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "gridwise_operation_wrapper.hpp" -#include "gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp" - -namespace launcher { - -using namespace ck; - -template -void device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk(InDesc in_nchw_desc, - Tensor& in_nchw, - WeiDesc wei_kcyx_desc, - const Tensor& wei_kcyx, - OutDesc out_nkhw_desc, - const Tensor& out_nkhw, - ConvStrides, - ConvDilations, - InLeftPads, - InRightPads, - std::size_t nrepeat) -{ - constexpr index_t N = out_nkhw_desc.GetLengths()[0]; - constexpr index_t K = out_nkhw_desc.GetLengths()[1]; - constexpr index_t C = wei_kcyx_desc.GetLengths()[1]; - - constexpr index_t Hi = in_nchw_desc.GetLengths()[2]; - constexpr index_t Wi = in_nchw_desc.GetLengths()[3]; - - constexpr index_t Ho = out_nkhw_desc.GetLengths()[2]; - constexpr index_t Wo = out_nkhw_desc.GetLengths()[3]; - - constexpr index_t Y = wei_kcyx_desc.GetLengths()[2]; - constexpr index_t X = wei_kcyx_desc.GetLengths()[3]; - - constexpr index_t ConvStrideH = ConvStrides{}[0]; - constexpr index_t ConvStrideW = ConvStrides{}[1]; - - constexpr index_t ConvDilationH = ConvDilations{}[0]; - constexpr index_t ConvDilationW = ConvDilations{}[1]; - - constexpr auto in_nhwc_desc = make_native_tensor_descriptor_packed(Sequence{}); - constexpr auto wei_kyxc_desc = make_native_tensor_descriptor_packed(Sequence{}); - constexpr auto out_nhwk_desc = make_native_tensor_descriptor_packed(Sequence{}); - - Tensor in_nhwc(make_HostTensorDescriptor(in_nhwc_desc)); - Tensor wei_kyxc(make_HostTensorDescriptor(wei_kyxc_desc)); - Tensor out_nhwk(make_HostTensorDescriptor(out_nhwk_desc)); - - auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) { - in_nhwc(n, hi, wi, c) = in_nchw(n, c, hi, wi); - }; - - auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) { - wei_kyxc(k, y, x, c) = wei_kcyx(k, c, y, x); - }; - - auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) { - out_nhwk(n, ho, wo, k) = out_nkhw(n, k, ho, wo); - }; - - make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)(std::thread::hardware_concurrency()); - make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)(std::thread::hardware_concurrency()); - make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)(std::thread::hardware_concurrency()); - - std::size_t data_sz = sizeof(T); - DeviceMem in_nhwc_device_buf(data_sz * in_nhwc.mDesc.GetElementSpace()); - DeviceMem wei_kyxc_device_buf(data_sz * wei_kyxc.mDesc.GetElementSpace()); - DeviceMem out_nhwk_device_buf(data_sz * out_nhwk.mDesc.GetElementSpace()); - - in_nhwc_device_buf.ToDevice(in_nhwc.mData.data()); - wei_kyxc_device_buf.ToDevice(wei_kyxc.mData.data()); - out_nhwk_device_buf.ToDevice(out_nhwk.mData.data()); - -#if 0 - // cdata = 64, BlockSize = 256, 128x128x8 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmThreadGemmDataPerReadM = 4; - constexpr index_t GemmThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1, 1, 4>; - using GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1, 8, 32>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; - - using GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 4, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 2, 128>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmK2 = 4; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 1 - // cdata = 64, BlockSize = 256, 128x128x16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 16; - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - constexpr index_t GemmThreadGemmDataPerReadM = 4; - constexpr index_t GemmThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1, 2, 4>; - using GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1, 8, 32>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; - - using GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 8, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 2, 128>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmK2 = 4; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#endif - - constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); - constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); - - constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH; - constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW; - - constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); - constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); - - constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); - constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW); - - constexpr index_t HTildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]); - constexpr index_t WTildaLeft = math::integer_divide_floor( - math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]); - - constexpr index_t HTildaRight = math::min( - HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1); - constexpr index_t WTildaRight = math::min( - WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1); - - constexpr index_t HTildaSlice = HTildaRight - HTildaLeft; - constexpr index_t WTildaSlice = WTildaRight - WTildaLeft; - - constexpr index_t GemmM = C; - constexpr index_t GemmN = N * HTildaSlice * WTildaSlice; - - constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) * - math::integer_divide_ceil(GemmN, GemmNPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - for(index_t i = 0; i < 5; ++i) - { - std::cout << "Start running " << nrepeat << " times..." << std::endl; - - KernelTimer timer; - timer.Start(); - - for(index_t i = 0; i < nrepeat; ++i) - { - using GridwiseConvBwdData = - GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk< - GridSize, - BlockSize, - T, - T, - decltype(in_nhwc_desc), - decltype(wei_kyxc_desc), - decltype(out_nhwk_desc), - ConvStrides, - ConvDilations, - InLeftPads, - InRightPads, - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerThread, - GemmNPerThread, - GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmThreadGemmDataPerReadM, - GemmThreadGemmDataPerReadN, - GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM, - GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM, - GemmABlockCopySrcDataPerRead_GemmM, - GemmABlockCopyDstDataPerWrite_GemmM, - GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN, - GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN, - GemmBBlockCopySrcDataPerRead_GemmK2, - GemmBBlockCopyDstDataPerWrite_GemmN, - GemmCThreadCopyDstDataPerWrite_GemmN1>; - - static_for<0, GridwiseConvBwdData::GetNumberOfGemm(), 1>{}([&](auto gemm_id) { - constexpr auto gemm_sizes = GridwiseConvBwdData::GetGemmSize(gemm_id); - constexpr index_t gemm_k2 = gemm_sizes[Number<4>{}]; - constexpr bool is_gemm_not_empty = gemm_k2 > 0; - - // only compile and run if GEMM is no empty - static_if{}([&](auto fwd) { - launch_kernel(run_gridwise_operation, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - static_cast(in_nhwc_device_buf.GetDeviceBuffer()), - static_cast(wei_kyxc_device_buf.GetDeviceBuffer()), - static_cast(out_nhwk_device_buf.GetDeviceBuffer()), - fwd(gemm_id)); - }); - }); - } - - timer.End(); - - float ave_time = timer.GetElapsedTime() / nrepeat; - - float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - in_nhwc_device_buf.FromDevice(in_nhwc.mData.data()); - - auto f_nhwc2nchw = [&](auto n, auto c, auto hi, auto wi) { - in_nchw(n, c, hi, wi) = in_nhwc(n, hi, wi, c); - }; - - make_ParallelTensorFunctor(f_nhwc2nchw, N, C, Hi, Wi)(std::thread::hardware_concurrency()); -} - -} // namespace launcher diff --git a/driver/include/device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 04eec6b9da..0000000000 --- a/driver/include/device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,849 +0,0 @@ -#pragma once -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "gridwise_operation_wrapper.hpp" -#include "gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp" - -template -void device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcyx, - OutDesc, - Tensor& out_nkhw, - ConvStrides, - ConvDilations, - LeftPads, - RightPads, - ck::index_t nrepeat) -{ - std::cout << "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw" << std::endl; - - using namespace ck; - - using TDevice = typename conditional::value, half_t, T>::type; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_desc = - make_native_tensor_descriptor(InDesc::GetLengths(), InDesc::GetStrides()); - constexpr auto wei_kcyx_desc = - make_native_tensor_descriptor(WeiDesc::GetLengths(), WeiDesc::GetStrides()); - constexpr auto out_nkhw_desc = - make_native_tensor_descriptor(OutDesc::GetLengths(), OutDesc::GetStrides()); - - constexpr index_t N = out_nkhw_desc.GetLength(I0); - constexpr index_t K = out_nkhw_desc.GetLength(I1); - constexpr index_t Ho = out_nkhw_desc.GetLength(I2); - constexpr index_t Wo = out_nkhw_desc.GetLength(I3); - - std::size_t data_sz = sizeof(T); - DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); - DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); - DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace()); - - in_nchw_device_buf.ToDevice(in_nchw.mData.data()); - wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); - out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); - -#if 0 - // cdata = 64, BlockSize = 256, 64x256x8 - constexpr index_t BlockSize = 256; - - constexpr index_t KPerBlock = 64; - constexpr index_t BPerBlock = 32; - constexpr index_t EPerBlock = 8; - - constexpr index_t GemmNRepeat = 2; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 16; - - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 32, 1>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; - - using WeiBlockCopySubLengths_E_K = Sequence<2, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 2; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 0 - // cdata = 64, BlockSize = 256, 128x128x4 - constexpr index_t BlockSize = 256; - - constexpr index_t KPerBlock = 128; - constexpr index_t BPerBlock = 16; - constexpr index_t EPerBlock = 4; - - constexpr index_t GemmNRepeat = 2; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 8; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 2>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 16, 2>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2; - - using WeiBlockCopySubLengths_E_K = Sequence<2, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 2; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 1 - // cdata = 64, BlockSize = 256, 128x128x8 - constexpr index_t BlockSize = 256; - - constexpr index_t KPerBlock = 128; - constexpr index_t BPerBlock = 16; - constexpr index_t EPerBlock = 8; - - constexpr index_t GemmNRepeat = 2; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 8; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; - - using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 1 - // cdata = 64, BlockSize = 256, 128x128x16 - constexpr index_t BlockSize = 256; - - constexpr index_t KPerBlock = 128; - constexpr index_t BPerBlock = 16; - constexpr index_t EPerBlock = 16; - - constexpr index_t GemmNRepeat = 2; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<16, 1, 16, 1>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; - - using WeiBlockCopySubLengths_E_K = Sequence<4, 2>; - using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2; -#elif 0 - // cdata = 4, BlockSize = 256, 128x128x16 - // for 1x1 - constexpr index_t BlockSize = 256; - - constexpr index_t KPerBlock = 128; - constexpr index_t BPerBlock = 16; - constexpr index_t EPerBlock = 16; - - constexpr index_t GemmNRepeat = 2; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<4, 1, 1, 2>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 16, 2>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2; - - using WeiBlockCopySubLengths_E_K = Sequence<4, 2>; - using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2; -#elif 0 - // cdata = 64, BlockSize = 128, 64x128x4 - constexpr index_t BlockSize = 128; - - constexpr index_t KPerBlock = 64; - constexpr index_t BPerBlock = 16; - constexpr index_t EPerBlock = 4; - - constexpr index_t GemmNRepeat = 2; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 16, 1>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; - - using WeiBlockCopySubLengths_E_K = Sequence<2, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 64>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 2; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 0 - // cdata = 64, BlockSize = 128, 64x128x8 - constexpr index_t BlockSize = 128; - - constexpr index_t KPerBlock = 64; - constexpr index_t BPerBlock = 16; - constexpr index_t EPerBlock = 8; - - constexpr index_t GemmNRepeat = 2; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 16, 1>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; - - using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 64>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 0 - // cdata = 64, BlockSize = 128, 64x128x16 - constexpr index_t BlockSize = 128; - - constexpr index_t KPerBlock = 64; - constexpr index_t BPerBlock = 16; - constexpr index_t EPerBlock = 16; - - constexpr index_t GemmNRepeat = 2; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<2, 2, 1, 4>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 16, 1>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; - - using WeiBlockCopySubLengths_E_K = Sequence<4, 2>; - using WeiBlockCopyClusterLengths_E_K = Sequence<4, 32>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2; -#elif 0 - // cdata = 64, BlockSize = 128, 128x64x4 - constexpr index_t BlockSize = 128; - - constexpr index_t KPerBlock = 128; - constexpr index_t BPerBlock = 8; - constexpr index_t EPerBlock = 4; - - constexpr index_t GemmNRepeat = 2; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 8; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 2>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 8, 2>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2; - - using WeiBlockCopySubLengths_E_K = Sequence<2, 2>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 64>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 2; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2; -#elif 0 - // cdata = 64, BlockSize = 128, 128x64x8 - constexpr index_t BlockSize = 128; - - constexpr index_t KPerBlock = 128; - constexpr index_t BPerBlock = 8; - constexpr index_t EPerBlock = 8; - - constexpr index_t GemmNRepeat = 2; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 8; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 8, 1>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; - - using WeiBlockCopySubLengths_E_K = Sequence<4, 2>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 64>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2; -#elif 0 - // cdata = 64, BlockSize = 128, 128x64x16 - constexpr index_t BlockSize = 128; - - constexpr index_t KPerBlock = 128; - constexpr index_t BPerBlock = 8; - constexpr index_t EPerBlock = 16; - - constexpr index_t GemmNRepeat = 2; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 8; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<16, 1, 8, 1>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; - - using WeiBlockCopySubLengths_E_K = Sequence<4, 4>; - using WeiBlockCopyClusterLengths_E_K = Sequence<4, 32>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 4; -#elif 0 - // cdata = 64, BlockSize = 64, 64x64x8 - constexpr index_t BlockSize = 64; - - constexpr index_t KPerBlock = 64; - constexpr index_t BPerBlock = 8; - constexpr index_t EPerBlock = 8; - - constexpr index_t GemmNRepeat = 2; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 2; - - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 8, 1>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; - - using WeiBlockCopySubLengths_E_K = Sequence<4, 2>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 32>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 0 - // cdata = 64, BlockSize = 32, 32x64x3 - constexpr index_t BlockSize = 32; - - constexpr index_t KPerBlock = 32; - constexpr index_t BPerBlock = 8; - constexpr index_t EPerBlock = 3; - - constexpr index_t GemmNRepeat = 2; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 1; - constexpr index_t GemmNLevel1Cluster = 2; - - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<3, 1, 1, 2>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<1, 2, 8, 2>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2; - - using WeiBlockCopySubLengths_E_K = Sequence<3, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<1, 32>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 1; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 0 - // cdata = 64, BlockSize = 64, 32x128x3 - constexpr index_t BlockSize = 64; - - constexpr index_t KPerBlock = 32; - constexpr index_t BPerBlock = 16; - constexpr index_t EPerBlock = 3; - - constexpr index_t GemmNRepeat = 2; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 1; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<3, 1, 1, 2>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<1, 2, 16, 2>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2; - - using WeiBlockCopySubLengths_E_K = Sequence<3, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<1, 32>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 1; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 0 - // cdata = 64, BlockSize = 64, 64x64x3 - constexpr index_t BlockSize = 64; - - constexpr index_t KPerBlock = 64; - constexpr index_t BPerBlock = 8; - constexpr index_t EPerBlock = 3; - - constexpr index_t GemmNRepeat = 2; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<3, 1, 1, 1>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<1, 2, 8, 4>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 1; - - using WeiBlockCopySubLengths_E_K = Sequence<3, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<1, 64>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 1; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 0 - // cdata = 64, BlockSize = 64, 32x128x4 - constexpr index_t BlockSize = 64; - - constexpr index_t KPerBlock = 32; - constexpr index_t BPerBlock = 16; - constexpr index_t EPerBlock = 4; - - constexpr index_t GemmNRepeat = 2; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 1, 16, 1>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; - - using WeiBlockCopySubLengths_E_K = Sequence<2, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 32>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 2; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 0 - // cdata = 64, BlockSize = 64, 32x128x8 - constexpr index_t BlockSize = 64; - - constexpr index_t KPerBlock = 32; - constexpr index_t BPerBlock = 16; - constexpr index_t EPerBlock = 8; - - constexpr index_t GemmNRepeat = 2; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 1; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t GemmDataPerReadA = 4; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<2, 2, 1, 4>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 1, 16, 1>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; - - using WeiBlockCopySubLengths_E_K = Sequence<4, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<2, 32>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 4; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 0 - // cdata = 32, BlockSize = 256, 64x128x8 - constexpr index_t BlockSize = 256; - - constexpr index_t KPerBlock = 64; - constexpr index_t BPerBlock = 16; - constexpr index_t EPerBlock = 8; - - constexpr index_t GemmNRepeat = 2; - - constexpr index_t GemmMPerThread = 2; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t GemmDataPerReadA = 2; - constexpr index_t GemmDataPerReadB = 4; - - using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>; - using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>; - using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B] - using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2] - using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2] - - constexpr index_t InBlockCopySrcDataPerRead_B = 1; - constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4; - - using WeiBlockCopySubLengths_E_K = Sequence<2, 1>; - using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>; - using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E] - using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K] - - constexpr index_t WeiBlockCopySrcDataPerRead_E = 2; - constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#endif - - constexpr index_t N1 = GemmNRepeat; - constexpr index_t N2 = GemmNPerThread; - - constexpr index_t B = (N * Ho * Wo) / (N1 * N2); - - constexpr index_t GridSize = - ((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - using gridwise_conv = - GridwiseConvolutionForwardImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer< - GridSize, - BlockSize, - T, - T, - decltype(in_nchw_desc), - decltype(wei_kcyx_desc), - decltype(out_nkhw_desc), - ConvStrides, - ConvDilations, - LeftPads, - RightPads, - BPerBlock, - KPerBlock, - EPerBlock, - GemmNRepeat, - GemmMPerThread, - GemmNPerThread, - GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmDataPerReadA, - GemmDataPerReadB, - InBlockCopySubLengths_E_N1_B_N2, - InBlockCopyClusterLengths_E_N1_B_N2, - InBlockCopyThreadClusterArrangeOrder, - InBlockCopySrcAccessOrder, - InBlockCopyDstAccessOrder, - InBlockCopySrcDataPerRead_B, - InBlockCopyDstDataPerWrite_N2, - WeiBlockCopySubLengths_E_K, - WeiBlockCopyClusterLengths_E_K, - WeiBlockCopyThreadClusterArrangeOrder, - WeiBlockCopySrcAccessOrder, - WeiBlockCopyDstAccessOrder, - WeiBlockCopySrcDataPerRead_E, - WeiBlockCopyDstDataPerWrite_K>; - - for(index_t i = 0; i < 5; ++i) - { - std::cout << "Start running " << nrepeat << " times..." << std::endl; - - KernelTimer timer; - timer.Start(); - - for(index_t j = 0; j < nrepeat; ++j) - { - launch_kernel(run_gridwise_operation, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); - } - - timer.End(); - - float ave_time = timer.GetElapsedTime() / nrepeat; - - float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); -} diff --git a/driver/include/device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp deleted file mode 100644 index f1c0eebde7..0000000000 --- a/driver/include/device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,1246 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "gridwise_operation_wrapper.hpp" -#include "gridwise_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" - -template -void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcyx, - OutDesc, - Tensor& out_nkhw, - ConvStrides, - ConvDilations, - InLeftPads, - InRightPads, - ck::index_t nrepeat) -{ - std::cout << "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw" << std::endl; - - using namespace ck; - - using TDevice = typename conditional::value, half_t, T>::type; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto in_nchw_desc = - make_native_tensor_descriptor(InDesc::GetLengths(), InDesc::GetStrides()); - constexpr auto wei_kcyx_desc = - make_native_tensor_descriptor(WeiDesc::GetLengths(), WeiDesc::GetStrides()); - constexpr auto out_nkhw_desc = - make_native_tensor_descriptor(OutDesc::GetLengths(), OutDesc::GetStrides()); - - constexpr index_t N = out_nkhw_desc.GetLength(I0); - constexpr index_t K = out_nkhw_desc.GetLength(I1); - constexpr index_t Ho = out_nkhw_desc.GetLength(I2); - constexpr index_t Wo = out_nkhw_desc.GetLength(I3); - - std::size_t data_sz = sizeof(T); - DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace()); - DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace()); - DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace()); - - in_nchw_device_buf.ToDevice(in_nchw.mData.data()); - wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); - out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); - -#if 0 - // cdata = 16, BlockSize = 64, 16x64x4 - constexpr index_t BlockSize = 64; - - constexpr index_t GemmMPerBlock = 16; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerThread = 2; - constexpr index_t GemmNPerThread = 2; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t ThreadGemmDataPerReadM = 2; - constexpr index_t ThreadGemmDataPerReadN = 2; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 2; -#elif 0 - // cdata = 16, BlockSize = 64, 16x64x4 - // GemmBBlockCopySrcDataPerRead_GemmN = 4 - // GemmCThreadCopyDstDataPerWrite_GemmN1 = 2 - constexpr index_t BlockSize = 64; - - constexpr index_t GemmMPerBlock = 16; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerThread = 2; - constexpr index_t GemmNPerThread = 2; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t ThreadGemmDataPerReadM = 2; - constexpr index_t ThreadGemmDataPerReadN = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; - using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; - - using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 16>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2; -#elif 0 - // cdata = 32, BlockSize = 64, 16x128x4 - // GemmBBlockCopySrcDataPerRead_GemmN = 4 - // GemmCThreadCopyDstDataPerWrite_GemmN1 = 4 - constexpr index_t BlockSize = 64; - - constexpr index_t GemmMPerBlock = 16; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerThread = 2; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t ThreadGemmDataPerReadM = 2; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; - using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; - - using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4; -#elif 0 - // cdata = 64, BlockSize = 256, 64x256x8 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 64; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 8; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 16; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 2; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 256>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 0 - // cdata = 64, BlockSize = 256, 128x128x2 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 2; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 8; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 0 - // cdata = 64, BlockSize = 256, 128x128x4 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 8; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 2; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 1 - // cdata = 64, BlockSize = 256, 128x128x8 - // b threadwise copy 4x1 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 8; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 8; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 0 - // cdata = 64, BlockSize = 256, 128x128x8 - // b threadwise copy 2x2 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 8; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 8; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 0 - // cdata = 64, BlockSize = 256, 128x128x8 - // vector 4 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 8; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 8; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; -#elif 0 - // cdata = 64, BlockSize = 256, 128x128x16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 16; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 0 - // cdata = 64, BlockSize = 256, 128x128x8 - // GemmBBlockCopySrcDataPerRead_GemmN = 4 - // GemmCThreadCopyDstDataPerWrite_GemmN1 = 4 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 8; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; -#elif 1 - // cdata = 64, BlockSize = 256, 128x128x16 - // GemmBBlockCopySrcDataPerRead_GemmN = 4 - // GemmCThreadCopyDstDataPerWrite_GemmN1 = 4 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 16; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; -#elif 1 - // cdata = 64, BlockSize = 256, 128x128x16 - // GemmBBlockCopySrcDataPerRead_GemmN = 4 - // GemmCThreadCopyDstDataPerWrite_GemmN1 = 4 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 16; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 8>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<16, 16>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; -#elif 0 - // cdata = 64, BlockSize = 128, 128x64x4 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<1, 128>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 64>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 0 - // cdata = 64, BlockSize = 128, 128x64x4 - // GemmBBlockCopySrcDataPerRead_GemmN = 2 - // GemmCThreadCopyDstDataPerWrite_GemmN1 = 2 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<1, 128>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 2>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 32>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 2; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 2; -#elif 0 - // cdata = 64, BlockSize = 128, 128x64x8 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 8; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 64>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 2; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 64>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 0 - // cdata = 64, BlockSize = 128, 128x64x8 - // GemmBBlockCopySrcDataPerRead_GemmN = 4 - // GemmCThreadCopyDstDataPerWrite_GemmN1 = 4 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 8; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 64>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 2; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 16>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; -#elif 0 - // cdata = 64, BlockSize = 128, 128x64x16 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 16; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 4>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 64>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 0 - // cdata = 64, BlockSize = 128, 64x128x4 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 64; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 64>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 2; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 0 - // BlockSize = 128, 64x128x4 - // GemmBBlockCopySrcDataPerRead_GemmN = 4 - // GemmCThreadCopyDstDataPerWrite_GemmN1 = 4 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 64; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 64>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 2; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 32>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; -#elif 0 - // cdata = 64, BlockSize = 128, 64x128x8 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 64; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 8; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 64>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 0 - // BlockSize = 128, 64x128x8 - // GemmBBlockCopySrcDataPerRead_GemmN = 4 - // GemmCThreadCopyDstDataPerWrite_GemmN1 = 4 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 64; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 8; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 64>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 32>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; -#elif 0 - // BlockSize = 128, 64x128x16 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 64; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 16; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<8, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 64>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<16, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 0 - // cdata = 64, BlockSize = 64, 64x64x4 - constexpr index_t BlockSize = 64; - - constexpr index_t GemmMPerBlock = 64; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 2; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<1, 64>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 0 - // cdata = 64, BlockSize = 64, 64x64x8 - constexpr index_t BlockSize = 64; - - constexpr index_t GemmMPerBlock = 64; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 8; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 2; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 2; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 0 - // cdata = 64, BlockSize = 64, 32x128x2 - constexpr index_t BlockSize = 64; - - constexpr index_t GemmMPerBlock = 32; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 2; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 0 - // cdata = 64, BlockSize = 64, 32x128x4 - constexpr index_t BlockSize = 64; - - constexpr index_t GemmMPerBlock = 32; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 2; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 0 - // cdata = 64, BlockSize = 32, 32x64x3 - constexpr index_t BlockSize = 32; - - constexpr index_t GemmMPerBlock = 32; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 3; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<3, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<1, 32>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<3, 2>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 32>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 0 - // cdata = 64, BlockSize = 64, 32x128x3 - constexpr index_t BlockSize = 64; - - constexpr index_t GemmMPerBlock = 32; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 3; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<3, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<1, 32>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<3, 2>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 0 - // cdata = 64, BlockSize = 64, 64x64x3 - constexpr index_t BlockSize = 64; - - constexpr index_t GemmMPerBlock = 64; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 3; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<3, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<1, 64>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<3, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 0 - // cdata = 64, BlockSize = 64, 32x128x8 - constexpr index_t BlockSize = 64; - - constexpr index_t GemmMPerBlock = 32; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 8; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 32>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 2>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 1 - // cdata = 32, BlockSize = 128, 32x128x8 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 32; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 8; - - constexpr index_t GemmMPerThread = 2; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t ThreadGemmDataPerReadM = 2; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 2; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 0 - // cdata = 32, BlockSize = 128, 32x128x16 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 32; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 16; - - constexpr index_t GemmMPerThread = 2; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 4; - - constexpr index_t ThreadGemmDataPerReadM = 2; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<16, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#endif - - constexpr index_t GemmM = K; - constexpr index_t GemmN = N * Ho * Wo; - - constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) * - math::integer_divide_ceil(GemmN, GemmNPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - using gridwise_conv = GridwiseConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw< - GridSize, - BlockSize, - TDevice, - TDevice, - decltype(in_nchw_desc), - decltype(wei_kcyx_desc), - decltype(out_nkhw_desc), - ConvStrides, - ConvDilations, - InLeftPads, - InRightPads, - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerThread, - GemmNPerThread, - GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - ThreadGemmDataPerReadM, - ThreadGemmDataPerReadN, - GemmABlockCopyThreadSliceLengths_GemmK_GemmM, - GemmABlockCopyThreadClusterLengths_GemmK_GemmM, - GemmABlockCopySrcDataPerRead_GemmK, - GemmABlockCopyDstDataPerWrite_GemmM, - GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, - GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, - GemmBBlockCopySrcDataPerRead_GemmN, - GemmBBlockCopyDstDataPerWrite_GemmN, - GemmCThreadCopyDstDataPerWrite_GemmN1>; - - for(index_t i = 0; i < 5; ++i) - { - std::cout << "Start running " << nrepeat << " times..." << std::endl; - - KernelTimer timer; - timer.Start(); - - for(index_t j = 0; j < nrepeat; ++j) - { - launch_kernel(run_gridwise_operation, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - static_cast(in_nchw_device_buf.GetDeviceBuffer()), - static_cast(wei_kcyx_device_buf.GetDeviceBuffer()), - static_cast(out_nkhw_device_buf.GetDeviceBuffer())); - } - - timer.End(); - - float ave_time = timer.GetElapsedTime() / nrepeat; - - float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - out_nkhw_device_buf.FromDevice(out_nkhw.mData.data()); -} diff --git a/driver/include/device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/driver/include/device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 238eebf2ee..0000000000 --- a/driver/include/device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,207 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "gridwise_operation_wrapper.hpp" -#include "gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" - -template -void device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcyx, - OutDesc, - Tensor& out_nkhw, - ConvStrides, - ConvDilations, - InLeftPads, - InRightPads, - ck::index_t nrepeat) -{ - std::cout << "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk" << std::endl; - - using namespace ck; - - using TDevice = typename conditional::value, half_t, T>::type; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto N = OutDesc::GetLengths()[I0]; - constexpr auto K = OutDesc::GetLengths()[I1]; - constexpr auto C = WeiDesc::GetLengths()[I1]; - - constexpr auto Hi = InDesc::GetLengths()[I2]; - constexpr auto Wi = InDesc::GetLengths()[I3]; - - constexpr auto Ho = OutDesc::GetLengths()[I2]; - constexpr auto Wo = OutDesc::GetLengths()[I3]; - - constexpr auto Y = WeiDesc::GetLengths()[I2]; - constexpr auto X = WeiDesc::GetLengths()[I3]; - - // compile-time variables - constexpr auto in_n_hi_wi_c_desc = - make_native_tensor_descriptor_packed(Sequence{}); - constexpr auto wei_k_y_x_c_desc = make_native_tensor_descriptor_packed(Sequence{}); - constexpr auto out_n_ho_wo_k_desc = - make_native_tensor_descriptor_packed(Sequence{}); - - Tensor in_nhwc( - make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence{}))); - Tensor wei_kyxc( - make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence{}))); - Tensor out_nhwk( - make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence{}))); - - auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) { - in_nhwc(n, hi, wi, c) = in_nchw(n, c, hi, wi); - }; - - auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) { - wei_kyxc(k, y, x, c) = wei_kcyx(k, c, y, x); - }; - - auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) { - out_nhwk(n, ho, wo, k) = out_nkhw(n, k, ho, wo); - }; - - make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)(std::thread::hardware_concurrency()); - make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)(std::thread::hardware_concurrency()); - make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)(std::thread::hardware_concurrency()); - - std::size_t data_sz = sizeof(T); - - DeviceMem in_nhwc_device_buf(data_sz * in_nhwc.mDesc.GetElementSpace()); - DeviceMem wei_kyxc_device_buf(data_sz * wei_kyxc.mDesc.GetElementSpace()); - DeviceMem out_nhwk_device_buf(data_sz * out_nhwk.mDesc.GetElementSpace()); - - in_nhwc_device_buf.ToDevice(in_nhwc.mData.data()); - wei_kyxc_device_buf.ToDevice(wei_kyxc.mData.data()); - out_nhwk_device_buf.ToDevice(out_nhwk.mData.data()); - -#if 1 - // cdata = 16, BlockSize = 64, 16x64x4 - constexpr index_t BlockSize = 64; - - constexpr index_t GemmMPerBlock = 16; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerThread = 2; - constexpr index_t GemmNPerThread = 2; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t ThreadGemmDataPerReadM = 2; - constexpr index_t ThreadGemmDataPerReadN = 2; - - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; - - constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; - - using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; - using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; - - constexpr index_t GemmBBlockCopySrcDataPerRead_GemmK = 4; - constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; - - constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmM1 = 2; -#endif - - constexpr index_t GemmM = K; - constexpr index_t GemmN = N * Ho * Wo; - - constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) * - math::integer_divide_ceil(GemmN, GemmNPerBlock); - - printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - - using gridwise_conv = GridwiseConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk< - GridSize, - BlockSize, - TDevice, - TDevice, - decltype(in_n_hi_wi_c_desc), - decltype(wei_k_y_x_c_desc), - decltype(out_n_ho_wo_k_desc), - ConvStrides, - ConvDilations, - InLeftPads, - InRightPads, - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerThread, - GemmNPerThread, - GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - ThreadGemmDataPerReadM, - ThreadGemmDataPerReadN, - GemmABlockCopyThreadSliceLengths_GemmK_GemmM, - GemmABlockCopyThreadClusterLengths_GemmK_GemmM, - GemmABlockCopySrcDataPerRead_GemmK, - GemmABlockCopyDstDataPerWrite_GemmM, - GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, - GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, - GemmBBlockCopySrcDataPerRead_GemmK, - GemmBBlockCopyDstDataPerWrite_GemmN, - GemmCThreadCopyDstDataPerWrite_GemmM1>; - - for(index_t i = 0; i < 5; ++i) - { - std::cout << "Start running " << nrepeat << " times..." << std::endl; - - KernelTimer timer; - timer.Start(); - - for(index_t j = 0; j < nrepeat; ++j) - { - launch_kernel(run_gridwise_operation, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - static_cast(in_nhwc_device_buf.GetDeviceBuffer()), - static_cast(wei_kyxc_device_buf.GetDeviceBuffer()), - static_cast(out_nhwk_device_buf.GetDeviceBuffer())); - } - - timer.End(); - - float ave_time = timer.GetElapsedTime() / nrepeat; - - float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - out_nhwk_device_buf.FromDevice(out_nhwk.mData.data()); - - auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) { - out_nkhw(n, k, ho, wo) = out_nhwk(n, ho, wo, k); - }; - - make_ParallelTensorFunctor(f_nhwk2nkhw, N, K, Ho, Wo)(std::thread::hardware_concurrency()); -} diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp index 9054c09d28..e0a89d2af3 100644 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -257,7 +257,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk GemmKPerBlock, GemmMPerWave, GemmNPerWave, - GemmK1, MRepeat, NRepeat, GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp index 130f7c97e2..bb37ac309f 100644 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp @@ -57,7 +57,35 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); #if 1 - // [M, N, K0, K1] = [256, 128, 4, 8] + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 256; diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp index f030ed74eb..c1e63664e5 100644 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp @@ -56,7 +56,63 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh const auto out_n_ho_wo_k_desc = make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); -#if 0 +#if 1 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 0 // [M, N, K0, K1] = [256, 256, 4, 8] for fp16 constexpr index_t BlockSize = 256; @@ -111,34 +167,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; #endif diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index 6b04f07b2f..f423f6228e 100644 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -56,7 +56,63 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh const auto out_n_ho_wo_k_desc = make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); -#if 0 +#if 1 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 // [M, N, K0, K1] = [256, 256, 4, 8] for fp16 constexpr index_t BlockSize = 256; @@ -139,34 +195,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #elif 1 // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp index b6799e1072..702ddc9e8f 100644 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp @@ -215,7 +215,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw( Sequence<4, 3, 2, 0, 1>, // BBlockTransferSrcAccessOrder GemmBBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, GemmBBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, - Sequence<0, 1, 2, 3, 4>, // BBlockTransferSrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3, 4>, // BBlockTransferSrcVectorTensorContiguousDimOrder Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder 5, // CThreadTransferSrcDstVectorDim GemmCThreadTransferDstScalarPerVector_BN1, diff --git a/driver/include/device_tensor.hpp b/driver/include/device_tensor.hpp index 07d98f87a3..1a7a34a4cf 100644 --- a/driver/include/device_tensor.hpp +++ b/driver/include/device_tensor.hpp @@ -1,23 +1,6 @@ #pragma once #include "host_tensor.hpp" #include "common_header.hpp" -#include "tensor_descriptor.hpp" - -template -auto make_HostTensorDescriptor_impl(TensorDesc, std::integer_sequence) -{ - std::initializer_list lengths = {TensorDesc::GetLengths()[Is]...}; - std::initializer_list strides = {TensorDesc::GetStrides()[Is]...}; - - return HostTensorDescriptor(lengths, strides); -} - -template -auto make_HostTensorDescriptor(TensorDesc) -{ - return make_HostTensorDescriptor_impl( - TensorDesc{}, std::make_integer_sequence{}); -} template void ostream_tensor_descriptor(TensorDesc, std::ostream& os = std::cout) diff --git a/script/cmake-cuda.sh b/script/cmake-cuda.sh deleted file mode 100755 index 106035d70a..0000000000 --- a/script/cmake-cuda.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -MY_PROJECT_SOURCE=../../../ - - export CUDA_ROOT=/usr/local/cuda - export CPATH=$CPATH:$CUDA_ROOT/include - export LIBRARY_PATH=$LIBRARY_PATH:$CUDA_ROOT/lib64 - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CUDA_ROOT/lib64 - -cmake \ --D CMAKE_CXX_COMPILER=clang++ \ --D CMAKE_BUILD_TYPE=Release \ --D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ --D DEVICE_BACKEND=NVIDIA \ --D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_61,code=sm_61 -Xptxas -v -maxrregcount=128" \ -${MY_PROJECT_SOURCE} - - -#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_70,code=sm_70" \ -#-D CMAKE_CUDA_FLAGS="-ccbin clang++ -m64 -Xcompiler -fopenmp -lineinfo --source-in-ptx -keep -Xptxas -v -gencode=arch=compute_70,code=sm_70 -Xptxas -v -maxrregcount=128" \ diff --git a/script/extract_asm-cuda.sh b/script/extract_asm-cuda.sh deleted file mode 100755 index 879e0b1a3d..0000000000 --- a/script/extract_asm-cuda.sh +++ /dev/null @@ -1,3 +0,0 @@ -DRIVER=$1 -ARCH=$2 -cuobjdump -xelf $ARCH ./driver/$DRIVER && nvdisasm --print-code -g $DRIVER.$ARCH.cubin > $DRIVER.$ARCH.asm