diff --git a/composable_kernel/include/driver/driver_dynamic_contraction_v1r1.hpp b/composable_kernel/include/driver/driver_dynamic_contraction_v1r1.hpp new file mode 100644 index 0000000000..0252f9487a --- /dev/null +++ b/composable_kernel/include/driver/driver_dynamic_contraction_v1r1.hpp @@ -0,0 +1,292 @@ +#ifndef CK_DRIVER_DYNAMIC_CONTRACTION_V1R1_HPP +#define CK_DRIVER_DYNAMIC_CONTRACTION_V1R1_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_contraction_v1r1.hpp" + +namespace ck { + +template +__host__ float +driver_dynamic_contraction_v1r1(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const AGKGM0GM1GridDesc& a_gk_gm0_gm1_grid_desc, + const BGKGN0GN1GridDesc& b_gk_gn0_gn1_grid_desc, + const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc, + AGridIteratorHacks, + BGridIteratorHacks, + CGridIteratorHacks, + AGridMoveSliceWindowIteratorHacks, + BGridMoveSliceWindowIteratorHacks, + index_t nrepeat) + +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + + // GEMM + using GridwiseContraction = GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + CGlobalMemoryDataOperation, + AGKGM0GM1GridDesc, + BGKGN0GN1GridDesc, + CGM0GM1GN0GN1GridDesc, + GM1PerBlockGM11, + GN1PerBlockGN11, + KPerBlock, + M1PerThread, + N1PerThread, + KPerThread, + M1N1ThreadClusterM10, + M1N1ThreadClusterN10, + M1N1ThreadClusterM11, + M1N1ThreadClusterN11, + ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11, + ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_GM11, + AThreadTransferSrcResetCoordinateAfterRun, + BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11, + BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_GN11, + BThreadTransferSrcResetCoordinateAfterRun, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + AGridIteratorHacks, + BGridIteratorHacks, + CGridIteratorHacks, + AGridMoveSliceWindowIteratorHacks, + BGridMoveSliceWindowIteratorHacks>; + + const auto K = a_gk_gm0_gm1_grid_desc.GetLength(I0); + + if(!GridwiseContraction::CheckValidity( + a_gk_gm0_gm1_grid_desc, b_gk_gn0_gn1_grid_desc, c_gm0_gm1_gn0_gn1_grid_desc)) + { + throw std::runtime_error( + "wrong! GridwiseDynamicContraction_km_kn0n1_mn0n1_v1r1 has invalid setting"); + } + + const auto a_gk_gm0_gm10_gm11_grid_desc = + GridwiseContraction::MakeAGKGM0GM10GM11GridDescriptor(a_gk_gm0_gm1_grid_desc); + const auto b_gk_gn0_gn10_gn11_grid_desc = + GridwiseContraction::MakeBGKGN0GN10GN11GridDescriptor(b_gk_gn0_gn1_grid_desc); + + using AGKGM0GM10GM11GridDesc = decltype(a_gk_gm0_gm10_gm11_grid_desc); + using BGKGN0GN10GN11GridDesc = decltype(b_gk_gn0_gn10_gn11_grid_desc); + + // c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc + const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc = + GridwiseContraction::MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(c_gm0_gm1_gn0_gn1_grid_desc); + + using CGM10BM0BM1GN10BN0BN1GridDesc = decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc); + + // c_blockid_to_gm10_gn10_block_cluster_adaptor + const auto c_blockid_to_gm10_gn10_block_cluster_adaptor = + GridwiseContraction::MakeCBlockIdToGM10GN10BlockClusterAdaptor(c_gm0_gm1_gn0_gn1_grid_desc); + + using CBlockIdToGM10GN10BlockClusterAdaptor = + decltype(c_blockid_to_gm10_gn10_block_cluster_adaptor); + + const index_t grid_size = GridwiseContraction::CalculateGridSize(c_gm0_gm1_gn0_gn1_grid_desc); + + const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(K); + + const bool has_double_tail_k_block_loop = + GridwiseContraction::CalculateHasDoubleTailKBlockLoop(K); + + { + std::cout << "a_gk_gm0_gm10_gm11_grid_desc{" << a_gk_gm0_gm10_gm11_grid_desc.GetLength(I0) + << ", " << a_gk_gm0_gm10_gm11_grid_desc.GetLength(I1) << ", " + << a_gk_gm0_gm10_gm11_grid_desc.GetLength(I2) << ", " + << a_gk_gm0_gm10_gm11_grid_desc.GetLength(I3) << "}" << std::endl; + + std::cout << "b_gk_gn0_gn10_gn11_grid_desc{" << b_gk_gn0_gn10_gn11_grid_desc.GetLength(I0) + << ", " << b_gk_gn0_gn10_gn11_grid_desc.GetLength(I1) << ", " + << b_gk_gn0_gn10_gn11_grid_desc.GetLength(I2) << ", " + << b_gk_gn0_gn10_gn11_grid_desc.GetLength(I3) << "}" << std::endl; + + std::cout << "c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc{ " + << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I0) << ", " + << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I1) << ", " + << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I2) << ", " + << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I3) << ", " + << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I4) << ", " + << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I5) << "}" << std::endl; + } + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = kernel_dynamic_contraction_v1r1< + GridwiseContraction, + FloatAB, + FloatC, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_gk_gm0_gm10_gm11_grid_desc, + b_gk_gn0_gn10_gn11_grid_desc, + c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, + c_blockid_to_gm10_gn10_block_cluster_adaptor); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = kernel_dynamic_contraction_v1r1< + GridwiseContraction, + FloatAB, + FloatC, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_gk_gm0_gm10_gm11_grid_desc, + b_gk_gn0_gn10_gn11_grid_desc, + c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, + c_blockid_to_gm10_gn10_block_cluster_adaptor); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = kernel_dynamic_contraction_v1r1< + GridwiseContraction, + FloatAB, + FloatC, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_gk_gm0_gm10_gm11_grid_desc, + b_gk_gn0_gn10_gn11_grid_desc, + c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, + c_blockid_to_gm10_gn10_block_cluster_adaptor); + } + else + { + const auto kernel = kernel_dynamic_contraction_v1r1< + GridwiseContraction, + FloatAB, + FloatC, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_gk_gm0_gm10_gm11_grid_desc, + b_gk_gn0_gn10_gn11_grid_desc, + c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, + c_blockid_to_gm10_gn10_block_cluster_adaptor); + } + + return ave_time; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/driver/driver_dynamic_gemm_v1.hpp b/composable_kernel/include/driver/driver_dynamic_gemm_v1.hpp deleted file mode 100644 index 4151fb72c3..0000000000 --- a/composable_kernel/include/driver/driver_dynamic_gemm_v1.hpp +++ /dev/null @@ -1,388 +0,0 @@ -#ifndef CK_DRIVER_DYNAMIC_GEMM_V1 -#define CK_DRIVER_DYNAMIC_GEMM_V1 - -#include "common_header.hpp" -#include "dynamic_tensor_descriptor.hpp" -#include "dynamic_tensor_descriptor_helper.hpp" -#include "gridwise_dynamic_gemm.hpp" -#include "gridwise_operation_wrapper.hpp" - -namespace ck { - -template -__host__ float launch_kernel_dynamic_gemm_v1(const FloatAB* p_a_global, - const FloatAB* p_b_global, - FloatC* p_c_global, - const AGlobalDesc& a_k_m_global_desc, - const BGlobalDesc& b_k_n_global_desc, - const CGlobalDesc& c_m0_m1_n0_n1_global_desc, - const CBlockClusterDesc& c_block_cluster_desc, - AGlobalIteratorHacks, - BGlobalIteratorHacks, - CGlobalIteratorHacks, - AGlobalMoveSliceWindowIteratorHacks, - BGlobalMoveSliceWindowIteratorHacks, - index_t nrepeat) - -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - const auto M = a_k_m_global_desc.GetLength(I1); - const auto N = b_k_n_global_desc.GetLength(I1); - const auto K = a_k_m_global_desc.GetLength(I0); - - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) - { - throw std::runtime_error("wrong! GEMM size no divisible"); - } - - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; - - if(!(MPerBlock % M1 == 0 && NPerBlock % N1 == 0)) - { - throw std::runtime_error("wrong! GEMM size no divisible"); - } - - // GEMM - using gridwise_gemm = - GridwiseDynamicGemm_km_kn_m0m1n0n1_v1; - - const auto GridSize = (M / MPerBlock) * (N / NPerBlock); - - const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1; - - const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0; - -#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE - float ave_time = 0; - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = kernel_dynamic_gemm_v1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - a_k_m_global_desc, - b_k_n_global_desc, - c_m0_m1_n0_n1_global_desc, - c_block_cluster_desc); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = kernel_dynamic_gemm_v1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - a_k_m_global_desc, - b_k_n_global_desc, - c_m0_m1_n0_n1_global_desc, - c_block_cluster_desc); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = kernel_dynamic_gemm_v1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - a_k_m_global_desc, - b_k_n_global_desc, - c_m0_m1_n0_n1_global_desc, - c_block_cluster_desc); - } - else - { - const auto kernel = kernel_dynamic_gemm_v1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - a_k_m_global_desc, - b_k_n_global_desc, - c_m0_m1_n0_n1_global_desc, - c_block_cluster_desc); - } - - return ave_time; -#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER - DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc)); - DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc)); - DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc)); - DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc)); - - a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc); - b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc); - c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc); - c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc); - - float ave_time = 0; - - if(has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = kernel_dynamic_gemm_v1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); - } - else if(has_main_k_block_loop && !has_double_tail_k_block_loop) - { - const auto kernel = kernel_dynamic_gemm_v1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - true, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); - } - else if(!has_main_k_block_loop && has_double_tail_k_block_loop) - { - const auto kernel = kernel_dynamic_gemm_v1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - true>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); - } - else - { - const auto kernel = kernel_dynamic_gemm_v1, - remove_reference_t, - remove_reference_t, - remove_reference_t, - false, - false>; - - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(GridSize), - dim3(BlockSize), - 0, - 0, - p_a_global, - p_b_global, - p_c_global, - (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), - (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); - } - - return ave_time; -#endif -} - -} // namespace ck -#endif diff --git a/composable_kernel/include/driver/driver_dynamic_gemm_v1r1.hpp b/composable_kernel/include/driver/driver_dynamic_gemm_v1r1.hpp new file mode 100644 index 0000000000..1b52d368fe --- /dev/null +++ b/composable_kernel/include/driver/driver_dynamic_gemm_v1r1.hpp @@ -0,0 +1,387 @@ +#ifndef CK_DRIVER_DYNAMIC_GEMM_V1 +#define CK_DRIVER_DYNAMIC_GEMM_V1 + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm_v1r1.hpp" + +namespace ck { + +template +__host__ float launch_kernel_dynamic_gemm_v1r1(const FloatAB* p_a_global, + const FloatAB* p_b_global, + FloatC* p_c_global, + const AGlobalDesc& a_k_m_global_desc, + const BGlobalDesc& b_k_n_global_desc, + const CGlobalDesc& c_m0_m1_n0_n1_global_desc, + const CBlockClusterDesc& c_block_cluster_desc, + AGlobalIteratorHacks, + BGlobalIteratorHacks, + CGlobalIteratorHacks, + AGlobalMoveSliceWindowIteratorHacks, + BGlobalMoveSliceWindowIteratorHacks, + index_t nrepeat) + +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto M = a_k_m_global_desc.GetLength(I1); + const auto N = b_k_n_global_desc.GetLength(I1); + const auto K = a_k_m_global_desc.GetLength(I0); + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + throw std::runtime_error("wrong! GEMM size no divisible"); + } + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + if(!(MPerBlock % M1 == 0 && NPerBlock % N1 == 0)) + { + throw std::runtime_error("wrong! GEMM size no divisible"); + } + + // GEMM + using gridwise_gemm = + GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1; + + const auto GridSize = (M / MPerBlock) * (N / NPerBlock); + + const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1; + + const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0; + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = kernel_dynamic_gemm_v1r1, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + p_a_global, + p_b_global, + p_c_global, + a_k_m_global_desc, + b_k_n_global_desc, + c_m0_m1_n0_n1_global_desc, + c_block_cluster_desc); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = kernel_dynamic_gemm_v1r1, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + p_a_global, + p_b_global, + p_c_global, + a_k_m_global_desc, + b_k_n_global_desc, + c_m0_m1_n0_n1_global_desc, + c_block_cluster_desc); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = kernel_dynamic_gemm_v1r1, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + p_a_global, + p_b_global, + p_c_global, + a_k_m_global_desc, + b_k_n_global_desc, + c_m0_m1_n0_n1_global_desc, + c_block_cluster_desc); + } + else + { + const auto kernel = kernel_dynamic_gemm_v1r1, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + p_a_global, + p_b_global, + p_c_global, + a_k_m_global_desc, + b_k_n_global_desc, + c_m0_m1_n0_n1_global_desc, + c_block_cluster_desc); + } + + return ave_time; +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + DeviceMem a_k_m_global_desc_device_buf(sizeof(AGlobalDesc)); + DeviceMem b_k_n_global_desc_device_buf(sizeof(BGlobalDesc)); + DeviceMem c_m0_m1_n0_n1_global_desc_device_buf(sizeof(CGlobalDesc)); + DeviceMem c_block_cluster_desc_device_buf(sizeof(c_block_cluster_desc)); + + a_k_m_global_desc_device_buf.ToDevice(&a_k_m_global_desc); + b_k_n_global_desc_device_buf.ToDevice(&b_k_n_global_desc); + c_m0_m1_n0_n1_global_desc_device_buf.ToDevice(&c_m0_m1_n0_n1_global_desc); + c_block_cluster_desc_device_buf.ToDevice(&c_block_cluster_desc); + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = kernel_dynamic_gemm_v1r1, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + p_a_global, + p_b_global, + p_c_global, + (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = kernel_dynamic_gemm_v1r1, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + p_a_global, + p_b_global, + p_c_global, + (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = kernel_dynamic_gemm_v1r1, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + p_a_global, + p_b_global, + p_c_global, + (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); + } + else + { + const auto kernel = kernel_dynamic_gemm_v1r1, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(GridSize), + dim3(BlockSize), + 0, + 0, + p_a_global, + p_b_global, + p_c_global, + (void __CONSTANT__*)a_k_m_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)b_k_n_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_m0_m1_n0_n1_global_desc_device_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_block_cluster_desc_device_buf.GetDeviceBuffer()); + } + + return ave_time; +#endif +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp b/composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp new file mode 100644 index 0000000000..527360d6b2 --- /dev/null +++ b/composable_kernel/include/driver/driver_dynamic_gemm_v1r2.hpp @@ -0,0 +1,285 @@ +#ifndef CK_DRIVER_DYNAMIC_GEMM_V1R2 +#define CK_DRIVER_DYNAMIC_GEMM_V1R2 + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm_v1r2.hpp" + +namespace ck { + +template +__host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const AKMGridDesc& a_k_m_grid_desc, + const BKNGridDesc& b_k_n_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + AGridIteratorHacks, + BGridIteratorHacks, + CGridIteratorHacks, + AGridMoveSliceWindowIteratorHacks, + BGridMoveSliceWindowIteratorHacks, + index_t nrepeat) + +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + + // GEMM + using GridwiseGemm = + GridwiseDynamicGemm_km_kn_mn_v1r2; + + const auto M = a_k_m_grid_desc.GetLength(I1); + const auto N = b_k_n_grid_desc.GetLength(I1); + const auto K = a_k_m_grid_desc.GetLength(I0); + + if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc)) + { + throw std::runtime_error("wrong! GridwiseDynamicGemm_km_kn_mn_v1r2 has invalid setting"); + } + + const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc); + const auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc); + + using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc); + using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc); + + // c_m0_m10_m11_n0_n10_n11_grid_desc + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = + GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); + + using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc); + + // c_blockid_to_m0_n0_block_cluster_adaptor + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); + + using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor); + + const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N); + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K); + + const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K); + + { + std::cout << "a_k_m0_m1_grid_desc{" << a_k_m0_m1_grid_desc.GetLength(I0) << ", " + << a_k_m0_m1_grid_desc.GetLength(I1) << ", " << a_k_m0_m1_grid_desc.GetLength(I2) + << "}" << std::endl; + + std::cout << "b_k_n0_n1_grid_desc{" << b_k_n0_n1_grid_desc.GetLength(I0) << ", " + << b_k_n0_n1_grid_desc.GetLength(I1) << ", " << b_k_n0_n1_grid_desc.GetLength(I2) + << "}" << std::endl; + + std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl; + } + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else + { + const auto kernel = + kernel_dynamic_gemm_v1r2, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + + return ave_time; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp index 16e25335fe..404129365f 100644 --- a/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -10,11 +10,7 @@ namespace ck { // GemmM = K // GemmN = N * Ho * Wo // GemmK = C * Y * X -template {}, Sequence<0, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); - const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); - const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); - - assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0); - - const auto GemmM0 = GemmM / Number{}; - const auto GemmN0 = GemmN / Number{}; - - const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor( - out_gemmm_gemmn_global_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), - make_unmerge_transform(make_tuple(GemmN0, GemmN1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); - - // out_gemm_block_cluster_desc - const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2( - make_tuple(GemmM / Number{}, GemmN / Number{})); - - // hack to control index calculation when iterating over wei_gemmk_gemmm_global tensor - constexpr auto wei_gemmk_gemmm_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); - - constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{}; - - // hack to control index calculation when iterating over in_gemmk_gemmn_global tensor - constexpr auto in_gemmk_gemmn_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{})); - - constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{}; - - // hack to control index calculation when iterating over out_gemmm0_gemmm1_gemmn0_gemmn1_global - // tensor hack for NKHW format - constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); - - return make_tuple(wei_gemmk_gemmm_global_desc, - in_gemmk_gemmn_global_desc, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - out_gemm_block_cluster_desc, - wei_gemmk_gemmm_global_iterator_hacks, - in_gemmk_gemmn_global_iterator_hacks, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks, - wei_gemmk_gemmm_global_move_slice_window_iterator_hacks, - in_gemmk_gemmn_global_move_slice_window_iterator_hacks); + return make_tuple( + wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); } -// GemmM = K -// GemmN = N * Ho * Wo -// GemmK = C * Y * X -template {}, Sequence<0, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); - const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); - const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); - - assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0); - - const auto GemmM0 = GemmM / Number{}; - const auto GemmN0 = GemmN / Number{}; - - const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor( - out_gemmm_gemmn_global_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), - make_unmerge_transform(make_tuple(GemmN0, GemmN1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); - - // out_gemm_block_cluster_desc - const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2( - make_tuple(GemmM / Number{}, GemmN / Number{})); - - // hack to control index calculation when iterating over a_k_m_global tensor - constexpr auto wei_gemmk_gemmm_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); - - constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{}; - - // hack to control index calculation when iterating over b_k_n_global tensor - constexpr auto in_gemmk_gemmn_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 1, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 1>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 2, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 2>{})); - - constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks = - Sequence<0, 0, 0, 0, 0, 1, 2>{}; - - // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor - // hack for NKHW format - constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); - - return make_tuple(wei_gemmk_gemmm_global_desc, - in_gemmk_gemmn_global_desc, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - out_gemm_block_cluster_desc, - wei_gemmk_gemmm_global_iterator_hacks, - in_gemmk_gemmn_global_iterator_hacks, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks, - wei_gemmk_gemmm_global_move_slice_window_iterator_hacks, - in_gemmk_gemmn_global_move_slice_window_iterator_hacks); + return make_tuple( + wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); } -// GemmM = K -// GemmN = N * Ho * Wo -// GemmK = C * Y * X -template {}, Sequence<0, 2>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); - const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); - const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); - - assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0); - - const auto GemmM0 = GemmM / Number{}; - const auto GemmN0 = GemmN / Number{}; - - const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor( - out_gemmm_gemmn_global_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), - make_unmerge_transform(make_tuple(GemmN0, GemmN1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); - - // out_gemm_block_cluster_desc - const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2( - make_tuple(GemmM / Number{}, GemmN / Number{})); - - // hack to control index calculation when iterating over a_k_m_global tensor - constexpr auto wei_gemmk_gemmm_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); - - constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{}; - - // hack to control index calculation when iterating over b_k_n_global tensor - constexpr auto in_gemmk_gemmn_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 1, 0>{}, Sequence<0, 0, 1>{}), - make_tuple(Sequence<0, 2, 0>{}, Sequence<0, 0, 2>{})); - - constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks = Sequence<0, 1, 2>{}; - - // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor - constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); - - return make_tuple(wei_gemmk_gemmm_global_desc, - in_gemmk_gemmn_global_desc, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - out_gemm_block_cluster_desc, - wei_gemmk_gemmm_global_iterator_hacks, - in_gemmk_gemmn_global_iterator_hacks, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks, - wei_gemmk_gemmm_global_move_slice_window_iterator_hacks, - in_gemmk_gemmn_global_move_slice_window_iterator_hacks); + return make_tuple( + wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); } } // namespace ck diff --git a/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp index 905efaabd7..987b3460c1 100644 --- a/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp +++ b/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp @@ -10,11 +10,7 @@ namespace ck { // GemmM = K // GemmN = N * Ho * Wo // GemmK = C * Y * X -template {}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); - const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); - const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); - const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); - - assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0); - - const auto GemmM0 = GemmM / Number{}; - const auto GemmN0 = GemmN / Number{}; - - const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor( - out_gemmm_gemmn_global_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), - make_unmerge_transform(make_tuple(GemmN0, GemmN1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); - - // out_gemm_block_cluster_desc - const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2( - make_tuple(GemmM / Number{}, GemmN / Number{})); - - // hack to control index calculation when iterating over a_k_m_global tensor - constexpr auto wei_gemmk_gemmm_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); - - constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{}; - - // hack to control index calculation when iterating over b_k_n_global tensor - constexpr auto in_gemmk_gemmn_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>{})); - - constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2>{}; - - // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor - // hack for NKHW format - constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); - - return make_tuple(wei_gemmk_gemmm_global_desc, - in_gemmk_gemmn_global_desc, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - out_gemm_block_cluster_desc, - wei_gemmk_gemmm_global_iterator_hacks, - in_gemmk_gemmn_global_iterator_hacks, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks, - wei_gemmk_gemmm_global_move_slice_window_iterator_hacks, - in_gemmk_gemmn_global_move_slice_window_iterator_hacks); + return make_tuple( + wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); } -// GemmM = K -// GemmN = N * Ho * Wo -// GemmK = C * Y * X -template {}, Sequence<1>{}), make_tuple(Sequence<1>{}, Sequence<0>{})); - const auto GemmM = out_gemmm_gemmn_global_desc.GetLength(I0); - const auto GemmN = out_gemmm_gemmn_global_desc.GetLength(I1); - const auto GemmK = wei_gemmk_gemmm_global_desc.GetLength(I0); - - assert(GemmM % GemmMPerBlock == 0 && GemmN % GemmNPerBlock == 0 && GemmK % GemmKPerBlock == 0); - - const auto GemmM0 = GemmM / Number{}; - const auto GemmN0 = GemmN / Number{}; - - const auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc = transform_dynamic_tensor_descriptor( - out_gemmm_gemmn_global_desc, - make_tuple(make_unmerge_transform(make_tuple(GemmM0, GemmM1)), - make_unmerge_transform(make_tuple(GemmN0, GemmN1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); - - // out_gemm_block_cluster_desc - const auto out_gemm_block_cluster_desc = make_cluster_descriptor_v2( - make_tuple(GemmM / Number{}, GemmN / Number{})); - - // hack to control index calculation when iterating over wei_gemmk_gemmm_global_iterator_hacks - // tensor - constexpr auto wei_gemmk_gemmm_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); - - constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{}; - - // hack to control index calculation when iterating over b_k_n_global tensor - constexpr auto in_gemmk_gemmn_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{})); - - constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks = Sequence<0, 0, 0>{}; - - // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor - constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{})); - - return make_tuple(wei_gemmk_gemmm_global_desc, - in_gemmk_gemmn_global_desc, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_desc, - out_gemm_block_cluster_desc, - wei_gemmk_gemmm_global_iterator_hacks, - in_gemmk_gemmn_global_iterator_hacks, - out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks, - wei_gemmk_gemmm_global_move_slice_window_iterator_hacks, - in_gemmk_gemmn_global_move_slice_window_iterator_hacks); + return make_tuple( + wei_gemmk_gemmm_global_desc, in_gemmk_gemmn_global_desc, out_gemmm_gemmn_global_desc); } } // namespace ck diff --git a/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r5_nchw_kcyx_nkhw.hpp b/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r5_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..ff2d4254c6 --- /dev/null +++ b/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r5_nchw_kcyx_nkhw.hpp @@ -0,0 +1,125 @@ +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5_NCHW_KCYX_NKHW_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmN = N * Ho * Wo +// GemmK = C * Y * X +template +__host__ __device__ constexpr auto +transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad( + const DynamicTensorDescriptor& wei_k_c_y_x_grid_desc, + const DynamicTensorDescriptor& in_n_c_hi_wi_grid_desc, + const DynamicTensorDescriptor& out_n_k_ho_wo_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); + const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + // weight tensor + const auto wei_gk_gm0_gm1_grid_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), + make_tuple(make_unmerge_transform(make_tuple(I1, K)), + make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1, 2>{}, Sequence<0>{})); + + // input tensor + const auto in_n_c_hip_wip_grid_desc = transform_dynamic_tensor_descriptor( + in_n_c_hi_wi_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + constexpr auto N0 = Number{}; + const auto N1 = N / N0; + + const auto in_n0_n1_c_y_ho_x_wo_grid_desc = transform_dynamic_tensor_descriptor( + in_n_c_hip_wip_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(N0, N1)), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3, 4>{}, Sequence<5, 6>{})); + + const auto in_gk_gn0_gn1_grid_desc = transform_dynamic_tensor_descriptor( + in_n0_n1_c_y_ho_x_wo_grid_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_pass_through_transform(N0), + make_merge_transform(make_tuple(N1, Ho, Wo))), + make_tuple(Sequence<2, 3, 5>{}, Sequence<0>{}, Sequence<1, 4, 6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // output tensor + const auto out_n_k_howo_grid_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)); + + const auto out_n0_n1_1_k_howo_grid_desc = transform_dynamic_tensor_descriptor( + out_n_k_howo_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(Number{}, N1)), + make_unmerge_transform(make_tuple(I1, K)), + make_pass_through_transform(Ho * Wo)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{})); + + const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_dynamic_tensor_descriptor( + out_n0_n1_1_k_howo_grid_desc, + make_tuple(make_pass_through_transform(I1), + make_pass_through_transform(K), + make_pass_through_transform(Number{}), + make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))), + make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + return make_tuple( + wei_gk_gm0_gm1_grid_desc, in_gk_gn0_gn1_grid_desc, out_gm0_gm1_gn0_gn1_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp b/composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp index 23748dad59..145099095f 100644 --- a/composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp +++ b/composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp @@ -1164,6 +1164,165 @@ struct DynamicMerge_v2_magic_division } }; +// Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering +// of both multi-index and delta of multi-index +// Caution: +// 1. The magic number division implementation being used would produce correct result if the +// dividended is uint32_t and its value is with in 31-bit value range of uint32_t. +// 2. The magic number division for int32_t dividened has not been implemented, the int32_t +// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for +// uint32_t is then used. +// 3. For Merge primitive, upper-index is the dividend. +// 4. When upper-index is uint32_t, its value need to be within 31-bit range. +// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be +// non-negative. +template +struct DynamicMerge_v2r2_magic_division +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using LowLengthsScan = decltype( + container_reverse_exclusive_scan(LowLengths{}, math::multiplies_v2{}, Number<1>{})); + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies_v2{}, Number<1>{}))); + + using LowLengthsScanMagicDivisorMultipiler = decltype(generate_tuple( + lambda_merge_generate_MagicDivision_calculate_magic_multiplier{}, + Number{})); + + using LowLengthsScanMagicDivisorShift = decltype( + generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_shift{}, + Number{})); + + LowLengths low_lengths_; + LowLengthsScan low_lengths_scan_; + LowLengthsScanMagicDivisorMultipiler low_lengths_scan_magic_divisor_multiplier_; + LowLengthsScanMagicDivisorShift low_lengths_scan_magic_divisor_shift_; + UpLengths up_lengths_; + + __host__ __device__ constexpr DynamicMerge_v2r2_magic_division() = default; + + __host__ __device__ constexpr DynamicMerge_v2r2_magic_division(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_scan_{ + container_reverse_exclusive_scan(low_lengths, math::multiplies_v2{}, Number<1>{})}, + low_lengths_scan_magic_divisor_multiplier_{generate_tuple( + [&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths_scan_[i]); }, + Number{})}, + low_lengths_scan_magic_divisor_shift_{generate_tuple( + [&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths_scan_[i]); }, + Number{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies_v2{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&, this](auto i) { + idx_low(i) = + MagicDivision::DoMagicDivision(tmp, + this->low_lengths_scan_magic_divisor_multiplier_[i], + this->low_lengths_scan_magic_divisor_shift_[i]); + + tmp -= idx_low[i] * this->low_lengths_scan_[i]; + }); + + idx_low(Number{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up_new[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&, this](auto i) { + index_t idx_low_old = idx_low[i]; + + idx_low(i) = + MagicDivision::DoMagicDivision(tmp, + this->low_lengths_scan_magic_divisor_multiplier_[i], + this->low_lengths_scan_magic_divisor_shift_[i]); + + idx_diff_low(i) = idx_low[i] - idx_low_old; + + tmp -= idx_low[i] * this->low_lengths_scan_[i]; + }); + + idx_diff_low(Number{}) = tmp - idx_low[Number{}]; + + idx_low(Number{}) = tmp; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("DynamicMerge_v2r2_magic_division, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_scan "); + print_multi_index(low_lengths_scan_); + printf("low_lengths_scan_magic_divisor_multiplier_ "); + print_multi_index(low_lengths_scan_magic_divisor_multiplier_); + printf("low_lengths_scan_magic_divisor_shift_ "); + print_multi_index(low_lengths_scan_magic_divisor_shift_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + template struct DynamicUnMerge { diff --git a/composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp b/composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp index 342be83d17..b27f0507c8 100644 --- a/composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp +++ b/composable_kernel/include/tensor_description/dynamic_multi_index_transform_helper.hpp @@ -56,8 +56,19 @@ __host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_le #if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION return DynamicMerge_v1_carry_check{low_lengths}; #else +#if 1 return DynamicMerge_v2_magic_division{low_lengths}; +#else + return DynamicMerge_v2r2_magic_division{low_lengths}; #endif +#endif +} + +template +__host__ __device__ constexpr auto +make_merge_transform_v2_magic_division(const LowLengths& low_lengths) +{ + return DynamicMerge_v2_magic_division{low_lengths}; } template diff --git a/composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp b/composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp index 03c2fccb2e..9d809576b8 100644 --- a/composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp +++ b/composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp @@ -308,6 +308,19 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss) { + // sanity check + { + constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); }, + NewLowerDimensionOldVisibleIdss{}); + + constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); }, + NewUpperDimensionNewVisibleIdss{}); + + static_assert(is_valid_sequence_map::value && + is_valid_sequence_map::value, + "wrong!"); + } + // lower dimension's hidden idss // convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of // sequences) diff --git a/composable_kernel/include/tensor_description/dynamic_tensor_descriptor_helper.hpp b/composable_kernel/include/tensor_description/dynamic_tensor_descriptor_helper.hpp index 385edab1c0..9b7db43664 100644 --- a/composable_kernel/include/tensor_description/dynamic_tensor_descriptor_helper.hpp +++ b/composable_kernel/include/tensor_description/dynamic_tensor_descriptor_helper.hpp @@ -115,26 +115,30 @@ template __host__ __device__ constexpr auto make_dynamic_naive_tensor_descriptor_aligned_v2(const Tuple& lengths, Align align) { + constexpr auto I1 = Number<1>{}; + constexpr index_t N = sizeof...(Lengths); + const auto stride_n_minus_2 = math::integer_least_multiple(lengths[Number{}], align); + auto strides = generate_tuple( [&](auto i) { if constexpr(i.value == N - 1) { - return Number<1>{}; + return I1; } else if constexpr(i.value == N - 2) { - return math::lcm(lengths[Number{}], align); + return Number{}; } else { return container_reduce(lengths, math::multiplies_v2{}, - math::lcm(lengths[Number{}], align), - i, + Number{}, + i + I1, Number{}, - Number<1>{}); + I1); } }, Number{}); diff --git a/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp index 54362c45b9..ca116ef17e 100644 --- a/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer.hpp @@ -31,8 +31,8 @@ template + bool ThreadTransferSrcResetCoordinateAfterRun, + bool ThreadTransferDstResetCoordinateAfterRun> struct BlockwiseDynamicTensorSliceTransfer_v4 { static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp index 2bb6187cf0..7e2f924b58 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp @@ -29,10 +29,10 @@ template {}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); // thread position 4-d thread space + // upper: {M1N1ThreadClusterM10, M1N1ThreadClusterM11, M1N1ThreadClusterN10, + // M1N1ThreadClusterN11} lower: {1, M1N1ThreadClusterM10 * M1N1ThreadClusterM11, 1, + // M1N1ThreadClusterN10 * M1N1ThreadClusterN11} constexpr auto adaptor1 = make_single_stage_tensor_adaptor( make_tuple( make_freeze_transform(make_multi_index(0)), - make_unmerge_transform(make_tuple(MLevel1ThreadCluster, MLevel0ThreadCluster)), + make_unmerge_transform(make_tuple(M1N1ThreadClusterM10, M1N1ThreadClusterM11)), make_freeze_transform(make_multi_index(0)), - make_unmerge_transform(make_tuple(NLevel1ThreadCluster, NLevel0ThreadCluster))), + make_unmerge_transform(make_tuple(M1N1ThreadClusterN10, M1N1ThreadClusterN11))), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{})); // 4-d thread space to 1-d thread space + // upper: {BlockSize} + // lower: {M1N1ThreadClusterM10, M1N1ThreadClusterM11, M1N1ThreadClusterN10, + // M1N1ThreadClusterN11} constexpr auto adaptor2 = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(MLevel1ThreadCluster, - NLevel1ThreadCluster, - MLevel0ThreadCluster, - NLevel0ThreadCluster))), + make_tuple(make_merge_transform(make_tuple(M1N1ThreadClusterM10, + M1N1ThreadClusterN10, + M1N1ThreadClusterM11, + M1N1ThreadClusterN11))), make_tuple(Sequence<0, 2, 1, 3>{}), make_tuple(Sequence<0>{})); @@ -221,10 +229,10 @@ template {}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{})); // 4-d thread space to 1-d thread space constexpr auto adaptor2 = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(MLevel1ThreadCluster, - NLevel1ThreadCluster, - MLevel0ThreadCluster, - NLevel0ThreadCluster))), + make_tuple(make_merge_transform(make_tuple(M1N1ThreadClusterM10, + M1N1ThreadClusterN10, + M1N1ThreadClusterM11, + M1N1ThreadClusterN11))), make_tuple(Sequence<0, 2, 1, 3>{}), make_tuple(Sequence<0>{})); diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_v2r2.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_v2r2.hpp new file mode 100644 index 0000000000..97fbc0bbaf --- /dev/null +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_v2r2.hpp @@ -0,0 +1,396 @@ +#ifndef CK_BLOCKWISE_GEMM_V2R2_HPP +#define CK_BLOCKWISE_GEMM_V2R2_HPP + +#include "common_header.hpp" +#include "tensor_adaptor.hpp" +#include "threadwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_gemm_v2.hpp" + +namespace ck { + +// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1] +// A and B are visable to the whole block, C is distributed among each thread +// Assume: +// 1. A: +// 1. AKMBlockDesc is known at compile-time +// 2. ABlockBuffer is DynamicBuffer +// 2. B: +// 1. BKNBlockDesc is known at compile-time +// 2. BBlockBuffer is DynamicBuffer +// 3. C: +// 1. CM0M1N0N1ThreadDesc is known at compile-time +// 2. CThreadBuffer is StaticBuffer +// Also assume: +// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization) +template ::type = false> +struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2 +{ + using AIndex = MultiIndex<3>; + using BIndex = MultiIndex<3>; + using CIndex = MultiIndex<4>; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr index_t K = AKMBlockDesc{}.GetLength(I0); + static constexpr index_t M = AKMBlockDesc{}.GetLength(I1); + static constexpr index_t N = BKNBlockDesc{}.GetLength(I1); + + static constexpr index_t M100 = M1N1ThreadClusterM100; + static constexpr index_t N100 = M1N1ThreadClusterN100; + + static constexpr index_t M101 = M1N1ThreadClusterM101; + static constexpr index_t N101 = M1N1ThreadClusterN101; + + static constexpr index_t M11 = M1PerThreadM11; + static constexpr index_t N11 = N1PerThreadN11; + + static constexpr index_t M1 = M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThreadM11; + static constexpr index_t N1 = M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThreadN11; + + static constexpr index_t M0 = M / M1; + static constexpr index_t N0 = N / N1; + + __host__ __device__ static constexpr auto + MakeAKM0M1BlockDescriptor(const AKMBlockDesc& a_k_m_block_desc) + { + const auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor( + AKMBlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + return a_k_m0_m1_block_desc; + } + + __host__ __device__ static constexpr auto + MakeBKN0N1BlockDescriptor(const BKNBlockDesc& b_k_n_block_desc) + { + const auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor( + BKNBlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + return b_k_n0_n1_block_desc; + } + + __host__ __device__ static constexpr auto MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor() + { + // upper: [M0, M100, M101, M11, N0, N100, N101, N11] + // lower: [M, N] + constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{})); + + return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor; + } + + __host__ __device__ static constexpr auto + MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor() + { + // upper: [M0, M100, M101, M11, N0, N100, N101, N11] + // lower: [M0, M1, N0, N1] + constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{})); + + return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor; + } + + __host__ __device__ static constexpr auto GetCM0M1N0N1ThreadTensorLengths() + { + return Sequence{}; + } + + static constexpr auto a_k_m0_m1_block_desc_ = MakeAKM0M1BlockDescriptor(AKMBlockDesc{}); + static constexpr auto b_k_n0_n1_block_desc_ = MakeBKN0N1BlockDescriptor(BKNBlockDesc{}); + + public: + __device__ BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2() + : c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock( + get_thread_local_1d_id())}, + a_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])}, + b_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])} + { + static_assert(AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(BlockSize == M101 * M100 * N101 * N100, + "wrong! blocksize and cluster size not consistent"); + + static_assert(M % M1 == 0 && N % N1 == 0, "wrong!"); + + static_assert(AKMBlockDesc{}.GetLength(I0) == BKNBlockDesc{}.GetLength(I0), + "wrong! K dimension not consistent"); + + // TODO: remove this restriction + static_assert(M0 == 2 && N0 == 2, "wrong"); + } + + __device__ static CIndex CalculateCM0M1N0N1ThreadOriginOnBlock(index_t thread_id) + { + // lower: [M0, M1, N0, N1] + // upper: [M0, M100, M101, M11, N0, N100, N101, N11] + constexpr auto adaptor0 = MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor(); + + // lower: [M0, M100, M101, M11, N0, N100, N101, N11] + // upper: [Tid, M0, M11, N0, N11] + constexpr auto adaptor1 = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M100, N100, M101, N101)), + make_pass_through_transform(M0), + make_pass_through_transform(M11), + make_pass_through_transform(N0), + make_pass_through_transform(N11)), + make_tuple( + Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1); + + return adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id(), 0, 0, 0, 0)); + } + + __host__ __device__ static constexpr index_t GetABlockAlignment() { return M1PerThreadM11; } + + __host__ __device__ static constexpr auto GetBBlockAlignment() { return N1PerThreadN11; } + + template + __device__ void Run(const CM0M1N0N1ThreadDesc& c_m0_m1_n0_n1_thread_desc, + const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + static_assert(CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + // TODO: remove this restriction + static_assert(M0 == 2 && N0 == 2 && CM0M1N0N1ThreadDesc{}.GetLength(I0) == M0 && + CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0, + "wrong"); + + auto a_thread_buf = make_static_buffer( + a_k_m0_m1_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_k_n0_n1_thread_desc_.GetElementSpaceSize()); + + constexpr auto threadwise_gemm = + ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1, + Sequence<1, M1PerThreadM11>, + Sequence<1, N1PerThreadN11>>{}; + + // read A_sub_0 + a_thread_copy_.Run(a_k_m0_m1_block_desc_, + make_tuple(I0, I0, I0), + a_block_buf, + a_k_m0_m1_thread_desc_, + make_tuple(I0, I0, I0), + a_thread_buf); + + // read B_sub_0 + b_thread_copy_.Run(b_k_n0_n1_block_desc_, + make_tuple(I0, I0, I0), + b_block_buf, + b_k_n0_n1_thread_desc_, + make_tuple(I0, I0, I0), + b_thread_buf); + + // read B_sub_1 + b_thread_copy_.Run(b_k_n0_n1_block_desc_, + make_tuple(I0, I1, I0), + b_block_buf, + b_k_n0_n1_thread_desc_, + make_tuple(I0, I1, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(a_k_m0_m1_block_desc_, + make_tuple(I0, I1, I0), + a_block_buf, + a_k_m0_m1_thread_desc_, + make_tuple(I0, I1, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); + + // loop over rest of k + static_for{}([&](auto k) { + // read A_sub_0 + a_thread_copy_.Run(a_k_m0_m1_block_desc_, + make_tuple(k, I0, I0), + a_block_buf, + a_k_m0_m1_thread_desc_, + make_tuple(I0, I0, I0), + a_thread_buf); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); + + // read B_sub_0 + b_thread_copy_.Run(b_k_n0_n1_block_desc_, + make_tuple(k, I0, I0), + b_block_buf, + b_k_n0_n1_thread_desc_, + make_tuple(I0, I0, I0), + b_thread_buf); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); + + // read B_sub_1 + b_thread_copy_.Run(b_k_n0_n1_block_desc_, + make_tuple(k, I1, I0), + b_block_buf, + b_k_n0_n1_thread_desc_, + make_tuple(I0, I1, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(a_k_m0_m1_block_desc_, + make_tuple(k, I1, I0), + a_block_buf, + a_k_m0_m1_thread_desc_, + make_tuple(I0, I1, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); + }); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); + } + + private: + // A[K, M0, M1] + static constexpr auto a_k_m0_m1_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{}, Number{})); + + // B[K, N0, N1] + static constexpr auto b_k_n0_n1_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{}, Number{})); + + using AThreadCopy = + ThreadwiseDynamicTensorSliceTransfer_v4, + Sequence<0, 1, 2>, + 2, + AThreadCopyScalarPerVector_M11, + 1>; + + using BThreadCopy = + ThreadwiseDynamicTensorSliceTransfer_v4, + Sequence<0, 1, 2>, + 2, + BThreadCopyScalarPerVector_N11, + 1>; + + CIndex c_thread_origin_data_idx_; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_v1r1.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_v1r1.hpp new file mode 100644 index 0000000000..05d070b94c --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_v1r1.hpp @@ -0,0 +1,680 @@ +#ifndef CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R1_HPP +#define CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R1_HPP + +#include "common_header.hpp" +#include "dynamic_multi_index_transform_helper.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "blockwise_gemm_v2r2.hpp" +#include "blockwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_dynamic_tensor_slice_set.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_dynamic_contraction_v1r1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGKGM0GM10GM11GridDesc a_gk_gm0_gm10_gm11_grid_desc, + const BGKGN0GN10GN11GridDesc b_gk_gn0_gn10_gn11_grid_desc, + const CGM10BM0BM1GN10BN0BN1GridDesc c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, + const CBlockIdToGM10GN10BlockClusterAdaptor + c_blockid_to_gm10_gn10_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseContraction::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_gk_gm0_gm10_gm11_grid_desc, + b_gk_gn0_gn10_gn11_grid_desc, + c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, + c_blockid_to_gm10_gn10_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} + +template +struct GridwiseDynamicContraction_km0m1_kn0n1_m0m1n0n1_v1r1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto GM0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I0); + static constexpr auto GN0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I2); + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = math::lcm(Number{}, + Number{}, + Number{}, + Number{}); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_gk_gm0_gm10_gm11_block_desc = + make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GM0, I1, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_gk_gn0_gn10_gn11_block_desc = + make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GN0, I1, Number{}), max_lds_align); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = math::integer_least_multiple( + a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = math::integer_least_multiple( + b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize(), max_lds_align); + + return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); + } + + __host__ __device__ static constexpr bool + CheckValidity(const AGKGM0GM1GridDesc& a_gk_gm0_gm1_grid_desc, + const BGKGN0GN1GridDesc& b_gk_gn0_gn1_grid_desc, + const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc) + { + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong!"); + + const auto GM1 = a_gk_gm0_gm1_grid_desc.GetLength(I2); + const auto GN1 = b_gk_gn0_gn1_grid_desc.GetLength(I2); + const auto GK = a_gk_gm0_gm1_grid_desc.GetLength(I0); + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + return ((GM0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I0) && + GM1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1) && + GN0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I2) && + GN1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3) && + GM0 == a_gk_gm0_gm1_grid_desc.GetLength(I1) && + GM1 == a_gk_gm0_gm1_grid_desc.GetLength(I2) && + GN0 == b_gk_gn0_gn1_grid_desc.GetLength(I1) && + GN1 == b_gk_gn0_gn1_grid_desc.GetLength(I2) && + GK == b_gk_gn0_gn1_grid_desc.GetLength(I0)) && + (GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK % KPerBlock == 0)); + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc) + { + const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1); + const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3); + + constexpr index_t GM11 = GM1PerBlockGM11; + constexpr index_t GN11 = GN1PerBlockGN11; + + const index_t GM10 = GM1 / GM11; + const index_t GN10 = GN1 / GN11; + + const index_t grid_size = GM10 * GN10; + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK) + { + const bool has_main_k_block_loop = (GK + KPerBlock) / (2 * KPerBlock) > 1; + + return has_main_k_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK) + { + const bool has_double_tail_k_block_loop = (GK / KPerBlock) % 2 == 0; + + return has_double_tail_k_block_loop; + } + + __host__ __device__ static constexpr auto + MakeAGKGM0GM10GM11GridDescriptor(const AGKGM0GM1GridDesc& a_gk_gm0_gm1_grid_desc) + { + const auto GK = a_gk_gm0_gm1_grid_desc.GetLength(I0); + const auto GM1 = a_gk_gm0_gm1_grid_desc.GetLength(I2); + + const auto GM11 = Number{}; + const auto GM10 = GM1 / GM11; + + const auto a_gk_gm0_gm10_gm11_grid_desc = transform_dynamic_tensor_descriptor( + a_gk_gm0_gm1_grid_desc, + make_tuple(make_pass_through_transform(GK), + make_pass_through_transform(GM0), + make_unmerge_transform(make_tuple(GM10, GM11))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); + + return a_gk_gm0_gm10_gm11_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeBGKGN0GN10GN11GridDescriptor(const BGKGN0GN1GridDesc& b_gk_gn0_gn1_grid_desc) + { + const auto GK = b_gk_gn0_gn1_grid_desc.GetLength(I0); + const auto GN1 = b_gk_gn0_gn1_grid_desc.GetLength(I2); + + const auto GN11 = Number{}; + const auto GN10 = GN1 / GN11; + + const auto b_gk_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor( + b_gk_gn0_gn1_grid_desc, + make_tuple(make_pass_through_transform(GK), + make_pass_through_transform(GN0), + make_unmerge_transform(make_tuple(GN10, GN11))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{})); + + return b_gk_gn0_gn10_gn11_grid_desc; + } + + __host__ __device__ static constexpr auto MakeCGM10BM0BM1GN10BN0BN1GridDescriptor( + const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc) + { + const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1); + const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3); + + constexpr auto GM11 = Number{}; + constexpr auto GN11 = Number{}; + + const auto GM10 = GM1 / GM11; + const auto GN10 = GN1 / GN11; + + constexpr auto BM = GM0 * GM11; + constexpr auto BN = GN0 * GN11; + + constexpr auto BM1 = + Number{}; + constexpr auto BN1 = + Number{}; + + constexpr auto BM0 = BM / BM1; + constexpr auto BN0 = BN / BN1; + + const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor( + c_gm0_gm1_gn0_gn1_grid_desc, + make_tuple(make_pass_through_transform(GM0), + make_unmerge_transform(make_tuple(GM10, GM11)), + make_pass_through_transform(GN0), + make_unmerge_transform(make_tuple(GN10, GN11))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{})); + + const auto c_gm10_bm_gn10_bn_grid_desc = transform_dynamic_tensor_descriptor( + c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc, + make_tuple(make_pass_through_transform(GM10), + make_merge_transform(make_tuple(GM0, GM11)), + make_pass_through_transform(GN10), + make_merge_transform(make_tuple(GN0, GN11))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc = transform_dynamic_tensor_descriptor( + c_gm10_bm_gn10_bn_grid_desc, + make_tuple(make_pass_through_transform(GM10), + make_unmerge_transform(make_tuple(BM0, BM1)), + make_pass_through_transform(GN10), + make_unmerge_transform(make_tuple(BN0, BN1))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{})); + + return c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc; + } + + __host__ __device__ static constexpr auto MakeCBlockIdToGM10GN10BlockClusterAdaptor( + const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc) + { + const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1); + const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3); + + constexpr auto GM11 = Number{}; + constexpr auto GN11 = Number{}; + + const auto GM10 = GM1 / GM11; + const auto GN10 = GN1 / GN11; + + const auto c_blockid_to_gm10_gn10_block_cluster_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(GM10, GN10))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + return c_blockid_to_gm10_gn10_block_cluster_adaptor; + } + + using AGKGM0GM10GM11GridDesc = decltype(MakeAGKGM0GM10GM11GridDescriptor(AGKGM0GM1GridDesc{})); + using BGKGN0GN10GN11GridDesc = decltype(MakeBGKGN0GN10GN11GridDescriptor(BGKGN0GN1GridDesc{})); + using CGM10BM0BM1GN10BN0BN1GridDesc = + decltype(MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(CGM0GM1GN0GN1GridDesc{})); + using CBlockIdToGM10GN10BlockClusterAdaptor = + decltype(MakeCBlockIdToGM10GN10BlockClusterAdaptor(CGM0GM1GN0GN1GridDesc{})); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AGKGM0GM10GM11GridDesc& a_gk_gm0_gm10_gm11_grid_desc, + const BGKGN0GN10GN11GridDesc& b_gk_gn0_gn10_gn11_grid_desc, + const CGM10BM0BM1GN10BN0BN1GridDesc& c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, + const CBlockIdToGM10GN10BlockClusterAdaptor& c_blockid_to_gm10_gn10_block_cluster_adaptor, + integral_constant, + integral_constant) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_grid, a_gk_gm0_gm10_gm11_grid_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_grid, b_gk_gn0_gn10_gn11_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetElementSpaceSize()); + + const auto GK = a_gk_gm0_gm10_gm11_grid_desc.GetLength(I0); + + // divide block work by [GM10, GN10] + const auto c_gm10_gn10_block_cluster_idx = + c_blockid_to_gm10_gn10_block_cluster_adaptor.CalculateBottomIndex( + make_multi_index(get_block_1d_id())); + + // HACK: this force index data into SGPR + const index_t igm10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I0]); + const index_t ign10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I1]); + + // lds max alignment + // part of them should be moved into blockwise-gemm + constexpr auto max_lds_align = math::lcm(Number{}, + Number{}, + Number{}, + Number{}); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_gk_gm0_gm10_gm11_block_desc = + make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GM0, I1, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_gk_gn0_gn10_gn11_block_desc = + make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GN0, I1, Number{}), max_lds_align); + + // A matrix in LDS memory for blockwise GEMM + // be careful of LDS alignment + constexpr auto a_gk_bm_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GM0 * Number{}), max_lds_align); + + // B matrix in LDS memory for blockwise GEMM + // be careful of LDS alignment + constexpr auto b_gk_bn_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GN0 * Number{}), max_lds_align); + + // A matrix blockwise copy + auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4< + BlockSize, + InMemoryDataOperation::Set, + Sequence, + ABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11, + ABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_gk_gm0_gm10_gm11_grid_desc), + decltype(a_gk_gm0_gm10_gm11_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_GM11, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>(a_gk_gm0_gm10_gm11_grid_desc, + make_multi_index(0, 0, igm10, 0), + a_gk_gm0_gm10_gm11_block_desc, + make_multi_index(0, 0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4< + BlockSize, + InMemoryDataOperation::Set, + Sequence, + BBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11, + BBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_gk_gn0_gn10_gn11_grid_desc), + decltype(b_gk_gn0_gn10_gn11_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_GN11, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_gk_gn0_gn10_gn11_grid_desc, + make_multi_index(0, 0, ign10, 0), + b_gk_gn0_gn10_gn11_block_desc, + make_multi_index(0, 0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[KPerBlock, GM1PerBlockGM11] is in LDS + // b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS + // c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in + // register + const auto blockwise_gemm = + BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2{}; + constexpr auto c_bm0_bm1_bn0_bn1_thread_tensor_lengths = + decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths(); + + constexpr auto c_bm0_bm1_bn0_bn1_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2( + sequence_to_tuple_of_number(c_bm0_bm1_bn0_bn1_thread_tensor_lengths)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = math::integer_least_multiple( + a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = math::integer_least_multiple( + b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block_double = p_shared_block; + FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; + + // register allocation for output + auto c_thread_buf = make_static_buffer( + c_bm0_bm1_bn0_bn1_thread_desc.GetElementSpaceSize()); + + ThreadwiseDynamicTensorSliceSet_v1{} + .Run(c_bm0_bm1_bn0_bn1_thread_desc, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + FloatAcc{0}); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); + + // hack to control index calculation when iterating over A and B matrix for threadwise copy + constexpr auto a_k_m0_m1_global_iterator_hacks = AGridIteratorHacks{}; + constexpr auto b_k_n0_n1_global_iterator_hacks = BGridIteratorHacks{}; + + // hack to control index calculation when move slice window for A and B matrix for + // threadwise copy + constexpr auto a_k_m0_m1_global_move_slice_window_iterator_hack = + AGridMoveSliceWindowIteratorHacks{}; + constexpr auto b_k_n0_n1_global_move_slice_window_iterator_hack = + BGridMoveSliceWindowIteratorHacks{}; + + auto a_block_even_buf = make_dynamic_buffer( + p_a_block_double, a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize()); + auto b_block_even_buf = make_dynamic_buffer( + p_b_block_double, b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize()); + + auto a_block_odd_buf = make_dynamic_buffer( + p_a_block_double + a_block_aligned_space_size, + a_gk_gm0_gm10_gm11_block_desc.GetElementSpaceSize()); + auto b_block_odd_buf = make_dynamic_buffer( + p_b_block_double + b_block_aligned_space_size, + b_gk_gn0_gn10_gn11_block_desc.GetElementSpaceSize()); + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.RunRead( + a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); + b_blockwise_copy.RunRead( + b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); + + a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_even_buf); + b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_even_buf); + } + + if constexpr(HasMainKBlockLoop) + { + index_t k_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + a_blockwise_copy.MoveSrcSliceWindow( + a_gk_gm0_gm10_gm11_grid_desc, + a_block_slice_copy_step, + a_k_m0_m1_global_move_slice_window_iterator_hack); + b_blockwise_copy.MoveSrcSliceWindow( + b_gk_gn0_gn10_gn11_grid_desc, + b_block_slice_copy_step, + b_k_n0_n1_global_move_slice_window_iterator_hack); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); + b_blockwise_copy.RunRead( + b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(c_bm0_bm1_bn0_bn1_thread_desc, + a_block_even_buf, + b_block_even_buf, + c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_odd_buf); + + // odd iteration + a_blockwise_copy.MoveSrcSliceWindow( + a_gk_gm0_gm10_gm11_grid_desc, + a_block_slice_copy_step, + a_k_m0_m1_global_move_slice_window_iterator_hack); + b_blockwise_copy.MoveSrcSliceWindow( + b_gk_gn0_gn10_gn11_grid_desc, + b_block_slice_copy_step, + b_k_n0_n1_global_move_slice_window_iterator_hack); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); + b_blockwise_copy.RunRead( + b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run( + c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_even_buf); + b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_even_buf); + + k_block_data_begin += 2 * KPerBlock; + } while(k_block_data_begin < GK - 2 * KPerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + a_blockwise_copy.MoveSrcSliceWindow(a_gk_gm0_gm10_gm11_grid_desc, + a_block_slice_copy_step, + a_k_m0_m1_global_move_slice_window_iterator_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_gk_gn0_gn10_gn11_grid_desc, + b_block_slice_copy_step, + b_k_n0_n1_global_move_slice_window_iterator_hack); + + __syncthreads(); + + // LDS double buffer: load last data from device mem + a_blockwise_copy.RunRead( + a_gk_gm0_gm10_gm11_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); + b_blockwise_copy.RunRead( + b_gk_gn0_gn10_gn11_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run( + c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + + // LDS double buffer: store last data to LDS + a_blockwise_copy.RunWrite(a_gk_gm0_gm10_gm11_block_desc, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_gk_gn0_gn10_gn11_block_desc, b_block_odd_buf); + + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr index_t M11 = + M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101; + constexpr index_t N11 = + N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101; + + constexpr index_t M10 = GM1PerBlockGM11 / M11; + constexpr index_t N10 = GN1PerBlockGN11 / N11; + + constexpr index_t M111 = M1PerThreadM111; + constexpr index_t N111 = N1PerThreadN111; + + constexpr auto c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + const auto c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block = + blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id()); + + ThreadwiseDynamicTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc), + decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc), + Sequence<1, + c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I0], + c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I1], + 1, + c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I2], + c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I3]>, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, + make_multi_index(igm10, + c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I0], + c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I1], + ign10, + c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I2], + c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I3])} + .Run(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, + c_grid_buf, + CGridIteratorHacks{}); + } + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r1.hpp similarity index 86% rename from composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp rename to composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r1.hpp index 49e9df297d..8e8af1a12a 100644 --- a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r1.hpp @@ -27,13 +27,13 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_dynamic_gemm_v1(const FloatA* __restrict__ p_a_global, - const FloatB* __restrict__ p_b_global, - FloatC* __restrict__ p_c_global, - const AGlobalDesc a_k_m_global_desc, - const BGlobalDesc b_k_n_global_desc, - const CGlobalDesc c_m0_m1_n0_n1_global_desc, - const CBlockClusterDesc c_block_cluster_desc) + kernel_dynamic_gemm_v1r1(const FloatA* __restrict__ p_a_global, + const FloatB* __restrict__ p_b_global, + FloatC* __restrict__ p_c_global, + const AGlobalDesc a_k_m_global_desc, + const BGlobalDesc b_k_n_global_desc, + const CGlobalDesc c_m0_m1_n0_n1_global_desc, + const CBlockClusterDesc c_block_cluster_desc) { GridwiseGemm::Run(p_a_global, p_b_global, @@ -63,13 +63,13 @@ __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) #endif - kernel_dynamic_gemm_v1(const FloatA* __restrict__ p_a_global, - const FloatB* __restrict__ p_b_global, - FloatC* __restrict__ p_c_global, - const void __CONSTANT__* p_a_k_m_global_desc, - const void __CONSTANT__* p_b_k_n_global_desc, - const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc, - const void __CONSTANT__* p_c_block_cluster_desc) + kernel_dynamic_gemm_v1r1(const FloatA* __restrict__ p_a_global, + const FloatB* __restrict__ p_b_global, + FloatC* __restrict__ p_c_global, + const void __CONSTANT__* p_a_k_m_global_desc, + const void __CONSTANT__* p_b_k_n_global_desc, + const void __CONSTANT__* p_c_m0_m1_n0_n1_global_desc, + const void __CONSTANT__* p_c_block_cluster_desc) { // first cast void __CONSTANT__ void* to void* // second cast void* to Desc* @@ -108,13 +108,13 @@ template -struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 +struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r1 { __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { constexpr auto max_lds_align = math::lcm(Number{}, Number{}, - Number{}, - Number{}); + Number{}, + Number{}); // A matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment @@ -210,8 +210,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 // lds max alignment constexpr auto max_lds_align = math::lcm(Number{}, Number{}, - Number{}, - Number{}); + Number{}, + Number{}); // A matrix in LDS memory, dst of blockwise copy // be careful of LDS alignment @@ -284,34 +284,39 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check - static_assert(MPerBlock % (MPerThread * MLevel0Cluster * MLevel1Cluster) == 0 && - NPerBlock % (NPerThread * NLevel0Cluster * NLevel1Cluster) == 0, - "wrong!"); + static_assert( + MPerBlock % (M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10) == 0 && + NPerBlock % (N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10) == 0, + "wrong!"); - constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster); - constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster); + constexpr index_t M0PerThread = + MPerBlock / (M1PerThread * M1N1ThreadClusterM11 * M1N1ThreadClusterM10); + constexpr index_t N0PerThread = + NPerBlock / (N1PerThread * M1N1ThreadClusterN11 * M1N1ThreadClusterN10); constexpr auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor( a_k_m_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple( - Number{}, Number{}))), + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple( + Number{}, + Number{}))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{})); constexpr auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor( b_k_n_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple( - Number{}, Number{}))), + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple( + Number{}, + Number{}))), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1, 2>{})); constexpr auto c_m0_m1_n0_n1_thread_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple( - Number{}, Number{}, Number{}, Number{})); + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(Number{}, + Number{}, + Number{}, + Number{})); const auto blockwise_gemm = BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2{}; + M1N1ThreadClusterM10, + M1N1ThreadClusterN10, + M1N1ThreadClusterM11, + M1N1ThreadClusterN11, + M1PerThread, + N1PerThread>{}; // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size = @@ -345,9 +350,10 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 auto c_thread_buf = make_static_buffer( c_m0_m1_n0_n1_thread_desc.GetElementSpaceSize()); - ThreadwiseDynamicTensorSliceSet_v1>{} + ThreadwiseDynamicTensorSliceSet_v1< + FloatAcc, + decltype(c_m0_m1_n0_n1_thread_desc), + Sequence>{} .Run(c_m0_m1_n0_n1_thread_desc, make_tuple(I0, I0, I0, I0), c_thread_buf, FloatAcc{0}); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0); @@ -479,8 +485,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 // output: register to global memory { - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; // hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor constexpr auto c_m0_m1_n0_n1_global_tensor_iterator_hacks = CGlobalIteratorHacks{}; @@ -493,7 +499,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1 FloatC, decltype(c_m0_m1_n0_n1_thread_desc), decltype(c_m0_m1_n0_n1_global_desc), - Sequence, + Sequence, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp new file mode 100644 index 0000000000..525f1bcf25 --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp @@ -0,0 +1,621 @@ +#ifndef CK_GRIDWISE_DYNAMIC_GEMM_V1R2_HPP +#define CK_GRIDWISE_DYNAMIC_GEMM_V1R2_HPP + +#include "common_header.hpp" +#include "dynamic_multi_index_transform_helper.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "blockwise_gemm_v2r2.hpp" +#include "blockwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_dynamic_tensor_slice_set.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_dynamic_gemm_v1r2( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AKM0M1GridDesc a_k_m0_m1_grid_desc, + const BKN0N1GridDesc b_k_n0_n1_grid_desc, + const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc, + const CBlockIdToM0N0BlockClusterAdaptor c_blockid_to_m0_n0_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k_m0_m1_grid_desc, + b_k_n0_n1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} + +template +struct GridwiseDynamicGemm_km_kn_mn_v1r2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = math::lcm(Number{}, + Number{}, + Number{}, + Number{}); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}), max_lds_align); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = + math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = + math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); + + return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); + } + + __host__ __device__ static constexpr bool CheckValidity(const AKMGridDesc& a_k_m_grid_desc, + const BKNGridDesc& b_k_n_grid_desc, + const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = a_k_m_grid_desc.GetLength(I1); + const auto N = b_k_n_grid_desc.GetLength(I1); + const auto K = a_k_m_grid_desc.GetLength(I0); + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K == b_k_n_grid_desc.GetLength(I0)) && + (M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K % KPerBlock == 0); + } + + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1); + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const bool has_main_k_block_loop = (K + KPerBlock) / (2 * KPerBlock) > 1; + + return has_main_k_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K) + { + const bool has_double_tail_k_block_loop = (K / KPerBlock) % 2 == 0; + + return has_double_tail_k_block_loop; + } + + __host__ __device__ static constexpr auto + MakeAKM0M1GridDescriptor(const AKMGridDesc& a_k_m_grid_desc) + { + const auto K = a_k_m_grid_desc.GetLength(I0); + const auto M = a_k_m_grid_desc.GetLength(I1); + + const auto M1 = Number{}; + const auto M0 = M / M1; + + const auto a_k_m0_m1_grid_desc = transform_dynamic_tensor_descriptor( + a_k_m_grid_desc, + make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(M0, M1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + return a_k_m0_m1_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeBKN0N1GridDescriptor(const BKNGridDesc& b_k_n_grid_desc) + { + const auto K = b_k_n_grid_desc.GetLength(I0); + const auto N = b_k_n_grid_desc.GetLength(I1); + + const auto N1 = Number{}; + const auto N0 = N / N1; + + const auto b_k_n0_n1_grid_desc = transform_dynamic_tensor_descriptor( + b_k_n_grid_desc, + make_tuple(make_pass_through_transform(K), make_unmerge_transform(make_tuple(N0, N1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + return b_k_n0_n1_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + constexpr auto M11 = + Number{}; + constexpr auto N11 = + Number{}; + + constexpr auto M10 = M1 / M11; + constexpr auto N10 = N1 / N11; + + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), + make_unmerge_transform(make_tuple(N0, N10, N11))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_m0_m10_m11_n0_n10_n11_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + return c_blockid_to_m0_n0_block_cluster_adaptor; + } + + using AKM0M1GridDesc = decltype(MakeAKM0M1GridDescriptor(AKMGridDesc{})); + using BKN0N1GridDesc = decltype(MakeBKN0N1GridDescriptor(BKNGridDesc{})); + using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{})); + using CBlockIdToM0N0BlockClusterAdaptor = + decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{})); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AKM0M1GridDesc& a_k_m0_m1_grid_desc, + const BKN0N1GridDesc& b_k_n0_n1_grid_desc, + const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc, + const CBlockIdToM0N0BlockClusterAdaptor& c_blockid_to_m0_n0_block_cluster_adaptor, + integral_constant, + integral_constant) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); + + const auto K = a_k_m0_m1_grid_desc.GetLength(I0); + + // divide block work by [M, N] + const auto c_m0_n0_block_cluster_idx = + c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex( + make_multi_index(get_block_1d_id())); + + // HACK: this force index data into SGPR + const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); + const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(Number{}, + Number{}, + Number{}, + Number{}); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}), max_lds_align); + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k_m0_m1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, I1, Number{}), max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k_n0_n1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, I1, Number{}), max_lds_align); + + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseDynamicTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_K_M0_M1, + ABlockTransferThreadClusterLengths_K_M0_M1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_k_m0_m1_grid_desc), + decltype(a_k_m0_m1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_M1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>(a_k_m0_m1_grid_desc, + make_multi_index(0, im0, 0), + a_k_m0_m1_block_desc, + make_multi_index(0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = + BlockwiseDynamicTensorSliceTransfer_v4, + BBlockTransferThreadSliceLengths_K_N0_N1, + BBlockTransferThreadClusterLengths_K_N0_N1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_k_n0_n1_grid_desc), + decltype(b_k_n0_n1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_N1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_k_n0_n1_grid_desc, + make_multi_index(0, in0, 0), + b_k_n0_n1_block_desc, + make_multi_index(0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[KPerBlock, MPerBlockM1] is in LDS + // b_mtx[KPerBlocl, NPerBlockN1] is in LDS + // c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in + // register + const auto blockwise_gemm = + BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2{}; + constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = + decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths(); + + constexpr auto c_m10_m11_n10_n11_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2( + sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = + math::integer_least_multiple(a_k_m0_m1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = + math::integer_least_multiple(b_k_n0_n1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block_double = p_shared_block; + FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; + + // register allocation for output + auto c_thread_buf = make_static_buffer( + c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); + + ThreadwiseDynamicTensorSliceSet_v1{} + .Run(c_m10_m11_n10_n11_thread_desc, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + FloatAcc{0}); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); + + // hack to control index calculation when iterating over A and B matrix for threadwise copy + constexpr auto a_k_m0_m1_global_iterator_hacks = AGridIteratorHacks{}; + constexpr auto b_k_n0_n1_global_iterator_hacks = BGridIteratorHacks{}; + + // hack to control index calculation when move slice window for A and B matrix for + // threadwise copy + constexpr auto a_k_m0_m1_global_move_slice_window_iterator_hack = + AGridMoveSliceWindowIteratorHacks{}; + constexpr auto b_k_n0_n1_global_move_slice_window_iterator_hack = + BGridMoveSliceWindowIteratorHacks{}; + + auto a_block_even_buf = make_dynamic_buffer( + p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize()); + auto b_block_even_buf = make_dynamic_buffer( + p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize()); + + auto a_block_odd_buf = + make_dynamic_buffer(p_a_block_double + a_block_aligned_space_size, + a_k_m0_m1_block_desc.GetElementSpaceSize()); + auto b_block_odd_buf = + make_dynamic_buffer(p_b_block_double + b_block_aligned_space_size, + b_k_n0_n1_block_desc.GetElementSpaceSize()); + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.RunRead( + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); + b_blockwise_copy.RunRead( + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); + + a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf); + b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf); + } + + if constexpr(HasMainKBlockLoop) + { + index_t k_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + a_blockwise_copy.MoveSrcSliceWindow( + a_k_m0_m1_grid_desc, + a_block_slice_copy_step, + a_k_m0_m1_global_move_slice_window_iterator_hack); + b_blockwise_copy.MoveSrcSliceWindow( + b_k_n0_n1_grid_desc, + b_block_slice_copy_step, + b_k_n0_n1_global_move_slice_window_iterator_hack); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); + b_blockwise_copy.RunRead( + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, + a_block_even_buf, + b_block_even_buf, + c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf); + + // odd iteration + a_blockwise_copy.MoveSrcSliceWindow( + a_k_m0_m1_grid_desc, + a_block_slice_copy_step, + a_k_m0_m1_global_move_slice_window_iterator_hack); + b_blockwise_copy.MoveSrcSliceWindow( + b_k_n0_n1_grid_desc, + b_block_slice_copy_step, + b_k_n0_n1_global_move_slice_window_iterator_hack); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); + b_blockwise_copy.RunRead( + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_even_buf); + b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_even_buf); + + k_block_data_begin += 2 * KPerBlock; + } while(k_block_data_begin < K - 2 * KPerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + a_blockwise_copy.MoveSrcSliceWindow(a_k_m0_m1_grid_desc, + a_block_slice_copy_step, + a_k_m0_m1_global_move_slice_window_iterator_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_k_n0_n1_grid_desc, + b_block_slice_copy_step, + b_k_n0_n1_global_move_slice_window_iterator_hack); + + __syncthreads(); + + // LDS double buffer: load last data from device mem + a_blockwise_copy.RunRead( + a_k_m0_m1_grid_desc, a_global_buf, a_k_m0_m1_global_iterator_hacks); + b_blockwise_copy.RunRead( + b_k_n0_n1_grid_desc, b_global_buf, b_k_n0_n1_global_iterator_hacks); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + + // LDS double buffer: store last data to LDS + a_blockwise_copy.RunWrite(a_k_m0_m1_block_desc, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_k_n0_n1_block_desc, b_block_odd_buf); + + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr index_t M11 = + M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101; + constexpr index_t N11 = + N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101; + + constexpr index_t M10 = MPerBlockM1 / M11; + constexpr index_t N10 = NPerBlockN1 / N11; + + constexpr index_t M111 = M1PerThreadM111; + constexpr index_t N111 = N1PerThreadN111; + + constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = + blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id()); + + ThreadwiseDynamicTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), + decltype(c_m0_m10_m11_n0_n10_n11_grid_desc), + Sequence<1, + c_m10_m11_n10_n11_thread_tensor_lengths[I0], + c_m10_m11_n10_n11_thread_tensor_lengths[I1], + 1, + c_m10_m11_n10_n11_thread_tensor_lengths[I2], + c_m10_m11_n10_n11_thread_tensor_lengths[I3]>, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{c_m0_m10_m11_n0_n10_n11_grid_desc, + make_multi_index(im0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], + in0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])} + .Run(c_m0_m10_m11_n0_n10_n11_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_grid_buf, + CGridIteratorHacks{}); + } + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/array_multi_index.hpp b/composable_kernel/include/utility/array_multi_index.hpp similarity index 100% rename from composable_kernel/include/tensor_description/array_multi_index.hpp rename to composable_kernel/include/utility/array_multi_index.hpp diff --git a/composable_kernel/include/utility/common_header.hpp b/composable_kernel/include/utility/common_header.hpp index fc05646008..309571b061 100644 --- a/composable_kernel/include/utility/common_header.hpp +++ b/composable_kernel/include/utility/common_header.hpp @@ -6,6 +6,7 @@ #include "container_helper.hpp" #include "statically_indexed_array.hpp" #include "container_element_picker.hpp" +#include "multi_index.hpp" #include "data_type.hpp" #include "float_type.hpp" #include "functional.hpp" diff --git a/composable_kernel/include/utility/config.amd.hpp.in b/composable_kernel/include/utility/config.amd.hpp.in index 89ea12fee3..c70cfd96db 100644 --- a/composable_kernel/include/utility/config.amd.hpp.in +++ b/composable_kernel/include/utility/config.amd.hpp.in @@ -14,7 +14,7 @@ #define CK_DEVICE_BACKEND_AMD 1 // GPU ID -#if 0 +#if 1 #define CK_AMD_GPU_GFX906 1 #elif 0 #define CK_AMD_GPU_GFX908 1 @@ -28,7 +28,7 @@ #endif // launch bounds -#define CK_USE_LAUNCH_BOUNDS 1 +#define CK_USE_LAUNCH_BOUNDS 0 #ifdef CK_USE_LAUNCH_BOUNDS #define CK_MAX_THREAD_PER_BLOCK 256 diff --git a/composable_kernel/include/utility/magic_division.hpp b/composable_kernel/include/utility/magic_division.hpp index 8e29e75348..3239205fda 100644 --- a/composable_kernel/include/utility/magic_division.hpp +++ b/composable_kernel/include/utility/magic_division.hpp @@ -118,6 +118,7 @@ struct MagicDivision return (tmp + dividend) >> shift; } +#if 1 // debug // HACK: magic division for int32_t // HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be // non-negative for result to be correct @@ -127,8 +128,25 @@ struct MagicDivision { uint32_t dividend_u32 = as_type(dividend_i32); uint32_t tmp = ((uint64_t)dividend_u32 * (uint64_t)multiplier) >> 32; - return (tmp + dividend_i32) >> shift; + return (tmp + dividend_u32) >> shift; } +#else + // the inline ASM is producing wrong result + __host__ __device__ static int32_t + DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) + { + uint32_t r; + asm volatile("\n \ + v_mul_hi_u32 %0, %1, %2 \n \ + v_add_u32_e32 %0, %1, %0 \n \ + v_lshrrev_b32_e32 %0, %3, %0 \n \ + " + : "=v"(r) + : "v"(as_type(dividend_i32)), "s"(multiplier), "s"(shift)); + + return as_type(r); + } +#endif }; } // namespace ck diff --git a/composable_kernel/include/utility/math.hpp b/composable_kernel/include/utility/math.hpp index 639d4157e6..11e87eca4c 100644 --- a/composable_kernel/include/utility/math.hpp +++ b/composable_kernel/include/utility/math.hpp @@ -74,7 +74,7 @@ __host__ __device__ constexpr auto integer_divide_floor(X x, Y y) template __host__ __device__ constexpr auto integer_divide_ceil(X x, Y y) { - return (x + y - 1) / y; + return (x + y - Number<1>{}) / y; } template diff --git a/composable_kernel/include/tensor_description/multi_index.hpp b/composable_kernel/include/utility/multi_index.hpp similarity index 100% rename from composable_kernel/include/tensor_description/multi_index.hpp rename to composable_kernel/include/utility/multi_index.hpp diff --git a/composable_kernel/include/utility/sequence_helper.hpp b/composable_kernel/include/utility/sequence_helper.hpp index 706b231792..ccedfc3e6f 100644 --- a/composable_kernel/include/utility/sequence_helper.hpp +++ b/composable_kernel/include/utility/sequence_helper.hpp @@ -26,5 +26,11 @@ __host__ __device__ constexpr auto generate_sequence_v2(F&& f, Number) typename arithmetic_sequence_gen<0, N, 1>::type{}); } +template +__host__ __device__ constexpr auto to_sequence(Tuple...>) +{ + return Sequence{}; +} + } // namespace ck #endif diff --git a/composable_kernel/include/tensor_description/statically_indexed_array_multi_index.hpp b/composable_kernel/include/utility/statically_indexed_array_multi_index.hpp similarity index 96% rename from composable_kernel/include/tensor_description/statically_indexed_array_multi_index.hpp rename to composable_kernel/include/utility/statically_indexed_array_multi_index.hpp index ff1df4bd10..9e96f06d73 100644 --- a/composable_kernel/include/tensor_description/statically_indexed_array_multi_index.hpp +++ b/composable_kernel/include/utility/statically_indexed_array_multi_index.hpp @@ -99,7 +99,8 @@ __host__ __device__ void print_multi_index(const Tuple& x) printf("{"); printf("MultiIndex, "); printf("size %d,", index_t{sizeof...(Xs)}); - static_for<0, sizeof...(Xs), 1>{}([&](auto i) { printf("%d ", index_t{x.At(i)}); }); + static_for<0, sizeof...(Xs), 1>{}( + [&](auto i) { printf("%d ", static_cast(x.At(i))); }); printf("}"); } diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index 10bb32f938..6b91ab986d 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -16,6 +16,7 @@ install(TARGETS host LIBRARY DESTINATION lib) if(DEVICE_BACKEND STREQUAL "AMD") set(CONV_SOURCE src/conv_driver.cpp) + set(CONV_V2_SOURCE src/conv_driver_v2.cpp) set(CONV_BWD_DATA_SOURCE src/conv_bwd_data_driver.cpp) elseif(DEVICE_BACKEND STREQUAL "NVIDIA") set(CONV_SOURCE src/conv_driver.cu) @@ -23,7 +24,9 @@ elseif(DEVICE_BACKEND STREQUAL "NVIDIA") endif() add_executable(conv_driver ${CONV_SOURCE}) +add_executable(conv_driver_v2 ${CONV_V2_SOURCE}) add_executable(conv_bwd_data_driver ${CONV_BWD_DATA_SOURCE}) target_link_libraries(conv_driver PRIVATE host) +target_link_libraries(conv_driver_v2 PRIVATE host) target_link_libraries(conv_bwd_data_driver PRIVATE host) diff --git a/driver/include/conv_common.hpp b/driver/include/conv_common.hpp index c4020928f3..2b89bc876e 100644 --- a/driver/include/conv_common.hpp +++ b/driver/include/conv_common.hpp @@ -2,15 +2,25 @@ #define CONV_COMMON_HPP #include "tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor.hpp" + +enum ConvTensorLayout +{ + NCHW, + NHWC, + CHWN, + NCHWc, + NHWCc +}; template + class LeftPads, + class RightPads> constexpr auto get_convolution_output_default_4d_tensor_descriptor( - InDesc, WeiDesc, ConvStrides, ConvDilations, LowerPads, UpperPads) + InDesc, WeiDesc, ConvStrides, ConvDilations, LeftPads, RightPads) { using namespace ck; @@ -35,21 +45,69 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor( constexpr index_t Y = wei_desc.GetLength(I2); constexpr index_t X = wei_desc.GetLength(I3); - constexpr index_t HPadLow = LowerPads{}.Get(I0); - constexpr index_t WPadLow = LowerPads{}.Get(I1); + constexpr index_t LeftPadH = LeftPads{}.Get(I0); + constexpr index_t LeftPadW = LeftPads{}.Get(I1); - constexpr index_t HPadUp = UpperPads{}.Get(I0); - constexpr index_t WPadUp = UpperPads{}.Get(I1); + constexpr index_t RightPadH = RightPads{}.Get(I0); + constexpr index_t RightPadW = RightPads{}.Get(I1); constexpr index_t YEff = (Y - 1) * ConvDilations{}[0] + 1; constexpr index_t XEff = (X - 1) * ConvDilations{}[1] + 1; - constexpr index_t Ho = (Hi + HPadLow + HPadUp - YEff) / ConvStrides{}[0] + 1; - constexpr index_t Wo = (Wi + WPadLow + WPadUp - XEff) / ConvStrides{}[1] + 1; + constexpr index_t Ho = (Hi + LeftPadH + RightPadH - YEff) / ConvStrides{}[0] + 1; + constexpr index_t Wo = (Wi + LeftPadW + RightPadW - XEff) / ConvStrides{}[1] + 1; return make_native_tensor_descriptor_packed(Sequence{}); } +template +constexpr auto get_convolution_output_default_4d_tensor_descriptor( + const ck::DynamicTensorDescriptor& in_desc, + const ck::DynamicTensorDescriptor& wei_desc, + const ConvStrides& conv_strides, + const ConvDilations conv_dilations, + const LeftPads& left_pads, + const RightPads& right_pads) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + assert(in_desc.GetNumOfDimension() == 4); + assert(wei_desc.GetNumOfDimension() == 4); + assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1)); + + const auto N = in_desc.GetLength(I0); + const auto Hi = in_desc.GetLength(I2); + const auto Wi = in_desc.GetLength(I3); + + const auto K = wei_desc.GetLength(I0); + const auto Y = wei_desc.GetLength(I2); + const auto X = wei_desc.GetLength(I3); + + const auto LeftPadH = left_pads[I0]; + const auto LeftPadW = left_pads[I1]; + + const auto RightPadH = right_pads[I0]; + const auto RightPadW = right_pads[I1]; + + const auto YEff = (Y - I1) * conv_dilations[I0] + I1; + const auto XEff = (X - I1) * conv_dilations[I1] + I1; + + const auto Ho = (Hi + LeftPadH + RightPadH - YEff) / conv_strides[I0] + I1; + const auto Wo = (Wi + LeftPadW + RightPadW - XEff) / conv_strides[I1] + I1; + + return make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo)); +} + template constexpr std::size_t calculate_convolution_flops(const InDesc& in_desc, const WeiDesc& wei_desc, const OutDesc& out_desc) diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp index 1a707237fd..f305cb9ae2 100644 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -2,30 +2,29 @@ #include "device.hpp" #include "host_tensor.hpp" #include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp" -#include "driver_dynamic_gemm_v1.hpp" +#include "driver_dynamic_gemm_v1r2.hpp" -template +template void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( - InDesc, + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, const Tensor& in_n_c_hi_wi, - WeiDesc, const Tensor& wei_k_c_y_x, - OutDesc, Tensor& out_n_k_ho_wo, - ConvStrides, - ConvDilations, - InLeftPads, - InRightPads, ck::index_t nrepeat) { using namespace ck; @@ -50,505 +49,155 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); -#if 1 - // run-time variables const auto in_n_c_hi_wi_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(InDesc::GetLengths())); + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); const auto wei_k_c_y_x_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(WeiDesc::GetLengths())); + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); const auto out_n_k_ho_wo_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(to_multi_index(OutDesc::GetLengths())); + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); - const auto conv_strides = to_multi_index(ConvStrides{}); - const auto conv_dilations = to_multi_index(ConvDilations{}); - const auto in_left_pads = to_multi_index(InLeftPads{}); - const auto in_right_pads = to_multi_index(InRightPads{}); -#else - // compile-time variables - const auto in_n_c_hi_wi_desc = make_dynamic_naive_tensor_descriptor_packed_v2( - sequence_to_tuple_of_number(InDesc::GetLengths())); - const auto wei_k_c_y_x_desc = make_dynamic_naive_tensor_descriptor_packed_v2( - sequence_to_tuple_of_number(WeiDesc::GetLengths())); - const auto out_n_k_ho_wo_desc = make_dynamic_naive_tensor_descriptor_packed_v2( - sequence_to_tuple_of_number(OutDesc::GetLengths())); - - const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{}); - const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{}); - const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{}); - const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); -#endif - -#if 0 - // cdata = 16, BlockSize = 64, 16x64x4 - constexpr index_t BlockSize = 64; - - constexpr index_t GemmMPerBlock = 16; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerThread = 2; - constexpr index_t GemmNPerThread = 2; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t ThreadGemmDataPerReadM = 2; - constexpr index_t ThreadGemmDataPerReadN = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; - using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; - - using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; - using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2; -#elif 0 - // cdata = 32, BlockSize 64, 16x128x4 - constexpr index_t BlockSize = 64; - - constexpr index_t GemmMPerBlock = 16; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerThread = 2; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t ThreadGemmDataPerReadM = 2; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; - using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; - - using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4; -#elif 0 - // cdata = 64, BlockSize 64, 16x256x2 - constexpr index_t BlockSize = 64; - - constexpr index_t GemmMPerBlock = 16; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 2; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 1; - constexpr index_t GemmNLevel1Cluster = 16; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; - using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 16>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; - - using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4; -#elif 0 - // cdata = 64, BlockSize 64, 16x256x4 - constexpr index_t BlockSize = 64; - - constexpr index_t GemmMPerBlock = 16; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 1; - constexpr index_t GemmNLevel1Cluster = 16; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; - using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; - - using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4; -#elif 0 - // cdata = 16, BlockSize = 64, 16x64x4 - // GemmBBlockCopySrcDataPerRead_GemmN = 4 - // GemmCThreadCopyDstDataPerWrite_GemmN1 = 2 - constexpr index_t BlockSize = 64; - - constexpr index_t GemmMPerBlock = 16; - constexpr index_t GemmNPerBlock = 64; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerThread = 2; - constexpr index_t GemmNPerThread = 2; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t ThreadGemmDataPerReadM = 2; - constexpr index_t ThreadGemmDataPerReadN = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; - using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; - - using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 16>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 2; -#elif 0 - // cdata = 32, BlockSize = 64, 16x128x4 - // GemmBBlockCopySrcDataPerRead_GemmN = 4 - // GemmCThreadCopyDstDataPerWrite_GemmN1 = 4 - constexpr index_t BlockSize = 64; - - constexpr index_t GemmMPerBlock = 16; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerThread = 2; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 8; - - constexpr index_t ThreadGemmDataPerReadM = 2; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; - using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; - - using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 32>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4; -#elif 0 - // cdata = 64, BlockSize = 128, 32x256x8 - constexpr index_t BlockSize = 128; - - constexpr index_t GemmMPerBlock = 32; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 8; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 16; - - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>; - using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; - - using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 0 - // cdata = 64, BlockSize = 256, 128x128x2 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 2; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 8; - constexpr index_t GemmNLevel1Cluster = 8; - - using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; - using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; - - using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<1, 1>; - using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 0 - // cdata = 64, BlockSize = 256, 128x128x4 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 8; - constexpr index_t GemmNLevel1Cluster = 8; - - using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>; - using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; - - using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 1>; - using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 1 +#if 1 // cdata = 64, BlockSize = 256, 128x128x8 // b thread copy 4x1 constexpr index_t BlockSize = 256; - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 8; + constexpr index_t GemmMPerBlockM1 = 128; + constexpr index_t GemmNPerBlockN1 = 128; + constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; + constexpr index_t GemmM1PerThreadM111 = 4; + constexpr index_t GemmN1PerThreadN111 = 4; + constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 8; - constexpr index_t GemmNLevel1Cluster = 8; + constexpr index_t GemmM11N11ThreadClusterM1100 = 8; + constexpr index_t GemmM11N11ThreadClusterN1100 = 8; + constexpr index_t GemmM11N11ThreadClusterM1101 = 2; + constexpr index_t GemmM11N11ThreadClusterN1101 = 2; - using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; - using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; + using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 1>; + using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 128>; - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1; - using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; - using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; + using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>; + using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>; - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferSrcScalarPerVector_N1 = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1; - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 1 - // cdata = 64, BlockSize = 256, 128x128x8 - // b thread copy 2x2 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 8; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 8; - constexpr index_t GemmNLevel1Cluster = 8; - - using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; - using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; - - using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 1 - // cdata = 64, BlockSize = 256, 128x128x16 - // GemmBBlockCopySrcDataPerRead_GemmN = 4 - // GemmCThreadCopyDstDataPerWrite_GemmN1 = 4 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 16; - - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; - - constexpr index_t GemmMLevel0Cluster = 4; - constexpr index_t GemmNLevel0Cluster = 4; - constexpr index_t GemmMLevel1Cluster = 4; - constexpr index_t GemmNLevel1Cluster = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; - - using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4; + constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 1; #endif - constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster; - constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster; - const auto descs = -#if 1 - transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad -#elif 0 - transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad -#else - transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1 -#endif - (wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads); + transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); + + constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}, + Sequence<0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{}, + Sequence<0, 0, 2, 0, 0>{})); + + constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; + + const auto wei_gemmk_gemmm_grid_desc = descs[I0]; + const auto in_gemmk_gemmn_grid_desc = descs[I1]; + const auto out_gemmm_gemmn_grid_desc = descs[I2]; for(index_t i = 0; i < 5; ++i) { - float ave_time = launch_kernel_dynamic_gemm_v1< + float ave_time = driver_dynamic_gemm_v1r2< BlockSize, - typename vector_type::type, + TInWei, TAcc, TOut, InMemoryDataOperation::Set, - decltype(descs[I0]), - decltype(descs[I1]), - decltype(descs[I2]), - decltype(descs[I3]), - GemmMPerBlock, - GemmNPerBlock, + decltype(wei_gemmk_gemmm_grid_desc), + decltype(in_gemmk_gemmn_grid_desc), + decltype(out_gemmm_gemmn_grid_desc), + GemmMPerBlockM1, + GemmNPerBlockN1, GemmKPerBlock, - GemmMPerThread, - GemmNPerThread, + GemmM1PerThreadM111, + GemmN1PerThreadN111, GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmABlockTransferThreadSliceLengths_GemmK_GemmM, - GemmABlockTransferThreadClusterLengths_GemmK_GemmM, - Sequence<1, 0>, - Sequence<1, 0>, - 0, - GemmABlockTransferSrcScalarPerVector_GemmK, - GemmABlockTransferDstScalarPerVector_GemmM, + GemmM11N11ThreadClusterM1100, + GemmM11N11ThreadClusterN1100, + GemmM11N11ThreadClusterM1101, + GemmM11N11ThreadClusterN1101, + GemmABlockTransferThreadSliceLengths_K_M0_M1, + GemmABlockTransferThreadClusterLengths_K_M0_M1, + Sequence<2, 1, 0>, // ABlockTransferThreadClusterArrangeOrder + Sequence<2, 1, 0>, // ABlockTransferSrcAccessOrder + 0, // ABlockTransferSrcVectorDim + GemmABlockTransferSrcScalarPerVector_K, + GemmABlockTransferDstScalarPerVector_M1, false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, - GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, - Sequence<0, 1>, - Sequence<0, 1>, - 1, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmN, - false, // don't move back src coordinate after threadwise copy, which will be fused with - // MoveSrcSliceWindow() to save addr computation - Sequence<2, 3, 0, 1>, - 3, - GemmCThreadTransferDstScalarPerVector_GemmN1, - decltype(descs[I4]), - decltype(descs[I5]), - decltype(descs[I6]), - decltype(descs[I7]), - decltype(descs[I8])>(static_cast::type*>( - wei_k_c_y_x_device_buf.GetDeviceBuffer()), - static_cast::type*>( - in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - descs[I0], - descs[I1], - descs[I2], - descs[I3], - descs[I4], - descs[I5], - descs[I6], - descs[I7], - descs[I8], - nrepeat); + GemmBBlockTransferThreadSliceLengths_K_N0_N1, + GemmBBlockTransferThreadClusterLengths_K_N0_N1, + Sequence<0, 1, 2>, // BBlockTransferThreadClusterArrangeOrder + Sequence<0, 1, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + GemmBBlockTransferSrcScalarPerVector_N1, + GemmBBlockTransferDstScalarPerVector_N1, + false, // don't move back src coordinate after threadwise copy + Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder + 5, // CThreadTransferSrcDstVectorDim + GemmCThreadTransferDstScalarPerVector_N11, + decltype(wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks), + decltype(in_gemmk_gemmn0_gemmn1_grid_iterator_hacks), + decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks), + decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks), + decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks)>( + static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + wei_gemmk_gemmm_grid_desc, + in_gemmk_gemmn_grid_desc, + out_gemmm_gemmn_grid_desc, + wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks, + in_gemmk_gemmn0_gemmn1_grid_iterator_hacks, + out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks, + wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks, + in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks, + nrepeat); float perf = (float)calculate_convolution_flops( in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) / diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp index 1fff630d4c..c3640d675c 100644 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp @@ -2,30 +2,29 @@ #include "device.hpp" #include "host_tensor.hpp" #include "transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk.hpp" -#include "driver_dynamic_gemm_v1.hpp" +#include "driver_dynamic_gemm_v1r2.hpp" -template +template void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( - InDesc, - const Tensor& in_n_c_hi_wi, - WeiDesc, - const Tensor& wei_k_c_y_x, - OutDesc, - Tensor& out_n_k_ho_wo, - ConvStrides, - ConvDilations, - InLeftPads, - InRightPads, + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + Tensor& out_n_ho_wo_k, ck::index_t nrepeat) { using namespace ck; @@ -42,73 +41,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( constexpr auto I7 = Number<7>{}; constexpr auto I8 = Number<8>{}; - constexpr auto N = OutDesc::GetLengths()[I0]; - constexpr auto K = OutDesc::GetLengths()[I1]; - constexpr auto C = WeiDesc::GetLengths()[I1]; - - constexpr auto Hi = InDesc::GetLengths()[I2]; - constexpr auto Wi = InDesc::GetLengths()[I3]; - - constexpr auto Ho = OutDesc::GetLengths()[I2]; - constexpr auto Wo = OutDesc::GetLengths()[I3]; - - constexpr auto Y = WeiDesc::GetLengths()[I2]; - constexpr auto X = WeiDesc::GetLengths()[I3]; - - constexpr auto C0 = C / Number{}; - constexpr auto C1 = Number{}; - -#if 1 - // run-time variables - constexpr auto in_n_hi_wi_c0_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0)); - constexpr auto wei_k_y_x_c0_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(K, Y, X, C0)); - constexpr auto out_n_ho_wo_k_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Ho, Wo, K)); - - const auto conv_strides = to_multi_index(ConvStrides{}); - const auto conv_dilations = to_multi_index(ConvDilations{}); - const auto in_left_pads = to_multi_index(InLeftPads{}); - const auto in_right_pads = to_multi_index(InRightPads{}); -#else - // compile-time variables - constexpr auto in_n_hi_wi_c0_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Hi, Wi, C0)); - constexpr auto wei_k_y_x_c0_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, Y, X, C0)); - constexpr auto out_n_ho_wo_k_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, Ho, Wo, K)); - - const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{}); - const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{}); - const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{}); - const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); -#endif - - Tensor in_n_hi_wi_c( - make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence{}))); - Tensor wei_k_y_x_c( - make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence{}))); - Tensor out_n_ho_wo_k( - make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence{}))); - - auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) { - in_n_hi_wi_c(n, hi, wi, c) = in_n_c_hi_wi(n, c, hi, wi); - }; - - auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) { - wei_k_y_x_c(k, y, x, c) = wei_k_c_y_x(k, c, y, x); - }; - - auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) { - out_n_ho_wo_k(n, ho, wo, k) = out_n_k_ho_wo(n, k, ho, wo); - }; - - make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)(); - make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)(); - make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)(); - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); @@ -117,357 +49,472 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk( wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + const auto in_n_hi_wi_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); + #if 0 // cdata = 16, BlockSize = 64, 16x64x4 constexpr index_t BlockSize = 64; - constexpr index_t GemmMPerBlock = 16; - constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmMPerBlockM1 = 16; + constexpr index_t GemmNPerBlockN1 = 64; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerThread = 2; - constexpr index_t GemmNPerThread = 2; - constexpr index_t GemmKPerThread = 1; + constexpr index_t GemmM1PerThreadM111 = 2; + constexpr index_t GemmN1PerThreadN111 = 2; + constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 8; + constexpr index_t GemmM11N11ThreadClusterM1101 = 2; + constexpr index_t GemmM11N11ThreadClusterN1101 = 2; + constexpr index_t GemmM11N11ThreadClusterM1100 = 2; + constexpr index_t GemmM11N11ThreadClusterN1100 = 8; - constexpr index_t ThreadGemmDataPerReadM = 2; - constexpr index_t ThreadGemmDataPerReadN = 2; + using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>; + using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 16>; - using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; - using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; + constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1; + constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1; - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>; + using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>; - using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; - using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; + constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1; - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 2; + constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 2; #elif 0 // cdata = 32, BlockSize = 64, 16x128x4 constexpr index_t BlockSize = 64; - constexpr index_t GemmMPerBlock = 16; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; + constexpr index_t GemmMPerBlockM1 = 16; + constexpr index_t GemmNPerBlockN1 = 128; + constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerThread = 2; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; + constexpr index_t GemmM1PerThreadM111 = 2; + constexpr index_t GemmN1PerThreadN111 = 4; + constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 8; + constexpr index_t GemmM11N11ThreadClusterM1101 = 2; + constexpr index_t GemmM11N11ThreadClusterN1101 = 2; + constexpr index_t GemmM11N11ThreadClusterM1100 = 2; + constexpr index_t GemmM11N11ThreadClusterN1100 = 8; - constexpr index_t ThreadGemmDataPerReadM = 2; - constexpr index_t ThreadGemmDataPerReadN = 4; + using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>; + using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 16>; - using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; - using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; + constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1; + constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1; - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 2>; + using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>; - using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; + constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1; - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 2; + constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 2; #elif 0 // cdata = 64, BlockSize = 64, 16x256x2 constexpr index_t BlockSize = 64; - constexpr index_t GemmMPerBlock = 16; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 2; + constexpr index_t GemmMPerBlockM1 = 16; + constexpr index_t GemmNPerBlockN1 = 256; + constexpr index_t GemmKPerBlock = 2; - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; + constexpr index_t GemmM1PerThreadM111 = 4; + constexpr index_t GemmN1PerThreadN111 = 4; + constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMLevel0Cluster = 1; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 16; + constexpr index_t GemmM11N11ThreadClusterM1101 = 1; + constexpr index_t GemmM11N11ThreadClusterN1101 = 2; + constexpr index_t GemmM11N11ThreadClusterM1100 = 2; + constexpr index_t GemmM11N11ThreadClusterN1100 = 16; - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; + using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>; + using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 16>; - using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; - using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 16>; + constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1; + constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1; - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<2, 1, 4>; + using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>; - using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; + constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1; - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 2; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; + constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4; #elif 0 // cdata = 64, BlockSize = 64, 16x256x4 constexpr index_t BlockSize = 64; - constexpr index_t GemmMPerBlock = 16; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 4; + constexpr index_t GemmMPerBlockM1 = 16; + constexpr index_t GemmNPerBlockN1 = 256; + constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; + constexpr index_t GemmM1PerThreadM111 = 4; + constexpr index_t GemmN1PerThreadN111 = 4; + constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 1; - constexpr index_t GemmNLevel1Cluster = 16; + constexpr index_t GemmM11N11ThreadClusterM1101 = 2; + constexpr index_t GemmM11N11ThreadClusterN1101 = 2; + constexpr index_t GemmM11N11ThreadClusterM1100 = 1; + constexpr index_t GemmM11N11ThreadClusterN1100 = 16; - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; + using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>; + using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 16>; - using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; - using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>; + constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1; + constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1; - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 4>; + using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 64>; - using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>; + constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1; - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; + constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4; #elif 0 // cdata = 64, BlockSize = 128, 32x256x4 constexpr index_t BlockSize = 128; - constexpr index_t GemmMPerBlock = 32; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 4; + constexpr index_t GemmMPerBlockM1 = 32; + constexpr index_t GemmNPerBlockN1 = 256; + constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; + constexpr index_t GemmM1PerThreadM111 = 4; + constexpr index_t GemmN1PerThreadN111 = 4; + constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 16; + constexpr index_t GemmM11N11ThreadClusterM1101 = 2; + constexpr index_t GemmM11N11ThreadClusterN1101 = 2; + constexpr index_t GemmM11N11ThreadClusterM1100 = 2; + constexpr index_t GemmM11N11ThreadClusterN1100 = 16; - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; + using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<1, 1, 1>; + using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 32>; - using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>; - using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>; + constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 1; + constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1; - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 1; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 2>; + using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 128>; - using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>; + constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1; - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; + constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4; #elif 0 // cdata = 64, BlockSize = 128, 32x256x8 constexpr index_t BlockSize = 128; - constexpr index_t GemmMPerBlock = 32; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 8; + constexpr index_t GemmMPerBlockM1 = 32; + constexpr index_t GemmNPerBlockN1 = 256; + constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; + constexpr index_t GemmM1PerThreadM111 = 4; + constexpr index_t GemmN1PerThreadN111 = 4; + constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 2; - constexpr index_t GemmNLevel1Cluster = 16; + constexpr index_t GemmM11N11ThreadClusterM1101 = 2; + constexpr index_t GemmM11N11ThreadClusterN1101 = 2; + constexpr index_t GemmM11N11ThreadClusterM1100 = 2; + constexpr index_t GemmM11N11ThreadClusterN1100 = 16; - constexpr index_t ThreadGemmDataPerReadM = 4; - constexpr index_t ThreadGemmDataPerReadN = 4; + using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<2, 1, 1>; + using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 32>; - using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<2, 1>; - using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 32>; + constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 2; + constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1; - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 2; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<8, 1, 2>; + using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<1, 1, 128>; - using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 2>; - using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<1, 128>; + constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1; - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; + constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4; #elif 1 // cdata = 64, BlockSize = 256, 128x128x8 constexpr index_t BlockSize = 256; - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 8; + constexpr index_t GemmMPerBlockM1 = 128; + constexpr index_t GemmNPerBlockN1 = 128; + constexpr index_t GemmKPerBlock = 8; - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; + constexpr index_t GemmM1PerThreadM111 = 4; + constexpr index_t GemmN1PerThreadN111 = 4; + constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 8; - constexpr index_t GemmNLevel1Cluster = 8; + constexpr index_t GemmM11N11ThreadClusterM1100 = 8; + constexpr index_t GemmM11N11ThreadClusterN1100 = 8; + constexpr index_t GemmM11N11ThreadClusterM1101 = 2; + constexpr index_t GemmM11N11ThreadClusterN1101 = 2; - using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>; - using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>; + using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 1>; + using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<2, 1, 128>; - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 1; + constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 1; - using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>; - using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; + using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<4, 1, 1>; + using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>; - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1; - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; + constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4; #elif 1 // cdata = 64, BlockSize = 256, 128x128x16 constexpr index_t BlockSize = 256; - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 16; + constexpr index_t GemmMPerBlockM1 = 128; + constexpr index_t GemmNPerBlockN1 = 128; + constexpr index_t GemmKPerBlock = 16; - constexpr index_t GemmMPerThread = 4; - constexpr index_t GemmNPerThread = 4; - constexpr index_t GemmKPerThread = 1; + constexpr index_t GemmM1PerThreadM111 = 4; + constexpr index_t GemmN1PerThreadN111 = 4; + constexpr index_t GemmKPerThread = 1; - constexpr index_t GemmMLevel0Cluster = 2; - constexpr index_t GemmNLevel0Cluster = 2; - constexpr index_t GemmMLevel1Cluster = 8; - constexpr index_t GemmNLevel1Cluster = 8; + constexpr index_t GemmM11N11ThreadClusterM1100 = 8; + constexpr index_t GemmM11N11ThreadClusterN1100 = 8; + constexpr index_t GemmM11N11ThreadClusterM1101 = 2; + constexpr index_t GemmM11N11ThreadClusterN1101 = 2; - using GemmABlockTransferThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>; - using GemmABlockTransferThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>; + using GemmABlockTransferThreadSliceLengths_K_M0_M1 = Sequence<4, 1, 2>; + using GemmABlockTransferThreadClusterLengths_K_M0_M1 = Sequence<4, 1, 64>; - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmM = 2; + constexpr index_t GemmABlockTransferSrcScalarPerVector_K = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_M1 = 2; - using GemmBBlockTransferThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>; - using GemmBBlockTransferThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>; + using GemmBBlockTransferThreadSliceLengths_K_N0_N1 = Sequence<8, 1, 1>; + using GemmBBlockTransferThreadClusterLengths_K_N0_N1 = Sequence<2, 1, 128>; - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferSrcScalarPerVector_K = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_N1 = 1; - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4; + constexpr index_t GemmCThreadTransferDstScalarPerVector_M11 = 4; #endif - constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster; - constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster; - - const auto descs = #if 1 - transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad + const auto descs = + transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, + in_n_hi_wi_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads); + +#if 0 + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); + + constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; #else - transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1 + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); + + constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + constexpr auto in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}; #endif - (wei_k_y_x_c0_desc, - in_n_hi_wi_c0_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads); + +#else + const auto descs = + transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1(wei_k_y_x_c_desc, + in_n_hi_wi_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto in_gemmk_gemmn0_gemmn1_grid_iterator_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), + make_tuple( + Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto out_gemmm0_gemmm1_gemmn0_gemmn1_global_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0>{})); + + constexpr auto wei_gemmk_gemmm_global_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + constexpr auto in_gemmk_gemmn_global_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0>{}; +#endif + + const auto wei_gemmk_gemmm_grid_desc = descs[I0]; + const auto in_gemmk_gemmn_grid_desc = descs[I1]; + const auto out_gemmm_gemmn_grid_desc = descs[I2]; for(index_t i = 0; i < 5; ++i) { - float ave_time = launch_kernel_dynamic_gemm_v1< + float ave_time = driver_dynamic_gemm_v1r2< BlockSize, - typename vector_type::type, + TInWei, TAcc, TOut, InMemoryDataOperation::Set, - decltype(descs[I0]), - decltype(descs[I1]), - decltype(descs[I2]), - decltype(descs[I3]), - GemmMPerBlock, - GemmNPerBlock, + decltype(wei_gemmk_gemmm_grid_desc), + decltype(in_gemmk_gemmn_grid_desc), + decltype(out_gemmm_gemmn_grid_desc), + GemmMPerBlockM1, + GemmNPerBlockN1, GemmKPerBlock, - GemmMPerThread, - GemmNPerThread, + GemmM1PerThreadM111, + GemmN1PerThreadN111, GemmKPerThread, - GemmMLevel0Cluster, - GemmNLevel0Cluster, - GemmMLevel1Cluster, - GemmNLevel1Cluster, - GemmABlockTransferThreadSliceLengths_GemmK_GemmM, - GemmABlockTransferThreadClusterLengths_GemmK_GemmM, - Sequence<1, 0>, - Sequence<1, 0>, - 0, - GemmABlockTransferSrcScalarPerVector_GemmK, - GemmABlockTransferDstScalarPerVector_GemmM, + GemmM11N11ThreadClusterM1100, + GemmM11N11ThreadClusterN1100, + GemmM11N11ThreadClusterM1101, + GemmM11N11ThreadClusterN1101, + GemmABlockTransferThreadSliceLengths_K_M0_M1, + GemmABlockTransferThreadClusterLengths_K_M0_M1, + Sequence<1, 2, 0>, // ABlockTransferThreadClusterArrangeOrder + Sequence<1, 2, 0>, // ABlockTransferSrcAccessOrder + 0, // ABlockTransferSrcVectorDim + GemmABlockTransferSrcScalarPerVector_K, + GemmABlockTransferDstScalarPerVector_M1, false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK_GemmN, - GemmBBlockTransferThreadClusterLengths_GemmK_GemmN, - Sequence<1, 0>, - Sequence<1, 0>, - 0, - GemmBBlockTransferSrcScalarPerVector_GemmK, - GemmBBlockTransferDstScalarPerVector_GemmN, - false, // don't move back src coordinate after threadwise copy, which will be fused with - // MoveSrcSliceWindow() to save addr computation - Sequence<2, 3, 0, 1>, - 1, - GemmCThreadTransferDstScalarPerVector_GemmM1, - decltype(descs[I4]), - decltype(descs[I5]), - decltype(descs[I6]), - decltype(descs[I7]), - decltype(descs[I8])>(static_cast::type*>( - wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast::type*>( - in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - descs[I0], - descs[I1], - descs[I2], - descs[I3], - descs[I4], - descs[I5], - descs[I6], - descs[I7], - descs[I8], - nrepeat); + GemmBBlockTransferThreadSliceLengths_K_N0_N1, + GemmBBlockTransferThreadClusterLengths_K_N0_N1, + Sequence<1, 2, 0>, // BBlockTransferThreadClusterArrangeOrder + Sequence<1, 2, 0>, // BBlockTransferSrcAccessOrder + 0, // BBlockTransferSrcVectorDim + GemmBBlockTransferSrcScalarPerVector_K, + GemmBBlockTransferDstScalarPerVector_N1, + false, // don't move back src coordinate after threadwise copy + Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder + 2, // CThreadTransferSrcDstVectorDim + GemmCThreadTransferDstScalarPerVector_M11, + decltype(wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks), + decltype(in_gemmk_gemmn0_gemmn1_grid_iterator_hacks), + decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks), + decltype(wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks), + decltype(in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks)>( + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + wei_gemmk_gemmm_grid_desc, + in_gemmk_gemmn_grid_desc, + out_gemmm_gemmn_grid_desc, + wei_gemmk_gemmm0_gemmn1_grid_iterator_hacks, + in_gemmk_gemmn0_gemmn1_grid_iterator_hacks, + out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks, + wei_gemmk_gemmm0_gemmm1_grid_move_slice_window_iterator_hacks, + in_gemmk_gemmn0_gemmn1_grid_move_slice_window_iterator_hacks, + nrepeat); - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / - (std::size_t(1000) * 1000 * 1000) / ave_time; + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + const auto Hi = in_n_hi_wi_c_lengths[I1]; + const auto Wi = in_n_hi_wi_c_lengths[I2]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } } // copy result back to host out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); - - auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) { - out_n_k_ho_wo(n, k, ho, wo) = out_n_ho_wo_k(n, ho, wo, k); - }; - - make_ParallelTensorFunctor(f_nhwk2nkhw, N, K, Ho, Wo)(); } diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..d00314c8d9 --- /dev/null +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp @@ -0,0 +1,240 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_forward_convolution_into_gemm_v4r5_nchw_kcyx_nkhw.hpp" +#include "driver_dynamic_contraction_v1r1.hpp" + +template +void device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw( + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + const auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); + +#if 1 + // cdata = 64, BlockSize = 256, [8, 1, 128] * [8, 4, 32] = [1, 128, 4, 32] + constexpr index_t BlockSize = 256; + + constexpr index_t N0 = 4; + + constexpr index_t GemmGM1PerBlockGM11 = 128; + constexpr index_t GemmGN1PerBlockGN11 = 32; + constexpr index_t GemmKPerBlock = 8; + + constexpr index_t GemmM1PerThreadM111 = 4; + constexpr index_t GemmN1PerThreadN111 = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmM11N11ThreadClusterM1101 = 2; + constexpr index_t GemmM11N11ThreadClusterN1101 = 2; + constexpr index_t GemmM11N11ThreadClusterM1100 = 8; + constexpr index_t GemmM11N11ThreadClusterN1100 = 8; + + using GemmABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11 = Sequence<4, 1, 1, 1>; + using GemmABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11 = Sequence<2, 1, 1, 128>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GK = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GM11 = 1; + + using GemmBBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11 = Sequence<1, 4, 1, 1>; + using GemmBBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11 = Sequence<8, 1, 1, 32>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GN11 = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GN11 = 1; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_BN1 = 1; +#elif 1 + // cdata = 64, BlockSize = 256, [8, 1, 128] * [8, 8, 16] = [1, 128, 8, 16] + constexpr index_t BlockSize = 256; + + constexpr index_t N0 = 8; + + constexpr index_t GemmGM1PerBlockGM11 = 128; + constexpr index_t GemmGN1PerBlockGN11 = 16; + constexpr index_t GemmKPerBlock = 8; + + constexpr index_t GemmM1PerThreadM111 = 4; + constexpr index_t GemmN1PerThreadN111 = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmM11N11ThreadClusterM1101 = 2; + constexpr index_t GemmM11N11ThreadClusterN1101 = 2; + constexpr index_t GemmM11N11ThreadClusterM1100 = 8; + constexpr index_t GemmM11N11ThreadClusterN1100 = 8; + + using GemmABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11 = Sequence<4, 1, 1, 1>; + using GemmABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11 = Sequence<2, 1, 1, 128>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GK = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GM11 = 1; + + using GemmBBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11 = Sequence<1, 4, 1, 1>; + using GemmBBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11 = Sequence<8, 2, 1, 16>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GN11 = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GN11 = 1; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_BN1 = 1; +#endif + + const auto descs = transform_forward_convolution_into_contraction_v4r5_nchw_kcyx_nkhw_pad( + wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads); + + const auto wei_gk_gm0_gm1_grid_desc = descs[I0]; + const auto in_gk_gn0_gn1_grid_desc = descs[I1]; + const auto out_gm0_gm1_gn0_gn1_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_gk_gm0_gm10_gm11_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0>{})); + + constexpr auto in_gk_gn0_gn10_gn11_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); + + constexpr auto out_gm10_bm0_bm1_gn10_bn0_bn1_grid_iterator_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); + + constexpr auto wei_gk_gm0_gm10_gm11_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0>{}; + + constexpr auto in_gk_gn0_gn10_gn11_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_dynamic_contraction_v1r1< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperation::Set, + decltype(wei_gk_gm0_gm1_grid_desc), + decltype(in_gk_gn0_gn1_grid_desc), + decltype(out_gm0_gm1_gn0_gn1_grid_desc), + GemmGM1PerBlockGM11, + GemmGN1PerBlockGN11, + GemmKPerBlock, + GemmM1PerThreadM111, + GemmN1PerThreadN111, + GemmKPerThread, + GemmM11N11ThreadClusterM1100, + GemmM11N11ThreadClusterN1100, + GemmM11N11ThreadClusterM1101, + GemmM11N11ThreadClusterN1101, + GemmABlockTransferThreadSliceLengths_GK_GM0_GM10_GM11, + GemmABlockTransferThreadClusterLengths_GK_GM0_GM10_GM11, + Sequence<3, 2, 1, 0>, // ABlockTransferThreadClusterArrangeOrder + Sequence<3, 2, 1, 0>, // ABlockTransferSrcAccessOrder + 0, // ABlockTransferSrcVectorDim + GemmABlockTransferSrcScalarPerVector_GK, + GemmABlockTransferDstScalarPerVector_GM11, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GK_GN0_GN10_GN11, + GemmBBlockTransferThreadClusterLengths_GK_GN0_GN10_GN11, + Sequence<0, 3, 2, 1>, // BBlockTransferThreadClusterArrangeOrder + Sequence<0, 3, 2, 1>, // BBlockTransferSrcAccessOrder + 3, // BBlockTransferSrcVectorDim + GemmBBlockTransferSrcScalarPerVector_GN11, + GemmBBlockTransferDstScalarPerVector_GN11, + false, // don't move back src coordinate after threadwise copy + Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder + 5, // CThreadTransferSrcDstVectorDim + GemmCThreadTransferDstScalarPerVector_BN1, + decltype(wei_gk_gm0_gm10_gm11_grid_iterator_hacks), + decltype(in_gk_gn0_gn10_gn11_grid_iterator_hacks), + decltype(out_gm10_bm0_bm1_gn10_bn0_bn1_grid_iterator_hacks), + decltype(wei_gk_gm0_gm10_gm11_grid_move_slice_window_iterator_hacks), + decltype(in_gk_gn0_gn10_gn11_grid_move_slice_window_iterator_hacks)>( + static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + wei_gk_gm0_gm1_grid_desc, + in_gk_gn0_gn1_grid_desc, + out_gm0_gm1_gn0_gn1_grid_desc, + wei_gk_gm0_gm10_gm11_grid_iterator_hacks, + in_gk_gn0_gn10_gn11_grid_iterator_hacks, + out_gm10_bm0_bm1_gn10_bn0_bn1_grid_iterator_hacks, + wei_gk_gm0_gm10_gm11_grid_move_slice_window_iterator_hacks, + in_gk_gn0_gn10_gn11_grid_move_slice_window_iterator_hacks, + nrepeat); + + float perf = (float)calculate_convolution_flops( + in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); +} diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp index 08573e418c..cb2e7e5264 100644 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp @@ -4,97 +4,64 @@ #include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" #include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp" -template + typename TAcc, + typename TOut, + typename InLengths, + typename WeiLengths, + typename OutLengths, + typename ConvStrides, + typename ConvDilations, + typename InLeftPads, + typename InRightPads> void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( - InDesc, + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, const Tensor& in_n_c_hi_wi, - WeiDesc, const Tensor& wei_k_c_y_x, - OutDesc, Tensor& out_n_k_ho_wo, - ConvStrides, - ConvDilations, - InLeftPads, - InRightPads, ck::index_t nrepeat) { using namespace ck; - std::cout << "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw" - << std::endl; - - DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + std::cout << __func__ << std::endl; constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - constexpr auto N = OutDesc::GetLengths()[I0]; - constexpr auto K = OutDesc::GetLengths()[I1]; - constexpr auto C = WeiDesc::GetLengths()[I1]; + const auto N = out_n_k_ho_wo_lengths[I0]; + const auto K = out_n_k_ho_wo_lengths[I1]; + const auto C = wei_k_c_y_x_lengths[I1]; - constexpr auto Hi = InDesc::GetLengths()[I2]; - constexpr auto Wi = InDesc::GetLengths()[I3]; + const auto Hi = in_n_c_hi_wi_lengths[I2]; + const auto Wi = in_n_c_hi_wi_lengths[I3]; - constexpr auto Ho = OutDesc::GetLengths()[I2]; - constexpr auto Wo = OutDesc::GetLengths()[I3]; + const auto Ho = out_n_k_ho_wo_lengths[I2]; + const auto Wo = out_n_k_ho_wo_lengths[I3]; - constexpr auto Y = WeiDesc::GetLengths()[I2]; - constexpr auto X = WeiDesc::GetLengths()[I3]; + const auto Y = wei_k_c_y_x_lengths[I2]; + const auto X = wei_k_c_y_x_lengths[I3]; - constexpr auto C0 = C / Number{}; - constexpr auto C1 = Number{}; + const auto C0 = C / Number{}; + const auto C1 = Number{}; - constexpr auto K0 = K / Number{}; - constexpr auto K1 = Number{}; + const auto K0 = K / Number{}; + const auto K1 = Number{}; -#if 0 - // run-time variables - const auto in_n_c0_hi_wi_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, C0, Hi, Wi)); - const auto wei_k_c0_y_x_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(K, C0, Y, X)); - const auto out_n_k0_ho_wo_k1_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, K0, Ho, Wo, K1)); - - const auto conv_strides = to_multi_index(ConvStrides{}); - const auto conv_dilations = to_multi_index(ConvDilations{}); - const auto in_left_pads = to_multi_index(InLeftPads{}); - const auto in_right_pads = to_multi_index(InRightPads{}); -#else - // compile-time variables - const auto in_n_c0_hi_wi_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C0, Hi, Wi)); - const auto wei_k_c0_y_x_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X)); - const auto out_n_k0_ho_wo_k1_desc = - make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)); - - const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{}); - const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{}); - const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{}); - const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{}); -#endif - - Tensor in_n_c0_hi_wi_c1(make_HostTensorDescriptor( - make_native_tensor_descriptor_packed(Sequence{}))); - Tensor wei_k_c0_y_x_c1(make_HostTensorDescriptor( - make_native_tensor_descriptor_packed(Sequence{}))); - Tensor out_n_k0_ho_wo_k1(make_HostTensorDescriptor( - make_native_tensor_descriptor_packed(Sequence{}))); + Tensor in_n_c0_hi_wi_c1( + HostTensorDescriptor(std::initializer_list{N, C0, Hi, Wi, C1})); + Tensor wei_k_c0_y_x_c1( + HostTensorDescriptor(std::initializer_list{K, C0, Y, X, C1})); + Tensor out_n_k0_ho_wo_k1( + HostTensorDescriptor(std::initializer_list{N, K0, Ho, Wo, K1})); auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) { in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) = @@ -109,17 +76,30 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( make_ParallelTensorFunctor(f_nchw2nc0hwc1, N, Hi, Wi, C)(); make_ParallelTensorFunctor(f_kcyx2kc0yxc1, K, Y, X, C)(); - in_n_c_hi_wi_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data()); - wei_k_c_y_x_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data()); + DeviceMem in_n_c0_hi_wi_c1_device_buf(sizeof(TInWei) * + in_n_c0_hi_wi_c1.mDesc.GetElementSpace()); + DeviceMem wei_k_c0_y_x_c1_device_buf(sizeof(TInWei) * wei_k_c0_y_x_c1.mDesc.GetElementSpace()); + DeviceMem out_n_k0_ho_wo_k1_device_buf(sizeof(TOut) * + out_n_k0_ho_wo_k1.mDesc.GetElementSpace()); + + in_n_c0_hi_wi_c1_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data()); + wei_k_c0_y_x_c1_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data()); + + const auto in_n_c0_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C0, Hi, Wi)); + const auto wei_k_c0_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X)); + const auto out_n_k0_ho_wo_k1_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)); #if 1 // cdata = 64, BlockSize = 64, 16x8x32x4 constexpr index_t BlockSize = 64; - constexpr index_t KPerBlock = K; + constexpr index_t KPerBlock = 16; constexpr index_t HoPerBlock = 8; constexpr index_t WoPerBlock = 32; - constexpr index_t EPerBlock = C0; + constexpr index_t EPerBlock = 1; constexpr index_t KPerThread = KPerBlock; constexpr index_t HoPerThread = 2; @@ -134,7 +114,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( constexpr index_t BThreadTransferSrcScalarPerVector_W = 1; - constexpr index_t CThreadTransferDstScalarPerVector_W = K1; + constexpr index_t CThreadTransferDstScalarPerVector_W = 16; static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, ""); #else @@ -165,17 +145,28 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( constexpr auto conv_driver = #if 0 - DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad< + DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad #else - DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad< + DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad #endif - BlockSize, - typename vector_type::type, TAcc, TOut, KPerBlock, - HoPerBlock, WoPerBlock, EPerBlock, KPerThread, HoPerThread, WoPerThread, - EPerThread, ABlockTransferThreadSliceLengths_E_K, - ABlockTransferThreadClusterLengths_E_K, ABlockTransferSrcScalarPerVector_E, - ABlockTransferDstScalarPerVector_K, BThreadTransferSrcScalarPerVector_W, - CThreadTransferDstScalarPerVector_W > {}; + ::type, + TAcc, + TOut, + KPerBlock, + HoPerBlock, + WoPerBlock, + EPerBlock, + KPerThread, + HoPerThread, + WoPerThread, + EPerThread, + ABlockTransferThreadSliceLengths_E_K, + ABlockTransferThreadClusterLengths_E_K, + ABlockTransferSrcScalarPerVector_E, + ABlockTransferDstScalarPerVector_K, + BThreadTransferSrcScalarPerVector_W, + CThreadTransferDstScalarPerVector_W>{}; conv_driver.Run(wei_k_c0_y_x_desc, in_n_c0_hi_wi_desc, @@ -185,12 +176,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( in_left_pads, in_right_pads, static_cast::type*>( - wei_k_c_y_x_device_buf.GetDeviceBuffer()), + wei_k_c0_y_x_c1_device_buf.GetDeviceBuffer()), static_cast::type*>( - in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer())); + in_n_c0_hi_wi_c1_device_buf.GetDeviceBuffer()), + static_cast(out_n_k0_ho_wo_k1_device_buf.GetDeviceBuffer())); - out_n_k_ho_wo_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data()); + out_n_k0_ho_wo_k1_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data()); auto f_nk0hwk1_to_nkhw = [&](auto n, auto k, auto ho, auto wo) { out_n_k_ho_wo(n, k, ho, wo) = diff --git a/driver/include/host_conv.hpp b/driver/include/host_conv.hpp index d7cf3bbb00..6d7d758df6 100644 --- a/driver/include/host_conv.hpp +++ b/driver/include/host_conv.hpp @@ -6,58 +6,94 @@ template -void host_direct_convolution(const Tensor& in_nchw, - const Tensor& wei_kcyx, - Tensor& out_nkhw, - ConvStrides, - ConvDilations, - LowerPads, - UpperPads) + class InLeftPads, + class InRightPads> +void host_direct_convolution(const Tensor& in, + const Tensor& wei, + Tensor& out, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const ConvTensorLayout layout = ConvTensorLayout::NCHW) { using namespace ck; - index_t h_pad_low = LowerPads{}.Get(Number<0>{}); - index_t w_pad_low = LowerPads{}.Get(Number<1>{}); + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; - auto f = [&](auto n, auto k, auto ho, auto wo) { + auto f_nchw = [&](auto n, auto k, auto ho, auto wo) { double v = 0; - for(int c = 0; c < wei_kcyx.mDesc.GetLengths()[1]; ++c) + for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c) { - for(int y = 0; y < wei_kcyx.mDesc.GetLengths()[2]; ++y) + for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) { - int hi = ho * ConvStrides{}[0] + y * ConvDilations{}[0] - h_pad_low; - for(int x = 0; x < wei_kcyx.mDesc.GetLengths()[3]; ++x) + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) { - int wi = wo * ConvStrides{}[1] + x * ConvDilations{}[1] - w_pad_low; - if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 && - wi < in_nchw.mDesc.GetLengths()[3]) + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && + wi < in.mDesc.GetLengths()[3]) { - v += static_cast(in_nchw(n, c, hi, wi)) * - static_cast(wei_kcyx(k, c, y, x)); + v += static_cast(in(n, c, hi, wi)) * + static_cast(wei(k, c, y, x)); } } } } - out_nkhw(n, k, ho, wo) = v; + out(n, k, ho, wo) = v; }; - auto f_par = make_ParallelTensorFunctor(f, - out_nkhw.mDesc.GetLengths()[0], - out_nkhw.mDesc.GetLengths()[1], - out_nkhw.mDesc.GetLengths()[2], - out_nkhw.mDesc.GetLengths()[3]); + auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) { + double v = 0; + for(int c = 0; c < wei.mDesc.GetLengths()[3]; ++c) + { + for(int y = 0; y < wei.mDesc.GetLengths()[1]; ++y) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int x = 0; x < wei.mDesc.GetLengths()[2]; ++x) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && + wi < in.mDesc.GetLengths()[2]) + { + v += static_cast(in(n, hi, wi, c)) * + static_cast(wei(k, y, x, c)); + } + } + } + } + out(n, ho, wo, k) = v; + }; - f_par(std::thread::hardware_concurrency()); + switch(layout) + { + case ConvTensorLayout::NCHW: + make_ParallelTensorFunctor(f_nchw, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + break; + case ConvTensorLayout::NHWC: + make_ParallelTensorFunctor(f_nhwc, + out.mDesc.GetLengths()[0], + out.mDesc.GetLengths()[1], + out.mDesc.GetLengths()[2], + out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + break; + default: throw std::runtime_error("wrong! not supported layout"); + } } -template +template void host_winograd_3x3_convolution(const Tensor& in_nchw, const Tensor& wei_kcyx, Tensor& out_nkhw, - LowerPads, - UpperPads) + InLeftPads, + InRightPads) { using namespace ck; @@ -76,8 +112,8 @@ void host_winograd_3x3_convolution(const Tensor& in_nchw, std::size_t HO = out_nkhw.mDesc.GetLengths()[2]; std::size_t WO = out_nkhw.mDesc.GetLengths()[3]; - index_t h_pad_low = LowerPads{}.Get(Number<0>{}); - index_t w_pad_low = LowerPads{}.Get(Number<1>{}); + index_t h_pad_low = InLeftPads{}.Get(Number<0>{}); + index_t w_pad_low = InLeftPads{}.Get(Number<1>{}); std::size_t HiPerTile = HoPerTile + Y - 1; std::size_t WiPerTile = WoPerTile + X - 1; diff --git a/driver/include/host_tensor.hpp b/driver/include/host_tensor.hpp index ac6df6f931..64d0ee26d3 100644 --- a/driver/include/host_tensor.hpp +++ b/driver/include/host_tensor.hpp @@ -271,19 +271,20 @@ struct Tensor std::vector mData; }; -void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout) +template +HostTensorDescriptor::HostTensorDescriptor(std::vector lens) : mLens(lens) { - os << "dim " << desc.GetNumOfDimension() << ", "; - - os << "lengths {"; - LogRange(os, desc.GetLengths(), ", "); - os << "}, "; - - os << "strides {"; - LogRange(os, desc.GetStrides(), ", "); - os << "}" << std::endl; + this->CalculateStrides(); } +template +HostTensorDescriptor::HostTensorDescriptor(std::vector lens, std::vector strides) + : mLens(lens), mStrides(strides) +{ +} + +void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os = std::cout); + template void check_error(const Tensor& ref, const Tensor& result) { diff --git a/driver/include/host_tensor_generator.hpp b/driver/include/host_tensor_generator.hpp index 84ff1bfff2..d49d2d9122 100644 --- a/driver/include/host_tensor_generator.hpp +++ b/driver/include/host_tensor_generator.hpp @@ -44,7 +44,7 @@ struct GeneratorTensor_Checkboard template double operator()(Ts... Xs) const { - std::array dims = {{Xs...}}; + std::array dims = {{static_cast(Xs)...}}; return std::accumulate(dims.begin(), dims.end(), true, diff --git a/driver/src/conv_driver.cpp b/driver/src/conv_driver.cpp index 00f83d2bca..4b32c786b8 100644 --- a/driver/src/conv_driver.cpp +++ b/driver/src/conv_driver.cpp @@ -14,19 +14,41 @@ #include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" -#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" -#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" -#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" int main(int argc, char* argv[]) { using namespace ck; + if(argc != 5) + { + printf("arg1: do_verification, arg2: do_log, arg3: init_method, arg4: nrepeat\n"); + exit(1); + } + + const bool do_verification = atoi(argv[1]); + const int init_method = atoi(argv[2]); + const bool do_log = atoi(argv[3]); + const int nrepeat = atoi(argv[4]); + #if 0 + constexpr index_t N = 8; + constexpr index_t C = 8; + constexpr index_t Hi = 4; + constexpr index_t Wi = 8; + constexpr index_t K = 256; + constexpr index_t Y = 3; + constexpr index_t X = 3; + + using ConvStrides = Sequence<1, 1>; + using ConvDilations = Sequence<1, 1>; + + using InLeftPads = Sequence<1, 1>; + using InRightPads = Sequence<1, 1>; +#elif 0 constexpr index_t N = 1; constexpr index_t C = 16; - constexpr index_t HI = 1080; - constexpr index_t WI = 1920; + constexpr index_t Hi = 540; + constexpr index_t Wi = 960; constexpr index_t K = 16; constexpr index_t Y = 1; constexpr index_t X = 1; @@ -34,13 +56,13 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using InLeftPads = Sequence<0, 0>; + using InRightPads = Sequence<0, 0>; #elif 0 constexpr index_t N = 1; constexpr index_t C = 16; - constexpr index_t HI = 540; - constexpr index_t WI = 960; + constexpr index_t Hi = 270; + constexpr index_t Wi = 480; constexpr index_t K = 16; constexpr index_t Y = 1; constexpr index_t X = 1; @@ -48,27 +70,13 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using InLeftPads = Sequence<0, 0>; + using InRightPads = Sequence<0, 0>; #elif 0 constexpr index_t N = 1; constexpr index_t C = 16; - constexpr index_t HI = 270; - constexpr index_t WI = 480; - constexpr index_t K = 16; - constexpr index_t Y = 1; - constexpr index_t X = 1; - - using ConvStrides = Sequence<1, 1>; - using ConvDilations = Sequence<1, 1>; - - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 0 - constexpr index_t N = 1; - constexpr index_t C = 16; - constexpr index_t HI = 1080; - constexpr index_t WI = 1920; + constexpr index_t Hi = 1080; + constexpr index_t Wi = 1920; constexpr index_t K = 16; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -76,13 +84,13 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<1, 1>; - using RightPads = Sequence<1, 1>; + using InLeftPads = Sequence<1, 1>; + using InRightPads = Sequence<1, 1>; #elif 0 constexpr index_t N = 1; constexpr index_t C = 1; - constexpr index_t HI = 1024; - constexpr index_t WI = 2048; + constexpr index_t Hi = 1024; + constexpr index_t Wi = 2048; constexpr index_t K = 4; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -90,13 +98,13 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<1, 1>; - using RightPads = Sequence<1, 1>; + using InLeftPads = Sequence<1, 1>; + using InRightPads = Sequence<1, 1>; #elif 0 constexpr index_t N = 1; constexpr index_t C = 16; - constexpr index_t HI = 540; - constexpr index_t WI = 960; + constexpr index_t Hi = 540; + constexpr index_t Wi = 960; constexpr index_t K = 16; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -104,13 +112,13 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<1, 1>; - using RightPads = Sequence<1, 1>; + using InLeftPads = Sequence<1, 1>; + using InRightPads = Sequence<1, 1>; #elif 0 constexpr index_t N = 1; constexpr index_t C = 16; - constexpr index_t HI = 270; - constexpr index_t WI = 480; + constexpr index_t Hi = 270; + constexpr index_t Wi = 480; constexpr index_t K = 16; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -118,14 +126,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<1, 1>; - using RightPads = Sequence<1, 1>; + using InLeftPads = Sequence<1, 1>; + using InRightPads = Sequence<1, 1>; #elif 0 // 3x3, 36x36, stride 2 constexpr index_t N = 128; constexpr index_t C = 192; - constexpr index_t HI = 37; - constexpr index_t WI = 37; + constexpr index_t Hi = 37; + constexpr index_t Wi = 37; constexpr index_t K = 384; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -133,14 +141,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<2, 2>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using InLeftPads = Sequence<0, 0>; + using InRightPads = Sequence<0, 0>; #elif 0 // 3x3, 35x35, stride 2 constexpr index_t N = 128; constexpr index_t C = 192; - constexpr index_t HI = 35; - constexpr index_t WI = 35; + constexpr index_t Hi = 35; + constexpr index_t Wi = 35; constexpr index_t K = 384; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -148,14 +156,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<2, 2>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; -#elif 1 + using InLeftPads = Sequence<0, 0>; + using InRightPads = Sequence<0, 0>; +#elif 0 // 3x3, 71x71 constexpr index_t N = 128; constexpr index_t C = 192; - constexpr index_t HI = 71; - constexpr index_t WI = 71; + constexpr index_t Hi = 71; + constexpr index_t Wi = 71; constexpr index_t K = 128; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -163,14 +171,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<2, 2>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<1, 1>; - using RightPads = Sequence<1, 1>; -#elif 1 + using InLeftPads = Sequence<1, 1>; + using InRightPads = Sequence<1, 1>; +#elif 0 // 1x1, 8x8 constexpr index_t N = 128; constexpr index_t C = 1536; - constexpr index_t HI = 8; - constexpr index_t WI = 8; + constexpr index_t Hi = 8; + constexpr index_t Wi = 8; constexpr index_t K = 256; constexpr index_t Y = 1; constexpr index_t X = 1; @@ -178,14 +186,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using InLeftPads = Sequence<0, 0>; + using InRightPads = Sequence<0, 0>; #elif 0 // 1x1, 73x73 constexpr index_t N = 128; constexpr index_t C = 160; - constexpr index_t HI = 73; - constexpr index_t WI = 73; + constexpr index_t Hi = 73; + constexpr index_t Wi = 73; constexpr index_t K = 64; constexpr index_t Y = 1; constexpr index_t X = 1; @@ -193,14 +201,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using InLeftPads = Sequence<0, 0>; + using InRightPads = Sequence<0, 0>; #elif 0 // 3x3, 35x35 constexpr index_t N = 128; constexpr index_t C = 96; - constexpr index_t HI = 35; - constexpr index_t WI = 35; + constexpr index_t Hi = 35; + constexpr index_t Wi = 35; constexpr index_t K = 128; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -208,14 +216,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<1, 1>; - using RightPads = Sequence<1, 1>; -#elif 1 + using InLeftPads = Sequence<1, 1>; + using InRightPads = Sequence<1, 1>; +#elif 0 // 3x3, 71x71 constexpr index_t N = 128; constexpr index_t C = 192; - constexpr index_t HI = 71; - constexpr index_t WI = 71; + constexpr index_t Hi = 71; + constexpr index_t Wi = 71; constexpr index_t K = 192; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -223,14 +231,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<2, 2>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<1, 1>; - using RightPads = Sequence<1, 1>; + using InLeftPads = Sequence<1, 1>; + using InRightPads = Sequence<1, 1>; #elif 0 // 7x1, 17x17 constexpr index_t N = 128; constexpr index_t C = 128; - constexpr index_t HI = 17; - constexpr index_t WI = 17; + constexpr index_t Hi = 17; + constexpr index_t Wi = 17; constexpr index_t K = 128; constexpr index_t Y = 7; constexpr index_t X = 1; @@ -238,14 +246,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<3, 0>; - using RightPads = Sequence<3, 0>; -#elif 0 + using InLeftPads = Sequence<3, 0>; + using InRightPads = Sequence<3, 0>; +#elif 1 // 1x7, 17x17 constexpr index_t N = 128; constexpr index_t C = 128; - constexpr index_t HI = 17; - constexpr index_t WI = 17; + constexpr index_t Hi = 17; + constexpr index_t Wi = 17; constexpr index_t K = 128; constexpr index_t Y = 1; constexpr index_t X = 7; @@ -253,14 +261,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 3>; - using RightPads = Sequence<0, 3>; + using InLeftPads = Sequence<0, 3>; + using InRightPads = Sequence<0, 3>; #elif 0 // 3x3, 299x299 stride=2 constexpr index_t N = 128; constexpr index_t C = 3; - constexpr index_t HI = 299; - constexpr index_t WI = 299; + constexpr index_t Hi = 299; + constexpr index_t Wi = 299; constexpr index_t K = 32; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -268,14 +276,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<2, 2>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using InLeftPads = Sequence<0, 0>; + using InRightPads = Sequence<0, 0>; #elif 0 // 3x3, 147x147 constexpr index_t N = 128; constexpr index_t C = 128; - constexpr index_t HI = 147; - constexpr index_t WI = 147; + constexpr index_t Hi = 147; + constexpr index_t Wi = 147; constexpr index_t K = 128; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -283,14 +291,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<1, 1>; - using RightPads = Sequence<1, 1>; + using InLeftPads = Sequence<1, 1>; + using InRightPads = Sequence<1, 1>; #elif 0 // 3x3, 149x149 constexpr index_t N = 128; constexpr index_t C = 32; - constexpr index_t HI = 149; - constexpr index_t WI = 149; + constexpr index_t Hi = 149; + constexpr index_t Wi = 149; constexpr index_t K = 32; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -298,14 +306,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using InLeftPads = Sequence<0, 0>; + using InRightPads = Sequence<0, 0>; #elif 0 // 3x3, 17x17, stride 2 constexpr index_t N = 128; constexpr index_t C = 192; - constexpr index_t HI = 17; - constexpr index_t WI = 17; + constexpr index_t Hi = 17; + constexpr index_t Wi = 17; constexpr index_t K = 192; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -313,14 +321,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<2, 2>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using InLeftPads = Sequence<0, 0>; + using InRightPads = Sequence<0, 0>; #elif 0 // 1x1, 35x35 constexpr index_t N = 128; constexpr index_t C = 384; - constexpr index_t HI = 35; - constexpr index_t WI = 35; + constexpr index_t Hi = 35; + constexpr index_t Wi = 35; constexpr index_t K = 96; constexpr index_t Y = 1; constexpr index_t X = 1; @@ -328,14 +336,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using InLeftPads = Sequence<0, 0>; + using InRightPads = Sequence<0, 0>; #elif 0 // 3x3, 35x35, stride 2 constexpr index_t N = 128; constexpr index_t C = 288; - constexpr index_t HI = 35; - constexpr index_t WI = 35; + constexpr index_t Hi = 35; + constexpr index_t Wi = 35; constexpr index_t K = 384; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -343,14 +351,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<2, 2>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using InLeftPads = Sequence<0, 0>; + using InRightPads = Sequence<0, 0>; #elif 0 // 1x3, 8x8 constexpr index_t N = 128; constexpr index_t C = 384; - constexpr index_t HI = 8; - constexpr index_t WI = 8; + constexpr index_t Hi = 8; + constexpr index_t Wi = 8; constexpr index_t K = 448; constexpr index_t Y = 1; constexpr index_t X = 3; @@ -358,14 +366,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 1>; - using RightPads = Sequence<0, 1>; + using InLeftPads = Sequence<0, 1>; + using InRightPads = Sequence<0, 1>; #elif 0 // 3x1, 8x8 constexpr index_t N = 128; constexpr index_t C = 448; - constexpr index_t HI = 8; - constexpr index_t WI = 8; + constexpr index_t Hi = 8; + constexpr index_t Wi = 8; constexpr index_t K = 512; constexpr index_t Y = 3; constexpr index_t X = 1; @@ -373,14 +381,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<1, 0>; - using RightPads = Sequence<1, 0>; + using InLeftPads = Sequence<1, 0>; + using InRightPads = Sequence<1, 0>; #elif 0 // 3x3, 147x147 constexpr index_t N = 128; constexpr index_t C = 64; - constexpr index_t HI = 147; - constexpr index_t WI = 147; + constexpr index_t Hi = 147; + constexpr index_t Wi = 147; constexpr index_t K = 96; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -388,14 +396,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<2, 2>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using InLeftPads = Sequence<0, 0>; + using InRightPads = Sequence<0, 0>; #elif 0 // 7x1, 73x73 constexpr index_t N = 128; constexpr index_t C = 64; - constexpr index_t HI = 73; - constexpr index_t WI = 73; + constexpr index_t Hi = 73; + constexpr index_t Wi = 73; constexpr index_t K = 64; constexpr index_t Y = 7; constexpr index_t X = 1; @@ -403,14 +411,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<3, 0>; - using RightPads = Sequence<3, 0>; + using InLeftPads = Sequence<3, 0>; + using InRightPads = Sequence<3, 0>; #elif 0 // 3x3, 73x73 constexpr index_t N = 128; constexpr index_t C = 64; - constexpr index_t HI = 73; - constexpr index_t WI = 73; + constexpr index_t Hi = 73; + constexpr index_t Wi = 73; constexpr index_t K = 96; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -418,14 +426,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using InLeftPads = Sequence<0, 0>; + using InRightPads = Sequence<0, 0>; #elif 0 // 1x1, 14x14, stride 2 constexpr index_t N = 128; constexpr index_t C = 1024; - constexpr index_t HI = 14; - constexpr index_t WI = 14; + constexpr index_t Hi = 14; + constexpr index_t Wi = 14; constexpr index_t K = 2048; constexpr index_t Y = 1; constexpr index_t X = 1; @@ -433,14 +441,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<2, 2>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using InLeftPads = Sequence<0, 0>; + using InRightPads = Sequence<0, 0>; #elif 0 // 1x1, 14x14 constexpr index_t N = 128; constexpr index_t C = 1024; - constexpr index_t HI = 14; - constexpr index_t WI = 14; + constexpr index_t Hi = 14; + constexpr index_t Wi = 14; constexpr index_t K = 256; constexpr index_t Y = 1; constexpr index_t X = 1; @@ -448,14 +456,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using InLeftPads = Sequence<0, 0>; + using InRightPads = Sequence<0, 0>; #elif 0 // 1x1, 14x14, stride 2 constexpr index_t N = 128; constexpr index_t C = 1024; - constexpr index_t HI = 14; - constexpr index_t WI = 14; + constexpr index_t Hi = 14; + constexpr index_t Wi = 14; constexpr index_t K = 512; constexpr index_t Y = 1; constexpr index_t X = 1; @@ -463,14 +471,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<2, 2>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using InLeftPads = Sequence<0, 0>; + using InRightPads = Sequence<0, 0>; #elif 1 // 3x3, 28x28 constexpr index_t N = 128; constexpr index_t C = 128; - constexpr index_t HI = 28; - constexpr index_t WI = 28; + constexpr index_t Hi = 28; + constexpr index_t Wi = 28; constexpr index_t K = 128; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -478,14 +486,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<1, 1>; - using RightPads = Sequence<1, 1>; + using InLeftPads = Sequence<1, 1>; + using InRightPads = Sequence<1, 1>; #elif 1 // 3x3, 14x14 constexpr index_t N = 128; constexpr index_t C = 256; - constexpr index_t HI = 14; - constexpr index_t WI = 14; + constexpr index_t Hi = 14; + constexpr index_t Wi = 14; constexpr index_t K = 256; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -493,14 +501,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<1, 1>; - using RightPads = Sequence<1, 1>; + using InLeftPads = Sequence<1, 1>; + using InRightPads = Sequence<1, 1>; #elif 0 // 1x1, 56x56, stride 2 constexpr index_t N = 128; constexpr index_t C = 256; - constexpr index_t HI = 56; - constexpr index_t WI = 56; + constexpr index_t Hi = 56; + constexpr index_t Wi = 56; constexpr index_t K = 128; constexpr index_t Y = 1; constexpr index_t X = 1; @@ -508,14 +516,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<2, 2>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using InLeftPads = Sequence<0, 0>; + using InRightPads = Sequence<0, 0>; #elif 0 // 7x7, 230x230 stride=2 constexpr index_t N = 128; constexpr index_t C = 3; - constexpr index_t HI = 230; - constexpr index_t WI = 230; + constexpr index_t Hi = 230; + constexpr index_t Wi = 230; constexpr index_t K = 64; constexpr index_t Y = 7; constexpr index_t X = 7; @@ -523,14 +531,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<2, 2>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using InLeftPads = Sequence<0, 0>; + using InRightPads = Sequence<0, 0>; #elif 0 // 1x1, 28x28, stride = 2 constexpr index_t N = 128; constexpr index_t C = 512; - constexpr index_t HI = 28; - constexpr index_t WI = 28; + constexpr index_t Hi = 28; + constexpr index_t Wi = 28; constexpr index_t K = 1024; constexpr index_t Y = 1; constexpr index_t X = 1; @@ -538,14 +546,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<2, 2>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using InLeftPads = Sequence<0, 0>; + using InRightPads = Sequence<0, 0>; #elif 0 // 1x1, 28x28, stride 2 constexpr index_t N = 128; constexpr index_t C = 512; - constexpr index_t HI = 28; - constexpr index_t WI = 28; + constexpr index_t Hi = 28; + constexpr index_t Wi = 28; constexpr index_t K = 256; constexpr index_t Y = 1; constexpr index_t X = 1; @@ -553,14 +561,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<2, 2>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using InLeftPads = Sequence<0, 0>; + using InRightPads = Sequence<0, 0>; #elif 1 // 1x1, 7x7 constexpr index_t N = 128; constexpr index_t C = 512; - constexpr index_t HI = 7; - constexpr index_t WI = 7; + constexpr index_t Hi = 7; + constexpr index_t Wi = 7; constexpr index_t K = 2048; constexpr index_t Y = 1; constexpr index_t X = 1; @@ -568,14 +576,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using InLeftPads = Sequence<0, 0>; + using InRightPads = Sequence<0, 0>; #elif 0 // 3x3, 7x7 constexpr index_t N = 128; constexpr index_t C = 512; - constexpr index_t HI = 7; - constexpr index_t WI = 7; + constexpr index_t Hi = 7; + constexpr index_t Wi = 7; constexpr index_t K = 512; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -583,14 +591,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<1, 1>; - using RightPads = Sequence<1, 1>; + using InLeftPads = Sequence<1, 1>; + using InRightPads = Sequence<1, 1>; #elif 0 // 1x1, 56x56 constexpr index_t N = 128; constexpr index_t C = 64; - constexpr index_t HI = 56; - constexpr index_t WI = 56; + constexpr index_t Hi = 56; + constexpr index_t Wi = 56; constexpr index_t K = 64; constexpr index_t Y = 1; constexpr index_t X = 1; @@ -598,14 +606,14 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<0, 0>; - using RightPads = Sequence<0, 0>; + using InLeftPads = Sequence<0, 0>; + using InRightPads = Sequence<0, 0>; #elif 0 // 3x3, 56x56 constexpr index_t N = 128; constexpr index_t C = 64; - constexpr index_t HI = 56; - constexpr index_t WI = 56; + constexpr index_t Hi = 56; + constexpr index_t Wi = 56; constexpr index_t K = 64; constexpr index_t Y = 3; constexpr index_t X = 3; @@ -613,82 +621,86 @@ int main(int argc, char* argv[]) using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<1, 1>; - using RightPads = Sequence<1, 1>; + using InLeftPads = Sequence<1, 1>; + using InRightPads = Sequence<1, 1>; #endif - auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence{}); - auto wei_kcyx_desc = make_native_tensor_descriptor_packed(Sequence{}); - auto out_nkhw_desc = get_convolution_output_default_4d_tensor_descriptor( - in_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, LeftPads{}, RightPads{}); + constexpr index_t YEff = (Y - 1) * ConvDilations{}[0] + 1; + constexpr index_t XEff = (X - 1) * ConvDilations{}[1] + 1; - ostream_tensor_descriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); - ostream_tensor_descriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: "); - ostream_tensor_descriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); - print_array("LeftPads", to_multi_index(LeftPads{})); - print_array("RightPads", to_multi_index(RightPads{})); - print_array("ConvStrides", to_multi_index(ConvStrides{})); - print_array("ConvDilations", to_multi_index(ConvDilations{})); + constexpr index_t Ho = (Hi + InLeftPads{}[0] + InRightPads{}[0] - YEff) / ConvStrides{}[0] + 1; + constexpr index_t Wo = (Wi + InLeftPads{}[1] + InRightPads{}[1] - XEff) / ConvStrides{}[1] + 1; #if 1 - using in_data_t = float; constexpr index_t in_vector_size = 1; + using in_data_t = typename vector_type::type; using acc_data_t = float; using out_data_t = float; #elif 0 - using in_data_t = float; constexpr index_t in_vector_size = 1; + using in_data_t = typename vector_type::type; using acc_data_t = float; using out_data_t = int8_t; #elif 1 - using in_data_t = int8_t; constexpr index_t in_vector_size = 16; + using in_data_t = typename vector_type::type; using acc_data_t = int32_t; using out_data_t = int8_t; #endif - Tensor in_nchw(make_HostTensorDescriptor(in_nchw_desc)); - Tensor wei_kcyx(make_HostTensorDescriptor(wei_kcyx_desc)); - Tensor out_nkhw_host(make_HostTensorDescriptor(out_nkhw_desc)); - Tensor out_nkhw_device(make_HostTensorDescriptor(out_nkhw_desc)); + Tensor in_nchw(HostTensorDescriptor(std::initializer_list{N, C, Hi, Wi})); + Tensor wei_kcyx(HostTensorDescriptor(std::initializer_list{K, C, Y, X})); + Tensor out_nkhw_host( + HostTensorDescriptor(std::initializer_list{N, K, Ho, Wo})); + Tensor out_nkhw_device( + HostTensorDescriptor(std::initializer_list{N, K, Ho, Wo})); + + ostream_HostTensorDescriptor(in_nchw.mDesc, std::cout << "in_nchw_desc: "); + ostream_HostTensorDescriptor(wei_kcyx.mDesc, std::cout << "wei_kcyx_desc: "); + ostream_HostTensorDescriptor(out_nkhw_host.mDesc, std::cout << "out_nkhw_desc: "); + + print_array("InLeftPads", InLeftPads{}); + print_array("InRightPads", InRightPads{}); + print_array("ConvStrides", ConvStrides{}); + print_array("ConvDilations", ConvDilations{}); std::size_t num_thread = std::thread::hardware_concurrency(); - if(argc != 4) - { - printf("arg1: do_verification, arg2: do_log, arg3: nrepeat\n"); - exit(1); - } - - bool do_verification = atoi(argv[1]); - bool do_log = atoi(argv[2]); - index_t nrepeat = atoi(argv[3]); - if(do_verification) { -#if 0 - in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); -#elif 0 - in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); -#elif 0 - in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); -#elif 1 - in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); -#elif 0 - in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); + switch(init_method) + { + case 0: + in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 1: + in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 2: + in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 3: + in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + in_nchw.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); - auto gen_wei = [](auto... is) { - return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); - }; - wei_kcyx.GenerateTensorValue(gen_wei, num_thread); -#endif + auto gen_wei = [](auto... is) { + return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); + }; + wei_kcyx.GenerateTensorValue(gen_wei, num_thread); + } } -#if 0 + constexpr auto in_nchw_desc = make_native_tensor_descriptor_packed(Sequence{}); + constexpr auto wei_kcyx_desc = make_native_tensor_descriptor_packed(Sequence{}); + constexpr auto out_nkhw_desc = make_native_tensor_descriptor_packed(Sequence{}); + +#if 1 device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, in_nchw, wei_kcyx_desc, @@ -697,8 +709,8 @@ int main(int argc, char* argv[]) out_nkhw_device, ConvStrides{}, ConvDilations{}, - LeftPads{}, - RightPads{}, + InLeftPads{}, + InRightPads{}, nrepeat); #elif 0 device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(in_nchw_desc, @@ -709,8 +721,8 @@ int main(int argc, char* argv[]) out_nkhw_device, ConvStrides{}, ConvDilations{}, - LeftPads{}, - RightPads{}, + InLeftPads{}, + InRightPads{}, nrepeat); #elif 0 device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(in_nchw_desc, @@ -721,58 +733,9 @@ int main(int argc, char* argv[]) out_nkhw_device, ConvStrides{}, ConvDilations{}, - LeftPads{}, - RightPads{}, + InLeftPads{}, + InRightPads{}, nrepeat); -#elif 0 - device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( - in_nchw_desc, - in_nchw, - wei_kcyx_desc, - wei_kcyx, - out_nkhw_desc, - out_nkhw_device, - ConvStrides{}, - ConvDilations{}, - LeftPads{}, - RightPads{}, - nrepeat); -#elif 1 - device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk - - (in_nchw_desc, - in_nchw, - wei_kcyx_desc, - wei_kcyx, - out_nkhw_desc, - out_nkhw_device, - ConvStrides{}, - ConvDilations{}, - LeftPads{}, - RightPads{}, - nrepeat); -#elif 1 - device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw( - in_nchw_desc, - in_nchw, - wei_kcyx_desc, - wei_kcyx, - out_nkhw_desc, - out_nkhw_device, - ConvStrides{}, - ConvDilations{}, - LeftPads{}, - RightPads{}, - nrepeat); #endif if(do_verification) @@ -782,8 +745,8 @@ int main(int argc, char* argv[]) out_nkhw_host, ConvStrides{}, ConvDilations{}, - LeftPads{}, - RightPads{}); + InLeftPads{}, + InRightPads{}); check_error(out_nkhw_host, out_nkhw_device); diff --git a/driver/src/conv_driver_v2.cpp b/driver/src/conv_driver_v2.cpp new file mode 100644 index 0000000000..3d57b4c15c --- /dev/null +++ b/driver/src/conv_driver_v2.cpp @@ -0,0 +1,410 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "conv_common.hpp" +#include "host_conv.hpp" +#include "device_tensor.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" + +#define USE_DYNAMIC_MODE 1 +#define USE_CONV_FWD_V4R4_NCHW 1 +#define USE_CONV_FWD_V4R4_NHWC 1 +#define USE_CONV_FWD_V4R5_NCHW 1 +#define USE_CONV_FWD_V5R1_NCHW 0 + +enum ConvForwardAlgo +{ + V4R4NCHW, + V4R4NHWC, + V4R5NCHW, + V5R1NCHW +}; + +int main(int argc, char* argv[]) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + +#if USE_DYNAMIC_MODE + // dynamic mode + if(argc != 22) + { + printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); + exit(1); + } + + const ConvTensorLayout layout = static_cast(atoi(argv[1])); + const ConvForwardAlgo algo = static_cast(atoi(argv[2])); + const bool do_verification = atoi(argv[3]); + const int init_method = atoi(argv[4]); + const bool do_log = atoi(argv[5]); + const int nrepeat = atoi(argv[6]); + + const index_t N = atoi(argv[7]); + const index_t K = atoi(argv[8]); + const index_t C = atoi(argv[9]); + const index_t Y = atoi(argv[10]); + const index_t X = atoi(argv[11]); + const index_t Hi = atoi(argv[12]); + const index_t Wi = atoi(argv[13]); + + const index_t conv_stride_h = atoi(argv[14]); + const index_t conv_stride_w = atoi(argv[15]); + const index_t conv_dilation_h = atoi(argv[16]); + const index_t conv_dilation_w = atoi(argv[17]); + const index_t in_left_pad_h = atoi(argv[18]); + const index_t in_left_pad_w = atoi(argv[19]); + const index_t in_right_pad_h = atoi(argv[20]); + const index_t in_right_pad_w = atoi(argv[21]); + + const index_t YEff = (Y - 1) * conv_dilation_h + 1; + const index_t XEff = (X - 1) * conv_dilation_w + 1; + + const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; +#else + // static mode + if(argc < 7) + { + printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + exit(1); + } + + const ConvTensorLayout layout = static_cast(atoi(argv[1])); + const ConvForwardAlgo algo = static_cast(atoi(argv[2])); + const bool do_verification = atoi(argv[3]); + const int init_method = atoi(argv[4]); + const bool do_log = atoi(argv[5]); + const int nrepeat = atoi(argv[6]); + + constexpr index_t N = 128; + constexpr index_t C = 128; + constexpr index_t Hi = 17; + constexpr index_t Wi = 17; + constexpr index_t K = 128; + constexpr index_t Y = 1; + constexpr index_t X = 7; + + const index_t conv_stride_h = 1; + const index_t conv_stride_w = 1; + const index_t conv_dilation_h = 1; + const index_t conv_dilation_w = 1; + const index_t in_left_pad_h = 0; + const index_t in_left_pad_w = 3; + const index_t in_right_pad_h = 0; + const index_t in_right_pad_w = 3; + + const index_t YEff = (Y - 1) * conv_dilation_h + 1; + const index_t XEff = (X - 1) * conv_dilation_w + 1; + + const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; +#endif + +#if 1 + constexpr index_t in_vector_size = 1; + using in_data_t = float; + using acc_data_t = float; + using out_data_t = float; +#elif 1 + constexpr index_t in_vector_size = 16; + using in_data_t = int8_t; + using acc_data_t = int32_t; + using out_data_t = int8_t; +#endif + + std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); + + switch(layout) + { + case ConvTensorLayout::NCHW: + // NCHW + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(C); + in_lengths_host[2] = static_cast(Hi); + in_lengths_host[3] = static_cast(Wi); + wei_lengths_host[0] = static_cast(K); + wei_lengths_host[1] = static_cast(C); + wei_lengths_host[2] = static_cast(Y); + wei_lengths_host[3] = static_cast(X); + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(K); + out_lengths_host[2] = static_cast(Ho); + out_lengths_host[3] = static_cast(Wo); + break; + case ConvTensorLayout::NHWC: + // NHWC + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(Hi); + in_lengths_host[2] = static_cast(Wi); + in_lengths_host[3] = static_cast(C); + wei_lengths_host[0] = static_cast(K); + wei_lengths_host[1] = static_cast(Y); + wei_lengths_host[2] = static_cast(X); + wei_lengths_host[3] = static_cast(C); + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(Ho); + out_lengths_host[2] = static_cast(Wo); + out_lengths_host[3] = static_cast(K); + break; + default: throw std::runtime_error("wrong! not implemented"); + } + + Tensor in(in_lengths_host); + Tensor wei(wei_lengths_host); + Tensor out_host(out_lengths_host); + Tensor out_device(out_lengths_host); + + std::cout << "layout: " << layout << std::endl; + ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: "); + ostream_HostTensorDescriptor(wei.mDesc, std::cout << "wei: "); + ostream_HostTensorDescriptor(out_host.mDesc, std::cout << "out: "); + print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); + print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); + print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); + print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + if(do_verification) + { + switch(init_method) + { + case 0: + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 1: + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 3: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); + + auto gen_wei = [](auto... is) { + return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); + }; + wei.GenerateTensorValue(gen_wei, num_thread); + } + } + + auto f_make_for_device_nchw = [&]() { +#if USE_DYNAMIC_MODE + const auto in_lengths_dev = make_tuple(N, C, Hi, Wi); + const auto wei_lengths_dev = make_tuple(K, C, Y, X); + const auto out_lengths_dev = make_tuple(N, K, Ho, Wo); + const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); + const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); + const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); + const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); +#else + const auto in_lengths_dev = + make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto wei_lengths_dev = make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto out_lengths_dev = + make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto conv_strides_dev = make_tuple(Number{}, Number{}); + const auto conv_dilations_dev = + make_tuple(Number{}, Number{}); + const auto in_left_pads_dev = make_tuple(Number{}, Number{}); + const auto in_right_pads_dev = + make_tuple(Number{}, Number{}); +#endif + + return make_tuple(in_lengths_dev, + wei_lengths_dev, + out_lengths_dev, + conv_strides_dev, + conv_dilations_dev, + in_left_pads_dev, + in_right_pads_dev); + }; + + auto f_make_for_device_nhwc = [&]() { +#if USE_DYNAMIC_MODE + const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); + const auto wei_lengths_dev = make_tuple(K, Y, X, C); + const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); + const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); + const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); + const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); + const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); +#else + const auto in_lengths_dev = + make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto wei_lengths_dev = make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto out_lengths_dev = + make_tuple(Number{}, Number{}, Number{}, Number{}); + const auto conv_strides_dev = make_tuple(Number{}, Number{}); + const auto conv_dilations_dev = + make_tuple(Number{}, Number{}); + const auto in_left_pads_dev = make_tuple(Number{}, Number{}); + const auto in_right_pads_dev = + make_tuple(Number{}, Number{}); +#endif + + return make_tuple(in_lengths_dev, + wei_lengths_dev, + out_lengths_dev, + conv_strides_dev, + conv_dilations_dev, + in_left_pads_dev, + in_right_pads_dev); + }; + + const auto nhwc_desc = f_make_for_device_nhwc(); + +#if USE_CONV_FWD_V4R4_NCHW + if(algo == ConvForwardAlgo::V4R4NCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + + device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); + } +#endif + +#if USE_CONV_FWD_V4R4_NHWC + if(algo == ConvForwardAlgo::V4R4NHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); + } +#endif + +#if USE_CONV_FWD_V4R5_NCHW + if(algo == ConvForwardAlgo::V4R5NCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + + device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); + } +#endif + +#if USE_CONV_FWD_V5R1_NCHW + if(algo == ConvForwardAlgo::V5R1NCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + + device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); + } +#endif + + if(do_verification) + { + host_direct_convolution(in, + wei, + out_host, + make_tuple(conv_stride_h, conv_stride_w), + make_tuple(conv_dilation_h, conv_dilation_w), + make_tuple(in_left_pad_h, in_left_pad_w), + make_tuple(in_right_pad_h, in_right_pad_w), + layout); + + check_error(out_host, out_device); + + if(do_log) + { + LogRange(std::cout << "in : ", in.mData, ",") << std::endl; + LogRange(std::cout << "wei: ", wei.mData, ",") << std::endl; + LogRange(std::cout << "out_host : ", out_host.mData, ",") << std::endl; + LogRange(std::cout << "out_device: ", out_device.mData, ",") << std::endl; + } + } +} diff --git a/driver/src/host_tensor.cpp b/driver/src/host_tensor.cpp index 10f358d2f6..e840baf7f5 100644 --- a/driver/src/host_tensor.cpp +++ b/driver/src/host_tensor.cpp @@ -3,18 +3,6 @@ #include "host_tensor.hpp" -template -HostTensorDescriptor::HostTensorDescriptor(std::vector lens) : mLens(lens) -{ - this->CalculateStrides(); -} - -template -HostTensorDescriptor::HostTensorDescriptor(std::vector lens, std::vector strides) - : mLens(lens), mStrides(strides) -{ -} - void HostTensorDescriptor::CalculateStrides() { mStrides.clear(); @@ -45,3 +33,16 @@ std::size_t HostTensorDescriptor::GetElementSpace() const const std::vector& HostTensorDescriptor::GetLengths() const { return mLens; } const std::vector& HostTensorDescriptor::GetStrides() const { return mStrides; } + +void ostream_HostTensorDescriptor(const HostTensorDescriptor& desc, std::ostream& os) +{ + os << "dim " << desc.GetNumOfDimension() << ", "; + + os << "lengths {"; + LogRange(os, desc.GetLengths(), ", "); + os << "}, "; + + os << "strides {"; + LogRange(os, desc.GetStrides(), ", "); + os << "}" << std::endl; +} diff --git a/script/run.sh b/script/run.sh new file mode 100755 index 0000000000..75caa16a8c --- /dev/null +++ b/script/run.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +#make -j conv_driver + make -j conv_driver_v2 + +LAYOUT=$1 +ALGO=$2 +VERIFY=$3 +INIT=$4 +LOG=$5 +REPEAT=$6 + +###################### layout algo verify init log repeat N__ K__ C__ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads + driver/conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 +#driver/conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 384 192 3 3 35 35 2 2 1 1 0 0 0 0 +#driver/conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 1 7 17 17 1 1 1 1 0 3 0 3 +#driver/conv_driver_v2 $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1