diff --git a/CMakeLists.txt b/CMakeLists.txt index 882a1af308..962ae7f00d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -81,24 +81,27 @@ message("Compiling options for drivers: ${CMAKE_CXX_FLAGS}") if(DEVICE_BACKEND STREQUAL "AMD") set(CONV_SOURCE driver/conv_driver.cpp) - set(CONV_V2_SOURCE driver/conv_driver_v2.cpp) - set(CONV_V2_OLC_SOURCE driver/conv_driver_v2_olc.cpp) set(CONV_BWD_DATA_SOURCE driver/conv_bwd_data_driver.cpp) + set(CONV_V2_SOURCE driver/conv_driver_v2.cpp) + set(CONV_BWD_DATA_V2_SOURCE driver/conv_bwd_data_driver_v2.cpp) + set(CONV_V2_OLC_SOURCE driver/conv_driver_v2_olc.cpp) elseif(DEVICE_BACKEND STREQUAL "NVIDIA") set(CONV_SOURCE driver/conv_driver.cu) set(CONV_BWD_DATA_SOURCE driver/conv_bwd_data_driver.cu) endif() -##add_executable(conv_driver ${CONV_SOURCE}) +add_executable(conv_driver ${CONV_SOURCE}) +add_executable(conv_bwd_data_driver ${CONV_BWD_DATA_SOURCE}) add_executable(conv_driver_v2 ${CONV_V2_SOURCE}) +add_executable(conv_bwd_data_driver_v2 ${CONV_BWD_DATA_V2_SOURCE}) add_executable(conv_driver_v2_olc ${CONV_V2_OLC_SOURCE}) -##add_executable(conv_bwd_data_driver ${CONV_BWD_DATA_SOURCE}) target_include_directories(conv_driver_v2_olc PRIVATE driver/olCompiling/include/) -##target_link_libraries(conv_driver PRIVATE modConv) +target_link_libraries(conv_driver PRIVATE modConv) +target_link_libraries(conv_bwd_data_driver PRIVATE modConv) target_link_libraries(conv_driver_v2 PRIVATE modConv) +target_link_libraries(conv_bwd_data_driver_v2 PRIVATE modConv) target_link_libraries(conv_driver_v2_olc PRIVATE modConv) -##target_link_libraries(conv_bwd_data_driver PRIVATE modConv) diff --git a/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..c2a67062c8 --- /dev/null +++ b/composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,365 @@ +#ifndef CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP +#define CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "driver_dynamic_gemm_xdlops_v1.hpp" +#include "driver_dynamic_gemm_xdlops_v2.hpp" + +namespace ck { + +// GemmM = K +// GemmN = N * Ho * Wo +// GemmK = C * Y * X +template +__host__ __device__ constexpr auto +transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad( + const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, + const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, + const DynamicTensorDescriptor& out_n_k_ho_wo_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); + const auto X = wei_k_c_y_x_global_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = N * Ho * Wo; + const auto GemmK = C * Y * X; + const auto GemmK0 = GemmK / GemmKPack; + + // weight tensor + const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmm_gemmk1_global_desc = transform_dynamic_tensor_descriptor( + wei_gemmk_gemmm_global_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // input tensor + const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_hip_wip_global_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_gemmk_gemmn_global_desc = + transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_global_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmn_gemmk1_global_desc = transform_dynamic_tensor_descriptor( + in_gemmk_gemmn_global_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + assert(GemmM == out_gemmm_gemmn_global_desc.GetLength(I0)); + assert(GemmN == out_gemmm_gemmn_global_desc.GetLength(I1)); + assert(GemmK0 == in_gemmk0_gemmn_gemmk1_global_desc.GetLength(I0)); + assert(GemmK0 == wei_gemmk0_gemmm_gemmk1_global_desc.GetLength(I0)); + + assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK0 % GemmKPerBlock == 0); + + constexpr auto xdlops_gemm = XdlopsGemm{}; + + constexpr auto CLayout = xdlops_gemm.GetCLayout(); + + constexpr index_t M0 = CLayout.M1(); + constexpr index_t M1 = CLayout.N1(); + constexpr index_t M2 = CLayout.M0(); + + const auto out_m0_m1_m2_n_global_desc = transform_dynamic_tensor_descriptor( + out_gemmm_gemmn_global_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmM / (M1 * M2), M1, M2)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{})); + + // out_gemm_block_cluster_desc + const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2( + make_tuple(GemmM / Number{}, GemmN / Number{})); + + // hack to control index calculation when iterating over wei_gemmk0_gemmm_gemmk1_global tensor + constexpr auto wei_gemmk0_gemmm_gemmk1_global_iterator_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + // hack to control index calculation when iterating over in_gemmk0_gemmn_gemmk1_global tensor + constexpr auto in_gemmk0_gemmn_gemmk1_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); + + constexpr auto in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; + + // hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global + // tensor hack for NKHW format + constexpr auto out_m0_m1_m2_n_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); + + return make_tuple(wei_gemmk0_gemmm_gemmk1_global_desc, + in_gemmk0_gemmn_gemmk1_global_desc, + out_m0_m1_m2_n_global_desc, + out_gemm_block_cluster_desc, + wei_gemmk0_gemmm_gemmk1_global_iterator_hacks, + in_gemmk0_gemmn_gemmk1_global_iterator_hacks, + out_m0_m1_m2_n_global_iterator_hacks, + wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks, + in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks); +} + +// GemmM = K +// GemmN = N * Ho * Wo +// GemmK = C * Y * X +template +__host__ __device__ constexpr auto +transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1( + const DynamicTensorDescriptor& wei_k_c_y_x_global_desc, + const DynamicTensorDescriptor& in_n_c_hi_wi_global_desc, + const DynamicTensorDescriptor& out_n_k_ho_wo_global_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_c_hi_wi_global_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_global_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_global_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_global_desc.GetLength(I2); + const auto X = wei_k_c_y_x_global_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = N * Ho * Wo; + const auto GemmK = C * Y * X; + const auto GemmK0 = GemmK / GemmKPack; + + assert(Y == 1 && X == 1 && ConvStrideH == 1 && ConvStrideW == 1 && ConvDilationH == 1 && + ConvDilationW == 1 && InLeftPadH == 0 && InLeftPadW == 0 && InRightPadH == 0 && + InRightPadW == 0); + + // weight tensor + const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmm_gemmk1_global_desc = transform_dynamic_tensor_descriptor( + wei_gemmk_gemmm_global_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // input tensor + const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor( + in_n_c_hi_wi_global_desc, + make_tuple(make_pass_through_transform(C), make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmn_gemmk1_global_desc = transform_dynamic_tensor_descriptor( + in_gemmk_gemmn_global_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmKPack)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + assert(GemmM == out_gemmm_gemmn_global_desc.GetLength(I0)); + assert(GemmN == out_gemmm_gemmn_global_desc.GetLength(I1)); + assert(GemmK0 == in_gemmk0_gemmn_gemmk1_global_desc.GetLength(I0)); + assert(GemmK0 == wei_gemmk0_gemmm_gemmk1_global_desc.GetLength(I0)); + + assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK0 % GemmKPerBlock == 0); + + constexpr auto xdlops_gemm = XdlopsGemm{}; + + constexpr auto CLayout = xdlops_gemm.GetCLayout(); + + constexpr index_t M0 = CLayout.M1(); + constexpr index_t M1 = CLayout.N1(); + constexpr index_t M2 = CLayout.M0(); + + const auto out_m0_m1_m2_n_global_desc = transform_dynamic_tensor_descriptor( + out_gemmm_gemmn_global_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmM / (M1 * M2), M1, M2)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{})); + + // out_gemm_block_cluster_desc + const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2( + make_tuple(GemmM / Number{}, GemmN / Number{})); + + // hack to control index calculation when iterating over wei_gemmk0_gemmm_gemmk1_global tensor + constexpr auto wei_gemmk0_gemmm_gemmk1_global_iterator_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + // hack to control index calculation when iterating over in_gemmk0_gemmn_gemmk1_global tensor + constexpr auto in_gemmk0_gemmn_gemmk1_global_iterator_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 1, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks = + Sequence<0, 1, 2, 0, 0>{}; + + // hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global + // tensor hack for NKHW format + constexpr auto out_m0_m1_m2_n_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); + + return make_tuple(wei_gemmk0_gemmm_gemmk1_global_desc, + in_gemmk0_gemmn_gemmk1_global_desc, + out_m0_m1_m2_n_global_desc, + out_gemm_block_cluster_desc, + wei_gemmk0_gemmm_gemmk1_global_iterator_hacks, + in_gemmk0_gemmn_gemmk1_global_iterator_hacks, + out_m0_m1_m2_n_global_iterator_hacks, + wei_gemmk0_gemmm_gemmk1_global_move_slice_window_iterator_hacks, + in_gemmk0_gemmn_gemmk1_global_move_slice_window_iterator_hacks); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v1.hpp b/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v1.hpp new file mode 100644 index 0000000000..5cde70cba9 --- /dev/null +++ b/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v1.hpp @@ -0,0 +1,384 @@ +#ifndef CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V1 +#define CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V1 + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm_xdlops.hpp" +#include "gridwise_operation_wrapper.hpp" + +namespace ck { + +template +__host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global, + const FloatAB* p_b_global, + FloatC* p_c_global, + const AGlobalDesc& a_k_m_global_desc, + const BGlobalDesc& b_k_n_global_desc, + const CGlobalDesc& c_m0_m1_n0_n1_global_desc, + const CBlockClusterDesc& c_block_cluster_desc, + AGlobalIteratorHacks, + BGlobalIteratorHacks, + CGlobalIteratorHacks, + AGlobalMoveSliceWindowIteratorHacks, + BGlobalMoveSliceWindowIteratorHacks, + index_t nrepeat) + +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto M = a_k_m_global_desc.GetLength(I1); + const auto N = b_k_n_global_desc.GetLength(I1); + const auto K = a_k_m_global_desc.GetLength(I0); + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + throw std::runtime_error("wrong! GEMM size no divisible"); + } + + if(!(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0)) + { + throw std::runtime_error("wrong! GEMM size no divisible"); + } + + // GEMM + using gridwise_gemm = + GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1; + + const auto GridSize = (M / MPerBlock) * (N / NPerBlock); + + const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1; + + const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0; + + std::cerr << "has_main_k_block_loop = " << has_main_k_block_loop + << " has_double_tail_k_block_loop = " << has_double_tail_k_block_loop << std::endl; + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = kernel_dynamic_gemm_xdlops_v1, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + p_a_global, + p_b_global, + p_c_global, + a_k_m_global_desc, + b_k_n_global_desc, + c_m0_m1_n0_n1_global_desc, + c_block_cluster_desc); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = kernel_dynamic_gemm_xdlops_v1, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + p_a_global, + p_b_global, + p_c_global, + a_k_m_global_desc, + b_k_n_global_desc, + c_m0_m1_n0_n1_global_desc, + c_block_cluster_desc); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = kernel_dynamic_gemm_xdlops_v1, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + p_a_global, + p_b_global, + p_c_global, + a_k_m_global_desc, + b_k_n_global_desc, + c_m0_m1_n0_n1_global_desc, + c_block_cluster_desc); + } + else + { + const auto kernel = kernel_dynamic_gemm_xdlops_v1, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + p_a_global, + p_b_global, + p_c_global, + a_k_m_global_desc, + b_k_n_global_desc, + c_m0_m1_n0_n1_global_desc, + c_block_cluster_desc); + } + + return ave_time; +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc)); + DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc)); + DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc)); + DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc)); + + a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc); + b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc); + c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc); + c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc); + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = kernel_dynamic_gemm_xdlops_v1, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + p_a_global, + p_b_global, + p_c_global, + (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = kernel_dynamic_gemm_xdlops_v1, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + p_a_global, + p_b_global, + p_c_global, + (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = kernel_dynamic_gemm_xdlops_v1, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + p_a_global, + p_b_global, + p_c_global, + (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); + } + else + { + const auto kernel = kernel_dynamic_gemm_xdlops_v1, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + p_a_global, + p_b_global, + p_c_global, + (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); + } + + return ave_time; +#endif +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2.hpp b/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2.hpp new file mode 100644 index 0000000000..e7462b919c --- /dev/null +++ b/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2.hpp @@ -0,0 +1,202 @@ +#ifndef CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2 +#define CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2 + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm_xdlops_v2.hpp" +#include "gridwise_operation_wrapper.hpp" + +namespace ck { + +template +__host__ float launch_kernel_dynamic_gemm_xdlops_v2(const FloatAB* p_a_global, + const FloatAB* p_b_global, + FloatC* p_c_global, + const AGlobalDesc& a_k_m_global_desc, + const BGlobalDesc& b_k_n_global_desc, + const CGlobalDesc& c_m0_m1_n0_n1_global_desc, + const CBlockClusterDesc& c_block_cluster_desc, + AGlobalIteratorHacks, + BGlobalIteratorHacks, + CGlobalIteratorHacks, + AGlobalMoveSliceWindowIteratorHacks, + BGlobalMoveSliceWindowIteratorHacks, + index_t nrepeat) + +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto M = a_k_m_global_desc.GetLength(I1); + const auto N = b_k_n_global_desc.GetLength(I1); + const auto K = a_k_m_global_desc.GetLength(I0); + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + throw std::runtime_error("wrong! GEMM size no divisible"); + } + + if(!(MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0)) + { + throw std::runtime_error("wrong! GEMM size no divisible"); + } + + // GEMM + using gridwise_gemm = + GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2; + + const auto GridSize = (M / MPerBlock) * (N / NPerBlock); + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + float ave_time = 0; + + const auto kernel = kernel_dynamic_gemm_xdlops_v2, + remove_reference_t, + remove_reference_t, + remove_reference_t>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + p_a_global, + p_b_global, + p_c_global, + a_k_m_global_desc, + b_k_n_global_desc, + c_m0_m1_n0_n1_global_desc, + c_block_cluster_desc); + + return ave_time; +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc)); + DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc)); + DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc)); + DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc)); + + a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc); + b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc); + c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc); + c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc); + + float ave_time = 0; + + const auto kernel = kernel_dynamic_gemm_xdlops_v1, + remove_reference_t, + remove_reference_t, + remove_reference_t>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + p_a_global, + p_b_global, + p_c_global, + (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); + + return ave_time; +#endif +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2r2.hpp b/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2r2.hpp new file mode 100644 index 0000000000..6bf53e06e2 --- /dev/null +++ b/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2r2.hpp @@ -0,0 +1,167 @@ +#ifndef CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2R2 +#define CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2R2 + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm_xdlops_v2r2.hpp" + +namespace ck { + +template +__host__ float driver_dynamic_gemm_xdlops_v2r2(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const AK0MK1GridDesc& a_k0_m_k1_grid_desc, + const BK0NK1GridDesc& b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + AGridIteratorHacks, + BGridIteratorHacks, + CGridIteratorHacks, + AGridMoveSliceWindowIteratorHacks, + BGridMoveSliceWindowIteratorHacks, + index_t nrepeat) + +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + + using GridwiseGemm = + GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r2; + + { + std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", " + << a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2) + << "}" << std::endl; + + std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", " + << b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2) + << "}" << std::endl; + + std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", " + << c_m_n_grid_desc.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) + { + throw std::runtime_error( + "wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2r2 has invalid setting"); + } + + const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + + using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc); + + const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + + using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); + + const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc); + + const auto kernel = kernel_dynamic_gemm_xdlops_v2r2, + remove_reference_t, + remove_reference_t, + remove_reference_t>; + + float ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_m1_m2_n_grid_desc, + c_block_cluster_adaptor); + + return ave_time; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2r3.hpp b/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2r3.hpp new file mode 100644 index 0000000000..ca65c2d073 --- /dev/null +++ b/composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v2r3.hpp @@ -0,0 +1,169 @@ +#ifndef CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3 +#define CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3 + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp" + +namespace ck { + +template +__host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const AK0MK1GridDesc& a_k0_m_k1_grid_desc, + const BK0NK1GridDesc& b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + AGridIteratorHacks, + BGridIteratorHacks, + CGridIteratorHacks, + AGridMoveSliceWindowIteratorHacks, + BGridMoveSliceWindowIteratorHacks, + index_t nrepeat) + +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + + using GridwiseGemm = + GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + + { + std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", " + << a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2) + << "}" << std::endl; + + std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", " + << b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2) + << "}" << std::endl; + + std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", " + << c_m_n_grid_desc.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) + { + throw std::runtime_error( + "wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); + } + + const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + + using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc); + + const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + + using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); + + const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc); + + const auto kernel = kernel_dynamic_gemm_xdlops_v2r3, + remove_reference_t, + remove_reference_t, + remove_reference_t>; + + float ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_m1_m2_n_grid_desc, + c_block_cluster_adaptor); + + return ave_time; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp index bc18872b38..e9266ca220 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp @@ -116,8 +116,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw constexpr index_t GemmN = N * HTildaSlice * WTildaSlice; // GemmK is different for each GEMM - index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot; - index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot; + index_t YDotSlice = math::integer_divide_ceil(Y - iYTilda, YTilda); + index_t XDotSlice = math::integer_divide_ceil(X - iXTilda, XTilda); index_t GemmK = K * YDotSlice * XDotSlice; @@ -176,8 +176,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); - constexpr index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot; - constexpr index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot; + constexpr index_t YDotSlice = math::integer_divide_ceil(Y - iYTilda, YTilda); + constexpr index_t XDotSlice = math::integer_divide_ceil(X - iXTilda, XTilda); constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp index 1e8eb7cea1..e47f2fce01 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp @@ -118,8 +118,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk constexpr index_t GemmN = N * HTildaSlice * WTildaSlice; // GemmK is different for each GEMM - index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot; - index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot; + index_t YDotSlice = math::integer_divide_ceil(Y - iYTilda, YTilda); + index_t XDotSlice = math::integer_divide_ceil(X - iXTilda, XTilda); index_t GemmK0 = YDotSlice; index_t GemmK1 = XDotSlice; @@ -180,8 +180,8 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda); constexpr index_t XDot = math::integer_divide_ceil(X, XTilda); - constexpr index_t YDotSlice = (iYTilda + 1) * YDot <= Y ? YDot : Y % YDot; - constexpr index_t XDotSlice = (iXTilda + 1) * XDot <= X ? XDot : X % XDot; + constexpr index_t YDotSlice = math::integer_divide_ceil(Y - iYTilda, YTilda); + constexpr index_t XDotSlice = math::integer_divide_ceil(X - iXTilda, XTilda); constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH); diff --git a/composable_kernel/include/kernel_algorithm/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/kernel_algorithm/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..5c582dea46 --- /dev/null +++ b/composable_kernel/include/kernel_algorithm/transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,272 @@ +#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +namespace ck { + +// Number of GEMMs = YTilda * XTilda +// GemmM = C +// GemmN = N * HTildaSlice * WTildaSlice +// GemmK = K * YDotSlice * XDotSlice +template +__host__ __device__ constexpr auto +transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk( + const DynamicTensorDescriptor& wei_k_y_x_c_grid_desc, + const DynamicTensorDescriptor& out_n_ho_wo_k_grid_desc, + const DynamicTensorDescriptor& in_n_hi_wi_c_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number, + Number, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + constexpr auto IYTilda = Number{}; + constexpr auto IXTilda = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilda = ConvStrideH / GcdStrideDilationH; + const auto XTilda = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilda); + const auto XDot = math::integer_divide_ceil(X, XTilda); + + const auto HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilda and WTilda that contribute to non-padding area of input tensor + const auto IHTildaSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH); + const auto IWTildaSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW); + + const auto IHTildaSliceEnd = + math::min(HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildaSliceEnd = + math::min(WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin; + const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin; + + // GemmK is different for each GEMM + const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda); + const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda); + + const auto K1 = GemmK1; + const auto K0 = K / K1; + + // weight tensor + const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_dynamic_tensor_descriptor( + wei_k_y_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(YDot, YTilda), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilda), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = + transform_dynamic_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(IYTilda), + make_freeze_transform(IXTilda), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + +#if 1 + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#else + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<0, 2, 3>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#endif + + // output tensor + // this add padding check + const auto out_n_hop_wop_k_grid_desc = transform_dynamic_tensor_descriptor( + out_n_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_dynamic_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilda), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilda), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc = + transform_dynamic_tensor_descriptor( + out_n_ydot_htilda_xdot_wtilda_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + +#if 1 + const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#else + const auto out_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), + make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#endif + + // input tensor + const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilda, HTilda), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilda, WTilda), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(IYTilda), + make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), + make_freeze_transform(IXTilda), + make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( + in_n_htildaslice_wtildaslice_c_grid_desc, + make_tuple(make_pass_through_transform(C), + make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/kernel_algorithm/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/kernel_algorithm/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..377a1ac29b --- /dev/null +++ b/composable_kernel/include/kernel_algorithm/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,275 @@ +#ifndef CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_BACKWARD_DATA_CONVOLUTION_INTO_GEMM_V4R1R2_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +namespace ck { + +// A: out +// B: wei +// C: in +// Number of GEMMs = YTilda * XTilda +// GemmM = N * HTildaSlice * WTildaSlice +// GemmN = C +// GemmK = K * YDotSlice * XDotSlice +template +__host__ __device__ constexpr auto +transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( + const DynamicTensorDescriptor& out_n_ho_wo_k_grid_desc, + const DynamicTensorDescriptor& wei_k_y_x_c_grid_desc, + const DynamicTensorDescriptor& in_n_hi_wi_c_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number, + Number, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + constexpr auto IYTilda = Number{}; + constexpr auto IXTilda = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilda = ConvStrideH / GcdStrideDilationH; + const auto XTilda = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilda); + const auto XDot = math::integer_divide_ceil(X, XTilda); + + const auto HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilda and WTilda that contribute to non-padding area of input tensor + const auto IHTildaSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilda - I1)), ConvStrideH); + const auto IWTildaSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilda - I1)), ConvStrideW); + + const auto IHTildaSliceEnd = + math::min(HTilda, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildaSliceEnd = + math::min(WTilda, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildaSlice = IHTildaSliceEnd - IHTildaSliceBegin; + const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin; + + // GemmK is different for each GEMM + const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda); + const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda); + + const auto K1 = GemmK1; + const auto K0 = K / K1; + + // A: output tensor + // this add padding check + const auto out_n_hop_wop_k_grid_desc = transform_dynamic_tensor_descriptor( + out_n_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilda_xdot_wtilda_k_grid_desc = transform_dynamic_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilda), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilda), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc = + transform_dynamic_tensor_descriptor( + out_n_ydot_htilda_xdot_wtilda_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + +#if 1 + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#else + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), + make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<5, 1, 3>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#endif + + // B: weight tensor + const auto wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc = transform_dynamic_tensor_descriptor( + wei_k_y_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(YDot, YTilda), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilda), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = + transform_dynamic_tensor_descriptor(wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(IYTilda), + make_freeze_transform(IXTilda), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + +#if 1 + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#else + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(K0, YDotSlice, XDotSlice)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<0, 2, 3>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); +#endif + + // C: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilda, HTilda), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilda, WTilda), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(IYTilda), + make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), + make_freeze_transform(IXTilda), + make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( + in_n_htildaslice_wtildaslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildaSlice, WTildaSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp index 987b3460c1..79051d9512 100644 --- a/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp +++ b/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp @@ -18,9 +18,9 @@ template __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad( - const DynamicTensorDescriptor& wei_k_y_x_c_global_desc, - const DynamicTensorDescriptor& in_n_hi_wi_c_global_desc, - const DynamicTensorDescriptor& out_n_ho_wo_k_global_desc, + const DynamicTensorDescriptor& wei_k_y_x_c_grid_desc, + const DynamicTensorDescriptor& in_n_hi_wi_c_grid_desc, + const DynamicTensorDescriptor& out_n_ho_wo_k_grid_desc, const ConvStrides& conv_strides, const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, @@ -31,18 +31,18 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - const auto N = in_n_hi_wi_c_global_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_global_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_global_desc.GetLength(I3); + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - const auto Hi = in_n_hi_wi_c_global_desc.GetLength(I1); - const auto Wi = in_n_hi_wi_c_global_desc.GetLength(I2); + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); - const auto Ho = out_n_ho_wo_k_global_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_global_desc.GetLength(I2); + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); - const auto Y = wei_k_y_x_c_global_desc.GetLength(I1); - const auto X = wei_k_y_x_c_global_desc.GetLength(I2); + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); const auto ConvStrideH = conv_strides[I0]; const auto ConvStrideW = conv_strides[I1]; @@ -57,15 +57,15 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ const auto InRightPadW = in_right_pads[I1]; // weight tensor - const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( + const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); // input tensor - const auto in_n_hip_wip_c_global_desc = transform_dynamic_tensor_descriptor( - in_n_hi_wi_c_global_desc, + const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_hi_wi_c_grid_desc, make_tuple(make_pass_through_transform(N), make_pad_transform(Hi, InLeftPadH, InRightPadH), make_pad_transform(Wi, InLeftPadW, InRightPadW), @@ -73,8 +73,8 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - const auto in_n_y_ho_x_wo_c_global_desc = transform_dynamic_tensor_descriptor( - in_n_hip_wip_c_global_desc, + const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_hip_wip_c_grid_desc, make_tuple(make_pass_through_transform(N), make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), @@ -82,22 +82,22 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - const auto in_gemmk_gemmn_global_desc = - transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_global_desc, + const auto in_gemmk_gemmn_grid_desc = + transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, make_tuple(make_merge_transform(make_tuple(Y, X, C)), make_merge_transform(make_tuple(N, Ho, Wo))), make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); // output tensor - const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( + const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); return make_tuple( - wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); + wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc); } template __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1( - const DynamicTensorDescriptor& wei_k_y_x_c_global_desc, - const DynamicTensorDescriptor& in_n_hi_wi_c_global_desc, - const DynamicTensorDescriptor& out_n_ho_wo_k_global_desc, + const DynamicTensorDescriptor& wei_k_y_x_c_grid_desc, + const DynamicTensorDescriptor& in_n_hi_wi_c_grid_desc, + const DynamicTensorDescriptor& out_n_ho_wo_k_grid_desc, const ConvStrides& conv_strides, const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, @@ -121,18 +121,18 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - const auto N = in_n_hi_wi_c_global_desc.GetLength(I0); - const auto C = in_n_hi_wi_c_global_desc.GetLength(I3); - const auto K = out_n_ho_wo_k_global_desc.GetLength(I3); + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); - const auto Hi = in_n_hi_wi_c_global_desc.GetLength(I1); - const auto Wi = in_n_hi_wi_c_global_desc.GetLength(I2); + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); - const auto Ho = out_n_ho_wo_k_global_desc.GetLength(I1); - const auto Wo = out_n_ho_wo_k_global_desc.GetLength(I2); + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); - const auto Y = wei_k_y_x_c_global_desc.GetLength(I1); - const auto X = wei_k_y_x_c_global_desc.GetLength(I2); + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); const auto ConvStrideH = conv_strides[I0]; const auto ConvStrideW = conv_strides[I1]; @@ -151,28 +151,28 @@ __host__ __device__ constexpr auto transform_forward_convolution_into_gemm_v4r4_ InRightPadW == 0); // weight tensor - const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor( + const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C)), make_tuple(make_pass_through_transform(K), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); // input tensor - const auto in_gemmk_gemmn_global_desc = transform_dynamic_tensor_descriptor( + const auto in_gemmk_gemmn_grid_desc = transform_dynamic_tensor_descriptor( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, C)), make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); // output tensor - const auto out_gemmm_gemmn_global_desc = transform_dynamic_tensor_descriptor( + const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); return make_tuple( - wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); + wei_gemmk_gemmm_grid_desc, in_gemmk_gemmn_grid_desc, out_gemmm_gemmn_grid_desc); } } // namespace ck diff --git a/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..49ae26518e --- /dev/null +++ b/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp @@ -0,0 +1,129 @@ +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmN = N * Ho * Wo +// GemmK = C * Y * X +template +__host__ __device__ constexpr auto +transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( + const DynamicTensorDescriptor& wei_k_c_y_x_grid_desc, + const DynamicTensorDescriptor& in_n_c_hi_wi_grid_desc, + const DynamicTensorDescriptor& out_n_k_ho_wo_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); + const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = N * Ho * Wo; + const auto GemmK = C * Y * X; + const auto GemmK0 = GemmK / GemmK1; + + // weight tensor + const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + wei_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // input tensor + const auto in_n_c_hip_wip_grid_desc = transform_dynamic_tensor_descriptor( + in_n_c_hi_wi_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_grid_desc = transform_dynamic_tensor_descriptor( + in_n_c_hip_wip_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_gemmk_gemmn_grid_desc = + transform_dynamic_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + in_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..5814e66766 --- /dev/null +++ b/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,129 @@ +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R2_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmN = N * Ho * Wo +// GemmK = C * Y * X +template +__host__ __device__ constexpr auto +transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad( + const DynamicTensorDescriptor& wei_k_y_x_c_grid_desc, + const DynamicTensorDescriptor& in_n_hi_wi_c_grid_desc, + const DynamicTensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = N * Ho * Wo; + const auto GemmK = C * Y * X; + const auto GemmK0 = GemmK / GemmK1; + + // weight tensor + const auto wei_gemmk_gemmm_grid_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + wei_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // input tensor + const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmn_grid_desc = + transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + in_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + return make_tuple(wei_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..0b0d8d961e --- /dev/null +++ b/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,132 @@ +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +namespace ck { + +// A: in +// B: wei +// C: out +// GemmM = N * Ho * Wo +// GemmN = K +// GemmK = C * Y * X +template +__host__ __device__ constexpr auto +transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( + const DynamicTensorDescriptor& in_n_hi_wi_c_grid_desc, + const DynamicTensorDescriptor& wei_k_y_x_c_grid_desc, + const DynamicTensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = N * Ho * Wo; + const auto GemmN = K; + const auto GemmK = Y * X * C; + const auto GemmK0 = GemmK / GemmK1; + + // A: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_dynamic_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmm_grid_desc = + transform_dynamic_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmk_gemmn_grid_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_dynamic_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmm_gemmn_grid_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp b/composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp index 145099095f..c6c2699342 100644 --- a/composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp +++ b/composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp @@ -1417,6 +1417,7 @@ struct DynamicUnMerge printf("DynamicUnMerge, "); printf("up_lengths_"); print_multi_index(up_lengths_); + printf("up_lengths_scan_"); print_multi_index(up_lengths_scan_); printf("}"); } @@ -1439,12 +1440,12 @@ struct DynamicFreeze template __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, - const UpIdx& idx_up) const + const UpIdx& /* idx_up */) const { static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 0, "wrong! inconsistent # of dimension"); - idx_low = low_idx_; + idx_low(Number<0>{}) = low_idx_; } template __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, - const UpIdxDiff& idx_diff_up, - LowIdx& idx_low, - const UpIdx& idx_up_new, + const UpIdxDiff& /* idx_diff_up */, + LowIdx& /* idx_low */, + const UpIdx& /* idx_up_new */, Number) { idx_diff_low(Number<0>{}) = 0; @@ -1487,6 +1488,73 @@ struct DynamicFreeze } }; +// Insert a dangling upper dimension without lower dimension +template +struct DynamicInsert +{ + using UpLengths = decltype(make_tuple(UpperLength{})); + + UpLengths up_lengths_; + + __host__ __device__ constexpr DynamicInsert() = default; + + __host__ __device__ constexpr DynamicInsert(const UpperLength& up_length) + : up_lengths_{make_tuple(up_length)} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 0; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr auto GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx&, const UpIdx&) const + { + static_assert(LowIdx::Size() == 0 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + } + + template + __host__ __device__ static void + UpdateLowerIndex(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&, Number) + { + static_assert(LowIdxDiff::Size() == 0 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 0 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("DynamicInsert"); + print_multi_index(up_lengths_); + } +}; + template struct DynamicVectorize { @@ -1572,5 +1640,99 @@ struct DynamicVectorize } }; +template +struct DynamicSlice +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(SliceEnd{} - SliceBegin{})); + + UpLengths up_lengths_; + SliceBegin slice_begin_; + SliceEnd slice_end_; + + __host__ __device__ constexpr DynamicSlice() = default; + + __host__ __device__ constexpr DynamicSlice(const LowLength& low_length, + const SliceBegin& slice_begin, + const SliceEnd& slice_end) + : up_lengths_{make_tuple(slice_end - slice_begin)}, + slice_begin_{slice_begin}, + slice_end_{slice_end} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}] + slice_begin_; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("DynamicSlice, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("slice_begin_ %d", index_t{slice_begin_}); + printf("slice_end %d", index_t{slice_end_}); + printf("}"); + } +}; + } // namespace ck #endif diff --git a/composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp b/composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp index b27f0507c8..b3e1c60485 100644 --- a/composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp +++ b/composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp @@ -85,6 +85,14 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i return DynamicFreeze{low_idx}; } +template +__host__ __device__ constexpr auto make_slice_transform(const LowLength& low_length, + const SliceBegin& slice_begin, + const SliceEnd& slice_end) +{ + return DynamicSlice{low_length, slice_begin, slice_end}; +} + template __host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size, const UpLength& up_length) diff --git a/composable_kernel/include/tensor_description/dynamic_tensor_descriptor_helper.hpp b/composable_kernel/include/tensor_description/dynamic_tensor_descriptor_helper.hpp index 9b7db43664..2e36451a66 100644 --- a/composable_kernel/include/tensor_description/dynamic_tensor_descriptor_helper.hpp +++ b/composable_kernel/include/tensor_description/dynamic_tensor_descriptor_helper.hpp @@ -137,7 +137,7 @@ make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple& lengths math::multiplies_v2{}, Number{}, i + I1, - Number{}, + Number{}, I1); } }, diff --git a/composable_kernel/include/tensor_description/multi_index_transform.hpp b/composable_kernel/include/tensor_description/multi_index_transform.hpp index b612e1e52f..d4f23b8459 100644 --- a/composable_kernel/include/tensor_description/multi_index_transform.hpp +++ b/composable_kernel/include/tensor_description/multi_index_transform.hpp @@ -121,7 +121,7 @@ struct Slice SliceEnds::GetSize() == nDim, "wrong! # of dimensions not consistent"); -#if 0 +#if 0 // TODO: would not compile, error on constexpr static_for<0, nDim, 1>{}([&](auto idim) { static_assert(SliceBegins::At(idim) <= SliceEnds::At(idim) && diff --git a/composable_kernel/include/tensor_description/tensor_adaptor.hpp b/composable_kernel/include/tensor_description/tensor_adaptor.hpp index 3623c92f21..8336fea2ae 100644 --- a/composable_kernel/include/tensor_description/tensor_adaptor.hpp +++ b/composable_kernel/include/tensor_description/tensor_adaptor.hpp @@ -184,6 +184,18 @@ struct TensorAdaptor return get_container_subset(idx_hidden, BottomDimensionHiddenIds{}); } + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + bool is_known = true; + + static_for<0, Transforms::Size(), 1>{}([&](auto i) { + is_known &= + remove_cv_t>::IsKnownAtCompileTime(); + }); + + return is_known && is_known_at_compile_time::value; + } + __host__ __device__ void Print() const { printf("{"); diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp new file mode 100644 index 0000000000..4b8133870e --- /dev/null +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp @@ -0,0 +1,528 @@ +#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP +#define CK_BLOCKWISE_GEMM_XDLOPS_HPP + +#include "common_header.hpp" +#include "threadwise_dynamic_tensor_slice_transfer.hpp" +#include "xdlops_gemm.hpp" + +namespace ck { + +template +struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 +{ + + using CIndex = MultiIndex<2>; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr index_t WaveSize = 64; + + static constexpr index_t M0 = ABlockDesc{}.GetLength(I1); + static constexpr index_t M1 = ABlockDesc{}.GetLength(I2); + + static constexpr index_t N0 = BBlockDesc{}.GetLength(I1); + static constexpr index_t N1 = BBlockDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = XdlopsGemm{}; + + static constexpr index_t MWaves = M1 / MPerWave; + static constexpr index_t NWaves = N1 / NPerWave; + + static constexpr index_t MRepeat = M0; + static constexpr index_t NRepeat = N0; + + __device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); } + + __device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); } + + __device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const index_t thread_id = get_thread_local_1d_id(); + const index_t waveId = thread_id / WaveSize; + const index_t laneId = thread_id % WaveSize; + const index_t waveId_m = waveId / NWaves; + const index_t waveId_n = waveId % NWaves; + + if constexpr(xdlops_gemm.IsKReduction) + { + const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId); + const index_t k_offset = xdlops_gemm.GetBlkId(laneId); + return make_tuple(k_offset, 0, m_offset, 0); + } + else + { + const index_t m_offset = waveId_m * MPerWave + laneId; + const index_t k_offset = 0; + return make_tuple(k_offset, 0, m_offset, 0); + } + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const index_t thread_id = get_thread_local_1d_id(); + const index_t waveId = thread_id / WaveSize; + const index_t laneId = thread_id % WaveSize; + const index_t waveId_m = waveId / NWaves; + const index_t waveId_n = waveId % NWaves; + + if constexpr(xdlops_gemm.IsKReduction) + { + const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId); + const index_t k_offset = xdlops_gemm.GetBlkId(laneId); + return make_tuple(k_offset, 0, n_offset, 0); + } + else + { + const index_t n_offset = waveId_n * NPerWave + laneId; + const index_t k_offset = 0; + return make_tuple(k_offset, 0, n_offset, 0); + } + } + + template + __device__ static CIndex + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + + const index_t waveId = get_thread_local_1d_id() / WaveSize; + + const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + const index_t waveId_m = waveId / NWaves; + const index_t waveId_n = waveId % NWaves; + + const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0]; + const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1]; + + return CIndex{m_offset, n_offset}; + } + + __device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1() + : a_thread_copy_{CalculateAThreadOriginDataIndex()}, + b_thread_copy_{CalculateBThreadOriginDataIndex()} + { + static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), + "wrong! K dimension not consistent"); + + static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), + "wrong! KPack dimension not consistent"); + + static_assert(BlockSize == MWaves * NWaves * WaveSize, + "BlockSize != MWaves * NWaves * WaveSize\n"); + + static_assert(KPack == BBlockDesc{}.GetLength(I3), "KPack is wrong!"); + + constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); + + static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!"); + + static_assert(KPack % xdlops_gemm.mfma_type.k_base == 0, "KPack is wrong!"); + } + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = + make_static_buffer(a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = + make_static_buffer(b_thread_desc_.GetElementSpaceSize()); + + constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); + + vector_type a_thread_vec; + + vector_type b_thread_vec; + + static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) { + // read A + a_thread_copy_.Run(ABlockDesc{}, + make_tuple(k, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + // read B + b_thread_copy_.Run(BBlockDesc{}, + make_tuple(k, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + using mfma_input_type = + typename vector_type::type; + + static_for<0, a_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) { + a_thread_vec.template AsType()(Number{}) = a_thread_buf[Number{}]; + }); + + static_for<0, b_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) { + b_thread_vec.template AsType()(Number{}) = b_thread_buf[Number{}]; + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + xdlops_gemm.template Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf); + }); + }); + }); + } + + private: + // A[K, M] + static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(I1, Number{}, I1, Number{})); + + // B[K, N] + static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(I1, Number{}, I1, Number{})); + + static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{})); + + using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + KPack, + 1>; + + using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + KPack, + 1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +template +struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline +{ + + using CIndex = MultiIndex<2>; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto xdlops_gemm = XdlopsGemm{}; + + static constexpr index_t WaveSize = 64; + + static constexpr index_t M0 = ABlockDesc{}.GetLength(I1); + static constexpr index_t M1 = ABlockDesc{}.GetLength(I2); + + static constexpr index_t N0 = BBlockDesc{}.GetLength(I1); + static constexpr index_t N1 = BBlockDesc{}.GetLength(I2); + + static constexpr index_t MWaves = M1 / MPerWave; + static constexpr index_t NWaves = N1 / NPerWave; + + static constexpr index_t MRepeat = M0; + static constexpr index_t NRepeat = N0; + + __device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); } + + __device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); } + + __device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const index_t thread_id = get_thread_local_1d_id(); + const index_t waveId = thread_id / WaveSize; + const index_t laneId = thread_id % WaveSize; + const index_t waveId_m = waveId / NWaves; + const index_t waveId_n = waveId % NWaves; + + if constexpr(xdlops_gemm.IsKReduction) + { + const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId); + const index_t k_offset = xdlops_gemm.GetBlkId(laneId); + return make_tuple(k_offset, 0, m_offset, 0); + } + else + { + const index_t m_offset = waveId_m * MPerWave + laneId; + const index_t k_offset = 0; + return make_tuple(k_offset, 0, m_offset, 0); + } + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const index_t thread_id = get_thread_local_1d_id(); + const index_t waveId = thread_id / WaveSize; + const index_t laneId = thread_id % WaveSize; + const index_t waveId_m = waveId / NWaves; + const index_t waveId_n = waveId % NWaves; + + if constexpr(xdlops_gemm.IsKReduction) + { + const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId); + const index_t k_offset = xdlops_gemm.GetBlkId(laneId); + return make_tuple(k_offset, 0, n_offset, 0); + } + else + { + const index_t n_offset = waveId_n * NPerWave + laneId; + const index_t k_offset = 0; + return make_tuple(k_offset, 0, n_offset, 0); + } + } + + template + __device__ static CIndex + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + + const index_t waveId = get_thread_local_1d_id() / WaveSize; + + const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + const index_t waveId_m = waveId / NWaves; + const index_t waveId_n = waveId % NWaves; + + const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0]; + const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1]; + + return CIndex{m_offset, n_offset}; + } + + __device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline() + : a_thread_copy_{CalculateAThreadOriginDataIndex()}, + b_thread_copy_{CalculateBThreadOriginDataIndex()} + { + static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), + "wrong! K dimension not consistent"); + + static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), + "wrong! KPack dimension not consistent"); + + static_assert(BlockSize == MWaves * NWaves * WaveSize, + "BlockSize != MWaves * NWaves * WaveSize\n"); + + static_assert(KPack == BBlockDesc{}.GetLength(I3), "KPack is wrong!"); + + constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); + + static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!"); + + static_assert(KPack % xdlops_gemm.mfma_type.k_base == 0, "KPack is wrong!"); + } + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = + make_static_buffer(a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = + make_static_buffer(b_thread_desc_.GetElementSpaceSize()); + + constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); + + // read A_sub_0 + a_thread_copy_.Run(ABlockDesc{}, + make_tuple(I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + // read B_sub_0 + b_thread_copy_.Run(BBlockDesc{}, + make_tuple(I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + // read B_sub_1 + b_thread_copy_.Run(BBlockDesc{}, + make_tuple(I0, I1, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I1, I0, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(ABlockDesc{}, + make_tuple(I0, I1, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I1, I0, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + + static_for{}([&](auto k) { + // read A_sub_0 + a_thread_copy_.Run(ABlockDesc{}, + make_tuple(k, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + + // read B_sub_0 + b_thread_copy_.Run(BBlockDesc{}, + make_tuple(k, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + + // read B_sub_1 + b_thread_copy_.Run(BBlockDesc{}, + make_tuple(k, I1, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I1, I0, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(ABlockDesc{}, + make_tuple(k, I1, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I1, I0, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + }); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); + } + + private: + // A[K, M] + static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(I1, Number{}, I1, Number{})); + + // B[K, N] + static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(I1, Number{}, I1, Number{})); + + static constexpr auto c_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{})); + + using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + 1, // KPack, + 1>; + + using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + 1, // KPack, + 1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_v1r1.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_v1r1.hpp index 05d070b94c..915a8e28d4 100644 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_v1r1.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_v1r1.hpp @@ -101,6 +101,7 @@ struct GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1 static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; + // GM0 and GN0 need to known at compile-time static constexpr auto GM0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I0); static constexpr auto GN0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I2); @@ -140,7 +141,7 @@ struct GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1 { static_assert(is_known_at_compile_time>::value && is_known_at_compile_time>::value, - "wrong!"); + "wrong! GM0 and GN0 need to be known at compile-time"); const auto GM1 = a_gk_gm0_gm1_grid_desc.GetLength(I2); const auto GN1 = b_gk_gn0_gn1_grid_desc.GetLength(I2); diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp new file mode 100644 index 0000000000..3fe9eb3b36 --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp @@ -0,0 +1,585 @@ +#ifndef CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_HPP +#define CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_HPP + +#include "common_header.hpp" +#include "dynamic_multi_index_transform_helper.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "blockwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_dynamic_tensor_slice_set.hpp" + +namespace ck { + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_dynamic_gemm_xdlops_v1(const FloatA* __restrict__ p_a_global, + const FloatB* __restrict__ p_b_global, + FloatC* __restrict__ p_c_global, + const AGlobalDesc a_k0_m_k1_global_desc, + const BGlobalDesc b_k0_n_k1_global_desc, + const CGlobalDesc c_m0_m1_m2_n_global_desc, + const CBlockClusterDesc c_block_cluster_desc) +{ + GridwiseGemm::Run(p_a_global, + p_b_global, + p_c_global, + a_k0_m_k1_global_desc, + b_k0_n_k1_global_desc, + c_m0_m1_m2_n_global_desc, + c_block_cluster_desc, + integral_constant{}, + integral_constant{}); +} +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER +// pass tensor descriptor by __CONSTANT__ void pointer +// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to +// non-modifiable parameter address space, so compiler can enable corresponding optimization +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_dynamic_gemm_xdlops_v1(const FloatA* __restrict__ p_a_global, + const FloatB* __restrict__ p_b_global, + FloatC* __restrict__ p_c_global, + const void __CONSTANT__* p_a_k0_m_k1_global_desc, + const void __CONSTANT__* p_b_k0_n_k1_global_desc, + const void __CONSTANT__* p_c_m0_m1_m2_n_global_desc, + const void __CONSTANT__* p_c_block_cluster_desc) +{ + // first cast void __CONSTANT__ void* to void* + // second cast void* to Desc* + // the copy constructor of tensor descriptor doesn't take address_space(4) + const auto a_k0_m_k1_global_desc = + *reinterpret_cast((const void*)p_a_k0_m_k1_global_desc); + const auto b_k0_n_k1_global_desc = + *reinterpret_cast((const void*)p_b_k0_n_k1_global_desc); + const auto c_m0_m1_m2_n_global_desc = + *reinterpret_cast((const void*)p_c_m0_m1_m2_n_global_desc); + + const auto c_block_cluster_desc = + *reinterpret_cast((const void*)p_c_block_cluster_desc); + + GridwiseGemm::Run(p_a_global, + p_b_global, + p_c_global, + a_k0_m_k1_global_desc, + b_k0_n_k1_global_desc, + c_m0_m1_m2_n_global_desc, + c_block_cluster_desc, + integral_constant{}, + integral_constant{}); +} +#endif + +template +struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1 +{ + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = Number{}; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, Number{}), max_lds_align); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + return 2 * (a_block_space_size + b_block_space_size) * sizeof(FloatAB); + } + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_global, + const FloatAB* __restrict__ p_b_global, + FloatC* __restrict__ p_c_global, + const AGlobalDesc& a_k0_m_k1_global_desc, + const BGlobalDesc& b_k0_n_k1_global_desc, + const CGlobalDesc& c_m0_m1_m2_n_global_desc, + const CBlockClusterDesc& c_block_cluster_desc, + FloatAB* __restrict__ p_shared_block, + integral_constant, + integral_constant) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_k0_m_k1_global_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_k0_n_k1_global_desc.GetElementSpaceSize()); + auto c_global_buf = make_dynamic_buffer( + p_c_global, c_m0_m1_m2_n_global_desc.GetElementSpaceSize()); + + const auto K0 = a_k0_m_k1_global_desc.GetLength(I0); + const auto M = a_k0_m_k1_global_desc.GetLength(I1); + const auto N = b_k0_n_k1_global_desc.GetLength(I1); + const auto K1 = b_k0_n_k1_global_desc.GetLength(I2); + + // divide block work by [M, N] + const auto block_work_idx = + c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_global into SGPR + const index_t m_block_data_idx_on_global = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_global = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = Number{}; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, Number{}), max_lds_align); + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseDynamicTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_K_M_KPack, + ABlockTransferThreadClusterLengths_K_M_KPack, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_k0_m_k1_global_desc), + decltype(a_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_KPack, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_k0_m_k1_global_desc, + make_multi_index(0, m_block_data_idx_on_global, 0), + a_k0_m_k1_block_desc, + make_multi_index(0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = + BlockwiseDynamicTensorSliceTransfer_v4, + BBlockTransferThreadSliceLengths_K_N_KPack, + BBlockTransferThreadClusterLengths_K_N_KPack, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_k0_n_k1_global_desc), + decltype(b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_KPack, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_k0_n_k1_global_desc, + make_multi_index(0, n_block_data_idx_on_global, 0), + b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[KPerBlock, MPerBlock] is in LDS + // b_mtx[KPerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + static_assert(MPerBlock % (MPerWave * MRepeat) == 0 && + NPerBlock % (NPerWave * NRepeat) == 0, + "wrong!"); + + constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor( + a_k0_m_k1_block_desc, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor( + b_k0_n_k1_block_desc, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto blockwise_gemm = + BlockwiseGemmXdlops_km_kn_m0m1m2n_v1{}; + + constexpr auto CLayout = blockwise_gemm.GetCLayout(); + + constexpr index_t BlkSize = CLayout.GetBlkSize(); + constexpr index_t NumBlks = CLayout.GetNumBlks(); + constexpr index_t NumXdlops = CLayout.GetNumXdlops(); + + constexpr auto c_mr_nr_nx_desc = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{}, Number{})); + + constexpr auto c_blk_nb_bs_desc = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{})); + + StaticBuffer, + c_mr_nr_nx_desc.GetElementSpaceSize()> + c_thread_buf; + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block_double = p_shared_block; + FloatAB* p_b_block_double = p_shared_block + 2 * a_block_space_size; + + // register allocation for output + // auto c_thread_buf = make_static_buffer( + // c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize()); + + // ThreadwiseDynamicTensorSliceSet_v1>{} + //.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); + + // hack to control index calculation when iterating over A and B matrix for threadwise copy + constexpr auto a_k0_m_k1_global_iterator_hacks = AGlobalIteratorHacks{}; + constexpr auto b_k0_n_k1_global_iterator_hacks = BGlobalIteratorHacks{}; + + // hack to control index calculation when move slice window for A and B matrix for + // threadwise copy + constexpr auto a_k0_m_k1_global_move_slice_window_iterator_hack = + AGlobalMoveSliceWindowIteratorHacks{}; + constexpr auto b_k0_n_k1_global_move_slice_window_iterator_hack = + BGlobalMoveSliceWindowIteratorHacks{}; + + auto a_block_even_buf = make_dynamic_buffer( + p_a_block_double, a_k0_m_k1_block_desc.GetElementSpaceSize()); + auto b_block_even_buf = make_dynamic_buffer( + p_b_block_double, b_k0_n_k1_block_desc.GetElementSpaceSize()); + + auto a_block_odd_buf = make_dynamic_buffer( + p_a_block_double + a_block_space_size, a_k0_m_k1_block_desc.GetElementSpaceSize()); + auto b_block_odd_buf = make_dynamic_buffer( + p_b_block_double + b_block_space_size, b_k0_n_k1_block_desc.GetElementSpaceSize()); + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.RunRead( + a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks); + b_blockwise_copy.RunRead( + b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks); + + a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_even_buf); + b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_even_buf); + } + + if constexpr(HasMainKBlockLoop) + { + index_t k_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + a_blockwise_copy.MoveSrcSliceWindow( + a_k0_m_k1_global_desc, + a_block_slice_copy_step, + a_k0_m_k1_global_move_slice_window_iterator_hack); + b_blockwise_copy.MoveSrcSliceWindow( + b_k0_n_k1_global_desc, + b_block_slice_copy_step, + b_k0_n_k1_global_move_slice_window_iterator_hack); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks); + b_blockwise_copy.RunRead( + b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks); + + asm volatile("s_nop 0"); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_odd_buf); + + // odd iteration + a_blockwise_copy.MoveSrcSliceWindow( + a_k0_m_k1_global_desc, + a_block_slice_copy_step, + a_k0_m_k1_global_move_slice_window_iterator_hack); + b_blockwise_copy.MoveSrcSliceWindow( + b_k0_n_k1_global_desc, + b_block_slice_copy_step, + b_k0_n_k1_global_move_slice_window_iterator_hack); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks); + b_blockwise_copy.RunRead( + b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks); + + asm volatile("s_nop 0"); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_even_buf); + b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_even_buf); + + k_block_data_begin += 2 * KPerBlock; + } while(k_block_data_begin < K0 - 2 * KPerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_global_desc, + a_block_slice_copy_step, + a_k0_m_k1_global_move_slice_window_iterator_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_global_desc, + b_block_slice_copy_step, + b_k0_n_k1_global_move_slice_window_iterator_hack); + + __syncthreads(); + + // LDS double buffer: load last data from device mem + a_blockwise_copy.RunRead( + a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks); + b_blockwise_copy.RunRead( + b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); + + // LDS double buffer: store last data to LDS + a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_odd_buf); + + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(a_block_odd_buf, b_block_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run(a_block_even_buf, b_block_even_buf, c_thread_buf); + } + + // output: register to global memory + { + + constexpr index_t M0 = CLayout.M1(); + constexpr index_t M1 = CLayout.N1(); + constexpr index_t M2 = CLayout.M0(); + + constexpr auto c_m0_m1_m2_n_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number<1>{}, Number{}, Number<1>{})); + + StaticBuffer c_blk_buf_; + + static_for<0, MRepeat, 1>{}([&](auto mr_i) { + static_for<0, NRepeat, 1>{}([&](auto nr_i) { + static_for<0, NumXdlops, 1>{}([&](auto xdlops_i) { + static_for<0, NumBlks, 1>{}([&](auto blk_i) { + auto c_blk = c_thread_buf[Number{}]; + + static_for<0, BlkSize, 1>{}([&](auto j) { + c_blk_buf_(j) = c_blk.template AsType()[Number< + c_blk_nb_bs_desc.CalculateOffset(make_tuple(blk_i, j))>{}]; + }); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex( + mr_i, nr_i, xdlops_i, blk_i); + + const index_t m_thread_data_on_global = + m_block_data_idx_on_global + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_global = + n_block_data_idx_on_global + c_thread_mtx_on_block[I1]; + + constexpr auto c_m0_m1_m2_n_global_tensor_iterator_hacks = + CGlobalIteratorHacks{}; + + ThreadwiseDynamicTensorSliceTransfer_v1r3< + FloatC, + FloatC, + decltype(c_m0_m1_m2_n_thread_desc), + decltype(c_m0_m1_m2_n_global_desc), + Sequence, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{c_m0_m1_m2_n_global_desc, + make_multi_index(m_thread_data_on_global / (M2 * M1), + m_thread_data_on_global % (M2 * M1) / M2, + m_thread_data_on_global % M2, + n_thread_data_on_global)} + .Run(c_m0_m1_m2_n_thread_desc, + make_tuple(I0, I0, I0, I0), + c_blk_buf_, + c_m0_m1_m2_n_global_desc, + c_global_buf, + c_m0_m1_m2_n_global_tensor_iterator_hacks); + }); + }); + }); + }); + } + } + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_global, + const FloatAB* __restrict__ p_b_global, + FloatC* __restrict__ p_c_global, + const AGlobalDesc& a_k0_m_k1_global_desc, + const BGlobalDesc& b_k0_n_k1_global_desc, + const CGlobalDesc& c_m0_m1_m2_n_global_desc, + const CBlockClusterDesc& c_block_cluster_desc, + integral_constant, + integral_constant) + { + constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + Run(p_a_global, + p_b_global, + p_c_global, + a_k0_m_k1_global_desc, + b_k0_n_k1_global_desc, + c_m0_m1_m2_n_global_desc, + c_block_cluster_desc, + p_shared_block, + integral_constant{}, + integral_constant{}); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2.hpp new file mode 100644 index 0000000000..4e1549355d --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2.hpp @@ -0,0 +1,498 @@ +#ifndef CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2_HPP +#define CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2_HPP + +#include "common_header.hpp" +#include "dynamic_multi_index_transform_helper.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "blockwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_dynamic_tensor_slice_set.hpp" + +namespace ck { + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_dynamic_gemm_xdlops_v2(const FloatA* __restrict__ p_a_global, + const FloatB* __restrict__ p_b_global, + FloatC* __restrict__ p_c_global, + const AGlobalDesc a_k0_m_k1_global_desc, + const BGlobalDesc b_k0_n_k1_global_desc, + const CGlobalDesc c_m0_m1_m2_n_global_desc, + const CBlockClusterDesc c_block_cluster_desc) +{ + GridwiseGemm::Run(p_a_global, + p_b_global, + p_c_global, + a_k0_m_k1_global_desc, + b_k0_n_k1_global_desc, + c_m0_m1_m2_n_global_desc, + c_block_cluster_desc); +} +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER +// pass tensor descriptor by __CONSTANT__ void pointer +// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to +// non-modifiable parameter address space, so compiler can enable corresponding optimization +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_dynamic_gemm_xdlops_v2(const FloatA* __restrict__ p_a_global, + const FloatB* __restrict__ p_b_global, + FloatC* __restrict__ p_c_global, + const void __CONSTANT__* p_a_k0_m_k1_global_desc, + const void __CONSTANT__* p_b_k0_n_k1_global_desc, + const void __CONSTANT__* p_c_m0_m1_m2_n_global_desc, + const void __CONSTANT__* p_c_block_cluster_desc) +{ + // first cast void __CONSTANT__ void* to void* + // second cast void* to Desc* + // the copy constructor of tensor descriptor doesn't take address_space(4) + const auto a_k0_m_k1_global_desc = + *reinterpret_cast((const void*)p_a_k0_m_k1_global_desc); + const auto b_k0_n_k1_global_desc = + *reinterpret_cast((const void*)p_b_k0_n_k1_global_desc); + const auto c_m0_m1_m2_n_global_desc = + *reinterpret_cast((const void*)p_c_m0_m1_m2_n_global_desc); + + const auto c_block_cluster_desc = + *reinterpret_cast((const void*)p_c_block_cluster_desc); + + GridwiseGemm::Run(p_a_global, + p_b_global, + p_c_global, + a_k0_m_k1_global_desc, + b_k0_n_k1_global_desc, + c_m0_m1_m2_n_global_desc, + c_block_cluster_desc, + integral_constant{}, + integral_constant{}); +} +#endif + +template +struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2 +{ + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = Number{}; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, Number{}), max_lds_align); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); + } + + __device__ static void Run(const FloatAB* __restrict__ p_a_global, + const FloatAB* __restrict__ p_b_global, + FloatC* __restrict__ p_c_global, + const AGlobalDesc& a_k0_m_k1_global_desc, + const BGlobalDesc& b_k0_n_k1_global_desc, + const CGlobalDesc& c_m0_m1_m2_n_global_desc, + const CBlockClusterDesc& c_block_cluster_desc, + FloatAB* __restrict__ p_shared_block) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto a_global_buf = make_dynamic_buffer( + p_a_global, a_k0_m_k1_global_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_global, b_k0_n_k1_global_desc.GetElementSpaceSize()); + auto c_global_buf = make_dynamic_buffer( + p_c_global, c_m0_m1_m2_n_global_desc.GetElementSpaceSize()); + + const auto K0 = a_k0_m_k1_global_desc.GetLength(I0); + const auto M = a_k0_m_k1_global_desc.GetLength(I1); + const auto N = b_k0_n_k1_global_desc.GetLength(I1); + const auto K1 = b_k0_n_k1_global_desc.GetLength(I2); + + // divide block work by [M, N] + const auto block_work_idx = + c_block_cluster_desc.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_global into SGPR + const index_t m_block_data_idx_on_global = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_global = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = Number{}; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, Number{}), max_lds_align); + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseDynamicTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_K_M_KPack, + ABlockTransferThreadClusterLengths_K_M_KPack, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_k0_m_k1_global_desc), + decltype(a_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_KPack, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_k0_m_k1_global_desc, + make_multi_index(0, m_block_data_idx_on_global, 0), + a_k0_m_k1_block_desc, + make_multi_index(0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = + BlockwiseDynamicTensorSliceTransfer_v4, + BBlockTransferThreadSliceLengths_K_N_KPack, + BBlockTransferThreadClusterLengths_K_N_KPack, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_k0_n_k1_global_desc), + decltype(b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_KPack, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_k0_n_k1_global_desc, + make_multi_index(0, n_block_data_idx_on_global, 0), + b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[KPerBlock, MPerBlock] is in LDS + // b_mtx[KPerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + static_assert(MPerBlock % (MPerWave * MRepeat) == 0 && + NPerBlock % (NPerWave * NRepeat) == 0, + "wrong!"); + + constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor( + a_k0_m_k1_block_desc, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor( + b_k0_n_k1_block_desc, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto blockwise_gemm = + BlockwiseGemmXdlops_km_kn_m0m1m2n_v1{}; + + constexpr auto CLayout = blockwise_gemm.GetCLayout(); + + constexpr index_t BlkSize = CLayout.GetBlkSize(); + constexpr index_t NumBlks = CLayout.GetNumBlks(); + constexpr index_t NumXdlops = CLayout.GetNumXdlops(); + + constexpr auto c_mr_nr_nx_desc = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{}, Number{})); + + constexpr auto c_blk_nb_bs_desc = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{})); + + StaticBuffer, + c_mr_nr_nx_desc.GetElementSpaceSize()> + c_thread_buf; + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block = p_shared_block; + FloatAB* p_b_block = p_shared_block + a_block_space_size; + + // register allocation for output + // auto c_thread_buf = make_static_buffer( + // c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize()); + + // ThreadwiseDynamicTensorSliceSet_v1>{} + //.Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); + + // hack to control index calculation when iterating over A and B matrix for threadwise copy + constexpr auto a_k0_m_k1_global_iterator_hacks = AGlobalIteratorHacks{}; + constexpr auto b_k0_n_k1_global_iterator_hacks = BGlobalIteratorHacks{}; + + // hack to control index calculation when move slice window for A and B matrix for + // threadwise copy + constexpr auto a_k0_m_k1_global_move_slice_window_iterator_hack = + AGlobalMoveSliceWindowIteratorHacks{}; + constexpr auto b_k0_n_k1_global_move_slice_window_iterator_hack = + BGlobalMoveSliceWindowIteratorHacks{}; + + auto a_block_buf = make_dynamic_buffer( + p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( + p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); + + // preload data into LDS + { + a_blockwise_copy.RunRead( + a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks); + b_blockwise_copy.RunRead( + b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks); + + a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); + } + + // main body + index_t k_block_data_begin = 0; + + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_global_desc, + a_block_slice_copy_step, + a_k0_m_k1_global_move_slice_window_iterator_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_global_desc, + b_block_slice_copy_step, + b_k0_n_k1_global_move_slice_window_iterator_hack); + + a_blockwise_copy.RunRead( + a_k0_m_k1_global_desc, a_global_buf, a_k0_m_k1_global_iterator_hacks); + b_blockwise_copy.RunRead( + b_k0_n_k1_global_desc, b_global_buf, b_k0_n_k1_global_iterator_hacks); + + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); + + k_block_data_begin += KPerBlock; + } while(k_block_data_begin < (K0 - KPerBlock)); + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + + // output: register to global memory + { + + constexpr index_t M0 = CLayout.M1(); + constexpr index_t M1 = CLayout.N1(); + constexpr index_t M2 = CLayout.M0(); + + constexpr auto c_m0_m1_m2_n_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number<1>{}, Number{}, Number<1>{})); + + StaticBuffer c_blk_buf_; + + static_for<0, MRepeat, 1>{}([&](auto mr_i) { + static_for<0, NRepeat, 1>{}([&](auto nr_i) { + static_for<0, NumXdlops, 1>{}([&](auto xdlops_i) { + static_for<0, NumBlks, 1>{}([&](auto blk_i) { + auto c_blk = c_thread_buf[Number{}]; + + static_for<0, BlkSize, 1>{}([&](auto j) { + c_blk_buf_(j) = c_blk.template AsType()[Number< + c_blk_nb_bs_desc.CalculateOffset(make_tuple(blk_i, j))>{}]; + }); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex( + mr_i, nr_i, xdlops_i, blk_i); + + const index_t m_thread_data_on_global = + m_block_data_idx_on_global + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_global = + n_block_data_idx_on_global + c_thread_mtx_on_block[I1]; + + constexpr auto c_m0_m1_m2_n_global_tensor_iterator_hacks = + CGlobalIteratorHacks{}; + + ThreadwiseDynamicTensorSliceTransfer_v1r3< + FloatC, + FloatC, + decltype(c_m0_m1_m2_n_thread_desc), + decltype(c_m0_m1_m2_n_global_desc), + Sequence, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{c_m0_m1_m2_n_global_desc, + make_multi_index(m_thread_data_on_global / (M2 * M1), + m_thread_data_on_global % (M2 * M1) / M2, + m_thread_data_on_global % M2, + n_thread_data_on_global)} + .Run(c_m0_m1_m2_n_thread_desc, + make_tuple(I0, I0, I0, I0), + c_blk_buf_, + c_m0_m1_m2_n_global_desc, + c_global_buf, + c_m0_m1_m2_n_global_tensor_iterator_hacks); + }); + }); + }); + }); + } + } + + __device__ static void Run(const FloatAB* __restrict__ p_a_global, + const FloatAB* __restrict__ p_b_global, + FloatC* __restrict__ p_c_global, + const AGlobalDesc& a_k0_m_k1_global_desc, + const BGlobalDesc& b_k0_n_k1_global_desc, + const CGlobalDesc& c_m0_m1_m2_n_global_desc, + const CBlockClusterDesc& c_block_cluster_desc) + { + constexpr index_t shared_block_size = GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + Run(p_a_global, + p_b_global, + p_c_global, + a_k0_m_k1_global_desc, + b_k0_n_k1_global_desc, + c_m0_m1_m2_n_global_desc, + c_block_cluster_desc, + p_shared_block); + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r2.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r2.hpp new file mode 100644 index 0000000000..4e7e59d10c --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r2.hpp @@ -0,0 +1,509 @@ +#ifndef CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R2_HPP +#define CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R2_HPP + +#include "common_header.hpp" +#include "dynamic_multi_index_transform_helper.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "blockwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_dynamic_tensor_slice_set.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_dynamic_gemm_xdlops_v2r2(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AK0MK1GridDesc a_k0_m_k1_grid_desc, + const BK0NK1GridDesc b_k0_n_k1_grid_desc, + const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc, + const CBlockClusterAdaptor c_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_m1_m2_n_grid_desc, + c_block_cluster_adaptor); +} + +template +struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + // K1 should be Number<...> + static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2); + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); + } + + __host__ __device__ static constexpr bool + CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, + const BK0NK1GridDesc& b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc) + { + // TODO: turn on this + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + const auto M = a_k0_m_k1_grid_desc.GetLength(I1); + const auto N = b_k0_n_k1_grid_desc.GetLength(I1); + const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K0 == b_k0_n_k1_grid_desc.GetLength(I0) && + K1 == a_k0_m_k1_grid_desc.GetLength(I2) && + K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && + (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0) && + (MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0); + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); + + return grid_size; + } + + __host__ __device__ static constexpr auto + MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto xdlops_gemm = XdlopsGemm{}; + + constexpr auto CLayout = xdlops_gemm.GetCLayout(); + + constexpr auto M0 = Number{}; + constexpr auto M1 = Number{}; + constexpr auto M2 = Number{}; + + const auto c_m0_m1_m2_n_grid_desc = transform_dynamic_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(M / (M1 * M2), M1, M2)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{})); + + return c_m0_m1_m2_n_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + return c_blockid_to_m0_n0_block_cluster_adaptor; + } + + using CM0M1M2NGridDesc = decltype(MakeCM0M1M2NGridDescriptor(CMNGridDesc{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{})); + + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AK0MK1GridDesc& a_k0_m_k1_grid_desc, + const BK0NK1GridDesc& b_k0_n_k1_grid_desc, + const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc, + const CBlockClusterAdaptor& c_block_cluster_adaptor) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize()); + + const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + const auto M = a_k0_m_k1_grid_desc.GetLength(I1); + const auto N = b_k0_n_k1_grid_desc.GetLength(I1); + + // divide block work by [M, N] + const auto block_work_idx = + c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseDynamicTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_K0_M_K1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_k0_m_k1_grid_desc), + decltype(a_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_k0_m_k1_grid_desc, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_k0_m_k1_block_desc, + make_multi_index(0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = + BlockwiseDynamicTensorSliceTransfer_v4, + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_k0_n_k1_grid_desc), + decltype(b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_k0_n_k1_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[KPerBlock, MPerBlock] is in LDS + // b_mtx[KPerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + static_assert(MPerBlock % (MPerWave * MRepeat) == 0 && + NPerBlock % (NPerWave * NRepeat) == 0, + "wrong!"); + + constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor( + a_k0_m_k1_block_desc, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor( + b_k0_n_k1_block_desc, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto blockwise_gemm = + BlockwiseGemmXdlops_km_kn_m0m1m2n_v1{}; + + constexpr auto CLayout = blockwise_gemm.GetCLayout(); + + constexpr index_t BlkSize = CLayout.GetBlkSize(); + constexpr index_t NumBlks = CLayout.GetNumBlks(); + constexpr index_t NumXdlops = CLayout.GetNumXdlops(); + + constexpr auto c_mr_nr_nx_desc = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{}, Number{})); + + constexpr auto c_blk_nb_bs_desc = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{})); + + StaticBuffer, + c_mr_nr_nx_desc.GetElementSpaceSize()> + c_thread_buf; + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block = p_shared_block; + FloatAB* p_b_block = p_shared_block + a_block_space_size; + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); + + // hack to control index calculation when iterating over A and B matrix for threadwise copy + constexpr auto a_k0_m_k1_grid_iterator_hacks = AGridIteratorHacks{}; + constexpr auto b_k0_n_k1_grid_iterator_hacks = BGridIteratorHacks{}; + + // hack to control index calculation when move slice window for A and B matrix for + // threadwise copy + constexpr auto a_k0_m_k1_grid_move_slice_window_iterator_hack = + AGridMoveSliceWindowIteratorHacks{}; + constexpr auto b_k0_n_k1_grid_move_slice_window_iterator_hack = + BGridMoveSliceWindowIteratorHacks{}; + + auto a_block_buf = make_dynamic_buffer( + p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( + p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); + + // preload data into LDS + { + a_blockwise_copy.RunRead( + a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks); + b_blockwise_copy.RunRead( + b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks); + + a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); + } + + // main body + index_t k_block_data_begin = 0; + + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc, + a_block_slice_copy_step, + a_k0_m_k1_grid_move_slice_window_iterator_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc, + b_block_slice_copy_step, + b_k0_n_k1_grid_move_slice_window_iterator_hack); + + a_blockwise_copy.RunRead( + a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks); + + block_sync_lds(); + + b_blockwise_copy.RunRead( + b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); + + k_block_data_begin += KPerBlock; + } while(k_block_data_begin < (K0 - KPerBlock)); + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr index_t M0 = CLayout.M1(); + constexpr index_t M1 = CLayout.N1(); + constexpr index_t M2 = CLayout.M0(); + + constexpr auto c_m0_m1_m2_n_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number<1>{}, Number{}, Number<1>{})); + + StaticBuffer c_blk_buf_; + + static_for<0, MRepeat, 1>{}([&](auto mr_i) { + static_for<0, NRepeat, 1>{}([&](auto nr_i) { + static_for<0, NumXdlops, 1>{}([&](auto xdlops_i) { + static_for<0, NumBlks, 1>{}([&](auto blk_i) { + auto c_blk = c_thread_buf[Number{}]; + + static_for<0, BlkSize, 1>{}([&](auto j) { + c_blk_buf_(j) = c_blk.template AsType()[Number< + c_blk_nb_bs_desc.CalculateOffset(make_tuple(blk_i, j))>{}]; + }); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex( + mr_i, nr_i, xdlops_i, blk_i); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = + CGridIteratorHacks{}; + + ThreadwiseDynamicTensorSliceTransfer_v1r3< + FloatC, + FloatC, + decltype(c_m0_m1_m2_n_thread_desc), + decltype(c_m0_m1_m2_n_grid_desc), + Sequence, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{c_m0_m1_m2_n_grid_desc, + make_multi_index(m_thread_data_on_grid / (M2 * M1), + m_thread_data_on_grid % (M2 * M1) / M2, + m_thread_data_on_grid % M2, + n_thread_data_on_grid)} + .Run(c_m0_m1_m2_n_thread_desc, + make_tuple(I0, I0, I0, I0), + c_blk_buf_, + c_m0_m1_m2_n_grid_desc, + c_grid_buf, + c_m0_m1_m2_n_grid_tensor_iterator_hacks); + }); + }); + }); + }); + } + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r3.hpp new file mode 100644 index 0000000000..d15ec86800 --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops_v2r3.hpp @@ -0,0 +1,777 @@ +#ifndef CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R3_HPP +#define CK_GRIDWISE_DYNAMIC_GEMM_XDLOPS_V2R3_HPP + +#include "common_header.hpp" +#include "dynamic_multi_index_transform_helper.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "blockwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_dynamic_tensor_slice_set.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_dynamic_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AK0MK1GridDesc a_k0_m_k1_grid_desc, + const BK0NK1GridDesc b_k0_n_k1_grid_desc, + const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc, + const CBlockClusterAdaptor c_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_m1_m2_n_grid_desc, + c_block_cluster_adaptor); +} + +template +struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + // K1 should be Number<...> + static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2); + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); + } + + __host__ __device__ static constexpr bool + CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, + const BK0NK1GridDesc& b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc) + { + // TODO: turn on this + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + const auto M = a_k0_m_k1_grid_desc.GetLength(I1); + const auto N = b_k0_n_k1_grid_desc.GetLength(I1); + const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K0 == b_k0_n_k1_grid_desc.GetLength(I0) && + K1 == a_k0_m_k1_grid_desc.GetLength(I2) && + K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && + (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0) && + (MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0); + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); + + return grid_size; + } + + __host__ __device__ static constexpr auto + MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto xdlops_gemm = XdlopsGemm{}; + + constexpr auto CLayout = xdlops_gemm.GetCLayout(); + + constexpr auto M0 = Number{}; + constexpr auto M1 = Number{}; + constexpr auto M2 = Number{}; + + constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat); + constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat); + + constexpr auto N0 = Number{}; + constexpr auto N1 = Number{}; + + const auto c_m0_m1_m2_n_grid_desc = transform_dynamic_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, M0, M1, M2)), + make_unmerge_transform(make_tuple(NRepeat, NWaves, N1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{})); + + return c_m0_m1_m2_n_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCBlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + +#if 1 + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); +#elif 1 + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(N0, M0))), + make_tuple(Sequence<1, 0>{}), + make_tuple(Sequence<0>{})); +#endif + + return c_blockid_to_m0_n0_block_cluster_adaptor; + } + + using CM0M1M2NGridDesc = decltype(MakeCM0M1M2NGridDescriptor(CMNGridDesc{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{})); + + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AK0MK1GridDesc& a_k0_m_k1_grid_desc, + const BK0NK1GridDesc& b_k0_n_k1_grid_desc, + const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc, + const CBlockClusterAdaptor& c_block_cluster_adaptor) + { + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize()); + + const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + const auto M = a_k0_m_k1_grid_desc.GetLength(I1); + const auto N = b_k0_n_k1_grid_desc.GetLength(I1); + + // divide block work by [M, N] + const auto block_work_idx = + c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseDynamicTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_K0_M_K1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_k0_m_k1_grid_desc), + decltype(a_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_k0_m_k1_grid_desc, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_k0_m_k1_block_desc, + make_multi_index(0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = + BlockwiseDynamicTensorSliceTransfer_v4, + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_k0_n_k1_grid_desc), + decltype(b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_k0_n_k1_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[KPerBlock, MPerBlock] is in LDS + // b_mtx[KPerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + static_assert(MPerBlock % (MPerWave * MRepeat) == 0 && + NPerBlock % (NPerWave * NRepeat) == 0, + "wrong!"); + + constexpr auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor( + a_k0_m_k1_block_desc, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + constexpr auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor( + b_k0_n_k1_block_desc, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto blockwise_gemm = + BlockwiseGemmXdlops_km_kn_m0m1m2n_v1{}; + + constexpr auto CLayout = blockwise_gemm.GetCLayout(); + + constexpr index_t BlkSize = CLayout.GetBlkSize(); + constexpr index_t NumBlks = CLayout.GetNumBlks(); + constexpr index_t NumXdlops = CLayout.GetNumXdlops(); + + static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only"); + + constexpr auto c_mr_nr_blk_desc = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{})); + + StaticBuffer, + c_mr_nr_blk_desc.GetElementSpaceSize()> + c_thread_buf; + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block = p_shared_block; + FloatAB* p_b_block = p_shared_block + a_block_space_size; + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); + + // hack to control index calculation when iterating over A and B matrix for threadwise copy + constexpr auto a_k0_m_k1_grid_iterator_hacks = AGridIteratorHacks{}; + constexpr auto b_k0_n_k1_grid_iterator_hacks = BGridIteratorHacks{}; + + // hack to control index calculation when move slice window for A and B matrix for + // threadwise copy + constexpr auto a_k0_m_k1_grid_move_slice_window_iterator_hack = + AGridMoveSliceWindowIteratorHacks{}; + constexpr auto b_k0_n_k1_grid_move_slice_window_iterator_hack = + BGridMoveSliceWindowIteratorHacks{}; + + auto a_block_buf = make_dynamic_buffer( + p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( + p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); + + // preload data into LDS + { + a_blockwise_copy.RunRead( + a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks); + b_blockwise_copy.RunRead( + b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks); + + a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); + } + + // main body + index_t k_block_data_begin = 0; + + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc, + a_block_slice_copy_step, + a_k0_m_k1_grid_move_slice_window_iterator_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc, + b_block_slice_copy_step, + b_k0_n_k1_grid_move_slice_window_iterator_hack); + + a_blockwise_copy.RunRead( + a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_iterator_hacks); + + block_sync_lds(); + + b_blockwise_copy.RunRead( + b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_iterator_hacks); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); + + k_block_data_begin += KPerBlock; + } while(k_block_data_begin < (K0 - KPerBlock)); + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + +#if 0 + // output: register to global memory + { + constexpr index_t M0 = CLayout.M1(); + constexpr index_t M1 = CLayout.N1(); + constexpr index_t M2 = CLayout.M0(); + + constexpr index_t N0 = CLayout.N1(); + constexpr index_t N1 = CLayout.N0(); + + constexpr auto c_m0_m1_m2_n_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number{}, + Number{}, + Number<1>{}, + Number<1>{}, + Number{}, + Number<1>{}, + Number{}, + Number<1>{})); + + StaticBuffer + c_blk_buf_; + + static_for<0, MRepeat, 1>{}([&](auto mr_i) { + static_for<0, NRepeat, 1>{}([&](auto nr_i) { + constexpr auto blk_off = + c_mr_nr_blk_desc.CalculateOffset(make_tuple(mr_i, nr_i)); + + static_for<0, BlkSize, 1>{}([&](auto j) { + c_blk_buf_(Number{}) = + c_thread_buf[Number{}] + .template AsType()[Number{}]; + }); + }); + }); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = CGridIteratorHacks{}; + + constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat); + constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat); + + ThreadwiseDynamicTensorSliceTransfer_v1r3< + FloatC, + FloatC, + decltype(c_m0_m1_m2_n_thread_desc), + decltype(c_m0_m1_m2_n_grid_desc), + Sequence, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + c_m0_m1_m2_n_grid_desc, + make_multi_index(m_thread_data_on_grid / (M2 * M1 * M0 * MWaves), + n_thread_data_on_grid / (N1 * NWaves), + m_thread_data_on_grid % (M2 * M1 * M0 * MWaves) / (M2 * M1 * M0), + n_thread_data_on_grid % (N1 * NWaves) / N1, + m_thread_data_on_grid % (M2 * M1 * M0) / (M2 * M1), + m_thread_data_on_grid % (M2 * M1) / M2, + m_thread_data_on_grid % M2, + n_thread_data_on_grid % N1)} + .Run(c_m0_m1_m2_n_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_blk_buf_, + c_m0_m1_m2_n_grid_desc, + c_grid_buf, + c_m0_m1_m2_n_grid_tensor_iterator_hacks); + } +#else + { + constexpr index_t M0 = CLayout.M1(); + constexpr index_t M1 = CLayout.N1(); + constexpr index_t M2 = CLayout.M0(); + + constexpr auto c_m0_m1_m2_n_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( + I1, I1, I1, I1, Number{}, Number<1>{}, Number{}, Number<1>{})); + + StaticBuffer c_blk_buf_; + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + constexpr auto c_m0_m1_m2_n_grid_tensor_iterator_hacks = CGridIteratorHacks{}; + + auto c_thread_copy = + ThreadwiseDynamicTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + c_m0_m1_m2_n_grid_desc, + make_multi_index(0, + 0, + 0, + 0, + m_thread_data_on_grid / (M2 * M1), + m_thread_data_on_grid % (M2 * M1) / M2, + m_thread_data_on_grid % M2, + n_thread_data_on_grid)}; + + auto init_copy = [&](auto c_thread_idx_) { + constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); + c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf[Number{}].template AsType(), + c_m0_m1_m2_n_grid_desc, + c_grid_buf, + c_m0_m1_m2_n_grid_tensor_iterator_hacks); + + return c_thread_idx_; + }; + + auto mrepeat_plus_copy = [&](auto c_thread_idx_) { + constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0); + c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus); + + constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); + c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf[Number{}].template AsType(), + c_m0_m1_m2_n_grid_desc, + c_grid_buf, + c_m0_m1_m2_n_grid_tensor_iterator_hacks); + }; + + auto nrepeat_plus_copy = [&](auto c_thread_idx_) { + constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0); + c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_plus); + + constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); + c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf[Number{}].template AsType(), + c_m0_m1_m2_n_grid_desc, + c_grid_buf, + c_m0_m1_m2_n_grid_tensor_iterator_hacks); + }; + + auto mrepeat_minus_copy = [&](auto c_thread_idx_) { + constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0); + c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus); + + constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); + c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf[Number{}].template AsType(), + c_m0_m1_m2_n_grid_desc, + c_grid_buf, + c_m0_m1_m2_n_grid_tensor_iterator_hacks); + }; + + auto nrepeat_minus_copy = [&](auto c_thread_idx_) { + constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0); + c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_minus); + + constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); + c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf[Number{}].template AsType(), + c_m0_m1_m2_n_grid_desc, + c_grid_buf, + c_m0_m1_m2_n_grid_tensor_iterator_hacks); + }; + + static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or + (MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2) or + (MRepeat == 2 && NRepeat == 1) or (MRepeat == 1 && NRepeat == 2) or + (MRepeat == 1 && NRepeat == 1), + "wrong"); + + if constexpr(MRepeat == 4 && NRepeat == 4) + { + init_copy(make_tuple(I0, I0)); + + if constexpr(CAccessOrderMRepeatNRepeat) + { + nrepeat_plus_copy(make_tuple(I0, I1)); + nrepeat_plus_copy(make_tuple(I0, I2)); + nrepeat_plus_copy(make_tuple(I0, I3)); + mrepeat_plus_copy(make_tuple(I1, I3)); + nrepeat_minus_copy(make_tuple(I1, I2)); + nrepeat_minus_copy(make_tuple(I1, I1)); + nrepeat_minus_copy(make_tuple(I1, I0)); + mrepeat_plus_copy(make_tuple(I2, I0)); + nrepeat_plus_copy(make_tuple(I2, I1)); + nrepeat_plus_copy(make_tuple(I2, I2)); + nrepeat_plus_copy(make_tuple(I2, I3)); + mrepeat_plus_copy(make_tuple(I3, I3)); + nrepeat_minus_copy(make_tuple(I3, I2)); + nrepeat_minus_copy(make_tuple(I3, I1)); + nrepeat_minus_copy(make_tuple(I3, I0)); + } + else + { + mrepeat_plus_copy(make_tuple(I1, I0)); + mrepeat_plus_copy(make_tuple(I2, I0)); + mrepeat_plus_copy(make_tuple(I3, I0)); + nrepeat_plus_copy(make_tuple(I3, I1)); + mrepeat_minus_copy(make_tuple(I2, I1)); + mrepeat_minus_copy(make_tuple(I1, I1)); + mrepeat_minus_copy(make_tuple(I0, I1)); + nrepeat_plus_copy(make_tuple(I0, I2)); + mrepeat_plus_copy(make_tuple(I1, I2)); + mrepeat_plus_copy(make_tuple(I2, I2)); + mrepeat_plus_copy(make_tuple(I3, I2)); + nrepeat_plus_copy(make_tuple(I3, I3)); + mrepeat_minus_copy(make_tuple(I2, I3)); + mrepeat_minus_copy(make_tuple(I1, I3)); + mrepeat_minus_copy(make_tuple(I0, I3)); + } + } + else if constexpr(MRepeat == 4 && NRepeat == 2) + { + init_copy(make_tuple(I0, I0)); + + if constexpr(CAccessOrderMRepeatNRepeat) + { + nrepeat_plus_copy(make_tuple(I0, I1)); + mrepeat_plus_copy(make_tuple(I1, I1)); + nrepeat_minus_copy(make_tuple(I1, I0)); + mrepeat_plus_copy(make_tuple(I2, I0)); + nrepeat_plus_copy(make_tuple(I2, I1)); + mrepeat_plus_copy(make_tuple(I3, I1)); + nrepeat_minus_copy(make_tuple(I3, I0)); + } + else + { + mrepeat_plus_copy(make_tuple(I1, I0)); + mrepeat_plus_copy(make_tuple(I2, I0)); + mrepeat_plus_copy(make_tuple(I3, I0)); + nrepeat_plus_copy(make_tuple(I3, I1)); + mrepeat_minus_copy(make_tuple(I2, I1)); + mrepeat_minus_copy(make_tuple(I1, I1)); + mrepeat_minus_copy(make_tuple(I0, I1)); + } + } + else if constexpr(MRepeat == 2 && NRepeat == 4) + { + init_copy(make_tuple(I0, I0)); + + if constexpr(CAccessOrderMRepeatNRepeat) + { + nrepeat_plus_copy(make_tuple(I0, I1)); + nrepeat_plus_copy(make_tuple(I0, I2)); + nrepeat_plus_copy(make_tuple(I0, I3)); + mrepeat_plus_copy(make_tuple(I1, I3)); + nrepeat_minus_copy(make_tuple(I1, I2)); + nrepeat_minus_copy(make_tuple(I1, I1)); + nrepeat_minus_copy(make_tuple(I1, I0)); + } + else + { + mrepeat_plus_copy(make_tuple(I1, I0)); + nrepeat_plus_copy(make_tuple(I1, I1)); + mrepeat_minus_copy(make_tuple(I0, I1)); + nrepeat_plus_copy(make_tuple(I0, I2)); + mrepeat_plus_copy(make_tuple(I1, I2)); + nrepeat_plus_copy(make_tuple(I1, I3)); + mrepeat_minus_copy(make_tuple(I0, I3)); + } + } + else if constexpr(MRepeat == 2 && NRepeat == 2) + { + init_copy(make_tuple(I0, I0)); + + if constexpr(CAccessOrderMRepeatNRepeat) + { + nrepeat_plus_copy(make_tuple(I0, I1)); + mrepeat_plus_copy(make_tuple(I1, I1)); + nrepeat_minus_copy(make_tuple(I1, I0)); + } + else + { + mrepeat_plus_copy(make_tuple(I1, I0)); + nrepeat_plus_copy(make_tuple(I1, I1)); + mrepeat_minus_copy(make_tuple(I0, I1)); + } + } + else if constexpr(MRepeat == 2 && NRepeat == 1) + { + init_copy(make_tuple(I0, I0)); + mrepeat_plus_copy(make_tuple(I1, I0)); + } + else if constexpr(MRepeat == 1 && NRepeat == 2) + { + init_copy(make_tuple(I0, I0)); + nrepeat_plus_copy(make_tuple(I0, I1)); + } + else if constexpr(MRepeat == 1 && NRepeat == 1) + { + init_copy(make_tuple(I0, I0)); + } + } +#endif + } +}; // namespace ck + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp index 530b66b0e3..4253431d5e 100644 --- a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp @@ -101,9 +101,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3 static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer"); - static_assert(is_same>, - remove_cv_t>>::value, - "wrong! SrcBuffer data type is wrong"); + // static_assert(is_same>, + // remove_cv_t>>::value, + //"wrong! SrcBuffer data type is wrong"); // SrcDesc and src_slice_origin_idx are known at compile-time constexpr auto src_desc = remove_cv_t>{}; @@ -1407,7 +1407,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4 constexpr auto data_to_origin_disp_idx = ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access; #endif - // src coordinate constexpr auto src_ref_to_data_disp_idx = src_ref_to_origin_disp_idx + data_to_origin_disp_idx; diff --git a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp new file mode 100644 index 0000000000..5fbc22c807 --- /dev/null +++ b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp @@ -0,0 +1,802 @@ +#ifndef CK_XDLOPS_GEMM_HPP +#define CK_XDLOPS_GEMM_HPP + +#include "common_header.hpp" +#include "ConstantMatrixDescriptor.hpp" +#include "math.hpp" +#include "amd_xdlops.hpp" + +namespace ck { + +enum struct mfma_instr +{ + /// fp32 + mfma_f32_32x32x1xf32 = 0, + mfma_f32_16x16x1xf32, + mfma_f32_4x4x1xf32, + mfma_f32_32x32x2xf32, // k reduction + mfma_f32_16x16x4xf32, // k reduction + /// fp16 + mfma_f32_32x32x4f16, + mfma_f32_16x16x4f16, + mfma_f32_4x4x4f16, + mfma_f32_32x32x8f16, // k reduction + mfma_f32_16x16x16f16, // k reduction + /// bfp16 + mfma_f32_32x32x2bf16, + mfma_f32_16x16x2bf16, + mfma_f32_4x4x2bf16, + mfma_f32_32x32x4bf16, // k reduction + mfma_f32_16x16x8bf16, // k reduction +}; + +template +struct mfma_info; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 4; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 2; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 32; + static constexpr index_t n = 32; + static constexpr index_t k = 1; + static constexpr index_t cycles = 64; + static constexpr index_t k_base = 1; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x1f32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 4; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 32; + static constexpr index_t n = 32; + static constexpr index_t k = 2; + static constexpr index_t cycles = 64; + static constexpr index_t k_base = 1; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x2f32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 16; + static constexpr index_t n = 16; + static constexpr index_t k = 4; + static constexpr index_t cycles = 32; + static constexpr index_t k_base = 1; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x4f32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 4; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 16; + static constexpr index_t n = 16; + static constexpr index_t k = 1; + static constexpr index_t cycles = 32; + static constexpr index_t k_base = 1; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x1f32::Run(a, b, reg_c); + } +}; + +// treat 4x4x1 as a single-blk 4x64 mfma +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = 4; + static constexpr index_t m = 4; + static constexpr index_t n = 64; + static constexpr index_t k = 1; + static constexpr index_t cycles = 8; + static constexpr index_t k_base = 1; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_4x4x1f32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 4; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 2; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 32; + static constexpr index_t n = 32; + static constexpr index_t k = 4; + static constexpr index_t cycles = 64; + static constexpr index_t k_base = 4; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x4f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 4; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 32; + static constexpr index_t n = 32; + static constexpr index_t k = 8; + static constexpr index_t cycles = 64; + static constexpr index_t k_base = 4; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x8f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 16; + static constexpr index_t n = 16; + static constexpr index_t k = 16; + static constexpr index_t cycles = 32; + static constexpr index_t k_base = 4; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x16f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 4; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 16; + static constexpr index_t n = 16; + static constexpr index_t k = 4; + static constexpr index_t cycles = 32; + static constexpr index_t k_base = 4; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x4f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = 4; + static constexpr index_t m = 4; + static constexpr index_t n = 64; + static constexpr index_t k = 4; + static constexpr index_t cycles = 8; + static constexpr index_t k_base = 4; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_4x4x4f16::Run(a, b, reg_c); + } +}; + +#if 0 +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 4; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 2; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 32; + static constexpr index_t n = 32; + static constexpr index_t k = 2; + static constexpr index_t cycles = 64; + static constexpr index_t k_base = 2; + + template + __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + { + const auto p_a = reinterpret_cast(a); + const auto p_b = reinterpret_cast(b); + + return intrin_mfma_f32_32x32x2bf16::run( + p_a, p_b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 4; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 32; + static constexpr index_t n = 32; + static constexpr index_t k = 4; + static constexpr index_t cycles = 64; + static constexpr index_t k_base = 2; + + template + __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + { + const auto p_a = reinterpret_cast(a); + const auto p_b = reinterpret_cast(b); + + return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 16; + static constexpr index_t n = 16; + static constexpr index_t k = 8; + static constexpr index_t cycles = 32; + static constexpr index_t k_base = 2; + + template + __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + { + const auto p_a = reinterpret_cast(a); + const auto p_b = reinterpret_cast(b); + + return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = wave_size / num_threads_blk; + static constexpr index_t num_output_blks = 4; + static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; + static constexpr index_t m = 16; + static constexpr index_t n = 16; + static constexpr index_t k = 2; + static constexpr index_t cycles = 32; + static constexpr index_t k_base = 2; + + template + __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + { + const auto p_a = reinterpret_cast(a); + const auto p_b = reinterpret_cast(b); + + return intrin_mfma_f32_16x16x2bf16(p_a, p_b, reg_c); + } +}; + +template <> +struct mfma_info +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_blk = 1; + static constexpr index_t num_regs_blk = group_size * num_groups_blk; + static constexpr index_t num_threads_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t num_regs_xdlops = 4; + static constexpr index_t m = 4; + static constexpr index_t n = 64; + static constexpr index_t k = 2; + static constexpr index_t cycles = 8; + static constexpr index_t k_base = 2; + + template + __device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const + { + const auto p_a = reinterpret_cast(a); + const auto p_b = reinterpret_cast(b); + + return intrin_mfma_f32_4x4x2bf16::run(p_a, p_b, reg_c); + } +}; +#endif + +template +struct xdlops_info +{ + static constexpr auto mfma_type = mfma_info{}; + + static constexpr index_t MPerXdlops = MPerXdlops_; + static constexpr index_t NPerXdlops = NPerXdlops_; + + static constexpr bool IsABroadcast() + { + static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast"); + return true; + } + + static constexpr bool IsKReduction() + { + return (mfma_type.num_output_blks == 1) && (mfma_type.num_input_blks > 1); + } + + static constexpr index_t GetKPerXdlops() + { + return IsKReduction() ? mfma_type.num_input_blks : 1; + } + + static constexpr index_t GetNumCRegs() { return MPerXdlops * NPerXdlops / mfma_type.wave_size; } +}; + +template +struct XdlopsGemm +{ + template + static constexpr auto GetXdlopsInfo(); + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + +#if 0 + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetXdlopsInfo() + { + return xdlops_info{}; + } +#endif + + using CIndex = MultiIndex<2>; + + __device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; } + + __device__ static constexpr index_t GetNumXdlops() + { + return MPerXdlops * NPerXdlops / (mfma_type.m * mfma_type.n * mfma_type.num_output_blks); + } + + __host__ __device__ constexpr XdlopsGemm() + { + static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 || + NPerXdlops == 64, + "Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); + + static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 || + MPerXdlops == 64, + "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); + + static_assert(mfma_type.num_threads_blk == mfma_type.n, "n != num_threads_blk"); + static_assert(mfma_type.num_regs_blk * mfma_type.num_input_blks == mfma_type.m, + "m != num_input_blks * num_regs_blk"); + static_assert(mfma_type.num_output_blks == mfma_type.num_input_blks || + mfma_type.num_output_blks == 1, + "incorrect num_output_blks"); + static_assert(mfma_type.num_regs_blk * mfma_type.wave_size == mfma_type.m * mfma_type.n, + "num_regs_blk incorrect"); + + static_assert(mfma_type.k % mfma_type.k_base == 0, "k % kbase != 0!"); + } + + __device__ static constexpr index_t GetRegSizePerXdlops() + { + return MPerXdlops * NPerXdlops / mfma_type.wave_size; + } + + template + __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "base base_type must be float, half, ushort!"); + + static_assert(KPack % mfma_type.k_base == 0, "KPack cannot be divided by k_base"); + + constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops(); + + static_for<0, KPack, mfma_type.k_base>{}([&](auto k) { + constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(0, m0, 0, k)); + constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(0, n0, 0, k)); + + mfma_type.template run( + p_a_wave[Number{}], + p_b_wave[Number{}], + p_c_thread); + }); + } + + __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i) + { + const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size; + const index_t blk_id = laneId / mfma_type.num_threads_blk; + const index_t blk_td = laneId % mfma_type.num_threads_blk; + + index_t n_offset = blk_i * mfma_type.n + blk_td; + index_t m_offset = xdlops_i * mfma_type.m + blk_id * mfma_type.group_size; + + return CIndex{m_offset, n_offset}; + } + + static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats; + static constexpr index_t NRepeats = GetXdlopsInfo().NRepeats; + static constexpr index_t MPerXdlops = GetXdlopsInfo().MPerXdlops; + static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops; + + static constexpr bool IsKReduction = GetXdlopsInfo().IsKReduction(); + static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast(); + static constexpr index_t KPerXdlops = GetXdlopsInfo().GetKPerXdlops(); + + static constexpr auto GetBlkId(const index_t lane_id) + { + return lane_id / mfma_type.num_threads_blk; + } + + static constexpr auto GetBlkTd(const index_t lane_id) + { + return lane_id % mfma_type.num_threads_blk; + } + + static constexpr auto mfma_type = GetXdlopsInfo().mfma_type; + + struct CLayout + { + __host__ __device__ static constexpr index_t M1() { return mfma_type.num_groups_blk; } + __host__ __device__ static constexpr index_t M0() { return mfma_type.group_size; } + __host__ __device__ static constexpr index_t N1() { return mfma_type.num_input_blks; } + __host__ __device__ static constexpr index_t N0() { return mfma_type.num_threads_blk; } + + __device__ static constexpr index_t GetBlkSize() { return mfma_type.num_regs_blk; } + + __device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; } + + __device__ static constexpr index_t GetNumXdlops() + { + return MPerXdlops * NPerXdlops / + (mfma_type.m * mfma_type.n * mfma_type.num_output_blks); + } + }; + + __host__ __device__ static constexpr auto GetCLayout() { return CLayout{}; } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/amd_buffer_addressing_v2.hpp b/composable_kernel/include/utility/amd_buffer_addressing_v2.hpp index da84b6ca7c..63d0e0529b 100644 --- a/composable_kernel/include/utility/amd_buffer_addressing_v2.hpp +++ b/composable_kernel/include/utility/amd_buffer_addressing_v2.hpp @@ -268,6 +268,7 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, } else if constexpr(N == 8) { +#if 0 vector_type tmp; tmp.AsType()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp16x4( @@ -280,6 +281,12 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource, 0); return tmp.AsType()(Number<0>{}); +#else + float4_t tmp = __llvm_amdgcn_raw_buffer_load_fp32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return as_type(tmp); +#endif } } else if constexpr(is_same::value) diff --git a/composable_kernel/include/utility/amd_xdlops.hpp b/composable_kernel/include/utility/amd_xdlops.hpp new file mode 100644 index 0000000000..b373e27be3 --- /dev/null +++ b/composable_kernel/include/utility/amd_xdlops.hpp @@ -0,0 +1,499 @@ +#ifndef CK_AMD_XDLOPS_HPP +#define CK_AMD_XDLOPS_HPP + +#include "float_type.hpp" + +namespace ck { + +// A, B, C, cbsz, abid, blgp +extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + float, float, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x1f32"); + +extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x2f32( + float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2f32"); + +extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x4f32( + float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f32"); + +extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32( + float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x1f32"); + +extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x1f32"); + +extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16"); + +extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8f16( + half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8f16"); + +extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16f16( + half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16f16"); + +extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16( + half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f16"); + +extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x4f16"); + +extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16( + ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16"); + +extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x4bf16( + ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4bf16"); + +extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x8bf16( + ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x8bf16"); + +extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16( + ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x2bf16"); + +extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16( + ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16"); + +template +struct intrin_mfma_f32_32x32x1f32; + +template +struct intrin_mfma_f32_32x32x1f32<64, 64, COffset> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 1, + 0, + 0); + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 1, + 1, + 0); + } +}; + +template +struct intrin_mfma_f32_32x32x1f32<32, 64, COffset> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 1, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_32x32x2f32; + +template +struct intrin_mfma_f32_32x32x2f32<32, 32, COffset> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x2f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_16x16x4f32; + +template +struct intrin_mfma_f32_16x16x4f32<16, 16, COffset> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_16x16x4f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_16x16x1f32; + +template +struct intrin_mfma_f32_16x16x1f32<16, 64, COffset> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_16x16x1f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 2, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_4x4x1f32; + +template +struct intrin_mfma_f32_4x4x1f32<4, 64, COffset> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 4, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_4x4x1f32<8, 64, COffset> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 4, + 0, + 0); + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 4, + 1, + 0); + } +}; + +template +struct intrin_mfma_f32_32x32x4f16; + +template +struct intrin_mfma_f32_32x32x4f16<64, 64, COffset> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 1, + 0, + 0); + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 1, + 1, + 0); + } +}; + +template +struct intrin_mfma_f32_32x32x4f16<32, 64, COffset> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 1, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_32x32x8f16; + +template +struct intrin_mfma_f32_32x32x8f16<32, 32, COffset> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_32x32x8f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_16x16x16f16; + +template +struct intrin_mfma_f32_16x16x16f16<16, 16, COffset> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_16x16x16f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_16x16x4f16; + +template +struct intrin_mfma_f32_16x16x4f16<16, 64, COffset> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_16x16x4f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 2, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_4x4x4f16; + +template +struct intrin_mfma_f32_4x4x4f16<4, 64, COffset> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 4, + 0, + 0); + } +}; + +template +struct intrin_mfma_f32_4x4x4f16<8, 64, COffset> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 4, + 0, + 0); + reg_c(Number{}).template AsType()(Number<0>{}) = + llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + reg_a, + reg_b, + reg_c[Number{}].template AsType()[Number<0>{}], + 4, + 1, + 0); + } +}; + +#if 0 +template +struct intrin_mfma_f32_32x32x2bf16; + +template +struct intrin_mfma_f32_32x32x2bf16<128, 64, AStride, BStride> +{ + __device__ static c_vec32_4_t::VecType + run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c) + { + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); + reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); + + reg_c.s.z = + llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0); + reg_c.s.w = + llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0); + + return reg_c; + } +}; + +template +struct intrin_mfma_f32_32x32x2bf16<64, 128, AStride, BStride> +{ + __device__ static c_vec32_4_t::VecType + run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c) + { + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); + reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); + + reg_c.s.z = + llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0); + reg_c.s.w = + llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0); + + return reg_c; + } +}; + +template +struct intrin_mfma_f32_32x32x2bf16<64, 64, AStride, BStride> +{ + __device__ static c_vec32_2_t::VecType + run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_2_t::VecType reg_c) + { + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); + reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0); + + return reg_c; + } +}; + +template +struct intrin_mfma_f32_32x32x2bf16<64, 32, AStride, BStride> +{ + __device__ static c_vec32_1_t::VecType + run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c) + { + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1); + + return reg_c; + } +}; + +template +struct intrin_mfma_f32_32x32x2bf16<32, 64, AStride, BStride> +{ + __device__ static c_vec32_1_t::VecType + run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c) + { + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0); + return reg_c; + } +}; + +__device__ c_vec16_1_t::VecType intrin_mfma_f32_32x32x4bf16(const ushort2_t* reg_a, + const ushort2_t* reg_b, + c_vec16_1_t::VecType reg_c) +{ + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0); + return reg_c; +} + +__device__ c_vec4_1_t::VecType intrin_mfma_f32_16x16x8bf16(const ushort2_t* reg_a, + const ushort2_t* reg_b, + c_vec4_1_t::VecType reg_c) +{ + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0); + return reg_c; +} + +template +__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a, + const ushort2_t* reg_b, + c_vec16_1_t::VecType reg_c); + +template <> +__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a, + const ushort2_t* reg_b, + c_vec16_1_t::VecType reg_c) +{ + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0); + return reg_c; +} + +template <> +__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t* reg_a, + const ushort2_t* reg_b, + c_vec16_1_t::VecType reg_c) +{ + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4); + return reg_c; +} + +template +struct intrin_mfma_f32_4x4x2bf16; + +template <> +struct intrin_mfma_f32_4x4x2bf16<4, 64> +{ + __device__ static c_vec4_1_t::VecType + run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_1_t::VecType reg_c) + { + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0); + return reg_c; + } +}; + +template <> +struct intrin_mfma_f32_4x4x2bf16<8, 64> +{ + __device__ static c_vec4_2_t::VecType + run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_2_t::VecType reg_c) + { + reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0); + reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0); + return reg_c; + } +}; + +#endif + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/config.amd.hpp.in b/composable_kernel/include/utility/config.amd.hpp.in index 609ae2b212..fd602e7f96 100644 --- a/composable_kernel/include/utility/config.amd.hpp.in +++ b/composable_kernel/include/utility/config.amd.hpp.in @@ -18,7 +18,7 @@ #define CK_AMD_GPU_GFX906 1 #elif 1 #define CK_AMD_GPU_GFX908 1 -#elif 1 +#elif 0 #define CK_AMD_GPU_GFX1030 1 #endif @@ -28,7 +28,7 @@ #endif // launch bounds -#define CK_USE_LAUNCH_BOUNDS 0 +#define CK_USE_LAUNCH_BOUNDS 1 #ifdef CK_USE_LAUNCH_BOUNDS #define CK_MAX_THREAD_PER_BLOCK 256 @@ -116,7 +116,7 @@ #define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1 // merge transformation use magic number division -#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0 +#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 1 // hack: have underlying assumption that need to be satsified, otherwise it's a bug // hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be diff --git a/composable_kernel/include/utility/container_helper.hpp b/composable_kernel/include/utility/container_helper.hpp index 74cd600cae..2ff0c46e6d 100644 --- a/composable_kernel/include/utility/container_helper.hpp +++ b/composable_kernel/include/utility/container_helper.hpp @@ -174,8 +174,15 @@ __host__ __device__ constexpr auto container_reduce(const Container& x, { static_assert((IEnd - IBegin) % IStep == 0, "wrong!"); - return container_reduce_impl( - x, reduce, init, Number{}, Number{}, Number{}); + if constexpr(IEnd > IBegin) + { + return container_reduce_impl( + x, reduce, init, Number{}, Number{}, Number{}); + } + else + { + return init; + } } #endif diff --git a/composable_kernel/include/utility/float_type.amd.hpp.in b/composable_kernel/include/utility/float_type.amd.hpp.in index 44cf657cb1..f41bd6db23 100644 --- a/composable_kernel/include/utility/float_type.amd.hpp.in +++ b/composable_kernel/include/utility/float_type.amd.hpp.in @@ -618,6 +618,252 @@ struct vector_type } }; +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + typedef T d128_t __attribute__((ext_vector_type(128))); + + using type = d128_t; + + union + { + d128_t d128_; + StaticallyIndexedArray d1x128_; + StaticallyIndexedArray d2x64_; + StaticallyIndexedArray d4x32_; + StaticallyIndexedArray d8x16_; + StaticallyIndexedArray d16x8_; + StaticallyIndexedArray d32x4_; + StaticallyIndexedArray d64x2_; + StaticallyIndexedArray d128x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x128_; + } + else if constexpr(is_same::value) + { + return data_.d2x64_; + } + else if constexpr(is_same::value) + { + return data_.d4x32_; + } + else if constexpr(is_same::value) + { + return data_.d8x16_; + } + else if constexpr(is_same::value) + { + return data_.d16x8_; + } + else if constexpr(is_same::value) + { + return data_.d32x4_; + } + else if constexpr(is_same::value) + { + return data_.d64x2_; + } + else if constexpr(is_same::value) + { + return data_.d128x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x128_; + } + else if constexpr(is_same::value) + { + return data_.d2x64_; + } + else if constexpr(is_same::value) + { + return data_.d4x32_; + } + else if constexpr(is_same::value) + { + return data_.d8x16_; + } + else if constexpr(is_same::value) + { + return data_.d16x8_; + } + else if constexpr(is_same::value) + { + return data_.d32x4_; + } + else if constexpr(is_same::value) + { + return data_.d64x2_; + } + else if constexpr(is_same::value) + { + return data_.d128x1_; + } + } +}; + +template +struct vector_type +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + typedef T d128_t __attribute__((ext_vector_type(128))); + typedef T d256_t __attribute__((ext_vector_type(256))); + + using type = d256_t; + + union + { + d256_t d256_; + StaticallyIndexedArray d1x256_; + StaticallyIndexedArray d2x128_; + StaticallyIndexedArray d4x64_; + StaticallyIndexedArray d8x32_; + StaticallyIndexedArray d16x16_; + StaticallyIndexedArray d32x8_; + StaticallyIndexedArray d64x4_; + StaticallyIndexedArray d128x2_; + StaticallyIndexedArray d256x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert( + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x256_; + } + else if constexpr(is_same::value) + { + return data_.d2x128_; + } + else if constexpr(is_same::value) + { + return data_.d4x64_; + } + else if constexpr(is_same::value) + { + return data_.d8x32_; + } + else if constexpr(is_same::value) + { + return data_.d16x16_; + } + else if constexpr(is_same::value) + { + return data_.d32x8_; + } + else if constexpr(is_same::value) + { + return data_.d64x4_; + } + else if constexpr(is_same::value) + { + return data_.d128x2_; + } + else if constexpr(is_same::value) + { + return data_.d256x1_; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert( + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, + "wrong!"); + + if constexpr(is_same::value) + { + return data_.d1x256_; + } + else if constexpr(is_same::value) + { + return data_.d2x128_; + } + else if constexpr(is_same::value) + { + return data_.d4x64_; + } + else if constexpr(is_same::value) + { + return data_.d8x32_; + } + else if constexpr(is_same::value) + { + return data_.d16x16_; + } + else if constexpr(is_same::value) + { + return data_.d32x8_; + } + else if constexpr(is_same::value) + { + return data_.d64x4_; + } + else if constexpr(is_same::value) + { + return data_.d128x2_; + } + else if constexpr(is_same::value) + { + return data_.d256x1_; + } + } +}; + // fp32 using float2_t = typename vector_type::type; using float4_t = typename vector_type::type; diff --git a/composable_kernel/include/utility/math.hpp b/composable_kernel/include/utility/math.hpp index 11e87eca4c..368a955ab3 100644 --- a/composable_kernel/include/utility/math.hpp +++ b/composable_kernel/include/utility/math.hpp @@ -9,25 +9,25 @@ namespace ck { namespace math { -template +template struct scales { __host__ __device__ constexpr T operator()(T a) const { return s * a; } }; -template +template struct plus { __host__ __device__ constexpr T operator()(T a, T b) const { return a + b; } }; -template +template struct minus { __host__ __device__ constexpr T operator()(T a, T b) const { return a - b; } }; -template +template struct multiplies { __host__ __device__ constexpr T operator()(T a, T b) const { return a * b; } @@ -42,83 +42,111 @@ struct multiplies_v2 } }; -template +template struct maximize { __host__ __device__ constexpr T operator()(T a, T b) const { return a >= b ? a : b; } }; -template +template struct minimize { __host__ __device__ constexpr T operator()(T a, T b) const { return a <= b ? a : b; } }; -template +template struct integer_divide_ceiler { __host__ __device__ constexpr T operator()(T a, T b) const { static_assert(is_same{} || is_same{}, "wrong type"); - return (a + b - 1) / b; + return (a + b - Number<1>{}) / b; } }; -template +template __host__ __device__ constexpr auto integer_divide_floor(X x, Y y) { return x / y; } -template +template __host__ __device__ constexpr auto integer_divide_ceil(X x, Y y) { return (x + y - Number<1>{}) / y; } -template +template __host__ __device__ constexpr auto integer_least_multiple(X x, Y y) { return y * integer_divide_ceil(x, y); } -template +template __host__ __device__ constexpr T max(T x) { return x; } -template -__host__ __device__ constexpr T max(T x, Ts... xs) +template +__host__ __device__ constexpr T max(T x, T y) { - static_assert(sizeof...(xs) > 0, "not enough argument"); - - auto y = max(xs...); - - static_assert(is_same{}, "not the same type"); - return x > y ? x : y; } -template +template +__host__ __device__ constexpr index_t max(Number, index_t y) +{ + return X > y ? X : y; +} + +template +__host__ __device__ constexpr index_t max(index_t x, Number) +{ + return x > Y ? x : Y; +} + +template +__host__ __device__ constexpr auto max(X x, Ys... ys) +{ + static_assert(sizeof...(Ys) > 0, "not enough argument"); + + return max(x, max(ys...)); +} + +template __host__ __device__ constexpr T min(T x) { return x; } -template -__host__ __device__ constexpr T min(T x, Ts... xs) +template +__host__ __device__ constexpr T min(T x, T y) { - static_assert(sizeof...(xs) > 0, "not enough argument"); - - auto y = min(xs...); - - static_assert(is_same{}, "not the same type"); - return x < y ? x : y; } +template +__host__ __device__ constexpr index_t min(Number, index_t y) +{ + return X < y ? X : y; +} + +template +__host__ __device__ constexpr index_t min(index_t x, Number) +{ + return x < Y ? x : Y; +} + +template +__host__ __device__ constexpr auto min(X x, Ys... ys) +{ + static_assert(sizeof...(Ys) > 0, "not enough argument"); + + return min(x, min(ys...)); +} + // greatest common divisor, aka highest common factor __host__ __device__ constexpr index_t gcd(index_t x, index_t y) { @@ -171,13 +199,13 @@ __host__ __device__ constexpr auto lcm(X x, Ys... ys) return lcm(x, lcm(ys...)); } -template +template struct equal { __host__ __device__ constexpr bool operator()(T x, T y) const { return x == y; } }; -template +template struct less { __host__ __device__ constexpr bool operator()(T x, T y) const { return x < y; } diff --git a/composable_kernel/include/utility/tuple.hpp b/composable_kernel/include/utility/tuple.hpp index f8b8bb62d4..15b73011b4 100644 --- a/composable_kernel/include/utility/tuple.hpp +++ b/composable_kernel/include/utility/tuple.hpp @@ -153,6 +153,8 @@ struct Tuple : detail::TupleImpl diff --git a/driver/conv_bwd_data_driver.cpp b/driver/conv_bwd_data_driver.cpp index cdb2526c75..63723f5f4f 100644 --- a/driver/conv_bwd_data_driver.cpp +++ b/driver/conv_bwd_data_driver.cpp @@ -19,7 +19,22 @@ int main(int argc, char* argv[]) { using namespace launcher; -#if 0 +#if 1 + // 1x1 filter, 14x14 image + constexpr index_t N = 1; + constexpr index_t C = 256; + constexpr index_t HI = 1; + constexpr index_t WI = 128; + constexpr index_t K = 16; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using LeftPads = Sequence<0, 0>; + using RightPads = Sequence<0, 0>; +#elif 0 constexpr index_t N = 64; constexpr index_t C = 256; constexpr index_t HI = 56; @@ -93,7 +108,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<0, 0>; using RightPads = Sequence<0, 0>; -#elif 0 +#elif 1 // 1x1 filter, 14x14 image constexpr index_t N = 128; constexpr index_t C = 512; @@ -153,7 +168,7 @@ int main(int argc, char* argv[]) using LeftPads = Sequence<2, 2>; using RightPads = Sequence<2, 2>; -#elif 1 +#elif 0 // 1x7 filter, 0x3 pad, 17x17 input constexpr index_t N = 128; constexpr index_t C = 128; @@ -245,7 +260,7 @@ int main(int argc, char* argv[]) device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw #elif 0 device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw -#elif 1 +#elif 0 device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw #elif 1 device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk diff --git a/driver/conv_bwd_data_driver_v2.cpp b/driver/conv_bwd_data_driver_v2.cpp new file mode 100644 index 0000000000..3c271a37ad --- /dev/null +++ b/driver/conv_bwd_data_driver_v2.cpp @@ -0,0 +1,345 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "conv_common.hpp" +#include "host_conv_bwd_data.hpp" +#include "device_tensor.hpp" +#include "device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp" +#include "device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp" + +#define USE_DYNAMIC_MODE 1 +#define USE_CONV_BWD_V4R1_XDL_NHWC 1 +#define USE_CONV_BWD_V4R1R2_XDL_NHWC 1 + +enum ConvBackwardDataAlgo +{ + V4R1XDLNHWC, + V4R1R2XDLNHWC, +}; + +int main(int argc, char* argv[]) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + +#if USE_DYNAMIC_MODE + // dynamic mode + if(argc != 22) + { + printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); + exit(1); + } + + const ConvTensorLayout layout = static_cast(atoi(argv[1])); + const ConvBackwardDataAlgo algo = static_cast(atoi(argv[2])); + const bool do_verification = atoi(argv[3]); + const int init_method = atoi(argv[4]); + const bool do_log = atoi(argv[5]); + const int nrepeat = atoi(argv[6]); + + const index_t N = atoi(argv[7]); + const index_t K = atoi(argv[8]); + const index_t C = atoi(argv[9]); + const index_t Y = atoi(argv[10]); + const index_t X = atoi(argv[11]); + const index_t Hi = atoi(argv[12]); + const index_t Wi = atoi(argv[13]); + + const index_t conv_stride_h = atoi(argv[14]); + const index_t conv_stride_w = atoi(argv[15]); + const index_t conv_dilation_h = atoi(argv[16]); + const index_t conv_dilation_w = atoi(argv[17]); + const index_t in_left_pad_h = atoi(argv[18]); + const index_t in_left_pad_w = atoi(argv[19]); + const index_t in_right_pad_h = atoi(argv[20]); + const index_t in_right_pad_w = atoi(argv[21]); + + const index_t YEff = (Y - 1) * conv_dilation_h + 1; + const index_t XEff = (X - 1) * conv_dilation_w + 1; + + const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; +#else + // static mode + if(argc < 7) + { + printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + exit(1); + } + + const ConvTensorLayout layout = static_cast(atoi(argv[1])); + const ConvBackwardDataAlgo algo = static_cast(atoi(argv[2])); + const bool do_verification = atoi(argv[3]); + const int init_method = atoi(argv[4]); + const bool do_log = atoi(argv[5]); + const int nrepeat = atoi(argv[6]); + + constexpr index_t N = 128; + constexpr index_t C = 192; + constexpr index_t Hi = 71; + constexpr index_t Wi = 71; + constexpr index_t K = 256; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + const index_t conv_stride_h = 2; + const index_t conv_stride_w = 2; + const index_t conv_dilation_h = 1; + const index_t conv_dilation_w = 1; + const index_t in_left_pad_h = 1; + const index_t in_left_pad_w = 1; + const index_t in_right_pad_h = 1; + const index_t in_right_pad_w = 1; + + const index_t YEff = (Y - 1) * conv_dilation_h + 1; + const index_t XEff = (X - 1) * conv_dilation_w + 1; + + const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; +#endif + +#if 1 + constexpr index_t in_vector_size = 1; + using in_data_t = float; + using acc_data_t = float; + using out_data_t = float; +#elif 1 + constexpr index_t in_vector_size = 1; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; +#endif + + std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); + + switch(layout) + { + case ConvTensorLayout::NCHW: + // NCHW + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(C); + in_lengths_host[2] = static_cast(Hi); + in_lengths_host[3] = static_cast(Wi); + wei_lengths_host[0] = static_cast(K); + wei_lengths_host[1] = static_cast(C); + wei_lengths_host[2] = static_cast(Y); + wei_lengths_host[3] = static_cast(X); + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(K); + out_lengths_host[2] = static_cast(Ho); + out_lengths_host[3] = static_cast(Wo); + break; + case ConvTensorLayout::NHWC: + // NHWC + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(Hi); + in_lengths_host[2] = static_cast(Wi); + in_lengths_host[3] = static_cast(C); + wei_lengths_host[0] = static_cast(K); + wei_lengths_host[1] = static_cast(Y); + wei_lengths_host[2] = static_cast(X); + wei_lengths_host[3] = static_cast(C); + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(Ho); + out_lengths_host[2] = static_cast(Wo); + out_lengths_host[3] = static_cast(K); + break; + default: throw std::runtime_error("wrong! not implemented"); + } + + Tensor in_host(in_lengths_host); + Tensor in_device(in_lengths_host); + Tensor wei(wei_lengths_host); + Tensor out(out_lengths_host); + + std::cout << "layout: " << layout << std::endl; + ostream_HostTensorDescriptor(in_host.mDesc, std::cout << "in: "); + ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: "); + ostream_HostTensorDescriptor(out.mDesc, std::cout << "out: "); + print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); + print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); + print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); + print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + if(do_verification) + { + switch(init_method) + { + case 0: + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 1: + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 2: + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + } + } + + auto f_make_for_device_nchw = [&]() { +#if USE_DYNAMIC_MODE + const auto in_lengths_dev = make_tuple(N, C, Hi, Wi); + const auto wei_lengths_dev = make_tuple(K, C, Y, X); + const auto out_lengths_dev = make_tuple(N, K, Ho, Wo); + const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); + const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); + const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); + const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); +#else + const auto in_lengths_dev = + make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto wei_lengths_dev = make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto out_lengths_dev = + make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto conv_strides_dev = make_tuple(Number{}, Number{}); + const auto conv_dilations_dev = + make_tuple(Number{}, Number{}); + const auto in_left_pads_dev = make_tuple(Number{}, Number{}); + const auto in_right_pads_dev = + make_tuple(Number{}, Number{}); +#endif + + return make_tuple(in_lengths_dev, + wei_lengths_dev, + out_lengths_dev, + conv_strides_dev, + conv_dilations_dev, + in_left_pads_dev, + in_right_pads_dev); + }; + + auto f_make_for_device_nhwc = [&]() { +#if USE_DYNAMIC_MODE + const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); + const auto wei_lengths_dev = make_tuple(K, Y, X, C); + const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); + const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); + const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); + const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); + const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); +#else + const auto in_lengths_dev = + make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto wei_lengths_dev = make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto out_lengths_dev = + make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto conv_strides_dev = make_tuple(Number{}, Number{}); + const auto conv_dilations_dev = + make_tuple(Number{}, Number{}); + const auto in_left_pads_dev = make_tuple(Number{}, Number{}); + const auto in_right_pads_dev = + make_tuple(Number{}, Number{}); +#endif + + return make_tuple(in_lengths_dev, + wei_lengths_dev, + out_lengths_dev, + conv_strides_dev, + conv_dilations_dev, + in_left_pads_dev, + in_right_pads_dev); + }; + + const auto nhwc_desc = f_make_for_device_nhwc(); + +#if USE_CONV_BWD_V4R1_XDL_NHWC + if(algo == ConvBackwardDataAlgo::V4R1XDLNHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk< + in_data_t, + acc_data_t, + out_data_t>(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in_device, + wei, + out, + nrepeat); + } +#endif + +#if USE_CONV_BWD_V4R1R2_XDL_NHWC + if(algo == ConvBackwardDataAlgo::V4R1R2XDLNHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk< + in_data_t, + acc_data_t, + out_data_t>(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in_device, + wei, + out, + nrepeat); + } +#endif + + if(do_verification) + { + host_direct_convolution_backward_data(in_host, + wei, + out, + make_tuple(conv_stride_h, conv_stride_w), + make_tuple(conv_dilation_h, conv_dilation_w), + make_tuple(in_left_pad_h, in_left_pad_w), + make_tuple(in_right_pad_h, in_right_pad_w), + layout); + + check_error(in_host, in_device); + + if(do_log) + { + LogRangeAsType(std::cout << "out : ", out.mData, ",") << std::endl; + LogRangeAsType(std::cout << "wei: ", wei.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in_host : ", in_host.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in_device: ", in_device.mData, ",") << std::endl; + } + } +} diff --git a/driver/conv_driver.cpp b/driver/conv_driver.cpp index 4b32c786b8..b116b21046 100644 --- a/driver/conv_driver.cpp +++ b/driver/conv_driver.cpp @@ -26,18 +26,32 @@ int main(int argc, char* argv[]) } const bool do_verification = atoi(argv[1]); - const int init_method = atoi(argv[2]); - const bool do_log = atoi(argv[3]); + const bool do_log = atoi(argv[2]); + const int init_method = atoi(argv[3]); const int nrepeat = atoi(argv[4]); #if 0 - constexpr index_t N = 8; - constexpr index_t C = 8; - constexpr index_t Hi = 4; - constexpr index_t Wi = 8; + constexpr index_t N = 256; + constexpr index_t C = 256; + constexpr index_t HI = 16; + constexpr index_t WI = 16; constexpr index_t K = 256; - constexpr index_t Y = 3; - constexpr index_t X = 3; + constexpr index_t Y = 1; + constexpr index_t X = 1; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using InLeftPads = Sequence<0, 0>; + using InRightPads = Sequence<0, 0>; +#elif 0 + constexpr index_t N = 1; + constexpr index_t C = 16; + constexpr index_t HI = 1080; + constexpr index_t WI = 1920; + constexpr index_t K = 16; + constexpr index_t Y = 1; + constexpr index_t X = 1; using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; @@ -162,9 +176,9 @@ int main(int argc, char* argv[]) // 3x3, 71x71 constexpr index_t N = 128; constexpr index_t C = 192; - constexpr index_t Hi = 71; - constexpr index_t Wi = 71; - constexpr index_t K = 128; + constexpr index_t HI = 71; + constexpr index_t WI = 71; + constexpr index_t K = 256; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -430,7 +444,7 @@ int main(int argc, char* argv[]) using InRightPads = Sequence<0, 0>; #elif 0 // 1x1, 14x14, stride 2 - constexpr index_t N = 128; + constexpr index_t N = 256; constexpr index_t C = 1024; constexpr index_t Hi = 14; constexpr index_t Wi = 14; @@ -445,7 +459,7 @@ int main(int argc, char* argv[]) using InRightPads = Sequence<0, 0>; #elif 0 // 1x1, 14x14 - constexpr index_t N = 128; + constexpr index_t N = 256; constexpr index_t C = 1024; constexpr index_t Hi = 14; constexpr index_t Wi = 14; @@ -636,6 +650,11 @@ int main(int argc, char* argv[]) using in_data_t = typename vector_type::type; using acc_data_t = float; using out_data_t = float; +#elif 1 + using in_data_t = half_t; + constexpr index_t in_vector_size = 1; + using acc_data_t = float; + using out_data_t = half_t; #elif 0 constexpr index_t in_vector_size = 1; using in_data_t = typename vector_type::type; diff --git a/driver/conv_driver_v2.cpp b/driver/conv_driver_v2.cpp index 693448ac25..530b779c2e 100644 --- a/driver/conv_driver_v2.cpp +++ b/driver/conv_driver_v2.cpp @@ -16,19 +16,31 @@ #include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" #define USE_DYNAMIC_MODE 1 #define USE_CONV_FWD_V4R4_NCHW 0 #define USE_CONV_FWD_V4R4_NHWC 0 -#define USE_CONV_FWD_V4R5_NCHW 1 +#define USE_CONV_FWD_V4R5_NCHW 0 #define USE_CONV_FWD_V5R1_NCHW 0 +#define USE_CONV_FWD_V4R4_XDL_NCHW 0 +#define USE_CONV_FWD_V4R4R2_XDL_NHWC 0 +#define USE_CONV_FWD_V4R4R3_XDL_NHWC 1 +#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1 enum ConvForwardAlgo { - V4R4NCHW, - V4R4NHWC, - V4R5NCHW, - V5R1NCHW + V4R4NCHW, // 0 + V4R4NHWC, // 1 + V4R5NCHW, // 2 + V5R1NCHW, // 3 + V4R4XDLNCHW, // 4 + V4R4R2XDLNHWC, // 5 + V4R4R3XDLNHWC, // 6 + V4R4R4XDLNHWC // 7 }; int main(int argc, char* argv[]) @@ -97,21 +109,21 @@ int main(int argc, char* argv[]) const int nrepeat = atoi(argv[6]); constexpr index_t N = 128; - constexpr index_t C = 128; - constexpr index_t Hi = 17; - constexpr index_t Wi = 17; - constexpr index_t K = 128; - constexpr index_t Y = 1; - constexpr index_t X = 7; + constexpr index_t C = 192; + constexpr index_t Hi = 71; + constexpr index_t Wi = 71; + constexpr index_t K = 256; + constexpr index_t Y = 3; + constexpr index_t X = 3; - const index_t conv_stride_h = 1; - const index_t conv_stride_w = 1; + const index_t conv_stride_h = 2; + const index_t conv_stride_w = 2; const index_t conv_dilation_h = 1; const index_t conv_dilation_w = 1; - const index_t in_left_pad_h = 0; - const index_t in_left_pad_w = 3; - const index_t in_right_pad_h = 0; - const index_t in_right_pad_w = 3; + const index_t in_left_pad_h = 1; + const index_t in_left_pad_w = 1; + const index_t in_right_pad_h = 1; + const index_t in_right_pad_w = 1; const index_t YEff = (Y - 1) * conv_dilation_h + 1; const index_t XEff = (X - 1) * conv_dilation_w + 1; @@ -120,11 +132,16 @@ int main(int argc, char* argv[]) const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; #endif -#if 1 +#if 0 constexpr index_t in_vector_size = 1; using in_data_t = float; using acc_data_t = float; using out_data_t = float; +#elif 1 + constexpr index_t in_vector_size = 1; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; #elif 1 constexpr index_t in_vector_size = 16; using in_data_t = int8_t; @@ -384,6 +401,114 @@ int main(int argc, char* argv[]) } #endif +#if USE_CONV_FWD_V4R4_XDL_NCHW + if(algo == ConvForwardAlgo::V4R4XDLNCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + + device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); + } +#endif + +#if USE_CONV_FWD_V4R4R2_XDL_NHWC + if(algo == ConvForwardAlgo::V4R4R2XDLNHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); + } +#endif + +#if USE_CONV_FWD_V4R4R3_XDL_NHWC + if(algo == ConvForwardAlgo::V4R4R3XDLNHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); + } +#endif + +#if USE_CONV_FWD_V4R4R4_XDL_NHWC + if(algo == ConvForwardAlgo::V4R4R4XDLNHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); + } +#endif + if(do_verification) { host_direct_convolution(in, @@ -397,6 +522,7 @@ int main(int argc, char* argv[]) check_error(out_host, out_device); +#if 0 if(do_log) { LogRange(std::cout << "in : ", in.mData, ",") << std::endl; @@ -404,5 +530,6 @@ int main(int argc, char* argv[]) LogRange(std::cout << "out_host : ", out_host.mData, ",") << std::endl; LogRange(std::cout << "out_device: ", out_device.mData, ",") << std::endl; } +#endif } } diff --git a/driver/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp b/driver/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..de48a0ea82 --- /dev/null +++ b/driver/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,340 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp" +#include "driver_dynamic_gemm_xdlops_v2r3.hpp" + +template +void device_dynamic_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + const Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); + +#if 1 + // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [256, 128, 4, 4] + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#endif + + const auto descs = + transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk(wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + in_n_hi_wi_c_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + I0, + I0, + Number{}); + + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto out_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto in_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 + + constexpr auto out_gemmk0_gemmn_gemmk1_grid_iterator_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 + + constexpr auto in_m0_m1_m2_n_grid_iterator_hacks = make_tuple( + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: MWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 3+: NWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 7+: N1 + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: MRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: NRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: MWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 3-: NWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1 + + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}; + + constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_dynamic_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperation::Set, + decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), + decltype(out_gemmk0_gemmn_gemmk1_grid_desc), + decltype(in_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<2, 0, 1>, + Sequence<0, 2, 1>, + 1, + GemmABlockTransferSrcScalarPerVector_GemmM, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmK1, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<1, 3, 7, 0, 2, 4, 5, 6>, + 6, + GemmCThreadTransferDstScalarPerVector, + decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks), + decltype(out_gemmk0_gemmn_gemmk1_grid_iterator_hacks), + decltype(in_m0_m1_m2_n_grid_iterator_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + wei_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc, + wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, + out_gemmk0_gemmn_gemmk1_grid_iterator_hacks, + in_m0_m1_m2_n_grid_iterator_hacks, + wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, + out_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Hi = in_n_hi_wi_c_lengths[I1]; + const auto Wi = in_n_hi_wi_c_lengths[I2]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data()); +} diff --git a/driver/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp b/driver/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..8332798690 --- /dev/null +++ b/driver/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,288 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp" +#include "driver_dynamic_gemm_xdlops_v2r3.hpp" + +template +void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + const Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); + +#if 1 + // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto descs = + transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(out_n_ho_wo_k_desc, + wei_k_y_x_c_desc, + in_n_hi_wi_c_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + I0, + I0, + Number{}); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto in_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto out_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmm + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 + + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmn + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 + + constexpr auto in_m0_m1_m2_n_grid_iterator_hacks = make_tuple( + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 2+: MWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: NWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 4+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 5+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 6+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N1 + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: MRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: NRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 2-: MWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: NWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 4-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 5-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1 + + constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; + + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_dynamic_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperation::Set, + decltype(out_gemmk0_gemmm_gemmk1_grid_desc), + decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), + decltype(in_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<2, 0, 1>, + Sequence<0, 2, 1>, + 1, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy +#if 0 + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, +#else + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, +#endif + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(out_gemmk0_gemmm_gemmk1_grid_iterator_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks), + decltype(in_m0_m1_m2_n_grid_iterator_hacks), + decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), + true // CAccessOrderMRepeatNRepeat + >(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc, + out_gemmk0_gemmm_gemmk1_grid_iterator_hacks, + wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks, + in_m0_m1_m2_n_grid_iterator_hacks, + out_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, + wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Hi = in_n_hi_wi_c_lengths[I1]; + const auto Wi = in_n_hi_wi_c_lengths[I2]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data()); +} diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..5890b12e00 --- /dev/null +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,283 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp" + +template +void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + const auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); + +#if 0 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmKPack = 8; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#elif 0 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmKPack = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#elif 0 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmKPack = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 4] + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmKPack = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 4] + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmKPack = 4; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; +#endif + + const auto descs = +#if 1 + transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad +#else + transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1 +#endif + ( + wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads); + + for(index_t i = 0; i < 5; ++i) + { +#if 0 + float ave_time = launch_kernel_dynamic_gemm_xdlops_v1 +#else + float ave_time = launch_kernel_dynamic_gemm_xdlops_v2 +#endif + , + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK, + GemmABlockTransferDstScalarPerVector_KPack, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<0, 2, 1>, + Sequence<1, 0, 2>, + 1, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_KPack, + false, // don't move back src coordinate after threadwise copy, which will be fused + // with MoveSrcSliceWindow() to save addr computation + Sequence<2, 3, 0, 1>, + 3, + GemmCThreadTransferDstScalarPerVector_GemmN1, + decltype(descs[I4]), + decltype(descs[I5]), + decltype(descs[I6]), + decltype(descs[I7]), + decltype(descs[I8])>(static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + descs[I0], + descs[I1], + descs[I2], + descs[I3], + descs[I4], + descs[I5], + descs[I6], + descs[I7], + descs[I8], + nrepeat); + + float perf = (float)calculate_convolution_flops( + in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); +} diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..9054c09d28 --- /dev/null +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,309 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp" +#include "driver_dynamic_gemm_xdlops_v2r2.hpp" + +template +void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + const auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); + +#if 0 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [256, 128, 4, 8] + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [256, 128, 4, 8] + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 4] + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4] + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto descs = + transform_forward_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto out_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); + + constexpr auto out_m0_m1_m2_n_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); + + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_dynamic_gemm_xdlops_v2r2< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperation::Set, + decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), + decltype(in_gemmk0_gemmn_gemmk1_grid_desc), + decltype(out_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<0, 2, 1>, + Sequence<1, 0, 2>, + 1, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<3, 0, 1, 2>, + 3, + GemmCThreadTransferDstScalarPerVector, + decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks), + decltype(out_m0_m1_m2_n_grid_iterator_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks)>( + static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + wei_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, + in_gemmk0_gemmn_gemmk1_grid_iterator_hacks, + out_m0_m1_m2_n_grid_iterator_hacks, + wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + nrepeat); + + float perf = (float)calculate_convolution_flops( + in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); +} diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..130f7c97e2 --- /dev/null +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,212 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" +#include "driver_dynamic_gemm_xdlops_v2r2.hpp" + +template +void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); + +#if 1 + // [M, N, K0, K1] = [256, 128, 4, 8] + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#endif + + const auto descs = + transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, + in_n_hi_wi_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto out_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); + + constexpr auto out_m0_m1_m2_n_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); + + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_dynamic_gemm_xdlops_v2r2< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperation::Set, + decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), + decltype(in_gemmk0_gemmn_gemmk1_grid_desc), + decltype(out_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmK1, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1>, + 2, + GemmCThreadTransferDstScalarPerVector, + decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks), + decltype(out_m0_m1_m2_n_grid_iterator_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks)>( + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + wei_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, + in_gemmk0_gemmn_gemmk1_grid_iterator_hacks, + out_m0_m1_m2_n_grid_iterator_hacks, + wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Hi = in_n_hi_wi_c_lengths[I1]; + const auto Wi = in_n_hi_wi_c_lengths[I2]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); +} diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..f030ed74eb --- /dev/null +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,277 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" +#include "driver_dynamic_gemm_xdlops_v2r3.hpp" + +template +void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); + +#if 0 + // [M, N, K0, K1] = [256, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#endif + + const auto descs = + transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, + in_n_hi_wi_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto out_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); + + constexpr auto out_m0_m1_m2_n_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); + + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_dynamic_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperation::Set, + decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), + decltype(in_gemmk0_gemmn_gemmk1_grid_desc), + decltype(out_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmK1, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 6, + GemmCThreadTransferDstScalarPerVector, + decltype(wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_iterator_hacks), + decltype(out_m0_m1_m2_n_grid_iterator_hacks), + decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + wei_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + wei_gemmk0_gemmm_gemmk1_grid_iterator_hacks, + in_gemmk0_gemmn_gemmk1_grid_iterator_hacks, + out_m0_m1_m2_n_grid_iterator_hacks, + wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Hi = in_n_hi_wi_c_lengths[I1]; + const auto Wi = in_n_hi_wi_c_lengths[I2]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); +} diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..0890bf2e7d --- /dev/null +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,364 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" +#include "driver_dynamic_gemm_xdlops_v2r3.hpp" + +template +void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); + +#if 0 + // [M, N, K0, K1] = [256, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto descs = + transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto out_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto in_gemmk0_gemmm_gemmk1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1 + + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 + + constexpr auto out_m0_m1_m2_n_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: MRepeat + Sequence<0, 0, 0, 0, 0>{}, // 1+: NRepeat + Sequence<0, 0, 0, 0, 0>{}, // 2+: MWaves + Sequence<0, 0, 0, 0, 0>{}, // 3+: NWaves + Sequence<0, 0, 0, 0, 0>{}, // 4+: M0 + Sequence<0, 0, 0, 0, 0>{}, // 5+: M1 + Sequence<0, 0, 0, 0, 0>{}, // 6+: M2 + Sequence<0, 0, 0, 0, 0>{}), // 7+: N1 + make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: MRepeat + Sequence<0, 0, 0, 0, 0>{}, // 1-: NRepeat + Sequence<0, 0, 0, 0, 0>{}, // 2-: MWaves + Sequence<0, 0, 0, 0, 0>{}, // 3-: NWaves + Sequence<0, 0, 0, 0, 0>{}, // 4-: M0 + Sequence<0, 0, 0, 0, 0>{}, // 5-: M1 + Sequence<0, 0, 0, 0, 0>{}, // 6-: M2 + Sequence<0, 0, 0, 0, 0>{})); // 7-: N1 + + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_dynamic_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperation::Set, + decltype(in_gemmk0_gemmm_gemmk1_grid_desc), + decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), + decltype(out_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmK1, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(in_gemmk0_gemmm_gemmk1_grid_iterator_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks), + decltype(out_m0_m1_m2_n_grid_iterator_hacks), + decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + in_gemmk0_gemmm_gemmk1_grid_iterator_hacks, + wei_gemmk0_gemmn_gemmk1_grid_iterator_hacks, + out_m0_m1_m2_n_grid_iterator_hacks, + in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks, + wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Hi = in_n_hi_wi_c_lengths[I1]; + const auto Wi = in_n_hi_wi_c_lengths[I2]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); +} diff --git a/driver/include/host_conv.hpp b/driver/include/host_conv.hpp index 6d7d758df6..7f26cb42f7 100644 --- a/driver/include/host_conv.hpp +++ b/driver/include/host_conv.hpp @@ -1,13 +1,13 @@ #pragma once #include "host_tensor.hpp" -template +template void host_direct_convolution(const Tensor& in, const Tensor& wei, Tensor& out, @@ -88,7 +88,7 @@ void host_direct_convolution(const Tensor& in, } } -template +template void host_winograd_3x3_convolution(const Tensor& in_nchw, const Tensor& wei_kcyx, Tensor& out_nkhw, diff --git a/driver/include/host_conv_bwd_data.hpp b/driver/include/host_conv_bwd_data.hpp index fbcfcd004f..07617c3926 100644 --- a/driver/include/host_conv_bwd_data.hpp +++ b/driver/include/host_conv_bwd_data.hpp @@ -6,56 +6,62 @@ template -void host_direct_convolution_backward_data(Tensor& in_nchw, - const Tensor& wei_kcyx, - const Tensor& out_nkhw, - ConvStrides, - ConvDilations, - LeftPads, - RightPads) + typename InLeftPads, + typename InRightPads> +void host_direct_convolution_backward_data(Tensor& in, + const Tensor& wei, + const Tensor& out, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const ConvTensorLayout layout = ConvTensorLayout::NCHW) { using namespace ck; - int N = in_nchw.mDesc.GetLengths()[0]; - int C = in_nchw.mDesc.GetLengths()[1]; - int HI = in_nchw.mDesc.GetLengths()[2]; - int WI = in_nchw.mDesc.GetLengths()[3]; + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; - std::size_t K = wei_kcyx.mDesc.GetLengths()[0]; - std::size_t Y = wei_kcyx.mDesc.GetLengths()[2]; - std::size_t X = wei_kcyx.mDesc.GetLengths()[3]; + auto f_nchw = [&](auto n, auto c, auto hi, auto wi) { + std::size_t N = in.mDesc.GetLengths()[I0]; + std::size_t C = in.mDesc.GetLengths()[I1]; + std::size_t Hi = in.mDesc.GetLengths()[I2]; + std::size_t Wi = in.mDesc.GetLengths()[I3]; - std::size_t HO = out_nkhw.mDesc.GetLengths()[2]; - std::size_t WO = out_nkhw.mDesc.GetLengths()[3]; + std::size_t K = wei.mDesc.GetLengths()[I0]; + std::size_t Y = wei.mDesc.GetLengths()[I2]; + std::size_t X = wei.mDesc.GetLengths()[I3]; + + std::size_t Ho = out.mDesc.GetLengths()[I2]; + std::size_t Wo = out.mDesc.GetLengths()[I3]; - auto f = [&](auto n, auto c, auto hi, auto wi) { double v = 0; for(int y = 0; y < Y; ++y) { - int h_tmp = hi + LeftPads{}[0] - y * ConvDilations{}[0]; + int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0]; - if(h_tmp % ConvStrides{}[0] == 0) + if(h_tmp % conv_strides[I0] == 0) { - int ho = h_tmp / ConvStrides{}[0]; + int ho = h_tmp / conv_strides[I0]; - if(ho >= 0 && ho < HO) + if(ho >= 0 && ho < Ho) { for(int x = 0; x < X; ++x) { - int w_tmp = wi + LeftPads{}[1] - x * ConvDilations{}[1]; + int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1]; - if(w_tmp % ConvStrides{}[1] == 0) + if(w_tmp % conv_strides[I1] == 0) { - int wo = w_tmp / ConvStrides{}[1]; + int wo = w_tmp / conv_strides[I1]; - if(wo >= 0 && wo < WO) + if(wo >= 0 && wo < Wo) { for(int k = 0; k < K; ++k) { - v += out_nkhw(n, k, ho, wo) * wei_kcyx(k, c, y, x); + v += out(n, k, ho, wo) * wei(k, c, y, x); } } } @@ -64,14 +70,74 @@ void host_direct_convolution_backward_data(Tensor& in_nchw, } } - in_nchw(n, c, hi, wi) = v; + in(n, c, hi, wi) = v; }; - auto f_par = make_ParallelTensorFunctor(f, - in_nchw.mDesc.GetLengths()[0], - in_nchw.mDesc.GetLengths()[1], - in_nchw.mDesc.GetLengths()[2], - in_nchw.mDesc.GetLengths()[3]); + auto f_nhwc = [&](auto n, auto hi, auto wi, auto c) { + std::size_t N = in.mDesc.GetLengths()[I0]; + std::size_t Hi = in.mDesc.GetLengths()[I1]; + std::size_t Wi = in.mDesc.GetLengths()[I2]; + std::size_t C = in.mDesc.GetLengths()[I3]; - f_par(std::thread::hardware_concurrency()); + std::size_t K = wei.mDesc.GetLengths()[I0]; + std::size_t Y = wei.mDesc.GetLengths()[I1]; + std::size_t X = wei.mDesc.GetLengths()[I2]; + + std::size_t Ho = out.mDesc.GetLengths()[I1]; + std::size_t Wo = out.mDesc.GetLengths()[I2]; + + double v = 0; + + for(int y = 0; y < Y; ++y) + { + int h_tmp = hi + in_left_pads[I0] - y * conv_dilations[I0]; + + if(h_tmp % conv_strides[I0] == 0) + { + int ho = h_tmp / conv_strides[I0]; + + if(ho >= 0 && ho < Ho) + { + for(int x = 0; x < X; ++x) + { + int w_tmp = wi + in_left_pads[I1] - x * conv_dilations[I1]; + + if(w_tmp % conv_strides[I1] == 0) + { + int wo = w_tmp / conv_strides[I1]; + + if(wo >= 0 && wo < Wo) + { + for(int k = 0; k < K; ++k) + { + v += out(n, ho, wo, k) * wei(k, y, x, c); + } + } + } + } + } + } + } + + in(n, hi, wi, c) = v; + }; + + switch(layout) + { + case ConvTensorLayout::NCHW: + make_ParallelTensorFunctor(f_nchw, + in.mDesc.GetLengths()[0], + in.mDesc.GetLengths()[1], + in.mDesc.GetLengths()[2], + in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + break; + case ConvTensorLayout::NHWC: + make_ParallelTensorFunctor(f_nhwc, + in.mDesc.GetLengths()[0], + in.mDesc.GetLengths()[1], + in.mDesc.GetLengths()[2], + in.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + break; + default: throw std::runtime_error("wrong! not supported layout"); + } } diff --git a/driver/include/host_tensor.hpp b/driver/include/host_tensor.hpp index 64d0ee26d3..d4998d511f 100644 --- a/driver/include/host_tensor.hpp +++ b/driver/include/host_tensor.hpp @@ -9,7 +9,7 @@ #include #include -template +template std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) { bool first = true; @@ -24,12 +24,27 @@ std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) return os; } +template +std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) +{ + bool first = true; + for(auto&& v : range) + { + if(first) + first = false; + else + os << delim; + os << T{v}; + } + return os; +} + typedef enum { Half = 0, Float = 1, } DataType_t; -template +template struct DataType; template <> @@ -37,13 +52,13 @@ struct DataType : std::integral_constant { }; -template +template auto call_f_unpack_args_impl(F f, T args, std::index_sequence) { return f(std::get(args)...); } -template +template auto call_f_unpack_args(F f, T args) { constexpr std::size_t N = std::tuple_size{}; @@ -51,13 +66,13 @@ auto call_f_unpack_args(F f, T args) return call_f_unpack_args_impl(f, args, std::make_index_sequence{}); } -template +template auto construct_f_unpack_args_impl(T args, std::index_sequence) { return F(std::get(args)...); } -template +template auto construct_f_unpack_args(F, T args) { constexpr std::size_t N = std::tuple_size{}; @@ -77,13 +92,13 @@ struct HostTensorDescriptor void CalculateStrides(); - template + template HostTensorDescriptor(const Range& lens) : mLens(lens.begin(), lens.end()) { this->CalculateStrides(); } - template + template HostTensorDescriptor(const Range1& lens, const Range2& strides) : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) { @@ -96,7 +111,7 @@ struct HostTensorDescriptor const std::vector& GetLengths() const; const std::vector& GetStrides() const; - template + template std::size_t GetOffsetFromMultiIndex(Is... is) const { assert(sizeof...(Is) == this->GetNumOfDimension()); @@ -111,7 +126,7 @@ struct HostTensorDescriptor struct joinable_thread : std::thread { - template + template joinable_thread(Xs&&... xs) : std::thread(std::forward(xs)...) { } @@ -126,7 +141,7 @@ struct joinable_thread : std::thread } }; -template +template struct ParallelTensorFunctor { F mF; @@ -180,26 +195,26 @@ struct ParallelTensorFunctor } }; -template +template auto make_ParallelTensorFunctor(F f, Xs... xs) { return ParallelTensorFunctor(f, xs...); } -template +template struct Tensor { - template + template Tensor(std::initializer_list lens) : mDesc(lens), mData(mDesc.GetElementSpace()) { } - template + template Tensor(std::vector lens) : mDesc(lens), mData(mDesc.GetElementSpace()) { } - template + template Tensor(std::vector lens, std::vector strides) : mDesc(lens, strides), mData(mDesc.GetElementSpace()) { @@ -207,7 +222,7 @@ struct Tensor Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpace()) {} - template + template void GenerateTensorValue(G g, std::size_t num_thread = 1) { switch(mDesc.GetNumOfDimension()) @@ -247,13 +262,13 @@ struct Tensor } } - template + template T& operator()(Is... is) { return mData[mDesc.GetOffsetFromMultiIndex(is...)]; } - template + template const T& operator()(Is... is) const { return mData[mDesc.GetOffsetFromMultiIndex(is...)]; @@ -285,7 +300,7 @@ HostTensorDescriptor::HostTensorDescriptor(std::vector lens, std::vector s void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout); -template +template void check_error(const Tensor& ref, const Tensor& result) { float error = 0; diff --git a/driver/include/host_tensor_generator.hpp b/driver/include/host_tensor_generator.hpp index d49d2d9122..a62045a182 100644 --- a/driver/include/host_tensor_generator.hpp +++ b/driver/include/host_tensor_generator.hpp @@ -1,13 +1,14 @@ #ifndef HOST_TENSOR_GENERATOR_HPP #define HOST_TENSOR_GENERATOR_HPP +#include #include "config.hpp" struct GeneratorTensor_1 { int value = 1; - template + template double operator()(Is... is) { return value; @@ -19,7 +20,7 @@ struct GeneratorTensor_2 int min_value = 0; int max_value = 1; - template + template double operator()(Is...) { return (std::rand() % (max_value - min_value)) + min_value; @@ -28,7 +29,7 @@ struct GeneratorTensor_2 struct GeneratorTensor_3 { - template + template double operator()(Is... is) { std::array dims = {{static_cast(is)...}}; @@ -41,7 +42,7 @@ struct GeneratorTensor_3 struct GeneratorTensor_Checkboard { - template + template double operator()(Ts... Xs) const { std::array dims = {{static_cast(Xs)...}}; diff --git a/script/cmake-rocm.sh b/script/cmake-rocm.sh index b94f45ea90..7a31c69dcb 100755 --- a/script/cmake-rocm.sh +++ b/script/cmake-rocm.sh @@ -3,24 +3,46 @@ rm -f CMakeCache.txt rm -f *.cmake rm -rf CMakeFiles -MY_PROJECT_SOURCE=../ +MY_PROJECT_SOURCE=../../../ MY_PROJECT_INSTALL=../install.dir cmake \ -D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \ -D CMAKE_BUILD_TYPE=Release \ --D DEVICE_BACKEND="AMD" \ --D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx1030 -gline-tables-only -save-temps=$CWD -ftemplate-backtrace-limit=0" \ +-D DEVICE_BACKEND=AMD \ +-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx908 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -gline-tables-only -save-temps=$CWD" \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ --D CMAKE_PREFIX_PATH="/opt/rocm" \ +-D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ ${MY_PROJECT_SOURCE} -#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx1030 -gline-tables-only -save-temps=$CWD -ftemplate-backtrace-limit=0" \ -#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx1030 -gline-tables-only -save-temps=$CWD -ftemplate-backtrace-limit=0 -mllvm -print-before=amdgpu-codegenprepare -mllvm -print-module-scope" \ -#-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -gline-tables-only -save-temps=$CWD" \ #-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-spill-vgpr-to-agpr=0" \ #-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-spill-vgpr-to-agpr=0 -save-temps=$CWD" \ #-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-enable-global-sgpr-addr -mllvm --amdgpu-spill-vgpr-to-agpr=0" \ #-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-enable-global-sgpr-addr -mllvm --amdgpu-spill-vgpr-to-agpr=0 -save-temps=$CWD" \ #-D CMAKE_CXX_FLAGS="-O3 --amdgpu-target=gfx906 -mllvm --amdgpu-enable-global-sgpr-addr -mllvm --amdgpu-spill-vgpr-to-agpr=0 -v -gline-tables-only -save-temps=$CWD" \ + +#CXX_FLAG_TMP=-Weverything +# -Wno-c++98-compat \ +# -Wno-c++98-compat-pedantic \ +# -Wno-conversion \ +# -Wno-double-promotion \ +# -Wno-exit-time-destructors \ +# -Wno-extra-semi \ +# -Wno-float-conversion \ +# -Wno-gnu-anonymous-struct \ +# -Wno-gnu-zero-variadic-macro-arguments \ +# -Wno-missing-noreturn \ +# -Wno-missing-prototypes \ +# -Wno-nested-anon-types \ +# -Wno-padded \ +# -Wno-return-std-move-in-c++11 \ +# -Wno-shorten-64-to-32 \ +# -Wno-sign-conversion \ +# -Wno-unknown-warning-option \ +# -Wno-unused-command-line-argument \ +# -Wno-weak-vtables \ +# -Wno-covered-switch-default \ +# -Wno-disabled-macro-expansion \ +# -Wno-undefined-reinterpret-cast + diff --git a/script/cmake-rocm3.1.sh b/script/cmake-rocm3.1.sh deleted file mode 100755 index c7bdb4f1c6..0000000000 --- a/script/cmake-rocm3.1.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash -rm -f CMakeCache.txt -rm -f *.cmake -rm -rf CMakeFiles - -MY_PROJECT_SOURCE=../../../ -MY_PROJECT_INSTALL=../install.dir - -cmake \ --D CMAKE_INSTALL_PREFIX=${MY_PROJECT_INSTALL} \ --D CMAKE_BUILD_TYPE=Release \ --D DEVICE_BACKEND="AMD" \ --D CMAKE_CXX_FLAGS="--amdgpu-target=gfx906" \ --D CMAKE_CXX_COMPILER=/opt/rocm/hip/bin/hipcc \ --D CMAKE_PREFIX_PATH="/opt/rocm" \ --D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -${MY_PROJECT_SOURCE} - -#-D CMAKE_CXX_FLAGS="-gline-tables-only -v --amdgpu-target=gfx906" \ diff --git a/script/compile-rocm3.1.sh b/script/compile-rocm3.1.sh deleted file mode 100755 index 0aebc1dd66..0000000000 --- a/script/compile-rocm3.1.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - export KMOPTLLC="-mattr=+enable-ds128 -amdgpu-enable-global-sgpr-addr" - export KMDUMPISA=1 - export KMDUMPLLVM=1 - export KMDUMPDIR=$PWD - - make -j $1 -#/opt/rocm/hcc/bin/llvm-objdump -mcpu=gfx906 -source -line-numbers driver/dump-gfx906.isabin > driver/dump-gfx906.isabin.asm diff --git a/script/docker-rocm3.7.sh b/script/docker-rocm4.1.sh old mode 100644 new mode 100755 similarity index 89% rename from script/docker-rocm3.7.sh rename to script/docker-rocm4.1.sh index e9aab49447..61cc33c5b8 --- a/script/docker-rocm3.7.sh +++ b/script/docker-rocm4.1.sh @@ -8,7 +8,7 @@ docker run \ --group-add sudo \ -w /root/workspace \ -v $WORKSPACE:/root/workspace \ -asroy/tensorflow:rocm3.7-tf2.3-dev-omp \ +rocm/tensorflow:rocm4.1-tf1.15-dev \ /bin/bash #--network host \ diff --git a/script/run.sh b/script/run.sh index 4160d9c08f..1a76adb876 100755 --- a/script/run.sh +++ b/script/run.sh @@ -1,14 +1,19 @@ #!/bin/bash +## GPU visibility + export ROCR_VISIBLE_DEVICE=0 + export GPU_DEVICE_ORDINAL=0 - export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH +## Boost +#export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH +## Compiling export OLC_DEBUG_HIP_VERBOSE=1 export OLC_DEBUG_HIP_DUMP=1 export OLC_DEBUG_SAVE_TEMP_DIR=1 -#make -j conv_driver #make -j conv_driver_v2 +#make -j conv_bwd_data_driver_v2 make -j conv_driver_v2_olc rm -rf /root/_hip_binary_kernels_/ @@ -21,11 +26,21 @@ INIT=$4 LOG=$5 REPEAT=$6 -###################### layout algo verify init log repeat N__ K__ C__ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads -#driver/conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 -#driver/conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 384 192 3 3 35 35 2 2 1 1 0 0 0 0 -#driver/conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 1 7 17 17 1 1 1 1 0 3 0 3 -#driver/conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +################################ layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads + ./conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 +#./conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 +#./conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 2048 3 3 14 14 1 1 1 1 1 1 1 1 +#./conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#./conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 -#./conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 - ./conv_driver_v2_olc $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 +#./conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0 +#./conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 +#./conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 + +#./conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 +#./conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#./conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 + +#./conv_bwd_data_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 + + ./conv_driver_v2_olc $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1