From e2753e68bdc0b0467e3fe64321cb18b86d0cf534 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 25 Mar 2021 13:51:11 -0500 Subject: [PATCH] Dynamic tensor descriptor (#24) * support dynamic tensor descriptor * use buffer load OOB feature for padding case * add navi support * add int8x4 inference kernel Co-authored-by: Chao Liu Co-authored-by: Jing Zhang [ROCm/composable_kernel commit: fcbb978828b308d8c367a3eeaebee485a61b548c] --- CMakeLists.txt | 3 +- ...ward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp | 2024 +++++++++++++++++ ...ward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp | 1341 +++++++++++ ...ward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp | 353 +++ .../include/gridwise_operation_wrapper.hpp | 6 +- ..._v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp | 24 +- ...data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp | 2 +- ...v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp} | 26 +- ...ard_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp} | 6 +- ...ward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp | 162 ++ .../tensor_description/array_multi_index.hpp | 77 + .../tensor_description/cluster_descriptor.hpp | 48 + .../dynamic_multi_index_transform.hpp | 1157 ++++++++++ .../dynamic_multi_index_transform_helper.hpp | 74 + .../dynamic_tensor_descriptor.hpp | 608 +++++ .../dynamic_tensor_descriptor_helper.hpp | 146 ++ .../tensor_description/multi_index.hpp | 12 + .../multi_index_transform.hpp | 114 +- .../statically_indexed_array_multi_index.hpp | 107 + .../tensor_description/tensor_coordinate.hpp | 12 +- .../tensor_description/tensor_descriptor.hpp | 29 +- .../tensor_descriptor_helper.hpp | 48 +- .../blockwise_batched_gemm.hpp | 51 +- ...lockwise_dynamic_tensor_slice_transfer.hpp | 171 ++ .../tensor_operation/blockwise_gemm.hpp | 29 +- .../tensor_operation/blockwise_gemm_v2.hpp | 370 +++ .../tensor_operation/blockwise_gemm_v3.hpp | 198 ++ .../blockwise_generic_tensor_slice_copy.hpp | 5 +- .../gridwise_dynamic_gemm.hpp | 509 +++++ .../gridwise_dynamic_gemm_v2.hpp | 471 ++++ .../tensor_operation/gridwise_gemm.hpp | 190 +- .../gridwise_tensor_contraction.hpp | 330 --- ...readwise_dynamic_tensor_slice_transfer.hpp | 1298 +++++++++++ .../tensor_operation/threadwise_gemm.hpp | 7 +- .../tensor_operation/threadwise_gemm_v2.hpp | 172 ++ .../tensor_operation/threadwise_gemm_v3.hpp | 141 ++ .../threadwise_generic_tensor_slice_copy.hpp | 157 +- .../include/utility/amd_buffer_addressing.hpp | 698 +++--- .../utility/amd_buffer_addressing_v2.hpp | 365 +++ .../include/utility/amd_inline_asm.hpp | 131 +- .../include/utility/amd_llvm_intrinsic.hpp | 11 + composable_kernel/include/utility/array.hpp | 399 +--- .../include/utility/common_header.hpp | 21 +- .../include/utility/config.amd.hpp.in | 88 +- .../utility/container_element_picker.hpp | 153 ++ .../include/utility/container_helper.hpp | 375 +++ .../include/utility/float_type.amd.hpp.in | 597 +++-- .../include/utility/float_type.nvidia.hpp.in | 30 +- .../include/utility/functional.hpp | 13 +- .../include/utility/functional2.hpp | 3 +- .../include/utility/functional3.hpp | 20 +- .../include/utility/functional4.hpp | 36 +- .../utility/in_memory_operation.amd.hpp.in | 53 +- .../utility/in_memory_operation.nvidia.hpp.in | 4 +- composable_kernel/include/utility/math.hpp | 31 +- composable_kernel/include/utility/print.hpp | 70 + .../include/utility/print_array.hpp | 177 -- .../include/utility/print_sequence.hpp | 46 - .../include/utility/sequence.hpp | 15 + .../include/utility/sequence_helper.hpp | 15 + .../utility/statically_indexed_array.hpp | 40 + composable_kernel/include/utility/tuple.hpp | 118 +- .../include/utility/tuple_helper.hpp | 80 + composable_kernel/include/utility/type.hpp | 39 +- composable_kernel/include/utility/utility.hpp | 2 +- driver/include/conv_common.hpp | 20 +- ...data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp | 2 +- ...data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp | 47 +- ...data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp | 2 +- ...ard_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp} | 107 +- ...ard_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp} | 260 ++- ...ward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp | 207 ++ ...ward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp | 508 +++++ ...ward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp | 427 ++++ ...ward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp | 167 ++ driver/include/host_conv.hpp | 2 +- driver/include/host_tensor.hpp | 2 +- driver/src/conv_bwd_data_driver.cpp | 29 +- driver/src/conv_bwd_data_driver.cu | 1 - driver/src/conv_driver.cpp | 378 ++- driver/src/conv_driver.cu | 1 - external/half/include/half.hpp | 97 +- script/{cmake-rocm3.5.sh => cmake-rocm3.7.sh} | 12 +- script/count_vgpr.sh | 259 +++ script/hipclang_opt.sh | 25 + 85 files changed, 14129 insertions(+), 2532 deletions(-) create mode 100644 composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp create mode 100644 composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp create mode 100644 composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp rename composable_kernel/include/kernel_algorithm/{gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp => gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp} (95%) rename composable_kernel/include/kernel_algorithm/{gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp => gridwise_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp} (97%) create mode 100644 composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp create mode 100644 composable_kernel/include/tensor_description/array_multi_index.hpp create mode 100644 composable_kernel/include/tensor_description/cluster_descriptor.hpp create mode 100644 composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp create mode 100644 composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp create mode 100644 composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp create mode 100644 composable_kernel/include/tensor_description/dynamic_tensor_descriptor_helper.hpp create mode 100644 composable_kernel/include/tensor_description/multi_index.hpp create mode 100644 composable_kernel/include/tensor_description/statically_indexed_array_multi_index.hpp create mode 100644 composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp create mode 100644 composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp create mode 100644 composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp create mode 100644 composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp create mode 100644 composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp delete mode 100644 composable_kernel/include/tensor_operation/gridwise_tensor_contraction.hpp create mode 100644 composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp create mode 100644 composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp create mode 100644 composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp create mode 100644 composable_kernel/include/utility/amd_buffer_addressing_v2.hpp create mode 100644 composable_kernel/include/utility/amd_llvm_intrinsic.hpp create mode 100644 composable_kernel/include/utility/container_element_picker.hpp create mode 100644 composable_kernel/include/utility/container_helper.hpp create mode 100644 composable_kernel/include/utility/print.hpp delete mode 100644 composable_kernel/include/utility/print_array.hpp delete mode 100644 composable_kernel/include/utility/print_sequence.hpp create mode 100644 composable_kernel/include/utility/sequence_helper.hpp create mode 100644 composable_kernel/include/utility/statically_indexed_array.hpp create mode 100644 composable_kernel/include/utility/tuple_helper.hpp rename driver/include/{device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp => device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp} (93%) rename driver/include/{device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp => device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp} (80%) create mode 100644 driver/include/device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp create mode 100644 driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp create mode 100644 driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp create mode 100644 driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp delete mode 120000 driver/src/conv_bwd_data_driver.cu delete mode 120000 driver/src/conv_driver.cu rename script/{cmake-rocm3.5.sh => cmake-rocm3.7.sh} (65%) create mode 100755 script/count_vgpr.sh create mode 100755 script/hipclang_opt.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index d8e51761bd..db48a26202 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,7 @@ project(modular_convolution) #c++ enable_language(CXX) -set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) message("CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") @@ -53,6 +53,7 @@ include_directories(BEFORE ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_description ${PROJECT_SOURCE_DIR}/composable_kernel/include/tensor_operation ${PROJECT_SOURCE_DIR}/composable_kernel/include/kernel_algorithm + ${PROJECT_SOURCE_DIR}/composable_kernel/include/driver ${PROJECT_SOURCE_DIR}/external/half/include ${PROJECT_SOURCE_DIR}/driver/include ${PROJECT_BINARY_DIR}/composable_kernel/include/utility diff --git a/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..a7fa193da1 --- /dev/null +++ b/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -0,0 +1,2024 @@ +#ifndef CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP +#define CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm.hpp" +#include "gridwise_operation_wrapper.hpp" + +namespace ck { + +// GemmM = K +// GemmN = N * Ho * Wo +// GemmK = C * Y * X +template +struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad +{ + template + __host__ void Run(const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, + const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, + const DynamicTensorDescriptor& out_n_k_ho_wo_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const FloatAB* __restrict__ p_wei_global, + const FloatAB* __restrict__ p_in_global, + FloatC* __restrict__ p_out_global) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); + const auto X = wei_k_c_y_x_global_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + // weight tensor + const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_hip_wip_global_desc, + make_tuple( + make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_y_ho_x_wo_global_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), + make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); + const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); + const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); + + if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && + GemmK % GemmKPerBlock == 0)) + { + throw std::runtime_error("wrong! GEMM size no divisible"); + } + + constexpr auto GemmM1 = Number{}; + constexpr auto GemmN1 = Number{}; + + const auto GemmM0 = GemmM / GemmM1; + const auto GemmN0 = GemmN / GemmN1; + + const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = + transform_dynamic_tensor_descriptor( + out_gemmm_gemmn_global_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), + make_unmerge_transform(make_tuple(GemmN0, GemmN1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + // hack to control index calculation when iterating over a_k_m_global tensor + constexpr auto a_k_m_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); + + constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; + + // hack to control index calculation when iterating over b_k_n_global tensor + constexpr auto b_k_n_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{})); + + constexpr auto b_k_n_global_move_slice_window_iterator_hack = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{}; + + // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor + // hack for NKHW format + constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); + + // GEMM + using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + InMemoryDataOperation::Set, + decltype(wei_gemmk_gemmm_global_desc), + decltype(in_gemmk_gemmn_global_desc), + decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmABlockTransferThreadSliceLengths_GemmK_GemmM, + GemmABlockTransferThreadClusterLengths_GemmK_GemmM, + Sequence<1, 0>, + Sequence<1, 0>, + 0, + GemmABlockTransferSrcScalarPerVector_GemmK, + GemmABlockTransferDstScalarPerVector_GemmM, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, + GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, + Sequence<0, 1>, + Sequence<0, 1>, + 1, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmN, + false, // don't move back src coordinate after threadwise copy, which will be fused with + // MoveSrcSliceWindow() to save addr computation + Sequence<2, 3, 0, 1>, + 3, + GemmCThreadTransferDstScalarPerVector_GemmN1, + decltype(a_k_m_global_iterator_hacks), + decltype(b_k_n_global_iterator_hacks), + decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks), + decltype(a_k_m_global_move_slice_window_iterator_hack), + decltype(b_k_n_global_move_slice_window_iterator_hack)>; + + const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock); + + const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1; + + const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0; + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + index_t nrepeat = 100; + + for(index_t i = 0; i < 5; ++i) + { + std::cout << "Start running " << nrepeat << " times..." << std::endl; + + KernelTimer timer; + timer.Start(); + + for(index_t j = 0; j < nrepeat; ++j) + { + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_gemmn_global_desc, + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_gemmn_global_desc, + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_gemmn_global_desc, + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_gemmn_global_desc, + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, + wei_k_c_y_x_global_desc, + out_n_k_ho_wo_global_desc) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_POINTER + using ADesc = decltype(wei_gemmk_gemmm_global_desc); + using BDesc = decltype(in_gemmk_gemmn_global_desc); + using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); + + DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc)); + DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc)); + DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc)); + + wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc); + in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc); + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice( + &out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); + + index_t nrepeat = 100; + + for(index_t i = 0; i < 5; ++i) + { + std::cout << "Start running " << nrepeat << " times..." << std::endl; + + KernelTimer timer; + timer.Start(); + + for(index_t j = 0; j < nrepeat; ++j) + { + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + reinterpret_cast( + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()), + p_wei_global, + reinterpret_cast( + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()), + p_in_global, + reinterpret_cast( + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer()), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + reinterpret_cast( + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()), + p_wei_global, + reinterpret_cast( + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()), + p_in_global, + reinterpret_cast( + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer()), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + reinterpret_cast( + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()), + p_wei_global, + reinterpret_cast( + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()), + p_in_global, + reinterpret_cast( + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer()), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + reinterpret_cast( + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()), + p_wei_global, + reinterpret_cast( + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()), + p_in_global, + reinterpret_cast( + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer()), + p_out_global, + integral_constant{}, + integral_constant{}); + } + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, + wei_k_c_y_x_global_desc, + out_n_k_ho_wo_global_desc) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + using ADesc = decltype(wei_gemmk_gemmm_global_desc); + using BDesc = decltype(in_gemmk_gemmn_global_desc); + using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); + + DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc)); + DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc)); + DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc)); + + wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc); + in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc); + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice( + &out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); + + index_t nrepeat = 100; + + for(index_t i = 0; i < 5; ++i) + { + std::cout << "Start running " << nrepeat << " times..." << std::endl; + + KernelTimer timer; + timer.Start(); + + for(index_t j = 0; j < nrepeat; ++j) + { + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), + p_wei_global, + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer(), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), + p_wei_global, + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer(), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), + p_wei_global, + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer(), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), + p_wei_global, + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer(), + p_out_global, + integral_constant{}, + integral_constant{}); + } + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, + wei_k_c_y_x_global_desc, + out_n_k_ho_wo_global_desc) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } +#endif + } +}; + +// GemmM = K +// GemmN = N * Ho * Wo +// GemmK = C * Y * X +template +struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad +{ + template + __host__ void Run(const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, + const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, + const DynamicTensorDescriptor& out_n_k_ho_wo_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const FloatAB* __restrict__ p_wei_global, + const FloatAB* __restrict__ p_in_global, + FloatC* __restrict__ p_out_global) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); + const auto X = wei_k_c_y_x_global_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + if(!(InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && InRightPadW == 0)) + { + throw std::runtime_error("wrong! no padding"); + } + + // weight tensor + const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple( + make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_y_ho_x_wo_global_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), + make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); + const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); + const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); + + if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && + GemmK % GemmKPerBlock == 0)) + { + throw std::runtime_error("wrong! GEMM size no divisible"); + } + + constexpr auto GemmM1 = Number{}; + constexpr auto GemmN1 = Number{}; + + const auto GemmM0 = GemmM / GemmM1; + const auto GemmN0 = GemmN / GemmN1; + + const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = + transform_dynamic_tensor_descriptor( + out_gemmm_gemmn_global_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), + make_unmerge_transform(make_tuple(GemmN0, GemmN1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + // hack to control index calculation when iterating over a_k_m_global tensor + constexpr auto a_k_m_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); + + constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; + + // hack to control index calculation when iterating over b_k_n_global tensor + constexpr auto b_k_n_global_iterator_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 1, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 1>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 2, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 2>{})); + + constexpr auto b_k_n_global_move_slice_window_iterator_hack = + Sequence<0, 0, 0, 0, 0, 1, 2>{}; + + // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor + // hack for NKHW format + constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); + + // GEMM + using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + InMemoryDataOperation::Set, + decltype(wei_gemmk_gemmm_global_desc), + decltype(in_gemmk_gemmn_global_desc), + decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmABlockTransferThreadSliceLengths_GemmK_GemmM, + GemmABlockTransferThreadClusterLengths_GemmK_GemmM, + Sequence<1, 0>, + Sequence<1, 0>, + 0, + GemmABlockTransferSrcScalarPerVector_GemmK, + GemmABlockTransferDstScalarPerVector_GemmM, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, + GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, + Sequence<0, 1>, + Sequence<0, 1>, + 1, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmN, + false, // don't move back src coordinate after threadwise copy, which will be fused with + // MoveSrcSliceWindow() to save addr computation + Sequence<2, 3, 0, 1>, + 3, + GemmCThreadTransferDstScalarPerVector_GemmN1, + decltype(a_k_m_global_iterator_hacks), + decltype(b_k_n_global_iterator_hacks), + decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks), + decltype(a_k_m_global_move_slice_window_iterator_hack), + decltype(b_k_n_global_move_slice_window_iterator_hack)>; + + const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock); + + const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1; + + const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0; + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + index_t nrepeat = 100; + + for(index_t i = 0; i < 5; ++i) + { + std::cout << "Start running " << nrepeat << " times..." << std::endl; + + KernelTimer timer; + timer.Start(); + + for(index_t j = 0; j < nrepeat; ++j) + { + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_gemmn_global_desc, + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_gemmn_global_desc, + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_gemmn_global_desc, + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_gemmn_global_desc, + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, + wei_k_c_y_x_global_desc, + out_n_k_ho_wo_global_desc) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_POINTER + using ADesc = decltype(wei_gemmk_gemmm_global_desc); + using BDesc = decltype(in_gemmk_gemmn_global_desc); + using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); + + DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc)); + DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc)); + DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc)); + + wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc); + in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc); + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice( + &out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); + + index_t nrepeat = 100; + + for(index_t i = 0; i < 5; ++i) + { + std::cout << "Start running " << nrepeat << " times..." << std::endl; + + KernelTimer timer; + timer.Start(); + + for(index_t j = 0; j < nrepeat; ++j) + { + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + reinterpret_cast( + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()), + p_wei_global, + reinterpret_cast( + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()), + p_in_global, + reinterpret_cast( + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer()), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + reinterpret_cast( + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()), + p_wei_global, + reinterpret_cast( + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()), + p_in_global, + reinterpret_cast( + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer()), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + reinterpret_cast( + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()), + p_wei_global, + reinterpret_cast( + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()), + p_in_global, + reinterpret_cast( + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer()), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + reinterpret_cast( + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()), + p_wei_global, + reinterpret_cast( + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()), + p_in_global, + reinterpret_cast( + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer()), + p_out_global, + integral_constant{}, + integral_constant{}); + } + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, + wei_k_c_y_x_global_desc, + out_n_k_ho_wo_global_desc) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + using ADesc = decltype(wei_gemmk_gemmm_global_desc); + using BDesc = decltype(in_gemmk_gemmn_global_desc); + using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); + + DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc)); + DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc)); + DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc)); + + wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc); + in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc); + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice( + &out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); + + index_t nrepeat = 100; + + for(index_t i = 0; i < 5; ++i) + { + std::cout << "Start running " << nrepeat << " times..." << std::endl; + + KernelTimer timer; + timer.Start(); + + for(index_t j = 0; j < nrepeat; ++j) + { + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), + p_wei_global, + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer(), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), + p_wei_global, + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer(), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), + p_wei_global, + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer(), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), + p_wei_global, + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer(), + p_out_global, + integral_constant{}, + integral_constant{}); + } + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, + wei_k_c_y_x_global_desc, + out_n_k_ho_wo_global_desc) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } +#endif + } +}; + +template +struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 +{ + template + __host__ void Run(const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, + const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, + const DynamicTensorDescriptor& out_n_k_ho_wo_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const FloatAB* __restrict__ p_wei_global, + const FloatAB* __restrict__ p_in_global, + FloatC* __restrict__ p_out_global) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); + const auto X = wei_k_c_y_x_global_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + if(!(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && + ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && + InRightPadW == 0)) + { + throw std::runtime_error("wrong! 1x1, stride 1, no padding"); + } + + // weight tensor + const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), + make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); + const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); + const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); + + if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && + GemmK % GemmKPerBlock == 0)) + { + throw std::runtime_error("wrong! GEMM size no divisible"); + } + + constexpr auto GemmM1 = Number{}; + constexpr auto GemmN1 = Number{}; + + const auto GemmM0 = GemmM / GemmM1; + const auto GemmN0 = GemmN / GemmN1; + + const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = + transform_dynamic_tensor_descriptor( + out_gemmm_gemmn_global_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), + make_unmerge_transform(make_tuple(GemmN0, GemmN1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + // hack to control index calculation when iterating over a_k_m_global tensor + constexpr auto a_k_m_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); + + constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; + + // hack to control index calculation when iterating over b_k_n_global tensor + constexpr auto b_k_n_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 1, 0>{}, Sequence<0, 0, 1>{}), + make_tuple(Sequence<0, 2, 0>{}, Sequence<0, 0, 2>{})); + + constexpr auto b_k_n_global_move_slice_window_iterator_hack = Sequence<0, 1, 2>{}; + + // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor + constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); + + // GEMM + using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + InMemoryDataOperation::Set, + decltype(wei_gemmk_gemmm_global_desc), + decltype(in_gemmk_gemmn_global_desc), + decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmABlockTransferThreadSliceLengths_GemmK_GemmM, + GemmABlockTransferThreadClusterLengths_GemmK_GemmM, + Sequence<1, 0>, + Sequence<1, 0>, + 0, + GemmABlockTransferSrcScalarPerVector_GemmK, + GemmABlockTransferDstScalarPerVector_GemmM, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, + GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, + Sequence<0, 1>, + Sequence<0, 1>, + 1, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmN, + false, // don't move back src coordinate after threadwise copy, which will be fused with + // MoveSrcSliceWindow() to save addr computation + Sequence<2, 3, 0, 1>, + 3, + GemmCThreadTransferDstScalarPerVector_GemmN1, + decltype(a_k_m_global_iterator_hacks), + decltype(b_k_n_global_iterator_hacks), + decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks), + decltype(a_k_m_global_move_slice_window_iterator_hack), + decltype(b_k_n_global_move_slice_window_iterator_hack)>; + + const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock); + + const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1; + + const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0; + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + index_t nrepeat = 100; + + for(index_t i = 0; i < 5; ++i) + { + std::cout << "Start running " << nrepeat << " times..." << std::endl; + + KernelTimer timer; + timer.Start(); + + for(index_t j = 0; j < nrepeat; ++j) + { + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_gemmn_global_desc, + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_gemmn_global_desc, + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_gemmn_global_desc, + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_gemmn_global_desc, + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, + wei_k_c_y_x_global_desc, + out_n_k_ho_wo_global_desc) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_POINTER + using ADesc = decltype(wei_gemmk_gemmm_global_desc); + using BDesc = decltype(in_gemmk_gemmn_global_desc); + using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); + + DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc)); + DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc)); + DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc)); + + wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc); + in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc); + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice( + &out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); + + index_t nrepeat = 100; + + for(index_t i = 0; i < 5; ++i) + { + std::cout << "Start running " << nrepeat << " times..." << std::endl; + + KernelTimer timer; + timer.Start(); + + for(index_t j = 0; j < nrepeat; ++j) + { + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + reinterpret_cast( + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()), + p_wei_global, + reinterpret_cast( + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()), + p_in_global, + reinterpret_cast( + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer()), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + reinterpret_cast( + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()), + p_wei_global, + reinterpret_cast( + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()), + p_in_global, + reinterpret_cast( + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer()), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + reinterpret_cast( + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()), + p_wei_global, + reinterpret_cast( + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()), + p_in_global, + reinterpret_cast( + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer()), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + reinterpret_cast( + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()), + p_wei_global, + reinterpret_cast( + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()), + p_in_global, + reinterpret_cast( + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer()), + p_out_global, + integral_constant{}, + integral_constant{}); + } + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, + wei_k_c_y_x_global_desc, + out_n_k_ho_wo_global_desc) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + using ADesc = decltype(wei_gemmk_gemmm_global_desc); + using BDesc = decltype(in_gemmk_gemmn_global_desc); + using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); + + DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc)); + DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc)); + DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc)); + + wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc); + in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc); + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice( + &out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); + + index_t nrepeat = 100; + + for(index_t i = 0; i < 5; ++i) + { + std::cout << "Start running " << nrepeat << " times..." << std::endl; + + KernelTimer timer; + timer.Start(); + + for(index_t j = 0; j < nrepeat; ++j) + { + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), + p_wei_global, + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer(), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), + p_wei_global, + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer(), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), + p_wei_global, + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer(), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), + p_wei_global, + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer(), + p_out_global, + integral_constant{}, + integral_constant{}); + } + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, + wei_k_c_y_x_global_desc, + out_n_k_ho_wo_global_desc) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } +#endif + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..7e9cc0af5f --- /dev/null +++ b/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,1341 @@ +#ifndef CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NHWC_KYXC_NHWK_HPP +#define CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm.hpp" +#include "gridwise_operation_wrapper.hpp" + +namespace ck { + +// GemmM = K +// GemmN = N * Ho * Wo +// GemmK = Y * X * C +template +struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad +{ + template + __host__ void Run(const DynamicTensorDescriptor& wei_k_y_x_c_global_desc, + const DynamicTensorDescriptor& in_n_hi_wi_c_global_desc, + const DynamicTensorDescriptor& out_n_ho_wo_k_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const FloatAB* __restrict__ p_wei_global, + const FloatAB* __restrict__ p_in_global, + FloatC* __restrict__ p_out_global) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_hi_wi_c_global_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_global_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_global_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_global_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_global_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_global_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_global_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_global_desc.GetLength(I1); + const auto X = wei_k_y_x_c_global_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + // weight tensor + const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_n_hip_wip_c_global_desc = transform_dynamic_tensor_descriptor( + in_n_hi_wi_c_global_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_global_desc = transform_dynamic_tensor_descriptor( + in_n_hip_wip_c_global_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor( + in_n_y_ho_x_wo_c_global_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); + const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); + const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); + + if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && + GemmK % GemmKPerBlock == 0)) + { + throw std::runtime_error("wrong! GEMM size no divisible"); + } + + constexpr auto GemmM1 = Number{}; + constexpr auto GemmN1 = Number{}; + + const auto GemmM0 = GemmM / GemmM1; + const auto GemmN0 = GemmN / GemmN1; + + const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = + transform_dynamic_tensor_descriptor( + out_gemmm_gemmn_global_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), + make_unmerge_transform(make_tuple(GemmN0, GemmN1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + // hack to control index calculation when iterating over a_k_m_global tensor + constexpr auto a_k_m_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); + + constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; + + // hack to control index calculation when iterating over b_k_n_global tensor + constexpr auto b_k_n_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{})); + + constexpr auto b_k_n_global_move_slice_window_iterator_hack = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{}; + + // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor + // hack for NKHW format + constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); + + // GEMM + using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + InMemoryDataOperation::Set, + decltype(wei_gemmk_gemmm_global_desc), + decltype(in_gemmk_gemmn_global_desc), + decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmABlockTransferThreadSliceLengths_GemmK_GemmM, + GemmABlockTransferThreadClusterLengths_GemmK_GemmM, + Sequence<1, 0>, + Sequence<1, 0>, + 0, + GemmABlockTransferSrcScalarPerVector_GemmK, + GemmABlockTransferDstScalarPerVector_GemmM, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, + GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, + Sequence<1, 0>, + Sequence<1, 0>, + 0, + GemmBBlockTransferSrcScalarPerVector_GemmK, + GemmBBlockTransferDstScalarPerVector_GemmN, + false, // don't move back src coordinate after threadwise copy, which will be fused with + // MoveSrcSliceWindow() to save addr computation + Sequence<2, 3, 0, 1>, + 1, + GemmCThreadTransferDstScalarPerVector_GemmM1, + decltype(a_k_m_global_iterator_hacks), + decltype(b_k_n_global_iterator_hacks), + decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks), + decltype(a_k_m_global_move_slice_window_iterator_hack), + decltype(b_k_n_global_move_slice_window_iterator_hack)>; + + const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock); + + const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1; + + const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0; + + printf("%s: BlockSize %d, GridSize %d \n", __func__, BlockSize, GridSize); + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + index_t nrepeat = 100; + + for(index_t i = 0; i < 5; ++i) + { + std::cout << "Start running " << nrepeat << " times..." << std::endl; + + KernelTimer timer; + timer.Start(); + + for(index_t j = 0; j < nrepeat; ++j) + { + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_gemmn_global_desc, + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_gemmn_global_desc, + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_gemmn_global_desc, + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_gemmn_global_desc, + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_POINTER + using ADesc = decltype(wei_gemmk_gemmm_global_desc); + using BDesc = decltype(in_gemmk_gemmn_global_desc); + using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); + + DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc)); + DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc)); + DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc)); + + wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc); + in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc); + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice( + &out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); + + index_t nrepeat = 100; + + for(index_t i = 0; i < 5; ++i) + { + std::cout << "Start running " << nrepeat << " times..." << std::endl; + + KernelTimer timer; + timer.Start(); + + for(index_t j = 0; j < nrepeat; ++j) + { + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + reinterpret_cast( + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()), + p_wei_global, + reinterpret_cast( + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()), + p_in_global, + reinterpret_cast( + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer()), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + reinterpret_cast( + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()), + p_wei_global, + reinterpret_cast( + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()), + p_in_global, + reinterpret_cast( + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer()), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + reinterpret_cast( + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()), + p_wei_global, + reinterpret_cast( + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()), + p_in_global, + reinterpret_cast( + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer()), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + reinterpret_cast( + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()), + p_wei_global, + reinterpret_cast( + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()), + p_in_global, + reinterpret_cast( + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer()), + p_out_global, + integral_constant{}, + integral_constant{}); + } + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + using ADesc = decltype(wei_gemmk_gemmm_global_desc); + using BDesc = decltype(in_gemmk_gemmn_global_desc); + using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); + + DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc)); + DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc)); + DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc)); + + wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc); + in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc); + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice( + &out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); + + index_t nrepeat = 100; + + for(index_t i = 0; i < 5; ++i) + { + std::cout << "Start running " << nrepeat << " times..." << std::endl; + + KernelTimer timer; + timer.Start(); + + for(index_t j = 0; j < nrepeat; ++j) + { + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), + p_wei_global, + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer(), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), + p_wei_global, + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer(), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), + p_wei_global, + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer(), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), + p_wei_global, + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer(), + p_out_global, + integral_constant{}, + integral_constant{}); + } + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } +#endif + } +}; + +template +struct DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 +{ + template + __host__ void Run(const DynamicTensorDescriptor& wei_k_y_x_c_global_desc, + const DynamicTensorDescriptor& in_n_hi_wi_c_global_desc, + const DynamicTensorDescriptor& out_n_ho_wo_k_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const FloatAB* __restrict__ p_wei_global, + const FloatAB* __restrict__ p_in_global, + FloatC* __restrict__ p_out_global) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_hi_wi_c_global_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_global_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_global_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_global_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_global_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_global_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_global_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_global_desc.GetLength(I1); + const auto X = wei_k_y_x_c_global_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + if(!(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && + ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && + InRightPadW == 0)) + { + throw std::runtime_error("wrong! 1x1, stride 1, no padding"); + } + + // weight tensor + const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, C)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // output tensor + const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); + const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); + const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); + + if(!(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && + GemmK % GemmKPerBlock == 0)) + { + throw std::runtime_error("wrong! GEMM size no divisible"); + } + + constexpr auto GemmM1 = Number{}; + constexpr auto GemmN1 = Number{}; + + const auto GemmM0 = GemmM / GemmM1; + const auto GemmN0 = GemmN / GemmN1; + + const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = + transform_dynamic_tensor_descriptor( + out_gemmm_gemmn_global_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), + make_unmerge_transform(make_tuple(GemmN0, GemmN1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + // hack to control index calculation when iterating over a_k_m_global tensor + constexpr auto a_k_m_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); + + constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; + + // hack to control index calculation when iterating over b_k_n_global tensor + constexpr auto b_k_n_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); + + constexpr auto b_k_n_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; + + // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor + constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{})); + + // GEMM + using gridwise_gemm = GridwiseDynamicGemm_km_kn_m0m1n0n1_v1< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + InMemoryDataOperation::Set, + decltype(wei_gemmk_gemmm_global_desc), + decltype(in_gemmk_gemmn_global_desc), + decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmABlockTransferThreadSliceLengths_GemmK_GemmM, + GemmABlockTransferThreadClusterLengths_GemmK_GemmM, + Sequence<1, 0>, + Sequence<1, 0>, + 0, + GemmABlockTransferSrcScalarPerVector_GemmK, + GemmABlockTransferDstScalarPerVector_GemmM, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, + GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, + Sequence<1, 0>, + Sequence<1, 0>, + 0, + GemmBBlockTransferSrcScalarPerVector_GemmK, + GemmBBlockTransferDstScalarPerVector_GemmN, + false, // don't move back src coordinate after threadwise copy, which will be fused with + // MoveSrcSliceWindow() to save addr computation + Sequence<2, 3, 0, 1>, + 1, + GemmCThreadTransferDstScalarPerVector_GemmM1, + decltype(a_k_m_global_iterator_hacks), + decltype(b_k_n_global_iterator_hacks), + decltype(c_m0_m1_n0_n1_global_tensor_iterator_hacks), + decltype(a_k_m_global_move_slice_window_iterator_hack), + decltype(b_k_n_global_move_slice_window_iterator_hack)>; + + const auto GridSize = (GemmM / GemmMPerBlock) * (GemmN / GemmNPerBlock); + + const bool has_main_k_block_loop = (GemmK + GemmKPerBlock) / (2 * GemmKPerBlock) > 1; + + const bool has_double_tail_k_block_loop = (GemmK / GemmKPerBlock) % 2 == 0; + + printf("%s: BlockSize %d, GridSize %d \n", __func__, BlockSize, GridSize); + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + index_t nrepeat = 100; + + for(index_t i = 0; i < 5; ++i) + { + std::cout << "Start running " << nrepeat << " times..." << std::endl; + + KernelTimer timer; + timer.Start(); + + for(index_t j = 0; j < nrepeat; ++j) + { + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_gemmn_global_desc, + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_gemmn_global_desc, + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_gemmn_global_desc, + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_gemmn_global_desc, + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_POINTER + using ADesc = decltype(wei_gemmk_gemmm_global_desc); + using BDesc = decltype(in_gemmk_gemmn_global_desc); + using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); + + DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc)); + DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc)); + DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc)); + + wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc); + in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc); + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice( + &out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); + + index_t nrepeat = 100; + + for(index_t i = 0; i < 5; ++i) + { + std::cout << "Start running " << nrepeat << " times..." << std::endl; + + KernelTimer timer; + timer.Start(); + + for(index_t j = 0; j < nrepeat; ++j) + { + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + reinterpret_cast( + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()), + p_wei_global, + reinterpret_cast( + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()), + p_in_global, + reinterpret_cast( + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer()), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + reinterpret_cast( + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()), + p_wei_global, + reinterpret_cast( + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()), + p_in_global, + reinterpret_cast( + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer()), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + reinterpret_cast( + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()), + p_wei_global, + reinterpret_cast( + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()), + p_in_global, + reinterpret_cast( + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer()), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + reinterpret_cast( + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer()), + p_wei_global, + reinterpret_cast( + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer()), + p_in_global, + reinterpret_cast( + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer()), + p_out_global, + integral_constant{}, + integral_constant{}); + } + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + using ADesc = decltype(wei_gemmk_gemmm_global_desc); + using BDesc = decltype(in_gemmk_gemmn_global_desc); + using CDesc = decltype(out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); + + DeviceMem wei_gemmk_gemmm_global_desc_device_buf(sizeof(ADesc)); + DeviceMem in_gemmk_gemmn_global_desc_device_buf(sizeof(BDesc)); + DeviceMem out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf(sizeof(CDesc)); + + wei_gemmk_gemmm_global_desc_device_buf.ToDevice(&wei_gemmk_gemmm_global_desc); + in_gemmk_gemmn_global_desc_device_buf.ToDevice(&in_gemmk_gemmn_global_desc); + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf.ToDevice( + &out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc); + + index_t nrepeat = 100; + + for(index_t i = 0; i < 5; ++i) + { + std::cout << "Start running " << nrepeat << " times..." << std::endl; + + KernelTimer timer; + timer.Start(); + + for(index_t j = 0; j < nrepeat; ++j) + { + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), + p_wei_global, + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer(), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), + p_wei_global, + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer(), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), + p_wei_global, + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer(), + p_out_global, + integral_constant{}, + integral_constant{}); + } + else + { + const auto kernel = run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc_device_buf.GetDeviceBuffer(), + p_wei_global, + in_gemmk_gemmn_global_desc_device_buf.GetDeviceBuffer(), + p_in_global, + out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc_desc_device_buf + .GetDeviceBuffer(), + p_out_global, + integral_constant{}, + integral_constant{}); + } + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } +#endif + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..54a9370999 --- /dev/null +++ b/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp @@ -0,0 +1,353 @@ +#ifndef CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP +#define CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm_v2.hpp" +#include "gridwise_operation_wrapper.hpp" + +namespace ck { + +template +struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad +{ + template + __host__ void Run(const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, + const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, + const DynamicTensorDescriptor& out_n_k_ho_wo_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const FloatAB* __restrict__ p_wei_global, + const FloatAB* __restrict__ p_in_global, + FloatC* __restrict__ p_out_global) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); + const auto X = wei_k_c_y_x_global_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + // weight tensor + const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + // input tensor + const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_hip_wip_global_desc, + make_tuple( + make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_gemmk_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_y_ho_x_wo_global_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_pass_through_transform(N), + make_pass_through_transform(Ho), + make_pass_through_transform(Wo)), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + // output tensor + const auto out_gemmm_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo)), + make_tuple(make_pass_through_transform(K), + make_pass_through_transform(N), + make_pass_through_transform(Ho), + make_pass_through_transform(Wo)), + make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto E = C * Y * X; + + if(!(K % KPerBlock == 0 && Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0 && + E % EPerBlock == 0)) + { + throw std::runtime_error("wrong! GEMM size no divisible"); + } + + // hack to control index calculation when iterating over a_k_m_global tensor + constexpr auto a_k_m_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); + + constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{}; + + constexpr auto b_k_n_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); + + constexpr auto b_k_n_global_move_slice_window_iterator_hack = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}; + + // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor + // hack for NKHW format + constexpr auto c_k_n_h_w_global_tensor_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{})); + +#if 1 + // GEMM + using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v3< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + InMemoryDataOperation::Set, + decltype(wei_gemmk_gemmm_global_desc), + decltype(in_gemmk_n_ho_wo_global_desc), + decltype(out_gemmm_n_ho_wo_global_desc), + KPerBlock, + HoPerBlock, + WoPerBlock, + EPerBlock, + KPerThread, + HoPerThread, + WoPerThread, + EPerThread, + ABlockTransferThreadSliceLengths_E_K, + ABlockTransferThreadClusterLengths_E_K, + Sequence<1, 0>, + Sequence<1, 0>, + 0, + ABlockTransferSrcScalarPerVector_E, + ABlockTransferDstScalarPerVector_K, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 2, 3, 1>, + 3, + BThreadTransferSrcScalarPerVector_W, + false, // don't move back src coordinate after threadwise copy, which will be fused with + // MoveSrcSliceWindow() to save addr computation + Sequence<0, 2, 3, 1>, + 3, + CThreadTransferDstScalarPerVector_W, + decltype(a_k_m_global_iterator_hacks), + decltype(b_k_n_global_iterator_hacks), + decltype(c_k_n_h_w_global_tensor_iterator_hacks), + decltype(a_k_m_global_move_slice_window_iterator_hack), + decltype(b_k_n_global_move_slice_window_iterator_hack)>; + + const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N; + + const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1; + + const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0; + + index_t nrepeat = 100; + + for(index_t i = 0; i < 5; ++i) + { + std::cout << "Start running " << nrepeat << " times..." << std::endl; + + KernelTimer timer; + timer.Start(); + std::cout << "has_main_k_block_loop: " << has_main_k_block_loop + << " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop + << std::endl; + + for(index_t j = 0; j < nrepeat; ++j) + { + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_n_ho_wo_global_desc, + p_in_global, + out_gemmm_n_ho_wo_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_n_ho_wo_global_desc, + p_in_global, + out_gemmm_n_ho_wo_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_n_ho_wo_global_desc, + p_in_global, + out_gemmm_n_ho_wo_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + else + { + const auto kernel = + run_gridwise_operation, + integral_constant>; + + launch_kernel(kernel, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + wei_gemmk_gemmm_global_desc, + p_wei_global, + in_gemmk_n_ho_wo_global_desc, + p_in_global, + out_gemmm_n_ho_wo_global_desc, + p_out_global, + integral_constant{}, + integral_constant{}); + } + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc, + wei_k_c_y_x_global_desc, + out_n_k_ho_wo_global_desc) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } +#endif + } +}; +} // namespace ck +#endif diff --git a/composable_kernel/include/gridwise_operation_wrapper.hpp b/composable_kernel/include/gridwise_operation_wrapper.hpp index 746e41ce33..0a1e07ec57 100644 --- a/composable_kernel/include/gridwise_operation_wrapper.hpp +++ b/composable_kernel/include/gridwise_operation_wrapper.hpp @@ -2,7 +2,11 @@ #define CK_GRIDWISE_OPERATION_KERNEL_WRAPPER template -__global__ void run_gridwise_operation(Xs... xs) +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + run_gridwise_operation(Xs... xs) { GridwiseOp{}.Run(xs...); } diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp index 286a1c995b..05e4c54a61 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp @@ -107,8 +107,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id()); - const index_t e_block_data_on_global = block_work_id[0] * EPerBlock; - const index_t b_block_data_on_global = block_work_id[1] * BPerBlock; + const index_t e_block_data_on_global = block_work_id[Number<0>{}] * EPerBlock; + const index_t b_block_data_on_global = block_work_id[Number<1>{}] * BPerBlock; // output tensor // global tensor in global memory, src of blockwise copy @@ -151,7 +151,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl AddressSpace::Vgpr, AddressSpace::Lds, InMemoryDataOperation::Set>( - {0, b_block_data_on_global, 0}, {0, 0, 0}); + make_multi_index(0, b_block_data_on_global, 0), make_multi_index(0, 0, 0)); // weight tensor // global tensor in global memory, src of blockwise copy @@ -191,7 +191,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl AddressSpace::Vgpr, AddressSpace::Lds, InMemoryDataOperation::Set>( - {0, e_block_data_on_global, 0}, {0, 0, 0}); + make_multi_index(0, e_block_data_on_global, 0), make_multi_index(0, 0, 0)); // GEMM definition // c_mtx += transpose(a_mtx) * b_mtx @@ -354,7 +354,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl { #if 1 // debug - // input: register to global memory, atomic add + // input: register to global memory, atomic add constexpr auto in_memory_op = (Y <= ConvStrideH && X <= ConvStrideW) ? InMemoryDataOperation::Set : InMemoryDataOperation::AtomicAdd; @@ -434,13 +434,13 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_doubl InThreadCopyDstDataPerWrite_B, AddressSpace::Vgpr, AddressSpace::Global, - in_memory_op>({0, 0, 0, 0, 0, 0}, - {e_thread_data_on_global / E1, - e_thread_data_on_global % E1, - 0, - b_thread_data_on_global / B1, - b_thread_data_on_global % B1, - 0}) + in_memory_op>(make_multi_index(0, 0, 0, 0, 0, 0), + make_multi_index(e_thread_data_on_global / E1, + e_thread_data_on_global % E1, + 0, + b_thread_data_on_global / B1, + b_thread_data_on_global % B1, + 0)) .Run(p_in_thread, p_in_global); } } diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp index 24422daeda..1e8eb7cea1 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp @@ -125,7 +125,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk index_t GemmK1 = XDotSlice; index_t GemmK2 = K; - return Array{GemmM, GemmN, GemmK0, GemmK1, GemmK2}; + return make_multi_index(GemmM, GemmN, GemmK0, GemmK1, GemmK2); } __host__ __device__ static constexpr auto GetGemmSize(index_t gemm_id) diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp similarity index 95% rename from composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp rename to composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp index 37be0c60c2..d270a24467 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp @@ -1,5 +1,5 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP -#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP +#ifndef CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP +#define CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP #include "common_header.hpp" #include "tensor_descriptor.hpp" @@ -49,7 +49,7 @@ template -struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer +struct GridwiseConvolutionForwardImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer { __device__ void Run(const Float* const __restrict__ p_in_global, const Float* const __restrict__ p_wei_global, @@ -119,8 +119,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id()); - const index_t k_block_data_on_global = block_work_id[0] * KPerBlock; - const index_t b_block_data_on_global = block_work_id[1] * BPerBlock; + const index_t k_block_data_on_global = block_work_id[I0] * KPerBlock; + const index_t b_block_data_on_global = block_work_id[I1] * BPerBlock; // input tensor // global tensor in global memory @@ -183,7 +183,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer AddressSpace::Vgpr, AddressSpace::Lds, InMemoryDataOperation::Set>( - {0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0}); + make_multi_index(0, 0, b_block_data_on_global, 0), make_multi_index(0, 0, 0, 0)); // weight tensor // global tensor in global memory, src of blockwise copy @@ -226,7 +226,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer AddressSpace::Vgpr, AddressSpace::Lds, InMemoryDataOperation::Set>( - {0, k_block_data_on_global}, {0, 0}); + make_multi_index(0, k_block_data_on_global), make_multi_index(0, 0)); // GEMM definition // c_mtx += transpose(a_mtx) * b_mtx @@ -439,12 +439,12 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer 1, AddressSpace::Vgpr, AddressSpace::Global, - InMemoryDataOperation::Set>({0, 0, 0, 0, 0}, - {k_thread_data_on_global / K1, - k_thread_data_on_global % K1, - 0, - b_thread_data_on_global, - 0}) + InMemoryDataOperation::Set>(make_multi_index(0, 0, 0, 0, 0), + make_multi_index(k_thread_data_on_global / K1, + k_thread_data_on_global % K1, + 0, + b_thread_data_on_global, + 0)) .Run(p_out_thread, p_out_global); } } diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp similarity index 97% rename from composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp rename to composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp index 5e4f621807..b8090321a9 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -1,5 +1,5 @@ -#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP -#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP +#ifndef CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP +#define CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP #include "common_header.hpp" #include "tensor_descriptor.hpp" @@ -43,7 +43,7 @@ template -struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw +struct GridwiseConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw { __device__ void Run(const Float* const __restrict__ p_in_global, const Float* const __restrict__ p_wei_global, diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..ac3e35f2db --- /dev/null +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,162 @@ +#ifndef CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NHWC_KYXC_NHWK_HPP +#define CK_GRIDWISE_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm.hpp" + +namespace ck { + +// GemmM = K +// GemmN = N * Ho * Wo +// GemmK = C * Y * X +template +struct GridwiseConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk +{ + __device__ void Run(const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_n_hi_wi_c_global_desc = InGlobalDesc{}; + constexpr auto wei_k_y_x_c_global_desc = WeiGlobalDesc{}; + constexpr auto out_n_ho_wo_k_global_desc = OutGlobalDesc{}; + + constexpr index_t N = in_n_hi_wi_c_global_desc.GetLengths()[I0]; + constexpr index_t Hi = in_n_hi_wi_c_global_desc.GetLengths()[I1]; + constexpr index_t Wi = in_n_hi_wi_c_global_desc.GetLengths()[I2]; + constexpr index_t C = in_n_hi_wi_c_global_desc.GetLengths()[I3]; + + constexpr index_t K = out_n_ho_wo_k_global_desc.GetLengths()[I3]; + constexpr index_t Ho = out_n_ho_wo_k_global_desc.GetLengths()[I1]; + constexpr index_t Wo = out_n_ho_wo_k_global_desc.GetLengths()[I2]; + + constexpr index_t Y = wei_k_y_x_c_global_desc.GetLengths()[I1]; + constexpr index_t X = wei_k_y_x_c_global_desc.GetLengths()[I2]; + + constexpr index_t ConvStrideH = ConvStrides{}[I0]; + constexpr index_t ConvStrideW = ConvStrides{}[I1]; + + constexpr index_t ConvDilationH = ConvDilations{}[I0]; + constexpr index_t ConvDilationW = ConvDilations{}[I1]; + + // weight tensor + constexpr auto wei_gemmk_gemmm_global_desc = reorder_tensor_descriptor_given_upper2lower( + unfold_tensor_descriptor(wei_k_y_x_c_global_desc, I1, I3), Sequence<1, 0>{}); + + // input tensor + constexpr auto in_n_hip_wip_c_global_desc = + transform_tensor_descriptor(in_n_hi_wi_c_global_desc, + make_tuple(PassThrough{}, + Pad, InLeftPads, InRightPads>{}, + PassThrough{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + constexpr index_t Hip = in_n_hip_wip_c_global_desc.GetLengths()[I1]; + constexpr index_t Wip = in_n_hip_wip_c_global_desc.GetLengths()[I2]; + + constexpr auto in_n_y_ho_x_wo_c_global_desc = transform_tensor_descriptor( + in_n_hip_wip_c_global_desc, + make_tuple(PassThrough{}, + Embed, Sequence>{}, + Embed, Sequence>{}, + PassThrough{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + constexpr auto in_gemmk_gemmn_global_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_global_desc, + make_tuple(Merge>{}, Merge>{}), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // output tensor + constexpr auto out_gemmm_gemmn_global_desc = transform_tensor_descriptor( + unfold_tensor_descriptor(out_n_ho_wo_k_global_desc, I0, I2), + make_tuple(PassThrough{}, Merge>{}), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // GEMM + constexpr auto gridwise_gemm = + GridwiseGemmTransposedANormalBNormalC_v1, + Sequence<1, 0>, + 0, + GemmABlockCopySrcDataPerRead_GemmK, + GemmABlockCopyDstDataPerWrite_GemmM, + GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, + GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, + Sequence<1, 0>, + Sequence<1, 0>, + 0, + GemmBBlockCopySrcDataPerRead_GemmK, + GemmBBlockCopyDstDataPerWrite_GemmN, + Sequence<2, 3, 0, 1>, + 1, + GemmCThreadCopyDstDataPerWrite_GemmM1>{}; + + gridwise_gemm.Run(p_wei_global, p_in_global, p_out_global); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/array_multi_index.hpp b/composable_kernel/include/tensor_description/array_multi_index.hpp new file mode 100644 index 0000000000..f692fb5143 --- /dev/null +++ b/composable_kernel/include/tensor_description/array_multi_index.hpp @@ -0,0 +1,77 @@ +#ifndef CK_ARRAY_MULTI_INDEX_HPP +#define CK_ARRAY_MULTI_INDEX_HPP + +#include "common_header.hpp" + +namespace ck { + +template +using MultiIndex = Array; + +template +__host__ __device__ constexpr auto make_multi_index(Xs&&... xs) +{ + return make_array(index_t{xs}...); +} + +template +__host__ __device__ constexpr auto make_zero_multi_index() +{ + return unpack([](auto... xs) { return make_multi_index(xs...); }, + typename uniform_sequence_gen::type{}); +} + +template +__host__ __device__ constexpr auto to_multi_index(const T& x) +{ + return unpack([](auto... ys) { return make_multi_index(ys...); }, x); +} + +template +__host__ __device__ constexpr auto operator+=(MultiIndex& y, const X& x) +{ + static_assert(X::Size() == NSize, "wrong! size not the same"); + static_for<0, NSize, 1>{}([&](auto i) { y(i) += x[i]; }); + return y; +} + +template +__host__ __device__ constexpr auto operator-=(MultiIndex& y, const X& x) +{ + static_assert(X::Size() == NSize, "wrong! size not the same"); + static_for<0, NSize, 1>{}([&](auto i) { y(i) -= x[i]; }); + return y; +} + +template +__host__ __device__ constexpr auto operator+(const MultiIndex& a, const T& b) +{ + using type = MultiIndex; + static_assert(T::Size() == NSize, "wrong! size not the same"); + type r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] + b[i]; }); + return r; +} + +template +__host__ __device__ constexpr auto operator-(const MultiIndex& a, const T& b) +{ + using type = MultiIndex; + static_assert(T::Size() == NSize, "wrong! size not the same"); + type r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] - b[i]; }); + return r; +} + +template +__host__ __device__ constexpr auto operator*(const MultiIndex& a, const T& b) +{ + using type = MultiIndex; + static_assert(T::Size() == NSize, "wrong! size not the same"); + type r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] * b[i]; }); + return r; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/cluster_descriptor.hpp b/composable_kernel/include/tensor_description/cluster_descriptor.hpp new file mode 100644 index 0000000000..96dbe07073 --- /dev/null +++ b/composable_kernel/include/tensor_description/cluster_descriptor.hpp @@ -0,0 +1,48 @@ +#ifndef CK_CLUSTER_DESCRIPTOR_HPP +#define CK_CLUSTER_DESCRIPTOR_HPP + +#include "common_header.hpp" + +// TODO remove dependency on deprecated tensor descriptor +#include "tensor_descriptor.hpp" + +namespace ck { + +// a cluster map 1d index to N-d index +template +struct ClusterDescriptor +{ + static constexpr index_t nDim = Lengths::Size(); + + static constexpr auto mDesc = transform_tensor_descriptor( + make_native_tensor_descriptor_packed(Lengths{}), + make_tuple(Merge{}), + make_tuple(ArrangeOrder{}), + make_tuple(Sequence<0>{})); + + __host__ __device__ constexpr ClusterDescriptor() + { + static_assert(Lengths::Size() == nDim && ArrangeOrder::Size() == nDim, + "wrong! size not the same"); + + static_assert(is_valid_sequence_map{}, "wrong! ArrangeOrder is wrong"); + } + + __host__ __device__ static constexpr index_t GetElementSize() { return mDesc.GetElementSize(); } + + __host__ __device__ static constexpr auto CalculateClusterIndex(index_t idx_1d) + { + return mDesc.CalculateLowerIndex(MultiIndex<1>{idx_1d}); + } +}; + +template ::type> +__host__ __device__ constexpr auto make_cluster_descriptor( + Lengths, ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{}) +{ + return ClusterDescriptor{}; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp b/composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp new file mode 100644 index 0000000000..429473c8f6 --- /dev/null +++ b/composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp @@ -0,0 +1,1157 @@ +#ifndef CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HPP +#define CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HPP + +#include "common_header.hpp" +#include "multi_index.hpp" + +namespace ck { + +template +struct DynamicPassThrough +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(LowLength{})); + + UpLengths up_lengths_; + + __host__ __device__ constexpr DynamicPassThrough() = default; + + __host__ __device__ constexpr DynamicPassThrough(const LowLength& low_length) + : up_lengths_{make_tuple(low_length)} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ static void CalculateLowerIndex(LowIdx& idx_low, const UpIdx& idx_up) + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}]; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("DynamicPassThrough, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +struct DynamicPad +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(LowLength{} + LeftPad{} + RightPad{})); + + UpLengths up_lengths_; + LeftPad left_pad_; + RightPad right_pad_; + + __host__ __device__ constexpr DynamicPad() = default; + + __host__ __device__ constexpr DynamicPad(const LowLength& low_length, + const LeftPad& left_pad, + const RightPad& right_pad) + : up_lengths_{make_tuple(low_length + left_pad + right_pad)}, + left_pad_{left_pad}, + right_pad_{right_pad} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return SkipIsValidCheck; + } + + template + __host__ __device__ constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const + { + return SkipIsValidCheck || ((idx_up[Number<0>{}] >= left_pad_) && + (idx_up[Number<0>{}] < up_lengths_[Number<0>{}] - right_pad_)); + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("DynamicPad, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("left_pad_ %d", index_t{left_pad_}); + printf("right_pad_ %d", index_t{right_pad_}); + printf("}"); + } +}; + +template +struct DynamicLeftPad +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(LowLength{} + LeftPad{})); + + UpLengths up_lengths_; + LeftPad left_pad_; + + __host__ __device__ constexpr DynamicLeftPad() = default; + + __host__ __device__ constexpr DynamicLeftPad(const LowLength& low_length, + const LeftPad& left_pad) + : up_lengths_{make_tuple(low_length + left_pad)}, left_pad_{left_pad} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return SkipIsValidCheck; + } + + template + __host__ __device__ constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const + { + return SkipIsValidCheck || (idx_up[Number<0>{}] >= left_pad_); + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("DynamicLeftPad, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("left_pad_ %d", index_t{left_pad_}); + printf("}"); + } +}; + +template +struct DynamicRightPad +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(LowLength{} + RightPad{})); + + UpLengths up_lengths_; + LowLength low_length_; + RightPad right_pad_; + + __host__ __device__ constexpr DynamicRightPad() = default; + + __host__ __device__ constexpr DynamicRightPad(const LowLength& low_length, + const RightPad& right_pad) + : up_lengths_{make_tuple(low_length + right_pad)}, + low_length_{low_length}, + right_pad_{right_pad} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}]; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return SkipIsValidCheck; + } + + template + __host__ __device__ constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const + { + return SkipIsValidCheck || (idx_up[Number<0>{}] < low_length_); + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("DynamicRightPad, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("low_length_ %d", index_t{low_length_}); + printf("left_pad_ %d", index_t{right_pad_}); + printf("}"); + } +}; + +// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] +// UpLengths and Coefficients can be either of the followings: +// 1) Tuple of index_t, which is known at run-time, or +// 2) Tuple of Number, which is known at compile-time, or +// 3) Tuple of mixture of index_t and Number, which is known partially at run-time and partially +// at compile-time +template ::type = false> +struct DynamicEmbed +{ + static constexpr index_t NDimUp = UpLengths::Size(); + + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex; + + UpLengths up_lengths_; + Coefficients coefficients_; + + __host__ __device__ constexpr DynamicEmbed() = default; + + __host__ __device__ constexpr DynamicEmbed(const UpLengths& up_lengths, + const Coefficients& coefficients) + : up_lengths_{up_lengths}, coefficients_{coefficients} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == NDimUp, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = 0; + + static_for<0, NDimUp, 1>{}([&idx_low, &idx_up, this](auto i) { + idx_low(Number<0>{}) += idx_up[i] * this->coefficients_[i]; + }); + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == NDimUp && + LowIdx::Size() == 1 && UpIdx::Size() == NDimUp, + "wrong! inconsistent # of dimension"); + + idx_diff_low(Number<0>{}) = 0; + + static_for<0, NDimUp, 1>{}( + [&](auto i) { idx_diff_low(Number<0>{}) += idx_diff_up[i] * coefficients_[i]; }); + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("DynamicEmbed, "); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("coefficients_ "); + print_multi_index(coefficients_); + printf("}"); + } +}; + +template +struct DynamicMerge +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using LowLengthsScan = decltype( + container_reverse_exclusive_scan(LowLengths{}, math::multiplies_v2{}, Number<1>{})); + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies_v2{}, Number<1>{}))); + + LowLengths low_lengths_; + LowLengthsScan low_lengths_scan_; + UpLengths up_lengths_; + + __host__ __device__ constexpr DynamicMerge() = default; + + __host__ __device__ constexpr DynamicMerge(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_scan_{ + container_reverse_exclusive_scan(low_lengths, math::multiplies_v2{}, Number<1>{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies_v2{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&idx_low, &tmp, this](auto i) { + idx_low(i) = tmp / this->low_lengths_scan_[i]; + tmp -= idx_low[i] * this->low_lengths_scan_[i]; + }); + + idx_low(Number{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex_1a(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& /* idx_up_new */, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions. + // However, + // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const + // can be calculated at compile-time. + // 2) If idx_diff_up is not known at compile-time, but its value + // doesn't change during the whole kernel execution, then + // idx_diff_low_const also + // doesn't change during the whole kernel execution. Compiler generated + // ISA should + // only caclculate idx_diff_low_const once and save it durinng the whole + // kernel execution + // If neither 1) nor 2) is satisfied, then the calculation will also be + // computed at + // run-time each time this function is called, and can be very expensive. + LowerIndex idx_diff_low_const; + LowerIndex idx_low_length_minus_idx_diff_low_const; + LowerIndex idx_low_length_plus_idx_diff_low_const; + +#if !CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = tmp / low_lengths_scan_[i]; + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = tmp; + + static_for<0, NDimLow, 1>{}([&](auto i) { + idx_low_length_minus_idx_diff_low_const(i) = low_lengths_[i] - idx_diff_low_const[i]; + + idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i]; + }); +#else + // Hack: this force result into SGPR. Need to make sure the result is thread invariant + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]); + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = __builtin_amdgcn_readfirstlane(tmp); + + static_for<0, NDimLow, 1>{}([&](auto i) { + idx_low_length_minus_idx_diff_low_const(i) = + __builtin_amdgcn_readfirstlane(low_lengths_[i] - idx_diff_low_const[i]); + + idx_low_length_plus_idx_diff_low_const(i) = + __builtin_amdgcn_readfirstlane(low_lengths_[i] + idx_diff_low_const[i]); + }); +#endif + + if constexpr(Hack == 1) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t carry = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] + carry; + + bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i]; + + idx_diff_low(i) = + do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i]; + + idx_diff_low(i) += carry; + + carry = do_carry ? 1 : 0; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry; + + idx_low += idx_diff_low; + } + else if constexpr(Hack == 2) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t borrow = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] - borrow; + + bool do_borrow = idx_low_tmp < -idx_diff_low_const[i]; + + idx_diff_low(i) = + do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low_const[i]; + + idx_diff_low(i) -= borrow; + + borrow = do_borrow ? 1 : 0; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] - borrow; + + idx_low += idx_diff_low; + } + else + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t carry = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] + carry; + + bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i]; + bool do_borrow = idx_low_tmp < -idx_diff_low_const[i]; + + idx_diff_low(i) = + do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i]; + idx_diff_low(i) = + do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low[i]; + + idx_diff_low(i) += carry; + + carry = do_carry ? 1 : 0; + carry = do_borrow ? -1 : carry; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry; + + idx_low += idx_diff_low; + } + } + + template + __host__ __device__ void UpdateLowerIndex_1b(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& /* idx_up_new */, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions. + // However, + // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const + // can be calculated at compile-time. + // 2) If idx_diff_up is not known at compile-time, but its value + // doesn't change during the whole kernel execution, then + // idx_diff_low_const also + // doesn't change during the whole kernel execution. Compiler generated + // ISA should + // only caclculate idx_diff_low_const once and save it durinng the whole + // kernel execution + // If neither 1) nor 2) is satisfied, then the calculation will also be + // computed at + // run-time each time this function is called, and can be very expensive. + LowerIndex idx_diff_low_const; + LowerIndex idx_low_length_minus_idx_diff_low_const; + LowerIndex idx_low_length_plus_idx_diff_low_const; + +#if !CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = tmp / low_lengths_scan_[i]; + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = tmp; + + static_for<0, NDimLow, 1>{}([&](auto i) { + idx_low_length_minus_idx_diff_low_const(i) = low_lengths_[i] - idx_diff_low_const[i]; + + idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i]; + }); +#else + // Hack: this force result into SGPR. Need to make sure the result is thread invariant + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]); + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = __builtin_amdgcn_readfirstlane(tmp); + + static_for<0, NDimLow, 1>{}([&](auto i) { + idx_low_length_minus_idx_diff_low_const(i) = + __builtin_amdgcn_readfirstlane(low_lengths_[i] - idx_diff_low_const[i]); + + idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i]; + }); +#endif + + if constexpr(Hack == 1) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t carry = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] + carry; + + bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i]; + + idx_diff_low(i) = + do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i]; + + idx_diff_low(i) += carry; + + carry = do_carry ? 1 : 0; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry; + + idx_low += idx_diff_low; + } + else if constexpr(Hack == 2) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t borrow = 0; + + static_for{}([&](auto i) { + index_t negative_idx_low_tmp = borrow - idx_low[i]; + + bool do_borrow = negative_idx_low_tmp > idx_diff_low_const[i]; + + idx_diff_low(i) = + do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low_const[i]; + + idx_diff_low(i) -= borrow; + + borrow = do_borrow ? 1 : 0; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] - borrow; + + idx_low += idx_diff_low; + } + else + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t carry = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] + carry; + + bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i]; + bool do_borrow = idx_low_tmp < -idx_diff_low_const[i]; + + idx_diff_low(i) = + do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i]; + idx_diff_low(i) = + do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low[i]; + + idx_diff_low(i) += carry; + + carry = do_carry ? 1 : 0; + carry = do_borrow ? -1 : carry; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry; + + idx_low += idx_diff_low; + } + } + + template + __host__ __device__ void UpdateLowerIndex_2(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& /* idx_up_new */, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions. + // However, + // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const + // can be calculated at compile-time. + // 2) If idx_diff_up is not known at compile-time, but its value + // doesn't change during the whole kernel execution, then + // idx_diff_low_const also + // doesn't change during the whole kernel execution. Compiler generated + // ISA should + // only caclculate idx_diff_low_const once and save it durinng the whole + // kernel execution + // If neither 1) nor 2) is satisfied, then the calculation will also be + // computed at run-time each time this function is called, and can be + // very expensive. + LowerIndex idx_diff_low_const; + +#if !CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = tmp / low_lengths_scan_[i]; + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = tmp; +#else + // Hack: this force result into SGPR. Need to make sure the result is thread invariant + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]); + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = __builtin_amdgcn_readfirstlane(tmp); +#endif + + if constexpr(Hack == 1) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + bool do_carry = 0; + + static_for{}([&](auto i) { + idx_diff_low(i) = idx_diff_low_const[i] + do_carry; + + index_t idx_low_tmp = idx_low[i] + idx_diff_low[i]; + + do_carry = idx_low_tmp >= low_lengths_[i]; + +#if 0 + // TODO: use exec-mask inline asm, which use 1 VALU + if(do_carry) + { + idx_diff_low(i) -= low_lengths_[i]; + } +#elif 1 + // this use 2 VALU + idx_diff_low(i) = do_carry ? idx_diff_low[i] - low_lengths_[i] : idx_diff_low[i]; +#elif 1 + // this use 2 VALU + index_t idx_diff_low_tmp = idx_diff_low[i] - low_lengths_[i]; + idx_diff_low(i) = do_carry ? idx_diff_low_tmp : idx_diff_low[i]; +#endif + + idx_low(i) += idx_diff_low[i]; + }); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_low_const[I0] + do_carry; + + idx_low(I0) += idx_diff_low[I0]; + } + else if constexpr(Hack == 2) + { + // do borrow check on each low dimension in reversed order + // do not need to check the first dimension + bool do_borrow = 0; + + static_for{}([&](auto i) { + idx_diff_low(i) = idx_diff_low_const[i] - do_borrow; + + index_t idx_low_tmp = idx_low[i] + idx_diff_low[i]; + + do_borrow = idx_low_tmp < 0; + +#if 0 + // TODO: use exec-mask inline asm + if(do_borrow) + { + idx_diff_low(i) += low_lengths_[i]; + } +#elif 1 + idx_diff_low(i) = do_borrow ? idx_diff_low[i] + low_lengths_[i] : idx_diff_low[i]; +#elif 1 + index_t idx_diff_low_tmp = idx_diff_low[i] + low_lengths_[i]; + idx_diff_low(i) = do_borrow ? idx_diff_low_tmp : idx_diff_low[i]; +#endif + + idx_low(i) += idx_diff_low[i]; + }); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_low_const[I0] - do_borrow; + + idx_low(I0) += idx_diff_low[I0]; + } + else + { + // not implemented + } + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { +#if 1 + UpdateLowerIndex_1a(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number{}); +#elif 0 + UpdateLowerIndex_1b(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number{}); +#else + UpdateLowerIndex_2(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number{}); +#endif + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("DynamicMerge, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_scan_ "); + print_multi_index(low_lengths_scan_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +struct DynamicUnMerge +{ + static constexpr index_t NDimUp = UpLengths::Size(); + + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex; + + using UpLengthsScan = + decltype(container_reverse_exclusive_scan(UpLengths{}, math::multiplies_v2{}, Number<1>{})); + + UpLengths up_lengths_; + UpLengthsScan up_lengths_scan_; + + __host__ __device__ constexpr DynamicUnMerge() = default; + + __host__ __device__ constexpr DynamicUnMerge(const UpLengths& up_lengths) + : up_lengths_{up_lengths}, + up_lengths_scan_{ + container_reverse_exclusive_scan(up_lengths, math::multiplies_v2{}, Number<1>{})} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + if constexpr(!Use24BitIntegerCalculation) + { + idx_low(Number<0>{}) = idx_up[Number{}]; + + static_for<0, NDimUp - 1, 1>{}( + [&](auto i) { idx_low(Number<0>{}) += idx_up[i] * up_lengths_scan_[i]; }); + } + else + { + idx_low(Number<0>{}) = idx_up[Number{}]; + + static_for<0, NDimUp - 1, 1>{}([&](auto i) { + idx_low(Number<0>{}) = + (0x00ffffff & idx_low[Number<0>{}]) + + (0x00ffffff & idx_up[i]) * (0x00ffffff & up_lengths_scan_[i]); + }); + } + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { + CalculateLowerIndex(idx_diff_low, idx_diff_up); + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("DynamicUnMerge, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + print_multi_index(up_lengths_scan_); + printf("}"); + } +}; + +template +struct DynamicFreeze +{ + LowerIndex low_idx_; + + __host__ __device__ constexpr DynamicFreeze() = default; + + __host__ __device__ constexpr DynamicFreeze(const LowerIndex& low_idx) : low_idx_{low_idx} {} + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 0; } + + __host__ __device__ static constexpr auto GetUpperLengths() { return Tuple<>{}; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low = low_idx_; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) + { + idx_diff_low(Number<0>{}) = index_t{Number<0>{}}; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("DynamicFreeze"); + printf("low_idx_ %d", index_t{low_idx_}); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp b/composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp new file mode 100644 index 0000000000..f460599ee5 --- /dev/null +++ b/composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp @@ -0,0 +1,74 @@ +#ifndef CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HELPER_HPP +#define CK_DYNAMIC_MULTI_INDEX_TRANSFORM_HELPER_HPP + +#include "common_header.hpp" +#include "dynamic_multi_index_transform.hpp" + +namespace ck { + +template +__host__ __device__ constexpr auto make_pass_through_transform(const LowLength& low_length) +{ + return DynamicPassThrough{low_length}; +} + +template +__host__ __device__ constexpr auto +make_pad_transform(const LowLength& low_length, + const LeftPad& left_pad, + const RightPad& right_pad, + integral_constant = integral_constant{}) +{ + return DynamicPad{ + low_length, left_pad, right_pad}; +} + +template +__host__ __device__ constexpr auto make_left_pad_transform( + const LowLength& low_length, + const LeftPad& left_pad, + integral_constant = integral_constant{}) +{ + return DynamicLeftPad{low_length, left_pad}; +} + +template +__host__ __device__ constexpr auto make_right_pad_transform( + const LowLength& low_length, + const RightPad& right_pad, + integral_constant = integral_constant{}) +{ + return DynamicRightPad{low_length, right_pad}; +} + +template ::type = false> +__host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths, + const Coefficients& coefficients) +{ + return DynamicEmbed{up_lengths, coefficients}; +} + +template +__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths) +{ + return DynamicMerge{low_lengths}; +} + +template +__host__ __device__ constexpr auto make_unmerge_transform( + const UpLengths& up_lengths, + integral_constant = integral_constant{}) +{ + return DynamicUnMerge{up_lengths}; +} + +template +__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx) +{ + return DynamicFreeze{low_idx}; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp b/composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp new file mode 100644 index 0000000000..e2121f1f3e --- /dev/null +++ b/composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp @@ -0,0 +1,608 @@ +#ifndef CK_DYNAMIC_TENSOR_DESCRIPTOR_HPP +#define CK_DYNAMIC_TENSOR_DESCRIPTOR_HPP + +#include "common_header.hpp" +#include "dynamic_multi_index_transform.hpp" + +namespace ck { + +template +struct DynamicTensorCoordinate; + +template +struct DynamicTensorCoordinateIterator; + +template +__host__ __device__ constexpr index_t GetNumOfHiddenDimension(LowerDimensionIdss, + UpperDimensionIdss) +{ + constexpr auto all_low_dim_ids = + unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{}); + + constexpr auto all_up_dim_ids = + unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{}); + + constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); + + using unique_sort_all_dim_ids = typename sequence_unique_sort, + math::equal>::type; + + return unique_sort_all_dim_ids::Size(); +} + +// Transforms: Tuple +// LowerDimensionIdss : Tuple, ...> +// UpperDimensionIdss : Tuple, ...> +// VisibleDimensionIds> : Sequence<...> +template +struct DynamicTensorDescriptor +{ + // TODO make these private + __host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); } + + __host__ __device__ static constexpr index_t GetNumOfVisibleDimension() + { + return VisibleDimensionIds::Size(); + } + + __host__ __device__ static constexpr index_t GetNumOfHiddenDimension() + { + constexpr auto all_low_dim_ids = + unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, + LowerDimensionIdss{}); + + constexpr auto all_up_dim_ids = + unpack([](auto&&... xs) constexpr { return merge_sequences(xs...); }, + UpperDimensionIdss{}); + + constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); + + using unique_sort_all_dim_ids = typename sequence_unique_sort, + math::equal>::type; + + return unique_sort_all_dim_ids::Size(); + } + + __host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms) + { + const auto lengths = generate_tuple( + [&](auto idim_visible) { + constexpr auto tmp = GetTransformAndItsUpperDimension(idim_visible); + + constexpr index_t itran = tmp[Number<0>{}]; + constexpr index_t idim_up = tmp[Number<1>{}]; + constexpr bool found = tmp[Number<2>{}]; + + static_assert(found == true, + "wrong! not found matching transformation and upper-dimension"); + + const auto length = + transforms[Number{}].GetUpperLengths()[Number{}]; + + return length; + }, + Number{}); + + // TODO: make container_reduce support tuple of Number and index_t + return container_reduce(lengths, math::multiplies_v2{}, Number<1>{}); + } + + template + __host__ __device__ static constexpr auto GetTransformAndItsUpperDimension(Number) + { + constexpr auto idim_visible = Number{}; + + constexpr index_t idim_hidden = VisibleDimensionIds::At(idim_visible); + + index_t itran_found = 0; + index_t idim_up_found = 0; + bool found = false; + + static_for<0, ntransform_, 1>{}([&](auto itran) { + constexpr auto up_dim_ids = UpperDimensionIdss{}[itran]; + + static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) { + if constexpr(up_dim_ids[idim_up] == idim_hidden) + { + itran_found = itran; + idim_up_found = idim_up; + found = true; + } + }); + }); + + return make_tuple(itran_found, idim_up_found, found); + } + + constexpr static index_t ntransform_ = GetNumOfTransform(); + constexpr static index_t ndim_visible_ = GetNumOfVisibleDimension(); + constexpr static index_t ndim_hidden_ = GetNumOfHiddenDimension(); + + using VisibleIndex = MultiIndex; + using HiddenIndex = MultiIndex; + using Coordinate = DynamicTensorCoordinate; + + // may be index_t or Number<> + using ElementSize = remove_cv_t; + + public: + __host__ __device__ constexpr DynamicTensorDescriptor() = default; + + __host__ __device__ constexpr DynamicTensorDescriptor(const Transforms& transforms, + ElementSpaceSize element_space_size) + : transforms_{transforms}, + element_size_{InitializeElementSize(transforms)}, + element_space_size_{element_space_size} + + { + static_assert(Transforms::Size() == ntransform_ && + LowerDimensionIdss::Size() == ntransform_ && + UpperDimensionIdss::Size() == ntransform_, + "wrong! inconsistent # of transformations"); + + // TODO check dependency of dimensions is valid + } + + __host__ __device__ static constexpr index_t GetNumOfDimension() + { + return GetNumOfVisibleDimension(); + } + + template + __host__ __device__ constexpr auto GetLength(Number) const + { + static_assert(IDim >= 0 && IDim < ndim_visible_, "wrong! out of range"); + + constexpr auto tmp = GetTransformAndItsUpperDimension(Number{}); + + constexpr index_t itran = tmp[Number<0>{}]; + constexpr index_t idim_up = tmp[Number<1>{}]; + constexpr bool found = tmp[Number<2>{}]; + + static_assert(found == true, + "wrong! not found matching transformation and upper-dimension"); + + return transforms_[Number{}].GetUpperLengths()[Number{}]; + } + + __host__ __device__ constexpr auto GetElementSize() const { return element_size_; } + + __host__ __device__ constexpr auto GetElementSpaceSize() const { return element_space_size_; } + + template + __host__ __device__ constexpr index_t CalculateOffset(const Idx& idx) const + { + static_assert(Idx::Size() == GetNumOfDimension(), "wrong! inconsistent # of dimension"); + + return make_dynamic_tensor_coordinate(*this, idx).GetOffset(); + } + + // TODO make these private + __host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; } + + __host__ __device__ static constexpr auto GetLowerDimensionIdss() + { + return LowerDimensionIdss{}; + } + + __host__ __device__ static constexpr auto GetUpperDimensionIdss() + { + return UpperDimensionIdss{}; + } + + __host__ __device__ static constexpr auto GetVisibleDimensionIds() + { + return VisibleDimensionIds{}; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + bool is_known = true; + + static_for<0, Transforms::Size(), 1>{}([&](auto i) { + is_known &= + remove_cv_t>::IsKnownAtCompileTime(); + }); + + return is_known && is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("DynamicTensorDescriptor, "); + static_for<0, ntransform_, 1>{}([&](auto i) { + printf("transforms: "); + transforms_[i].Print(); + printf("LowerDimensionIds:"); + LowerDimensionIdss{}.At(i).Print(); + printf("UpperDimensionIds:"); + UpperDimensionIdss{}.At(i).Print(); + }); + printf("}"); + + VisibleDimensionIds::Print(); + } + + // TODO make these private + Transforms transforms_; + ElementSize element_size_; + ElementSpaceSize element_space_size_; +}; + +template +struct DynamicTensorCoordinate +{ + // TODO make these private + static constexpr index_t ndim_visible_ = VisibleDimensionIds::Size(); + + using HiddenIndex = MultiIndex; + using VisibleIndex = MultiIndex; + + public: + __host__ __device__ constexpr DynamicTensorCoordinate() = default; + + __host__ __device__ constexpr DynamicTensorCoordinate(const HiddenIndex& idx_hidden) + : idx_hidden_{idx_hidden} + { + } + + __host__ __device__ constexpr auto GetIndex() const { return GetVisibleIndex(); } + + __host__ __device__ constexpr index_t GetOffset() const { return idx_hidden_[Number<0>{}]; } + + // TODO make these private + __host__ __device__ constexpr const auto& GetHiddenIndex() const { return idx_hidden_; } + + __host__ __device__ auto& GetHiddenIndex() { return idx_hidden_; } + + __host__ __device__ constexpr auto GetVisibleIndex() const + { + return get_container_subset(idx_hidden_, VisibleDimensionIds{}); + } + + // TODO make these private + HiddenIndex idx_hidden_; +}; + +template +struct DynamicTensorCoordinateIterator +{ + // TODO make these private + using VisibleIndex = MultiIndex; + + public: + __host__ __device__ constexpr DynamicTensorCoordinateIterator() = default; + + __host__ __device__ constexpr DynamicTensorCoordinateIterator( + const VisibleIndex& idx_diff_visible, const MultiIndex& do_transforms) + : idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms} + { + } + + __host__ __device__ constexpr const auto& GetIndexDiff() const { return GetVisibleIndexDiff(); } + + // TODO make these private + __host__ __device__ constexpr const auto& GetVisibleIndexDiff() const + { + return idx_diff_visible_; + } + + VisibleIndex idx_diff_visible_; + MultiIndex do_transforms_; + + // HACK: control UpdateLowerIndex() + static constexpr UpdateLowerIndexHack update_lower_index_hack_; +}; + +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor, and to put it outside the scope where it is used +// (transform_dynamic_tensor_descriptor) because template cannot be defined inside a function +// template +template +struct lambda_get_up_dim_num +{ + template + __host__ __device__ constexpr auto operator()(I) const + { + using Tran = remove_reference_t; + return Number{}; + } +}; + +template +__host__ __device__ constexpr auto +transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, + const NewTransforms& new_transforms, + NewLowerDimensionOldVisibleIdss, + NewUpperDimensionNewVisibleIdss) +{ + // lower dimension's hidden idss + // convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of + // sequences) + constexpr auto low_dim_hidden_idss = transform_tuples( + // convert lower dimension visible ids (a sequence) to hidden ids (a sequence) + [](auto low_dim_visible_ids) constexpr { + return transform_sequences( + // convert lower dimension visible id to hidden id + [](auto low_dim_visible_id) constexpr { + return OldTensorDescriptor::GetVisibleDimensionIds()[low_dim_visible_id]; + }, + low_dim_visible_ids); + }, + NewLowerDimensionOldVisibleIdss{}); + + constexpr index_t num_new_transform = NewTransforms::Size(); + + // upper dimension's hidden idss + constexpr index_t old_hidden_dim_number = OldTensorDescriptor::GetNumOfHiddenDimension(); + + constexpr auto up_dim_numbers = + generate_sequence(lambda_get_up_dim_num{}, Number{}); + + constexpr auto up_dim_numbers_scan = merge_sequences( + Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus{}, Number<0>{})); + + constexpr auto up_dim_hidden_idss = + generate_tuple([ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr { + return + typename arithmetic_sequence_gen::type{}; + }, + Number{}); + + // new visible dimension's hidden ids + constexpr auto unordered_new_visible_dim_hidden_ids = + unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss); + + constexpr auto new_visible_dim_unordered2ordered = + unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, + NewUpperDimensionNewVisibleIdss{}); + + constexpr auto new_visible_dim_hidden_ids = + unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered); + + // put everything together + const auto all_transforms = container_cat(old_tensor_desc.GetTransforms(), new_transforms); + + constexpr auto all_low_dim_hidden_idss = + container_cat(OldTensorDescriptor::GetLowerDimensionIdss(), low_dim_hidden_idss); + + constexpr auto all_up_dim_hidden_idss = + container_cat(OldTensorDescriptor::GetUpperDimensionIdss(), up_dim_hidden_idss); + + const auto element_space_size = old_tensor_desc.GetElementSpaceSize(); + + return DynamicTensorDescriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{all_transforms, + element_space_size}; +} + +template +__host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDesc& tensor_desc, + const VisibleIndex& idx_visible) +{ + static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), + "wrong! # of dimension inconsistent"); + + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); + constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension(); + constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds(); + + MultiIndex idx_hidden; + + // initialize visible index + set_container_subset(idx_hidden, visible_dim_ids, idx_visible); + + // calculate hidden index + static_for{}([&tensor_desc, &idx_hidden](auto itran_p1) { + auto itran = itran_p1 - Number<1>{}; + const auto& tran = tensor_desc.GetTransforms().At(itran); + constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); + constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran); + + const auto idx_up = get_container_subset(idx_hidden, dims_up); + + MultiIndex idx_low; + + tran.CalculateLowerIndex(idx_low, idx_up); + + set_container_subset(idx_hidden, dims_low, idx_low); + }); + + return DynamicTensorCoordinate{idx_hidden}; +} + +// UpdateLowerIndexHack: Sequence<...> +// HACK: control UpdateLowerIndex +template +__host__ __device__ constexpr auto make_dynamic_tensor_coordinate_iterator( + const TensorDesc&, const VisibleIndex& idx_diff_visible, UpdateLowerIndexHack) +{ + static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), + "wrong! # of dimension inconsistent"); + + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); + constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension(); + constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds(); + + static_assert(UpdateLowerIndexHack::Size() == ntransform, "wrong!"); + + // use index_t for boolean type + auto do_transforms = make_zero_multi_index(); + auto is_non_zero_diff = make_zero_multi_index(); + + // decide do_transform by checkout non-zero index diff components + MultiIndex non_zero_diff_pick_visible; + + static_for<0, ndim_visible, 1>{}( + [&](auto i) { non_zero_diff_pick_visible(i) = (idx_diff_visible[i] != 0); }); + + set_container_subset(is_non_zero_diff, visible_dim_ids, non_zero_diff_pick_visible); + + static_for{}([&](auto itran) { + constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); + constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran); + + const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up); + + MultiIndex non_zero_diff_pick_low; + + // if any of upper index diff components is non-zero, then + // 1) Need to do this transform + // 2) all components of lower index diff will assume to be non-zero and need to be + // computed + const bool idx_diff_up_has_non_zero = container_reduce( + non_zero_diff_pick_up, [](auto a, auto b) constexpr { return a or b; }, false); + + do_transforms(itran) = idx_diff_up_has_non_zero; + + static_for<0, dims_low.Size(), 1>{}( + [&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; }); + + set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low); + }); + + return DynamicTensorCoordinateIterator{ + idx_diff_visible, do_transforms}; +} + +template +__host__ __device__ constexpr auto +make_dynamic_tensor_coordinate_iterator(const TensorDesc&, const VisibleIndex& idx_diff_visible) +{ + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + + return make_dynamic_tensor_coordinate_iterator( + TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen::type{}); +} + +template +__host__ __device__ constexpr void move_dynamic_tensor_coordinate( + const TensorDesc& tensor_desc, TensorCoord& coord, const TensorCoordIterator& coord_iterator) +{ + constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); + constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension(); + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + + using HiddenIndex = MultiIndex; + + // this is what needs to be calculated + auto idx_diff_hidden = make_zero_multi_index(); + + // initialize visible index diff + set_container_subset(idx_diff_hidden, + TensorDesc::GetVisibleDimensionIds(), + coord_iterator.GetVisibleIndexDiff()); + + // this is what needs to be updated + auto& idx_hidden = coord.GetHiddenIndex(); + + // update visible index + auto idx_hidden_pick_visible = + get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds()); + + idx_hidden_pick_visible += coord_iterator.GetIndexDiff(); + + set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible); + + // update rest of hidden index + static_for{}([&](auto itran) { + if(coord_iterator.do_transforms_[itran]) + { + const auto& tran = tensor_desc.GetTransforms().At(itran); + constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); + constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran); + + const auto idx_up_new = get_container_subset(idx_hidden, dims_up); + auto idx_low = get_container_subset(idx_hidden, dims_low); + const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up); + + MultiIndex idx_diff_low; + + // HACK: control UpdateLowerIndex for DynamicMerge using hack + constexpr index_t Hack = decltype(coord_iterator.update_lower_index_hack_)::At(itran); + + tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number{}); + + set_container_subset(idx_diff_hidden, dims_low, idx_diff_low); + set_container_subset(idx_hidden, dims_low, idx_low); + } + }); +} + +template +__host__ __device__ constexpr bool +coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc& tensor_desc, + const TensorCoord& coord) +{ + bool valid = true; + + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + + const auto& idx_hidden = coord.GetHiddenIndex(); + + static_for{}([&tensor_desc, &idx_hidden, &valid](auto itran) { + const auto tran = tensor_desc.GetTransforms().At(itran); + + // check validity, only if current transformation does not always has a valid mapping + if constexpr(!decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex()) + { + const auto idx_up = + get_container_subset(idx_hidden, TensorDesc::GetUpperDimensionIdss().At(itran)); + + // Comment: using valid = valid && .. will result in weird control flow in ISA + valid &= tran.IsValidUpperIndexMappedToValidLowerIndex(idx_up); + } + }); + + return valid; +} + +template +__host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& tensor_desc, + const TensorCoord& coord) +{ + // check visible index + const auto& idx_visible = coord.GetVisibleIndex(); + + bool is_visible_index_valid = true; + + static_for<0, TensorDesc::GetNumOfDimension(), 1>{}( + [&is_visible_index_valid, &idx_visible, &tensor_desc](auto i) { + is_visible_index_valid = + is_visible_index_valid && + (idx_visible[i] >= 0 && idx_visible[i] < tensor_desc.GetLength(i)); + }); + + // check other hidden index + return is_visible_index_valid && + coordinate_has_valid_offset_assuming_visible_index_is_valid(tensor_desc, coord); +} + +template +using DynamicTensorCoordinate_t = decltype(make_dynamic_tensor_coordinate( + TensorDesc{}, MultiIndex>::GetNumOfDimension()>{})); + +template +using DynamicTensorCoordinateIterator_t = decltype(make_dynamic_tensor_coordinate_iterator( + TensorDesc{}, MultiIndex>::GetNumOfDimension()>{})); + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/dynamic_tensor_descriptor_helper.hpp b/composable_kernel/include/tensor_description/dynamic_tensor_descriptor_helper.hpp new file mode 100644 index 0000000000..385edab1c0 --- /dev/null +++ b/composable_kernel/include/tensor_description/dynamic_tensor_descriptor_helper.hpp @@ -0,0 +1,146 @@ +#ifndef CK_DYNAMIC_TENSOR_DESCRIPTOR_HELPER_HPP +#define CK_DYNAMIC_TENSOR_DESCRIPTOR_HELPER_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_multi_index_transform_helper.hpp" + +namespace ck { + +/* + * These functions create tensor descriptor at runtime. If they are not constexpr, you will + * likely see usage of scratch memory during construction of these tensor descriptors. So + * it's better to call these functions on host and then pass the constructed tensor descritpors + * to GPU. If the tensor descritpors being constructed are constexpr, then you can call these + * functions on GPU without worrying about scratch memory usage. + */ + +#if CK_WORKAROUND_SWDEV_275126 +template +__host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengths& lengths, + const Strides& strides, + Number i, + AccOld acc_old) +{ + auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i]; + + if constexpr(i.value < Lengths::Size() - 1) + { + return calculate_element_space_size_impl(lengths, strides, i + Number<1>{}, acc_new); + } + else + { + return acc_new; + } +} +#endif + +template ::type = false> +__host__ __device__ constexpr auto +make_dynamic_naive_tensor_descriptor_v2(const Tuple& lengths, + const Tuple& strides) +{ + constexpr index_t N = sizeof...(Lengths); + + const auto transforms = make_tuple(make_embed_transform(lengths, strides)); + + constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{}); + + constexpr auto up_dim_hidden_idss = + make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{}); + + constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; + +#if !CK_WORKAROUND_SWDEV_275126 + // rocm-4.1 compiler would crash for recursive labmda + // recursive function for reduction + auto f = [&](auto fs, auto i, auto acc_old) { + auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i]; + + if constexpr(i.value < N - 1) + { + return fs(fs, i + Number<1>{}, acc_new); + } + else + { + return acc_new; + } + }; + + const auto element_space_size = f(f, Number<0>{}, Number<1>{}); +#else + const auto element_space_size = + calculate_element_space_size_impl(lengths, strides, Number<0>{}, Number<1>{}); +#endif + + return DynamicTensorDescriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{transforms, + element_space_size}; +} + +// Lengths... can be: +// 1) index_t, which is known at run-time +// 2) Number<>, which is known at compile-time +template +__host__ __device__ constexpr auto +make_dynamic_naive_tensor_descriptor_packed_v2(const Tuple& lengths) +{ + constexpr index_t N = sizeof...(Lengths); + + const auto transforms = make_tuple(make_unmerge_transform(lengths)); + + constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{}); + + constexpr auto up_dim_hidden_idss = + make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{}); + + constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; + + const auto element_space_size = container_reduce(lengths, math::multiplies_v2{}, Number<1>{}); + + return DynamicTensorDescriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{transforms, + element_space_size}; +} + +template +__host__ __device__ constexpr auto +make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple& lengths, Align align) +{ + constexpr index_t N = sizeof...(Lengths); + + auto strides = generate_tuple( + [&](auto i) { + if constexpr(i.value == N - 1) + { + return Number<1>{}; + } + else if constexpr(i.value == N - 2) + { + return math::lcm(lengths[Number{}], align); + } + else + { + return container_reduce(lengths, + math::multiplies_v2{}, + math::lcm(lengths[Number{}], align), + i, + Number{}, + Number<1>{}); + } + }, + Number{}); + + return make_dynamic_naive_tensor_descriptor_v2(lengths, strides); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/multi_index.hpp b/composable_kernel/include/tensor_description/multi_index.hpp new file mode 100644 index 0000000000..0bb34fb1e2 --- /dev/null +++ b/composable_kernel/include/tensor_description/multi_index.hpp @@ -0,0 +1,12 @@ +#ifndef CK_MULTI_INDEX_HPP +#define CK_MULTI_INDEX_HPP + +#include "common_header.hpp" + +#if CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX +#include "array_multi_index.hpp" +#else +#include "statically_indexed_array_multi_index.hpp" +#endif + +#endif diff --git a/composable_kernel/include/tensor_description/multi_index_transform.hpp b/composable_kernel/include/tensor_description/multi_index_transform.hpp index 15a052ea31..b612e1e52f 100644 --- a/composable_kernel/include/tensor_description/multi_index_transform.hpp +++ b/composable_kernel/include/tensor_description/multi_index_transform.hpp @@ -2,18 +2,10 @@ #define CK_MULTI_INDEX_TRANSFORM_HPP #include "common_header.hpp" +#include "multi_index.hpp" namespace ck { -template -using MultiIndex = Array; - -template -__host__ __device__ constexpr auto make_multi_index(Xs... xs) -{ - return MultiIndex(xs...); -} - template struct PassThrough { @@ -62,7 +54,7 @@ struct Pad using LowerIndex = MultiIndex; using UpperIndex = MultiIndex; - __host__ __device__ explicit constexpr Pad() + __host__ __device__ constexpr Pad() { static_assert(LowerLengths::GetSize() == nDim && LeftPads::GetSize() == nDim && RightPads::GetSize() == nDim, @@ -123,7 +115,7 @@ struct Slice using LowerIndex = MultiIndex; using UpperIndex = MultiIndex; - __host__ __device__ explicit constexpr Slice() + __host__ __device__ constexpr Slice() { static_assert(LowerLengths::GetSize() == nDim && SliceBegins::GetSize() == nDim && SliceEnds::GetSize() == nDim, @@ -197,8 +189,8 @@ struct Merge index_t& itmp; LowerIndex& idx_low; - __host__ __device__ explicit constexpr lambda_CalculateLowerIndex(index_t& itmp_, - LowerIndex& idx_low_) + __host__ __device__ constexpr lambda_CalculateLowerIndex(index_t& itmp_, + LowerIndex& idx_low_) : itmp(itmp_), idx_low(idx_low_) { } @@ -216,7 +208,7 @@ struct Merge { LowerIndex idx_low; - index_t itmp = idx_up[0]; + index_t itmp = idx_up[Number<0>{}]; constexpr auto pseudo_low_strides = reverse_inclusive_scan_sequence( @@ -226,7 +218,7 @@ struct Merge static_for<0, nDimLow - 1, 1>{}( lambda_CalculateLowerIndex(itmp, idx_low)); - idx_low(nDimLow - 1) = itmp / pseudo_low_strides[nDimLow - 1]; + idx_low(Number{}) = itmp / pseudo_low_strides[Number{}]; return idx_low; } @@ -240,9 +232,9 @@ struct Merge const UpperIndex& /* idx_up_old */, const LowerIndex& idx_low_old) { - if(idx_up_diff[0] == 0) + if(idx_up_diff[Number<0>{}] == 0) { - return make_zero_array(); + return make_zero_multi_index(); } else { @@ -265,7 +257,7 @@ struct Merge LowerIndex idx_low_new = idx_low_old + idx_low_diff_tmp; - if(idx_up_diff[0] > 0) + if(idx_up_diff[Number<0>{}] > 0) { // do carry check on each low dimension in reversed order // starting from the first digit that changed @@ -293,7 +285,7 @@ struct Merge // highest dimension, no out-of-bound check if(carry) { - ++idx_low_new(0); + ++idx_low_new(Number<0>{}); } } else @@ -324,7 +316,7 @@ struct Merge // highest dimension, no out-of-bound check if(borrow) { - --idx_low_new(0); + --idx_low_new(Number<0>{}); } } @@ -358,7 +350,7 @@ struct UnMerge __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up) { - LowerIndex idx_low{0}; + LowerIndex idx_low = make_multi_index(0); constexpr auto pseudo_up_strides = reverse_inclusive_scan_sequence( @@ -366,7 +358,7 @@ struct UnMerge .PushBack(Number<1>{}); static_for<0, nDimUp, 1>{}( - [&](auto idim) { idx_low(0) += idx_up[idim] * pseudo_up_strides[idim]; }); + [&](auto idim) { idx_low(Number<0>{}) += idx_up[idim] * pseudo_up_strides[idim]; }); return idx_low; } @@ -405,7 +397,7 @@ struct Embed using LowerIndex = MultiIndex; using UpperIndex = MultiIndex; - __host__ __device__ explicit constexpr Embed() + __host__ __device__ constexpr Embed() { static_assert(UpperLengths::GetSize() == nDimUp && Coefficients::GetSize() == nDimUp + 1, "wrong! # of dimensions not consistent"); @@ -419,12 +411,10 @@ struct Embed __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up) { - LowerIndex idx_low(Coefficients{}[nDimUp]); + LowerIndex idx_low = make_multi_index(Coefficients{}[Number{}]); - for(index_t i = 0; i < nDimUp; ++i) - { - idx_low(0) += idx_up[i] * Coefficients{}[i]; - } + static_for<0, nDimUp, 1>{}( + [&](auto i) { idx_low(Number<0>{}) += idx_up[i] * Coefficients{}[i]; }); return idx_low; } @@ -434,12 +424,10 @@ struct Embed const UpperIndex& /* idx_up_old */, const LowerIndex& /* idx_low_old */) { - LowerIndex idx_low_diff{0}; + LowerIndex idx_low_diff = make_multi_index(0); - for(index_t i = 0; i < nDimUp; ++i) - { - idx_low_diff(0) += idx_up_diff[i] * Coefficients{}[i]; - } + static_for<0, nDimUp, 1>{}( + [&](auto i) { idx_low_diff(Number<0>{}) += idx_up_diff[i] * Coefficients{}[i]; }); return idx_low_diff; } @@ -467,21 +455,21 @@ struct Embed for(index_t icorner = 0; icorner < ncorner; ++icorner) { // generate upper index for each corner - auto idx_up = make_zero_array(); + auto idx_up = make_zero_multi_index(); index_t itmp = icorner; - for(index_t idim = nDimUp - 1; idim >= 0; --idim) - { - idx_up(idim) = itmp % 2 == 0 ? 0 : UpperLengths::At(idim) - 1; + static_for{}([&](auto idim) { + auto idim_m1 = idim - Number<1>{}; + idx_up(idim_m1) = itmp % 2 == 0 ? 0 : UpperLengths::At(idim_m1) - 1; itmp /= 2; - } + }); // calculate lower index auto idx_low = CalculateLowerIndex(idx_up); // judge if lower index is valid - flag = flag && idx_low[0] >= 0 && idx_low[0] < LowerLength; + flag = flag && idx_low[Number<0>{}] >= 0 && idx_low[Number<0>{}] < LowerLength; } return flag; @@ -499,7 +487,7 @@ struct Freeze using LowerIndex = MultiIndex; using UpperIndex = MultiIndex; - __host__ __device__ explicit constexpr Freeze() + __host__ __device__ constexpr Freeze() { // TODO: sanity check: LowerFreezePoint should be within range of LowerLengths } @@ -512,7 +500,7 @@ struct Freeze __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& /*idx_up*/) { - return to_array(LowerFreezePoint{}); + return to_multi_index(LowerFreezePoint{}); } __host__ __device__ static constexpr auto @@ -520,49 +508,7 @@ struct Freeze const UpperIndex& /* idx_up_old */, const LowerIndex& /* idx_low_old */) { - return make_zero_array(); - } - - __host__ __device__ static constexpr bool IsLinearTransform() { return true; } - - __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() - { - return true; - } -}; - -template -struct Vectorize -{ - using LowerIndex = MultiIndex<1>; - using UpperIndex = MultiIndex<1>; - - __host__ __device__ constexpr Vectorize() - { - static_assert(VectorSize > 0 && LowerLength % VectorSize == 0, - "wrong! cannot evenly divide"); - } - - __host__ __device__ static constexpr auto GetNumOfLowerDimension() { return Number<1>{}; } - - __host__ __device__ static constexpr auto GetNumOfUpperDimension() { return Number<1>{}; } - - __host__ __device__ static constexpr auto GetUpperLengths() - { - return Sequence{}; - } - - __host__ __device__ static constexpr auto CalculateLowerIndex(const UpperIndex& idx_up) - { - return VectorSize * idx_up; - } - - __host__ __device__ static constexpr auto - CalculateLowerIndexDiff(const UpperIndex& idx_up_diff, - const UpperIndex& /* idx_up_old */, - const LowerIndex& /* idx_low_old */) - { - return VectorSize * idx_up_diff; + return make_zero_multi_index(); } __host__ __device__ static constexpr bool IsLinearTransform() { return true; } diff --git a/composable_kernel/include/tensor_description/statically_indexed_array_multi_index.hpp b/composable_kernel/include/tensor_description/statically_indexed_array_multi_index.hpp new file mode 100644 index 0000000000..ff1df4bd10 --- /dev/null +++ b/composable_kernel/include/tensor_description/statically_indexed_array_multi_index.hpp @@ -0,0 +1,107 @@ +#ifndef CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP +#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP + +#include "common_header.hpp" + +namespace ck { + +template +using MultiIndex = StaticallyIndexedArray; + +template +__host__ __device__ constexpr auto make_multi_index(Xs&&... xs) +{ + return make_statically_indexed_array(index_t{xs}...); +} + +template +__host__ __device__ constexpr auto make_zero_multi_index() +{ + return unpack([](auto... xs) { return make_multi_index(xs...); }, + typename uniform_sequence_gen::type{}); +} + +template +__host__ __device__ constexpr auto to_multi_index(const T& x) +{ + return unpack([](auto... ys) { return make_multi_index(ys...); }, x); +} + +// Here should use MultiIndex, instead of Tuple, although the former +// is the alias of the latter. This is because compiler cannot infer the NSize if +// using MultiIndex +// TODO: how to fix this? +template +__host__ __device__ constexpr auto operator+=(Tuple& y, const X& x) +{ + static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Ys); + static_for<0, NSize, 1>{}([&](auto i) { y(i) += x[i]; }); + return y; +} + +template +__host__ __device__ constexpr auto operator-=(Tuple& y, const X& x) +{ + static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Ys); + static_for<0, NSize, 1>{}([&](auto i) { y(i) -= x[i]; }); + return y; +} + +template +__host__ __device__ constexpr auto operator+(const Tuple& x, const Y& y) +{ + static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Xs); + + Tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] + y[i]; }); + return r; +} + +template +__host__ __device__ constexpr auto operator-(const Tuple& x, const Y& y) +{ + static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Xs); + + Tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] - y[i]; }); + return r; +} + +template +__host__ __device__ constexpr auto operator*(const Tuple& x, const Y& y) +{ + static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Xs); + + Tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] * y[i]; }); + return r; +} + +// MultiIndex = index_t * MultiIndex +template +__host__ __device__ constexpr auto operator*(index_t a, const Tuple& x) +{ + constexpr index_t NSize = sizeof...(Xs); + + Tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = a * x[i]; }); + return r; +} + +template +__host__ __device__ void print_multi_index(const Tuple& x) +{ + printf("{"); + printf("MultiIndex, "); + printf("size %d,", index_t{sizeof...(Xs)}); + static_for<0, sizeof...(Xs), 1>{}([&](auto i) { printf("%d ", index_t{x.At(i)}); }); + printf("}"); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/tensor_coordinate.hpp b/composable_kernel/include/tensor_description/tensor_coordinate.hpp index a2d6bb3fb1..efd80beaf8 100644 --- a/composable_kernel/include/tensor_description/tensor_coordinate.hpp +++ b/composable_kernel/include/tensor_description/tensor_coordinate.hpp @@ -41,13 +41,13 @@ struct NativeTensorCoordinate template __host__ __device__ constexpr NativeTensorCoordinate(Xs... xs) - : NativeTensorCoordinate(Index{xs...}) + : NativeTensorCoordinate(make_multi_index(xs...)) { } template __host__ __device__ constexpr NativeTensorCoordinate(Sequence) - : NativeTensorCoordinate(Index{Xs...}) + : NativeTensorCoordinate(make_mutli_index(Xs...)) { } @@ -267,18 +267,18 @@ struct TensorCoordinate private: template __host__ __device__ static constexpr auto - MakeDummyTensorCoordinate(NativeTensorDescriptor) + MakeDummyTensorCoordinate(NativeTensorDescriptor) { return NativeTensorCoordinate>( - make_zero_array()); + make_zero_multi_index()); } template __host__ __device__ static constexpr auto - MakeDummyTensorCoordinate(TransformedTensorDescriptor) + MakeDummyTensorCoordinate(TransformedTensorDescriptor) { return TransformedTensorCoordinate>( - make_zero_array()); + make_zero_multi_index()); } public: diff --git a/composable_kernel/include/tensor_description/tensor_descriptor.hpp b/composable_kernel/include/tensor_description/tensor_descriptor.hpp index de525748c7..7b57723341 100644 --- a/composable_kernel/include/tensor_description/tensor_descriptor.hpp +++ b/composable_kernel/include/tensor_description/tensor_descriptor.hpp @@ -311,13 +311,13 @@ struct TransformedTensorDescriptor static_for<0, nTransform, 1>{}([&](auto itran) { constexpr auto tran = Transforms{}.At(itran); - const auto idx_up_part = pick_array_element(idx_up, UpDimensionIds{}.At(itran)); - auto idx_low_part = pick_array_element(idx_low, LowDimensionIds{}.At(itran)); + const auto idx_up_part = pick_container_element(idx_up, UpDimensionIds{}.At(itran)); + auto idx_low_part = pick_container_element(idx_low, LowDimensionIds{}.At(itran)); // this assume each lower (single) index is only assocaited with one transformation, // which is required for index transformation, and has been checked during constructor // of TransformedTensorDescriptor - idx_low_part = tran.CalculateLowerIndex(to_array(idx_up_part)); + idx_low_part = tran.CalculateLowerIndex(to_multi_index(idx_up_part)); }); return idx_low; @@ -333,20 +333,23 @@ struct TransformedTensorDescriptor constexpr auto tran = Transforms{}.At(itran); const auto idx_up_diff_part = - pick_array_element(idx_up_diff, UpDimensionIds{}.At(itran)); + pick_container_element(idx_up_diff, UpDimensionIds{}.At(itran)); - const auto idx_up_old_part = pick_array_element(idx_up_old, UpDimensionIds{}.At(itran)); + const auto idx_up_old_part = + pick_container_element(idx_up_old, UpDimensionIds{}.At(itran)); const auto idx_low_old_part = - pick_array_element(idx_low_old, LowDimensionIds{}.At(itran)); + pick_container_element(idx_low_old, LowDimensionIds{}.At(itran)); - auto idx_low_diff_part = pick_array_element(idx_low_diff, LowDimensionIds{}.At(itran)); + auto idx_low_diff_part = + pick_container_element(idx_low_diff, LowDimensionIds{}.At(itran)); // this assume each lower (single) index is associated with only one transformation, // which is required for index transformation, and has been checked during constructor // of TransformedTensorDescriptor - idx_low_diff_part = tran.CalculateLowerIndexDiff( - to_array(idx_up_diff_part), to_array(idx_up_old_part), to_array(idx_low_old_part)); + idx_low_diff_part = tran.CalculateLowerIndexDiff(to_multi_index(idx_up_diff_part), + to_multi_index(idx_up_old_part), + to_multi_index(idx_low_old_part)); }); return idx_low_diff; @@ -506,12 +509,12 @@ struct TransformedTensorDescriptor constexpr auto low_dims_part = LowDimensionIds{}.At(itran); constexpr auto low_lengths_part = GetLowerTensorDescriptor().GetLengths(low_dims_part); - const auto idx_low_part = to_array(pick_array_element(idx_low, low_dims_part)); + const auto idx_low_part = + to_multi_index(pick_container_element(idx_low, low_dims_part)); - for(index_t i = 0; i < low_dims_part.Size(); ++i) - { + static_for<0, decltype(low_dims_part)::Size(), 1>{}([&](auto i) { flag = flag && idx_low_part[i] >= 0 && idx_low_part[i] < low_lengths_part[i]; - } + }); } }); diff --git a/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp b/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp index b65edf5d44..bed6de6d1e 100644 --- a/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp +++ b/composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp @@ -64,10 +64,10 @@ template __host__ __device__ constexpr auto - reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor, - Sequence, - Sequence, - Sequence) +reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor, + Sequence, + Sequence, + Sequence) { return TransformedTensorDescriptor...>, @@ -78,7 +78,7 @@ __host__ __device__ constexpr auto // reorder a NativeTensorDescriptor template __host__ __device__ constexpr auto - reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor, MapLower2Upper) +reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor, MapLower2Upper) { static_assert(is_valid_sequence_map{}, "wrong! MapLower2Upper is not a valid map"); @@ -96,7 +96,7 @@ __host__ __device__ constexpr auto // reorder a TransformedTensorDescriptor template __host__ __device__ constexpr auto - reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor, MapLower2Upper) +reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor, MapLower2Upper) { static_assert(is_valid_sequence_map{}, "wrong! MapLower2Upper is not a valid map"); @@ -172,41 +172,5 @@ __host__ __device__ constexpr auto unfold_tensor_descriptor(NativeTensorDescript return make_native_tensor_descriptor(new_lengths, new_strides); } -// a cluster map 1d index to N-d index -template -struct ClusterDescriptor -{ - static constexpr index_t nDim = Lengths::Size(); - - static constexpr auto mDesc = transform_tensor_descriptor( - make_native_tensor_descriptor_packed(Lengths{}), - make_tuple(Merge{}), - make_tuple(ArrangeOrder{}), - make_tuple(Sequence<0>{})); - - __host__ __device__ constexpr ClusterDescriptor() - { - static_assert(Lengths::Size() == nDim && ArrangeOrder::Size() == nDim, - "wrong! size not the same"); - - static_assert(is_valid_sequence_map{}, "wrong! ArrangeOrder is wrong"); - } - - __host__ __device__ static constexpr index_t GetElementSize() { return mDesc.GetElementSize(); } - - __host__ __device__ static constexpr auto CalculateClusterIndex(index_t idx_1d) - { - return mDesc.CalculateLowerIndex(MultiIndex<1>{idx_1d}); - } -}; - -template ::type> -__host__ __device__ constexpr auto make_cluster_descriptor( - Lengths, ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{}) -{ - return ClusterDescriptor{}; -} - } // namespace ck #endif diff --git a/composable_kernel/include/tensor_operation/blockwise_batched_gemm.hpp b/composable_kernel/include/tensor_operation/blockwise_batched_gemm.hpp index a64f169a6d..f5c0df4d7d 100644 --- a/composable_kernel/include/tensor_operation/blockwise_batched_gemm.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_batched_gemm.hpp @@ -210,17 +210,16 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 #pragma unroll for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) { - threadwise_matrix_copy( - a_block_mtx, - p_a_block + - a_block_mtx.GetOffsetFromMultiIndex(k_begin, - m_repeat * MPerLevel1Cluster) + - ib * BlockMatrixStrideA + mMyThreadOffsetA, - a_thread_mtx, - p_a_thread + - a_thread_mtx.GetOffsetFromMultiIndex(0, m_repeat * MPerThreadSubC), - a_thread_sub_mtx.GetLengths(), - Number{}); + threadwise_matrix_copy(a_block_mtx, + p_a_block + + a_block_mtx.GetOffsetFromMultiIndex( + k_begin, m_repeat * MPerLevel1Cluster) + + ib * BlockMatrixStrideA + mMyThreadOffsetA, + a_thread_mtx, + p_a_thread + a_thread_mtx.GetOffsetFromMultiIndex( + 0, m_repeat * MPerThreadSubC), + a_thread_sub_mtx.GetLengths(), + Number{}); } } @@ -229,17 +228,16 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 #pragma unroll for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) { - threadwise_matrix_copy( - b_block_mtx, - p_b_block + - b_block_mtx.GetOffsetFromMultiIndex(k_begin, - n_repeat * NPerLevel1Cluster) + - ib * BlockMatrixStrideB + mMyThreadOffsetB, - b_thread_mtx, - p_b_thread + - b_thread_mtx.GetOffsetFromMultiIndex(0, n_repeat * NPerThreadSubC), - b_thread_sub_mtx.GetLengths(), - Number{}); + threadwise_matrix_copy(b_block_mtx, + p_b_block + + b_block_mtx.GetOffsetFromMultiIndex( + k_begin, n_repeat * NPerLevel1Cluster) + + ib * BlockMatrixStrideB + mMyThreadOffsetB, + b_thread_mtx, + p_b_thread + b_thread_mtx.GetOffsetFromMultiIndex( + 0, n_repeat * NPerThreadSubC), + b_thread_sub_mtx.GetLengths(), + Number{}); } } @@ -307,7 +305,7 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 "Run_amd_asm can only deal with BlockMatrixStrideA == 0 && BatchPerThread == " "1 for now\n"); - using Float4 = vector_type::MemoryType; + using Float4 = vector_type::type; Float4* reg_a = (Float4*)(p_a_thread); Float4* reg_b = (Float4*)(p_b_thread); @@ -391,9 +389,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 { threadwise_matrix_copy( c_thread_sub_mtx, - p_c_thread + - c_thread_sub_mtx.GetOffsetFromMultiIndex(m_repeat * MPerLevel1Cluster, - n_repeat * NPerLevel1Cluster), + p_c_thread + c_thread_sub_mtx.GetOffsetFromMultiIndex( + m_repeat * MPerLevel1Cluster, n_repeat * NPerLevel1Cluster), c_block_mtx, p_c_block + c_block_mtx.GetOffsetFromMultiIndex(m_repeat * MPerLevel1Cluster, @@ -405,5 +402,5 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2 } }; -} // namespace +} // namespace ck #endif diff --git a/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp new file mode 100644 index 0000000000..5aac3f9d19 --- /dev/null +++ b/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp @@ -0,0 +1,171 @@ +#ifndef CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP +#define CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "cluster_descriptor.hpp" +#include "threadwise_dynamic_tensor_slice_transfer.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +template +struct BlockwiseDynamicTensorSliceTransfer_v4 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + using Index = MultiIndex; + + __device__ constexpr BlockwiseDynamicTensorSliceTransfer_v4(const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin) + : threadwise_transfer_( + src_desc, make_zero_multi_index(), dst_desc, make_zero_multi_index()) + + { + static_assert(nDim == remove_reference_t>::GetNumOfDimension() && + nDim == remove_reference_t>::GetNumOfDimension() && + nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), + "wrong! BlockSize too small"); + + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_id = + thread_cluster_desc_.CalculateClusterIndex(get_thread_local_1d_id()); + + const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{}; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_id_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_id_begin); + } + } + + __device__ static constexpr auto CalculateThreadDataBegin() + { + const auto thread_cluster_id = + thread_cluster_desc_.CalculateClusterIndex(get_thread_local_1d_id()); + + return thread_cluster_id * ThreadSliceLengths{}; + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcData* p_src, + const SrcIteratorHacks& src_iterator_hacks) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_desc, p_src, src_iterator_hacks); + } + } + + __device__ void RunWrite(const DstDesc& dst_desc, DstData* p_dst) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunWrite(dst_desc, p_dst); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + // SrcMoveSliceWindowIteratorHack to control index calculation move slice window + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& step, + const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow( + src_desc, step, src_move_slice_window_iterator_hack); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseDynamicTensorSliceTransfer_v3; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm.hpp index 2e21f7141b..3ffeb3f16f 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm.hpp @@ -95,26 +95,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC}; } - __device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c, - index_t n_in_c) - { - constexpr auto c_thread_mtx = ThreadMatrixC{}; - - constexpr index_t MPerLevel1Cluster = - MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster; - constexpr index_t NPerLevel1Cluster = - NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster; - - index_t m_repeat = m_in_c / MPerThreadSubC; - index_t n_repeat = n_in_c / NPerThreadSubC; - - index_t m_in_sub_c = m_in_c % MPerThreadSubC; - index_t n_in_sub_c = n_in_c % NPerThreadSubC; - - return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c, - n_repeat * NPerLevel1Cluster + n_in_sub_c}; - } - template __device__ void Run_naive(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const @@ -336,9 +316,14 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2 constexpr index_t MRepeat = MPerThread / MPerThreadSubC; constexpr index_t NRepeat = NPerThread / NPerThreadSubC; - static_if{}([&](auto) { + if constexpr(MRepeat == 2 && NRepeat == 2) + { Run_pipelined_2x2(p_a_block, p_b_block, p_c_thread); - }).Else([&](auto) { Run_naive(p_a_block, p_b_block, p_c_thread); }); + } + else + { + Run_naive(p_a_block, p_b_block, p_c_thread); + } #else Run_naive(p_a_block, p_b_block, p_c_thread); #endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp new file mode 100644 index 0000000000..e19bc5093e --- /dev/null +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp @@ -0,0 +1,370 @@ +#ifndef CK_BLOCKWISE_GEMM_V2_HPP +#define CK_BLOCKWISE_GEMM_V2_HPP + +#include "common_header.hpp" +#include "threadwise_gemm_v2.hpp" + +namespace ck { + +// blockwise GEMM: C[M, N] += transpose(A[K, M]) * B[K, N] +// A and B are visable to the whole block, C is distributed among each thread +// If following number are power of 2, index calculation shall be greatly reduced: +// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster, +// MLevel1ThreadCluster, NLevel1ThreadCluster +template +struct BlockwiseGemm_km_kn_m0m1n0n1_v1 +{ + struct MatrixIndex + { + index_t row; + index_t col; + }; + + index_t mMyThreadOffsetA; + index_t mMyThreadOffsetB; + + __device__ BlockwiseGemm_km_kn_m0m1n0n1_v1() + { + static_assert(BlockMatrixA::IsKnownAtCompileTime() && + BlockMatrixB::IsKnownAtCompileTime() && + ThreadMatrixC::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + constexpr index_t ThreadPerLevel1Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster * + MLevel1ThreadCluster * NLevel1ThreadCluster; + + static_assert(BlockSize == ThreadPerLevel1Cluster, "wrong! wrong blocksize\n"); + + static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0), + "wrong! K dimension not consistent\n"); + + constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed + constexpr index_t N = BlockMatrixB{}.GetLength(I1); + + static_assert(M % (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster) == 0 && + N % (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster) == 0, + "wrong! Cannot evenly divide work among\n"); + + static_assert(ThreadMatrixC{}.GetLength(I0) == GetThreadMatrixCLengths()[I0] && + ThreadMatrixC{}.GetLength(I1) == GetThreadMatrixCLengths()[I1], + "wrong! ThreadMatrixC lengths is wrong"); + + auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + mMyThreadOffsetA = BlockMatrixA{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.row)); + mMyThreadOffsetB = BlockMatrixB{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.col)); + } + + __device__ static constexpr auto GetThreadMatrixCLengths() + { + constexpr auto I1 = Number<1>{}; + + constexpr index_t M = BlockMatrixA{}.GetLength(I1); // A is transposed + constexpr index_t N = BlockMatrixB{}.GetLength(I1); + + constexpr index_t MRepeat = + M / (MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster); + constexpr index_t NRepeat = + N / (NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster); + + return Sequence{}; + } + + __device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) + { + constexpr index_t ThreadPerLevel0Cluster = MLevel0ThreadCluster * NLevel0ThreadCluster; + + index_t level1_id = thread_id / ThreadPerLevel0Cluster; + index_t level1_m_id = level1_id / NLevel1ThreadCluster; + index_t level1_n_id = level1_id % NLevel1ThreadCluster; + + index_t level0_id = thread_id % ThreadPerLevel0Cluster; + index_t level0_m_id = level0_id / NLevel0ThreadCluster; + index_t level0_n_id = level0_id % NLevel0ThreadCluster; + + constexpr index_t MPerLevel0Cluster = MPerThreadSubC * MLevel0ThreadCluster; + constexpr index_t NPerLevel0Cluster = NPerThreadSubC * NLevel0ThreadCluster; + + return MatrixIndex{level1_m_id * MPerLevel0Cluster + level0_m_id * MPerThreadSubC, + level1_n_id * NPerLevel0Cluster + level0_n_id * NPerThreadSubC}; + } + + template + __device__ void + Run_naive(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + constexpr auto a_block_mtx = BlockMatrixA{}; + constexpr auto b_block_mtx = BlockMatrixB{}; + constexpr auto c_thread_mtx = ThreadMatrixC{}; + + constexpr auto K = a_block_mtx.GetLength(I0); + + constexpr auto MPerThread = c_thread_mtx.GetLength(I0); + constexpr auto NPerThread = c_thread_mtx.GetLength(I1); + + constexpr index_t MPerLevel1Cluster = + MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster; + constexpr index_t NPerLevel1Cluster = + NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster; + + constexpr index_t MRepeat = MPerThread / MPerThreadSubC; + constexpr index_t NRepeat = NPerThread / NPerThreadSubC; + + // thread A, B for GEMM + constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2( + Number{}, Number{}); + + constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2( + Number{}, Number{}); + + FloatA p_a_thread[a_thread_mtx.GetElementSpace()]; + FloatB p_b_thread[b_thread_mtx.GetElementSpace()]; + + constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy_v2{}; + + constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy_v2{}; + + constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1{}; +#pragma unroll + // loop over k + for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop) + { +#pragma unroll + // read A + for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat) + { + a_thread_copy.Run(p_a_block + + a_block_mtx.CalculateOffset( + make_tuple(k_begin, m_repeat * MPerLevel1Cluster)) + + mMyThreadOffsetA, + p_a_thread + a_thread_mtx.CalculateOffset( + make_tuple(0, m_repeat * MPerThreadSubC))); + } + +#pragma unroll + // read B + for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat) + { + b_thread_copy.Run(p_b_block + + b_block_mtx.CalculateOffset( + make_tuple(k_begin, n_repeat * NPerLevel1Cluster)) + + mMyThreadOffsetB, + p_b_thread + b_thread_mtx.CalculateOffset( + make_tuple(0, n_repeat * NPerThreadSubC))); + } + + // C += A * B + threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread); + } + } + + template + __device__ void + Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + constexpr auto a_block_mtx = BlockMatrixA{}; + constexpr auto b_block_mtx = BlockMatrixB{}; + constexpr auto c_thread_mtx = ThreadMatrixC{}; + + constexpr auto K = a_block_mtx.GetLength(I0); + + constexpr auto MPerThread = c_thread_mtx.GetLength(I0); + constexpr auto NPerThread = c_thread_mtx.GetLength(I1); + + constexpr index_t MPerLevel1Cluster = + MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster; + + constexpr index_t NPerLevel1Cluster = + NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster; + + constexpr index_t MRepeat = MPerThread / MPerThreadSubC; + constexpr index_t NRepeat = NPerThread / NPerThreadSubC; + + static_assert(MRepeat == 2 && NRepeat == 2, + "wrong! inline asm cannot deal with this GEMM config yet"); + + // thread A, B + constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{})); + + constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{})); + + // thread A-sub, B-sub + constexpr auto a_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2( + make_tuple(Number{}, Number{}), + make_tuple(Number{}, Number<1>{})); + + constexpr auto b_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2( + make_tuple(Number{}, Number{}), + make_tuple(Number{}, Number<1>{})); + + constexpr auto c_thread_sub_mtx = make_dynamic_naive_tensor_descriptor_v2( + make_tuple(Number{}, Number{}), + make_tuple(Number{}, Number<1>{})); + + FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()]; + FloatB p_b_thread[b_thread_mtx.GetElementSpaceSize()]; + + constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy_v2{}; + + constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy_v2{}; + + constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v1{}; + + const FloatA* p_a_block_off = p_a_block + mMyThreadOffsetA; + const FloatB* p_b_block_off = p_b_block + mMyThreadOffsetB; + + // read A_sub_0 + a_thread_copy.Run(p_a_block_off, p_a_thread); + + // read B_sub_0 + b_thread_copy.Run(p_b_block_off, p_b_thread); + + // read B_sub_1 + b_thread_copy.Run(p_b_block_off + + b_block_mtx.CalculateOffset(make_tuple(0, NPerLevel1Cluster)), + p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC))); + + // read A_sub_1 + a_thread_copy.Run(p_a_block_off + + a_block_mtx.CalculateOffset(make_tuple(0, MPerLevel1Cluster)), + p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC))); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_gemm.Run( + p_a_thread, + p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)), + p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC))); + +#pragma unroll + // loop over rest of k + for(index_t k = KPerThreadLoop; k < K; k += KPerThreadLoop) + { + // read A_sub_0 + a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(make_tuple(k, 0)), + p_a_thread); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_gemm.Run( + p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)), + p_b_thread, + p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, 0))); + + // read B_sub_0 + b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(make_tuple(k, 0)), + p_b_thread); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_gemm.Run( + p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)), + p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)), + p_c_thread + + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC))); + + // read B_sub_1 + b_thread_copy.Run( + p_b_block_off + b_block_mtx.CalculateOffset(make_tuple(k, NPerLevel1Cluster)), + p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC))); + + // read A_sub_1 + a_thread_copy.Run( + p_a_block_off + a_block_mtx.CalculateOffset(make_tuple(k, MPerLevel1Cluster)), + p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC))); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_gemm.Run( + p_a_thread, + p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)), + p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC))); + } + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_gemm.Run( + p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)), + p_b_thread, + p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, 0))); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_gemm.Run( + p_a_thread + a_thread_mtx.CalculateOffset(make_tuple(0, MPerThreadSubC)), + p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(0, NPerThreadSubC)), + p_c_thread + c_thread_mtx.CalculateOffset(make_tuple(MPerThreadSubC, NPerThreadSubC))); + } + + template + __device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const + { +#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + constexpr index_t MPerThread = ThreadMatrixC{}.GetLength(I0); + constexpr index_t NPerThread = ThreadMatrixC{}.GetLength(I1); + + constexpr index_t MRepeat = MPerThread / MPerThreadSubC; + constexpr index_t NRepeat = NPerThread / NPerThreadSubC; + + if constexpr(MRepeat == 2 && NRepeat == 2) + { + Run_pipelined_2x2(p_a_block, p_b_block, p_c_thread); + } + else + { + Run_naive(p_a_block, p_b_block, p_c_thread); + } +#else + Run_naive(p_a_block, p_b_block, p_c_thread); +#endif + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp new file mode 100644 index 0000000000..76f50bc811 --- /dev/null +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp @@ -0,0 +1,198 @@ +#ifndef CK_BLOCKWISE_GEMM_V3_HPP +#define CK_BLOCKWISE_GEMM_V3_HPP + +#include "common_header.hpp" +#include "threadwise_gemm_v3.hpp" + +namespace ck { + +// blockwise GEMM: C[M, N] += transpose(A[K, M]) * B[K, N] +// A and B are visable to the whole block, C is distributed among each thread +// If following number are power of 2, index calculation shall be greatly reduced: +// KPerThread, HPerThread, MLevel0ThreadCluster, NLevel0ThreadCluster, +// MLevel1ThreadCluster, NLevel1ThreadCluster +template +struct BlockwiseGemm_km_kn_m0m1n0n1_v3 +{ + struct MatrixIndex + { + index_t k; + index_t h; + index_t w; + }; + + index_t mMyThreadOffsetA; + + __device__ BlockwiseGemm_km_kn_m0m1n0n1_v3() + { + static_assert(BlockMatrixA::IsKnownAtCompileTime() && + BlockMatrixB::IsKnownAtCompileTime() && + ThreadMatrixC::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + static_assert(BlockMatrixA{}.GetLength(I0) == BlockMatrixB{}.GetLength(I0), + "wrong! K dimension not consistent\n"); + + constexpr index_t K = BlockMatrixA{}.GetLength(I1); // A is transposed + constexpr index_t N = BlockMatrixB{}.GetLength(I1); + constexpr index_t H = BlockMatrixB{}.GetLength(I2); + constexpr index_t W = BlockMatrixB{}.GetLength(I3); + + static_assert(K % KPerThread == 0 && H % HPerThread == 0 && W % WPerThread == 0, + "wrong! Cannot evenly divide work among\n"); + + constexpr auto KThreadCluster = K / KPerThread; + constexpr auto HThreadCluster = H / HPerThread; + constexpr auto WThreadCluster = W / WPerThread; + + static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster, + "wrong! wrong blocksize\n"); + + auto c_thread_mtx_index = GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + mMyThreadOffsetA = + BlockMatrixA{}.CalculateOffset(make_tuple(0, c_thread_mtx_index.k * KPerThread)); + } + + __device__ static constexpr auto GetThreadMatrixCLengths() + { + return Sequence{}; + } + + __device__ static MatrixIndex GetBeginOfThreadMatrixC(index_t thread_id) + { + constexpr index_t H = BlockMatrixB{}.GetLength(Number<2>{}); + constexpr index_t W = BlockMatrixB{}.GetLength(Number<3>{}); + + constexpr auto num_w_threads = W / WPerThread; + constexpr auto num_h_threads = H / HPerThread; + constexpr auto num_hw_threads = num_w_threads * num_h_threads; + + index_t k_thread_id = thread_id / num_hw_threads; + index_t hw_thread_id = thread_id % num_hw_threads; + + index_t h_thread_id = hw_thread_id / num_w_threads; + index_t w_thread_id = hw_thread_id % num_w_threads; + + return MatrixIndex{k_thread_id, h_thread_id, w_thread_id}; + } + + template + struct ThreadwiseSliceCopy_a + { + template + __device__ static void Run(const Data* p_src, Data* p_dst) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + using vector_t = typename vector_type::type; + + static_for<0, NSliceRow, 1>{}([&](auto i) { + static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) { + constexpr auto src_offset = SrcDesc{}.CalculateOffset(make_tuple(i, j)); + constexpr auto dst_offset = DstDesc{}.CalculateOffset(make_tuple(i, j)); + + *reinterpret_cast(&p_dst[dst_offset]) = + *reinterpret_cast(&p_src[src_offset]); + }); + }); + } + }; + + template + __device__ void + Run_naive(const FloatA* p_a_block, const FloatB* p_b_thread, FloatC* p_c_thread) const + { + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto a_block_mtx = BlockMatrixA{}; + + constexpr auto EPerBlock = a_block_mtx.GetLength(I0); + + constexpr auto KPerThreadSubC = 4; + + static_assert(KPerThread % KPerThreadSubC == 0, ""); + static_assert(HPerThread % 2 == 0, ""); + static_assert(WPerThread % 2 == 0, ""); + + // thread A, B for GEMM + constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{})); + + constexpr auto b_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + constexpr auto c_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + FloatA p_a_thread[a_thread_mtx.GetElementSpaceSize()]; + + constexpr auto a_thread_copy = ThreadwiseSliceCopy_a{}; + + constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3{}; + // loop over k +#pragma unroll + for(index_t e_begin = 0; e_begin < EPerBlock; e_begin += EPerThreadLoop) + { +#pragma unroll + for(index_t k_begin = 0; k_begin < KPerThread; k_begin += KPerThreadSubC) + { + a_thread_copy.Run(p_a_block + + a_block_mtx.CalculateOffset(make_tuple(e_begin, k_begin)) + + mMyThreadOffsetA, + p_a_thread); + + for(index_t h_begin = 0; h_begin < HPerThread; h_begin += 2) + { + + for(index_t w_begin = 0; w_begin < WPerThread; w_begin += 2) + { + threadwise_gemm.Run(p_a_thread, + p_b_thread + b_thread_mtx.CalculateOffset(make_tuple( + e_begin, 0, h_begin, w_begin)), + p_c_thread + c_thread_mtx.CalculateOffset(make_tuple( + k_begin, 0, h_begin, w_begin))); + } + } + } + } + } + + template + __device__ void Run(const FloatA* p_a_block, const FloatB* p_b_thread, FloatC* p_c_thread) const + { + Run_naive(p_a_block, p_b_thread, p_c_thread); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp b/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp index a63ebd27bc..d67101a935 100644 --- a/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp @@ -5,6 +5,7 @@ #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "tensor_coordinate.hpp" +#include "cluster_descriptor.hpp" #include "threadwise_generic_tensor_slice_copy.hpp" namespace ck { @@ -68,9 +69,9 @@ struct BlockwiseGenericTensorSliceCopy_v4 const auto thread_data_id_begin = thread_cluster_id * ThreadSliceLengths{}; mThreadwiseLoad.SetSrcSliceOrigin(src_block_slice_origin + thread_data_id_begin); - mThreadwiseLoad.SetDstSliceOrigin(make_zero_array()); + mThreadwiseLoad.SetDstSliceOrigin(make_zero_multi_index()); - mThreadwiseStore.SetSrcSliceOrigin(make_zero_array()); + mThreadwiseStore.SetSrcSliceOrigin(make_zero_multi_index()); mThreadwiseStore.SetDstSliceOrigin(dst_block_slice_origin + thread_data_id_begin); } } diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp new file mode 100644 index 0000000000..b0674debfa --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp @@ -0,0 +1,509 @@ +#ifndef CK_GRIDWISE_DYNAMIC_GEMM_HPP +#define CK_GRIDWISE_DYNAMIC_GEMM_HPP + +#include "common_header.hpp" +#include "dynamic_multi_index_transform_helper.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "blockwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_dynamic_tensor_slice_transfer.hpp" +#include "blockwise_gemm_v2.hpp" + +namespace ck { + +template +struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 +{ + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = math::lcm(Number{}, + Number{}, + Number{}, + Number{}); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}), max_lds_align); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); + + return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB); + } + + template + __device__ void Run(const AGlobalDesc& a_k_m_global_desc, + const FloatAB* __restrict__ p_a_global, + const BGlobalDesc& b_k_n_global_desc, + const FloatAB* __restrict__ p_b_global, + const CGlobalDesc& c_m0_m1_n0_n1_global_desc, + FloatC* __restrict__ p_c_global, + FloatAB* __restrict__ p_shared_block, + integral_constant, + integral_constant) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + const auto K = a_k_m_global_desc.GetLength(I0); + const auto M = a_k_m_global_desc.GetLength(I1); + const auto N = b_k_n_global_desc.GetLength(I1); + + // divide block work by [M, N] +#if 0 + const auto m_block_work_num = M / Number{}; + const auto n_block_work_num = N / Number{}; + + const index_t m_block_work_id = get_block_1d_id() / n_block_work_num; + const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num; + +#else + // Hack: this force result into SGPR + const index_t m_block_work_num = __builtin_amdgcn_readfirstlane(M / MPerBlock); + const index_t n_block_work_num = __builtin_amdgcn_readfirstlane(N / NPerBlock); + + const index_t m_block_work_id = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / n_block_work_num); + const index_t n_block_work_id = get_block_1d_id() - m_block_work_id * n_block_work_num; +#endif + + const index_t m_block_data_on_global = m_block_work_id * MPerBlock; + const index_t n_block_data_on_global = n_block_work_id * NPerBlock; + + // lds max alignment + constexpr auto max_lds_align = math::lcm(Number{}, + Number{}, + Number{}, + Number{}); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}), max_lds_align); + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseDynamicTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_K_M, + ABlockTransferThreadClusterLengths_K_M, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_k_m_global_desc), + decltype(a_k_m_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 1>, + ABlockTransferSrcVectorDim, + 1, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_M, + AddressSpace::Global, + AddressSpace::Lds, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_k_m_global_desc, + make_multi_index(0, m_block_data_on_global), + a_k_m_block_desc, + make_multi_index(0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = + BlockwiseDynamicTensorSliceTransfer_v4, + BBlockTransferThreadSliceLengths_K_N, + BBlockTransferThreadClusterLengths_K_N, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_k_n_global_desc), + decltype(b_k_n_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 1>, + BBlockTransferSrcVectorDim, + 1, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_N, + AddressSpace::Global, + AddressSpace::Lds, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_k_n_global_desc, + make_multi_index(0, n_block_data_on_global), + b_k_n_block_desc, + make_multi_index(0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[KPerBlock, MPerBlock] is in LDS + // b_mtx[KPerBlocl, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 && + NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0, + "wrong!"); + + constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster); + constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster); + + // c_thread_mtx definition: this is a mess + // TODO:: more elegent way of defining c_thread_mtx + constexpr auto c_m0m1_n0n1_thread_desc = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{})); + + const auto blockwise_gemm = + BlockwiseGemm_km_kn_m0m1n0n1_v1{}; + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block_double = p_shared_block; + FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size; + + // register allocation for output + FloatAcc p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()]; + + // zero out threadwise output + threadwise_matrix_set_zero_v2(c_m0m1_n0n1_thread_desc, p_c_thread); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0); + + // hack to control index calculation when iterating over A and B matrix for threadwise copy + constexpr auto a_k_m_global_iterator_hacks = AGlobalIteratorHacks{}; + constexpr auto b_k_n_global_iterator_hacks = BGlobalIteratorHacks{}; + + // hack to control index calculation when move slice window for A and B matrix for + // threadwise copy + constexpr auto a_k_m_global_move_slice_window_iterator_hack = + AGlobalMoveSliceWindowIteratorHacks{}; + constexpr auto b_k_n_global_move_slice_window_iterator_hack = + BGlobalMoveSliceWindowIteratorHacks{}; + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); + b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); + + a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double); + b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double); + } + + if constexpr(HasMainKBlockLoop) + { + FloatAB* p_a_block_even = p_a_block_double; + FloatAB* p_b_block_even = p_b_block_double; + + FloatAB* p_a_block_odd = p_a_block_double + a_block_space_size; + FloatAB* p_b_block_odd = p_b_block_double + b_block_space_size; + + index_t k_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, + a_block_slice_copy_step, + a_k_m_global_move_slice_window_iterator_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, + b_block_slice_copy_step, + b_k_n_global_move_slice_window_iterator_hack); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); + b_blockwise_copy.RunRead( + b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_odd); + b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_odd); + + // odd iteration + a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, + a_block_slice_copy_step, + a_k_m_global_move_slice_window_iterator_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, + b_block_slice_copy_step, + b_k_n_global_move_slice_window_iterator_hack); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); + b_blockwise_copy.RunRead( + b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_even); + b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_even); + + k_block_data_begin += 2 * KPerBlock; + } while(k_block_data_begin < K - 2 * KPerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + a_blockwise_copy.MoveSrcSliceWindow(a_k_m_global_desc, + a_block_slice_copy_step, + a_k_m_global_move_slice_window_iterator_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_k_n_global_desc, + b_block_slice_copy_step, + b_k_n_global_move_slice_window_iterator_hack); + + __syncthreads(); + + // LDS double buffer: load last data from device mem + a_blockwise_copy.RunRead(a_k_m_global_desc, p_a_global, a_k_m_global_iterator_hacks); + b_blockwise_copy.RunRead(b_k_n_global_desc, p_b_global, b_k_n_global_iterator_hacks); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); + + // LDS double buffer: store last data to LDS + a_blockwise_copy.RunWrite(a_k_m_block_desc, p_a_block_double + a_block_space_size); + b_blockwise_copy.RunWrite(b_k_n_block_desc, p_b_block_double + b_block_space_size); + + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(p_a_block_double + a_block_space_size, + p_b_block_double + b_block_space_size, + p_c_thread); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); + } + + // output: register to global memory + { + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + // define input tensor descriptor for threadwise copy + // thread input tensor, src of threadwise copy + constexpr auto c_m0_m1_n0_n1_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number{}, + Number{}, + Number{}, + Number{})); + + // calculate origin of thread input tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const index_t m_thread_data_on_global = + m_block_data_on_global + c_thread_mtx_on_block.row; + + const index_t n_thread_data_on_global = + n_block_data_on_global + c_thread_mtx_on_block.col; + + // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor + constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; + + constexpr auto tmp = make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}, Number{})); + + ThreadwiseDynamicTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_m0_m1_n0_n1_thread_desc), + decltype(c_m0_m1_n0_n1_global_desc), + Sequence, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + AddressSpace::Vgpr, + AddressSpace::Global, + CGlobalMemoryDataOperation, + 1, + true>(c_m0_m1_n0_n1_global_desc, + make_multi_index(m_thread_data_on_global / M1, + m_thread_data_on_global % M1, + n_thread_data_on_global / N1, + n_thread_data_on_global % N1)) + .Run(c_m0_m1_n0_n1_thread_desc, + make_tuple(I0, I0, I0, I0), + p_c_thread, + c_m0_m1_n0_n1_global_desc, + p_c_global, + c_m0_m1_n0_n1_global_tensor_iterator_hacks); + } + } + + // pass tensor descriptor by reference + template + __device__ void Run(const AGlobalDesc& a_k_m_global_desc, + const FloatAB* __restrict__ p_a_global, + const BGlobalDesc& b_k_n_global_desc, + const FloatAB* __restrict__ p_b_global, + const CGlobalDesc& c_m0_m1_n0_n1_global_desc, + FloatC* __restrict__ p_c_global, + integral_constant, + integral_constant) const + { + constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + Run(a_k_m_global_desc, + p_a_global, + b_k_n_global_desc, + p_b_global, + c_m0_m1_n0_n1_global_desc, + p_c_global, + p_shared_block, + integral_constant{}, + integral_constant{}); + } + + // pass tensor descriptors by pointers + template + __device__ void Run(const AGlobalDesc* p_a_k_m_global_desc, + const FloatAB* __restrict__ p_a_global, + const BGlobalDesc* p_b_k_n_global_desc, + const FloatAB* __restrict__ p_b_global, + const CGlobalDesc* p_c_m0_m1_n0_n1_global_desc, + FloatC* __restrict__ p_c_global, + integral_constant, + integral_constant) const + { + const auto a_k_m_global_desc = *p_a_k_m_global_desc; + const auto b_k_n_global_desc = *p_b_k_n_global_desc; + const auto c_m0_m1_n0_n1_global_desc = *p_c_m0_m1_n0_n1_global_desc; + + Run(a_k_m_global_desc, + p_a_global, + b_k_n_global_desc, + p_b_global, + c_m0_m1_n0_n1_global_desc, + p_c_global, + integral_constant{}, + integral_constant{}); + } + + // pass tensor descriptors by void* + template + __device__ void Run(const void* p_a_k_m_global_desc, + const FloatAB* __restrict__ p_a_global, + const void* p_b_k_n_global_desc, + const FloatAB* __restrict__ p_b_global, + const void* p_c_m0_m1_n0_n1_global_desc, + FloatC* __restrict__ p_c_global, + integral_constant, + integral_constant) const + { + const auto a_k_m_global_desc = *reinterpret_cast(p_a_k_m_global_desc); + const auto b_k_n_global_desc = *reinterpret_cast(p_b_k_n_global_desc); + const auto c_m0_m1_n0_n1_global_desc = + *reinterpret_cast(p_c_m0_m1_n0_n1_global_desc); + + Run(a_k_m_global_desc, + p_a_global, + b_k_n_global_desc, + p_b_global, + c_m0_m1_n0_n1_global_desc, + p_c_global, + integral_constant{}, + integral_constant{}); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp new file mode 100644 index 0000000000..81a3a0674f --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp @@ -0,0 +1,471 @@ +#ifndef CK_GRIDWISE_DYNAMIC_GEMM_V2_HPP +#define CK_GRIDWISE_DYNAMIC_GEMM_V2_HPP + +#include "common_header.hpp" +#include "dynamic_multi_index_transform_helper.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "blockwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_dynamic_tensor_slice_transfer.hpp" +#include "blockwise_gemm_v3.hpp" + +namespace ck { + +template +struct GridwiseDynamicGemm_km_kn_mn_v3 +{ + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto E = EPerBlock * 3 * 3; + + constexpr auto max_lds_align = + math::lcm(Number{}, Number{}); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_e_k_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}), max_lds_align); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_e_k_desc.GetElementSpaceSize(), max_lds_align); + + return a_block_space_size * sizeof(FloatAB); + } + + template + __device__ void Run(const AGlobalDesc& a_e_k_global_desc, + const FloatAB* __restrict__ p_a_global, + const BGlobalDesc& b_e_n_ho_wo_global_desc, + const FloatAB* __restrict__ p_b_global, + const CGlobalDesc& c_k_n_ho_wo_global_desc, + FloatC* __restrict__ p_c_global, + FloatAB* __restrict__ p_shared_block, + integral_constant, + integral_constant) const + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto E = EPerBlock * 3 * 3; + + // const auto E = a_e_k_global_desc.GetLength(I0); + const auto K = a_e_k_global_desc.GetLength(I1); + + const auto N = b_e_n_ho_wo_global_desc.GetLength(I1); + const auto Ho = b_e_n_ho_wo_global_desc.GetLength(I2); + const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3); + + // divide block work by [M, N] +#if 0 + const auto k_block_work_num = K / Number{}; + const auto ho_block_work_num = Ho / Number{}; + const auto wo_block_work_num = Wo / Number{}; + const auto hwo_block_work_num = ho_block_work_num * wo_block_work_num; + + const index_t k_block_work_id = get_block_1d_id() / hwo_block_work_num; + const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num; + + const index_t ho_block_work_id = hwo_block_work_id / wo_block_work_num; + const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num; +#else + // Hack: this force result into SGPR + const index_t k_block_work_num = __builtin_amdgcn_readfirstlane(K / KPerBlock); + const index_t ho_block_work_num = __builtin_amdgcn_readfirstlane(Ho / HoPerBlock); + const index_t wo_block_work_num = __builtin_amdgcn_readfirstlane(Wo / WoPerBlock); + const index_t hwo_block_work_num = ho_block_work_num * wo_block_work_num; + + const index_t k_block_work_id = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / hwo_block_work_num); + const index_t hwo_block_work_id = get_block_1d_id() - k_block_work_id * hwo_block_work_num; + + const index_t ho_block_work_id = + __builtin_amdgcn_readfirstlane(hwo_block_work_id / wo_block_work_num); + const index_t wo_block_work_id = hwo_block_work_id - ho_block_work_id * wo_block_work_num; +#endif + + // lds max alignment + constexpr auto max_lds_align = + math::lcm(Number{}, Number{}); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_e_k_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}), max_lds_align); + + constexpr auto a_e_k_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_e_n_ho_wo_block_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + // c_thread_mtx definition: this is a mess + // TODO:: more elegent way of defining c_thread_mtx + constexpr auto c_k_n_ho_wo_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + const auto blockwise_gemm = + BlockwiseGemm_km_kn_m0m1n0n1_v3{}; + + auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); + + const auto k_thread_id = c_thread_mtx_index.k; + const auto ho_thread_id = c_thread_mtx_index.h; + const auto wo_thread_id = c_thread_mtx_index.w; + + const index_t k_block_data_on_global = k_block_work_id * KPerBlock; + const index_t ho_block_data_on_global = ho_block_work_id * HoPerBlock; + const index_t wo_block_data_on_global = wo_block_work_id * WoPerBlock; + + const index_t ho_thread_data_on_global = + ho_block_data_on_global + ho_thread_id * HoPerThread; + const index_t wo_thread_data_on_global = + wo_block_data_on_global + wo_thread_id * WoPerThread; + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseDynamicTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_E_K, + ABlockTransferThreadClusterLengths_E_K, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_e_k_global_desc), + decltype(a_e_k_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 1>, + ABlockTransferSrcVectorDim, + 1, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K, + AddressSpace::Global, + AddressSpace::Lds, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_e_k_global_desc, + make_multi_index(0, k_block_data_on_global), + a_e_k_desc, + make_multi_index(0, 0)); + + constexpr auto b_e_n_ho_wo_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + auto b_threadwise_transfer = ThreadwiseDynamicTensorSliceTransfer_v2< + FloatAB, + FloatAB, + decltype(b_e_n_ho_wo_global_desc), + decltype(b_e_n_ho_wo_thread_desc), + Sequence, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + AddressSpace::Global, + AddressSpace::Vgpr, + InMemoryDataOperation::Set, + 1, + true>(b_e_n_ho_wo_global_desc, + make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global)); + + FloatAB* p_a_block = p_shared_block; + + // register allocation for output + FloatAcc p_c_thread[c_k_n_ho_wo_thread_desc.GetElementSpaceSize()]; + + // zero out threadwise output + threadwise_matrix_set_zero_v3(c_k_n_ho_wo_thread_desc, p_c_thread); + + constexpr auto b_thread_slice_copy_step = make_multi_index(EPerBlock, 0, 0, 0); + + // hack to control index calculation when iterating over A and B matrix for threadwise copy + constexpr auto a_e_k_global_iterator_hacks = AGlobalIteratorHacks{}; + constexpr auto b_e_n_ho_wo_global_iterator_hacks = BGlobalIteratorHacks{}; + + // hack to control index calculation when move slice window for A and B matrix for + // threadwise copy + constexpr auto a_e_k_global_move_slice_window_iterator_hack = + AGlobalMoveSliceWindowIteratorHacks{}; + constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack = + BGlobalMoveSliceWindowIteratorHacks{}; + + constexpr auto b_thread_space_size = b_e_n_ho_wo_thread_desc.GetElementSpaceSize(); + FloatAB p_b_thread[b_thread_space_size * 2]; + + FloatAB* p_b_thread_double = p_b_thread; + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.RunRead(a_e_k_global_desc, p_a_global, a_e_k_global_iterator_hacks); + + b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, + p_b_global, + b_e_n_ho_wo_thread_desc, + make_tuple(I0, I0, I0, I0), + p_b_thread_double, + b_e_n_ho_wo_global_iterator_hacks); + + a_blockwise_copy.RunWrite(a_e_k_desc, p_a_block); + } + + __syncthreads(); + + index_t b_block_data_begin = 0; + +#if 1 + if constexpr(HasMainKBlockLoop) + { + FloatAB* p_b_thread_even = p_b_thread_double; + FloatAB* p_b_thread_odd = p_b_thread_double + b_thread_space_size; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, + b_thread_slice_copy_step); + + b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, + p_b_global, + b_e_n_ho_wo_thread_desc, + make_tuple(I0, I0, I0, I0), + p_b_thread_odd, + b_e_n_ho_wo_global_iterator_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run( + p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)), + p_b_thread_even, + p_c_thread); + + b_block_data_begin += EPerBlock; + + b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, + b_thread_slice_copy_step); + + b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, + p_b_global, + b_e_n_ho_wo_thread_desc, + make_tuple(I0, I0, I0, I0), + p_b_thread_even, + b_e_n_ho_wo_global_iterator_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run( + p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)), + p_b_thread_odd, + p_c_thread); + + b_block_data_begin += EPerBlock; + + } while(b_block_data_begin < E - 2 * EPerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + b_threadwise_transfer.MoveSrcSliceWindow(b_e_n_ho_wo_global_desc, + b_thread_slice_copy_step); + + b_threadwise_transfer.Run(b_e_n_ho_wo_global_desc, + p_b_global, + b_e_n_ho_wo_thread_desc, + make_tuple(I0, I0, I0, I0), + p_b_thread_double + b_thread_space_size, + b_e_n_ho_wo_global_iterator_hacks); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run( + p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)), + p_b_thread_double, + p_c_thread); + + b_block_data_begin += EPerBlock; + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)), + p_b_thread_double + b_thread_space_size, + p_c_thread); + } + else // if has 1 iteration left + { + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + p_a_block + a_e_k_block_desc.CalculateOffset(make_tuple(b_block_data_begin, 0)), + p_b_thread_double, + p_c_thread); + } +#endif + +#if 1 + // output: register to global memory + { + // hack to control index calculation when iterating over c_k_n_ho_wo_global tensor + constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; + + const index_t k_thread_data_on_global = + k_block_data_on_global + k_thread_id * KPerThread; + + ThreadwiseDynamicTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_k_n_ho_wo_thread_desc), + decltype(c_k_n_ho_wo_global_desc), + Sequence, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + AddressSpace::Vgpr, + AddressSpace::Global, + CGlobalMemoryDataOperation, + 1, + true>( + c_k_n_ho_wo_global_desc, + make_multi_index( + k_thread_data_on_global, 0, ho_thread_data_on_global, wo_thread_data_on_global)) + .Run(c_k_n_ho_wo_thread_desc, + make_tuple(I0, I0, I0, I0), + p_c_thread, + c_k_n_ho_wo_global_desc, + p_c_global, + c_k_n_ho_wo_global_tensor_iterator_hacks); + } +#endif + } + + // pass tensor descriptor by reference + template + __device__ void Run(const AGlobalDesc& a_e_k_global_desc, + const FloatAB* __restrict__ p_a_global, + const BGlobalDesc& b_e_n_ho_wo_global_desc, + const FloatAB* __restrict__ p_b_global, + const CGlobalDesc& c_k_n_ho_wo_global_desc, + FloatC* __restrict__ p_c_global, + integral_constant, + integral_constant) const + { + constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + Run(a_e_k_global_desc, + p_a_global, + b_e_n_ho_wo_global_desc, + p_b_global, + c_k_n_ho_wo_global_desc, + p_c_global, + p_shared_block, + integral_constant{}, + integral_constant{}); + } + + // pass tensor descriptors by their pointers + template + __device__ void Run(const AGlobalDesc* p_a_e_k_global_desc, + const FloatAB* __restrict__ p_a_global, + const BGlobalDesc* p_b_e_n_ho_wo_global_desc, + const FloatAB* __restrict__ p_b_global, + const CGlobalDesc* p_c_k_n_ho_wo_global_desc, + FloatC* __restrict__ p_c_global, + integral_constant, + integral_constant) const + { + const auto a_e_k_global_desc = *p_a_e_k_global_desc; + const auto b_e_n_ho_wo_global_desc = *p_b_e_n_ho_wo_global_desc; + const auto c_k_n_ho_wo_global_desc = *p_c_k_n_ho_wo_global_desc; + + Run(a_e_k_global_desc, + p_a_global, + b_e_n_ho_wo_global_desc, + p_b_global, + c_k_n_ho_wo_global_desc, + p_c_global, + integral_constant{}, + integral_constant{}); + } + + // pass tensor descriptors by void* + template + __device__ void Run(const void* p_a_e_k_global_desc, + const FloatAB* __restrict__ p_a_global, + const void* p_b_e_n_ho_wo_global_desc, + const FloatAB* __restrict__ p_b_global, + const void* p_c_k_n_ho_wo_global_desc, + FloatC* __restrict__ p_c_global, + integral_constant, + integral_constant) const + { + const auto a_e_k_global_desc = *reinterpret_cast(p_a_e_k_global_desc); + const auto b_e_n_ho_wo_global_desc = + *reinterpret_cast(p_b_e_n_ho_wo_global_desc); + const auto c_k_n_ho_wo_global_desc = + *reinterpret_cast(p_c_k_n_ho_wo_global_desc); + + Run(a_e_k_global_desc, + p_a_global, + b_e_n_ho_wo_global_desc, + p_b_global, + c_k_n_ho_wo_global_desc, + p_c_global, + integral_constant{}, + integral_constant{}); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm.hpp index fbf2bfe911..7f9936bcd9 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm.hpp @@ -68,13 +68,13 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 Sequence{}, Number{}); // LDS allocation for A and B: be careful of alignment - constexpr index_t a_block_space = + constexpr index_t a_block_space_size = math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align); - constexpr index_t b_block_space = + constexpr index_t b_block_space_size = math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align); - return 2 * (a_block_space + b_block_space) * sizeof(Float); + return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float); } __device__ void Run(const Float* __restrict__ p_a_global, @@ -116,8 +116,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id()); - const index_t m_block_data_on_global = block_work_id[0] * MPerBlock; - const index_t n_block_data_on_global = block_work_id[1] * NPerBlock; + const index_t m_block_data_on_global = block_work_id[Number<0>{}] * MPerBlock; + const index_t n_block_data_on_global = block_work_id[Number<1>{}] * NPerBlock; // A matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment @@ -143,7 +143,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 AddressSpace::Vgpr, AddressSpace::Lds, InMemoryDataOperation::Set>( - {0, m_block_data_on_global}, {0, 0}); + make_multi_index(0, m_block_data_on_global), make_multi_index(0, 0)); // B matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment @@ -169,7 +169,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 AddressSpace::Vgpr, AddressSpace::Lds, InMemoryDataOperation::Set>( - {0, n_block_data_on_global}, {0, 0}); + make_multi_index(0, n_block_data_on_global), make_multi_index(0, 0)); // GEMM definition // c_mtx += transpose(a_mtx) * b_mtx @@ -209,14 +209,14 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 ThreadGemmBThreadCopySrcDataPerRead_N>{}; // LDS allocation for A and B: be careful of alignment - constexpr index_t a_block_space = + constexpr index_t a_block_space_size = math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align); - constexpr index_t b_block_space = + constexpr index_t b_block_space_size = math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align); Float* p_a_block_double = p_shared_block; - Float* p_b_block_double = p_shared_block + 2 * a_block_space; + Float* p_b_block_double = p_shared_block + 2 * a_block_space_size; // register allocation for output AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()]; @@ -230,47 +230,55 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 b_blockwise_copy.Run(p_b_global, p_b_block_double); } - constexpr auto a_block_slice_copy_steps = Sequence{}; - constexpr auto b_block_slice_copy_steps = Sequence{}; + constexpr auto a_block_slice_copy_step = Sequence{}; + constexpr auto b_block_slice_copy_step = Sequence{}; + + Float* p_a_block_even = p_a_block_double; + Float* p_b_block_even = p_b_block_double; + + Float* p_a_block_odd = p_a_block_double + a_block_space_size; + Float* p_b_block_odd = p_b_block_double + b_block_space_size; // LDS double buffer: main body - for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K; + for(index_t k_block_data_begin = 0; k_block_data_begin < K - 2 * KPerBlock; k_block_data_begin += 2 * KPerBlock) { -#pragma unroll - for(index_t iloop = 0; iloop < 2; ++iloop) - { - const bool even_loop = (iloop % 2 == 0); + Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; + Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; - Float* p_a_block_now = - even_loop ? p_a_block_double : p_a_block_double + a_block_space; - Float* p_b_block_now = - even_loop ? p_b_block_double : p_b_block_double + b_block_space; + // even iteration + a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True); + b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True); - Float* p_a_block_next = - even_loop ? p_a_block_double + a_block_space : p_a_block_double; - Float* p_b_block_next = - even_loop ? p_b_block_double + b_block_space : p_b_block_double; + __syncthreads(); - Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; - Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer); + b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer); - a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True); - b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True); + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(p_a_block_even, p_b_block_even, p_c_thread); - __syncthreads(); + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_odd); + b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_odd); - // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer); - b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer); + // odd iteration + a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True); + b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True); - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread); + __syncthreads(); - // LDS double buffer: store next data to LDS - a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next); - b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next); - } + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer); + b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(p_a_block_odd, p_b_block_odd, p_c_thread); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_even); + b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_even); } // LDS double buffer: tail @@ -282,8 +290,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; - a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True); - b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True); + a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True); + b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True); __syncthreads(); @@ -296,15 +304,16 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 // LDS double buffer: store last data to LDS a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, - p_a_block_double + a_block_space); + p_a_block_double + a_block_space_size); b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, - p_b_block_double + b_block_space); + p_b_block_double + b_block_space_size); __syncthreads(); // LDS double buffer: GEMM on last data - blockwise_gemm.Run( - p_a_block_double + a_block_space, p_b_block_double + b_block_space, p_c_thread); + blockwise_gemm.Run(p_a_block_double + a_block_space_size, + p_b_block_double + b_block_space_size, + p_c_thread); } else // if has 1 iteration left { @@ -355,11 +364,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1 AddressSpace::Vgpr, AddressSpace::Global, CGlobalMemoryDataOperation>( - {0, 0, 0, 0}, - {m_thread_data_on_global / M1, - m_thread_data_on_global % M1, - n_thread_data_on_global / N1, - n_thread_data_on_global % N1}) + make_multi_index(0, 0, 0, 0), + make_multi_index(m_thread_data_on_global / M1, + m_thread_data_on_global % M1, + n_thread_data_on_global / N1, + n_thread_data_on_global % N1)) .Run(p_c_thread, p_c_global); } } @@ -433,13 +442,13 @@ struct GridwiseGemmTransposedANormalBNormalC_v2 Sequence{}, Number{}); // LDS allocation for A and B: be careful of alignment - constexpr index_t a_block_space = + constexpr index_t a_block_space_size = math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align); - constexpr index_t b_block_space = + constexpr index_t b_block_space_size = math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align); - return 2 * (a_block_space + b_block_space) * sizeof(Float); + return 2 * (a_block_space_size + b_block_space_size) * sizeof(Float); } __device__ void Run(const Float* __restrict__ p_a_global, @@ -447,21 +456,23 @@ struct GridwiseGemmTransposedANormalBNormalC_v2 Float* __restrict__ p_c_global, Float* __restrict__ p_shared_block) const { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto True = integral_constant{}; constexpr auto False = integral_constant{}; - constexpr auto I0 = Number<0>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto a_k0_k1_k2_m_global_desc = AGlobalDesc{}; constexpr auto b_k0_k1_k2_n_global_desc = BGlobalDesc{}; constexpr auto c_m_n_global_desc = CGlobalDesc{}; - constexpr auto K0 = a_k0_k1_k2_m_global_desc.GetLengths()[0]; - constexpr auto K1 = a_k0_k1_k2_m_global_desc.GetLengths()[1]; - constexpr auto K = a_k0_k1_k2_m_global_desc.GetLengths()[2]; - constexpr auto M = c_m_n_global_desc.GetLengths()[0]; - constexpr auto N = c_m_n_global_desc.GetLengths()[1]; + constexpr auto K0 = a_k0_k1_k2_m_global_desc.GetLengths()[I0]; + constexpr auto K1 = a_k0_k1_k2_m_global_desc.GetLengths()[I1]; + constexpr auto K = a_k0_k1_k2_m_global_desc.GetLengths()[I2]; + constexpr auto M = c_m_n_global_desc.GetLengths()[I0]; + constexpr auto N = c_m_n_global_desc.GetLengths()[I1]; // don't do anything if K == 0 if(K == 0) @@ -487,8 +498,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v2 const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id()); - const index_t m_block_data_on_global = block_work_id[0] * MPerBlock; - const index_t n_block_data_on_global = block_work_id[1] * NPerBlock; + const index_t m_block_data_on_global = block_work_id[I0] * MPerBlock; + const index_t n_block_data_on_global = block_work_id[I1] * NPerBlock; // A matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment @@ -514,7 +525,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v2 AddressSpace::Vgpr, AddressSpace::Lds, InMemoryDataOperation::Set>( - {0, 0, 0, m_block_data_on_global}, {0, 0, 0, 0}); + make_multi_index(0, 0, 0, m_block_data_on_global), make_multi_index(0, 0, 0, 0)); // B matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment @@ -540,7 +551,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v2 AddressSpace::Vgpr, AddressSpace::Lds, InMemoryDataOperation::Set>( - {0, 0, 0, n_block_data_on_global}, {0, 0, 0, 0}); + make_multi_index(0, 0, 0, n_block_data_on_global), make_multi_index(0, 0, 0, 0)); // GEMM definition // c_mtx += transpose(a_mtx) * b_mtx @@ -582,14 +593,14 @@ struct GridwiseGemmTransposedANormalBNormalC_v2 ThreadGemmBThreadCopySrcDataPerRead_N>{}; // LDS allocation for A and B: be careful of alignment - constexpr index_t a_block_space = + constexpr index_t a_block_space_size = math::integer_least_multiple(a_k0_k1_k2_m_block_desc.GetElementSpace(), max_lds_align); - constexpr index_t b_block_space = + constexpr index_t b_block_space_size = math::integer_least_multiple(b_k0_k1_k2_n_block_desc.GetElementSpace(), max_lds_align); Float* p_a_block_double = p_shared_block; - Float* p_b_block_double = p_shared_block + 2 * a_block_space; + Float* p_b_block_double = p_shared_block + 2 * a_block_space_size; // register allocation for output AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()]; @@ -601,15 +612,14 @@ struct GridwiseGemmTransposedANormalBNormalC_v2 { for(index_t k1 = 0; k1 < K1; ++k1) { - // LDS double buffer: preload data into LDS { a_blockwise_copy.Run(p_a_global, p_a_block_double); b_blockwise_copy.Run(p_b_global, p_b_block_double); } - constexpr auto a_block_slice_copy_steps = Sequence<0, 0, KPerBlock, 0>{}; - constexpr auto b_block_slice_copy_steps = Sequence<0, 0, KPerBlock, 0>{}; + constexpr auto a_block_slice_copy_step = Sequence<0, 0, KPerBlock, 0>{}; + constexpr auto b_block_slice_copy_step = Sequence<0, 0, KPerBlock, 0>{}; // LDS double buffer: main body for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K; @@ -621,20 +631,20 @@ struct GridwiseGemmTransposedANormalBNormalC_v2 const bool even_loop = (iloop % 2 == 0); Float* p_a_block_now = - even_loop ? p_a_block_double : p_a_block_double + a_block_space; + even_loop ? p_a_block_double : p_a_block_double + a_block_space_size; Float* p_b_block_now = - even_loop ? p_b_block_double : p_b_block_double + b_block_space; + even_loop ? p_b_block_double : p_b_block_double + b_block_space_size; Float* p_a_block_next = - even_loop ? p_a_block_double + a_block_space : p_a_block_double; + even_loop ? p_a_block_double + a_block_space_size : p_a_block_double; Float* p_b_block_next = - even_loop ? p_b_block_double + b_block_space : p_b_block_double; + even_loop ? p_b_block_double + b_block_space_size : p_b_block_double; Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; - a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True); - b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True); + a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True); + b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True); __syncthreads(); @@ -660,8 +670,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v2 Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; - a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True); - b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True); + a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_step, True); + b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_step, True); __syncthreads(); @@ -673,16 +683,16 @@ struct GridwiseGemmTransposedANormalBNormalC_v2 blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); // LDS double buffer: store last data to LDS - a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, - p_a_block_double + a_block_space); - b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, - p_b_block_double + b_block_space); + a_blockwise_copy.RunStoreThreadBuffer( + p_a_thread_buffer, p_a_block_double + a_block_space_size); + b_blockwise_copy.RunStoreThreadBuffer( + p_b_thread_buffer, p_b_block_double + b_block_space_size); __syncthreads(); // LDS double buffer: GEMM on last data - blockwise_gemm.Run(p_a_block_double + a_block_space, - p_b_block_double + b_block_space, + blockwise_gemm.Run(p_a_block_double + a_block_space_size, + p_b_block_double + b_block_space_size, p_c_thread); } else // if has 1 iteration left @@ -750,11 +760,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v2 AddressSpace::Vgpr, AddressSpace::Global, CGlobalMemoryDataOperation>( - {0, 0, 0, 0}, - {m_thread_data_on_global / M1, - m_thread_data_on_global % M1, - n_thread_data_on_global / N1, - n_thread_data_on_global % N1}) + make_multi_index(0, 0, 0, 0), + make_multi_index(m_thread_data_on_global / M1, + m_thread_data_on_global % M1, + n_thread_data_on_global / N1, + n_thread_data_on_global % N1)) .Run(p_c_thread, p_c_global); } } diff --git a/composable_kernel/include/tensor_operation/gridwise_tensor_contraction.hpp b/composable_kernel/include/tensor_operation/gridwise_tensor_contraction.hpp deleted file mode 100644 index 3a3960863f..0000000000 --- a/composable_kernel/include/tensor_operation/gridwise_tensor_contraction.hpp +++ /dev/null @@ -1,330 +0,0 @@ -#ifndef CK_GRIDWISE_TENSOR_CONTRACTION_HPP -#define CK_GRIDWISE_TENSOR_CONTRACTION_HPP - -#include "common_header.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "ConstantMatrixDescriptor.hpp" -#include "blockwise_generic_tensor_slice_copy.hpp" -#include "threadwise_generic_tensor_slice_copy.hpp" -#include "blockwise_gemm.hpp" - -namespace ck { - -template -struct GridwiseTensorContraction_v1 -{ - __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() {} - - __device__ void Run(const Float* __restrict__ p_a_global, - const Float* __restrict__ p_b_global, - Float* __restrict__ p_c_global, - Float* __restrict__ p_shared_block) const - { - /// \todo sanity-check on AGlobalDesc, BGlboalDesc, CGlobalDesc length consisitency - /// \todo santiy-check on CBlockLengtsh - - constexpr auto True = integral_constant{}; - - constexpr auto a_global_desc = AGlobalDesc{}; - constexpr auto b_global_desc = BGlobalDesc{}; - constexpr auto c_global_desc = CGlobalDesc{}; - - constexpr auto K = a_global_desc.GetLengths()[0]; - - // don't do anything if K == 0 - if(K == 0) - { - return; - } - - // lds max alignment - constexpr index_t max_lds_align = math::lcm(ABlockCopyDstDataPerWrite_M, - BBlockCopyDstDataPerWrite_N, - ThreadGemmAThreadCopySrcDataPerRead_M, - ThreadGemmBThreadCopySrcDataPerRead_N); - - // divide block work by [M, N] - static_assert(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0, - "wrong! cannot divide work evenly among block"); - - constexpr index_t MBlockWork = M / MPerBlock; - constexpr index_t NBlockWork = N / NPerBlock; - - constexpr auto block_work_desc = - make_cluster_descriptor(Sequence{}); - - const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id()); - - const index_t m_block_data_on_global = block_work_id[0] * MPerBlock; - const index_t n_block_data_on_global = block_work_id[1] * NPerBlock; - - // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k_m_block_desc = make_native_tensor_descriptor_aligned( - Sequence{}, Number{}); - - // A matrix blockwise copy - auto a_blockwise_copy = - BlockwiseGenericTensorSliceCopy_v4, - ABlockCopySrcVectorReadDim, - 1, - ABlockCopySrcDataPerRead, - ABlockCopyDstDataPerWrite_M, - AddressSpace::Global, - AddressSpace::Vgpr, - AddressSpace::Lds, - InMemoryDataOperation::Set>( - {0, m_block_data_on_global}, {0, 0}); - - // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k_n_block_desc = make_native_tensor_descriptor_aligned( - Sequence{}, Number{}); - - // B matrix blockwise copy - auto b_blockwise_copy = - BlockwiseGenericTensorSliceCopy_v4, - BBlockCopySrcVectorReadDim, - 1, - BBlockCopySrcDataPerRead, - BBlockCopyDstDataPerWrite_N, - AddressSpace::Global, - AddressSpace::Vgpr, - AddressSpace::Lds, - InMemoryDataOperation::Set>( - {0, n_block_data_on_global}, {0, 0}); - - // GEMM definition - // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[KPerBlock, MPerBlock] is in LDS - // b_mtx[KPerBlocl, NPerBlock] is in LDS - // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in - // register - constexpr auto a_k_m_block_mtx_desc = make_ConstantMatrixDescriptor(a_k_m_block_desc); - constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(b_k_n_block_desc); - - // sanity check - static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 && - NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0, - "wrong!"); - - constexpr index_t GemmMRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster); - - constexpr index_t GemmNRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster); - - // c_thread_mtx definition: this is a mess - // TODO:: more elegent way of defining c_thread_mtx - constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed( - Number{}, Number{}); - - const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2< - BlockSize, - decltype(a_k_m_block_mtx_desc), - decltype(b_k_n_block_mtx_desc), - decltype(c_m0m1_n0n1_thread_mtx_desc), - MPerThread, - NPerThread, - MLevel0Cluster, - NLevel0Cluster, - MLevel1Cluster, - NLevel1Cluster, - KPerThread, - ThreadGemmAThreadCopySrcDataPerRead_M, - ThreadGemmBThreadCopySrcDataPerRead_N>{}; - - // LDS allocation for A and B: be careful of alignment - constexpr index_t a_block_space = - math::integer_least_multiple(a_k_m_block_desc.GetElementSpace(), max_lds_align); - - constexpr index_t b_block_space = - math::integer_least_multiple(b_k_n_block_desc.GetElementSpace(), max_lds_align); - - Float* p_a_block_double = p_shared_block; - Float* p_b_block_double = p_shared_block + 2 * a_block_space; - - // register allocation for output - AccFloat p_c_thread[c_m0m1_n0n1_thread_mtx_desc.GetElementSpace()]; - - // zero out threadwise output - threadwise_matrix_set_zero(c_m0m1_n0n1_thread_mtx_desc, p_c_thread); - - // LDS double buffer: preload data into LDS - { - a_blockwise_copy.Run(p_a_global, p_a_block_double); - b_blockwise_copy.Run(p_b_global, p_b_block_double); - } - - constexpr auto a_block_slice_copy_steps = Sequence{}; - constexpr auto b_block_slice_copy_steps = Sequence{}; - - // LDS double buffer: main body - for(index_t k_block_data_begin = 0; k_block_data_begin + 2 * KPerBlock < K; - k_block_data_begin += 2 * KPerBlock) - { -#pragma unroll - for(index_t iloop = 0; iloop < 2; ++iloop) - { - const bool even_loop = (iloop % 2 == 0); - - Float* p_a_block_now = - even_loop ? p_a_block_double : p_a_block_double + a_block_space; - Float* p_b_block_now = - even_loop ? p_b_block_double : p_b_block_double + b_block_space; - - Float* p_a_block_next = - even_loop ? p_a_block_double + a_block_space : p_a_block_double; - Float* p_b_block_next = - even_loop ? p_b_block_double + b_block_space : p_b_block_double; - - Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; - Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; - - a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True); - b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True); - - __syncthreads(); - - // LDS doubel buffer: load next data from device mem - a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer); - b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer); - - // LDS double buffer: GEMM on current data - blockwise_gemm.Run(p_a_block_now, p_b_block_now, p_c_thread); - - // LDS double buffer: store next data to LDS - a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, p_a_block_next); - b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, p_b_block_next); - } - } - - // LDS double buffer: tail - { - constexpr bool has_two_iteration_left = (K % (2 * KPerBlock) == 0); - - if(has_two_iteration_left) // if has 2 iteration left - { - Float p_a_thread_buffer[a_blockwise_copy.GetThreadBufferSize()]; - Float p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()]; - - a_blockwise_copy.MoveSrcSliceWindow(a_block_slice_copy_steps, True); - b_blockwise_copy.MoveSrcSliceWindow(b_block_slice_copy_steps, True); - - __syncthreads(); - - // LDS double buffer: load last data from device mem - a_blockwise_copy.RunLoadThreadBuffer(p_a_global, p_a_thread_buffer); - b_blockwise_copy.RunLoadThreadBuffer(p_b_global, p_b_thread_buffer); - - // LDS double buffer: GEMM on 2nd-last data - blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); - - // LDS double buffer: store last data to LDS - a_blockwise_copy.RunStoreThreadBuffer(p_a_thread_buffer, - p_a_block_double + a_block_space); - b_blockwise_copy.RunStoreThreadBuffer(p_b_thread_buffer, - p_b_block_double + b_block_space); - - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run( - p_a_block_double + a_block_space, p_b_block_double + b_block_space, p_c_thread); - } - else // if has 1 iteration left - { - __syncthreads(); - - // LDS double buffer: GEMM on last data - blockwise_gemm.Run(p_a_block_double, p_b_block_double, p_c_thread); - } - } - - // input: register to global memory - { - constexpr index_t M1 = MPerThread * MLevel0Cluster * MLevel1Cluster; - constexpr index_t M0 = M / M1; - - constexpr index_t N1 = NPerThread * NLevel0Cluster * NLevel1Cluster; - constexpr index_t N0 = N / N1; - - // define input tensor descriptor for threadwise copy - // thread input tensor, src of threadwise copy - constexpr auto c_m0_m1_n0_n1_thread_desc = make_native_tensor_descriptor_packed( - Sequence{}); - - constexpr auto c_m0_m1_n0_n1_global_desc = transform_tensor_descriptor( - c_m_n_global_desc, - make_tuple(UnMerge>{}, UnMerge>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); - - // calculate origin of thread input tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id()); - - const index_t m_thread_data_on_global = - m_block_data_on_global + c_thread_mtx_on_block.row; - - const index_t n_thread_data_on_global = - n_block_data_on_global + c_thread_mtx_on_block.col; - - ThreadwiseGenericTensorSliceCopy_v4r2( - {0, 0, 0, 0}, - {m_thread_data_on_global / M1, - m_thread_data_on_global % M1, - n_thread_data_on_global / N1, - n_thread_data_on_global % N1}) - .Run(p_c_thread, p_c_global); - } - } - - __device__ void Run(const Float* __restrict__ p_a_global, - const Float* __restrict__ p_b_global, - Float* __restrict__ p_c_global) const - { - constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(Float); - - __shared__ Float p_shared_block[shared_block_size]; - - Run(p_a_global, p_b_global, p_c_global, p_shared_block); - } -}; - -} // namespace ck -#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp new file mode 100644 index 0000000000..9e2f0b472f --- /dev/null +++ b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp @@ -0,0 +1,1298 @@ +#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP +#define CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +namespace ck { + +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor +template +struct lambda_scalar_per_access +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + return (i == VectorDim) ? ScalarPerVector : 1; + } +}; + +template +struct lambda_scalar_step_in_vector +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + return (i == VectorDim) ? 1 : 0; + } +}; + +// this version is less likely to have scratch memory issue, due to: +// 1. It does not keep reference to tensor descriptor +// 2. It does not construct new tensor coordinate for this->Run() +// Assume src_slice_origin_idx is 0 +// TODO: support non-zero src_slice_oring_idx +template ::type = false> +struct ThreadwiseDynamicTensorSliceTransfer_v1r3 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{})); + + using DstCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(DstDesc{}, Index{})); + + __device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v1r3( + const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + : dst_slice_origin_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx)) + { + static_assert(SrcDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_slice_origin_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcData* p_src, + const DstDesc& dst_desc, + DstData* p_dst, + const DstIteratorHacks& dst_iterator_hacks) + { + static_assert(SrcDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + + static_assert( + is_known_at_compile_time>>::value, + "wrong! SrcSliceOrigin need to known at compile-time"); + + // SrcDesc and src_slice_origin_idx are known at compile-time + constexpr auto src_desc = remove_cv_t>{}; + constexpr auto src_slice_origin_idx = SrcSliceOriginIdx{}; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + // make forward iterators + const auto dst_forward_iterators = generate_tuple( + [&](auto i) { + Index forward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_dynamic_tensor_coordinate_iterator( + dst_desc, forward_step, dst_iterator_hacks[I0][i]); + }, + Number{}); + + // make backward iterators + const auto dst_backward_iterators = generate_tuple( + [&](auto i) { + Index backward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_dynamic_tensor_coordinate_iterator( + dst_desc, backward_step, dst_iterator_hacks[I1][i]); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep; + + forward_sweep(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_access_idx[I0]; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; + }); + + forward_sweep(i) = tmp % 2 == 0; + }); + + return forward_sweep; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] + ? ordered_access_idx[i] + : ordered_access_lengths[i] - 1 - ordered_access_idx[i]; + }); + + auto dst_data_idx = container_reorder_given_old2new(ordered_idx, dim_access_order) * + dst_scalar_per_access; + + return dst_data_idx; + }(); + + // copy data + vector_type dst_vector; + + using dst_vector_t = typename vector_type::type; + + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + constexpr index_t src_offset = + src_desc.CalculateOffset(to_multi_index(src_slice_origin_idx) + dst_data_idx + + i * dst_scalar_step_in_vector); + + dst_vector.Scalars()(i) = type_convert{}(p_src[Number{}]); + }); + + const bool is_dst_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( + dst_desc, dst_slice_origin_coord_); + + if constexpr(SrcAddressSpace == AddressSpace::Vgpr && + DstAddressSpace == AddressSpace::Global) + { +#if CK_USE_AMD_BUFFER_ADDRESSING + amd_buffer_store_v2( + dst_vector.Vector(), + p_dst, + dst_slice_origin_coord_.GetOffset(), + is_dst_valid, + dst_desc.GetElementSpaceSize()); +#else + if(is_dst_valid) + { + *reinterpret_cast( + &(p_dst[dst_slice_origin_coord_.GetOffset()])) = dst_vector.Vector(); + } +#endif + } + else + { + if(is_dst_valid) + { + *reinterpret_cast( + &(p_dst[dst_slice_origin_coord_.GetOffset()])) = dst_vector.Vector(); + } + } + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; + }); + }); + + return move_on_dim; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_dynamic_tensor_coordinate(dst_desc, + dst_slice_origin_coord_, + dst_forward_iterators[dim_access_order[i]]); + } + else + { + move_dynamic_tensor_coordinate(dst_desc, + dst_slice_origin_coord_, + dst_backward_iterators[dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_iterator = + make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep()); + + move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_coord_, dst_reset_iterator); + } + } + + __device__ void Run(const SrcData* p_src, const DstDesc& dst_desc, DstData* p_dst) + { + constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto dst_iterator_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + Run(p_src, dst_desc, p_dst, dst_iterator_hacks); + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + lambda_scalar_per_access{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep; + + forward_sweep(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_access_lengths[I0] - 1; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; + }); + + forward_sweep(i) = tmp % 2 == 0; + }); + + return forward_sweep; + }(); + + // calculate dst data index after last iteration in Run(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0; + }); + + auto dst_data_idx = container_reorder_given_old2new(ordered_idx, dim_access_order) * + dst_scalar_per_access; + + return dst_data_idx; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step; + }(); + + return reset_dst_data_step; + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = + make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx); + + move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_coord_, adjusted_step); + } + + private: + DstCoord dst_slice_origin_coord_; +}; // namespace ck + +// this version is less likely to have scratch memory issue, due to: +// 1. It does not keep reference to tensor descriptor +// 2. It does not construct new tensor coordinate for this->Run() +// Assume dst_slice_origin_idx is 0 +template ::type = false> +struct ThreadwiseDynamicTensorSliceTransfer_v2 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{})); + + using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{})); + + __device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v2(const SrcDesc& src_desc, + const Index& src_slice_origin_idx) + : src_slice_origin_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx)) + { + static_assert(DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + } + + __device__ void SetDstSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_slice_origin_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcData* p_src, + const DstDesc&, + const DstSliceOriginIdx&, + DstData* p_dst, + const SrcIteratorHacks& src_iterator_hacks) + { + static_assert(DstDesc::IsKnownAtCompileTime(), + "wrong! DstDesc need to known at compile-time"); + + static_assert( + is_known_at_compile_time>>::value, + "wrong! DstSliceOrigin need to known at compile-time"); + + // DstDesc and dst_slice_origin_idx are known at compile-time + constexpr auto dst_desc = remove_cv_t>{}; + constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{}; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + lambda_scalar_per_access{}, Number{}); + + constexpr auto src_scalar_step_in_vector = + generate_sequence(lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + // make forward iterators + const auto src_forward_iterators = generate_tuple( + [&](auto i) { + Index forward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + }); + + return make_dynamic_tensor_coordinate_iterator( + src_desc, forward_step, src_iterator_hacks[I0][i]); + }, + Number{}); + + // make backward iterators + const auto src_backward_iterators = generate_tuple( + [&](auto i) { + Index backward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + }); + + return make_dynamic_tensor_coordinate_iterator( + src_desc, backward_step, src_iterator_hacks[I1][i]); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep; + + forward_sweep(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_access_idx[I0]; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; + }); + + forward_sweep(i) = tmp % 2 == 0; + }); + + return forward_sweep; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] + ? ordered_access_idx[i] + : ordered_access_lengths[i] - 1 - ordered_access_idx[i]; + }); + + auto src_data_idx = container_reorder_given_old2new(ordered_idx, dim_access_order) * + src_scalar_per_access; + + return src_data_idx; + }(); + + // copy data + static_assert(DstAddressSpace == AddressSpace::Vgpr, "wrong! hardcode for vgpr dst"); + + vector_type src_vector; + + using src_vector_t = typename vector_type::type; + + const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( + src_desc, src_slice_origin_coord_); + + if constexpr(SrcAddressSpace == AddressSpace::Global) + { +#if CK_USE_AMD_BUFFER_ADDRESSING + src_vector.Vector() = amd_buffer_load_v2( + p_src, + src_slice_origin_coord_.GetOffset(), + is_src_valid, + src_desc.GetElementSpaceSize()); +#else + src_vector.Vector() = is_src_valid + ? *reinterpret_cast( + &p_src[src_slice_origin_coord_.GetOffset()]) + : src_vector_t{0}; +#endif + } + else + { + src_vector.Vector() = is_src_valid + ? *reinterpret_cast( + &p_src[src_slice_origin_coord_.GetOffset()]) + : src_vector_t{0}; + } + + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = + dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + + i * src_scalar_step_in_vector); + + p_dst[Number{}] = src_vector.Scalars()[i]; + }); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1; + }); + }); + + return move_on_dim; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_dynamic_tensor_coordinate(src_desc, + src_slice_origin_coord_, + src_forward_iterators[dim_access_order[i]]); + } + else + { + move_dynamic_tensor_coordinate(src_desc, + src_slice_origin_coord_, + src_backward_iterators[dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_iterator = + make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep()); + + move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, src_reset_iterator); + } + } + + __device__ void Run(const SrcDesc& src_desc, const SrcData* p_src, DstData* p_dst) + { + constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto src_iterator_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + Run(src_desc, p_src, p_dst, src_iterator_hacks); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + lambda_scalar_per_access{}, Number{}); + + constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep; + + forward_sweep(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_access_lengths[I0] - 1; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1; + }); + + forward_sweep(i) = tmp % 2 == 0; + }); + + return forward_sweep; + }(); + + // calculate src data index after last iteration in Run(), if it has not being reset by + // RunWrite() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0; + }); + + auto src_data_idx = container_reorder_given_old2new(ordered_idx, dim_access_order) * + src_scalar_per_access; + + return src_data_idx; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step(i) = -src_data_idx[i]; }); + + return reset_src_data_step; + }(); + + return reset_src_data_step; + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = + make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx); + + move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, adjusted_step); + } + + private: + SrcCoord src_slice_origin_coord_; +}; // namespace ck + +// this version does following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions +// 1. It does not keep reference to tensor descriptor +// 2. It does not construct new tensor coordinate for this->Run() +// 3. It does not use pointer for VGPR thread buffer +// 4. It calculate offset for thread buffer directly, instead of moving the coordinate +template // control whether to move back dst coordinate after each + // RunWrite(), will be fused with MoveDstSliceWindow to + // save addr computation +struct ThreadwiseDynamicTensorSliceTransfer_v3 +{ + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{})); + + using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{})); + using DstCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(DstDesc{}, Index{})); + + __device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v3(const SrcDesc& src_desc, + const Index& src_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin) + : src_slice_origin_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin)), + dst_slice_origin_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin)) + { + static_assert(SrcAddressSpace == AddressSpace::Global or + SrcAddressSpace == AddressSpace::Lds, + "wrong!"); + static_assert(DstAddressSpace == AddressSpace::Global or + DstAddressSpace == AddressSpace::Lds, + "wrong!"); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_slice_origin_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_slice_origin_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcData* p_src, + const SrcIteratorHacks& src_iterator_hacks) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + lambda_scalar_per_access{}, Number{}); + + constexpr auto src_scalar_step_in_vector = + generate_sequence(lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward iterators + const auto src_forward_iterators = generate_tuple( + [&](auto i) { + Index forward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + }); + + return make_dynamic_tensor_coordinate_iterator( + src_desc, forward_step, src_iterator_hacks[I0][i]); + }, + Number{}); + + // make backward iterators + const auto src_backward_iterators = generate_tuple( + [&](auto i) { + Index backward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + }); + + return make_dynamic_tensor_coordinate_iterator( + src_desc, backward_step, src_iterator_hacks[I1][i]); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep; + + forward_sweep(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep(i) = tmp % 2 == 0; + }); + + return forward_sweep; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + auto src_data_idx = + container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + + return src_data_idx; + }(); + + // copy data + vector_type src_vector; + + using src_vector_t = typename vector_type::type; + + const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( + src_desc, src_slice_origin_coord_); + + if constexpr(SrcAddressSpace == AddressSpace::Global) + { +#if CK_USE_AMD_BUFFER_ADDRESSING + src_vector.Vector() = amd_buffer_load_v2( + p_src, + src_slice_origin_coord_.GetOffset(), + is_src_valid, + src_desc.GetElementSpaceSize()); +#else + src_vector.Vector() = is_src_valid + ? *reinterpret_cast( + &p_src[src_slice_origin_coord_.GetOffset()]) + : src_vector_t{0}; +#endif + } + else + { + src_vector.Vector() = is_src_valid + ? *reinterpret_cast( + &p_src[src_slice_origin_coord_.GetOffset()]) + : src_vector_t{0}; + } + + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t buffer_offset = + buffer_desc_.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector); + + buffer_(Number{}) = src_vector.Scalars()[i]; + }); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + return move_on_dim; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_dynamic_tensor_coordinate( + src_desc, + src_slice_origin_coord_, + src_forward_iterators[src_dim_access_order[i]]); + } + else + { + move_dynamic_tensor_coordinate( + src_desc, + src_slice_origin_coord_, + src_backward_iterators[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_iterator = + make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep()); + + move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, src_reset_iterator); + } + } + + template + __device__ void + RunWrite(const DstDesc& dst_desc, DstData* p_dst, const DstIteratorHacks& dst_iterator_hacks) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // src scalar per access on each dim + // TODO: don't use this + constexpr auto dst_scalar_per_access = generate_sequence( + lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // make forward iterators + const auto dst_forward_iterators = generate_tuple( + [&](auto i) { + Index forward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + const auto forward_iterator = make_dynamic_tensor_coordinate_iterator( + dst_desc, forward_step, dst_iterator_hacks[I0][i]); + + return forward_iterator; + }, + Number{}); + + // make backward iterators + const auto dst_backward_iterators = generate_tuple( + [&](auto i) { + Index backward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + const auto backward_iterator = make_dynamic_tensor_coordinate_iterator( + dst_desc, backward_step, dst_iterator_hacks[I1][i]); + + return backward_iterator; + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep; + + forward_sweep(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep(i) = tmp % 2 == 0; + }); + + return forward_sweep; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths[i] - 1 - + ordered_dst_access_idx[i]; + }); + + auto dst_data_idx = + container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + + return dst_data_idx; + }(); + + // copy data + // hardcoding for ds_write + // TODO refactor transfer_data() to encapsulate this + static_assert(DstAddressSpace == AddressSpace::Lds && + DstInMemOp == InMemoryDataOperation::Set, + "wrong! hardcoded for ds_write"); + + vector_type dst_vector; + + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + constexpr index_t buffer_offset = + buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector); + + dst_vector.Scalars()(i) = buffer_[Number{}]; + }); + + using DstVectorType = typename vector_type::type; + + *reinterpret_cast(p_dst + dst_slice_origin_coord_.GetOffset()) = + dst_vector.Vector(); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim(i) &= + ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_dynamic_tensor_coordinate( + dst_desc, + dst_slice_origin_coord_, + dst_forward_iterators[dst_dim_access_order[i]]); + } + else + { + move_dynamic_tensor_coordinate( + dst_desc, + dst_slice_origin_coord_, + dst_backward_iterators[dst_dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_iterator = + make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep()); + + move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_coord_, dst_reset_iterator); + } + } + + __device__ void RunRead(const SrcDesc& src_desc, const SrcData* p_src) + { + constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto src_iterator_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + RunRead(src_desc, p_src, src_iterator_hacks); + } + + __device__ void RunWrite(const DstDesc& dst_desc, DstData* p_dst) + { + constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto dst_iterator_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + RunWrite(dst_desc, p_dst, dst_iterator_hacks); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep; + + forward_sweep(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep(i) = tmp % 2 == 0; + }); + + return forward_sweep; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + auto src_data_idx = container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + + return src_data_idx; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step(i) = -src_data_idx[i]; }); + + return reset_src_data_step; + }(); + + return reset_src_data_step; + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep; + + forward_sweep(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep(i) = tmp % 2 == 0; + }); + + return forward_sweep; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + auto dst_data_idx = container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + + return dst_data_idx; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = + make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx); + + move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx, + const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_dynamic_tensor_coordinate_iterator( + src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack); + + move_dynamic_tensor_coordinate(src_desc, src_slice_origin_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = + make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx); + + move_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_coord_, adjusted_step); + } + + private: + static constexpr auto buffer_desc_ = + make_dynamic_naive_tensor_descriptor_packed_v2(sequence_to_tuple_of_number(SliceLengths{})); + + static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); + + StaticallyIndexedArray buffer_; + + SrcCoord src_slice_origin_coord_; + DstCoord dst_slice_origin_coord_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_gemm.hpp b/composable_kernel/include/tensor_operation/threadwise_gemm.hpp index 7cfd54e050..56440bc2b7 100644 --- a/composable_kernel/include/tensor_operation/threadwise_gemm.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_gemm.hpp @@ -39,7 +39,7 @@ struct ThreadwiseMatrixSliceCopy template __device__ static void Run(const Data* p_src, Data* p_dst) { - using vector_t = typename vector_type::MemoryType; + using vector_t = typename vector_type::type; for(index_t i = 0; i < NSliceRow; ++i) { @@ -153,9 +153,8 @@ struct ThreadwiseGemmTransANormalBNormalC (is_same{} && is_same{}) || (is_same{} && is_same{})); - static_if{}([&](auto fwd) { - Run_amd_asm(p_a, p_b, fwd(p_c)); - }).Else([&](auto) { Run_source(p_a, p_b, p_c); }); + static_if{}([&](auto fwd) { Run_amd_asm(p_a, p_b, fwd(p_c)); }) + .Else([&](auto) { Run_source(p_a, p_b, p_c); }); #else Run_source(p_a, p_b, p_c); #endif diff --git a/composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp b/composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp new file mode 100644 index 0000000000..1af88e5cbb --- /dev/null +++ b/composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp @@ -0,0 +1,172 @@ +#ifndef CK_THREADWISE_GEMM_V2_HPP +#define CK_THREADWISE_GEMM_V2_HPP + +#include "common_header.hpp" +#include "math.hpp" + +namespace ck { + +template +__device__ void threadwise_matrix_set_zero_v2(Desc, Float* __restrict__ p_thread) +{ + static_assert(Desc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + constexpr auto desc = Desc{}; + + constexpr auto M = desc.GetLength(I0); + constexpr auto N = desc.GetLength(I1); + + static_for<0, M, 1>{}([&](auto i) { + static_for<0, N, 1>{}([&](auto j) { + constexpr auto offset = desc.CalculateOffset(make_tuple(i, j)); + + p_thread[offset] = Float(0); + }); + }); +} + +template +struct ThreadwiseMatrixSliceCopy_v2 +{ + template + __device__ static void Run(const Data* p_src, Data* p_dst) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + using vector_t = typename vector_type::type; + + static_for<0, NSliceRow, 1>{}([&](auto i) { + static_for<0, NSliceCol, DataPerAccess>{}([&](auto j) { + constexpr auto src_offset = SrcDesc{}.CalculateOffset(make_tuple(i, j)); + constexpr auto dst_offset = DstDesc{}.CalculateOffset(make_tuple(i, j)); + + *reinterpret_cast(&p_dst[dst_offset]) = + *reinterpret_cast(&p_src[src_offset]); + }); + }); + } +}; + +// C[M, N] += transpose(A[K, M]) * B[K, N] +// Element of matrix can be vectorized data +template ::type = false> +struct ThreadwiseGemm_km_kn_mn_v1 +{ + template + __device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) + { + static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && + CDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + constexpr auto M = CDesc{}.GetLength(I0); + constexpr auto N = CDesc{}.GetLength(I1); + constexpr auto K = ADesc{}.GetLength(I0); + + static_for<0, K, 1>{}([&](auto k) { + static_for<0, M, 1>{}([&](auto m) { + static_for<0, N, 1>{}([&](auto n) { + constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(k, m)); + constexpr auto b_offset = BDesc{}.CalculateOffset(make_tuple(k, n)); + constexpr auto c_offset = CDesc{}.CalculateOffset(make_tuple(m, n)); + + p_c[c_offset] += + inner_product_with_conversion{}(p_a[a_offset], p_b[b_offset]); + }); + }); + }); + } + +#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM + template + __device__ static void Run_amd_asm(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) + { + static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && + CDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto M = CDesc{}.GetLength(I0); + constexpr auto N = CDesc{}.GetLength(I1); + constexpr auto K = ADesc{}.GetLength(I0); + + static_assert(N == 4 || N == 2, "wrong! this config not supported by asm yet"); + + static_for<0, K, 1>{}([&](auto k) { + static_for<0, M, 1>{}([&](auto m) { + constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(k, m)); + + if constexpr(N == 2) + { + constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(k, I0)); + constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(k, I1)); + + constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(m, I0)); + constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(m, I1)); + + amd_assembly_outer_product_1x2(p_a[a_offset], + p_b[b_offset_0], + p_b[b_offset_1], + p_c[c_offset_0], + p_c[c_offset_1]); + } + else if constexpr(N == 4) + { + constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(k, I0)); + constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(k, I1)); + constexpr auto b_offset_2 = BDesc{}.CalculateOffset(make_tuple(k, I2)); + constexpr auto b_offset_3 = BDesc{}.CalculateOffset(make_tuple(k, I3)); + + constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(m, I0)); + constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(m, I1)); + constexpr auto c_offset_2 = CDesc{}.CalculateOffset(make_tuple(m, I2)); + constexpr auto c_offset_3 = CDesc{}.CalculateOffset(make_tuple(m, I3)); + + amd_assembly_outer_product_1x4(p_a[a_offset], + p_b[b_offset_0], + p_b[b_offset_1], + p_b[b_offset_2], + p_b[b_offset_3], + p_c[c_offset_0], + p_c[c_offset_1], + p_c[c_offset_2], + p_c[c_offset_3]); + } + }); + }); + } +#endif + + template + __device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) + { +#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM + Run_amd_asm(p_a, p_b, p_c); +#else + Run_source(p_a, p_b, p_c); +#endif + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp b/composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp new file mode 100644 index 0000000000..96d0afa892 --- /dev/null +++ b/composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp @@ -0,0 +1,141 @@ +#ifndef CK_THREADWISE_GEMM_V3_HPP +#define CK_THREADWISE_GEMM_V3_HPP + +#include "common_header.hpp" +#include "math.hpp" + +namespace ck { + +template +__device__ void threadwise_matrix_set_zero_v3(Desc, Float* __restrict__ p_thread) +{ + static_assert(Desc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto desc = Desc{}; + + constexpr auto K = desc.GetLength(I0); + constexpr auto H = desc.GetLength(I2); + constexpr auto W = desc.GetLength(I3); + + static_for<0, K, 1>{}([&](auto i) { + static_for<0, H, 1>{}([&](auto j) { + static_for<0, W, 1>{}([&](auto k) { + constexpr auto offset = desc.CalculateOffset(make_tuple(i, 0, j, k)); + + p_thread[offset] = Float(0); + }); + }); + }); +} + +// C[M, N] += transpose(A[K, M]) * B[K, N] +// Element of matrix can be vectorized data +template ::type = false> +struct ThreadwiseGemm_km_kn_mn_v3 +{ + template + __device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) + { + static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && + CDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + // constexpr auto H = BDesc{}.GetLength(I2); + // constexpr auto W = BDesc{}.GetLength(I3); + constexpr auto H = 2; + constexpr auto W = 2; + + constexpr auto E = ADesc{}.GetLength(I0); + constexpr auto K = ADesc{}.GetLength(I1); + + static_for<0, E, 1>{}([&](auto e) { + static_for<0, K, 1>{}([&](auto k) { + constexpr auto a_offset = ADesc{}.CalculateOffset(make_tuple(e, k)); + + if constexpr(H == 2 && W == 2) + { + + constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(e, 0, 0, 0)); + constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(e, 0, 0, 1)); + constexpr auto b_offset_2 = BDesc{}.CalculateOffset(make_tuple(e, 0, 1, 0)); + constexpr auto b_offset_3 = BDesc{}.CalculateOffset(make_tuple(e, 0, 1, 1)); + + constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(k, 0, 0, 0)); + constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(k, 0, 0, 1)); + constexpr auto c_offset_2 = CDesc{}.CalculateOffset(make_tuple(k, 0, 1, 0)); + constexpr auto c_offset_3 = CDesc{}.CalculateOffset(make_tuple(k, 0, 1, 1)); + + amd_assembly_outer_product_1x4(p_a[a_offset], + p_b[b_offset_0], + p_b[b_offset_1], + p_b[b_offset_2], + p_b[b_offset_3], + p_c[c_offset_0], + p_c[c_offset_1], + p_c[c_offset_2], + p_c[c_offset_3]); + } + else if constexpr(H == 4 && W == 1) + { + + constexpr auto b_offset_0 = BDesc{}.CalculateOffset(make_tuple(e, 0, 0, 0)); + constexpr auto b_offset_1 = BDesc{}.CalculateOffset(make_tuple(e, 0, 1, 0)); + constexpr auto b_offset_2 = BDesc{}.CalculateOffset(make_tuple(e, 0, 2, 0)); + constexpr auto b_offset_3 = BDesc{}.CalculateOffset(make_tuple(e, 0, 3, 0)); + + constexpr auto c_offset_0 = CDesc{}.CalculateOffset(make_tuple(k, 0, 0, 0)); + constexpr auto c_offset_1 = CDesc{}.CalculateOffset(make_tuple(k, 0, 1, 0)); + constexpr auto c_offset_2 = CDesc{}.CalculateOffset(make_tuple(k, 0, 2, 0)); + constexpr auto c_offset_3 = CDesc{}.CalculateOffset(make_tuple(k, 0, 3, 0)); + + amd_assembly_outer_product_1x4(p_a[a_offset], + p_b[b_offset_0], + p_b[b_offset_1], + p_b[b_offset_2], + p_b[b_offset_3], + p_c[c_offset_0], + p_c[c_offset_1], + p_c[c_offset_2], + p_c[c_offset_3]); + } + else + { + static_for<0, H, 1>{}([&](auto h) { + static_for<0, W, 1>{}([&](auto w) { + constexpr auto b_offset = + BDesc{}.CalculateOffset(make_tuple(e, 0, h, w)); + constexpr auto c_offset = + CDesc{}.CalculateOffset(make_tuple(k, 0, h, w)); + + p_c[c_offset] += inner_product_with_conversion{}(p_a[a_offset], + p_b[b_offset]); + }); + }); + } + }); + }); + } + + template + __device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c) + { + Run_source(p_a, p_b, p_c); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp index db6660a3cb..f9f48a18b7 100644 --- a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp @@ -54,8 +54,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 } __device__ constexpr ThreadwiseGenericTensorSliceCopy_v4r2() - : ThreadwiseGenericTensorSliceCopy_v4r2(make_zero_array(), - make_zero_array()) + : ThreadwiseGenericTensorSliceCopy_v4r2(make_zero_multi_index(), + make_zero_multi_index()) { } @@ -82,113 +82,104 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2 constexpr auto long_vector_access_lengths = SliceLengths::Modify( vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size); - ford{}([&]( - auto long_vector_access_id) { + ford{}( + [&](auto long_vector_access_id) { - // data id w.r.t slicing-window - auto long_vector_data_begin_id = long_vector_access_id; - long_vector_data_begin_id(vector_access_dim) = - long_vector_size * long_vector_access_id[vector_access_dim]; + // data id w.r.t slicing-window + auto long_vector_data_begin_id = long_vector_access_id; + long_vector_data_begin_id(vector_access_dim) = + long_vector_size * long_vector_access_id[vector_access_dim]; - // buffer to hold a src long-vector - SrcData p_src_long_vector[long_vector_size]; + // buffer to hold a src long-vector + SrcData p_src_long_vector[long_vector_size]; -#if 1 - // zero out buffer - for(index_t i = 0; i < long_vector_size; ++i) - { - p_src_long_vector[i] = 0; - } -#endif + // load data from src to the long-vector buffer + static_for<0, long_vector_size / src_data_per_access, 1>{}([&](auto i) { + auto scalar_id = make_zero_multi_index(); + scalar_id(vector_access_dim) = i * src_data_per_access; - // load data from src to the long-vector buffer - for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i) - { - auto scalar_id = make_zero_array(); - scalar_id(vector_access_dim) = i * src_data_per_access; + const index_t buffer_offset = i * src_data_per_access; - const index_t buffer_offset = i * src_data_per_access; + const auto src_coord = + mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id); - const auto src_coord = mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id); + // Check src data's valid mapping situation, only check the first data in this + // src + // vector. It's user's responsiblity to make sure all data in the src vector + // has the valid/invalid mapping situation + transfer_data(p_src, + src_coord.GetOffset(), + src_coord.IsOffsetValidAssumingUpperIndexIsValid(), + SrcDesc::GetElementSpace(), + p_src_long_vector, + buffer_offset, + true, + long_vector_size); + }); - // Check src data's valid mapping situation, only check the first data in this src - // vector. It's user's responsiblity to make sure all data in the src vector - // has the valid/invalid mapping situation - transfer_data(p_src, - src_coord.GetOffset(), - src_coord.IsOffsetValidAssumingUpperIndexIsValid(), - SrcDesc::GetElementSpace(), - p_src_long_vector, - buffer_offset, - true, - long_vector_size); - } + // SrcData to DstData conversion + DstData p_dst_long_vector[long_vector_size]; - // SrcData to DstData conversion - DstData p_dst_long_vector[long_vector_size]; + static_for<0, long_vector_size, 1>{}([&](auto i) { + p_dst_long_vector[i] = type_convert{}(p_src_long_vector[i]); + }); - for(index_t i = 0; i < long_vector_size; ++i) - { - p_dst_long_vector[i] = type_convert{}(p_src_long_vector[i]); - } + // store data from the long-vector buffer to dst + static_for<0, long_vector_size / dst_data_per_access, 1>{}([&](auto i) { + auto scalar_id = make_zero_multi_index(); + scalar_id(vector_access_dim) = i * dst_data_per_access; - // store data from the long-vector buffer to dst - for(index_t i = 0; i < long_vector_size / dst_data_per_access; ++i) - { - auto scalar_id = make_zero_array(); - scalar_id(vector_access_dim) = i * dst_data_per_access; + const index_t buffer_offset = i * dst_data_per_access; - const index_t buffer_offset = i * dst_data_per_access; + const auto dst_coord = + mDstSliceOrigin + (long_vector_data_begin_id + scalar_id); - const auto dst_coord = mDstSliceOrigin + (long_vector_data_begin_id + scalar_id); - - // Check dst data's valid mapping situation, only check the first data in this dst - // vector. It's user's responsiblity to make sure all data in the dst vector - // has the valid/invalid mapping situation - transfer_data(p_dst_long_vector, - buffer_offset, - true, - long_vector_size, - p_dst, - dst_coord.GetOffset(), - dst_coord.IsOffsetValidAssumingUpperIndexIsValid(), - DstDesc::GetElementSpace()); - } - }); + // Check dst data's valid mapping situation, only check the first data in this + // dst + // vector. It's user's responsiblity to make sure all data in the dst vector + // has the valid/invalid mapping situation + transfer_data(p_dst_long_vector, + buffer_offset, + true, + long_vector_size, + p_dst, + dst_coord.GetOffset(), + dst_coord.IsOffsetValidAssumingUpperIndexIsValid(), + DstDesc::GetElementSpace()); + }); + }); } template __device__ void MoveSrcSliceWindow(const T& step_sizes_, integral_constant) { - const auto step_sizes = to_array(step_sizes_); + const auto step_sizes = to_multi_index(step_sizes_); - static_if{}([&](auto) { - mSrcSliceOrigin += to_array(step_sizes); - }).Else([&](auto) { mSrcSliceOrigin -= step_sizes; }); + static_if{}([&](auto) { mSrcSliceOrigin += to_multi_index(step_sizes); }) + .Else([&](auto) { mSrcSliceOrigin -= step_sizes; }); } template __device__ void MoveDstSliceWindow(const T& step_sizes_, integral_constant) { - const auto step_sizes = to_array(step_sizes_); + const auto step_sizes = to_multi_index(step_sizes_); - static_if{}([&](auto) { - mDstSliceOrigin += step_sizes; - }).Else([&](auto) { mDstSliceOrigin -= step_sizes; }); + static_if{}([&](auto) { mDstSliceOrigin += step_sizes; }) + .Else([&](auto) { mDstSliceOrigin -= step_sizes; }); } private: diff --git a/composable_kernel/include/utility/amd_buffer_addressing.hpp b/composable_kernel/include/utility/amd_buffer_addressing.hpp index 9176241bfc..b8630c464f 100644 --- a/composable_kernel/include/utility/amd_buffer_addressing.hpp +++ b/composable_kernel/include/utility/amd_buffer_addressing.hpp @@ -2,20 +2,10 @@ #define CK_AMD_BUFFER_ADDRESSING_HPP #include "float_type.hpp" +#include "amd_buffer_addressing_v2.hpp" namespace ck { -// For 128 bit SGPRs to supply resource constant in buffer instructions -// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions -template -union BufferResourceConstant -{ - int32x4_t data; - T* address[2]; - int32_t range[4]; - int32_t config[4]; -}; - __device__ float __llvm_amdgcn_buffer_load_f32(int32x4_t srsrc, index_t vindex, index_t offset, @@ -35,44 +25,17 @@ __llvm_amdgcn_buffer_load_f32x4(int32x4_t srsrc, index_t offset, bool glc, bool slc) __asm("llvm.amdgcn.buffer.load.v4f32"); +__device__ half_t +__llvm_amdgcn_raw_buffer_load_f16(int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); -__device__ half_t __llvm_amdgcn_buffer_load_f16(int32x4_t srsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.load.f16"); - -__device__ half2_t __llvm_amdgcn_buffer_load_f16x2(int32x4_t srsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.load.v2f16"); - -__device__ half4_t __llvm_amdgcn_buffer_load_f16x4(int32x4_t srsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.load.v4f16"); - -__device__ ushort __llvm_amdgcn_buffer_load_bf16(int32x4_t srsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.load.bf16"); - -__device__ ushort2_t -__llvm_amdgcn_buffer_load_bf16x2(int32x4_t srsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.load.v2bf16"); - -__device__ ushort4_t -__llvm_amdgcn_buffer_load_bf16x4(int32x4_t srsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.load.v4bf16"); +__device__ ushort +__llvm_amdgcn_raw_buffer_load_bf16(int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.bf16"); __device__ void __llvm_amdgcn_buffer_store_f32(float vdata, int32x4_t srsrc, @@ -95,67 +58,43 @@ __device__ void __llvm_amdgcn_buffer_store_f32x4(float4_t vdata, bool glc, bool slc) __asm("llvm.amdgcn.buffer.store.v4f32"); -__device__ void __llvm_amdgcn_buffer_store_f16(half_t vdata, - int32x4_t srsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.store.f16"); - -__device__ void __llvm_amdgcn_buffer_store_f16x2(half2_t vdata, - int32x4_t srsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.store.v2f16"); - -__device__ void __llvm_amdgcn_buffer_store_f16x4(half4_t vdata, - int32x4_t srsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.store.v4f16"); - -__device__ void __llvm_amdgcn_buffer_store_bf16(ushort vdata, - int32x4_t srsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.store.bf16"); +__device__ void +__llvm_amdgcn_raw_buffer_store_f16(half_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16"); __device__ void -__llvm_amdgcn_buffer_store_bf16x2(ushort2_t vdata, - int32x4_t srsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.store.v2bf16"); - -__device__ void -__llvm_amdgcn_buffer_store_bf16x4(ushort4_t vdata, - int32x4_t srsrc, - index_t vindex, - index_t offset, - bool glc, - bool slc) __asm("llvm.amdgcn.buffer.store.v4bf16"); +__llvm_amdgcn_raw_buffer_store_bf16(ushort vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.bf16"); +#if CK_USE_AMD_BUFFER_ATOMIC_FADD +#if CK_HIP_VERSION_FLAT >= 3010020405 +// starting ROCm-3.10, the return type becomes float +__device__ float +#else __device__ void +#endif __llvm_amdgcn_buffer_atomic_add_f32(float vdata, - int32x4_t srsrc, + int32x4_t rsrc, index_t vindex, index_t offset, bool slc) __asm("llvm.amdgcn.buffer.atomic.fadd.f32"); +#endif // buffer_load requires: -// 1) p_src_thread must be in global memory space, p_dst_thread must be vgpr -// 2) p_src_thread to be a wavewise pointer. +// 1) p_src_wave must be in global memory space +// 2) p_src_wave to be a wavewise pointer. // It is user's responsibility to make sure that is true. template -__device__ typename vector_type::MemoryType -amd_buffer_load(const T* p_src_wave, - index_t src_thread_data_offset, - bool src_thread_data_valid, - index_t src_elemenst_space); +__device__ typename vector_type::type amd_buffer_load(const T* p_src_wave, + index_t src_thread_data_offset, + bool src_thread_data_valid, + index_t src_elemenst_space); // buffer_store requires: // 1) p_src_thread must be in vgpr space, p_dst_thread must be global memory @@ -185,36 +124,27 @@ __device__ float amd_buffer_load(const float* p_src_wave, bool src_thread_data_valid, index_t src_data_range) { - BufferResourceConstant src_wave_buffer_resource; + BufferResource src_wave_buffer_resource; // wavewise base address (64 bit) src_wave_buffer_resource.address[0] = const_cast(p_src_wave); // wavewise range (32 bit) src_wave_buffer_resource.range[2] = src_data_range * sizeof(float); // wavewise setting (32 bit) - src_wave_buffer_resource.config[3] = 0x00027000; + src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float); -#if 1 // debug -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - return __llvm_amdgcn_buffer_load_f32(src_wave_buffer_resource.data, - 0, - src_thread_data_valid ? src_thread_addr_offset - : 0xffffffff, - false, - false); -#else +#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; return __llvm_amdgcn_buffer_load_f32( src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false); -#endif #else - return src_thread_data_valid - ? __llvm_amdgcn_buffer_load_f32( - src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false) - : 0; + float tmp = __llvm_amdgcn_buffer_load_f32( + src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false); + + return src_thread_data_valid ? tmp : float(0); #endif } @@ -224,29 +154,27 @@ __device__ float2_t amd_buffer_load(const float* p_src_wave, bool src_thread_data_valid, index_t src_data_range) { - BufferResourceConstant src_wave_buffer_resource; + BufferResource src_wave_buffer_resource; // wavewise base address (64 bit) src_wave_buffer_resource.address[0] = const_cast(p_src_wave); // wavewise range (32 bit) src_wave_buffer_resource.range[2] = src_data_range * sizeof(float); // wavewise setting (32 bit) - src_wave_buffer_resource.config[3] = 0x00027000; + src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float); -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - return __llvm_amdgcn_buffer_load_f32x2(src_wave_buffer_resource.data, - 0, - src_thread_data_valid ? src_thread_addr_offset - : 0xffffffff, - false, - false); -#else +#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; return __llvm_amdgcn_buffer_load_f32x2( src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false); +#else + float2_t tmp = __llvm_amdgcn_buffer_load_f32x2( + src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false); + + return src_thread_data_valid ? tmp : float2_t(0); #endif } @@ -256,29 +184,27 @@ __device__ float4_t amd_buffer_load(const float* p_src_wave, bool src_thread_data_valid, index_t src_data_range) { - BufferResourceConstant src_wave_buffer_resource; + BufferResource src_wave_buffer_resource; // wavewise base address (64 bit) src_wave_buffer_resource.address[0] = const_cast(p_src_wave); // wavewise range (32 bit) src_wave_buffer_resource.range[2] = src_data_range * sizeof(float); // wavewise setting (32 bit) - src_wave_buffer_resource.config[3] = 0x00027000; + src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float); -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - return __llvm_amdgcn_buffer_load_f32x4(src_wave_buffer_resource.data, - 0, - src_thread_data_valid ? src_thread_addr_offset - : 0xffffffff, - false, - false); -#else +#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; return __llvm_amdgcn_buffer_load_f32x4( src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false); +#else + float4_t tmp = __llvm_amdgcn_buffer_load_f32x4( + src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false); + + return src_thread_data_valid ? tmp : float4_t(0); #endif } @@ -288,33 +214,32 @@ __device__ half_t amd_buffer_load(const half_t* p_src_wave, bool src_thread_data_valid, index_t src_data_range) { - BufferResourceConstant src_wave_buffer_resource; + BufferResource src_wave_buffer_resource; // wavewise base address (64 bit) src_wave_buffer_resource.address[0] = const_cast(p_src_wave); // wavewise range (32 bit) src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t); // wavewise setting (32 bit) - src_wave_buffer_resource.config[3] = 0x00027000; + src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; -#if !CK_WORKAROUND_SWDEV_231101 index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t); -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - return __llvm_amdgcn_buffer_load_f16(src_wave_buffer_resource.data, - 0, - src_thread_data_valid ? src_thread_addr_offset - : 0xffffffff, - false, - false); -#else +#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; - return __llvm_amdgcn_buffer_load_f16( - src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false); -#endif + // current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and + // everything is passed to Voffset + return __llvm_amdgcn_raw_buffer_load_f16( + src_wave_buffer_resource.data, src_addr_shift + src_thread_addr_offset, 0, 0); #else - return src_thread_data_valid ? p_src_wave[src_thread_data_offset] : 0; + half_t zero(0); + + // current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and + // everything is passed to Voffset + return src_thread_data_valid ? __llvm_amdgcn_raw_buffer_load_f16( + src_wave_buffer_resource.data, src_thread_addr_offset, 0, 0) + : zero; #endif } @@ -324,32 +249,32 @@ __device__ half2_t amd_buffer_load(const half_t* p_src_wave, bool src_thread_data_valid, index_t src_data_range) { - BufferResourceConstant src_wave_buffer_resource; + BufferResource src_wave_buffer_resource; // wavewise base address (64 bit) src_wave_buffer_resource.address[0] = const_cast(p_src_wave); // wavewise range (32 bit) src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t); // wavewise setting (32 bit) - src_wave_buffer_resource.config[3] = 0x00027000; + src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t); -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - float dst_out_tmp = - __llvm_amdgcn_buffer_load_f32(src_wave_buffer_resource.data, - 0, - src_thread_data_valid ? src_thread_addr_offset : 0xffffffff, - false, - false); -#else +#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; float dst_out_tmp = __llvm_amdgcn_buffer_load_f32( src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false); -#endif return *reinterpret_cast(&dst_out_tmp); +#else + half2_t zeros(0); + + float dst_out_tmp = __llvm_amdgcn_buffer_load_f32( + src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false); + + return src_thread_data_valid ? *reinterpret_cast(&dst_out_tmp) : zeros; +#endif } template <> @@ -358,32 +283,32 @@ __device__ half4_t amd_buffer_load(const half_t* p_src_wave, bool src_thread_data_valid, index_t src_data_range) { - BufferResourceConstant src_wave_buffer_resource; + BufferResource src_wave_buffer_resource; // wavewise base address (64 bit) src_wave_buffer_resource.address[0] = const_cast(p_src_wave); // wavewise range (32 bit) src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t); // wavewise setting (32 bit) - src_wave_buffer_resource.config[3] = 0x00027000; + src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t); -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - float2_t dst_out_tmp = - __llvm_amdgcn_buffer_load_f32x2(src_wave_buffer_resource.data, - 0, - src_thread_data_valid ? src_thread_addr_offset : 0xffffffff, - false, - false); -#else +#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2( src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false); -#endif return *reinterpret_cast(&dst_out_tmp); +#else + half4_t zeros(0); + + float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2( + src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false); + + return src_thread_data_valid ? *reinterpret_cast(&dst_out_tmp) : zeros; +#endif } template <> @@ -392,32 +317,32 @@ __device__ half8_t amd_buffer_load(const half_t* p_src_wave, bool src_thread_data_valid, index_t src_data_range) { - BufferResourceConstant src_wave_buffer_resource; + BufferResource src_wave_buffer_resource; // wavewise base address (64 bit) src_wave_buffer_resource.address[0] = const_cast(p_src_wave); // wavewise range (32 bit) src_wave_buffer_resource.range[2] = src_data_range * sizeof(half_t); // wavewise setting (32 bit) - src_wave_buffer_resource.config[3] = 0x00027000; + src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; index_t src_thread_addr_offset = src_thread_data_offset * sizeof(half_t); -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - float4_t dst_out_tmp = - __llvm_amdgcn_buffer_load_f32x4(src_wave_buffer_resource.data, - 0, - src_thread_data_valid ? src_thread_addr_offset : 0xffffffff, - false, - false); -#else +#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4( src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false); -#endif return *reinterpret_cast(&dst_out_tmp); +#else + half8_t zeros(0); + + float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4( + src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false); + + return src_thread_data_valid ? *reinterpret_cast(&dst_out_tmp) : zeros; +#endif } template <> @@ -426,34 +351,32 @@ __device__ ushort amd_buffer_load(const ushort* p_src_wave, bool src_thread_data_valid, index_t src_data_range) { - BufferResourceConstant src_wave_buffer_resource; + BufferResource src_wave_buffer_resource; // wavewise base address (64 bit) src_wave_buffer_resource.address[0] = const_cast(p_src_wave); // wavewise range (32 bit) src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort); // wavewise setting (32 bit) - src_wave_buffer_resource.config[3] = 0x00027000; + src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; -#if !CK_WORKAROUND_SWDEV_231101 index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort); -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - return __llvm_amdgcn_buffer_load_bf16(src_wave_buffer_resource.data, - 0, - src_thread_data_valid ? src_thread_addr_offset - : 0xffffffff, - false, - false); -#else +#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; - return __llvm_amdgcn_buffer_load_bf16( - src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false); -#endif - + // current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and + // everything is passed to Voffset + return __llvm_amdgcn_raw_buffer_load_bf16( + src_wave_buffer_resource.data, src_addr_shift + src_thread_addr_offset, 0, 0); #else - return src_thread_data_valid ? p_src_wave[src_thread_data_offset] : 0; + ushort zero(0); + + // current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and + // everything is passed to Voffset + return src_thread_data_valid ? __llvm_amdgcn_raw_buffer_load_bf16( + src_wave_buffer_resource.data, src_thread_addr_offset, 0, 0) + : zero; #endif } @@ -463,32 +386,32 @@ __device__ ushort2_t amd_buffer_load(const ushort* p_src_wave, bool src_thread_data_valid, index_t src_data_range) { - BufferResourceConstant src_wave_buffer_resource; + BufferResource src_wave_buffer_resource; // wavewise base address (64 bit) src_wave_buffer_resource.address[0] = const_cast(p_src_wave); // wavewise range (32 bit) src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort); // wavewise setting (32 bit) - src_wave_buffer_resource.config[3] = 0x00027000; + src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort); -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - float dst_out_tmp = - __llvm_amdgcn_buffer_load_f32(src_wave_buffer_resource.data, - 0, - src_thread_data_valid ? src_thread_addr_offset : 0xffffffff, - false, - false); -#else +#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; float dst_out_tmp = __llvm_amdgcn_buffer_load_f32( src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false); -#endif return *reinterpret_cast(&dst_out_tmp); +#else + ushort2_t zeros(0); + + float dst_out_tmp = __llvm_amdgcn_buffer_load_f32( + src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false); + + return src_thread_data_valid ? *reinterpret_cast(&dst_out_tmp) : zeros; +#endif } template <> @@ -497,32 +420,32 @@ __device__ ushort4_t amd_buffer_load(const ushort* p_src_wave, bool src_thread_data_valid, index_t src_data_range) { - BufferResourceConstant src_wave_buffer_resource; + BufferResource src_wave_buffer_resource; // wavewise base address (64 bit) src_wave_buffer_resource.address[0] = const_cast(p_src_wave); // wavewise range (32 bit) src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort); // wavewise setting (32 bit) - src_wave_buffer_resource.config[3] = 0x00027000; + src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort); -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - float2_t dst_out_tmp = - __llvm_amdgcn_buffer_load_f32x2(src_wave_buffer_resource.data, - 0, - src_thread_data_valid ? src_thread_addr_offset : 0xffffffff, - false, - false); -#else +#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2( src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false); -#endif return *reinterpret_cast(&dst_out_tmp); +#else + ushort4_t zeros(0); + + float2_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x2( + src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false); + + return src_thread_data_valid ? *reinterpret_cast(&dst_out_tmp) : zeros; +#endif } template <> @@ -531,32 +454,32 @@ __device__ ushort8_t amd_buffer_load(const ushort* p_src_wave, bool src_thread_data_valid, index_t src_data_range) { - BufferResourceConstant src_wave_buffer_resource; + BufferResource src_wave_buffer_resource; // wavewise base address (64 bit) src_wave_buffer_resource.address[0] = const_cast(p_src_wave); // wavewise range (32 bit) src_wave_buffer_resource.range[2] = src_data_range * sizeof(ushort); // wavewise setting (32 bit) - src_wave_buffer_resource.config[3] = 0x00027000; + src_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; index_t src_thread_addr_offset = src_thread_data_offset * sizeof(ushort); -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - float4_t dst_out_tmp = - __llvm_amdgcn_buffer_load_f32x4(src_wave_buffer_resource.data, - 0, - src_thread_data_valid ? src_thread_addr_offset : 0xffffffff, - false, - false); -#else +#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4( src_wave_buffer_resource.data, 0, src_addr_shift + src_thread_addr_offset, false, false); -#endif return *reinterpret_cast(&dst_out_tmp); +#else + ushort8_t zeros(0); + + float4_t dst_out_tmp = __llvm_amdgcn_buffer_load_f32x4( + src_wave_buffer_resource.data, 0, src_thread_addr_offset, false, false); + + return src_thread_data_valid ? *reinterpret_cast(&dst_out_tmp) : zeros; +#endif } template <> @@ -566,26 +489,18 @@ __device__ void amd_buffer_store(const float* p_src_thread, bool dst_thread_data_valid, index_t dst_data_range) { - BufferResourceConstant dst_wave_buffer_resource; + BufferResource dst_wave_buffer_resource; // wavewise base address (64 bit) dst_wave_buffer_resource.address[0] = p_dst_wave; // wavewise range (32 bit) dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float); // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = 0x00027000; + dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); -#if 1 // debug -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - __llvm_amdgcn_buffer_store_f32(*p_src_thread, - dst_wave_buffer_resource.data, - 0, - dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff, - false, - false); -#else +#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; __llvm_amdgcn_buffer_store_f32(*p_src_thread, @@ -594,7 +509,6 @@ __device__ void amd_buffer_store(const float* p_src_thread, dst_addr_shift + dst_thread_addr_offset, false, false); -#endif #else if(dst_thread_data_valid) { @@ -611,25 +525,18 @@ __device__ void amd_buffer_store(const float* p_src_thread, bool dst_thread_data_valid, index_t dst_data_range) { - BufferResourceConstant dst_wave_buffer_resource; + BufferResource dst_wave_buffer_resource; // wavewise base address (64 bit) dst_wave_buffer_resource.address[0] = p_dst_wave; // wavewise range (32 bit) dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float); // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = 0x00027000; + dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - __llvm_amdgcn_buffer_store_f32x2(*reinterpret_cast(p_src_thread), - dst_wave_buffer_resource.data, - 0, - dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff, - false, - false); -#else +#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; __llvm_amdgcn_buffer_store_f32x2(*reinterpret_cast(p_src_thread), @@ -638,6 +545,16 @@ __device__ void amd_buffer_store(const float* p_src_thread, dst_addr_shift + dst_thread_addr_offset, false, false); +#else + if(dst_thread_data_valid) + { + __llvm_amdgcn_buffer_store_f32x2(*reinterpret_cast(p_src_thread), + dst_wave_buffer_resource.data, + 0, + dst_thread_addr_offset, + false, + false); + } #endif } @@ -648,25 +565,18 @@ __device__ void amd_buffer_store(const float* p_src_thread, bool dst_thread_data_valid, index_t dst_data_range) { - BufferResourceConstant dst_wave_buffer_resource; + BufferResource dst_wave_buffer_resource; // wavewise base address (64 bit) dst_wave_buffer_resource.address[0] = p_dst_wave; // wavewise range (32 bit) dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float); // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = 0x00027000; + dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - __llvm_amdgcn_buffer_store_f32x4(*reinterpret_cast(p_src_thread), - dst_wave_buffer_resource.data, - 0, - dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff, - false, - false); -#else +#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; __llvm_amdgcn_buffer_store_f32x4(*reinterpret_cast(p_src_thread), @@ -675,6 +585,16 @@ __device__ void amd_buffer_store(const float* p_src_thread, dst_addr_shift + dst_thread_addr_offset, false, false); +#else + if(dst_thread_data_valid) + { + __llvm_amdgcn_buffer_store_f32x4(*reinterpret_cast(p_src_thread), + dst_wave_buffer_resource.data, + 0, + dst_thread_addr_offset, + false, + false); + } #endif } @@ -685,40 +605,34 @@ __device__ void amd_buffer_store(const half_t* p_src_thread, bool dst_thread_data_valid, index_t dst_data_range) { - BufferResourceConstant dst_wave_buffer_resource; + BufferResource dst_wave_buffer_resource; // wavewise base address (64 bit) dst_wave_buffer_resource.address[0] = p_dst_wave; // wavewise range (32 bit) dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t); // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = 0x00027000; + dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; -#if !CK_WORKAROUND_SWDEV_231101 index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t); -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - __llvm_amdgcn_buffer_store_f16(*p_src_thread, - dst_wave_buffer_resource.data, - 0, - dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff, - false, - false); -#else +#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; - __llvm_amdgcn_buffer_store_f16(*p_src_thread, - dst_wave_buffer_resource.data, - 0, - dst_addr_shift + dst_thread_addr_offset, - false, - false); -#endif - + // current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and + // everything is passed to Voffset + __llvm_amdgcn_raw_buffer_store_f16(*p_src_thread, + dst_wave_buffer_resource.data, + dst_addr_shift + dst_thread_addr_offset, + 0, + 0); #else if(dst_thread_data_valid) { - p_dst_wave[dst_thread_data_offset] = *p_src_thread; + // current code cannot isolate Soffset and Voffset, so Soffset is hard-coded to 0, and + // everything is passed to Voffset + __llvm_amdgcn_raw_buffer_store_f16( + *p_src_thread, dst_wave_buffer_resource.data, dst_thread_addr_offset, 0, 0); } #endif } @@ -730,27 +644,20 @@ __device__ void amd_buffer_store(const half_t* p_src_thread, bool dst_thread_data_valid, index_t dst_data_range) { - BufferResourceConstant dst_wave_buffer_resource; + BufferResource dst_wave_buffer_resource; // wavewise base address (64 bit) dst_wave_buffer_resource.address[0] = p_dst_wave; // wavewise range (32 bit) dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t); // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = 0x00027000; + dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t); const float* p_src_tmp = reinterpret_cast(p_src_thread); -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - __llvm_amdgcn_buffer_store_f32(*p_src_tmp, - dst_wave_buffer_resource.data, - 0, - dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff, - false, - false); -#else +#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; __llvm_amdgcn_buffer_store_f32(*p_src_tmp, @@ -759,6 +666,12 @@ __device__ void amd_buffer_store(const half_t* p_src_thread, dst_addr_shift + dst_thread_addr_offset, false, false); +#else + if(dst_thread_data_valid) + { + __llvm_amdgcn_buffer_store_f32( + *p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false); + } #endif } @@ -769,27 +682,20 @@ __device__ void amd_buffer_store(const half_t* p_src_thread, bool dst_thread_data_valid, index_t dst_data_range) { - BufferResourceConstant dst_wave_buffer_resource; + BufferResource dst_wave_buffer_resource; // wavewise base address (64 bit) dst_wave_buffer_resource.address[0] = p_dst_wave; // wavewise range (32 bit) dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t); // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = 0x00027000; + dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t); const float2_t* p_src_tmp = reinterpret_cast(p_src_thread); -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - __llvm_amdgcn_buffer_store_f32x2(*p_src_tmp, - dst_wave_buffer_resource.data, - 0, - dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff, - false, - false); -#else +#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; __llvm_amdgcn_buffer_store_f32x2(*p_src_tmp, @@ -798,6 +704,12 @@ __device__ void amd_buffer_store(const half_t* p_src_thread, dst_addr_shift + dst_thread_addr_offset, false, false); +#else + if(dst_thread_data_valid) + { + __llvm_amdgcn_buffer_store_f32x2( + *p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false); + } #endif } @@ -808,27 +720,20 @@ __device__ void amd_buffer_store(const half_t* p_src_thread, bool dst_thread_data_valid, index_t dst_data_range) { - BufferResourceConstant dst_wave_buffer_resource; + BufferResource dst_wave_buffer_resource; // wavewise base address (64 bit) dst_wave_buffer_resource.address[0] = p_dst_wave; // wavewise range (32 bit) dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(half_t); // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = 0x00027000; + dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(half_t); const float4_t* p_src_tmp = reinterpret_cast(p_src_thread); -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - __llvm_amdgcn_buffer_store_f32x4(*p_src_tmp, - dst_wave_buffer_resource.data, - 0, - dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff, - false, - false); -#else +#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; __llvm_amdgcn_buffer_store_f32x4(*p_src_tmp, @@ -837,6 +742,12 @@ __device__ void amd_buffer_store(const half_t* p_src_thread, dst_addr_shift + dst_thread_addr_offset, false, false); +#else + if(dst_thread_data_valid) + { + __llvm_amdgcn_buffer_store_f32x4( + *p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false); + } #endif } @@ -847,40 +758,30 @@ __device__ void amd_buffer_store(const ushort* p_src_thread, bool dst_thread_data_valid, index_t dst_data_range) { - BufferResourceConstant dst_wave_buffer_resource; + BufferResource dst_wave_buffer_resource; // wavewise base address (64 bit) dst_wave_buffer_resource.address[0] = p_dst_wave; // wavewise range (32 bit) dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort); // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = 0x00027000; + dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; -#if !CK_WORKAROUND_SWDEV_231101 index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort); -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - __llvm_amdgcn_buffer_store_bf16(*p_src_thread, - dst_wave_buffer_resource.data, - 0, - dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff, - false, - false); -#else +#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; - __llvm_amdgcn_buffer_store_bf16(*p_src_thread, - dst_wave_buffer_resource.data, - 0, - dst_addr_shift + dst_thread_addr_offset, - false, - false); -#endif - + __llvm_amdgcn_raw_buffer_store_bf16(*p_src_thread, + dst_wave_buffer_resource.data, + dst_addr_shift + dst_thread_addr_offset, + 0, + 0); #else if(dst_thread_data_valid) { - p_dst_wave[dst_thread_data_offset] = *p_src_thread; + __llvm_amdgcn_raw_buffer_store_bf16( + *p_src_thread, dst_wave_buffer_resource.data, dst_thread_addr_offset, 0, 0); } #endif } @@ -892,27 +793,20 @@ __device__ void amd_buffer_store(const ushort* p_src_thread, bool dst_thread_data_valid, index_t dst_data_range) { - BufferResourceConstant dst_wave_buffer_resource; + BufferResource dst_wave_buffer_resource; // wavewise base address (64 bit) dst_wave_buffer_resource.address[0] = p_dst_wave; // wavewise range (32 bit) dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort); // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = 0x00027000; + dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort); const float* p_src_tmp = reinterpret_cast(p_src_thread); -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - __llvm_amdgcn_buffer_store_f32(*p_src_tmp, - dst_wave_buffer_resource.data, - 0, - dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff, - false, - false); -#else +#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; __llvm_amdgcn_buffer_store_f32(*p_src_tmp, @@ -921,6 +815,12 @@ __device__ void amd_buffer_store(const ushort* p_src_thread, dst_addr_shift + dst_thread_addr_offset, false, false); +#else + if(dst_thread_data_valid) + { + __llvm_amdgcn_buffer_store_f32( + *p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false); + } #endif } @@ -931,27 +831,20 @@ __device__ void amd_buffer_store(const ushort* p_src_thread, bool dst_thread_data_valid, index_t dst_data_range) { - BufferResourceConstant dst_wave_buffer_resource; + BufferResource dst_wave_buffer_resource; // wavewise base address (64 bit) dst_wave_buffer_resource.address[0] = p_dst_wave; // wavewise range (32 bit) dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort); // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = 0x00027000; + dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort); const float2_t* p_src_tmp = reinterpret_cast(p_src_thread); -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - __llvm_amdgcn_buffer_store_f32x2(*p_src_tmp, - dst_wave_buffer_resource.data, - 0, - dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff, - false, - false); -#else +#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; __llvm_amdgcn_buffer_store_f32x2(*p_src_tmp, @@ -960,6 +853,12 @@ __device__ void amd_buffer_store(const ushort* p_src_thread, dst_addr_shift + dst_thread_addr_offset, false, false); +#else + if(dst_thread_data_valid) + { + __llvm_amdgcn_buffer_store_f32x2( + *p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false); + } #endif } @@ -970,27 +869,20 @@ __device__ void amd_buffer_store(const ushort* p_src_thread, bool dst_thread_data_valid, index_t dst_data_range) { - BufferResourceConstant dst_wave_buffer_resource; + BufferResource dst_wave_buffer_resource; // wavewise base address (64 bit) dst_wave_buffer_resource.address[0] = p_dst_wave; // wavewise range (32 bit) dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(ushort); // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = 0x00027000; + dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(ushort); const float4_t* p_src_tmp = reinterpret_cast(p_src_thread); -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - __llvm_amdgcn_buffer_store_f32x4(*p_src_tmp, - dst_wave_buffer_resource.data, - 0, - dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff, - false, - false); -#else +#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; __llvm_amdgcn_buffer_store_f32x4(*p_src_tmp, @@ -999,9 +891,16 @@ __device__ void amd_buffer_store(const ushort* p_src_thread, dst_addr_shift + dst_thread_addr_offset, false, false); +#else + if(dst_thread_data_valid) + { + __llvm_amdgcn_buffer_store_f32x4( + *p_src_tmp, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false, false); + } #endif } +#if CK_USE_AMD_BUFFER_ATOMIC_FADD template <> __device__ void amd_buffer_atomic_add(const float* p_src_thread, float* p_dst_wave, @@ -1009,24 +908,18 @@ __device__ void amd_buffer_atomic_add(const float* p_src_thread, bool dst_thread_data_valid, index_t dst_data_range) { - BufferResourceConstant dst_wave_buffer_resource; + BufferResource dst_wave_buffer_resource; // wavewise base address (64 bit) dst_wave_buffer_resource.address[0] = p_dst_wave; // wavewise range (32 bit) dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float); // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = 0x00027000; + dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - __llvm_amdgcn_buffer_atomic_add_f32(*p_src_thread, - dst_wave_buffer_resource.data, - 0, - dst_thread_data_valid ? dst_thread_addr_offset : 0xffffffff, - false); -#else +#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; __llvm_amdgcn_buffer_atomic_add_f32(*p_src_thread, @@ -1034,6 +927,12 @@ __device__ void amd_buffer_atomic_add(const float* p_src_thread, 0, dst_addr_shift + dst_thread_addr_offset, false); +#else + if(dst_thread_data_valid) + { + __llvm_amdgcn_buffer_atomic_add_f32( + *p_src_thread, dst_wave_buffer_resource.data, 0, dst_thread_addr_offset, false); + } #endif } @@ -1044,28 +943,18 @@ __device__ void amd_buffer_atomic_add(const float* p_src_thread, bool dst_thread_data_valid, index_t dst_data_range) { - BufferResourceConstant dst_wave_buffer_resource; + BufferResource dst_wave_buffer_resource; // wavewise base address (64 bit) dst_wave_buffer_resource.address[0] = p_dst_wave; // wavewise range (32 bit) dst_wave_buffer_resource.range[2] = dst_data_range; // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = 0x00027000; + dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - for(index_t i = 0; i < 2; ++i) - { - __llvm_amdgcn_buffer_atomic_add_f32( - p_src_thread[i], - dst_wave_buffer_resource.data, - 0, - dst_thread_data_valid ? (dst_thread_addr_offset + i * sizeof(float)) : 0xffffffff, - false); - } -#else +#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; for(index_t i = 0; i < 2; ++i) @@ -1077,6 +966,18 @@ __device__ void amd_buffer_atomic_add(const float* p_src_thread, i * sizeof(float), false); } +#else + if(dst_thread_data_valid) + { + for(index_t i = 0; i < 2; ++i) + { + __llvm_amdgcn_buffer_atomic_add_f32(p_src_thread[i], + dst_wave_buffer_resource.data, + 0, + dst_thread_addr_offset + i * sizeof(float), + false); + } + } #endif } @@ -1087,28 +988,18 @@ __device__ void amd_buffer_atomic_add(const float* p_src_thread, bool dst_thread_data_valid, index_t dst_data_range) { - BufferResourceConstant dst_wave_buffer_resource; + BufferResource dst_wave_buffer_resource; // wavewise base address (64 bit) dst_wave_buffer_resource.address[0] = p_dst_wave; // wavewise range (32 bit) dst_wave_buffer_resource.range[2] = dst_data_range * sizeof(float); // wavewise setting (32 bit) - dst_wave_buffer_resource.config[3] = 0x00027000; + dst_wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float); -#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK - for(index_t i = 0; i < 4; ++i) - { - __llvm_amdgcn_buffer_atomic_add_f32( - p_src_thread[i], - dst_wave_buffer_resource.data, - 0, - dst_thread_data_valid ? (dst_thread_addr_offset + i * sizeof(float)) : 0xffffffff, - false); - } -#else +#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; for(index_t i = 0; i < 4; ++i) @@ -1120,8 +1011,21 @@ __device__ void amd_buffer_atomic_add(const float* p_src_thread, i * sizeof(float), false); } +#else + if(dst_thread_data_valid) + { + for(index_t i = 0; i < 4; ++i) + { + __llvm_amdgcn_buffer_atomic_add_f32(p_src_thread[i], + dst_wave_buffer_resource.data, + 0, + dst_thread_addr_offset + i * sizeof(float), + false); + } + } #endif } +#endif // CK_USE_AMD_BUFFER_ATOMIC_FADD } // namespace ck #endif diff --git a/composable_kernel/include/utility/amd_buffer_addressing_v2.hpp b/composable_kernel/include/utility/amd_buffer_addressing_v2.hpp new file mode 100644 index 0000000000..5c4e153869 --- /dev/null +++ b/composable_kernel/include/utility/amd_buffer_addressing_v2.hpp @@ -0,0 +1,365 @@ +#ifndef CK_AMD_BUFFER_ADDRESSING_V2_HPP +#define CK_AMD_BUFFER_ADDRESSING_V2_HPP + +#include "float_type.hpp" + +namespace ck { + +template +union BufferResource +{ + // 128 bit SGPRs to supply buffer resource in buffer instructions + // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions + int32x4_t data; + T* address[2]; + int32_t range[4]; + int32_t config[4]; +}; + +template +__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_size) +{ + BufferResource wave_buffer_resource; + + // wavewise base address (64 bit) + wave_buffer_resource.address[0] = const_cast*>(p_wave); + // wavewise range (32 bit) + wave_buffer_resource.range[2] = data_space_size * sizeof(T); + // wavewise setting (32 bit) + wave_buffer_resource.config[3] = CK_BUFFER_RESOURCE_3RD_DWORD; + + return wave_buffer_resource.data; +} + +// load +__device__ int8_t +__llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8"); +__device__ int16_t +__llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32"); +__device__ int32_t +__llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32"); + +__device__ int32x2_t +__llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32"); + +__device__ int32x4_t +__llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); + +__device__ float +__llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); + +__device__ float2_t +__llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32"); + +__device__ float4_t +__llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32"); + +// store +__device__ void +__llvm_amdgcn_raw_buffer_store_i8(int8_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8"); + +__device__ void +__llvm_amdgcn_raw_buffer_store_i16(int16_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); + +__device__ void +__llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32"); + +__device__ void +__llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32"); + +__device__ void +__llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); + +__device__ void +__llvm_amdgcn_raw_buffer_store_fp32(float vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32"); + +__device__ void +__llvm_amdgcn_raw_buffer_store_fp32x2(float2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); + +__device__ void +__llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); + +template +__device__ typename vector_type::type +amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) +{ + static_assert((is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)), + "wrong! not implemented"); + + if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return __llvm_amdgcn_raw_buffer_load_fp32( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 2) + { + return __llvm_amdgcn_raw_buffer_load_fp32x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 4) + { + return __llvm_amdgcn_raw_buffer_load_fp32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 8) + { + vector_type tmp; + + tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4( + src_wave_buffer_resource, src_thread_addr_offset, 4 * sizeof(float), 0); + + return tmp.Vector(); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + return __llvm_amdgcn_raw_buffer_load_i32( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 2) + { + return __llvm_amdgcn_raw_buffer_load_i32x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 4) + { + return __llvm_amdgcn_raw_buffer_load_i32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + } + else if constexpr(N == 8) + { + vector_type tmp; + + tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_i32x4( + src_wave_buffer_resource, src_thread_addr_offset, 4 * sizeof(int32_t), 0); + + return tmp.Vector(); + } + } +} + +template +__device__ void amd_buffer_store_impl_v2(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert((is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2 || N == 4)), + "wrong! not implemented"); + + if constexpr(is_same::value) + { + if constexpr(N == 1) + { + __llvm_amdgcn_raw_buffer_store_fp32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + __llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + __llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + __llvm_amdgcn_raw_buffer_store_i32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + __llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + __llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + __llvm_amdgcn_raw_buffer_store_i8(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + __llvm_amdgcn_raw_buffer_store_i16(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + __llvm_amdgcn_raw_buffer_store_i32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + } +} + +// buffer_load requires: +// 1) p_src_wave must be in global memory space +// 2) p_src_wave to be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ typename vector_type::type amd_buffer_load_v2(const T* p_src_wave, + index_t src_thread_data_offset, + bool src_thread_data_valid, + index_t src_element_space) +{ + const int32x4_t src_wave_buffer_resource = + make_wave_buffer_resource(p_src_wave, src_element_space); + + index_t src_thread_addr_offset = src_thread_data_offset * sizeof(T); + +#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK + uint32_t src_addr_shift = src_thread_data_valid ? 0 : 0x7fffffff; + + return amd_buffer_load_impl_v2( + src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); +#else + using vector_t = typename vector_type::type; + + vector_t tmp = + amd_buffer_load_impl_v2(src_wave_buffer_resource, src_thread_addr_offset, 0); + + return src_thread_data_valid ? tmp : vector_t(0); +#endif +} + +// buffer_store requires: +// 1) p_dst_wave must be global memory +// 2) p_dst_wave to be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ void amd_buffer_store_v2(const typename vector_type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_data_offset, + const bool dst_thread_data_valid, + const index_t dst_element_space) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space); + + index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(T); + +#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = dst_thread_data_valid ? 0 : 0x7fffffff; + + amd_buffer_store_impl_v2( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_data_valid) + { + amd_buffer_store_impl_v2( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } +#endif +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/amd_inline_asm.hpp b/composable_kernel/include/utility/amd_inline_asm.hpp index 51ebfb9065..6260fdc5bb 100644 --- a/composable_kernel/include/utility/amd_inline_asm.hpp +++ b/composable_kernel/include/utility/amd_inline_asm.hpp @@ -5,21 +5,44 @@ namespace ck { -// outer-product: c[i,j] += inner_product(a[i], b[j]) +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) { +#if CK_USE_AMD_V_FMAC_F32 + asm volatile("\n \ + v_fmac_f32 %0, %2, %3 \n \ + v_fmac_f32 %1, %2, %4 \n \ + " + : "=v"(c0), "=v"(c1) + : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); +#else asm volatile("\n \ v_mac_f32 %0, %2, %3 \n \ v_mac_f32 %1, %2, %4 \n \ " : "=v"(c0), "=v"(c1) : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); +#endif } -// outer-product: c[i,j] += inner_product(a[i], b[j]) +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +// c2 += inner_product(a, b2) +// c3 += inner_product(a, b3) __device__ void amd_assembly_outer_product_1x4( float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3) { +#if CK_USE_AMD_V_FMAC_F32 + asm volatile("\n \ + v_fmac_f32 %0, %4, %5 \n \ + v_fmac_f32 %1, %4, %6 \n \ + v_fmac_f32 %2, %4, %7 \n \ + v_fmac_f32 %3, %4, %8 \n \ + " + : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) + : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); +#else asm volatile("\n \ v_mac_f32 %0, %4, %5 \n \ v_mac_f32 %1, %4, %6 \n \ @@ -28,9 +51,11 @@ __device__ void amd_assembly_outer_product_1x4( " : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); +#endif } -// outer-product: c[i,j] += inner_product(a[i], b[j]) +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) __device__ void amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1) { @@ -38,15 +63,12 @@ amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, flo v_dot2_f32_f16 %0, %2, %3, %0\n \ v_dot2_f32_f16 %1, %2, %4, %1\n \ " - : "=v"(c0), "=v"(c1) // Dest registers - : "v"(a), // 1st Src register for 1 half2 registers - "v"(b0), // 2nd Src register - "v"(b1), - "0"(c0), // 3rd Src register - "1"(c1)); + : "=v"(c0), "=v"(c1) + : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); } -// outer-product: c[i,j] += inner_product(a[i], b[j]) +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) __device__ void amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1) { @@ -61,18 +83,21 @@ amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, flo v_dot2_f32_f16 %0, %3, %5, %0\n \ v_dot2_f32_f16 %1, %3, %7, %1\n \ " - : "=v"(c0), "=v"(c1) // Dest registers + : "=v"(c0), "=v"(c1) : "v"(p_a_half2[0]), - "v"(p_a_half2[1]), // 1st Src registers for 2 half2 registers + "v"(p_a_half2[1]), "v"(p_b0_half2[0]), "v"(p_b0_half2[1]), "v"(p_b1_half2[0]), - "v"(p_b1_half2[1]), // 2nd Src registers for 2 half2 registers + "v"(p_b1_half2[1]), "0"(c0), - "1"(c1)); // 3rd Src Acc registers for 2 half2 registers + "1"(c1)); } -// outer-product: c[i,j] += inner_product(a[i], b[j]) +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +// c2 += inner_product(a, b2) +// c3 += inner_product(a, b3) __device__ void amd_assembly_outer_product_1x4(half2_t a, half2_t b0, half2_t b1, @@ -89,19 +114,14 @@ __device__ void amd_assembly_outer_product_1x4(half2_t a, v_dot2_f32_f16 %2, %4, %7, %2\n \ v_dot2_f32_f16 %3, %4, %8, %3\n \ " - : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) // Dest registers - : "v"(a), // 1st Src register for 1 half2 registers - "v"(b0), // 2nd Src register - "v"(b1), - "v"(b2), - "v"(b3), - "0"(c0), // 3rd Src register - "1"(c1), - "2"(c2), - "3"(c3)); + : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) + : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); } -// outer-product: c[i,j] += inner_product(a[i], b[j]) +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +// c2 += inner_product(a, b2) +// c3 += inner_product(a, b3) __device__ void amd_assembly_outer_product_1x4(half4_t a, half4_t b0, half4_t b1, @@ -129,21 +149,70 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a, v_dot2_f32_f16 %2, %5, %11, %2\n \ v_dot2_f32_f16 %3, %5, %13, %3\n \ " - : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) // Dest registers + : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) : "v"(p_a_half2[0]), - "v"(p_a_half2[1]), // 1st Src registers for 2 half2 registers + "v"(p_a_half2[1]), "v"(p_b0_half2[0]), "v"(p_b0_half2[1]), "v"(p_b1_half2[0]), - "v"(p_b1_half2[1]), // 2nd Src registers for 2 half2 registers + "v"(p_b1_half2[1]), "v"(p_b2_half2[0]), "v"(p_b2_half2[1]), "v"(p_b3_half2[0]), - "v"(p_b3_half2[1]), // 2nd Src registers for 2 half2 registers + "v"(p_b3_half2[1]), "0"(c0), "1"(c1), "2"(c2), - "3"(c3)); // 3rd Src Acc registers for 2 half2 registers + "3"(c3)); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +__device__ void +amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0, int32_t& c1) +{ +#if 1 + asm volatile("\n \ + v_dot4_i32_i8 %0, %2, %3, %0\n \ + v_dot4_i32_i8 %1, %2, %4, %1\n \ + " + : "=v"(c0), "=v"(c1) + : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); +#else + c0 = __builtin_amdgcn_sdot4(a, b0, c0, false); + c1 = __builtin_amdgcn_sdot4(a, b1, c1, false); +#endif +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +// c2 += inner_product(a, b2) +// c3 += inner_product(a, b3) +__device__ void amd_assembly_outer_product_1x4(int8x4_t a, + int8x4_t b0, + int8x4_t b1, + int8x4_t b2, + int8x4_t b3, + int32_t& c0, + int32_t& c1, + int32_t& c2, + int32_t& c3) +{ +#if 1 + asm volatile("\n \ + v_dot4_i32_i8 %0, %4, %5, %0\n \ + v_dot4_i32_i8 %1, %4, %6, %1\n \ + v_dot4_i32_i8 %2, %4, %7, %2\n \ + v_dot4_i32_i8 %3, %4, %8, %3\n \ + " + : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) + : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); +#else + c0 = __builtin_amdgcn_sdot4(a, b0, c0, false); + c1 = __builtin_amdgcn_sdot4(a, b1, c1, false); + c2 = __builtin_amdgcn_sdot4(a, b2, c2, false); + c3 = __builtin_amdgcn_sdot4(a, b3, c3, false); +#endif } } // namespace ck diff --git a/composable_kernel/include/utility/amd_llvm_intrinsic.hpp b/composable_kernel/include/utility/amd_llvm_intrinsic.hpp new file mode 100644 index 0000000000..8981db7a7b --- /dev/null +++ b/composable_kernel/include/utility/amd_llvm_intrinsic.hpp @@ -0,0 +1,11 @@ +#ifndef CK_AMD_LLVM_INTRINSIC_HPP +#define CK_AMD_LLVM_INTRINSIC_HPP + +#include "float_type.hpp" + +namespace ck { + +__device__ int32_t __llvm_amdgcn_readfirstlane_i32(int32_t i) __asm("llvm.amdgcn.readfirstlane"); + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/array.hpp b/composable_kernel/include/utility/array.hpp index 0f68ec7d58..7271094d39 100644 --- a/composable_kernel/include/utility/array.hpp +++ b/composable_kernel/include/utility/array.hpp @@ -1,419 +1,62 @@ #ifndef CK_ARRAY_HPP #define CK_ARRAY_HPP -#include "sequence.hpp" #include "functional2.hpp" +#include "sequence.hpp" namespace ck { template struct Array { - using type = Array; + using type = Array; using data_type = TData; - // TODO: implement empty Array - index_t mData[NSize]; - - __host__ __device__ explicit constexpr Array() {} - - template - __host__ __device__ constexpr Array(X x, Xs... xs) - : mData{static_cast(x), static_cast(xs)...} - { - static_assert(sizeof...(Xs) + 1 == NSize, "wrong! size"); - } + TData mData[NSize]; __host__ __device__ static constexpr index_t Size() { return NSize; } - // TODO: remove - __host__ __device__ static constexpr index_t GetSize() { return Size(); } - - template - __host__ __device__ constexpr const TData& At(Number) const - { - static_assert(I < NSize, "wrong!"); - - return mData[I]; - } - - template - __host__ __device__ constexpr TData& At(Number) - { - static_assert(I < NSize, "wrong!"); - - return mData[I]; - } - __host__ __device__ constexpr const TData& At(index_t i) const { return mData[i]; } __host__ __device__ constexpr TData& At(index_t i) { return mData[i]; } - template - __host__ __device__ constexpr const TData& operator[](I i) const - { - return At(i); - } + __host__ __device__ constexpr const TData& operator[](index_t i) const { return At(i); } - template - __host__ __device__ constexpr TData& operator()(I i) - { - return At(i); - } + __host__ __device__ constexpr TData& operator()(index_t i) { return At(i); } template - __host__ __device__ constexpr type& operator=(const T& x) + __host__ __device__ constexpr auto operator=(const T& a) { - static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = x[i]; }); + static_assert(T::Size() == Size(), "wrong! size not the same"); - return *this; - } - - struct lambda_PushBack // emulate constexpr lambda - { - const Array& old_array; - Array& new_array; - - __host__ __device__ constexpr lambda_PushBack(const Array& old_array_, - Array& new_array_) - : old_array(old_array_), new_array(new_array_) - { - } - - template - __host__ __device__ constexpr void operator()(Number) const - { - new_array(Number{}) = old_array[I]; - } - }; - - __host__ __device__ constexpr auto PushBack(TData x) const - { - Array new_array; - - static_for<0, NSize, 1>{}(lambda_PushBack(*this, new_array)); - - new_array(Number{}) = x; - - return new_array; - } -}; - -// Arr: Array -// Picks: Sequence<...> -template -struct ArrayElementPicker -{ - using type = ArrayElementPicker; - using data_type = typename Arr::data_type; - - __host__ __device__ constexpr ArrayElementPicker() = delete; - - __host__ __device__ explicit constexpr ArrayElementPicker(Arr& array) : mArray{array} - { - constexpr index_t imax = reduce_on_sequence(Picks{}, math::maxer{}, Number<0>{}); - - static_assert(imax < Arr::Size(), "wrong! exceeding # array element"); - } - - __host__ __device__ static constexpr auto Size() { return Picks::Size(); } - - template - __host__ __device__ constexpr const data_type& At(Number) const - { - static_assert(I < Size(), "wrong!"); - - constexpr auto IP = Picks{}[I]; - return mArray[IP]; - } - - template - __host__ __device__ constexpr data_type& At(Number) - { - static_assert(I < Size(), "wrong!"); - - constexpr auto IP = Picks{}[I]; - return mArray(IP); - } - - template - __host__ __device__ constexpr const data_type& operator[](I i) const - { - return At(i); - } - - template - __host__ __device__ constexpr data_type& operator()(I i) - { - return At(i); - } - - template - __host__ __device__ constexpr type& operator=(const T& a) - { static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; }); return *this; } - - Arr& mArray; }; -template -__host__ __device__ constexpr auto pick_array_element(Arr& a, Picks) +// empty Array +template +struct Array { - return ArrayElementPicker(a); -} + using type = Array; + using data_type = TData; -template -__host__ __device__ constexpr auto to_array(const T& x) -{ - Array y; - - static_for<0, T::Size(), 1>{}([&](auto i) { y.At(i) = x.At(i); }); - - return y; -} - -// TODO: remove this -template -__host__ __device__ constexpr auto sequence2array(Sequence) -{ - return Array{Is...}; -} - -template -__host__ __device__ constexpr auto make_zero_array() -{ - constexpr auto zero_sequence = typename uniform_sequence_gen::type{}; - constexpr auto zero_array = sequence2array(zero_sequence); - return zero_array; -} - -template -__host__ __device__ constexpr auto reorder_array_given_new2old(const Array& old_array, - Sequence /*new2old*/) -{ - static_assert(NSize == sizeof...(IRs), "NSize not consistent"); - - static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); - - return Array{old_array[IRs]...}; -} - -template -struct lambda_reorder_array_given_old2new -{ - const Array& old_array; - Array& new_array; - - __host__ __device__ constexpr lambda_reorder_array_given_old2new( - const Array& old_array_, Array& new_array_) - : old_array(old_array_), new_array(new_array_) - { - } - - template - __host__ __device__ constexpr void operator()(Number) const - { - TData old_data = old_array[IOldDim]; - - constexpr index_t INewDim = MapOld2New::At(Number{}); - - new_array(Number{}) = old_data; - } + __host__ __device__ static constexpr index_t Size() { return 0; } }; -template -__host__ __device__ constexpr auto reorder_array_given_old2new(const Array& old_array, - Sequence /*old2new*/) +template +__host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs) { - Array new_array; - - static_assert(NSize == sizeof...(IRs), "NSize not consistent"); - - static_assert(is_valid_sequence_map>::value, "wrong! invalid reorder map"); - - static_for<0, NSize, 1>{}( - lambda_reorder_array_given_old2new>(old_array, new_array)); - - return new_array; + using data_type = remove_cv_t>; + return Array{{std::forward(x), std::forward(xs)...}}; } -template -__host__ __device__ constexpr auto extract_array(const Array& old_array, ExtractSeq) +// make empty array +template +__host__ __device__ constexpr auto make_array() { - Array new_array; - - constexpr index_t new_size = ExtractSeq::GetSize(); - - static_assert(new_size <= NSize, "wrong! too many extract"); - - static_for<0, new_size, 1>{}([&](auto I) { new_array(I) = old_array[ExtractSeq::At(I)]; }); - - return new_array; -} - -// emulate constepxr lambda for array -template -struct lambda_array_math -{ - const F& f; - const X& x; - const Y& y; - Z& z; - - __host__ __device__ constexpr lambda_array_math(const F& f_, const X& x_, const Y& y_, Z& z_) - : f(f_), x(x_), y(y_), z(z_) - { - } - - template - __host__ __device__ constexpr void operator()(Number) const - { - constexpr auto IDim = Number{}; - z(IDim) = f(x[IDim], y[IDim]); - } -}; - -// Array = Array + Array -template -__host__ __device__ constexpr auto operator+(Array a, Array b) -{ - Array result; - - auto f = math::plus{}; - - static_for<0, NSize, 1>{}( - lambda_array_math( - f, a, b, result)); - - return result; -} - -// Array = Array - Array -template -__host__ __device__ constexpr auto operator-(Array a, Array b) -{ - Array result; - - auto f = math::minus{}; - - static_for<0, NSize, 1>{}( - lambda_array_math( - f, a, b, result)); - - return result; -} - -// Array += Array -template -__host__ __device__ constexpr auto operator+=(Array& a, Array b) -{ - a = a + b; - return a; -} - -// Array -= Array -template -__host__ __device__ constexpr auto operator-=(Array& a, Array b) -{ - a = a - b; - return a; -} -// Array = Array + Sequence -template -__host__ __device__ constexpr auto operator+(Array a, Sequence b) -{ - static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); - - Array result; - - auto f = math::plus{}; - - static_for<0, NSize, 1>{}( - lambda_array_math( - f, a, b, result)); - - return result; -} - -// Array = Array - Sequence -template -__host__ __device__ constexpr auto operator-(Array a, Sequence b) -{ - static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); - - Array result; - - auto f = math::minus{}; - - static_for<0, NSize, 1>{}( - lambda_array_math( - f, a, b, result)); - - return result; -} - -// Array = Array * Sequence -template -__host__ __device__ constexpr auto operator*(Array a, Sequence b) -{ - static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); - - Array result; - - auto f = math::multiplies{}; - - static_for<0, NSize, 1>{}( - lambda_array_math( - f, a, b, result)); - - return result; -} - -// Array = Sequence - Array -template -__host__ __device__ constexpr auto operator-(Sequence a, Array b) -{ - static_assert(sizeof...(Is) == NSize, "wrong! size not the same"); - - Array result; - - auto f = math::minus{}; - - static_for<0, NSize, 1>{}( - lambda_array_math( - f, a, b, result)); - - return result; -} - -// Array = Array * TData -template -__host__ __device__ constexpr auto operator*(TData v, Array a) -{ - Array result; - - for(index_t i = 0; i < NSize; ++i) - { - result(i) = a[i] * v; - } - - return result; -} - -template -__host__ __device__ constexpr TData -accumulate_on_array(const Array& a, Reduce f, TData init) -{ - TData result = init; - - static_assert(NSize > 0, "wrong"); - - static_for<0, NSize, 1>{}([&](auto I) { result = f(result, a[I]); }); - - return result; + return Array{}; } } // namespace ck diff --git a/composable_kernel/include/utility/common_header.hpp b/composable_kernel/include/utility/common_header.hpp index dcf0be1674..63f94bd3c2 100644 --- a/composable_kernel/include/utility/common_header.hpp +++ b/composable_kernel/include/utility/common_header.hpp @@ -2,21 +2,26 @@ #define CK_COMMON_HEADER_HPP #include "config.hpp" -#include "utility.hpp" -#include "integral_constant.hpp" -#include "number.hpp" -#include "float_type.hpp" -#include "type.hpp" -#include "tuple.hpp" -#include "math.hpp" -#include "sequence.hpp" #include "array.hpp" +#include "container_helper.hpp" +#include "statically_indexed_array.hpp" +#include "container_element_picker.hpp" +#include "float_type.hpp" #include "functional.hpp" #include "functional2.hpp" #include "functional3.hpp" #include "functional4.hpp" #include "in_memory_operation.hpp" +#include "integral_constant.hpp" +#include "math.hpp" +#include "number.hpp" +#include "sequence.hpp" +#include "sequence_helper.hpp" #include "synchronization.hpp" +#include "tuple.hpp" +#include "tuple_helper.hpp" +#include "type.hpp" +#include "utility.hpp" #if CK_USE_AMD_INLINE_ASM #include "amd_inline_asm.hpp" diff --git a/composable_kernel/include/utility/config.amd.hpp.in b/composable_kernel/include/utility/config.amd.hpp.in index acb33271a5..8ad6f35dde 100644 --- a/composable_kernel/include/utility/config.amd.hpp.in +++ b/composable_kernel/include/utility/config.amd.hpp.in @@ -1,16 +1,43 @@ #ifndef CK_CONFIG_AMD_HPP #define CK_CONFIG_AMD_HPP +#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS #include "hip/hip_runtime.h" #include "hip/hip_fp16.h" +#endif #include "bfloat16_dev.hpp" -// index type: unsigned or signed -#define CK_UNSIGNED_INDEX_TYPE 0 - // device backend #define CK_DEVICE_BACKEND_AMD 1 +// GPU ID +#define CK_AMD_GPU_GFX906 1 +#define CK_AMD_GPU_GFX908 0 +#define CK_AMD_GPU_GFX1030 0 + +// HIP version +#ifndef CK_HIP_VERSION_FLAT +#define CK_HIP_VERSION_FLAT 0 +#endif + +// launch bounds +#define CK_USE_LAUNCH_BOUNDS 0 + +#ifdef CK_USE_LAUNCH_BOUNDS +#define CK_MAX_THREAD_PER_BLOCK 256 +#define CK_MIN_BLOCK_PER_CU 1 +#endif + +// buffer resourse +#if defined(CK_AMD_GPU_GFX906) || defined(CK_AMD_GPU_GFX908) +#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 +#elif defined(CK_AMD_GPU_GFX1030) +#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 +#endif + +// multi index +#define CK_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0 + // AMD inline asm #ifndef CK_USE_AMD_INLINE_ASM #define CK_USE_AMD_INLINE_ASM 1 @@ -20,14 +47,18 @@ #define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 1 #endif +#ifndef CK_USE_AMD_V_FMAC_F32 +#define CK_USE_AMD_V_FMAC_F32 1 +#endif + // AMD buffer addressing #ifndef CK_USE_AMD_BUFFER_ADDRESSING #define CK_USE_AMD_BUFFER_ADDRESSING 1 #endif // only gfx908 support native floating point atomic add -#ifndef CK_USE_AMD_BUFFER_ATOMIC_ADD -#define CK_USE_AMD_BUFFER_ATOMIC_ADD 0 +#ifndef CK_USE_AMD_BUFFER_ATOMIC_FADD +#define CK_USE_AMD_BUFFER_ATOMIC_FADD 0 #endif // AMD XDLOPS @@ -49,8 +80,16 @@ #endif // experimental implementation -#ifndef CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK -#define CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK 1 +#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1 +#endif + +#ifndef CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK +#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 +#endif + +#ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK +#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK 1 #endif #ifndef CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE @@ -65,14 +104,33 @@ #define CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK 0 #endif +// pass tensor descriptor by value, pointer or void* +#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 1 +#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_POINTER 0 +#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0 + +// hack: have underlying assumption that need to be satsified, otherwise it's a bug +// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be +// thread-invariant, otherwise it's a bug +// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread" +#ifndef CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE +#define CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0 +#endif + // workaround: put all workaround here -// workaround for unnecessary VGPA <--> AGRP data movement when using mfma LLVM intrinsic +// workaround for unnecessary VGPR <--> AGPR data movement when using mfma LLVM intrinsic #ifndef CK_WORKAROUND_SWDEV_229564 #define CK_WORKAROUND_SWDEV_229564 1 #endif -// workaround for buffer load/store fp16/bfp16 intrinsic bug -#ifndef CK_WORKAROUND_SWDEV_231101 -#define CK_WORKAROUND_SWDEV_231101 1 + +// workaround for accvgpr over-allocation +#ifndef CK_WORKAROUND_SWDEV_241664 +#define CK_WORKAROUND_SWDEV_241664 1 +#endif + +// workaround for compiler crash when compiling recursive lambda +#ifndef CK_WORKAROUND_SWDEV_275126 +#define CK_WORKAROUND_SWDEV_275126 1 #endif namespace ck { @@ -91,14 +149,8 @@ enum InMemoryDataOperation AtomicAdd }; -#if CK_UNSIGNED_INDEX_TYPE -using index_t = uint32_t; -#else +// index type using index_t = int32_t; -#endif - -// int32x4_t use by buffer_load and buffer_store llvm intrinsic -typedef int32_t int32x4_t __attribute__((ext_vector_type(4))); } // namespace ck #endif diff --git a/composable_kernel/include/utility/container_element_picker.hpp b/composable_kernel/include/utility/container_element_picker.hpp new file mode 100644 index 0000000000..f71086f6cb --- /dev/null +++ b/composable_kernel/include/utility/container_element_picker.hpp @@ -0,0 +1,153 @@ +#ifndef CK_CONTAINER_ELEMENT_PICKER_HPP +#define CK_CONTAINER_ELEMENT_PICKER_HPP + +#include "functional2.hpp" +#include "sequence.hpp" + +namespace ck { + +// Arr: Array or StaticallyIndexedArray +// Picks: Sequence<...> +template +struct ContainerElementPicker +{ + using type = ContainerElementPicker; +#if 0 + using data_type = typename Arr::data_type; +#endif + + __host__ __device__ constexpr ContainerElementPicker() = delete; + + __host__ __device__ constexpr ContainerElementPicker(Arr& array) : mArray{array} + { + constexpr index_t imax = reduce_on_sequence(Picks{}, math::maxer{}, Number<0>{}); + + static_assert(imax < Arr::Size(), "wrong! exceeding # array element"); + } + + __host__ __device__ static constexpr auto Size() { return Picks::Size(); } + + template + __host__ __device__ constexpr const auto& At(Number i) const + { + static_assert(I < Size(), "wrong!"); + + constexpr auto IP = Picks{}[i]; + return mArray[IP]; + } + + template + __host__ __device__ constexpr auto& At(Number i) + { + static_assert(I < Size(), "wrong!"); + + constexpr auto IP = Picks{}[i]; + return mArray(IP); + } + + template + __host__ __device__ constexpr const auto& operator[](Number i) const + { + return At(i); + } + + template + __host__ __device__ constexpr auto& operator()(Number i) + { + return At(i); + } + + template + __host__ __device__ constexpr auto operator=(const T& a) + { + static_assert(T::Size() == Size(), "wrong! size not the same"); + + static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; }); + + return *this; + } + + private: + Arr& mArray; +}; + +// Arr: Array or StaticallyIndexedArray +// Picks: Sequence<...> +template +struct ConstantContainerElementPicker +{ + using type = ConstantContainerElementPicker; +#if 0 + using data_type = typename Arr::data_type; +#endif + + __host__ __device__ constexpr ConstantContainerElementPicker() = delete; + + __host__ __device__ constexpr ConstantContainerElementPicker(const Arr& array) : mArray{array} + { + constexpr index_t imax = reduce_on_sequence(Picks{}, math::maxer{}, Number<0>{}); + + static_assert(imax < Arr::Size(), "wrong! exceeding # array element"); + } + + __host__ __device__ static constexpr auto Size() { return Picks::Size(); } + + template + __host__ __device__ constexpr const auto& At(Number i) const + { + static_assert(I < Size(), "wrong!"); + + constexpr auto IP = Picks{}[i]; + return mArray[IP]; + } + + template + __host__ __device__ constexpr const auto& operator[](Number i) const + { + return At(i); + } + + private: + const Arr& mArray; +}; + +template +__host__ __device__ constexpr auto operator+=(ContainerElementPicker& y, const X& x) +{ + using Y = ContainerElementPicker; + constexpr index_t nsize = Y::Size(); + + static_assert(nsize == X::Size(), "wrong! size not the same"); + + static_for<0, nsize, 1>{}([&](auto i) { y(i) += x[i]; }); + + return y; +} + +template +__host__ __device__ constexpr auto operator-=(ContainerElementPicker& y, const X& x) +{ + using Y = ContainerElementPicker; + constexpr index_t nsize = Y::Size(); + + static_assert(nsize == X::Size(), "wrong! size not the same"); + + static_for<0, nsize, 1>{}([&](auto i) { y(i) -= x[i]; }); + + return y; +} + +template +__host__ __device__ constexpr auto pick_container_element(Arr& a, Picks) +{ + return ContainerElementPicker(a); +} + +template +__host__ __device__ constexpr auto pick_container_element(const Arr& a, Picks) +{ + return ConstantContainerElementPicker(a); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/container_helper.hpp b/composable_kernel/include/utility/container_helper.hpp new file mode 100644 index 0000000000..f47a29d058 --- /dev/null +++ b/composable_kernel/include/utility/container_helper.hpp @@ -0,0 +1,375 @@ +#ifndef CK_CONTAINER_HELPER_HPP +#define CK_CONTAINER_HELPER_HPP + +#include "sequence.hpp" +#include "sequence_helper.hpp" +#include "array.hpp" +#include "tuple.hpp" +#include "tuple_helper.hpp" +#include "statically_indexed_array.hpp" +#include "container_element_picker.hpp" + +namespace ck { + +template +__host__ __device__ constexpr auto container_push_back(const Array& a, const TData& x) +{ + Array r; + + static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; }); + + r(Number{}) = x; + + return r; +} + +template +__host__ __device__ constexpr auto container_push_front(const Tuple& a, const T& x) +{ + return container_cat(make_tuple(x), a); +} + +template +__host__ __device__ constexpr auto container_push_back(const Tuple& a, const T& x) +{ + return container_cat(a, make_tuple(x)); +} + +template +__host__ __device__ constexpr auto +container_reorder_given_new2old(const Array& old_array, Sequence /*new2old*/) +{ + static_assert(NSize == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + return make_array(old_array[Number{}]...); +} + +template +__host__ __device__ constexpr auto +container_reorder_given_old2new(const Array& old_array, Sequence old2new) +{ + return container_reorder_given_new2old( + old_array, typename sequence_map_inverse::type{}); +} + +template +__host__ __device__ constexpr auto container_reorder_given_new2old(const Tuple& old_tuple, + Sequence /*new2old*/) +{ + static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + return make_tuple(old_tuple[Number{}]...); +} + +template +__host__ __device__ constexpr auto container_reorder_given_old2new(const Tuple& old_tuple, + Sequence old2new) +{ + return container_reorder_given_new2old( + old_tuple, typename sequence_map_inverse::type{}); +} + +template +__host__ __device__ constexpr auto container_reorder_given_new2old(Sequence /* old_seq */, + Sequence /*new2old*/) +{ + static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + return Sequence::At(Number{})...>{}; +} + +template +__host__ __device__ constexpr auto container_reorder_given_old2new(Sequence old_seq, + Sequence /* old2new */) +{ + static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + constexpr auto new2old = typename sequence_map_inverse>::type{}; + + return container_reorder_give_new2old(old_seq, new2old); +} + +#if !CK_WORKAROUND_SWDEV_275126 +// rocm-4.1 compiler would crash for recursive lambda +template +__host__ __device__ constexpr auto container_reduce(const Container& x, + Reduce reduce, + Init init, + Number = Number<0>{}, + Number = Number{}, + Number = Number<1>{}) +{ + static_assert((IEnd - IBegin) % IStep == 0, "wrong!"); + + // f is recursive function, fs is a dummy of f + // i is index, y_old is current scan, r_old is current reduction + auto f = [&](auto fs, auto i, auto r_old) { + auto r_new = reduce(x[i], r_old); + + if constexpr(i.value < IEnd - IStep) + { + // recursively call f/fs + return fs(fs, i + Number{}, r_new); + } + else + { + return r_new; + } + }; + + // start recursion + return f(f, Number{}, init); +} +#else +// i is index, y_old is current scan, r_old is current reduction +template +__host__ __device__ constexpr auto container_reduce_impl( + const Container& x, Reduce reduce, ROld r_old, Number i, Number, Number) +{ + auto r_new = reduce(x[i], r_old); + + if constexpr(i.value < IEnd - IStep) + { + return container_reduce_impl( + x, reduce, r_new, i + Number{}, Number{}, Number{}); + } + else + { + return r_new; + } +} + +// rocm-4.1 compiler would crash for recursive lambda +template +__host__ __device__ constexpr auto container_reduce(const Container& x, + Reduce reduce, + Init init, + Number = Number<0>{}, + Number = Number{}, + Number = Number<1>{}) +{ + static_assert((IEnd - IBegin) % IStep == 0, "wrong!"); + + return container_reduce_impl( + x, reduce, init, Number{}, Number{}, Number{}); +} +#endif + +template +__host__ __device__ constexpr auto +container_reverse_inclusive_scan(const Array& x, Reduce f, TData init) +{ + Array y; + + TData r = init; + + static_for{}([&](auto i) { + r = f(r, x[i]); + y(i) = r; + }); + + r = f(r, x[Number<0>{}]); + y(Number<0>{}) = r; + + return y; +} + +template +__host__ __device__ constexpr auto +container_reverse_exclusive_scan(const Array& x, Reduce f, TData init) +{ + Array y; + + TData r = init; + + static_for{}([&](auto i) { + y(i) = r; + r = f(r, x[i]); + }); + + y(Number<0>{}) = r; + + return y; +} + +#if !CK_WORKAROUND_SWDEV_275126 +// rocm4.1 compiler would crash with recursive lambda +template +__host__ __device__ constexpr auto +container_reverse_exclusive_scan(const Tuple& x, Reduce reduce, Init init) +{ + constexpr index_t NSize = sizeof...(Xs); + + // f is recursive function, fs is a dummy of f + // i is index, y_old is current scan, r_old is current reduction + auto f = [&](auto fs, auto i, auto y_old, auto r_old) { + auto r_new = reduce(x[i], r_old); + + auto y_new = container_push_front(y_old, r_new); + + if constexpr(i.value > 1) + { + // recursively call f/fs + return fs(fs, i - Number<1>{}, y_new, r_new); + } + else + { + return y_new; + } + }; + + // start recursion + return f(f, Number{}, make_tuple(init), init); +} +#else +// i is index, y_old is current scan, r_old is current reduction +template +__host__ __device__ constexpr auto container_reverse_exclusive_scan_impl( + const Tuple& x, Reduce reduce, Number i, YOld y_old, ROld r_old) +{ + auto r_new = reduce(x[i], r_old); + + auto y_new = container_push_front(y_old, r_new); + + if constexpr(i.value > 1) + { + // recursively call f/fs + return container_reverse_exclusive_scan_impl(x, reduce, i - Number<1>{}, y_new, r_new); + } + else + { + return y_new; + } +} + +template +__host__ __device__ constexpr auto +container_reverse_exclusive_scan(const Tuple& x, Reduce reduce, Init init) +{ + constexpr index_t NSize = sizeof...(Xs); + + return container_reverse_exclusive_scan_impl( + x, reduce, Number{}, make_tuple(init), init); +} +#endif + +// TODO: update to like container_reverse_exclusive_scan to deal with Tuple of Numebr<> +template +__host__ __device__ constexpr auto +container_reverse_inclusive_scan(const Tuple& x, Reduce f, TData init) +{ + constexpr index_t NSize = sizeof...(Xs); + + Tuple y; + + TData r = init; + + static_for{}([&](auto i) { + r = f(r, x[i]); + y(i) = r; + }); + + r = f(r, x[Number<0>{}]); + y(Number<0>{}) = r; + + return y; +} + +template +__host__ __device__ constexpr auto container_cat(const X& x, const Ys&... ys) +{ + return container_cat(x, container_cat(ys...)); +} + +template +__host__ __device__ constexpr auto container_cat(const Array& ax, const Array& ay) +{ + return unpack2( + [&](auto&&... zs) { return make_array(std::forward(zs)...); }, ax, ay); +} + +template +__host__ __device__ constexpr auto container_cat(const Tuple& tx, const Tuple& ty) +{ + return unpack2( + [&](auto&&... zs) { return make_tuple(std::forward(zs)...); }, tx, ty); +} + +template +__host__ __device__ constexpr auto container_cat(const Container& x) +{ + return x; +} + +template +__host__ __device__ constexpr auto get_container_subset(const Array& arr, Sequence) +{ + static_assert(N >= sizeof...(Is), "wrong! size"); + + return make_array(arr[Number{}]...); +} + +template +__host__ __device__ constexpr auto get_container_subset(const Tuple& tup, Sequence) +{ + static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size"); + + return make_tuple(tup[Number{}]...); +} + +template +__host__ __device__ constexpr void +set_container_subset(Array& y, Sequence picks, const Array& x) +{ + static_assert(N >= sizeof...(Is), "wrong! size"); + + static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; }); +} + +template +__host__ __device__ constexpr void +set_container_subset(Tuple& y, Sequence picks, const Tuple& x) +{ + static_assert(sizeof...(Ys) >= sizeof...(Is) && sizeof...(Is) == sizeof...(Xs), "wrong! size"); + + static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; }); +} + +template +__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence) +{ + using Seq = Sequence; + + return generate_tuple( + [&](auto i) { + constexpr index_t tmp = Seq::At(i); + return Number{}; + }, + Seq::Size()); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/float_type.amd.hpp.in b/composable_kernel/include/utility/float_type.amd.hpp.in index 058bfcca02..6c0aedb6d4 100644 --- a/composable_kernel/include/utility/float_type.amd.hpp.in +++ b/composable_kernel/include/utility/float_type.amd.hpp.in @@ -3,263 +3,278 @@ namespace ck { -// For some reason, HIP compiler need this definition to generate optimal ISA -// float -typedef float float2_t __attribute__((ext_vector_type(2))); -typedef float float4_t __attribute__((ext_vector_type(4))); -typedef float float16_t __attribute__((ext_vector_type(16))); -typedef float float32_t __attribute__((ext_vector_type(32))); +template +struct vector_type; -// float16 -typedef _Float16 half_t; -typedef _Float16 half2_t __attribute__((ext_vector_type(2))); -typedef _Float16 half4_t __attribute__((ext_vector_type(4))); -typedef _Float16 half8_t __attribute__((ext_vector_type(8))); - -// bfloat16 -typedef ushort ushort2_t __attribute__((ext_vector_type(2))); -typedef ushort ushort4_t __attribute__((ext_vector_type(4))); -typedef ushort ushort8_t __attribute__((ext_vector_type(8))); - -template -struct vector_type +template +struct vector_type { - typedef struct + using type = T; + + union { - T scalar[N]; - } MemoryType; + T d1_; + StaticallyIndexedArray d1x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + __host__ __device__ static constexpr index_t Size() { return 1; } + + __host__ __device__ constexpr const auto& Vector() const { return data_.d1_; } + + __host__ __device__ constexpr auto& Vector() { return data_.d1_; } + + __host__ __device__ constexpr const auto& Scalars() const { return data_.d1x1_; } + + __host__ __device__ constexpr auto& Scalars() { return data_.d1x1_; } + + __host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x1_; } + + __host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x1_; } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + + using type = d2_t; + + union + { + d2_t d2_; + StaticallyIndexedArray d1x2_; + StaticallyIndexedArray d2x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + __host__ __device__ static constexpr index_t Size() { return 2; } + + __host__ __device__ constexpr const auto& Vector() const { return data_.d2_; } + + __host__ __device__ constexpr auto& Vector() { return data_.d2_; } + + __host__ __device__ constexpr const auto& Scalars() const { return data_.d1x2_; } + + __host__ __device__ constexpr auto& Scalars() { return data_.d1x2_; } + + __host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x2_; } + + __host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x1_; } + + __host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x2_; } + + __host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x1_; } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + + using type = d4_t; + + union + { + d4_t d4_; + StaticallyIndexedArray d1x4_; + StaticallyIndexedArray d2x2_; + StaticallyIndexedArray d4x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + __host__ __device__ static constexpr index_t Size() { return 4; } + + __host__ __device__ constexpr const auto& Vector() const { return data_.d4_; } + + __host__ __device__ constexpr auto& Vector() { return data_.d4_; } + + __host__ __device__ constexpr const auto& Scalars() const { return data_.d1x4_; } + + __host__ __device__ constexpr auto& Scalars() { return data_.d1x4_; } + + __host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x4_; } + + __host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x2_; } + + __host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x1_; } + + __host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x4_; } + + __host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x2_; } + + __host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x1_; } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + + using type = d8_t; + + union + { + d8_t d8_; + StaticallyIndexedArray d1x8_; + StaticallyIndexedArray d2x4_; + StaticallyIndexedArray d4x2_; + StaticallyIndexedArray d8x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + __host__ __device__ static constexpr index_t Size() { return 8; } + + __host__ __device__ constexpr const auto& Vector() const { return data_.d8_; } + + __host__ __device__ constexpr auto& Vector() { return data_.d8_; } + + __host__ __device__ constexpr const auto& Scalars() const { return data_.d1x8_; } + + __host__ __device__ constexpr auto& Scalars() { return data_.d1x8_; } + + __host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x8_; } + + __host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x4_; } + + __host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x2_; } + + __host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x1_; } + + __host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x8_; } + + __host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x4_; } + + __host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x2_; } + + __host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; } }; template <> -struct vector_type +struct vector_type { - using MemoryType = float; + using d1_t = int8_t; + typedef int16_t d2_t; - template - __host__ __device__ static void SetScalar(MemoryType& v, float s, Number) + using type = d2_t; + + union { - static_assert(I < 1, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } + d2_t d2_; + StaticallyIndexedArray d1x2_; + StaticallyIndexedArray d2x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + __host__ __device__ static constexpr index_t Size() { return 2; } + + __host__ __device__ constexpr const auto& Vector() const { return data_.d2_; } + + __host__ __device__ constexpr auto& Vector() { return data_.d2_; } + + __host__ __device__ constexpr const auto& Scalars() const { return data_.d1x2_; } + + __host__ __device__ constexpr auto& Scalars() { return data_.d1x2_; } + + __host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x2_; } + + __host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x1_; } + + __host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x2_; } + + __host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x1_; } }; template <> -struct vector_type +struct vector_type { - using MemoryType = float2_t; + using d1_t = int8_t; + typedef int16_t d2_t; + typedef int32_t d4_t; - union DataType - { - MemoryType vector; - float scalar[2]; - }; + using type = d4_t; - template - __host__ __device__ static void SetScalar(MemoryType& v, float s, Number) + union { - static_assert(I < 2, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } + d4_t d4_; + StaticallyIndexedArray d1x4_; + StaticallyIndexedArray d2x2_; + StaticallyIndexedArray d4x1_; + } data_; - __host__ __device__ static MemoryType Pack(float s0, float s1) - { - DataType data; - data.scalar[0] = s0; - data.scalar[1] = s1; - return data.vector; - } + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + __host__ __device__ static constexpr index_t Size() { return 4; } + + __host__ __device__ constexpr const auto& Vector() const { return data_.d4_; } + + __host__ __device__ constexpr auto& Vector() { return data_.d4_; } + + __host__ __device__ constexpr const auto& Scalars() const { return data_.d1x4_; } + + __host__ __device__ constexpr auto& Scalars() { return data_.d1x4_; } + + __host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x4_; } + + __host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x2_; } + + __host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x1_; } + + __host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x4_; } + + __host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x2_; } + + __host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x1_; } }; -template <> -struct vector_type -{ - using MemoryType = float4_t; +// fp32 +using float2_t = typename vector_type::type; +using float4_t = typename vector_type::type; +using float8_t = typename vector_type::type; - __host__ __device__ static constexpr index_t GetSize() { return 4; } +// fp16 +using half_t = _Float16; +using half2_t = typename vector_type::type; +using half4_t = typename vector_type::type; +using half8_t = typename vector_type::type; - template - __host__ __device__ static void SetScalar(MemoryType& v, float s, Number) - { - static_assert(I < 4, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } -}; +// bfp16 +using ushort2_t = typename vector_type::type; +using ushort4_t = typename vector_type::type; +using ushort8_t = typename vector_type::type; -template <> -struct vector_type -{ - using MemoryType = half_t; +// i32 +using int32x2_t = typename vector_type::type; +using int32x4_t = typename vector_type::type; +using int32x8_t = typename vector_type::type; - template - __host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number) - { - static_assert(I < 1, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } -}; - -template <> -struct vector_type -{ - using MemoryType = half2_t; - - union DataType - { - MemoryType vector; - half_t scalar[2]; - }; - - template - __host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number) - { - static_assert(I < 2, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } - - __host__ __device__ static MemoryType Pack(half_t s0, half_t s1) - { - DataType data; - data.scalar[0] = s0; - data.scalar[1] = s1; - return data.vector; - } -}; - -template <> -struct vector_type -{ - using MemoryType = half4_t; - - union DataType - { - MemoryType vector; - half_t scalar[4]; - }; - - template - __host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number) - { - static_assert(I < 4, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } - - __host__ __device__ static MemoryType Pack(half_t s0, half_t s1, half_t s2, half_t s3) - { - DataType data; - data.scalar[0] = s0; - data.scalar[1] = s1; - data.scalar[2] = s2; - data.scalar[3] = s3; - return data.vector; - } -}; - -template <> -struct vector_type -{ - using MemoryType = half8_t; - - union DataType - { - MemoryType vector; - half_t scalar[8]; - }; - - template - __host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number) - { - static_assert(I < 8, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } -}; - -template <> -struct vector_type -{ - using MemoryType = ushort; - - template - __host__ __device__ static void SetScalar(MemoryType& v, ushort s, Number) - { - static_assert(I < 1, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } -}; - -template <> -struct vector_type -{ - using MemoryType = ushort2_t; - - union DataType - { - MemoryType vector; - ushort scalar[2]; - }; - - template - __host__ __device__ static void SetScalar(MemoryType& v, ushort s, Number) - { - static_assert(I < 2, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } - - __host__ __device__ static MemoryType Pack(ushort s0, ushort s1) - { - DataType data; - data.scalar[0] = s0; - data.scalar[1] = s1; - return data.vector; - } -}; - -template <> -struct vector_type -{ - using MemoryType = ushort4_t; - - union DataType - { - MemoryType vector; - ushort scalar[4]; - }; - - template - __host__ __device__ static void SetScalar(MemoryType& v, ushort s, Number) - { - static_assert(I < 4, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } - - __host__ __device__ static MemoryType Pack(ushort s0, ushort s1, ushort s2, ushort s3) - { - DataType data; - data.scalar[0] = s0; - data.scalar[1] = s1; - data.scalar[2] = s2; - data.scalar[3] = s3; - return data.vector; - } -}; - -template <> -struct vector_type -{ - using MemoryType = ushort8_t; - - union DataType - { - MemoryType vector; - ushort scalar[8]; - }; - - template - __host__ __device__ static void SetScalar(MemoryType& v, ushort s, Number) - { - static_assert(I < 8, "wrong"); - *(reinterpret_cast(&v) + I) = s; - } -}; +// i8 +// hack for int8x4_t, because compiler does not have native support for int8x4_t +// int8x4_t is defined as int32_t +using int8x4_t = typename vector_type::type; // data type conversion template @@ -291,113 +306,37 @@ struct inner_product_with_conversion { static constexpr auto convert = type_convert(); - __device__ T operator()(float4_t a, float4_t b) const + template + __device__ T operator()(typename vector_type::type a, + typename vector_type::type b) const { - const float* p_a_float = reinterpret_cast(&a); - const float* p_b_float = reinterpret_cast(&b); + const vector_type a_vector{a}; + const vector_type b_vector{b}; T acc = 0; - for(index_t v = 0; v < 4; ++v) - { - acc += convert(p_a_float[v]) * convert(p_b_float[v]); - } + + static_for<0, N, 1>{}([&](auto i) { + acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]); + }); return acc; } - __device__ T operator()(float2_t a, float2_t b) const + __device__ T operator()(float_t a, float_t b) const { return convert(a) * convert(b); } + + // hack for int8x4_t, because compiler does not have native support for int8x4_t + // int8x4_t is defined as int32_t + __device__ T operator()(int8x4_t a, int8x4_t b) const { - const float* p_a_float = reinterpret_cast(&a); - const float* p_b_float = reinterpret_cast(&b); + const vector_type a_vector{a}; + const vector_type b_vector{b}; T acc = 0; - for(index_t v = 0; v < 2; ++v) - { - acc += convert(p_a_float[v]) * convert(p_b_float[v]); - } - return acc; - } + static_for<0, 4, 1>{}([&](auto i) { + acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]); + }); - __device__ T operator()(float a, float b) const { return convert(a) * convert(b); } - - __device__ T operator()(half2_t a, half2_t b) const - { - const half_t* p_a_half = reinterpret_cast(&a); - const half_t* p_b_half = reinterpret_cast(&b); - - T acc = 0; - for(index_t v = 0; v < 2; ++v) - { - acc += convert(p_a_half[v]) * convert(p_b_half[v]); - } - - return acc; - } - - __device__ T operator()(half4_t a, half4_t b) const - { - const half_t* p_a_half = reinterpret_cast(&a); - const half_t* p_b_half = reinterpret_cast(&b); - - T acc = 0; - for(index_t v = 0; v < 4; ++v) - { - acc += convert(p_a_half[v]) * convert(p_b_half[v]); - } - return acc; - } - - __device__ T operator()(half8_t a, half8_t b) const - { - const half_t* p_a_half = reinterpret_cast(&a); - const half_t* p_b_half = reinterpret_cast(&b); - - T acc = 0; - for(index_t v = 0; v < 8; ++v) - { - acc += convert(p_a_half[v]) * convert(p_b_half[v]); - } - return acc; - } - - __device__ T operator()(ushort2_t a, ushort2_t b) const - { - const ushort* p_a_bfloat16 = reinterpret_cast(&a); - const ushort* p_b_bfloat16 = reinterpret_cast(&b); - - T acc = 0; - for(index_t v = 0; v < 2; ++v) - { - acc += convert(p_a_bfloat16[v]) * convert(p_b_bfloat16[v]); - } - - return acc; - } - - __device__ T operator()(ushort4_t a, ushort4_t b) const - { - const ushort* p_a_bfloat16 = reinterpret_cast(&a); - const ushort* p_b_bfloat16 = reinterpret_cast(&b); - - T acc = 0; - for(index_t v = 0; v < 4; ++v) - { - acc += convert(p_a_bfloat16[v]) * convert(p_b_bfloat16[v]); - } - return acc; - } - - __device__ T operator()(ushort8_t a, ushort8_t b) const - { - const ushort* p_a_bfloat16 = reinterpret_cast(&a); - const ushort* p_b_bfloat16 = reinterpret_cast(&b); - - T acc = 0; - for(index_t v = 0; v < 8; ++v) - { - acc += convert(p_a_bfloat16[v]) * convert(p_b_bfloat16[v]); - } return acc; } }; diff --git a/composable_kernel/include/utility/float_type.nvidia.hpp.in b/composable_kernel/include/utility/float_type.nvidia.hpp.in index f4a0a47c67..82b147483a 100644 --- a/composable_kernel/include/utility/float_type.nvidia.hpp.in +++ b/composable_kernel/include/utility/float_type.nvidia.hpp.in @@ -32,16 +32,16 @@ struct vector_type typedef struct { T scalar[N]; - } MemoryType; + } type; }; template <> struct vector_type { - using MemoryType = float; + using type = float; template - __host__ __device__ static void SetScalar(MemoryType& v, float s, Number) + __host__ __device__ static void SetScalar(type& v, float s, Number) { static_assert(I < 1, "wrong"); *(reinterpret_cast(&v) + I) = s; @@ -51,22 +51,22 @@ struct vector_type template <> struct vector_type { - using MemoryType = float2_t; + using type = float2_t; union DataType { - MemoryType vector; + type vector; float scalar[2]; }; template - __host__ __device__ static void SetScalar(MemoryType& v, float s, Number) + __host__ __device__ static void SetScalar(type& v, float s, Number) { static_assert(I < 2, "wrong"); *(reinterpret_cast(&v) + I) = s; } - __host__ __device__ static MemoryType Pack(float s0, float s1) + __host__ __device__ static type Pack(float s0, float s1) { DataType data; data.scalar[0] = s0; @@ -78,12 +78,12 @@ struct vector_type template <> struct vector_type { - using MemoryType = float4_t; + using type = float4_t; __host__ __device__ static constexpr index_t GetSize() { return 4; } template - __host__ __device__ static void SetScalar(MemoryType& v, float s, Number) + __host__ __device__ static void SetScalar(type& v, float s, Number) { static_assert(I < 4, "wrong"); *(reinterpret_cast(&v) + I) = s; @@ -93,10 +93,10 @@ struct vector_type template <> struct vector_type { - using MemoryType = half_t; + using type = half_t; template - __host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number) + __host__ __device__ static void SetScalar(type& v, half_t s, Number) { static_assert(I < 1, "wrong"); *(reinterpret_cast(&v) + I) = s; @@ -106,22 +106,22 @@ struct vector_type template <> struct vector_type { - using MemoryType = half2_t; + using type = half2_t; union DataType { - MemoryType vector; + type vector; half_t scalar[2]; }; template - __host__ __device__ static void SetScalar(MemoryType& v, half_t s, Number) + __host__ __device__ static void SetScalar(type& v, half_t s, Number) { static_assert(I < 2, "wrong"); *(reinterpret_cast(&v) + I) = s; } - __host__ __device__ static MemoryType Pack(half_t s0, half_t s1) + __host__ __device__ static type Pack(half_t s0, half_t s1) { DataType data; data.scalar[0] = s0; diff --git a/composable_kernel/include/utility/functional.hpp b/composable_kernel/include/utility/functional.hpp index 479f41a775..b84b617f44 100644 --- a/composable_kernel/include/utility/functional.hpp +++ b/composable_kernel/include/utility/functional.hpp @@ -2,7 +2,6 @@ #define CK_FUNCTIONAL_HPP #include "integral_constant.hpp" -#include "sequence.hpp" #include "type.hpp" namespace ck { @@ -56,8 +55,10 @@ struct static_if __host__ __device__ constexpr auto operator()(F f) const { // This is a trick for compiler: - // Pass forwarder to lambda "f" as "auto" argument, and make sure "f" will use it, - // this will make "f" a generic lambda, so that "f" won't be compiled until being + // Pass forwarder to lambda "f" as "auto" argument, and make sure "f" will + // use it, + // this will make "f" a generic lambda, so that "f" won't be compiled + // until being // instantiated here f(forwarder{}); return Type{}; @@ -84,8 +85,10 @@ struct static_if __host__ __device__ static void Else(F f) { // This is a trick for compiler: - // Pass forwarder to lambda "f" as "auto" argument, and make sure "f" will use it, - // this will make "f" a generic lambda, so that "f" won't be compiled until being + // Pass forwarder to lambda "f" as "auto" argument, and make sure "f" will + // use it, + // this will make "f" a generic lambda, so that "f" won't be compiled + // until being // instantiated here f(forwarder{}); } diff --git a/composable_kernel/include/utility/functional2.hpp b/composable_kernel/include/utility/functional2.hpp index ed0ce1ce0e..371182a05e 100644 --- a/composable_kernel/include/utility/functional2.hpp +++ b/composable_kernel/include/utility/functional2.hpp @@ -32,7 +32,8 @@ struct static_for static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0, "Wrong! should satisfy (NEnd - NBegin) % Increment == 0"); static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd), - "wrongs! should have NBegin <= NEnd"); + "wrongs! should (Increment > 0 && NBegin <= NEnd) || (Increment < 0 && " + "NBegin >= NEnd)"); } template diff --git a/composable_kernel/include/utility/functional3.hpp b/composable_kernel/include/utility/functional3.hpp index 48a0933793..6a400f3ca6 100644 --- a/composable_kernel/include/utility/functional3.hpp +++ b/composable_kernel/include/utility/functional3.hpp @@ -4,7 +4,7 @@ #include "functional.hpp" #include "functional2.hpp" #include "sequence.hpp" -#include "array.hpp" +#include "multi_index.hpp" namespace ck { @@ -63,7 +63,7 @@ struct ford_impl for(index_t i = 0; i < RemainLengths::Front(); ++i) { ford_impl{}( - f, current_ordered_id.PushBack(i)); + f, container_push_back(current_ordered_id, i)); } } }; @@ -77,14 +77,16 @@ struct ford_impl, Orders> __host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const { // retrive unordered Id - f(reorder_array_given_old2new(current_ordered_id, Orders{})); + f(container_reorder_given_old2new(current_ordered_id, Orders{})); } }; } // namespace detail -// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop -// Orders is Sequence<...>, it is the order of dimension in which static_ford will loop over each +// Lengths is Sequence<...>, it is the length of each dimension for +// N-dimensional loop +// Orders is Sequence<...>, it is the order of dimension in which static_ford +// will loop over each // dimension template ::type> @@ -106,8 +108,10 @@ struct static_ford } }; -// Lengths is Sequence<...>, it is the length of each dimension for N-dimensional loop -// Orders is Sequence<...>, it is the order of dimension in which ford will loop over each +// Lengths is Sequence<...>, it is the length of each dimension for +// N-dimensional loop +// Orders is Sequence<...>, it is the order of dimension in which ford will loop +// over each // dimension template ::type> @@ -129,7 +133,7 @@ struct ford for(index_t i = 0; i < ordered_lengths.Front(); ++i) { detail::ford_impl{}(f, - Array{i}); + make_multi_index(i)); } } }; diff --git a/composable_kernel/include/utility/functional4.hpp b/composable_kernel/include/utility/functional4.hpp index 70475ced4a..b039644380 100644 --- a/composable_kernel/include/utility/functional4.hpp +++ b/composable_kernel/include/utility/functional4.hpp @@ -16,18 +16,46 @@ template struct unpack_impl> { template - __host__ __device__ constexpr auto operator()(F f, const X& x) const + __host__ __device__ constexpr auto operator()(F&& f, X&& x) const { - return f(x.At(Number{})...); + return std::forward(f)(std::forward(x).At(Number{})...); + } +}; + +template +struct unpack2_impl; + +// TODO: remove this, after properly implementing unpack that takes any number of containers +template +struct unpack2_impl, Sequence> +{ + template + __host__ __device__ constexpr auto operator()(F&& f, X&& x, Y&& y) const + { + return std::forward(f)(std::forward(x).At(Number{})..., + std::forward(y).At(Number{})...); } }; } // namespace detail template -__host__ __device__ constexpr auto unpack(F f, const X& x) +__host__ __device__ constexpr auto unpack(F&& f, X&& x) { - return detail::unpack_impl::type>{}(f, x); + using X_ = remove_reference_t; + return detail::unpack_impl::type>{}( + std::forward(f), std::forward(x)); +} + +// TODO: properly implement unpack that takes any number of containers +template +__host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y) +{ + using X_ = remove_reference_t; + using Y_ = remove_reference_t; + return detail::unpack2_impl::type, + typename arithmetic_sequence_gen<0, Y_::Size(), 1>::type>{}( + std::forward(f), std::forward(x), std::forward(y)); } } // namespace ck diff --git a/composable_kernel/include/utility/in_memory_operation.amd.hpp.in b/composable_kernel/include/utility/in_memory_operation.amd.hpp.in index 83ecae161c..97ea488a63 100644 --- a/composable_kernel/include/utility/in_memory_operation.amd.hpp.in +++ b/composable_kernel/include/utility/in_memory_operation.amd.hpp.in @@ -5,6 +5,7 @@ #if CK_USE_AMD_BUFFER_ADDRESSING #include "amd_buffer_addressing.hpp" +#include "amd_buffer_addressing_v2.hpp" #endif namespace ck { @@ -43,7 +44,7 @@ __device__ void atomic_add_impl(float4_t* p_dst, float4_t src) template struct SetData { - using vector_t = typename vector_type::MemoryType; + using vector_t = typename vector_type::type; // This version is only for compatibility, don't use this version if possible template @@ -60,8 +61,13 @@ struct SetData { if(src_valid) { +#if 0 *reinterpret_cast(&p_dst[dst_offset]) = *reinterpret_cast(&p_src[src_offset]); +#else + *reinterpret_cast(&p_dst[dst_offset]) = + *reinterpret_cast(&p_src[0x3fffffff & src_offset]); +#endif } else { @@ -88,7 +94,7 @@ struct SetData if(dst_valid) { *reinterpret_cast(&p_dst[dst_offset]) = - amd_buffer_load(p_src, src_offset, src_valid, src_range); + amd_buffer_load_v2(p_src, src_offset, src_valid, src_range); } } @@ -108,12 +114,12 @@ struct SetData { const auto zeros = vector_t(0); - amd_buffer_store(src_valid ? &(p_src[src_offset]) - : reinterpret_cast(&zeros), - p_dst, - dst_offset, - dst_valid, - dst_range); + amd_buffer_store_v2( + src_valid ? *reinterpret_cast(&(p_src[src_offset])) : zeros, + p_dst, + dst_offset, + dst_valid, + dst_range); } #endif }; @@ -121,7 +127,7 @@ struct SetData template struct AtomicAddData { - using vector_t = typename vector_type::MemoryType; + using vector_t = typename vector_type::type; // This version is only for compatibility, don't use this version if possible template @@ -141,7 +147,7 @@ struct AtomicAddData } } -#if CK_USE_AMD_BUFFER_ADDRESSING && CK_USE_AMD_BUFFER_ATOMIC_ADD +#if CK_USE_AMD_BUFFER_ADDRESSING && CK_USE_AMD_BUFFER_ATOMIC_FADD // buffer_atomic requires: // 1) p_src_thread must be in vgpr space, p_dst_thread must be global memory // 2) p_dst_thread to be a wavewise pointer. @@ -185,25 +191,26 @@ __device__ void transfer_data(const T* p_src, "wrong! InMemoryDataOperation not supported!"); // keep it simple, don't use static_if here, otherwise compiler will do weird things - if(SrcDataStride == 1 && DstDataStride == 1) + if constexpr(SrcDataStride == 1 && DstDataStride == 1) { - // TODO: use static_if::ElseIf - static_if{}([&](auto) { + if constexpr(DstInMemOp == InMemoryDataOperation::Set) + { SetData{}.template Run( p_src, src_offset, src_valid, src_range, p_dst, dst_offset, dst_valid, dst_range); - }); - - static_if{}([&](auto) { + } + else if constexpr(DstInMemOp == InMemoryDataOperation::AtomicAdd) + { AtomicAddData{}.template Run( p_src, src_offset, src_valid, src_range, p_dst, dst_offset, dst_valid, dst_range); - }); + } } else { +#pragma unroll for(index_t i = 0; i < DataPerAccess; ++i) { - // TODO: use static_if::ElseIf - static_if{}([&](auto) { + if constexpr(DstInMemOp == InMemoryDataOperation::Set) + { SetData{}.template Run( p_src, src_offset + i * SrcDataStride, @@ -213,9 +220,9 @@ __device__ void transfer_data(const T* p_src, dst_offset + i * DstDataStride, dst_valid, dst_range); - }); - - static_if{}([&](auto) { + } + else if constexpr(DstInMemOp == InMemoryDataOperation::AtomicAdd) + { AtomicAddData{}.template Run( p_src, src_offset + i * SrcDataStride, @@ -225,7 +232,7 @@ __device__ void transfer_data(const T* p_src, dst_offset + i * DstDataStride, dst_valid, dst_range); - }); + } } } } diff --git a/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in b/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in index 0e2c7e9603..2778321035 100644 --- a/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in +++ b/composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in @@ -37,7 +37,7 @@ __device__ void atomic_add_impl(float4_t* p_dst, float4_t src) template struct SetData { - using vector_t = typename vector_type::MemoryType; + using vector_t = typename vector_type::type; template __device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const @@ -50,7 +50,7 @@ struct SetData template struct AtomicAddData { - using vector_t = typename vector_type::MemoryType; + using vector_t = typename vector_type::type; template __device__ void Run(const T* p_src, index_t src_offset, T* p_dst, index_t dst_offset) const diff --git a/composable_kernel/include/utility/math.hpp b/composable_kernel/include/utility/math.hpp index 7ce60d74cf..5738030732 100644 --- a/composable_kernel/include/utility/math.hpp +++ b/composable_kernel/include/utility/math.hpp @@ -33,6 +33,15 @@ struct multiplies __host__ __device__ constexpr T operator()(T a, T b) const { return a * b; } }; +struct multiplies_v2 +{ + template + __host__ __device__ constexpr auto operator()(const A& a, const B& b) const + { + return a * b; + } +}; + template struct maxer { @@ -105,8 +114,7 @@ __host__ __device__ constexpr T min(T x, Ts... xs) } // greatest common divisor, aka highest common factor -template -__host__ __device__ constexpr T gcd(T x, T y) +__host__ __device__ constexpr index_t gcd(index_t x, index_t y) { if(x == y || x == 0) { @@ -129,24 +137,29 @@ __host__ __device__ constexpr T gcd(T x, T y) template __host__ __device__ constexpr auto gcd(Number, Number) { - constexpr auto result = gcd(X, Y); - return Number{}; + constexpr auto r = gcd(X, Y); + + return Number{}; } -template +template = 2, bool>::type = false> __host__ __device__ constexpr auto gcd(X x, Ys... ys) { return gcd(x, ys...); } // least common multiple -template -__host__ __device__ constexpr T lcm(T x, T y) +template +__host__ __device__ constexpr auto lcm(X x, Y y) { return (x * y) / gcd(x, y); } -template +template = 2, bool>::type = false> __host__ __device__ constexpr auto lcm(X x, Ys... ys) { return lcm(x, lcm(ys...)); @@ -165,6 +178,6 @@ struct less }; } // namespace math -} // namspace ck +} // namespace ck #endif diff --git a/composable_kernel/include/utility/print.hpp b/composable_kernel/include/utility/print.hpp new file mode 100644 index 0000000000..0dd646153a --- /dev/null +++ b/composable_kernel/include/utility/print.hpp @@ -0,0 +1,70 @@ +#ifndef CK_PRINT_HPP +#define CK_PRINT_HPP + +#include "array.hpp" +#include "statically_indexed_array.hpp" +#include "container_helper.hpp" +#include "sequence.hpp" + +namespace ck { + +template +__host__ __device__ void print_array(const char* s, T a) +{ + using data_type = decltype(a.At(Number<0>{})); + constexpr index_t nsize = a.Size(); + +#if 0 + if constexpr(is_same{}) + { + printf("%s size %u, {", s, nsize); + static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%u, ", uint32_t{a[i]}); }); + printf("}\n"); + } + else if constexpr(is_same{}) + { + printf("%s size %d, {", s, nsize); + static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); }); + printf("}\n"); + } + else if constexpr(is_same{}) + { + printf("%s size %d, {", s, nsize); + static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", bool{a[i]}); }); + printf("}\n"); + } +#else + printf("%s size %d, {", s, nsize); + static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("%d, ", int32_t{a[i]}); }); + printf("}\n"); +#endif +} + +template +__host__ __device__ void print_array_v2(const char* s, T a) +{ + using data_type = decltype(a.At(Number<0>{})); + constexpr index_t nsize = a.Size(); + +#if 0 + if constexpr(is_same{}) + { + printf("%s size %u, {", s, nsize); + static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%u] %u, ", i.value, a[i]); }); + printf("}\n"); + } + else if constexpr(is_same{}) + { + printf("%s size %d, {", s, nsize); + static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); }); + printf("}\n"); + } +#else + printf("%s size %d, {", s, nsize); + static_for<0, nsize, 1>{}([&a](auto i) constexpr { printf("[%d] %d, ", i.value, a[i]); }); + printf("}\n"); +#endif +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/print_array.hpp b/composable_kernel/include/utility/print_array.hpp deleted file mode 100644 index b53bbb90f3..0000000000 --- a/composable_kernel/include/utility/print_array.hpp +++ /dev/null @@ -1,177 +0,0 @@ -#ifndef CK_PRINT_ARRAY_HPP -#define CK_PRINT_ARRAY_HPP - -#include "array.hpp" - -namespace ck { - -template -__host__ __device__ void print_array(const char* s, Array a) -{ - constexpr index_t nsize = a.GetSize(); - - static_assert(nsize > 0 && nsize <= 10, "wrong!"); - - static_if{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, a[0]); }); - - static_if{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, a[0], a[1]); }); - - static_if{}( - [&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, a[0], a[1], a[2]); }); - - static_if{}( - [&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3]); }); - - static_if{}([&](auto) { - printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]); - }); - - static_if{}([&](auto) { - printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]); - }); - - static_if{}([&](auto) { - printf("%s size %u, {%u %u %u %u %u %u %u}\n", - s, - nsize, - a[0], - a[1], - a[2], - a[3], - a[4], - a[5], - a[6]); - }); - - static_if{}([&](auto) { - printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", - s, - nsize, - a[0], - a[1], - a[2], - a[3], - a[4], - a[5], - a[6], - a[7]); - }); - - static_if{}([&](auto) { - printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", - s, - nsize, - a[0], - a[1], - a[2], - a[3], - a[4], - a[5], - a[6], - a[7], - a[8]); - }); - - static_if{}([&](auto) { - printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", - s, - nsize, - a[0], - a[1], - a[2], - a[3], - a[4], - a[5], - a[6], - a[7], - a[8], - a[9]); - }); -} - -template -__host__ __device__ void print_array(const char* s, Array a) -{ - constexpr index_t nsize = a.GetSize(); - - static_assert(nsize > 0 && nsize <= 10, "wrong!"); - - static_if{}([&](auto) { printf("%s size %d, {%d}\n", s, nsize, a[0]); }); - - static_if{}([&](auto) { printf("%s size %d, {%d %d}\n", s, nsize, a[0], a[1]); }); - - static_if{}( - [&](auto) { printf("%s size %d, {%d %d %d}\n", s, nsize, a[0], a[1], a[2]); }); - - static_if{}( - [&](auto) { printf("%s size %d, {%d %d %d %d}\n", s, nsize, a[0], a[1], a[2], a[3]); }); - - static_if{}([&](auto) { - printf("%s size %d, {%d %d %d %d %d}\n", s, nsize, a[0], a[1], a[2], a[3], a[4]); - }); - - static_if{}([&](auto) { - printf("%s size %d, {%d %d %d %d %d %d}\n", s, nsize, a[0], a[1], a[2], a[3], a[4], a[5]); - }); - - static_if{}([&](auto) { - printf("%s size %d, {%d %d %d %d %d %d %d}\n", - s, - nsize, - a[0], - a[1], - a[2], - a[3], - a[4], - a[5], - a[6]); - }); - - static_if{}([&](auto) { - printf("%s size %d, {%d %d %d %d %d %d %d %d}\n", - s, - nsize, - a[0], - a[1], - a[2], - a[3], - a[4], - a[5], - a[6], - a[7]); - }); - - static_if{}([&](auto) { - printf("%s size %d, {%d %d %d %d %d %d %d %d %d}\n", - s, - nsize, - a[0], - a[1], - a[2], - a[3], - a[4], - a[5], - a[6], - a[7], - a[8]); - }); - - static_if{}([&](auto) { - printf("%s size %d, {%d %d %d %d %d %d %d %d %d %d}\n", - s, - nsize, - a[0], - a[1], - a[2], - a[3], - a[4], - a[5], - a[6], - a[7], - a[8], - a[9]); - }); -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/utility/print_sequence.hpp b/composable_kernel/include/utility/print_sequence.hpp deleted file mode 100644 index 463f9d097d..0000000000 --- a/composable_kernel/include/utility/print_sequence.hpp +++ /dev/null @@ -1,46 +0,0 @@ -#ifndef CK_PRINT_SEQUENCE_HPP -#define CK_PRINT_SEQUENCE_HPP - -#include "sequence.hpp" - -namespace ck { - -template -__host__ __device__ void print_sequence(const char* s, Sequence) -{ - constexpr index_t nsize = Sequence::Size(); - - static_assert(nsize <= 10, "wrong!"); - - static_if{}([&](auto) { printf("%s size %u, {}\n", s, nsize, Xs...); }); - - static_if{}([&](auto) { printf("%s size %u, {%u}\n", s, nsize, Xs...); }); - - static_if{}([&](auto) { printf("%s size %u, {%u %u}\n", s, nsize, Xs...); }); - - static_if{}([&](auto) { printf("%s size %u, {%u %u %u}\n", s, nsize, Xs...); }); - - static_if{}([&](auto) { printf("%s size %u, {%u %u %u %u}\n", s, nsize, Xs...); }); - - static_if{}( - [&](auto) { printf("%s size %u, {%u %u %u %u %u}\n", s, nsize, Xs...); }); - - static_if{}( - [&](auto) { printf("%s size %u, {%u %u %u %u %u %u}\n", s, nsize, Xs...); }); - - static_if{}( - [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); - - static_if{}( - [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); - - static_if{}( - [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); - - static_if{}( - [&](auto) { printf("%s size %u, {%u %u %u %u %u %u %u %u %u %u}\n", s, nsize, Xs...); }); -} - -} // namespace ck - -#endif diff --git a/composable_kernel/include/utility/sequence.hpp b/composable_kernel/include/utility/sequence.hpp index ac4cf5eb5e..81eb488715 100644 --- a/composable_kernel/include/utility/sequence.hpp +++ b/composable_kernel/include/utility/sequence.hpp @@ -168,6 +168,14 @@ struct Sequence { return Sequence{}; } + + __host__ __device__ static void Print() + { + printf("{"); + printf("size %d, ", index_t{Size()}); + static_for<0, Size(), 1>{}([&](auto i) { printf("%d ", At(i).value); }); + printf("}"); + } }; // merge sequence @@ -750,6 +758,13 @@ __host__ __device__ constexpr auto reverse_inclusive_scan_sequence(Seq, Reduce, return typename sequence_reverse_inclusive_scan::type{}; } +template +__host__ __device__ constexpr auto reverse_exclusive_scan_sequence(Seq, Reduce, Number) +{ + return reverse_inclusive_scan_sequence(Seq::PopFront(), Reduce{}, Number{}) + .PushBack(Number{}); +} + template __host__ __device__ constexpr auto inclusive_scan_sequence(Seq, Reduce, Number) { diff --git a/composable_kernel/include/utility/sequence_helper.hpp b/composable_kernel/include/utility/sequence_helper.hpp new file mode 100644 index 0000000000..d0829c8c35 --- /dev/null +++ b/composable_kernel/include/utility/sequence_helper.hpp @@ -0,0 +1,15 @@ +#ifndef CK_SEQUENCE_HELPER_HPP +#define CK_SEQUENCE_HELPER_HPP + +#include "sequence_helper.hpp" + +namespace ck { + +template +__host__ __device__ constexpr auto generate_sequence(F, Number) +{ + return typename sequence_gen::type{}; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/statically_indexed_array.hpp b/composable_kernel/include/utility/statically_indexed_array.hpp new file mode 100644 index 0000000000..f30a3a9ee6 --- /dev/null +++ b/composable_kernel/include/utility/statically_indexed_array.hpp @@ -0,0 +1,40 @@ +#ifndef CK_STATICALLY_INDEXED_ARRAY_HPP +#define CK_STATICALLY_INDEXED_ARRAY_HPP + +#include "functional2.hpp" +#include "sequence.hpp" +#include "tuple.hpp" + +namespace ck { + +namespace detail { + +template +__host__ __device__ constexpr auto generate_same_type_tuple() +{ + return generate_tuple([](auto) -> T { return T{}; }, Number{}); +} + +template +using same_type_tuple = decltype(generate_same_type_tuple()); + +} // namespace detail + +template +using StaticallyIndexedArray = detail::same_type_tuple; + +template +__host__ __device__ constexpr auto make_statically_indexed_array(const X& x, const Xs&... xs) +{ + return StaticallyIndexedArray(x, static_cast(xs)...); +} + +// make empty StaticallyIndexedArray +template +__host__ __device__ constexpr auto make_statically_indexed_array() +{ + return StaticallyIndexedArray(); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/tuple.hpp b/composable_kernel/include/utility/tuple.hpp index 665db3ff31..f8b8bb62d4 100644 --- a/composable_kernel/include/utility/tuple.hpp +++ b/composable_kernel/include/utility/tuple.hpp @@ -2,8 +2,8 @@ #define CK_TUPLE_HPP #include "integral_constant.hpp" -#include "type.hpp" #include "sequence.hpp" +#include "type.hpp" namespace ck { @@ -12,15 +12,19 @@ namespace detail { template struct TupleElementKey { + __host__ __device__ constexpr TupleElementKey() = default; }; template struct TupleElement { - __host__ __device__ explicit constexpr TupleElement() : mData() {} + __host__ __device__ constexpr TupleElement() = default; - template - __host__ __device__ explicit constexpr TupleElement(T&& v) : mData(static_cast(v)) + template < + typename T, + typename std::enable_if>, TupleElement>::value, + bool>::type = false> + __host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward(v)) { } @@ -30,7 +34,7 @@ struct TupleElement template __host__ __device__ constexpr const Data& get_tuple_element(const TupleElement& x) { - return x.mData; + return static_cast(x.mData); } template @@ -39,6 +43,7 @@ __host__ __device__ constexpr Data& get_tuple_element(TupleElement& x return x.mData; } +// TODO: not sure the use of reference is correct template __host__ __device__ constexpr Data&& get_tuple_element(TupleElement&& x) { @@ -51,14 +56,24 @@ struct TupleImpl; template struct TupleImpl, Xs...> : TupleElement, Xs>... { - __host__ __device__ explicit constexpr TupleImpl() : TupleElement, Xs>()... + __host__ __device__ constexpr TupleImpl() = default; + + template < + typename Y, + typename std::enable_if>, TupleImpl>::value, + bool>::type = false> + __host__ __device__ constexpr TupleImpl(Y&& y) + : TupleElement, Xs>(std::forward(y))... { } - template - __host__ __device__ explicit constexpr TupleImpl(Ys&&... ys) - : TupleElement, Xs>(static_cast(ys))... + template = 2, bool>::type = false> + __host__ __device__ constexpr TupleImpl(Ys&&... ys) + : TupleElement, Xs>(std::forward(ys))... { + static_assert(sizeof...(Is) == sizeof...(Xs) && sizeof...(Is) == sizeof...(Ys), + "wrong! inconsistent size"); } __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); } @@ -84,11 +99,25 @@ struct Tuple : detail::TupleImpl::type, Xs...>; - template - __host__ __device__ explicit constexpr Tuple(Ys&&... ys) : base(static_cast(ys)...) + __host__ __device__ constexpr Tuple() = default; + + template >, Tuple>::value, + bool>::type = false> + __host__ __device__ constexpr Tuple(Y&& y) : base(std::forward(y)) { } + template = 2, + bool>::type = false> + __host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward(ys)...) + { + } + + __host__ __device__ static constexpr index_t Size() { return sizeof...(Xs); } + template __host__ __device__ constexpr const auto& At(Number) const { @@ -102,6 +131,28 @@ struct Tuple : detail::TupleImpl{}); } + + template + __host__ __device__ constexpr const auto& operator[](Number i) const + { + return At(i); + } + + template + __host__ __device__ constexpr auto& operator()(Number i) + { + return At(i); + } + + template + __host__ __device__ constexpr auto operator=(const T& a) + { + static_assert(T::Size() == Size(), "wrong! size not the same"); + + static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; }); + + return *this; + } }; template @@ -110,50 +161,5 @@ __host__ __device__ constexpr auto make_tuple(Xs&&... xs) return Tuple>...>(std::forward(xs)...); } -namespace detail { - -template -__host__ __device__ constexpr auto transform_tuples_impl(F f, const X& x, Sequence) -{ - return make_tuple(f(x.At(Number{}))...); -} - -template -__host__ __device__ constexpr auto -transform_tuples_impl(F f, const X& x, const Y& y, Sequence) -{ - return make_tuple(f(x.At(Number{}), y.At(Number{}))...); -} - -template -__host__ __device__ constexpr auto -transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, Sequence) -{ - return make_tuple(f(x.At(Number{}), y.At(Number{}), z.At(Number{}))...); -} - -} // namespace detail - -template -__host__ __device__ constexpr auto transform_tuples(F f, const X& x) -{ - return detail::transform_tuples_impl( - f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); -} - -template -__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y) -{ - return detail::transform_tuples_impl( - f, x, y, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); -} - -template -__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z) -{ - return detail::transform_tuples_impl( - f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); -} - } // namespace ck #endif diff --git a/composable_kernel/include/utility/tuple_helper.hpp b/composable_kernel/include/utility/tuple_helper.hpp new file mode 100644 index 0000000000..9499a3596c --- /dev/null +++ b/composable_kernel/include/utility/tuple_helper.hpp @@ -0,0 +1,80 @@ +#ifndef CK_TUPLE_HELPER_HPP +#define CK_TUPLE_HELPER_HPP + +#include "functional4.hpp" +#include "tuple.hpp" + +namespace ck { + +template +struct is_known_at_compile_time> +{ + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return container_reduce( + Tuple{}, + [](auto x, bool r) { + return is_known_at_compile_time< + remove_cv_t>>::value & + r; + }, + true); + } + + static constexpr bool value = IsKnownAtCompileTime(); +}; + +template +__host__ __device__ constexpr auto generate_tuple(F&& f, Number) +{ + return unpack([&f](auto&&... xs) { return make_tuple(f(xs)...); }, + typename arithmetic_sequence_gen<0, N, 1>::type{}); +} + +namespace detail { + +template +__host__ __device__ constexpr auto transform_tuples_impl(F f, const X& x, Sequence) +{ + return make_tuple(f(x.At(Number{}))...); +} + +template +__host__ __device__ constexpr auto +transform_tuples_impl(F f, const X& x, const Y& y, Sequence) +{ + return make_tuple(f(x.At(Number{}), y.At(Number{}))...); +} + +template +__host__ __device__ constexpr auto +transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, Sequence) +{ + return make_tuple(f(x.At(Number{}), y.At(Number{}), z.At(Number{}))...); +} + +} // namespace detail + +template +__host__ __device__ constexpr auto transform_tuples(F f, const X& x) +{ + return detail::transform_tuples_impl( + f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); +} + +template +__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y) +{ + return detail::transform_tuples_impl( + f, x, y, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); +} + +template +__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z) +{ + return detail::transform_tuples_impl( + f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{}); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/type.hpp b/composable_kernel/include/utility/type.hpp index ac6d306d7f..b137168a1f 100644 --- a/composable_kernel/include/utility/type.hpp +++ b/composable_kernel/include/utility/type.hpp @@ -5,9 +5,6 @@ namespace ck { -template -struct Sequence; - template struct is_same : public integral_constant { @@ -18,26 +15,32 @@ struct is_same : public integral_constant { }; -template -struct is_static : integral_constant -{ -}; - -template -struct is_static> : integral_constant -{ -}; - -template -struct is_static> : integral_constant -{ -}; - template using remove_reference_t = typename std::remove_reference::type; template using remove_cv_t = typename std::remove_cv::type; +template +constexpr std::remove_reference_t&& move(T&& t) noexcept +{ + return static_cast::type&&>(t); +} + +template +struct is_known_at_compile_time; + +template <> +struct is_known_at_compile_time +{ + static constexpr bool value = false; +}; + +template +struct is_known_at_compile_time> +{ + static constexpr bool value = true; +}; + } // namespace ck #endif diff --git a/composable_kernel/include/utility/utility.hpp b/composable_kernel/include/utility/utility.hpp index 4cb2daa7c8..9f34e044b7 100644 --- a/composable_kernel/include/utility/utility.hpp +++ b/composable_kernel/include/utility/utility.hpp @@ -9,6 +9,6 @@ __device__ index_t get_thread_local_1d_id() { return threadIdx.x; } __device__ index_t get_block_1d_id() { return blockIdx.x; } -} // namspace ck +} // namespace ck #endif diff --git a/driver/include/conv_common.hpp b/driver/include/conv_common.hpp index 2c09622e5e..c4020928f3 100644 --- a/driver/include/conv_common.hpp +++ b/driver/include/conv_common.hpp @@ -51,26 +51,24 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor( } template -constexpr std::size_t calculate_convolution_flops(InDesc, WeiDesc, OutDesc) +constexpr std::size_t +calculate_convolution_flops(const InDesc& in_desc, const WeiDesc& wei_desc, const OutDesc& out_desc) { using namespace ck; - constexpr auto wei_desc = WeiDesc{}; - constexpr auto out_desc = OutDesc{}; - constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - constexpr index_t N = out_desc.GetLength(I0); - constexpr index_t K = out_desc.GetLength(I1); - constexpr index_t Ho = out_desc.GetLength(I2); - constexpr index_t Wo = out_desc.GetLength(I3); + const index_t N = out_desc.GetLength(I0); + const index_t K = out_desc.GetLength(I1); + const index_t Ho = out_desc.GetLength(I2); + const index_t Wo = out_desc.GetLength(I3); - constexpr index_t C = wei_desc.GetLength(I1); - constexpr index_t Y = wei_desc.GetLength(I2); - constexpr index_t X = wei_desc.GetLength(I3); + const index_t C = wei_desc.GetLength(I1); + const index_t Y = wei_desc.GetLength(I2); + const index_t X = wei_desc.GetLength(I3); return std::size_t(2) * N * K * Ho * Wo * C * Y * X; } diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp index 7357563eb5..1b8e70878a 100644 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp @@ -183,7 +183,7 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i GemmBBlockCopyDstDataPerWrite_GemmN, GemmCThreadCopyDstDataPerWrite_GemmN1>; - for(index_t i = 0; i < 5; ++i) + for(index_t i = 0; i < 1; ++i) { std::cout << "Start running " << nrepeat << " times..." << std::endl; diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index 032fd375b6..b4f421131c 100644 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -57,10 +57,41 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data()); out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); -#if 0 +#if 1 // cdata = 64, BlockSize = 256, 128x128x8 constexpr index_t BlockSize = 256; + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 8; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 8; + constexpr index_t GemmNLevel1Cluster = 8; + constexpr index_t GemmThreadGemmDataPerReadM = 4; + constexpr index_t GemmThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 1 + // cdata = 64, BlockSize = 256, 128x128x8 + // GemmABlockCopySrcDataPerRead_GemmM = 4 + constexpr index_t BlockSize = 256; + constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 8; @@ -74,11 +105,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4; - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>; - constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; @@ -104,11 +135,11 @@ void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc i constexpr index_t GemmThreadGemmDataPerReadM = 4; constexpr index_t GemmThreadGemmDataPerReadN = 4; - using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<8, 1>; - using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 8>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<16, 16>; constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1; - constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>; using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; diff --git a/driver/include/device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp b/driver/include/device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp index 789ebc4b9d..b534215637 100644 --- a/driver/include/device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp +++ b/driver/include/device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp @@ -222,7 +222,7 @@ void device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk(InDesc i static_for<0, GridwiseConvBwdData::GetNumberOfGemm(), 1>{}([&](auto gemm_id) { constexpr auto gemm_sizes = GridwiseConvBwdData::GetGemmSize(gemm_id); - constexpr index_t gemm_k2 = gemm_sizes.At(4); + constexpr index_t gemm_k2 = gemm_sizes[Number<4>{}]; constexpr bool is_gemm_not_empty = gemm_k2 > 0; // only compile and run if GEMM is no empty diff --git a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp similarity index 93% rename from driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp rename to driver/include/device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index 080aa2006f..04eec6b9da 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -3,7 +3,7 @@ #include "device.hpp" #include "host_tensor.hpp" #include "gridwise_operation_wrapper.hpp" -#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp" +#include "gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp" template -void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcyx, - OutDesc, - Tensor& out_nkhw, - ConvStrides, - ConvDilations, - LeftPads, - RightPads, - ck::index_t nrepeat) +void device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, + const Tensor& in_nchw, + WeiDesc, + const Tensor& wei_kcyx, + OutDesc, + Tensor& out_nkhw, + ConvStrides, + ConvDilations, + LeftPads, + RightPads, + ck::index_t nrepeat) { + std::cout << "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw" << std::endl; + using namespace ck; using TDevice = typename conditional::value, half_t, T>::type; @@ -133,7 +135,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, constexpr index_t WeiBlockCopySrcDataPerRead_E = 2; constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1; -#elif 0 +#elif 1 // cdata = 64, BlockSize = 256, 128x128x8 constexpr index_t BlockSize = 256; @@ -770,45 +772,46 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc, printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - using gridwise_conv = GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer< - GridSize, - BlockSize, - T, - T, - decltype(in_nchw_desc), - decltype(wei_kcyx_desc), - decltype(out_nkhw_desc), - ConvStrides, - ConvDilations, - LeftPads, - RightPads, - BPerBlock, - KPerBlock, - EPerBlock, - GemmNRepeat, - GemmMPerThread, - GemmNPerThread, - GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmDataPerReadA, - GemmDataPerReadB, - InBlockCopySubLengths_E_N1_B_N2, - InBlockCopyClusterLengths_E_N1_B_N2, - InBlockCopyThreadClusterArrangeOrder, - InBlockCopySrcAccessOrder, - InBlockCopyDstAccessOrder, - InBlockCopySrcDataPerRead_B, - InBlockCopyDstDataPerWrite_N2, - WeiBlockCopySubLengths_E_K, - WeiBlockCopyClusterLengths_E_K, - WeiBlockCopyThreadClusterArrangeOrder, - WeiBlockCopySrcAccessOrder, - WeiBlockCopyDstAccessOrder, - WeiBlockCopySrcDataPerRead_E, - WeiBlockCopyDstDataPerWrite_K>; + using gridwise_conv = + GridwiseConvolutionForwardImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer< + GridSize, + BlockSize, + T, + T, + decltype(in_nchw_desc), + decltype(wei_kcyx_desc), + decltype(out_nkhw_desc), + ConvStrides, + ConvDilations, + LeftPads, + RightPads, + BPerBlock, + KPerBlock, + EPerBlock, + GemmNRepeat, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmDataPerReadA, + GemmDataPerReadB, + InBlockCopySubLengths_E_N1_B_N2, + InBlockCopyClusterLengths_E_N1_B_N2, + InBlockCopyThreadClusterArrangeOrder, + InBlockCopySrcAccessOrder, + InBlockCopyDstAccessOrder, + InBlockCopySrcDataPerRead_B, + InBlockCopyDstDataPerWrite_N2, + WeiBlockCopySubLengths_E_K, + WeiBlockCopyClusterLengths_E_K, + WeiBlockCopyThreadClusterArrangeOrder, + WeiBlockCopySrcAccessOrder, + WeiBlockCopyDstAccessOrder, + WeiBlockCopySrcDataPerRead_E, + WeiBlockCopyDstDataPerWrite_K>; for(index_t i = 0; i < 5; ++i) { diff --git a/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/driver/include/device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp similarity index 80% rename from driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp rename to driver/include/device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp index 4c887e9322..f1c0eebde7 100644 --- a/driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -2,7 +2,7 @@ #include "device.hpp" #include "host_tensor.hpp" #include "gridwise_operation_wrapper.hpp" -#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" +#include "gridwise_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" template -void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, - const Tensor& in_nchw, - WeiDesc, - const Tensor& wei_kcyx, - OutDesc, - Tensor& out_nkhw, - ConvStrides, - ConvDilations, - InLeftPads, - InRightPads, - ck::index_t nrepeat) +void device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, + const Tensor& in_nchw, + WeiDesc, + const Tensor& wei_kcyx, + OutDesc, + Tensor& out_nkhw, + ConvStrides, + ConvDilations, + InLeftPads, + InRightPads, + ck::index_t nrepeat) { + std::cout << "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw" << std::endl; + using namespace ck; using TDevice = typename conditional::value, half_t, T>::type; @@ -55,6 +57,109 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, out_nkhw_device_buf.ToDevice(out_nkhw.mData.data()); #if 0 + // cdata = 16, BlockSize = 64, 16x64x4 + constexpr index_t BlockSize = 64; + + constexpr index_t GemmMPerBlock = 16; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerThread = 2; + constexpr index_t GemmNPerThread = 2; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t ThreadGemmDataPerReadM = 2; + constexpr index_t ThreadGemmDataPerReadN = 2; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 2; +#elif 0 + // cdata = 16, BlockSize = 64, 16x64x4 + // GemmBBlockCopySrcDataPerRead_GemmN = 4 + // GemmCThreadCopyDstDataPerWrite_GemmN1 = 2 + constexpr index_t BlockSize = 64; + + constexpr index_t GemmMPerBlock = 16; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerThread = 2; + constexpr index_t GemmNPerThread = 2; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t ThreadGemmDataPerReadM = 2; + constexpr index_t ThreadGemmDataPerReadN = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; + using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + + using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 16>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2; +#elif 0 + // cdata = 32, BlockSize = 64, 16x128x4 + // GemmBBlockCopySrcDataPerRead_GemmN = 4 + // GemmCThreadCopyDstDataPerWrite_GemmN1 = 4 + constexpr index_t BlockSize = 64; + + constexpr index_t GemmMPerBlock = 16; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerThread = 2; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t ThreadGemmDataPerReadM = 2; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; + using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + + using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4; +#elif 0 // cdata = 64, BlockSize = 256, 64x256x8 constexpr index_t BlockSize = 256; @@ -62,14 +167,14 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 16; + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 16; constexpr index_t ThreadGemmDataPerReadM = 4; constexpr index_t ThreadGemmDataPerReadN = 4; @@ -86,6 +191,39 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 0 + // cdata = 64, BlockSize = 256, 128x128x2 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 2; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 8; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; #elif 0 // cdata = 64, BlockSize = 256, 128x128x4 @@ -99,10 +237,10 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, constexpr index_t GemmNPerThread = 4; constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 8; + constexpr index_t GemmNLevel1Cluster = 8; constexpr index_t ThreadGemmDataPerReadM = 4; constexpr index_t ThreadGemmDataPerReadN = 4; @@ -122,6 +260,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; #elif 1 // cdata = 64, BlockSize = 256, 128x128x8 + // b threadwise copy 4x1 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; @@ -152,6 +291,40 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; +#elif 0 + // cdata = 64, BlockSize = 256, 128x128x8 + // b threadwise copy 2x2 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 8; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 8; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; #elif 0 // cdata = 64, BlockSize = 256, 128x128x8 @@ -255,7 +428,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; -#elif 0 +#elif 1 // cdata = 64, BlockSize = 256, 128x128x16 // GemmBBlockCopySrcDataPerRead_GemmN = 4 // GemmCThreadCopyDstDataPerWrite_GemmN1 = 4 @@ -289,6 +462,41 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; +#elif 1 + // cdata = 64, BlockSize = 256, 128x128x16 + // GemmBBlockCopySrcDataPerRead_GemmN = 4 + // GemmCThreadCopyDstDataPerWrite_GemmN1 = 4 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 16; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 8>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<16, 16>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4; + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4; #elif 0 // cdata = 64, BlockSize = 128, 128x64x4 @@ -826,7 +1034,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 2; constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1; -#elif 1 +#elif 0 // cdata = 64, BlockSize = 64, 64x64x3 constexpr index_t BlockSize = 64; @@ -968,7 +1176,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc, printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); - using gridwise_conv = GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw< + using gridwise_conv = GridwiseConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw< GridSize, BlockSize, TDevice, diff --git a/driver/include/device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/driver/include/device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..238eebf2ee --- /dev/null +++ b/driver/include/device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,207 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "gridwise_operation_wrapper.hpp" +#include "gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" + +template +void device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc, + const Tensor& in_nchw, + WeiDesc, + const Tensor& wei_kcyx, + OutDesc, + Tensor& out_nkhw, + ConvStrides, + ConvDilations, + InLeftPads, + InRightPads, + ck::index_t nrepeat) +{ + std::cout << "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk" << std::endl; + + using namespace ck; + + using TDevice = typename conditional::value, half_t, T>::type; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto N = OutDesc::GetLengths()[I0]; + constexpr auto K = OutDesc::GetLengths()[I1]; + constexpr auto C = WeiDesc::GetLengths()[I1]; + + constexpr auto Hi = InDesc::GetLengths()[I2]; + constexpr auto Wi = InDesc::GetLengths()[I3]; + + constexpr auto Ho = OutDesc::GetLengths()[I2]; + constexpr auto Wo = OutDesc::GetLengths()[I3]; + + constexpr auto Y = WeiDesc::GetLengths()[I2]; + constexpr auto X = WeiDesc::GetLengths()[I3]; + + // compile-time variables + constexpr auto in_n_hi_wi_c_desc = + make_native_tensor_descriptor_packed(Sequence{}); + constexpr auto wei_k_y_x_c_desc = make_native_tensor_descriptor_packed(Sequence{}); + constexpr auto out_n_ho_wo_k_desc = + make_native_tensor_descriptor_packed(Sequence{}); + + Tensor in_nhwc( + make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence{}))); + Tensor wei_kyxc( + make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence{}))); + Tensor out_nhwk( + make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence{}))); + + auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) { + in_nhwc(n, hi, wi, c) = in_nchw(n, c, hi, wi); + }; + + auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) { + wei_kyxc(k, y, x, c) = wei_kcyx(k, c, y, x); + }; + + auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) { + out_nhwk(n, ho, wo, k) = out_nkhw(n, k, ho, wo); + }; + + make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)(std::thread::hardware_concurrency()); + make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)(std::thread::hardware_concurrency()); + make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)(std::thread::hardware_concurrency()); + + std::size_t data_sz = sizeof(T); + + DeviceMem in_nhwc_device_buf(data_sz * in_nhwc.mDesc.GetElementSpace()); + DeviceMem wei_kyxc_device_buf(data_sz * wei_kyxc.mDesc.GetElementSpace()); + DeviceMem out_nhwk_device_buf(data_sz * out_nhwk.mDesc.GetElementSpace()); + + in_nhwc_device_buf.ToDevice(in_nhwc.mData.data()); + wei_kyxc_device_buf.ToDevice(wei_kyxc.mData.data()); + out_nhwk_device_buf.ToDevice(out_nhwk.mData.data()); + +#if 1 + // cdata = 16, BlockSize = 64, 16x64x4 + constexpr index_t BlockSize = 64; + + constexpr index_t GemmMPerBlock = 16; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerThread = 2; + constexpr index_t GemmNPerThread = 2; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t ThreadGemmDataPerReadM = 2; + constexpr index_t ThreadGemmDataPerReadN = 2; + + using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; + using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; + + constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1; + constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1; + + using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; + using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; + + constexpr index_t GemmBBlockCopySrcDataPerRead_GemmK = 4; + constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1; + + constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmM1 = 2; +#endif + + constexpr index_t GemmM = K; + constexpr index_t GemmN = N * Ho * Wo; + + constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) * + math::integer_divide_ceil(GemmN, GemmNPerBlock); + + printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize); + + using gridwise_conv = GridwiseConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk< + GridSize, + BlockSize, + TDevice, + TDevice, + decltype(in_n_hi_wi_c_desc), + decltype(wei_k_y_x_c_desc), + decltype(out_n_ho_wo_k_desc), + ConvStrides, + ConvDilations, + InLeftPads, + InRightPads, + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + ThreadGemmDataPerReadM, + ThreadGemmDataPerReadN, + GemmABlockCopyThreadSliceLengths_GemmK_GemmM, + GemmABlockCopyThreadClusterLengths_GemmK_GemmM, + GemmABlockCopySrcDataPerRead_GemmK, + GemmABlockCopyDstDataPerWrite_GemmM, + GemmBBlockCopyThreadSliceLengths_GemmK_GemmN, + GemmBBlockCopyThreadClusterLengths_GemmK_GemmN, + GemmBBlockCopySrcDataPerRead_GemmK, + GemmBBlockCopyDstDataPerWrite_GemmN, + GemmCThreadCopyDstDataPerWrite_GemmM1>; + + for(index_t i = 0; i < 5; ++i) + { + std::cout << "Start running " << nrepeat << " times..." << std::endl; + + KernelTimer timer; + timer.Start(); + + for(index_t j = 0; j < nrepeat; ++j) + { + launch_kernel(run_gridwise_operation, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + static_cast(in_nhwc_device_buf.GetDeviceBuffer()), + static_cast(wei_kyxc_device_buf.GetDeviceBuffer()), + static_cast(out_nhwk_device_buf.GetDeviceBuffer())); + } + + timer.End(); + + float ave_time = timer.GetElapsedTime() / nrepeat; + + float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + out_nhwk_device_buf.FromDevice(out_nhwk.mData.data()); + + auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) { + out_nkhw(n, k, ho, wo) = out_nhwk(n, ho, wo, k); + }; + + make_ParallelTensorFunctor(f_nhwk2nkhw, N, K, Ho, Wo)(std::thread::hardware_concurrency()); +} diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..53a8e7ac4b --- /dev/null +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -0,0 +1,508 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" + +template +void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( + InDesc, + const Tensor& in_n_c_hi_wi, + WeiDesc, + const Tensor& wei_k_c_y_x, + OutDesc, + Tensor& out_n_k_ho_wo, + ConvStrides, + ConvDilations, + InLeftPads, + InRightPads, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw" + << std::endl; + + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + +#if 0 + // run-time variables + const auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths())); + const auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(WeiDesc::GetLengths())); + const auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(OutDesc::GetLengths())); + + const auto conv_strides = to_multi_index(ConvStrides{}); + const auto conv_dilations = to_multi_index(ConvDilations{}); + const auto in_left_pads = to_multi_index(InLeftPads{}); + const auto in_right_pads = to_multi_index(InRightPads{}); +#else + // compile-time variables + const auto in_n_c_hi_wi_desc = make_dynamic_naive_tensor_descriptor_packed_v2( + sequence_to_tuple_of_number(InDesc::GetLengths())); + const auto wei_k_c_y_x_desc = make_dynamic_naive_tensor_descriptor_packed_v2( + sequence_to_tuple_of_number(WeiDesc::GetLengths())); + const auto out_n_k_ho_wo_desc = make_dynamic_naive_tensor_descriptor_packed_v2( + sequence_to_tuple_of_number(OutDesc::GetLengths())); + + const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{}); + const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{}); + const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{}); + const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); +#endif + +#if 0 + // cdata = 16, BlockSize = 64, 16x64x4 + constexpr index_t BlockSize = 64; + + constexpr index_t GemmMPerBlock = 16; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerThread = 2; + constexpr index_t GemmNPerThread = 2; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t ThreadGemmDataPerReadM = 2; + constexpr index_t ThreadGemmDataPerReadN = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; + using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + + using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; + using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2; +#elif 0 + // cdata = 32, BlockSize 64, 16x128x4 + constexpr index_t BlockSize = 64; + + constexpr index_t GemmMPerBlock = 16; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerThread = 2; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t ThreadGemmDataPerReadM = 2; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; + using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + + using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4; +#elif 0 + // cdata = 64, BlockSize 64, 16x256x2 + constexpr index_t BlockSize = 64; + + constexpr index_t GemmMPerBlock = 16; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 2; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 1; + constexpr index_t GemmNLevel1Cluster = 16; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; + using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 16>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + + using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4; +#elif 1 + // cdata = 64, BlockSize 64, 16x256x4 + constexpr index_t BlockSize = 64; + + constexpr index_t GemmMPerBlock = 16; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 1; + constexpr index_t GemmNLevel1Cluster = 16; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; + using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + + using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4; +#elif 0 + // cdata = 16, BlockSize = 64, 16x64x4 + // GemmBBlockCopySrcDataPerRead_GemmN = 4 + // GemmCThreadCopyDstDataPerWrite_GemmN1 = 2 + constexpr index_t BlockSize = 64; + + constexpr index_t GemmMPerBlock = 16; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerThread = 2; + constexpr index_t GemmNPerThread = 2; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t ThreadGemmDataPerReadM = 2; + constexpr index_t ThreadGemmDataPerReadN = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; + using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + + using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 16>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2; +#elif 0 + // cdata = 32, BlockSize = 64, 16x128x4 + // GemmBBlockCopySrcDataPerRead_GemmN = 4 + // GemmCThreadCopyDstDataPerWrite_GemmN1 = 4 + constexpr index_t BlockSize = 64; + + constexpr index_t GemmMPerBlock = 16; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerThread = 2; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t ThreadGemmDataPerReadM = 2; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; + using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + + using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4; +#elif 0 + // cdata = 64, BlockSize = 128, 32x256x8 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 32; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 8; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 16; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>; + using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + + using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#elif 0 + // cdata = 64, BlockSize = 256, 128x128x2 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 2; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 8; + constexpr index_t GemmNLevel1Cluster = 8; + + using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; + using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + + using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 1>; + using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#elif 0 + // cdata = 64, BlockSize = 256, 128x128x4 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 8; + constexpr index_t GemmNLevel1Cluster = 8; + + using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>; + using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + + using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 1>; + using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#elif 1 + // cdata = 64, BlockSize = 256, 128x128x8 + // b thread copy 4x1 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 8; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 8; + constexpr index_t GemmNLevel1Cluster = 8; + + using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; + using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + + using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; + using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#elif 1 + // cdata = 64, BlockSize = 256, 128x128x8 + // b thread copy 2x2 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 8; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 8; + constexpr index_t GemmNLevel1Cluster = 8; + + using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; + using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + + using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#elif 1 + // cdata = 64, BlockSize = 256, 128x128x16 + // GemmBBlockCopySrcDataPerRead_GemmN = 4 + // GemmCThreadCopyDstDataPerWrite_GemmN1 = 4 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 16; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 4; + constexpr index_t GemmNLevel0Cluster = 4; + constexpr index_t GemmMLevel1Cluster = 4; + constexpr index_t GemmNLevel1Cluster = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + + using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4; +#endif + + constexpr auto conv_driver = +#if 1 + DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad +#elif 0 + DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad +#elif 1 + DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1 +#endif + ::type, + TAcc, + TOut, + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmABlockTransferThreadSliceLengths_GemmK_GemmM, + GemmABlockTransferThreadClusterLengths_GemmK_GemmM, + GemmABlockTransferSrcScalarPerVector_GemmK, + GemmABlockTransferDstScalarPerVector_GemmM, + GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, + GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmN, + GemmCThreadTransferDstScalarPerVector_GemmN1>{}; + + conv_driver.Run(wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + static_cast::type*>( + wei_k_c_y_x_device_buf.GetDeviceBuffer()), + static_cast::type*>( + in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer())); + + out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); +} diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..fee97dddbc --- /dev/null +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,427 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" + +template +void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( + InDesc, + const Tensor& in_n_c_hi_wi, + WeiDesc, + const Tensor& wei_k_c_y_x, + OutDesc, + Tensor& out_n_k_ho_wo, + ConvStrides, + ConvDilations, + InLeftPads, + InRightPads, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk" + << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto N = OutDesc::GetLengths()[I0]; + constexpr auto K = OutDesc::GetLengths()[I1]; + constexpr auto C = WeiDesc::GetLengths()[I1]; + + constexpr auto Hi = InDesc::GetLengths()[I2]; + constexpr auto Wi = InDesc::GetLengths()[I3]; + + constexpr auto Ho = OutDesc::GetLengths()[I2]; + constexpr auto Wo = OutDesc::GetLengths()[I3]; + + constexpr auto Y = WeiDesc::GetLengths()[I2]; + constexpr auto X = WeiDesc::GetLengths()[I3]; + + constexpr auto C0 = C / Number{}; + constexpr auto C1 = Number{}; + +#if 0 + // run-time variables + constexpr auto in_n_hi_wi_c0_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0)); + constexpr auto wei_k_y_x_c0_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(K, Y, X, C0)); + constexpr auto out_n_ho_wo_k_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Ho, Wo, K)); + + const auto conv_strides = to_multi_index(ConvStrides{}); + const auto conv_dilations = to_multi_index(ConvDilations{}); + const auto in_left_pads = to_multi_index(InLeftPads{}); + const auto in_right_pads = to_multi_index(InRightPads{}); +#else + // compile-time variables + constexpr auto in_n_hi_wi_c0_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Hi, Wi, C0)); + constexpr auto wei_k_y_x_c0_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y, X, C0)); + constexpr auto out_n_ho_wo_k_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Ho, Wo, K)); + + const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{}); + const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{}); + const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{}); + const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); +#endif + + Tensor in_n_hi_wi_c( + make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence{}))); + Tensor wei_k_y_x_c( + make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence{}))); + Tensor out_n_ho_wo_k( + make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence{}))); + + auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) { + in_n_hi_wi_c(n, hi, wi, c) = in_n_c_hi_wi(n, c, hi, wi); + }; + + auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) { + wei_k_y_x_c(k, y, x, c) = wei_k_c_y_x(k, c, y, x); + }; + + auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) { + out_n_ho_wo_k(n, ho, wo, k) = out_n_k_ho_wo(n, k, ho, wo); + }; + + make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)(); + make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)(); + make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)(); + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + +#if 1 + // cdata = 16, BlockSize = 64, 16x64x4 + constexpr index_t BlockSize = 64; + + constexpr index_t GemmMPerBlock = 16; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerThread = 2; + constexpr index_t GemmNPerThread = 2; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t ThreadGemmDataPerReadM = 2; + constexpr index_t ThreadGemmDataPerReadN = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; + using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + + using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; + using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 2; +#elif 0 + // cdata = 32, BlockSize = 64, 16x128x4 + constexpr index_t BlockSize = 64; + + constexpr index_t GemmMPerBlock = 16; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerThread = 2; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 8; + + constexpr index_t ThreadGemmDataPerReadM = 2; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; + using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + + using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 2; +#elif 0 + // cdata = 64, BlockSize = 64, 16x256x2 + constexpr index_t BlockSize = 64; + + constexpr index_t GemmMPerBlock = 16; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 2; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 1; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 16; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; + using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 16>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + + using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; +#elif 1 + // cdata = 64, BlockSize = 64, 16x256x4 + constexpr index_t BlockSize = 64; + + constexpr index_t GemmMPerBlock = 16; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 1; + constexpr index_t GemmNLevel1Cluster = 16; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; + using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + + using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; +#elif 0 + // cdata = 64, BlockSize = 128, 32x256x4 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 32; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 16; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; + using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + + using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; +#elif 0 + // cdata = 64, BlockSize = 128, 32x256x8 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 32; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 8; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 2; + constexpr index_t GemmNLevel1Cluster = 16; + + constexpr index_t ThreadGemmDataPerReadM = 4; + constexpr index_t ThreadGemmDataPerReadN = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>; + using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + + using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; +#elif 0 + // cdata = 64, BlockSize = 256, 128x128x8 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 8; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 8; + constexpr index_t GemmNLevel1Cluster = 8; + + using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; + using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + + using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; + using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; +#elif 1 + // cdata = 64, BlockSize = 256, 128x128x16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 16; + + constexpr index_t GemmMPerThread = 4; + constexpr index_t GemmNPerThread = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmMLevel0Cluster = 2; + constexpr index_t GemmNLevel0Cluster = 2; + constexpr index_t GemmMLevel1Cluster = 8; + constexpr index_t GemmNLevel1Cluster = 8; + + using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>; + using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; +#endif + + constexpr auto conv_driver = +#if 1 + DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad +#elif 0 + DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_no_pad +#elif 1 + DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1 +#endif + ::type, + TAcc, + TOut, + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerThread, + GemmNPerThread, + GemmKPerThread, + GemmMLevel0Cluster, + GemmNLevel0Cluster, + GemmMLevel1Cluster, + GemmNLevel1Cluster, + GemmABlockTransferThreadSliceLengths_GemmK_GemmM, + GemmABlockTransferThreadClusterLengths_GemmK_GemmM, + GemmABlockTransferSrcScalarPerVector_GemmK, + GemmABlockTransferDstScalarPerVector_GemmM, + GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, + GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, + GemmBBlockTransferSrcScalarPerVector_GemmK, + GemmBBlockTransferDstScalarPerVector_GemmN, + GemmCThreadTransferDstScalarPerVector_GemmM1>{}; + + conv_driver.Run(wei_k_y_x_c0_desc, + in_n_hi_wi_c0_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + static_cast::type*>( + wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast::type*>( + in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer())); + + out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); + + auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) { + out_n_k_ho_wo(n, k, ho, wo) = out_n_ho_wo_k(n, ho, wo, k); + }; + + make_ParallelTensorFunctor(f_nhwk2nkhw, N, K, Ho, Wo)(); +} diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..3bce677665 --- /dev/null +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp @@ -0,0 +1,167 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" + +template +void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( + InDesc, + const Tensor& in_n_c_hi_wi, + WeiDesc, + const Tensor& wei_k_c_y_x, + OutDesc, + Tensor& out_n_k_ho_wo, + ConvStrides, + ConvDilations, + InLeftPads, + InRightPads, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw" + << std::endl; + + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto N = OutDesc::GetLengths()[I0]; + constexpr auto K = OutDesc::GetLengths()[I1]; + constexpr auto C = WeiDesc::GetLengths()[I1]; + + constexpr auto Hi = InDesc::GetLengths()[I2]; + constexpr auto Wi = InDesc::GetLengths()[I3]; + + constexpr auto Ho = OutDesc::GetLengths()[I2]; + constexpr auto Wo = OutDesc::GetLengths()[I3]; + + constexpr auto Y = WeiDesc::GetLengths()[I2]; + constexpr auto X = WeiDesc::GetLengths()[I3]; + + constexpr auto C0 = C / Number{}; + constexpr auto C1 = Number{}; + +#if 0 + // run-time variables + const auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths())); + const auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(WeiDesc::GetLengths())); + const auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(OutDesc::GetLengths())); + + const auto conv_strides = to_multi_index(ConvStrides{}); + const auto conv_dilations = to_multi_index(ConvDilations{}); + const auto in_left_pads = to_multi_index(InLeftPads{}); + const auto in_right_pads = to_multi_index(InRightPads{}); +#else + // compile-time variables + const auto in_n_c0_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C0, Hi, Wi)); + const auto wei_k_c0_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X)); + const auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo)); + + const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{}); + const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{}); + const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{}); + const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); +#endif + + Tensor in_n_c0_hi_wi_c1(make_HostTensorDescriptor( + make_native_tensor_descriptor_packed(Sequence{}))); + Tensor wei_k_c0_y_x_c1(make_HostTensorDescriptor( + make_native_tensor_descriptor_packed(Sequence{}))); + + auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) { + in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) = + in_n_c_hi_wi(n, c, hi, wi); + }; + + auto f_kcyx2kc0yxc1 = [&](auto k, auto y, auto x, auto c) { + wei_k_c0_y_x_c1(k, c / InWeiVectorSize, y, x, c % InWeiVectorSize) = + wei_k_c_y_x(k, c, y, x); + }; + + make_ParallelTensorFunctor(f_nchw2nc0hwc1, N, Hi, Wi, C)(); + make_ParallelTensorFunctor(f_kcyx2kc0yxc1, K, Y, X, C)(); + + in_n_c_hi_wi_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data()); + wei_k_c_y_x_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data()); + + // cdata = 64, BlockSize = 64, 16x8x32x4 + constexpr index_t BlockSize = 64; + + constexpr index_t KPerBlock = 16; + constexpr index_t HoPerBlock = 8; + constexpr index_t WoPerBlock = 32; + constexpr index_t EPerBlock = 4; + + constexpr index_t KPerThread = 16; + constexpr index_t HoPerThread = 2; + constexpr index_t WoPerThread = 2; + constexpr index_t EPerThread = EPerBlock; + + using ABlockTransferThreadSliceLengths_E_K = Sequence<9, 1>; + using ABlockTransferThreadClusterLengths_E_K = Sequence; + + constexpr index_t ABlockTransferSrcScalarPerVector_E = 1; + constexpr index_t ABlockTransferDstScalarPerVector_K = 1; + + constexpr index_t BThreadTransferSrcScalarPerVector_W = 1; + + constexpr index_t CThreadTransferDstScalarPerVector_W = 1; + + constexpr auto conv_driver = + DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad< + BlockSize, + typename vector_type::type, + TAcc, + TOut, + KPerBlock, + HoPerBlock, + WoPerBlock, + EPerBlock, + KPerThread, + HoPerThread, + WoPerThread, + EPerThread, + ABlockTransferThreadSliceLengths_E_K, + ABlockTransferThreadClusterLengths_E_K, + ABlockTransferSrcScalarPerVector_E, + ABlockTransferDstScalarPerVector_K, + BThreadTransferSrcScalarPerVector_W, + CThreadTransferDstScalarPerVector_W>{}; + + conv_driver.Run(wei_k_c0_y_x_desc, + in_n_c0_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + static_cast::type*>( + wei_k_c_y_x_device_buf.GetDeviceBuffer()), + static_cast::type*>( + in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer())); + + out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); +} diff --git a/driver/include/host_conv.hpp b/driver/include/host_conv.hpp index 5ce822e70a..d7cf3bbb00 100644 --- a/driver/include/host_conv.hpp +++ b/driver/include/host_conv.hpp @@ -273,7 +273,7 @@ void host_winograd_3x3_convolution(const Tensor& in_nchw, std::size_t ho = HoPerTile * htile + j; for(int i = 0; i < WoPerTile; ++i) { - std::size_t wo = WoPerTile * wtile + i; + std::size_t wo = WoPerTile * wtile + i; out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i); } } diff --git a/driver/include/host_tensor.hpp b/driver/include/host_tensor.hpp index c1cc46576f..ac6df6f931 100644 --- a/driver/include/host_tensor.hpp +++ b/driver/include/host_tensor.hpp @@ -158,7 +158,7 @@ struct ParallelTensorFunctor return indices; } - void operator()(std::size_t num_thread) const + void operator()(std::size_t num_thread = std::thread::hardware_concurrency()) const { std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread; diff --git a/driver/src/conv_bwd_data_driver.cpp b/driver/src/conv_bwd_data_driver.cpp index a5248bfb1a..cdb2526c75 100644 --- a/driver/src/conv_bwd_data_driver.cpp +++ b/driver/src/conv_bwd_data_driver.cpp @@ -4,10 +4,7 @@ #include #include #include "config.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "print_array.hpp" -#include "print_sequence.hpp" +#include "print.hpp" #include "device.hpp" #include "host_tensor_generator.hpp" #include "device_tensor.hpp" @@ -54,10 +51,10 @@ int main(int argc, char* argv[]) #elif 0 // 3x3, 28x28 constexpr index_t N = 128; - constexpr index_t C = 256; + constexpr index_t C = 128; constexpr index_t HI = 28; constexpr index_t WI = 28; - constexpr index_t K = 1024; + constexpr index_t K = 128; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -156,13 +153,13 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<2, 2>; using RightPads = Sequence<2, 2>; -#elif 0 +#elif 1 // 1x7 filter, 0x3 pad, 17x17 input constexpr index_t N = 128; - constexpr index_t C = 256; + constexpr index_t C = 128; constexpr index_t HI = 17; constexpr index_t WI = 17; - constexpr index_t K = 1024; + constexpr index_t K = 128; constexpr index_t Y = 1; constexpr index_t X = 7; @@ -197,7 +194,7 @@ int main(int argc, char* argv[]) constexpr index_t X = 3; using ConvStrides = Sequence<2, 2>; - using ConvDilations = Sequence<1, 1>; + using ConvDilations = Sequence<2, 2>; using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; @@ -211,11 +208,11 @@ int main(int argc, char* argv[]) ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); - print_sequence("LeftPads", LeftPads{}); - print_sequence("LeftPads", LeftPads{}); - print_sequence("RightPads", RightPads{}); - print_sequence("ConvStrides", ConvStrides{}); - print_sequence("ConvDilations", ConvDilations{}); + print_array("LeftPads", LeftPads{}); + print_array("LeftPads", LeftPads{}); + print_array("RightPads", RightPads{}); + print_array("ConvStrides", ConvStrides{}); + print_array("ConvDilations", ConvDilations{}); Tensor in_nchw_device(make_HostTensorDescriptor(in_nchw_desc)); Tensor in_nchw_host(make_HostTensorDescriptor(in_nchw_desc)); @@ -248,7 +245,7 @@ int main(int argc, char* argv[]) device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw #elif 0 device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw -#elif 0 +#elif 1 device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw #elif 1 device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk diff --git a/driver/src/conv_bwd_data_driver.cu b/driver/src/conv_bwd_data_driver.cu deleted file mode 120000 index bf6baa8d47..0000000000 --- a/driver/src/conv_bwd_data_driver.cu +++ /dev/null @@ -1 +0,0 @@ -conv_bwd_data_driver.cpp \ No newline at end of file diff --git a/driver/src/conv_driver.cpp b/driver/src/conv_driver.cpp index 7317bd6a1c..8d5bc24c8c 100644 --- a/driver/src/conv_driver.cpp +++ b/driver/src/conv_driver.cpp @@ -5,27 +5,29 @@ #include #include #include "config.hpp" -#include "print_array.hpp" -#include "print_sequence.hpp" +#include "print.hpp" #include "device.hpp" #include "host_tensor_generator.hpp" #include "conv_common.hpp" #include "host_conv.hpp" #include "device_tensor.hpp" -#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" -#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" +#include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" +#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" +#include "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" int main(int argc, char* argv[]) { using namespace ck; #if 0 - // 1x1, 17x17 - constexpr index_t N = 128; - constexpr index_t C = 1024; - constexpr index_t HI = 17; - constexpr index_t WI = 17; - constexpr index_t K = 256; + constexpr index_t N = 1; + constexpr index_t C = 16; + constexpr index_t HI = 1080; + constexpr index_t WI = 1920; + constexpr index_t K = 16; constexpr index_t Y = 1; constexpr index_t X = 1; @@ -35,6 +37,135 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; #elif 0 + constexpr index_t N = 1; + constexpr index_t C = 16; + constexpr index_t HI = 540; + constexpr index_t WI = 960; + constexpr index_t K = 16; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; +#elif 0 + constexpr index_t N = 1; + constexpr index_t C = 16; + constexpr index_t HI = 270; + constexpr index_t WI = 480; + constexpr index_t K = 16; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; +#elif 1 + constexpr index_t N = 1; + constexpr index_t C = 16; + constexpr index_t HI = 1080; + constexpr index_t WI = 1920; + constexpr index_t K = 16; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<1, 1>; + using RightPads = Sequence<1, 1>; +#elif 0 + constexpr index_t N = 1; + constexpr index_t C = 1; + constexpr index_t HI = 1024; + constexpr index_t WI = 2048; + constexpr index_t K = 4; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<1, 1>; + using RightPads = Sequence<1, 1>; +#elif 0 + constexpr index_t N = 1; + constexpr index_t C = 16; + constexpr index_t HI = 540; + constexpr index_t WI = 960; + constexpr index_t K = 16; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<1, 1>; + using RightPads = Sequence<1, 1>; +#elif 0 + constexpr index_t N = 1; + constexpr index_t C = 16; + constexpr index_t HI = 270; + constexpr index_t WI = 480; + constexpr index_t K = 16; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<1, 1>; + using RightPads = Sequence<1, 1>; +#elif 0 + // 3x3, 36x36, stride 2 + constexpr index_t N = 128; + constexpr index_t C = 192; + constexpr index_t HI = 37; + constexpr index_t WI = 37; + constexpr index_t K = 384; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + using ConvStrides = Sequence<2, 2>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; +#elif 0 + // 3x3, 35x35, stride 2 + constexpr index_t N = 128; + constexpr index_t C = 192; + constexpr index_t HI = 35; + constexpr index_t WI = 35; + constexpr index_t K = 384; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + using ConvStrides = Sequence<2, 2>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; +#elif 0 + // 3x3, 71x71 + constexpr index_t N = 128; + constexpr index_t C = 192; + constexpr index_t HI = 71; + constexpr index_t WI = 71; + constexpr index_t K = 128; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + using ConvStrides = Sequence<2, 2>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<1, 1>; + using RightPads = Sequence<1, 1>; +#elif 1 // 1x1, 8x8 constexpr index_t N = 128; constexpr index_t C = 1536; @@ -70,7 +201,7 @@ int main(int argc, char* argv[]) constexpr index_t C = 96; constexpr index_t HI = 35; constexpr index_t WI = 35; - constexpr index_t K = 96; + constexpr index_t K = 128; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -94,7 +225,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>; -#elif 0 +#elif 1 // 7x1, 17x17 constexpr index_t N = 128; constexpr index_t C = 128; @@ -109,7 +240,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<3, 0>; using RightPads = Sequence<3, 0>; -#elif 1 +#elif 0 // 1x7, 17x17 constexpr index_t N = 128; constexpr index_t C = 128; @@ -141,12 +272,11 @@ int main(int argc, char* argv[]) using RightPads = Sequence<0, 0>; #elif 0 // 3x3, 147x147 - // v4r4@v100 xx.xx%, cudnn@v100 xx.xx% constexpr index_t N = 128; - constexpr index_t C = 32; + constexpr index_t C = 128; constexpr index_t HI = 147; constexpr index_t WI = 147; - constexpr index_t K = 64; + constexpr index_t K = 128; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -157,7 +287,6 @@ int main(int argc, char* argv[]) using RightPads = Sequence<1, 1>; #elif 0 // 3x3, 149x149 - // v4r4@v100 xx.xx%, cudnn@v100 xx.xx% constexpr index_t N = 128; constexpr index_t C = 32; constexpr index_t HI = 149; @@ -201,7 +330,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; -#elif 1 +#elif 0 // 3x3, 35x35, stride 2 constexpr index_t N = 128; constexpr index_t C = 288; @@ -244,21 +373,6 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<1, 0>; - using RightPads = Sequence<1, 0>; -#elif 0 - // 3x1, 8x8 - constexpr index_t N = 128; - constexpr index_t C = 448; - constexpr index_t HI = 8; - constexpr index_t WI = 8; - constexpr index_t K = 512; - constexpr index_t Y = 3; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<1, 0>; using RightPads = Sequence<1, 0>; #elif 0 @@ -278,7 +392,6 @@ int main(int argc, char* argv[]) using RightPads = Sequence<0, 0>; #elif 0 // 7x1, 73x73 - // v44@v100 xx.xx%, cudnn@v100 xx.xx% constexpr index_t N = 128; constexpr index_t C = 64; constexpr index_t HI = 73; @@ -352,7 +465,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; -#elif 0 +#elif 1 // 3x3, 28x28 constexpr index_t N = 128; constexpr index_t C = 128; @@ -382,7 +495,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>; -#elif 1 +#elif 0 // 1x1, 56x56, stride 2 constexpr index_t N = 128; constexpr index_t C = 256; @@ -442,7 +555,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; -#elif 0 +#elif 1 // 1x1, 7x7 constexpr index_t N = 128; constexpr index_t C = 512; @@ -472,7 +585,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<1, 1>; using RightPads = Sequence<1, 1>; -#elif 1 +#elif 0 // 1x1, 56x56 constexpr index_t N = 128; constexpr index_t C = 64; @@ -487,7 +600,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; -#elif 1 +#elif 0 // 3x3, 56x56 constexpr index_t N = 128; constexpr index_t C = 64; @@ -512,17 +625,26 @@ int main(int argc, char* argv[]) ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); - print_sequence("LeftPads", LeftPads{}); - print_sequence("RightPads", RightPads{}); - print_sequence("ConvStrides", ConvStrides{}); - print_sequence("ConvDilations", ConvDilations{}); + print_array("LeftPads", to_multi_index(LeftPads{})); + print_array("RightPads", to_multi_index(RightPads{})); + print_array("ConvStrides", to_multi_index(ConvStrides{})); + print_array("ConvDilations", to_multi_index(ConvDilations{})); -#if 1 - using in_data_t = float; - using out_data_t = float; -#else - using in_data_t = half_float::half; - using out_data_t = half_float::half; +#if 0 + using in_data_t = float; + constexpr index_t in_vector_size = 1; + using acc_data_t = float; + using out_data_t = float; +#elif 0 + using in_data_t = float; + constexpr index_t in_vector_size = 1; + using acc_data_t = float; + using out_data_t = int8_t; +#elif 1 + using in_data_t = int8_t; + constexpr index_t in_vector_size = 4; + using acc_data_t = int32_t; + using out_data_t = int8_t; #endif Tensor in_nchw(make_HostTensorDescriptor(in_nchw_desc)); @@ -532,14 +654,15 @@ int main(int argc, char* argv[]) std::size_t num_thread = std::thread::hardware_concurrency(); - if(argc != 3) + if(argc != 4) { - printf("arg1: do_verification, arg2: nrepeat\n"); + printf("arg1: do_verification, arg2: do_log, arg3: nrepeat\n"); exit(1); } bool do_verification = atoi(argv[1]); - index_t nrepeat = atoi(argv[2]); + bool do_log = atoi(argv[2]); + index_t nrepeat = atoi(argv[3]); if(do_verification) { @@ -548,9 +671,9 @@ int main(int argc, char* argv[]) wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); #elif 0 in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei_kcyx.GenerateTensorValue(GeneratorTensor_3{}, num_thread); + wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); #elif 0 - in_nchw.GenerateTensorValue(GeneratorTensor_3{}, num_thread); + in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); #elif 1 in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); @@ -565,59 +688,112 @@ int main(int argc, char* argv[]) #endif } -#if 1 - device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, - in_nchw, - wei_kcyx_desc, - wei_kcyx, - out_nkhw_desc, - out_nkhw_device, - ConvStrides{}, - ConvDilations{}, - LeftPads{}, - RightPads{}, - nrepeat); +#if 0 + device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, + in_nchw, + wei_kcyx_desc, + wei_kcyx, + out_nkhw_desc, + out_nkhw_device, + ConvStrides{}, + ConvDilations{}, + LeftPads{}, + RightPads{}, + nrepeat); +#elif 0 + device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc, + in_nchw, + wei_kcyx_desc, + wei_kcyx, + out_nkhw_desc, + out_nkhw_device, + ConvStrides{}, + ConvDilations{}, + LeftPads{}, + RightPads{}, + nrepeat); +#elif 0 + device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(in_nchw_desc, + in_nchw, + wei_kcyx_desc, + wei_kcyx, + out_nkhw_desc, + out_nkhw_device, + ConvStrides{}, + ConvDilations{}, + LeftPads{}, + RightPads{}, + nrepeat); +#elif 0 + device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw + + (in_nchw_desc, + in_nchw, + wei_kcyx_desc, + wei_kcyx, + out_nkhw_desc, + out_nkhw_device, + ConvStrides{}, + ConvDilations{}, + LeftPads{}, + RightPads{}, + nrepeat); #elif 1 - device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc, - in_nchw, - wei_kcyx_desc, - wei_kcyx, - out_nkhw_desc, - out_nkhw_device, - ConvStrides{}, - ConvDilations{}, - LeftPads{}, - RightPads{}, - nrepeat); + device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk + + (in_nchw_desc, + in_nchw, + wei_kcyx_desc, + wei_kcyx, + out_nkhw_desc, + out_nkhw_device, + ConvStrides{}, + ConvDilations{}, + LeftPads{}, + RightPads{}, + nrepeat); +#elif 1 + device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( + in_nchw_desc, + in_nchw, + wei_kcyx_desc, + wei_kcyx, + out_nkhw_desc, + out_nkhw_device, + ConvStrides{}, + ConvDilations{}, + LeftPads{}, + RightPads{}, + nrepeat); #endif if(do_verification) { -#if 0 - if(Y == 3 && X == 3 && ConvStrides{}[0] == 1 && ConvStrides{}[1] == 1 && - ConvDilations{}[0] == 1 && ConvDilations{}[1] == 1) - { - host_winograd_3x3_convolution( - in_nchw, wei_kcyx, out_nkhw_host, LeftPads{}, RightPads{}); - } - else -#endif - { - host_direct_convolution(in_nchw, - wei_kcyx, - out_nkhw_host, - ConvStrides{}, - ConvDilations{}, - LeftPads{}, - RightPads{}); - } + host_direct_convolution(in_nchw, + wei_kcyx, + out_nkhw_host, + ConvStrides{}, + ConvDilations{}, + LeftPads{}, + RightPads{}); + check_error(out_nkhw_host, out_nkhw_device); -#if 0 - LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl; - LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl; - LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl; - LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl; -#endif + if(do_log) + { + LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl; + LogRange(std::cout << "wei_kcyx: ", wei_kcyx.mData, ",") << std::endl; + LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl; + LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl; + } } } diff --git a/driver/src/conv_driver.cu b/driver/src/conv_driver.cu deleted file mode 120000 index 6b7ee4e3e0..0000000000 --- a/driver/src/conv_driver.cu +++ /dev/null @@ -1 +0,0 @@ -conv_driver.cpp \ No newline at end of file diff --git a/external/half/include/half.hpp b/external/half/include/half.hpp index 1172a2c564..f15e8d00dd 100644 --- a/external/half/include/half.hpp +++ b/external/half/include/half.hpp @@ -508,8 +508,8 @@ template struct bool_type : std::integral_constant { }; -using std::true_type; using std::false_type; +using std::true_type; /// Type traits for floating-point types. template @@ -854,8 +854,8 @@ inline HALF_CONSTEXPR_NOERR unsigned int signal(unsigned int x, unsigned int y, ((x & 0x7FFF) > 0x7C00 && !(x & 0x200)) || ((y & 0x7FFF) > 0x7C00 && !(y & 0x200)) || ((z & 0x7FFF) > 0x7C00 && !(z & 0x200))); #endif - return ((x & 0x7FFF) > 0x7C00) ? (x | 0x200) : ((y & 0x7FFF) > 0x7C00) ? (y | 0x200) - : (z | 0x200); + return ((x & 0x7FFF) > 0x7C00) ? (x | 0x200) + : ((y & 0x7FFF) > 0x7C00) ? (y | 0x200) : (z | 0x200); } /// Select value or signaling NaN. @@ -1756,9 +1756,9 @@ uint32 mulhi(uint32 x, uint32 y) uint32 xy = (x >> 16) * (y & 0xFFFF), yx = (x & 0xFFFF) * (y >> 16), c = (xy & 0xFFFF) + (yx & 0xFFFF) + (((x & 0xFFFF) * (y & 0xFFFF)) >> 16); return (x >> 16) * (y >> 16) + (xy >> 16) + (yx >> 16) + (c >> 16) + - ((R == std::round_to_nearest) ? ((c >> 15) & 1) : (R == std::round_toward_infinity) - ? ((c & 0xFFFF) != 0) - : 0); + ((R == std::round_to_nearest) + ? ((c >> 15) & 1) + : (R == std::round_toward_infinity) ? ((c & 0xFFFF) != 0) : 0); } /// 64-bit multiplication. @@ -2247,7 +2247,7 @@ unsigned int area(unsigned int arg) { if(expy < 0) { - r = 0x40000000 + ((expy > -30) ? ((r >> -expy) | + r = 0x40000000 + ((expy > -30) ? ((r >> -expy) | ((r & ((static_cast(1) << -expy) - 1)) != 0)) : 1); expy = 0; @@ -2379,10 +2379,12 @@ unsigned int erf(unsigned int arg) t / ((x2.exp < 0) ? f31(exp2((x2.exp > -32) ? (x2.m >> -x2.exp) : 0, 30), 0) : f31(exp2((x2.m << x2.exp) & 0x7FFFFFFF, 22), x2.m >> (31 - x2.exp))); - return (!C || sign) ? fixed2half( - 0x80000000 - (e.m >> (C - e.exp)), 14 + C, sign & (C - 1U)) - : (e.exp < -25) ? underflow() : fixed2half( - e.m >> 1, e.exp + 14, 0, e.m & 1); + return (!C || sign) + ? fixed2half( + 0x80000000 - (e.m >> (C - e.exp)), 14 + C, sign & (C - 1U)) + : (e.exp < -25) + ? underflow() + : fixed2half(e.m >> 1, e.exp + 14, 0, e.m & 1); } /// Gamma function and postprocessing. @@ -2402,8 +2404,7 @@ unsigned int gamma(unsigned int arg) for(unsigned int i=0; i<5; ++i) s += p[i+1] / (arg+i); return std::log(s) + (arg-0.5)*std::log(t) - t; -*/ static const f31 - pi(0xC90FDAA2, 1), +*/ static const f31 pi(0xC90FDAA2, 1), lbe(0xB8AA3B29, 0); unsigned int abs = arg & 0x7FFF, sign = arg & 0x8000; bool bsign = sign != 0; @@ -2490,7 +2491,7 @@ unsigned int gamma(unsigned int arg) { if(z.exp < 0) s = s * z; - s = pi / s; + s = pi / s; if(s.exp < -24) return underflow(sign); } @@ -2789,7 +2790,7 @@ inline half operator"" _h(long double value) { return half(detail::binary, detail::float2half(value)); } -} +} // namespace literal #endif namespace detail { @@ -2837,8 +2838,8 @@ struct half_caster { static half cast(half arg) { return arg; } }; -} -} +} // namespace detail +} // namespace half_float /// Extensions to the C++ standard library. namespace std { @@ -3003,7 +3004,7 @@ struct hash } }; #endif -} +} // namespace std namespace half_float { /// \anchor compop @@ -3122,13 +3123,14 @@ inline half operator+(half x, half y) return half(detail::binary, (absx > 0x7C00 || absy > 0x7C00) ? detail::signal(x.data_, y.data_) - : (absy != 0x7C00) ? x.data_ : (sub && absx == 0x7C00) ? detail::invalid() - : y.data_); + : (absy != 0x7C00) ? x.data_ + : (sub && absx == 0x7C00) ? detail::invalid() : y.data_); if(!absx) - return absy ? y : half(detail::binary, - (half::round_style == std::round_toward_neg_infinity) - ? (x.data_ | y.data_) - : (x.data_ & y.data_)); + return absy ? y + : half(detail::binary, + (half::round_style == std::round_toward_neg_infinity) + ? (x.data_ | y.data_) + : (x.data_ & y.data_)); if(!absy) return x; unsigned int sign = ((sub && absy > absx) ? y.data_ : x.data_) & 0x8000; @@ -3449,10 +3451,11 @@ inline half fma(half x, half y, half z) : (sign | 0x7C00)) : z; if(!absx || !absy) - return absz ? z : half(detail::binary, - (half::round_style == std::round_toward_neg_infinity) - ? (z.data_ | sign) - : (z.data_ & sign)); + return absz + ? z + : half(detail::binary, + (half::round_style == std::round_toward_neg_infinity) ? (z.data_ | sign) + : (z.data_ & sign)); for(; absx < 0x400; absx <<= 1, --exp) ; for(; absy < 0x400; absy <<= 1, --exp) @@ -3516,9 +3519,8 @@ inline half fma(half x, half y, half z) inline HALF_CONSTEXPR_NOERR half fmax(half x, half y) { return half(detail::binary, - (!isnan(y) && (isnan(x) || - (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) < - (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))))) + (!isnan(y) && (isnan(x) || (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) < + (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))))) ? detail::select(y.data_, x.data_) : detail::select(x.data_, y.data_)); } @@ -3533,9 +3535,8 @@ inline HALF_CONSTEXPR_NOERR half fmax(half x, half y) inline HALF_CONSTEXPR_NOERR half fmin(half x, half y) { return half(detail::binary, - (!isnan(y) && (isnan(x) || - (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) > - (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))))) + (!isnan(y) && (isnan(x) || (x.data_ ^ (0x8000 | (0x8000 - (x.data_ >> 15)))) > + (y.data_ ^ (0x8000 | (0x8000 - (y.data_ >> 15)))))) ? detail::select(y.data_, x.data_) : detail::select(x.data_, y.data_)); } @@ -3886,9 +3887,9 @@ inline half log1p(half arg) #else if(arg.data_ >= 0xBC00) return half(detail::binary, - (arg.data_ == 0xBC00) ? detail::pole(0x8000) : (arg.data_ <= 0xFC00) - ? detail::invalid() - : detail::signal(arg.data_)); + (arg.data_ == 0xBC00) + ? detail::pole(0x8000) + : (arg.data_ <= 0xFC00) ? detail::invalid() : detail::signal(arg.data_)); int abs = arg.data_ & 0x7FFF, exp = -15; if(!abs || abs >= 0x7C00) return (abs > 0x7C00) ? half(detail::binary, detail::signal(arg.data_)) : arg; @@ -4395,7 +4396,7 @@ inline half cos(half arg) if(half::round_style != std::round_to_nearest && abs == 0x598C) return half(detail::binary, detail::rounded(0x80FC, 1, 1)); std::pair sc = detail::sincos(detail::angle_arg(abs, k), 28); - detail::uint32 sign = -static_cast(((k >> 1) ^ k) & 1); + detail::uint32 sign = -static_cast(((k >> 1) ^ k) & 1); return half(detail::binary, detail::fixed2half( (((k & 1) ? sc.first : sc.second) ^ sign) - sign)); @@ -4439,7 +4440,7 @@ inline half tan(half arg) } std::pair sc = detail::sincos(detail::angle_arg(abs, k), 30); if(k & 1) - sc = std::make_pair(-sc.second, sc.first); + sc = std::make_pair(-sc.second, sc.first); detail::uint32 signy = detail::sign_mask(sc.first), signx = detail::sign_mask(sc.second); detail::uint32 my = (sc.first ^ signy) - signy, mx = (sc.second ^ signx) - signx; for(; my < 0x80000000; my <<= 1, --exp) @@ -4517,7 +4518,7 @@ inline half acos(half arg) ? detail::invalid() : sign ? detail::rounded(0x4248, 0, 1) : 0); std::pair cs = detail::atan2_args(abs); - detail::uint32 m = detail::atan2(cs.second, cs.first, 28); + detail::uint32 m = detail::atan2(cs.second, cs.first, 28); return half(detail::binary, detail::fixed2half( sign ? (0xC90FDAA2 - m) : m, 15, 0, sign)); @@ -5354,13 +5355,13 @@ inline HALF_CONSTEXPR half copysign(half x, half y) /// \retval FP_NORMAL for all other (normal) values inline HALF_CONSTEXPR int fpclassify(half arg) { - return !(arg.data_ & 0x7FFF) ? FP_ZERO : ((arg.data_ & 0x7FFF) < 0x400) - ? FP_SUBNORMAL - : ((arg.data_ & 0x7FFF) < 0x7C00) - ? FP_NORMAL - : ((arg.data_ & 0x7FFF) == 0x7C00) - ? FP_INFINITE - : FP_NAN; + return !(arg.data_ & 0x7FFF) + ? FP_ZERO + : ((arg.data_ & 0x7FFF) < 0x400) + ? FP_SUBNORMAL + : ((arg.data_ & 0x7FFF) < 0x7C00) + ? FP_NORMAL + : ((arg.data_ & 0x7FFF) == 0x7C00) ? FP_INFINITE : FP_NAN; } /// Check if finite number. @@ -5652,7 +5653,7 @@ inline void fethrowexcept(int excepts, const char* msg = "") throw std::range_error(msg); } /// \} -} +} // namespace half_float #undef HALF_UNUSED_NOERR #undef HALF_CONSTEXPR diff --git a/script/cmake-rocm3.5.sh b/script/cmake-rocm3.7.sh similarity index 65% rename from script/cmake-rocm3.5.sh rename to script/cmake-rocm3.7.sh index d3a9b575ee..929a0b6e87 100755 --- a/script/cmake-rocm3.5.sh +++ b/script/cmake-rocm3.7.sh @@ -3,19 +3,23 @@ rm -f CMakeCache.txt rm -f *.cmake rm -rf CMakeFiles -MY_PROJECT_SOURCE=../../../ +MY_PROJECT_SOURCE=../ MY_PROJECT_INSTALL=../install.dir cmake \ -D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \ -D CMAKE_BUILD_TYPE=Release \ -D DEVICE_BACKEND="AMD" \ --D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-enable-global-sgpr-addr -mllvm --amdgpu-spill-vgpr-to-agpr=0" \ +-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx1030 -gline-tables-only -save-temps=$CWD -ftemplate-backtrace-limit=0" \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_PREFIX_PATH="/opt/rocm" \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ ${MY_PROJECT_SOURCE} +#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -save-temps=$CWD" \ +#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -gline-tables-only -save-temps=$CWD" \ +#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-spill-vgpr-to-agpr=0" \ +#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -save-temps=$CWD" \ #-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-enable-global-sgpr-addr -mllvm --amdgpu-spill-vgpr-to-agpr=0" \ -#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-enable-global-sgpr-addr -mllvm --amdgpu-spill-vgpr-to-agpr=0 -save-temps" \ -#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-enable-global-sgpr-addr -mllvm --amdgpu-spill-vgpr-to-agpr=0 -v -gline-tables-only -save-temps" \ +#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-enable-global-sgpr-addr -mllvm --amdgpu-spill-vgpr-to-agpr=0 -save-temps=$CWD" \ +#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-enable-global-sgpr-addr -mllvm --amdgpu-spill-vgpr-to-agpr=0 -v -gline-tables-only -save-temps=$CWD" \ diff --git a/script/count_vgpr.sh b/script/count_vgpr.sh new file mode 100755 index 0000000000..4fbfec0278 --- /dev/null +++ b/script/count_vgpr.sh @@ -0,0 +1,259 @@ +#!/bin/bash +FILE=$1 + +echo v0 $( grep -w v0 $FILE | wc -l ) +echo v1 $( grep -w v1 $FILE | wc -l ) +echo v2 $( grep -w v2 $FILE | wc -l ) +echo v3 $( grep -w v3 $FILE | wc -l ) +echo v4 $( grep -w v4 $FILE | wc -l ) +echo v5 $( grep -w v5 $FILE | wc -l ) +echo v6 $( grep -w v6 $FILE | wc -l ) +echo v7 $( grep -w v7 $FILE | wc -l ) +echo v8 $( grep -w v8 $FILE | wc -l ) +echo v9 $( grep -w v9 $FILE | wc -l ) +echo v10 $( grep -w v10 $FILE | wc -l ) +echo v11 $( grep -w v11 $FILE | wc -l ) +echo v12 $( grep -w v12 $FILE | wc -l ) +echo v13 $( grep -w v13 $FILE | wc -l ) +echo v14 $( grep -w v14 $FILE | wc -l ) +echo v15 $( grep -w v15 $FILE | wc -l ) +echo v16 $( grep -w v16 $FILE | wc -l ) +echo v17 $( grep -w v17 $FILE | wc -l ) +echo v18 $( grep -w v18 $FILE | wc -l ) +echo v19 $( grep -w v19 $FILE | wc -l ) +echo v20 $( grep -w v20 $FILE | wc -l ) +echo v21 $( grep -w v21 $FILE | wc -l ) +echo v22 $( grep -w v22 $FILE | wc -l ) +echo v23 $( grep -w v23 $FILE | wc -l ) +echo v24 $( grep -w v24 $FILE | wc -l ) +echo v25 $( grep -w v25 $FILE | wc -l ) +echo v26 $( grep -w v26 $FILE | wc -l ) +echo v27 $( grep -w v27 $FILE | wc -l ) +echo v28 $( grep -w v28 $FILE | wc -l ) +echo v29 $( grep -w v29 $FILE | wc -l ) +echo v30 $( grep -w v30 $FILE | wc -l ) +echo v31 $( grep -w v31 $FILE | wc -l ) +echo v32 $( grep -w v32 $FILE | wc -l ) +echo v33 $( grep -w v33 $FILE | wc -l ) +echo v34 $( grep -w v34 $FILE | wc -l ) +echo v35 $( grep -w v35 $FILE | wc -l ) +echo v36 $( grep -w v36 $FILE | wc -l ) +echo v37 $( grep -w v37 $FILE | wc -l ) +echo v38 $( grep -w v38 $FILE | wc -l ) +echo v39 $( grep -w v39 $FILE | wc -l ) +echo v40 $( grep -w v40 $FILE | wc -l ) +echo v41 $( grep -w v41 $FILE | wc -l ) +echo v42 $( grep -w v42 $FILE | wc -l ) +echo v43 $( grep -w v43 $FILE | wc -l ) +echo v44 $( grep -w v44 $FILE | wc -l ) +echo v45 $( grep -w v45 $FILE | wc -l ) +echo v46 $( grep -w v46 $FILE | wc -l ) +echo v47 $( grep -w v47 $FILE | wc -l ) +echo v48 $( grep -w v48 $FILE | wc -l ) +echo v49 $( grep -w v49 $FILE | wc -l ) +echo v50 $( grep -w v50 $FILE | wc -l ) +echo v51 $( grep -w v51 $FILE | wc -l ) +echo v52 $( grep -w v52 $FILE | wc -l ) +echo v53 $( grep -w v53 $FILE | wc -l ) +echo v54 $( grep -w v54 $FILE | wc -l ) +echo v55 $( grep -w v55 $FILE | wc -l ) +echo v56 $( grep -w v56 $FILE | wc -l ) +echo v57 $( grep -w v57 $FILE | wc -l ) +echo v58 $( grep -w v58 $FILE | wc -l ) +echo v59 $( grep -w v59 $FILE | wc -l ) +echo v60 $( grep -w v60 $FILE | wc -l ) +echo v61 $( grep -w v61 $FILE | wc -l ) +echo v62 $( grep -w v62 $FILE | wc -l ) +echo v63 $( grep -w v63 $FILE | wc -l ) +echo v64 $( grep -w v64 $FILE | wc -l ) +echo v65 $( grep -w v65 $FILE | wc -l ) +echo v66 $( grep -w v66 $FILE | wc -l ) +echo v67 $( grep -w v67 $FILE | wc -l ) +echo v68 $( grep -w v68 $FILE | wc -l ) +echo v69 $( grep -w v69 $FILE | wc -l ) +echo v70 $( grep -w v70 $FILE | wc -l ) +echo v71 $( grep -w v71 $FILE | wc -l ) +echo v72 $( grep -w v72 $FILE | wc -l ) +echo v73 $( grep -w v73 $FILE | wc -l ) +echo v74 $( grep -w v74 $FILE | wc -l ) +echo v75 $( grep -w v75 $FILE | wc -l ) +echo v76 $( grep -w v76 $FILE | wc -l ) +echo v77 $( grep -w v77 $FILE | wc -l ) +echo v78 $( grep -w v78 $FILE | wc -l ) +echo v79 $( grep -w v79 $FILE | wc -l ) +echo v80 $( grep -w v80 $FILE | wc -l ) +echo v81 $( grep -w v81 $FILE | wc -l ) +echo v82 $( grep -w v82 $FILE | wc -l ) +echo v83 $( grep -w v83 $FILE | wc -l ) +echo v84 $( grep -w v84 $FILE | wc -l ) +echo v85 $( grep -w v85 $FILE | wc -l ) +echo v86 $( grep -w v86 $FILE | wc -l ) +echo v87 $( grep -w v87 $FILE | wc -l ) +echo v88 $( grep -w v88 $FILE | wc -l ) +echo v89 $( grep -w v89 $FILE | wc -l ) +echo v90 $( grep -w v90 $FILE | wc -l ) +echo v91 $( grep -w v91 $FILE | wc -l ) +echo v92 $( grep -w v92 $FILE | wc -l ) +echo v93 $( grep -w v93 $FILE | wc -l ) +echo v94 $( grep -w v94 $FILE | wc -l ) +echo v95 $( grep -w v95 $FILE | wc -l ) +echo v96 $( grep -w v96 $FILE | wc -l ) +echo v97 $( grep -w v97 $FILE | wc -l ) +echo v98 $( grep -w v98 $FILE | wc -l ) +echo v99 $( grep -w v99 $FILE | wc -l ) +echo v100 $( grep -w v100 $FILE | wc -l ) +echo v101 $( grep -w v101 $FILE | wc -l ) +echo v102 $( grep -w v102 $FILE | wc -l ) +echo v103 $( grep -w v103 $FILE | wc -l ) +echo v104 $( grep -w v104 $FILE | wc -l ) +echo v105 $( grep -w v105 $FILE | wc -l ) +echo v106 $( grep -w v106 $FILE | wc -l ) +echo v107 $( grep -w v107 $FILE | wc -l ) +echo v108 $( grep -w v108 $FILE | wc -l ) +echo v109 $( grep -w v109 $FILE | wc -l ) +echo v110 $( grep -w v110 $FILE | wc -l ) +echo v111 $( grep -w v111 $FILE | wc -l ) +echo v112 $( grep -w v112 $FILE | wc -l ) +echo v113 $( grep -w v113 $FILE | wc -l ) +echo v114 $( grep -w v114 $FILE | wc -l ) +echo v115 $( grep -w v115 $FILE | wc -l ) +echo v116 $( grep -w v116 $FILE | wc -l ) +echo v117 $( grep -w v117 $FILE | wc -l ) +echo v118 $( grep -w v118 $FILE | wc -l ) +echo v119 $( grep -w v119 $FILE | wc -l ) +echo v120 $( grep -w v120 $FILE | wc -l ) +echo v121 $( grep -w v121 $FILE | wc -l ) +echo v122 $( grep -w v122 $FILE | wc -l ) +echo v123 $( grep -w v123 $FILE | wc -l ) +echo v124 $( grep -w v124 $FILE | wc -l ) +echo v125 $( grep -w v125 $FILE | wc -l ) +echo v126 $( grep -w v126 $FILE | wc -l ) +echo v127 $( grep -w v127 $FILE | wc -l ) +echo v128 $( grep -w v128 $FILE | wc -l ) +echo v129 $( grep -w v129 $FILE | wc -l ) +echo v130 $( grep -w v130 $FILE | wc -l ) +echo v131 $( grep -w v131 $FILE | wc -l ) +echo v132 $( grep -w v132 $FILE | wc -l ) +echo v133 $( grep -w v133 $FILE | wc -l ) +echo v134 $( grep -w v134 $FILE | wc -l ) +echo v135 $( grep -w v135 $FILE | wc -l ) +echo v136 $( grep -w v136 $FILE | wc -l ) +echo v137 $( grep -w v137 $FILE | wc -l ) +echo v138 $( grep -w v138 $FILE | wc -l ) +echo v139 $( grep -w v139 $FILE | wc -l ) +echo v140 $( grep -w v140 $FILE | wc -l ) +echo v141 $( grep -w v141 $FILE | wc -l ) +echo v142 $( grep -w v142 $FILE | wc -l ) +echo v143 $( grep -w v143 $FILE | wc -l ) +echo v144 $( grep -w v144 $FILE | wc -l ) +echo v145 $( grep -w v145 $FILE | wc -l ) +echo v146 $( grep -w v146 $FILE | wc -l ) +echo v147 $( grep -w v147 $FILE | wc -l ) +echo v148 $( grep -w v148 $FILE | wc -l ) +echo v149 $( grep -w v149 $FILE | wc -l ) +echo v150 $( grep -w v150 $FILE | wc -l ) +echo v151 $( grep -w v151 $FILE | wc -l ) +echo v152 $( grep -w v152 $FILE | wc -l ) +echo v153 $( grep -w v153 $FILE | wc -l ) +echo v154 $( grep -w v154 $FILE | wc -l ) +echo v155 $( grep -w v155 $FILE | wc -l ) +echo v156 $( grep -w v156 $FILE | wc -l ) +echo v157 $( grep -w v157 $FILE | wc -l ) +echo v158 $( grep -w v158 $FILE | wc -l ) +echo v159 $( grep -w v159 $FILE | wc -l ) +echo v160 $( grep -w v160 $FILE | wc -l ) +echo v161 $( grep -w v161 $FILE | wc -l ) +echo v162 $( grep -w v162 $FILE | wc -l ) +echo v163 $( grep -w v163 $FILE | wc -l ) +echo v164 $( grep -w v164 $FILE | wc -l ) +echo v165 $( grep -w v165 $FILE | wc -l ) +echo v166 $( grep -w v166 $FILE | wc -l ) +echo v167 $( grep -w v167 $FILE | wc -l ) +echo v168 $( grep -w v168 $FILE | wc -l ) +echo v169 $( grep -w v169 $FILE | wc -l ) +echo v170 $( grep -w v170 $FILE | wc -l ) +echo v171 $( grep -w v171 $FILE | wc -l ) +echo v172 $( grep -w v172 $FILE | wc -l ) +echo v173 $( grep -w v173 $FILE | wc -l ) +echo v174 $( grep -w v174 $FILE | wc -l ) +echo v175 $( grep -w v175 $FILE | wc -l ) +echo v176 $( grep -w v176 $FILE | wc -l ) +echo v177 $( grep -w v177 $FILE | wc -l ) +echo v178 $( grep -w v178 $FILE | wc -l ) +echo v179 $( grep -w v179 $FILE | wc -l ) +echo v180 $( grep -w v180 $FILE | wc -l ) +echo v181 $( grep -w v181 $FILE | wc -l ) +echo v182 $( grep -w v182 $FILE | wc -l ) +echo v183 $( grep -w v183 $FILE | wc -l ) +echo v184 $( grep -w v184 $FILE | wc -l ) +echo v185 $( grep -w v185 $FILE | wc -l ) +echo v186 $( grep -w v186 $FILE | wc -l ) +echo v187 $( grep -w v187 $FILE | wc -l ) +echo v188 $( grep -w v188 $FILE | wc -l ) +echo v189 $( grep -w v189 $FILE | wc -l ) +echo v190 $( grep -w v190 $FILE | wc -l ) +echo v191 $( grep -w v191 $FILE | wc -l ) +echo v192 $( grep -w v192 $FILE | wc -l ) +echo v193 $( grep -w v193 $FILE | wc -l ) +echo v194 $( grep -w v194 $FILE | wc -l ) +echo v195 $( grep -w v195 $FILE | wc -l ) +echo v196 $( grep -w v196 $FILE | wc -l ) +echo v197 $( grep -w v197 $FILE | wc -l ) +echo v198 $( grep -w v198 $FILE | wc -l ) +echo v199 $( grep -w v199 $FILE | wc -l ) +echo v200 $( grep -w v200 $FILE | wc -l ) +echo v201 $( grep -w v201 $FILE | wc -l ) +echo v202 $( grep -w v202 $FILE | wc -l ) +echo v203 $( grep -w v203 $FILE | wc -l ) +echo v204 $( grep -w v204 $FILE | wc -l ) +echo v205 $( grep -w v205 $FILE | wc -l ) +echo v206 $( grep -w v206 $FILE | wc -l ) +echo v207 $( grep -w v207 $FILE | wc -l ) +echo v208 $( grep -w v208 $FILE | wc -l ) +echo v209 $( grep -w v209 $FILE | wc -l ) +echo v210 $( grep -w v210 $FILE | wc -l ) +echo v211 $( grep -w v211 $FILE | wc -l ) +echo v212 $( grep -w v212 $FILE | wc -l ) +echo v213 $( grep -w v213 $FILE | wc -l ) +echo v214 $( grep -w v214 $FILE | wc -l ) +echo v215 $( grep -w v215 $FILE | wc -l ) +echo v216 $( grep -w v216 $FILE | wc -l ) +echo v217 $( grep -w v217 $FILE | wc -l ) +echo v218 $( grep -w v218 $FILE | wc -l ) +echo v219 $( grep -w v219 $FILE | wc -l ) +echo v220 $( grep -w v220 $FILE | wc -l ) +echo v221 $( grep -w v221 $FILE | wc -l ) +echo v222 $( grep -w v222 $FILE | wc -l ) +echo v223 $( grep -w v223 $FILE | wc -l ) +echo v224 $( grep -w v224 $FILE | wc -l ) +echo v225 $( grep -w v225 $FILE | wc -l ) +echo v226 $( grep -w v226 $FILE | wc -l ) +echo v227 $( grep -w v227 $FILE | wc -l ) +echo v228 $( grep -w v228 $FILE | wc -l ) +echo v229 $( grep -w v229 $FILE | wc -l ) +echo v230 $( grep -w v230 $FILE | wc -l ) +echo v231 $( grep -w v231 $FILE | wc -l ) +echo v232 $( grep -w v232 $FILE | wc -l ) +echo v233 $( grep -w v233 $FILE | wc -l ) +echo v234 $( grep -w v234 $FILE | wc -l ) +echo v235 $( grep -w v235 $FILE | wc -l ) +echo v236 $( grep -w v236 $FILE | wc -l ) +echo v237 $( grep -w v237 $FILE | wc -l ) +echo v238 $( grep -w v238 $FILE | wc -l ) +echo v239 $( grep -w v239 $FILE | wc -l ) +echo v240 $( grep -w v240 $FILE | wc -l ) +echo v241 $( grep -w v241 $FILE | wc -l ) +echo v242 $( grep -w v242 $FILE | wc -l ) +echo v243 $( grep -w v243 $FILE | wc -l ) +echo v244 $( grep -w v244 $FILE | wc -l ) +echo v245 $( grep -w v245 $FILE | wc -l ) +echo v246 $( grep -w v246 $FILE | wc -l ) +echo v247 $( grep -w v247 $FILE | wc -l ) +echo v248 $( grep -w v248 $FILE | wc -l ) +echo v249 $( grep -w v249 $FILE | wc -l ) +echo v250 $( grep -w v250 $FILE | wc -l ) +echo v251 $( grep -w v251 $FILE | wc -l ) +echo v252 $( grep -w v252 $FILE | wc -l ) +echo v253 $( grep -w v253 $FILE | wc -l ) +echo v254 $( grep -w v254 $FILE | wc -l ) +echo v255 $( grep -w v255 $FILE | wc -l ) diff --git a/script/hipclang_opt.sh b/script/hipclang_opt.sh new file mode 100755 index 0000000000..c51bd51d97 --- /dev/null +++ b/script/hipclang_opt.sh @@ -0,0 +1,25 @@ +rm *.ll *.s + +BC_FILE=$1 + +/opt/rocm/llvm/bin/llvm-dis $BC_FILE -o original.ll +/opt/rocm/llvm/bin/opt -S -inline -inline-threshold=104857 original.ll > inline.ll +/opt/rocm/llvm/bin/opt -S -sroa inline.ll > sroa.ll +/opt/rocm/llvm/bin/opt -S -O3 sroa.ll > o3.ll + +/opt/rocm/llvm/bin/llc -mcpu=gfx906 original.ll +/opt/rocm/llvm/bin/llc -mcpu=gfx906 inline.ll +/opt/rocm/llvm/bin/llc -mcpu=gfx906 sroa.ll +/opt/rocm/llvm/bin/llc -mcpu=gfx906 o3.ll + +#/opt/rocm/llvm/bin/opt -S -O3 -sroa inline.ll > o3.ll +#/opt/rocm/llvm/bin/opt -S -O3 -sroa o3.ll > o3_2.ll +#/opt/rocm/llvm/bin/opt -S -O3 -sroa o3_2.ll > o3_3.ll +#/opt/rocm/llvm/bin/opt -S -O3 -sroa o3_3.ll > o3_4.ll + +#/opt/rocm/llvm/bin/llc -mcpu=gfx908 opt.ll +#/opt/rocm/llvm/bin/llc -mcpu=gfx908 inline.ll +#/opt/rocm/llvm/bin/llc -mcpu=gfx908 o3.ll +#/opt/rocm/llvm/bin/llc -mcpu=gfx908 o3_2.ll +#/opt/rocm/llvm/bin/llc -mcpu=gfx908 o3_3.ll +#/opt/rocm/llvm/bin/llc -mcpu=gfx908 o3_4.ll