diff --git a/composable_kernel/include/driver/driver_dynamic_contraction_v1r2.hpp b/composable_kernel/include/driver/driver_dynamic_contraction_v1r2.hpp new file mode 100644 index 0000000000..20d7c2ef89 --- /dev/null +++ b/composable_kernel/include/driver/driver_dynamic_contraction_v1r2.hpp @@ -0,0 +1,292 @@ +#ifndef CK_DRIVER_DYNAMIC_CONTRACTION_V1R2_HPP +#define CK_DRIVER_DYNAMIC_CONTRACTION_V1R2_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_contraction_v1r2.hpp" + +namespace ck { + +template +__host__ float +driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const AGKGM0GM1GridDesc& a_gk0_gm0_gm1_gk1_grid_desc, + const BGKGN0GN1GridDesc& b_gk0_gn0_gn1_gk1_grid_desc, + const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc, + AGridIteratorHacks, + BGridIteratorHacks, + CGridIteratorHacks, + AGridMoveSliceWindowIteratorHacks, + BGridMoveSliceWindowIteratorHacks, + index_t nrepeat) + +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + + // GEMM + using GridwiseContraction = GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2< + BlockSize, + FloatAB, + FloatAcc, + FloatC, + CGlobalMemoryDataOperation, + AGKGM0GM1GridDesc, + BGKGN0GN1GridDesc, + CGM0GM1GN0GN1GridDesc, + GM1PerBlockGM11, + GN1PerBlockGN11, + KPerBlock, + M1PerThread, + N1PerThread, + KPerThread, + M1N1ThreadClusterM10, + M1N1ThreadClusterN10, + M1N1ThreadClusterM11, + M1N1ThreadClusterN11, + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferSrcVectorTensorContiguousDimOrder, + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferSrcVectorTensorContiguousDimOrder, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + AGridIteratorHacks, + BGridIteratorHacks, + CGridIteratorHacks, + AGridMoveSliceWindowIteratorHacks, + BGridMoveSliceWindowIteratorHacks>; + + const auto GK0 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I0); + + if(!GridwiseContraction::CheckValidity( + a_gk0_gm0_gm1_gk1_grid_desc, b_gk0_gn0_gn1_gk1_grid_desc, c_gm0_gm1_gn0_gn1_grid_desc)) + { + throw std::runtime_error( + "wrong! GridwiseDynamicContraction_km_kn0n1_mn0n1_v1r1 has invalid setting"); + } + + const auto a_gk0_gm0_gm10_gm11_gk1_grid_desc = + GridwiseContraction::MakeAGK0GM0GM10GM11GK1GridDescriptor(a_gk0_gm0_gm1_gk1_grid_desc); + const auto b_gk0_gn0_gn10_gn11_gk1_grid_desc = + GridwiseContraction::MakeBGK0GN0GN10GN11GK1GridDescriptor(b_gk0_gn0_gn1_gk1_grid_desc); + + using AGK0GM0GM10GM11GK1GridDesc = decltype(a_gk0_gm0_gm10_gm11_gk1_grid_desc); + using BGK0GN0GN10GN11GK1GridDesc = decltype(b_gk0_gn0_gn10_gn11_gk1_grid_desc); + + // c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc + const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc = + GridwiseContraction::MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(c_gm0_gm1_gn0_gn1_grid_desc); + + using CGM10BM0BM1GN10BN0BN1GridDesc = decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc); + + // c_blockid_to_gm10_gn10_block_cluster_adaptor + const auto c_blockid_to_gm10_gn10_block_cluster_adaptor = + GridwiseContraction::MakeCBlockIdToGM10GN10BlockClusterAdaptor(c_gm0_gm1_gn0_gn1_grid_desc); + + using CBlockIdToGM10GN10BlockClusterAdaptor = + decltype(c_blockid_to_gm10_gn10_block_cluster_adaptor); + + const index_t grid_size = GridwiseContraction::CalculateGridSize(c_gm0_gm1_gn0_gn1_grid_desc); + + const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(GK0); + + const bool has_double_tail_k_block_loop = + GridwiseContraction::CalculateHasDoubleTailKBlockLoop(GK0); + + { + std::cout << "a_gk0_gm0_gm10_gm11_gk1_grid_desc{" + << a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I0) << ", " + << a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I1) << ", " + << a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I2) << ", " + << a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I3) << ", " + << a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I4) << "}" << std::endl; + + std::cout << "b_gk0_gn0_gn10_gn11_gk1_grid_desc{" + << b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I0) << ", " + << b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I1) << ", " + << b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I2) << ", " + << b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I3) << ", " + << b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I4) << "}" << std::endl; + + std::cout << "c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc{ " + << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I0) << ", " + << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I1) << ", " + << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I2) << ", " + << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I3) << ", " + << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I4) << ", " + << c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I5) << "}" << std::endl; + } + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = kernel_dynamic_contraction_v1r1< + GridwiseContraction, + FloatAB, + FloatC, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_gk0_gm0_gm10_gm11_gk1_grid_desc, + b_gk0_gn0_gn10_gn11_gk1_grid_desc, + c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, + c_blockid_to_gm10_gn10_block_cluster_adaptor); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = kernel_dynamic_contraction_v1r1< + GridwiseContraction, + FloatAB, + FloatC, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_gk0_gm0_gm10_gm11_gk1_grid_desc, + b_gk0_gn0_gn10_gn11_gk1_grid_desc, + c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, + c_blockid_to_gm10_gn10_block_cluster_adaptor); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = kernel_dynamic_contraction_v1r1< + GridwiseContraction, + FloatAB, + FloatC, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_gk0_gm0_gm10_gm11_gk1_grid_desc, + b_gk0_gn0_gn10_gn11_gk1_grid_desc, + c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, + c_blockid_to_gm10_gn10_block_cluster_adaptor); + } + else + { + const auto kernel = kernel_dynamic_contraction_v1r1< + GridwiseContraction, + FloatAB, + FloatC, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_gk0_gm0_gm10_gm11_gk1_grid_desc, + b_gk0_gn0_gn10_gn11_gk1_grid_desc, + c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, + c_blockid_to_gm10_gn10_block_cluster_adaptor); + } + + return ave_time; +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/driver/driver_dynamic_gemm_v1r3.hpp b/composable_kernel/include/driver/driver_dynamic_gemm_v1r3.hpp new file mode 100644 index 0000000000..be65fd34c8 --- /dev/null +++ b/composable_kernel/include/driver/driver_dynamic_gemm_v1r3.hpp @@ -0,0 +1,416 @@ +#ifndef CK_DRIVER_DYNAMIC_GEMM_v1r3 +#define CK_DRIVER_DYNAMIC_GEMM_v1r3 + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "gridwise_dynamic_gemm_v1r3.hpp" + +namespace ck { + +template +__host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const AK0MK1GridDesc& a_k0_m_k1_grid_desc, + const BK0NK1GridDesc& b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + AGridIteratorHacks, + BGridIteratorHacks, + CGridIteratorHacks, + AGridMoveSliceWindowIteratorHacks, + BGridMoveSliceWindowIteratorHacks, + index_t nrepeat) + +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + + // GEMM + using GridwiseGemm = + GridwiseDynamicGemm_km_kn_mn_v1r3; + + const auto M = a_k0_m_k1_grid_desc.GetLength(I1); + const auto N = b_k0_n_k1_grid_desc.GetLength(I1); + const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + + if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) + { + throw std::runtime_error("wrong! GridwiseDynamicGemm_km_kn_mn_v1r3 has invalid setting"); + } + + const auto a_k0_m0_m1_k1_grid_desc = + GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_k0_m_k1_grid_desc); + const auto b_k0_n0_n1_k1_grid_desc = + GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_k0_n_k1_grid_desc); + + using AK0M0M1K1GridDesc = decltype(a_k0_m0_m1_k1_grid_desc); + using BK0N0N1K1GridDesc = decltype(b_k0_n0_n1_k1_grid_desc); + + // c_m0_m10_m11_n0_n10_n11_grid_desc + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = + GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); + + using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc); + + // c_blockid_to_m0_n0_block_cluster_adaptor + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc); + + using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor); + + const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N); + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + + const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + { + std::cout << "a_k0_m0_m1_k1_grid_desc{" << a_k0_m0_m1_k1_grid_desc.GetLength(I0) << ", " + << a_k0_m0_m1_k1_grid_desc.GetLength(I1) << ", " + << a_k0_m0_m1_k1_grid_desc.GetLength(I2) << ", " + << a_k0_m0_m1_k1_grid_desc.GetLength(I3) << "}" << std::endl; + + std::cout << "b_k0_n0_n1_k1_grid_desc{" << b_k0_n0_n1_k1_grid_desc.GetLength(I0) << ", " + << b_k0_n0_n1_k1_grid_desc.GetLength(I1) << ", " + << b_k0_n0_n1_k1_grid_desc.GetLength(I2) << ", " + << b_k0_n0_n1_k1_grid_desc.GetLength(I3) << "}" << std::endl; + + std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", " + << c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl; + } + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + else + { + const auto kernel = + kernel_dynamic_gemm_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor); + } + + return ave_time; +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + DeviceMem a_k0_m0_m1_k1_grid_desc_dev_buf(sizeof(AK0M0M1K1GridDesc)); + DeviceMem b_k0_n0_n1_k1_grid_desc_dev_buf(sizeof(BK0N0N1K1GridDesc)); + DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc)); + DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf( + sizeof(CBlockIdToM0N0BlockClusterAdaptor)); + + a_k0_m0_m1_k1_grid_desc_dev_buf.ToDevice(&a_k0_m0_m1_k1_grid_desc); + b_k0_n0_n1_k1_grid_desc_dev_buf.ToDevice(&b_k0_n0_n1_k1_grid_desc); + c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc); + c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice( + &c_blockid_to_m0_n0_block_cluster_adaptor); + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_dynamic_gemm_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + else + { + const auto kernel = + kernel_dynamic_gemm_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + 0, + p_a_grid, + p_b_grid, + p_c_grid, + (void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(), + (void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer()); + } + + return ave_time; +#endif +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp index 0b0d8d961e..ad9d99f4e7 100644 --- a/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp +++ b/composable_kernel/include/kernel_algorithm/transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp @@ -12,7 +12,7 @@ namespace ck { // C: out // GemmM = N * Ho * Wo // GemmN = K -// GemmK = C * Y * X +// GemmK = Y * X * C template +__host__ __device__ constexpr auto +transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad( + const DynamicTensorDescriptor& wei_k_c_y_x_grid_desc, + const DynamicTensorDescriptor& in_n_c_hi_wi_grid_desc, + const DynamicTensorDescriptor& out_n_k_ho_wo_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); + const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + constexpr auto N0 = Number{}; + constexpr auto C0 = Number{}; + + const auto N1 = N / N0; + const auto C1 = C / C0; + + // weight tensor + const auto wei_gk0_gm0_gm1_gk1_grid_desc = transform_dynamic_tensor_descriptor( + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)), + make_tuple(make_unmerge_transform(make_tuple(I1, K)), + make_unmerge_transform(make_tuple(C0, C1 * Y * X))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1, 2>{}, Sequence<3, 0>{})); + + // input tensor + const auto in_n_c_hip_wip_grid_desc = transform_dynamic_tensor_descriptor( + in_n_c_hi_wi_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n0_n1_c0_c1_y_ho_x_wo_grid_desc = transform_dynamic_tensor_descriptor( + in_n_c_hip_wip_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(N0, N1)), + make_unmerge_transform(make_tuple(C0, C1)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{})); + + const auto in_gk0_gn0_gn1_gk1_grid_desc = transform_dynamic_tensor_descriptor( + in_n0_n1_c0_c1_y_ho_x_wo_grid_desc, + make_tuple(make_merge_transform(make_tuple(C1, Y, X)), + make_pass_through_transform(N0), + make_merge_transform(make_tuple(N1, Ho, Wo)), + make_pass_through_transform(C0)), + make_tuple(Sequence<3, 4, 6>{}, Sequence<0>{}, Sequence<1, 5, 7>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + // output tensor + const auto out_n_k_howo_grid_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo)); + + const auto out_n0_n1_1_k_howo_grid_desc = transform_dynamic_tensor_descriptor( + out_n_k_howo_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(Number{}, N1)), + make_unmerge_transform(make_tuple(I1, K)), + make_pass_through_transform(Ho * Wo)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{})); + + const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_dynamic_tensor_descriptor( + out_n0_n1_1_k_howo_grid_desc, + make_tuple(make_pass_through_transform(I1), + make_pass_through_transform(K), + make_pass_through_transform(Number{}), + make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))), + make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + return make_tuple( + wei_gk0_gm0_gm1_gk1_grid_desc, in_gk0_gn0_gn1_gk1_grid_desc, out_gm0_gm1_gn0_gn1_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer_v2.hpp b/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer_v2.hpp new file mode 100644 index 0000000000..9e2902e991 --- /dev/null +++ b/composable_kernel/include/tensor_operation/blockwise_dynamic_tensor_slice_transfer_v2.hpp @@ -0,0 +1,158 @@ +#ifndef CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP +#define CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "cluster_descriptor.hpp" +#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +template +struct BlockwiseDynamicTensorSliceTransfer_v4r1 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + using Index = MultiIndex; + + __device__ constexpr BlockwiseDynamicTensorSliceTransfer_v4r1( + const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin) + : threadwise_transfer_( + src_desc, make_zero_multi_index(), dst_desc, make_zero_multi_index()) + + { + static_assert(nDim == remove_reference_t>::GetNumOfDimension() && + nDim == remove_reference_t>::GetNumOfDimension() && + nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), + "wrong! BlockSize too small"); + + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{}; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const SrcIteratorHacks& src_iterator_hacks) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks); + } + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunWrite(dst_desc, dst_buf); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + // SrcMoveSliceWindowIteratorHack to control index calculation move slice window + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& step, + const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow( + src_desc, step, src_move_slice_window_iterator_hack); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseDynamicTensorSliceTransfer_v3r1; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_v2r3.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_v2r3.hpp new file mode 100644 index 0000000000..c57134762b --- /dev/null +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_v2r3.hpp @@ -0,0 +1,398 @@ +#ifndef CK_BLOCKWISE_GEMM_V2R3_HPP +#define CK_BLOCKWISE_GEMM_V2R3_HPP + +#include "common_header.hpp" +#include "tensor_adaptor.hpp" +#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp" +#include "threadwise_gemm_v2.hpp" + +namespace ck { + +// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1] +// A and B are visable to the whole block, C is distributed among each thread +// Assume: +// 1. A: +// 1. AK0MK1BlockDesc is known at compile-time +// 2. ABlockBuffer is DynamicBuffer +// 2. B: +// 1. BK0NK1BlockDesc is known at compile-time +// 2. BBlockBuffer is DynamicBuffer +// 3. C: +// 1. CM0M1N0N1ThreadDesc is known at compile-time +// 2. CThreadBuffer is StaticBuffer +// Also assume: +// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization) +template ::type = false> +struct BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2 +{ + using AIndex = MultiIndex<3>; + using BIndex = MultiIndex<3>; + using CIndex = MultiIndex<4>; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr index_t K0 = AK0MK1BlockDesc{}.GetLength(I0); + static constexpr index_t K1 = AK0MK1BlockDesc{}.GetLength(I2); + static constexpr index_t M = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t N = BK0NK1BlockDesc{}.GetLength(I1); + + static constexpr index_t M100 = M1N1ThreadClusterM100; + static constexpr index_t N100 = M1N1ThreadClusterN100; + + static constexpr index_t M101 = M1N1ThreadClusterM101; + static constexpr index_t N101 = M1N1ThreadClusterN101; + + static constexpr index_t M11 = M1PerThreadM11; + static constexpr index_t N11 = N1PerThreadN11; + + static constexpr index_t M1 = M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThreadM11; + static constexpr index_t N1 = M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThreadN11; + + static constexpr index_t M0 = M / M1; + static constexpr index_t N0 = N / N1; + + __host__ __device__ static constexpr auto + MakeAK0M0M1K1BlockDescriptor(const AK0MK1BlockDesc& a_k0_m_k1_block_desc) + { + const auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor( + a_k0_m_k1_block_desc, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return a_k0_m0_m1_k1_block_desc; + } + + __host__ __device__ static constexpr auto + MakeBK0N0N1K1BlockDescriptor(const BK0NK1BlockDesc& b_k0_n_k1_block_desc) + { + const auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor( + b_k0_n_k1_block_desc, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return b_k0_n0_n1_k1_block_desc; + } + + __host__ __device__ static constexpr auto MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor() + { + // upper: [M0, M100, M101, M11, N0, N100, N101, N11] + // lower: [M, N] + constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{})); + + return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor; + } + + __host__ __device__ static constexpr auto + MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor() + { + // upper: [M0, M100, M101, M11, N0, N100, N101, N11] + // lower: [M0, M1, N0, N1] + constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{})); + + return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor; + } + + __host__ __device__ static constexpr auto GetCM0M1N0N1ThreadTensorLengths() + { + return Sequence{}; + } + + static constexpr auto a_k0_m0_m1_k1_block_desc_ = + MakeAK0M0M1K1BlockDescriptor(AK0MK1BlockDesc{}); + static constexpr auto b_k0_n0_n1_k1_block_desc_ = + MakeBK0N0N1K1BlockDescriptor(BK0NK1BlockDesc{}); + + public: + __device__ BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2() + : c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock( + get_thread_local_1d_id())}, + a_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1], 0)}, + b_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3], 0)} + { + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(BlockSize == M101 * M100 * N101 * N100, + "wrong! blocksize and cluster size not consistent"); + + static_assert(M % M1 == 0 && N % N1 == 0, "wrong!"); + + static_assert(AK0MK1BlockDesc{}.GetLength(I0) == BK0NK1BlockDesc{}.GetLength(I0), + "wrong! K dimension not consistent"); + + // TODO: remove this restriction + static_assert(M0 == 2 && N0 == 2, "wrong"); + } + + __device__ static CIndex CalculateCM0M1N0N1ThreadOriginOnBlock(index_t thread_id) + { + // lower: [M0, M1, N0, N1] + // upper: [M0, M100, M101, M11, N0, N100, N101, N11] + constexpr auto adaptor0 = MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor(); + + // lower: [M0, M100, M101, M11, N0, N100, N101, N11] + // upper: [Tid, M0, M11, N0, N11] + constexpr auto adaptor1 = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M100, N100, M101, N101)), + make_pass_through_transform(M0), + make_pass_through_transform(M11), + make_pass_through_transform(N0), + make_pass_through_transform(N11)), + make_tuple( + Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1); + + return adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id(), 0, 0, 0, 0)); + } + + template + __device__ void Run(const CM0M1N0N1ThreadDesc& c_m0_m1_n0_n1_thread_desc, + const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + static_assert(CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + // TODO: remove this restriction + static_assert(M0 == 2 && N0 == 2 && CM0M1N0N1ThreadDesc{}.GetLength(I0) == M0 && + CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0, + "wrong"); + + auto a_thread_buf = make_static_buffer( + a_k0_m0_m1_k1_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_k0_n0_n1_k1_thread_desc_.GetElementSpaceSize()); + + constexpr auto threadwise_gemm = + ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1, + Sequence<1, M1PerThreadM11>, + Sequence<1, N1PerThreadN11>>{}; + + // read A_sub_0 + a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_, + make_tuple(I0, I0, I0, I0), + a_block_buf, + a_k0_m0_m1_k1_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + // read B_sub_0 + b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_, + make_tuple(I0, I0, I0, I0), + b_block_buf, + b_k0_n0_n1_k1_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + // read B_sub_1 + b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_, + make_tuple(I0, I1, I0, I0), + b_block_buf, + b_k0_n0_n1_k1_thread_desc_, + make_tuple(I0, I1, I0, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_, + make_tuple(I0, I1, I0, I0), + a_block_buf, + a_k0_m0_m1_k1_thread_desc_, + make_tuple(I0, I1, I0, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); + + // loop over rest of k + static_for{}([&](auto k) { + // read A_sub_0 + a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_, + make_tuple(k, I0, I0, I0), + a_block_buf, + a_k0_m0_m1_k1_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); + + // read B_sub_0 + b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_, + make_tuple(k, I0, I0, I0), + b_block_buf, + b_k0_n0_n1_k1_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); + + // read B_sub_1 + b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_, + make_tuple(k, I1, I0, I0), + b_block_buf, + b_k0_n0_n1_k1_thread_desc_, + make_tuple(I0, I1, I0, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_, + make_tuple(k, I1, I0, I0), + a_block_buf, + a_k0_m0_m1_k1_thread_desc_, + make_tuple(I0, I1, I0, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); + }); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); + } + + private: + // A[K0, M0, M1, K1] + static constexpr auto a_k0_m0_m1_k1_thread_desc_ = + make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{}, Number{}, Number{})); + + // B[K0, N0, N1, K1] + static constexpr auto b_k0_n0_n1_k1_thread_desc_ = + make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(Number{}, Number{}, Number{}, Number{})); + + using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1< + FloatA, + FloatA, + decltype(a_k0_m0_m1_k1_block_desc_), + decltype(a_k0_m0_m1_k1_thread_desc_), + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + Sequence<1, 1, M1PerThreadM11, K1>, // SrcVectorTensorLengths + Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder + + using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1< + FloatB, + FloatB, + decltype(b_k0_n0_n1_k1_block_desc_), + decltype(b_k0_n0_n1_k1_thread_desc_), + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + Sequence<1, 1, N1PerThreadN11, K1>, // SrcVectorTensorLengths + Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder + + CIndex c_thread_origin_data_idx_; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_v1r2.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_v1r2.hpp new file mode 100644 index 0000000000..af25458b6c --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_contraction_v1r2.hpp @@ -0,0 +1,675 @@ +#ifndef CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R2_HPP +#define CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R2_HPP + +#include "common_header.hpp" +#include "dynamic_multi_index_transform_helper.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "blockwise_gemm_v2r2.hpp" +#include "blockwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_dynamic_tensor_slice_transfer.hpp" +#include "threadwise_dynamic_tensor_slice_set.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_dynamic_contraction_v1r2( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGK0GM0GM10GM11GK1GridDesc a_gk0_gm0_gm10_gm11_gk1_grid_desc, + const BGK0GN0GN10GN11GK1GridDesc b_gk0_gn0_gn10_gn11_gk1_grid_desc, + const CGM10BM0BM1GN10BN0BN1GridDesc c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, + const CBlockIdToGM10GN10BlockClusterAdaptor + c_blockid_to_gm10_gn10_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseContraction::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_gk0_gm0_gm10_gm11_gk1_grid_desc, + b_gk0_gn0_gn10_gn11_gk1_grid_desc, + c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, + c_blockid_to_gm10_gn10_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} + +template +struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + // GM0 and GN0 need to known at compile-time + static constexpr auto GM0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I0); + static constexpr auto GN0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I2); + static constexpr auto GK1 = AGK0GM0GM1GK1GridDesc{}.GetLength(I3); + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // lds max alignment + // TODO: part of them should be moved into blockwise-gemm + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = GK1; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_gk0_gm0_gm10_gm11_gk1_block_desc = + make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GM0, I1, Number{}, GK1), + max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_gk0_gn0_gn10_gn11_gk1_block_desc = + make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GN0, I1, Number{}, GK1), + max_lds_align); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = math::integer_least_multiple( + a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = math::integer_least_multiple( + b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize(), max_lds_align); + + return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); + } + + __host__ __device__ static constexpr bool + CheckValidity(const AGK0GM0GM1GK1GridDesc& a_gk0_gm0_gm1_gk1_grid_desc, + const BGK0GN0GN1GK1GridDesc& b_gk0_gn0_gn1_gk1_grid_desc, + const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc) + { + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! GM0 and GN0 need to be known at compile-time"); + + const auto GM1 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I2); + const auto GN1 = b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I2); + const auto GK0 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I0); + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + return ((GM0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I0) && + GM1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1) && + GN0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I2) && + GN1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3) && + GM0 == a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I1) && + GM1 == a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I2) && + GN0 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I1) && + GN1 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I2) && + GK0 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I0) && + GK1 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I3)) && + (GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK0 % KPerBlock == 0)); + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc) + { + const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1); + const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3); + + constexpr index_t GM11 = GM1PerBlockGM11; + constexpr index_t GN11 = GN1PerBlockGN11; + + const index_t GM10 = GM1 / GM11; + const index_t GN10 = GN1 / GN11; + + const index_t grid_size = GM10 * GN10; + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK0) + { + const bool has_main_k_block_loop = (GK0 + KPerBlock) / (2 * KPerBlock) > 1; + + return has_main_k_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK0) + { + const bool has_double_tail_k_block_loop = (GK0 / KPerBlock) % 2 == 0; + + return has_double_tail_k_block_loop; + } + + __host__ __device__ static constexpr auto + MakeAGK0GM0GM10GM11GK1GridDescriptor(const AGK0GM0GM1GK1GridDesc& a_gk0_gm0_gm1_gk1_grid_desc) + { + const auto GK0 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I0); + const auto GM1 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I2); + + const auto GM11 = Number{}; + const auto GM10 = GM1 / GM11; + + const auto a_gk0_gm0_gm10_gm11_gk1_grid_desc = transform_dynamic_tensor_descriptor( + a_gk0_gm0_gm1_gk1_grid_desc, + make_tuple(make_pass_through_transform(GK0), + make_pass_through_transform(GM0), + make_unmerge_transform(make_tuple(GM10, GM11)), + make_pass_through_transform(GK1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); + + return a_gk0_gm0_gm10_gm11_gk1_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeBGK0GN0GN10GN11GK1GridDescriptor(const BGK0GN0GN1GK1GridDesc& b_gk0_gn0_gn1_gk1_grid_desc) + { + const auto GK0 = b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I0); + const auto GN1 = b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I2); + + const auto GN11 = Number{}; + const auto GN10 = GN1 / GN11; + + const auto b_gk0_gn0_gn10_gn11_gk1_grid_desc = transform_dynamic_tensor_descriptor( + b_gk0_gn0_gn1_gk1_grid_desc, + make_tuple(make_pass_through_transform(GK0), + make_pass_through_transform(GN0), + make_unmerge_transform(make_tuple(GN10, GN11)), + make_pass_through_transform(GK1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); + + return b_gk0_gn0_gn10_gn11_gk1_grid_desc; + } + + __host__ __device__ static constexpr auto MakeCGM10BM0BM1GN10BN0BN1GridDescriptor( + const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc) + { + const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1); + const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3); + + constexpr auto GM11 = Number{}; + constexpr auto GN11 = Number{}; + + const auto GM10 = GM1 / GM11; + const auto GN10 = GN1 / GN11; + + constexpr auto BM = GM0 * GM11; + constexpr auto BN = GN0 * GN11; + + constexpr auto BM1 = + Number{}; + constexpr auto BN1 = + Number{}; + + constexpr auto BM0 = BM / BM1; + constexpr auto BN0 = BN / BN1; + + const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor( + c_gm0_gm1_gn0_gn1_grid_desc, + make_tuple(make_pass_through_transform(GM0), + make_unmerge_transform(make_tuple(GM10, GM11)), + make_pass_through_transform(GN0), + make_unmerge_transform(make_tuple(GN10, GN11))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{})); + + const auto c_gm10_bm_gn10_bn_grid_desc = transform_dynamic_tensor_descriptor( + c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc, + make_tuple(make_pass_through_transform(GM10), + make_merge_transform(make_tuple(GM0, GM11)), + make_pass_through_transform(GN10), + make_merge_transform(make_tuple(GN0, GN11))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc = transform_dynamic_tensor_descriptor( + c_gm10_bm_gn10_bn_grid_desc, + make_tuple(make_pass_through_transform(GM10), + make_unmerge_transform(make_tuple(BM0, BM1)), + make_pass_through_transform(GN10), + make_unmerge_transform(make_tuple(BN0, BN1))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{})); + + return c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc; + } + + __host__ __device__ static constexpr auto MakeCBlockIdToGM10GN10BlockClusterAdaptor( + const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc) + { + const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1); + const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3); + + constexpr auto GM11 = Number{}; + constexpr auto GN11 = Number{}; + + const auto GM10 = GM1 / GM11; + const auto GN10 = GN1 / GN11; + + const auto c_blockid_to_gm10_gn10_block_cluster_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(GM10, GN10))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + return c_blockid_to_gm10_gn10_block_cluster_adaptor; + } + + using AGK0GM0GM10GM11GK1GridDesc = + decltype(MakeAGK0GM0GM10GM11GK1GridDescriptor(AGK0GM0GM1GK1GridDesc{})); + using BGK0GN0GN10GN11GK1GridDesc = + decltype(MakeBGK0GN0GN10GN11GK1GridDescriptor(BGK0GN0GN1GK1GridDesc{})); + using CGM10BM0BM1GN10BN0BN1GridDesc = + decltype(MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(CGM0GM1GN0GN1GridDesc{})); + using CBlockIdToGM10GN10BlockClusterAdaptor = + decltype(MakeCBlockIdToGM10GN10BlockClusterAdaptor(CGM0GM1GN0GN1GridDesc{})); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AGK0GM0GM10GM11GK1GridDesc& a_gk0_gm0_gm10_gm11_gk1_grid_desc, + const BGK0GN0GN10GN11GK1GridDesc& b_gk0_gn0_gn10_gn11_gk1_grid_desc, + const CGM10BM0BM1GN10BN0BN1GridDesc& c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, + const CBlockIdToGM10GN10BlockClusterAdaptor& c_blockid_to_gm10_gn10_block_cluster_adaptor, + integral_constant, + integral_constant) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_grid, a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_grid, b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetElementSpaceSize()); + + const auto GK0 = a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I0); + + // divide block work by [GM10, GN10] + const auto c_gm10_gn10_block_cluster_idx = + c_blockid_to_gm10_gn10_block_cluster_adaptor.CalculateBottomIndex( + make_multi_index(get_block_1d_id())); + + // HACK: this force index data into SGPR + const index_t igm10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I0]); + const index_t ign10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I1]); + + // lds max alignment + // TODO: part of them should be moved into blockwise-gemm + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = GK1; + + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_gk0_gm0_gm10_gm11_gk1_block_desc = + make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GM0, I1, Number{}, GK1), + max_lds_align); + + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_gk0_gn0_gn10_gn11_gk1_block_desc = + make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GN0, I1, Number{}, GK1), + max_lds_align); + + // A matrix in LDS memory for blockwise GEMM + // be careful of LDS alignment + constexpr auto a_gk0_bm_gk1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GM0 * Number{}, GK1), max_lds_align); + + // B matrix in LDS memory for blockwise GEMM + // be careful of LDS alignment + constexpr auto b_gk0_bn_gk1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, GN0 * Number{}, GK1), max_lds_align); + + static_assert(a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize() == + a_gk0_bm_gk1_block_desc.GetElementSpaceSize() && + b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize() == + b_gk0_bn_gk1_block_desc.GetElementSpaceSize(), + "wrong!"); + + // A matrix blockwise copy + auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< + BlockSize, + InMemoryDataOperation::Set, + Sequence, + ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_gk0_gm0_gm10_gm11_gk1_grid_desc), + decltype(a_gk0_gm0_gm10_gm11_gk1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3, 4>, + ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // SrcVectorTensorLengths + ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // DstVectorTensorLengths + ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder + false, + true>(a_gk0_gm0_gm10_gm11_gk1_grid_desc, + make_multi_index(0, 0, igm10, 0, 0), + a_gk0_gm0_gm10_gm11_gk1_block_desc, + make_multi_index(0, 0, 0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< + BlockSize, + InMemoryDataOperation::Set, + Sequence, + BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_gk0_gn0_gn10_gn11_gk1_grid_desc), + decltype(b_gk0_gn0_gn10_gn11_gk1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3, 4>, + BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // SrcVectorTensorLengths + BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // DstVectorTensorLengths + BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder + false, + true>(b_gk0_gn0_gn10_gn11_gk1_grid_desc, + make_multi_index(0, 0, ign10, 0, 0), + b_gk0_gn0_gn10_gn11_gk1_block_desc, + make_multi_index(0, 0, 0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[KPerBlock, GM1PerBlockGM11] is in LDS + // b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS + // c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in + // register + const auto blockwise_gemm = + BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2{}; + + constexpr auto c_bm0_bm1_bn0_bn1_thread_tensor_lengths = + decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths(); + + constexpr auto c_bm0_bm1_bn0_bn1_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2( + sequence_to_tuple_of_number(c_bm0_bm1_bn0_bn1_thread_tensor_lengths)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = math::integer_least_multiple( + a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = math::integer_least_multiple( + b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block_double = p_shared_block; + FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; + + // register allocation for output + auto c_thread_buf = make_static_buffer( + c_bm0_bm1_bn0_bn1_thread_desc.GetElementSpaceSize()); + + ThreadwiseDynamicTensorSliceSet_v1{} + .Run(c_bm0_bm1_bn0_bn1_thread_desc, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + FloatAcc{0}); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0, 0); + + auto a_block_even_buf = make_dynamic_buffer( + p_a_block_double, a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize()); + auto b_block_even_buf = make_dynamic_buffer( + p_b_block_double, b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize()); + + auto a_block_odd_buf = make_dynamic_buffer( + p_a_block_double + a_block_aligned_space_size, + a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize()); + auto b_block_odd_buf = make_dynamic_buffer( + p_b_block_double + b_block_aligned_space_size, + b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize()); + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.RunRead( + a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{}); + b_blockwise_copy.RunRead( + b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{}); + + a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_even_buf); + b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_even_buf); + } + + if constexpr(HasMainKBlockLoop) + { + index_t k_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + a_blockwise_copy.MoveSrcSliceWindow(a_gk0_gm0_gm10_gm11_gk1_grid_desc, + a_block_slice_copy_step, + AGridMoveSliceWindowIteratorHacks{}); + b_blockwise_copy.MoveSrcSliceWindow(b_gk0_gn0_gn10_gn11_gk1_grid_desc, + b_block_slice_copy_step, + BGridMoveSliceWindowIteratorHacks{}); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{}); + b_blockwise_copy.RunRead( + b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{}); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(c_bm0_bm1_bn0_bn1_thread_desc, + a_block_even_buf, + b_block_even_buf, + c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_odd_buf); + + // odd iteration + a_blockwise_copy.MoveSrcSliceWindow(a_gk0_gm0_gm10_gm11_gk1_grid_desc, + a_block_slice_copy_step, + AGridMoveSliceWindowIteratorHacks{}); + b_blockwise_copy.MoveSrcSliceWindow(b_gk0_gn0_gn10_gn11_gk1_grid_desc, + b_block_slice_copy_step, + BGridMoveSliceWindowIteratorHacks{}); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{}); + b_blockwise_copy.RunRead( + b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{}); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run( + c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_even_buf); + b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_even_buf); + + k_block_data_begin += 2 * KPerBlock; + } while(k_block_data_begin < GK0 - 2 * KPerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + a_blockwise_copy.MoveSrcSliceWindow(a_gk0_gm0_gm10_gm11_gk1_grid_desc, + a_block_slice_copy_step, + AGridMoveSliceWindowIteratorHacks{}); + b_blockwise_copy.MoveSrcSliceWindow(b_gk0_gn0_gn10_gn11_gk1_grid_desc, + b_block_slice_copy_step, + BGridMoveSliceWindowIteratorHacks{}); + + __syncthreads(); + + // LDS double buffer: load last data from device mem + a_blockwise_copy.RunRead( + a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{}); + b_blockwise_copy.RunRead( + b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{}); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run( + c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + + // LDS double buffer: store last data to LDS + a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_odd_buf); + + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr index_t M11 = + M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101; + constexpr index_t N11 = + N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101; + + constexpr index_t M10 = GM1PerBlockGM11 / M11; + constexpr index_t N10 = GN1PerBlockGN11 / N11; + + constexpr index_t M111 = M1PerThreadM111; + constexpr index_t N111 = N1PerThreadN111; + + constexpr auto c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + const auto c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block = + blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id()); + + ThreadwiseDynamicTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc), + decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc), + Sequence<1, + c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I0], + c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I1], + 1, + c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I2], + c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I3]>, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, + make_multi_index(igm10, + c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I0], + c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I1], + ign10, + c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I2], + c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I3])} + .Run(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc, + c_grid_buf, + CGridIteratorHacks{}); + } + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r3.hpp b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r3.hpp new file mode 100644 index 0000000000..c070cac826 --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r3.hpp @@ -0,0 +1,669 @@ +#ifndef CK_GRIDWISE_DYNAMIC_GEMM_V1R3_HPP +#define CK_GRIDWISE_DYNAMIC_GEMM_V1R3_HPP + +#include "common_header.hpp" +#include "dynamic_multi_index_transform_helper.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" +#include "blockwise_gemm_v2r3.hpp" +#include "blockwise_dynamic_tensor_slice_transfer_v2.hpp" +#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp" +#include "threadwise_dynamic_tensor_slice_set.hpp" + +namespace ck { + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_dynamic_gemm_v1r3( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AK0M0M1K1GridDesc a_k0_m0_m1_k1_grid_desc, + const BK0N0N1K1GridDesc b_k0_n0_n1_k1_grid_desc, + const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc, + const CBlockIdToM0N0BlockClusterAdaptor c_blockid_to_m0_n0_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER +// pass tensor descriptor by __CONSTANT__ void pointer +// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to +// non-modifiable parameter address space, so compiler can enable corresponding optimization +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_dynamic_gemm_v1r3( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void __CONSTANT__* p_a_k0_m0_m1_k1_grid_desc, + const void __CONSTANT__* p_b_k0_n0_n1_k1_grid_desc, + const void __CONSTANT__* p_c_m0_m10_m11_n0_n10_n11_grid_desc, + const void __CONSTANT__* p_c_blockid_to_m0_n0_block_cluster_adaptor) +{ + // first cast void __CONSTANT__ void* to void* + // second cast void* to Desc* + // the copy constructor of tensor descriptor doesn't take address_space(4) + const auto a_k0_m0_m1_k1_grid_desc = + *reinterpret_cast((const void*)p_a_k0_m0_m1_k1_grid_desc); + const auto b_k0_n0_n1_k1_grid_desc = + *reinterpret_cast((const void*)p_b_k0_n0_n1_k1_grid_desc); + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = + *reinterpret_cast( + (const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc); + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + *reinterpret_cast( + (const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor); + + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m0_m1_k1_grid_desc, + b_k0_n0_n1_k1_grid_desc, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_blockid_to_m0_n0_block_cluster_adaptor, + integral_constant{}, + integral_constant{}); +} +#endif + +template +struct GridwiseDynamicGemm_km_kn_mn_v1r3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + // K1 should be Number<...> + static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2); + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = K1; + + // TODO: check alignment + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = + math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = + math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align); + + return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); + } + + __host__ __device__ static constexpr bool + CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, + const BK0NK1GridDesc& b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = a_k0_m_k1_grid_desc.GetLength(I1); + const auto N = b_k0_n_k1_grid_desc.GetLength(I1); + const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + const auto K1 = a_k0_m_k1_grid_desc.GetLength(I2); + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K0 == b_k0_n_k1_grid_desc.GetLength(I0) && + K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && + (M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0); + } + + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1); + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0) + { + const bool has_main_k_block_loop = (K0 + KPerBlock) / (2 * KPerBlock) > 1; + + return has_main_k_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0) + { + const bool has_double_tail_k_block_loop = (K0 / KPerBlock) % 2 == 0; + + return has_double_tail_k_block_loop; + } + + __host__ __device__ static constexpr auto + MakeAK0M0M1K1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc) + { + const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + const auto M = a_k0_m_k1_grid_desc.GetLength(I1); + + const auto M1 = Number{}; + const auto M0 = M / M1; + + const auto a_k0_m0_m1_k1_grid_desc = transform_dynamic_tensor_descriptor( + a_k0_m_k1_grid_desc, + make_tuple(make_pass_through_transform(K0), + make_unmerge_transform(make_tuple(M0, M1)), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return a_k0_m0_m1_k1_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeBK0N0N1K1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc) + { + const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0); + const auto N = b_k0_n_k1_grid_desc.GetLength(I1); + + const auto N1 = Number{}; + const auto N0 = N / N1; + + const auto b_k0_n0_n1_k1_grid_desc = transform_dynamic_tensor_descriptor( + b_k0_n_k1_grid_desc, + make_tuple(make_pass_through_transform(K0), + make_unmerge_transform(make_tuple(N0, N1)), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return b_k0_n0_n1_k1_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + constexpr auto M11 = + Number{}; + constexpr auto N11 = + Number{}; + + constexpr auto M10 = M1 / M11; + constexpr auto N10 = N1 / N11; + + const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), + make_unmerge_transform(make_tuple(N0, N10, N11))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_m0_m10_m11_n0_n10_n11_grid_desc; + } + + __host__ __device__ static constexpr auto + MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + const auto c_blockid_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + return c_blockid_to_m0_n0_block_cluster_adaptor; + } + + using AK0M0M1K1GridDesc = decltype(MakeAK0M0M1K1GridDescriptor(AK0MK1GridDesc{})); + using BK0N0N1K1GridDesc = decltype(MakeBK0N0N1K1GridDescriptor(BK0NK1GridDesc{})); + using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{})); + using CBlockIdToM0N0BlockClusterAdaptor = + decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{})); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AK0M0M1K1GridDesc& a_k0_m0_m1_k1_grid_desc, + const BK0N0N1K1GridDesc& b_k0_n0_n1_k1_grid_desc, + const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc, + const CBlockIdToM0N0BlockClusterAdaptor& c_blockid_to_m0_n0_block_cluster_adaptor, + integral_constant, + integral_constant) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto c_m0_n0_block_cluster_idx = + c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex( + make_multi_index(get_block_1d_id())); + + // HACK: this force index data into SGPR + const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); + const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); + + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = K1; + + // TODO: check alignment + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_k0_m0_m1_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_k0_n0_n1_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + + // TODO: check alignment + // A matrix in LDS memory, for blockwise GEMM + constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, for blockwise GEMM + constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() == + a_k0_m_k1_block_desc.GetElementSpaceSize() && + b_k0_n0_n1_k1_block_desc.GetElementSpaceSize() == + b_k0_n_k1_block_desc.GetElementSpaceSize() && + "wrong!"); + + // A matrix blockwise copy + auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< + BlockSize, + InMemoryDataOperation::Set, + Sequence, + ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_k0_m0_m1_k1_grid_desc), + decltype(a_k0_m0_m1_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3>, + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths + ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths + ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder + false, + true>(a_k0_m0_m1_k1_grid_desc, + make_multi_index(0, im0, 0, 0), + a_k0_m0_m1_k1_block_desc, + make_multi_index(0, 0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1< + BlockSize, + InMemoryDataOperation::Set, + Sequence, + BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_k0_n0_n1_k1_grid_desc), + decltype(b_k0_n0_n1_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3>, + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths + BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths + BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder + false, + true>(b_k0_n0_n1_k1_grid_desc, + make_multi_index(0, in0, 0, 0), + b_k0_n0_n1_k1_block_desc, + make_multi_index(0, 0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[KPerBlock, MPerBlockM1] is in LDS + // b_mtx[KPerBlocl, NPerBlockN1] is in LDS + // c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in + // register + const auto blockwise_gemm = + BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2{}; + + constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = + decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths(); + + constexpr auto c_m10_m11_n10_n11_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2( + sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = math::integer_least_multiple( + a_k0_m0_m1_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = math::integer_least_multiple( + b_k0_n0_n1_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block_double = p_shared_block; + FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; + + // register allocation for output + auto c_thread_buf = make_static_buffer( + c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize()); + + ThreadwiseDynamicTensorSliceSet_v1{} + .Run(c_m10_m11_n10_n11_thread_desc, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + FloatAcc{0}); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0); + + auto a_block_even_buf = make_dynamic_buffer( + p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); + auto b_block_even_buf = make_dynamic_buffer( + p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); + + auto a_block_odd_buf = + make_dynamic_buffer(p_a_block_double + a_block_aligned_space_size, + a_k0_m0_m1_k1_block_desc.GetElementSpaceSize()); + auto b_block_odd_buf = + make_dynamic_buffer(p_b_block_double + b_block_aligned_space_size, + b_k0_n0_n1_k1_block_desc.GetElementSpaceSize()); + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); + b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{}); + + a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); + b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); + } + + if constexpr(HasMainKBlockLoop) + { + const auto K0 = a_k0_m0_m1_k1_grid_desc.GetLength(I0); + + index_t k_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, + a_block_slice_copy_step, + AGridMoveSliceWindowIteratorHacks{}); + b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, + b_block_slice_copy_step, + BGridMoveSliceWindowIteratorHacks{}); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); + b_blockwise_copy.RunRead( + b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{}); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc, + a_block_even_buf, + b_block_even_buf, + c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf); + + // odd iteration + a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, + a_block_slice_copy_step, + AGridMoveSliceWindowIteratorHacks{}); + b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, + b_block_slice_copy_step, + BGridMoveSliceWindowIteratorHacks{}); + + __syncthreads(); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead( + a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); + b_blockwise_copy.RunRead( + b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{}); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf); + b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf); + + k_block_data_begin += 2 * KPerBlock; + } while(k_block_data_begin < K0 - 2 * KPerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, + a_block_slice_copy_step, + AGridMoveSliceWindowIteratorHacks{}); + b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, + b_block_slice_copy_step, + BGridMoveSliceWindowIteratorHacks{}); + + __syncthreads(); + + // LDS double buffer: load last data from device mem + a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{}); + b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{}); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + + // LDS double buffer: store last data to LDS + a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf); + + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr index_t M11 = + M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101; + constexpr index_t N11 = + N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101; + + constexpr index_t M10 = MPerBlockM1 / M11; + constexpr index_t N10 = NPerBlockN1 / N11; + + constexpr index_t M111 = M1PerThreadM111; + constexpr index_t N111 = N1PerThreadN111; + + constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc = + make_dynamic_naive_tensor_descriptor_packed_v2( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = + blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id()); + + ThreadwiseDynamicTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), + decltype(c_m0_m10_m11_n0_n10_n11_grid_desc), + Sequence<1, + c_m10_m11_n10_n11_thread_tensor_lengths[I0], + c_m10_m11_n10_n11_thread_tensor_lengths[I1], + 1, + c_m10_m11_n10_n11_thread_tensor_lengths[I2], + c_m10_m11_n10_n11_thread_tensor_lengths[I3]>, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{c_m0_m10_m11_n0_n10_n11_grid_desc, + make_multi_index(im0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], + in0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])} + .Run(c_m0_m10_m11_n0_n10_n11_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_m0_m10_m11_n0_n10_n11_grid_desc, + c_grid_buf, + CGridIteratorHacks{}); + } + } +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer_v2.hpp b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer_v2.hpp new file mode 100644 index 0000000000..331c9fc201 --- /dev/null +++ b/composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer_v2.hpp @@ -0,0 +1,799 @@ +#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP +#define CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP + +#include "common_header.hpp" +#include "dynamic_tensor_descriptor.hpp" +#include "dynamic_tensor_descriptor_helper.hpp" + +namespace ck { + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +template // control whether to move back dst coordinate after each + // RunWrite(), will be fused with MoveDstSliceWindow to + // save addr computation +struct ThreadwiseDynamicTensorSliceTransfer_v3r1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{})); + + using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{})); + using DstCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(DstDesc{}, Index{})); + + __device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v3r1(const SrcDesc& src_desc, + const Index& src_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin) + : src_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin)) + { + // TODO: fix this + static_assert(is_same::value, + "wrong! current implementation assume SrcData and DstData are same type"); + + static_for<0, nDim, 1>{}([](auto i) { + static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0 && + SliceLengths::At(i) % DstVectorTensorLengths::At(i) == 0, + "wrong!"); + }); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const SrcIteratorHacks& src_iterator_hacks) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpace::Global or + SrcBuffer::GetAddressSpace() == AddressSpace::Lds, + "wrong!"); + + static_assert(is_same>, + remove_cv_t>>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + // tensor descriptor for src_vector + constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; + + constexpr auto src_vector_tensor_strides = container_reorder_given_old2new( + container_reverse_exclusive_scan( + container_reorder_given_new2old(src_vector_tensor_lengths, + SrcVectorTensorContiguousDimOrder{}), + math::multiplies_v2{}, + I1), + SrcVectorTensorContiguousDimOrder{}); + + constexpr auto src_vector_desc = make_dynamic_naive_tensor_descriptor_v2( + sequence_to_tuple_of_number(src_vector_tensor_lengths), + sequence_to_tuple_of_number(src_vector_tensor_strides)); + + // access order and lengths + constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward iterators + const auto src_forward_iterators = generate_tuple( + [&](auto i) { + Index forward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0; + }); + + return make_dynamic_tensor_coordinate_iterator( + src_desc, forward_step, src_iterator_hacks[I0][i]); + }, + Number{}); + + // make backward iterators + const auto src_backward_iterators = generate_tuple( + [&](auto i) { + Index backward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0; + }); + + return make_dynamic_tensor_coordinate_iterator( + src_desc, backward_step, src_iterator_hacks[I1][i]); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep; + + forward_sweep(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep(i) = tmp % 2 == 0; + }); + + return forward_sweep; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + auto src_data_idx = + container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_vector_tensor_lengths; + + return src_data_idx; + }(); + + vector_type_maker_t src_vector; + + using src_vector_t = typename decltype(src_vector)::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + // copy data from src_buf to src_vector + src_vector.template AsType()(I0) = + src_buf.template Get(src_coord_.GetOffset(), is_src_valid); + + // copy data from src_vector to buffer_ + static_ford{}([&](auto src_vector_idx_) { + constexpr auto src_vector_idx = to_multi_index(src_vector_idx_); + + constexpr index_t src_vector_offset = + src_vector_desc.CalculateOffset(src_vector_idx); + + constexpr index_t buffer_offset = + buffer_desc_.CalculateOffset(src_data_idx + src_vector_idx); + + buffer_(Number{}) = + src_vector.template AsType()[Number{}]; + }); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + return move_on_dim; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_dynamic_tensor_coordinate( + src_desc, src_coord_, src_forward_iterators[src_dim_access_order[i]]); + } + else + { + move_dynamic_tensor_coordinate( + src_desc, src_coord_, src_backward_iterators[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_iterator = + make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep()); + + move_dynamic_tensor_coordinate(src_desc, src_coord_, src_reset_iterator); + } + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + const DstIteratorHacks& dst_iterator_hacks) + { + static_assert(DstBuffer::GetAddressSpace() == AddressSpace::Global or + DstBuffer::GetAddressSpace() == AddressSpace::Lds, + "wrong!"); + + static_assert(is_same>, + remove_cv_t>>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + // tensor descriptor for dst_vector + constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{}; + + constexpr auto dst_vector_tensor_strides = container_reorder_given_old2new( + container_reverse_exclusive_scan( + container_reorder_given_new2old(dst_vector_tensor_lengths, + DstVectorTensorContiguousDimOrder{}), + math::multiplies_v2{}, + I1), + DstVectorTensorContiguousDimOrder{}); + + constexpr auto dst_vector_desc = make_dynamic_naive_tensor_descriptor_v2( + sequence_to_tuple_of_number(dst_vector_tensor_lengths), + sequence_to_tuple_of_number(dst_vector_tensor_strides)); + + // dst access order and lengths + constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // make forward iterators + const auto dst_forward_iterators = generate_tuple( + [&](auto i) { + Index forward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0; + }); + + const auto forward_iterator = make_dynamic_tensor_coordinate_iterator( + dst_desc, forward_step, dst_iterator_hacks[I0][i]); + + return forward_iterator; + }, + Number{}); + + // make backward iterators + const auto dst_backward_iterators = generate_tuple( + [&](auto i) { + Index backward_step; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0; + }); + + const auto backward_iterator = make_dynamic_tensor_coordinate_iterator( + dst_desc, backward_step, dst_iterator_hacks[I1][i]); + + return backward_iterator; + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep; + + forward_sweep(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep(i) = tmp % 2 == 0; + }); + + return forward_sweep; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths[i] - 1 - + ordered_dst_access_idx[i]; + }); + + auto dst_data_idx = + container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_vector_tensor_lengths; + + return dst_data_idx; + }(); + + vector_type_maker_t dst_vector; + + // copy data from buffer_ to dst_vector (also cast from SrcData to DstData) + static_ford{}([&](auto dst_vector_idx_) { + constexpr auto dst_vector_idx = to_multi_index(dst_vector_idx_); + + constexpr index_t buffer_offset = + buffer_desc_.CalculateOffset(dst_data_idx + dst_vector_idx); + + constexpr index_t dst_vector_offset = + dst_vector_desc.CalculateOffset(dst_vector_idx); + + dst_vector.template AsType()(Number{}) = + type_convert{}(buffer_[Number{}]); + }); + + using dst_vector_t = typename decltype(dst_vector)::type; + + // copy data from dst_vector to dst_buf + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector.template AsType()[Number<0>{}]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim(i) &= + ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_dynamic_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_iterators[dst_dim_access_order[i]]); + } + else + { + move_dynamic_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_iterators[dst_dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_iterator = + make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep()); + + move_dynamic_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator); + } + } + + template + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) + { + constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto src_iterator_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + RunRead(src_desc, src_buf, src_iterator_hacks); + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) + { + constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto dst_iterator_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + RunWrite(dst_desc, dst_buf, dst_iterator_hacks); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; + + constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep; + + forward_sweep(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep(i) = tmp % 2 == 0; + }); + + return forward_sweep; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + auto src_data_idx = container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_vector_tensor_lengths; + + return src_data_idx; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step(i) = -src_data_idx[i]; }); + + return reset_src_data_step; + }(); + + return reset_src_data_step; + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{}; + + constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep; + + forward_sweep(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep(i) = tmp % 2 == 0; + }); + + return forward_sweep; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + auto dst_data_idx = container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_vector_tensor_lengths; + + return dst_data_idx; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = + make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx); + + move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx, + const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_dynamic_tensor_coordinate_iterator( + src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack); + + move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = + make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx); + + move_dynamic_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + static constexpr auto buffer_desc_ = + make_dynamic_naive_tensor_descriptor_packed_v2(sequence_to_tuple_of_number(SliceLengths{})); + + static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); + + StaticBuffer buffer_; + + SrcCoord src_coord_; + DstCoord dst_coord_; +}; + +// Assume: +// 1. src: +// 1. SrcDesc is known at compile-time +// 2. SrcBuffer is DynamicBuffer +// 3. src_ref_idx is known at run-time +// 4. SrcRefToOriginDisplacement is known at compile-time +// 5. use #-iterator +// 2. dst: +// 1. DstDesc is known at compile-time +// 2. DstBuffer is StaticBuffer +// 3. DstOriginIdx is known at compile-time +// 4. use direct address calculation +// 3. vector access on src +template < + typename SrcData, + typename DstData, + typename SrcDesc, + typename DstDesc, + typename SliceLengths, + typename DimAccessOrder, + typename SrcVectorTensorLengths, + typename SrcVectorTensorContiguousDimOrder, + typename std::enable_if::type = false> +struct ThreadwiseDynamicTensorSliceTransfer_v4r1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{})); + + using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{})); + + __device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v4r1(const Index& src_ref_idx) + : src_ref_coord_(make_dynamic_tensor_coordinate(SrcDesc{}, src_ref_idx)) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_for<0, nDim, 1>{}([](auto i) { + static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0, "wrong!"); + }); + } + + template + __device__ void Run(const SrcDesc&, + const SrcRefToOriginDisplacement&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert(is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); + + static_assert( + is_known_at_compile_time< + remove_cv_t>>::value && + is_known_at_compile_time>>::value, + "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " + "at compile-time"); + + // SrcDesc and DstDesc are known at compile-time + constexpr auto src_desc = remove_cv_t>{}; + constexpr auto dst_desc = remove_cv_t>{}; + + // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time + constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); + constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{}); + + // tensor descriptor for src_vector + constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; + + constexpr auto src_vector_tensor_strides = container_reorder_given_old2new( + container_reverse_exclusive_scan( + container_reorder_given_new2old(src_vector_tensor_lengths, + SrcVectorTensorContiguousDimOrder{}), + math::multiplies_v2{}, + I1), + SrcVectorTensorContiguousDimOrder{}); + + constexpr auto src_vector_desc = make_dynamic_naive_tensor_descriptor_v2( + sequence_to_tuple_of_number(src_vector_tensor_lengths), + sequence_to_tuple_of_number(src_vector_tensor_strides)); + + // access order and lengths + constexpr auto access_lengths = SliceLengths{} / src_vector_tensor_lengths; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + static_ford{}([&](auto ordered_access_idx) { + // position in slice window + constexpr auto data_to_origin_disp_idx = + ordered_access_idx.ReorderGivenOld2New(dim_access_order) * + src_vector_tensor_lengths; + + // src coordinate at starting point of src_vector + constexpr auto src_ref_to_data_disp_idx = + src_ref_to_origin_disp_idx + data_to_origin_disp_idx; + + constexpr auto src_ref_to_data_disp_coord_iterator = + make_dynamic_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx); + + auto src_data_coord = src_ref_coord_; + + move_dynamic_tensor_coordinate( + src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator); + + vector_type_maker_t src_vector; + + using src_vector_t = typename decltype(src_vector)::type; + + const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( + src_desc, src_data_coord); + + // copy data from src_buf into src_vector + src_vector.template AsType()(I0) = + src_buf.template Get(src_data_coord.GetOffset(), is_src_valid); + + // copy data from src_vector into dst_buf (also cast from SrcData to DstData) + static_ford{}([&](auto src_vector_idx_) { + constexpr auto src_vector_idx = to_multi_index(src_vector_idx_); + + constexpr index_t src_vector_offset = + src_vector_desc.CalculateOffset(src_vector_idx); + + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + src_vector_idx); + + dst_buf(Number{}) = type_convert{}( + src_vector.template AsType()[Number{}]); + }); + }); + } + + template + __device__ void MoveSrcSliceWindow(const SrcDesc&, + const SrcSliceMoveStepIdx& src_slice_move_step_idx) + { + constexpr auto src_desc = SrcDesc{}; + + const auto src_slice_move_step_iter = make_dynamic_tensor_coordinate_iterator( + src_desc, to_multi_index(src_slice_move_step_idx)); + + move_dynamic_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); + } + + private: + SrcCoord src_ref_coord_; +}; + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp b/composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp index aa22cf4aa9..748a888999 100644 --- a/composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp @@ -6,140 +6,6 @@ namespace ck { -// C[M, N] += transpose(A[K, M]) * B[K, N] -// Element of matrix can be vectorized data -// Assume: -// 1. ADesc, BDesc, CDesc are known at compile-time -// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time -template ::type = false> -struct ThreadwiseGemm_km_kn_mn_v1r1 -{ - template - __device__ static void Run(const ABuffer& a_buf, - AOriginIdx, - const BBuffer& b_buf, - BOriginIdx, - CBuffer& c_buf, - COriginIdx) - { - static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && - CDesc::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - static_assert( - is_known_at_compile_time>>::value && - is_known_at_compile_time>>::value && - is_known_at_compile_time>>::value, - "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); - - static_assert(is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - "wrong! inconsistent type"); - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - constexpr auto M = CDesc{}.GetLength(I0); - constexpr auto N = CDesc{}.GetLength(I1); - constexpr auto K = ADesc{}.GetLength(I0); - - constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); - constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); - constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); - - static_for<0, K, 1>{}([&](auto k) { - static_for<0, M, 1>{}([&](auto m) { - constexpr index_t a_offset = - ADesc{}.CalculateOffset(a_origin_idx + make_tuple(k, m)); - -#if 0 - if constexpr(N == 2) - { - constexpr index_t b_offset_0 = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0)); - constexpr index_t b_offset_1 = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1)); - - constexpr index_t c_offset_0 = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0)); - constexpr index_t c_offset_1 = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1)); - - amd_assembly_outer_product_1x2(a_buf[Number{}], - b_buf[Number{}], - b_buf[Number{}], - c_buf(Number{}), - c_buf(Number{})); - } - else if constexpr(N == 4) - { - constexpr index_t b_offset_0 = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0)); - constexpr index_t b_offset_1 = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1)); - constexpr index_t b_offset_2 = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I2)); - constexpr index_t b_offset_3 = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I3)); - - constexpr index_t c_offset_0 = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0)); - constexpr index_t c_offset_1 = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1)); - constexpr index_t c_offset_2 = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I2)); - constexpr index_t c_offset_3 = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I3)); - - amd_assembly_outer_product_1x4(a_buf[Number{}], - b_buf[Number{}], - b_buf[Number{}], - b_buf[Number{}], - b_buf[Number{}], - c_buf(Number{}), - c_buf(Number{}), - c_buf(Number{}), - c_buf(Number{})); - } - else -#endif - { - static_for<0, N, 1>{}([&](auto n) { - - constexpr index_t b_offset = - BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, n)); - constexpr index_t c_offset = - CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, n)); - - amd_assembly_inner_product(a_buf[Number{}], - b_buf[Number{}], - c_buf(Number{})); - }); - } - }); - }); - } -}; - // C[M0, M1, N0, N1] += A[K, M0, M1] * B[K, N0, N1] // Tensor element can be vectorized data // Assume: @@ -227,9 +93,124 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1 constexpr index_t c_offset = CDesc{}.CalculateOffset( c_origin_idx + make_multi_index(m0, m1, n0, n1)); - amd_assembly_inner_product(a_buf[Number{}], - b_buf[Number{}], - c_buf(Number{})); + amd_inner_product_dlop( + a_buf[Number{}], + b_buf[Number{}], + c_buf(Number{})); + }); + }); + }); + }); + }); + } +}; + +// C[M0, M1, N0, N1] += A[K0, M0, M1, K1] * B[K0, N0, N1, K1] +// Tensor element can be vectorized data +// Assume: +// 1. ADesc, BDesc, CDesc are known at compile-time +// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time +template ::type = false> +struct ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1 +{ + __device__ constexpr ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1() + { + static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && + CDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + // TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths + + // TODO remove this restriction + static_assert(KLengths::Size() == 2 && MLengths::Size() == 2 && NLengths::Size() == 2, + "wrong!"); + } + + template + __device__ static void Run(const ABuffer& a_buf, + AOriginIdx, + const BBuffer& b_buf, + BOriginIdx, + CBuffer& c_buf, + COriginIdx) + { + static_assert( + is_known_at_compile_time>>::value && + is_known_at_compile_time>>::value && + is_known_at_compile_time>>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + + static_assert(is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + is_same>, + remove_cv_t>>::value && + "wrong! inconsistent type"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr index_t K0 = KLengths{}[I0]; + constexpr index_t K1 = KLengths{}[I1]; + constexpr index_t M0 = MLengths{}[I0]; + constexpr index_t M1 = MLengths{}[I1]; + constexpr index_t N0 = NLengths{}[I0]; + constexpr index_t N1 = NLengths{}[I1]; + + constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); + constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); + constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); + + static_for<0, K0, 1>{}([&](auto k0) { + static_for<0, M0, 1>{}([&](auto m0) { + static_for<0, M1, 1>{}([&](auto m1) { + static_for<0, N0, 1>{}([&](auto n0) { + static_for<0, N1, 1>{}([&](auto n1) { + + vector_type a_vec; + vector_type b_vec; + + static_for<0, K1, 1>{}([&](auto k1) { + constexpr index_t a_offset = ADesc{}.CalculateOffset( + a_origin_idx + make_multi_index(k0, m0, m1, k1)); + + constexpr index_t b_offset = BDesc{}.CalculateOffset( + b_origin_idx + make_multi_index(k0, n0, n1, k1)); + + a_vec.template AsType()(k1) = a_buf[Number{}]; + + b_vec.template AsType()(k1) = b_buf[Number{}]; + }); + + using a_vector_t = typename vector_type::type; + using b_vector_t = typename vector_type::type; + + constexpr index_t c_offset = CDesc{}.CalculateOffset( + c_origin_idx + make_multi_index(m0, m1, n0, n1)); + + amd_inner_product_dlop( + a_vec.template AsType()[I0], + b_vec.template AsType()[I0], + c_buf(Number{})); }); }); }); diff --git a/composable_kernel/include/utility/amd_dlop.hpp b/composable_kernel/include/utility/amd_dlop.hpp new file mode 100644 index 0000000000..c67cfc7118 --- /dev/null +++ b/composable_kernel/include/utility/amd_dlop.hpp @@ -0,0 +1,146 @@ +#ifndef CK_AMD_DLOP_HPP +#define CK_AMD_DLOP_HPP + +#include "float_type.hpp" + +namespace ck { + +template +__device__ void amd_inner_product_dlop(const TA& a, const TB& b, TC& c); + +template <> +__device__ void +amd_inner_product_dlop(const float& a, const float& b, float& c) +{ +#if CK_USE_AMD_DLOP_INLINE_ASM + asm volatile("\n \ + v_fmac_f32 %0, %1, %2 \n \ + " + : "=v"(c) + : "v"(a), "v"(b), "0"(c)); +#else + c += a * b; +#endif +} + +#if CK_USE_AMD_DLOP +template <> +__device__ void +amd_inner_product_dlop(const half2_t& a, const half2_t& b, float& c) +{ +#if CK_USE_AMD_DLOP_INLINE_ASM + asm volatile("\n \ + v_dot2_f32_f16 %0, %1, %2, %0\n \ + " + : "=v"(c) + : "v"(a), "v"(b), "0"(c)); +#else + c = __builtin_amdgcn_sdot2(a, b, c, false); +#endif +} + +template <> +__device__ void +amd_inner_product_dlop(const half4_t& a, const half4_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + amd_inner_product_dlop(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); +} + +template <> +__device__ void +amd_inner_product_dlop(const half8_t& a, const half8_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + amd_inner_product_dlop(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I2], + vector_type{b}.AsType()[I2], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I3], + vector_type{b}.AsType()[I3], + c); +} + +template <> +__device__ void amd_inner_product_dlop(const int8x4_t& a, + const int8x4_t& b, + int32_t& c) +{ +#if CK_USE_AMD_DLOP_INLINE_ASM + asm volatile("\n \ + v_dot4_i32_i8 %0, %1, %2, %0\n \ + " + : "=v"(c) + : "v"(as_type(a)), "v"(as_type(b)), "0"(c)); +#else + c = __builtin_amdgcn_sdot4(as_type(a), as_type(b), c, false); +#endif +} + +template <> +__device__ void amd_inner_product_dlop(const int8x8_t& a, + const int8x8_t& b, + int32_t& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + amd_inner_product_dlop(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); +} + +template <> +__device__ void amd_inner_product_dlop(const int8x16_t& a, + const int8x16_t& b, + int32_t& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + amd_inner_product_dlop(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I2], + vector_type{b}.AsType()[I2], + c); + + amd_inner_product_dlop(vector_type{a}.AsType()[I3], + vector_type{b}.AsType()[I3], + c); +} +#endif // CK_USE_AMD_DLOP + +} // namespace ck +#endif diff --git a/composable_kernel/include/utility/amd_inline_asm.hpp b/composable_kernel/include/utility/amd_inline_asm.hpp index b5d2e4e38e..c4b167a128 100644 --- a/composable_kernel/include/utility/amd_inline_asm.hpp +++ b/composable_kernel/include/utility/amd_inline_asm.hpp @@ -5,94 +5,16 @@ namespace ck { -// c += inner_product(a, b) -__device__ void amd_assembly_inner_product(const float& a, const float& b, float& c) -{ -#if CK_USE_AMD_V_FMAC_F32 - asm volatile("\n \ - v_fmac_f32 %0, %1, %2 \n \ - " - : "=v"(c) - : "v"(a), "v"(b), "0"(c)); -#else - asm volatile("\n \ - v_mac_f32 %0, %1, %2 \n \ - " - : "=v"(c) - : "v"(a), "v"(b), "0"(c)); -#endif -} - -__device__ void amd_assembly_inner_product(const int8x4_t& a, const int8x4_t& b, int32_t& c) -{ -#if 1 - asm volatile("\n \ - v_dot4_i32_i8 %0, %1, %2, %0\n \ - " - : "=v"(c) - : "v"(as_type(a)), "v"(as_type(b)), "0"(c)); -#else - c = __builtin_amdgcn_sdot4(as_type(a), as_type(b), c, false); -#endif -} - -__device__ void amd_assembly_inner_product(const int8x8_t& a, const int8x8_t& b, int32_t& c) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - - amd_assembly_inner_product(vector_type{a}.AsType()[I0], - vector_type{b}.AsType()[I0], - c); - - amd_assembly_inner_product(vector_type{a}.AsType()[I1], - vector_type{b}.AsType()[I1], - c); -} - -__device__ void amd_assembly_inner_product(const int8x16_t& a, const int8x16_t& b, int32_t& c) -{ - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - amd_assembly_inner_product(vector_type{a}.AsType()[I0], - vector_type{b}.AsType()[I0], - c); - - amd_assembly_inner_product(vector_type{a}.AsType()[I1], - vector_type{b}.AsType()[I1], - c); - - amd_assembly_inner_product(vector_type{a}.AsType()[I2], - vector_type{b}.AsType()[I2], - c); - - amd_assembly_inner_product(vector_type{a}.AsType()[I3], - vector_type{b}.AsType()[I3], - c); -} - // c0 += inner_product(a, b0) // c1 += inner_product(a, b1) __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) { -#if CK_USE_AMD_V_FMAC_F32 asm volatile("\n \ v_fmac_f32 %0, %2, %3 \n \ v_fmac_f32 %1, %2, %4 \n \ " : "=v"(c0), "=v"(c1) : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); -#else - asm volatile("\n \ - v_mac_f32 %0, %2, %3 \n \ - v_mac_f32 %1, %2, %4 \n \ - " - : "=v"(c0), "=v"(c1) - : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); -#endif } // c0 += inner_product(a, b0) @@ -102,7 +24,6 @@ __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, floa __device__ void amd_assembly_outer_product_1x4( float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3) { -#if CK_USE_AMD_V_FMAC_F32 asm volatile("\n \ v_fmac_f32 %0, %4, %5 \n \ v_fmac_f32 %1, %4, %6 \n \ @@ -111,16 +32,6 @@ __device__ void amd_assembly_outer_product_1x4( " : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); -#else - asm volatile("\n \ - v_mac_f32 %0, %4, %5 \n \ - v_mac_f32 %1, %4, %6 \n \ - v_mac_f32 %2, %4, %7 \n \ - v_mac_f32 %3, %4, %8 \n \ - " - : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) - : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); -#endif } // c0 += inner_product(a, b0) diff --git a/composable_kernel/include/utility/common_header.hpp b/composable_kernel/include/utility/common_header.hpp index 309571b061..9517c8e8bd 100644 --- a/composable_kernel/include/utility/common_header.hpp +++ b/composable_kernel/include/utility/common_header.hpp @@ -28,10 +28,15 @@ #include "static_buffer.hpp" #include "dynamic_buffer.hpp" +// TODO: remove this #if CK_USE_AMD_INLINE_ASM #include "amd_inline_asm.hpp" #endif +#if CK_USE_AMD_DLOP +#include "amd_dlop.hpp" +#endif + #if CK_USE_AMD_XDLOPS #include "amd_xdlops.hpp" #include "amd_xdlops_inline_asm.hpp" diff --git a/composable_kernel/include/utility/config.amd.hpp.in b/composable_kernel/include/utility/config.amd.hpp.in index fd602e7f96..dadaa53b1b 100644 --- a/composable_kernel/include/utility/config.amd.hpp.in +++ b/composable_kernel/include/utility/config.amd.hpp.in @@ -54,8 +54,13 @@ #define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 1 #endif -#ifndef CK_USE_AMD_V_FMAC_F32 -#define CK_USE_AMD_V_FMAC_F32 1 +// AMD DLOPS +#ifndef CK_USE_AMD_DLOP +#define CK_USE_AMD_DLOP 1 +#endif + +#ifndef CK_USE_AMD_DLOP_INLINE_ASM +#define CK_USE_AMD_DLOP_INLINE_ASM 1 #endif // AMD buffer addressing @@ -116,7 +121,7 @@ #define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1 // merge transformation use magic number division -#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 1 +#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0 // hack: have underlying assumption that need to be satsified, otherwise it's a bug // hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be diff --git a/composable_kernel/include/utility/container_helper.hpp b/composable_kernel/include/utility/container_helper.hpp index 2ff0c46e6d..a7ed8ec059 100644 --- a/composable_kernel/include/utility/container_helper.hpp +++ b/composable_kernel/include/utility/container_helper.hpp @@ -94,7 +94,7 @@ __host__ __device__ constexpr auto container_reorder_given_old2new(Sequence>::type{}; - return container_reorder_give_new2old(old_seq, new2old); + return container_reorder_given_new2old(old_seq, new2old); } #if !CK_WORKAROUND_SWDEV_275126 @@ -223,6 +223,13 @@ container_reverse_exclusive_scan(const Array& x, Reduce f, TData i return y; } +template +__host__ __device__ constexpr auto +container_reverse_exclusive_scan(const Sequence& seq, Reduce f, Number) +{ + return reverse_exclusive_scan_sequence(seq, f, Number{}); +} + #if !CK_WORKAROUND_SWDEV_275126 // rocm4.1 compiler would crash with recursive lambda template @@ -366,6 +373,19 @@ set_container_subset(Tuple& y, Sequence picks, const Tuple& static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; }); } +template +__host__ __device__ constexpr auto to_tuple_of_number(const Container&) +{ + static_assert(is_known_at_compile_time::value, "wrong!"); + + return generate_tuple( + [&](auto i) { + constexpr index_t tmp = Container::At(i); + return Number{}; + }, + Container::Size()); +} + template __host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence) { diff --git a/composable_kernel/include/utility/dynamic_buffer.hpp b/composable_kernel/include/utility/dynamic_buffer.hpp index addcaff22a..073577b93f 100644 --- a/composable_kernel/include/utility/dynamic_buffer.hpp +++ b/composable_kernel/include/utility/dynamic_buffer.hpp @@ -100,39 +100,72 @@ struct DynamicBuffer *reinterpret_cast(&p_data_[i]) = x; #else // HACK: compiler would lower IR "store address_space(3)" into inefficient - // ISA, so I try to let compiler emit use IR "store" which would be lower to + // ISA, so I try to let compiler emit IR "store" which would be lower to // ds_write_b128 // TODO: remove this after compiler fix if constexpr(is_same>>::type, int8_t>::value) { static_assert( - (is_same>, int8x4_t>::value && - is_same>, int8x4_t>::value) || + (is_same>, int8_t>::value && + is_same>, int8_t>::value) || + (is_same>, int8_t>::value && + is_same>, int8x2_t>::value) || + (is_same>, int8_t>::value && + is_same>, int8x4_t>::value) || + (is_same>, int8x4_t>::value && + is_same>, int8x4_t>::value) || (is_same>, int8x8_t>::value && is_same>, int8x8_t>::value) || (is_same>, int8x16_t>::value && is_same>, int8x16_t>::value), "wrong! not implemented for this combination, please add implementation"); - if constexpr(is_same>, int8x4_t>::value && - is_same>, int8x4_t>::value) + if constexpr(is_same>, int8_t>::value && + is_same>, int8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *reinterpret_cast(&p_data_[i]) = + *reinterpret_cast(&x); + } + else if constexpr(is_same>, int8_t>::value && + is_same>, int8x2_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *reinterpret_cast(&p_data_[i]) = + *reinterpret_cast(&x); + } + else if constexpr(is_same>, int8_t>::value && + is_same>, int8x4_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *reinterpret_cast(&p_data_[i]) = *reinterpret_cast(&x); } - if constexpr(is_same>, int8x8_t>::value && - is_same>, int8x8_t>::value) + else if constexpr(is_same>, + int8x4_t>::value && + is_same>, int8x4_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *reinterpret_cast(&p_data_[i]) = + *reinterpret_cast(&x); + } + else if constexpr(is_same>, + int8x8_t>::value && + is_same>, int8x8_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *reinterpret_cast(&p_data_[i]) = *reinterpret_cast(&x); } - if constexpr(is_same>, int8x16_t>::value && - is_same>, int8x16_t>::value) + else if constexpr(is_same>, + int8x16_t>::value && + is_same>, int8x16_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix diff --git a/driver/conv_driver_v2.cpp b/driver/conv_driver_v2.cpp index 530b779c2e..38b93395f9 100644 --- a/driver/conv_driver_v2.cpp +++ b/driver/conv_driver_v2.cpp @@ -14,7 +14,9 @@ #include "device_tensor.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp" +#include "device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp" @@ -24,23 +26,27 @@ #define USE_DYNAMIC_MODE 1 #define USE_CONV_FWD_V4R4_NCHW 0 #define USE_CONV_FWD_V4R4_NHWC 0 +#define USE_CONV_FWD_V4R4R2_NHWC 0 #define USE_CONV_FWD_V4R5_NCHW 0 +#define USE_CONV_FWD_V4R5R2_NCHW 1 #define USE_CONV_FWD_V5R1_NCHW 0 #define USE_CONV_FWD_V4R4_XDL_NCHW 0 #define USE_CONV_FWD_V4R4R2_XDL_NHWC 0 -#define USE_CONV_FWD_V4R4R3_XDL_NHWC 1 -#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1 +#define USE_CONV_FWD_V4R4R3_XDL_NHWC 0 +#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0 enum ConvForwardAlgo { V4R4NCHW, // 0 V4R4NHWC, // 1 - V4R5NCHW, // 2 - V5R1NCHW, // 3 - V4R4XDLNCHW, // 4 - V4R4R2XDLNHWC, // 5 - V4R4R3XDLNHWC, // 6 - V4R4R4XDLNHWC // 7 + V4R4R2NHWC, // 2 + V4R5NCHW, // 3 + V4R5R2NCHW, // 4 + V5R1NCHW, // 5 + V4R4XDLNCHW, // 6 + V4R4R2XDLNHWC, // 7 + V4R4R3XDLNHWC, // 8 + V4R4R4XDLNHWC // 9 }; int main(int argc, char* argv[]) @@ -132,21 +138,18 @@ int main(int argc, char* argv[]) const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; #endif -#if 0 - constexpr index_t in_vector_size = 1; - using in_data_t = float; - using acc_data_t = float; - using out_data_t = float; +#if 1 + using in_data_t = float; + using acc_data_t = float; + using out_data_t = float; #elif 1 - constexpr index_t in_vector_size = 1; - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; #elif 1 - constexpr index_t in_vector_size = 16; - using in_data_t = int8_t; - using acc_data_t = int32_t; - using out_data_t = int8_t; + using in_data_t = int8_t; + using acc_data_t = int32_t; + using out_data_t = int8_t; #endif std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); @@ -348,6 +351,33 @@ int main(int argc, char* argv[]) } #endif +#if USE_CONV_FWD_V4R4R2_NHWC + if(algo == ConvForwardAlgo::V4R4R2NHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); + } +#endif + #if USE_CONV_FWD_V4R5_NCHW if(algo == ConvForwardAlgo::V4R5NCHW) { @@ -374,6 +404,33 @@ int main(int argc, char* argv[]) } #endif +#if USE_CONV_FWD_V4R5R2_NCHW + if(algo == ConvForwardAlgo::V4R5R2NCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + + device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei, + out_device, + nrepeat); + } +#endif + #if USE_CONV_FWD_V5R1_NCHW if(algo == ConvForwardAlgo::V5R1NCHW) { @@ -385,7 +442,7 @@ int main(int argc, char* argv[]) const auto tmp = f_make_for_device_nchw(); device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(tmp[I0], tmp[I1], @@ -525,10 +582,10 @@ int main(int argc, char* argv[]) #if 0 if(do_log) { - LogRange(std::cout << "in : ", in.mData, ",") << std::endl; - LogRange(std::cout << "wei: ", wei.mData, ",") << std::endl; - LogRange(std::cout << "out_host : ", out_host.mData, ",") << std::endl; - LogRange(std::cout << "out_device: ", out_device.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; + LogRangeAsType(std::cout << "wei: ", wei.mData, ",") << std::endl; + LogRangeAsType(std::cout << "out_host : ", out_host.mData, ",") << std::endl; + LogRangeAsType(std::cout << "out_device: ", out_device.mData, ",") << std::endl; } #endif } diff --git a/driver/include/device.hpp b/driver/include/device.hpp index b68e07c85f..7dac30a45a 100644 --- a/driver/include/device.hpp +++ b/driver/include/device.hpp @@ -55,7 +55,7 @@ float launch_and_time_kernel(F kernel, { KernelTimer timer; - printf("%s: block_dim {%d, %d, %d}, grid_dim {%d, %d, %d} \n", + printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n", __func__, grid_dim.x, grid_dim.y, diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..ea0ada9f88 --- /dev/null +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,292 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" +#include "driver_dynamic_gemm_v1r3.hpp" + +template +void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + constexpr auto I7 = Number<7>{}; + constexpr auto I8 = Number<8>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); + +#if 0 + // [M, N, K0, K1] = [128, 128, 8, 1] for fp32 + // cdata = 64, BlockSize = 256 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlockM1 = 128; + constexpr index_t GemmNPerBlockN1 = 128; + constexpr index_t GemmKPerBlock = 8; + constexpr index_t GemmK1 = 1; + + constexpr index_t GemmM1PerThreadM111 = 4; + constexpr index_t GemmN1PerThreadN111 = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmM11N11ThreadClusterM1100 = 8; + constexpr index_t GemmM11N11ThreadClusterN1100 = 8; + constexpr index_t GemmM11N11ThreadClusterM1101 = 2; + constexpr index_t GemmM11N11ThreadClusterN1101 = 2; + + using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>; + using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>; + + using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>; + using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 1>; + + using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>; + using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>; + + using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>; + using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 1>; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 8, 2] for fp16 + // cdata = 64, BlockSize = 256 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlockM1 = 128; + constexpr index_t GemmNPerBlockN1 = 128; + constexpr index_t GemmKPerBlock = 8; + constexpr index_t GemmK1 = 2; + + constexpr index_t GemmM1PerThreadM111 = 4; + constexpr index_t GemmN1PerThreadN111 = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmM11N11ThreadClusterM1100 = 8; + constexpr index_t GemmM11N11ThreadClusterN1100 = 8; + constexpr index_t GemmM11N11ThreadClusterM1101 = 2; + constexpr index_t GemmM11N11ThreadClusterN1101 = 2; + + using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>; + using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>; + + using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>; + using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 2>; + + using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>; + using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>; + + using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>; + using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 2>; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 8, 4] for i8 + // cdata = 64, BlockSize = 256 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlockM1 = 128; + constexpr index_t GemmNPerBlockN1 = 128; + constexpr index_t GemmKPerBlock = 8; + constexpr index_t GemmK1 = 4; + + constexpr index_t GemmM1PerThreadM111 = 4; + constexpr index_t GemmN1PerThreadN111 = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmM11N11ThreadClusterM1100 = 8; + constexpr index_t GemmM11N11ThreadClusterN1100 = 8; + constexpr index_t GemmM11N11ThreadClusterM1101 = 2; + constexpr index_t GemmM11N11ThreadClusterN1101 = 2; + + using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>; + using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>; + + using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>; + using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 4>; + + using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>; + using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>; + + using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>; + using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 4>; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4; +#endif + + const auto descs = + transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto out_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GemmM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1 + + constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1 + + constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmM0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM10 + Sequence<0, 0, 0, 0, 0>{}, // 2+: GemmM11 + Sequence<0, 0, 0, 0, 0>{}, // 3+: GemmN0 + Sequence<0, 0, 0, 0, 0>{}, // 4+: GemmN10 + Sequence<0, 0, 0, 0, 0>{}), // 5+: GemmN11 + make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmM0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM10 + Sequence<0, 0, 0, 0, 0>{}, // 2-: GemmM11 + Sequence<0, 0, 0, 0, 0>{}, // 3-: GemmN0 + Sequence<0, 0, 0, 0, 0>{}, // 4-: GemmN10 + Sequence<0, 0, 0, 0, 0>{})); // 5-: GemmN11 + + constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{}; + + constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_dynamic_gemm_v1r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperation::Set, + decltype(in_gemmk0_gemmm_gemmk1_grid_desc), + decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), + decltype(out_gemmm_gemmn_grid_desc), + GemmMPerBlockM1, + GemmNPerBlockN1, + GemmKPerBlock, + GemmM1PerThreadM111, + GemmN1PerThreadN111, + GemmKPerThread, + GemmM11N11ThreadClusterM1100, + GemmM11N11ThreadClusterN1100, + GemmM11N11ThreadClusterM1101, + GemmM11N11ThreadClusterN1101, + GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1, + GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1, + Sequence<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder + Sequence<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder + GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, + Sequence<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder + GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, + GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1, + GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1, + Sequence<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder + Sequence<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder + GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, + Sequence<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder + GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, + Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder + 5, // CThreadTransferSrcDstVectorDim + GemmCThreadTransferDstScalarPerVector_N11, + decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks), + decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks), + decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks), + decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks), + decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks)>( + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks, + wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks, + out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks, + in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks, + wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Hi = in_n_hi_wi_c_lengths[I1]; + const auto Wi = in_n_hi_wi_c_lengths[I2]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); +} diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index 0890bf2e7d..6b04f07b2f 100644 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -275,12 +275,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh Sequence<0, 0, 0, 0, 0>{}, // 6-: M2 Sequence<0, 0, 0, 0, 0>{})); // 7-: N1 - constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = - Sequence<0, 0, 0, 0, 0>{}; - constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0>{}; + for(index_t i = 0; i < 5; ++i) { float ave_time = driver_dynamic_gemm_xdlops_v2r3< diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..b6799e1072 --- /dev/null +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp @@ -0,0 +1,249 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_forward_convolution_into_gemm_v4r5r2_nchw_kcyx_nkhw.hpp" +#include "driver_dynamic_contraction_v1r2.hpp" + +template +void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw( + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + const Tensor& wei_k_c_y_x, + Tensor& out_n_k_ho_wo, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + const auto in_n_c_hi_wi_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = + make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); + +#if 1 + // [8, 1, 128, 1] * [8, 4, 32, 1] = [1, 128, 4, 32] for fp32 + // cdata = 64, BlockSize = 256 + constexpr index_t BlockSize = 256; + + constexpr index_t GN0 = 4; + constexpr index_t GK1 = 1; + + constexpr index_t GemmGM1PerBlockGM11 = 128; + constexpr index_t GemmGN1PerBlockGN11 = 32; + constexpr index_t GemmKPerBlock = 8; + + constexpr index_t GemmM1PerThreadM111 = 4; + constexpr index_t GemmN1PerThreadN111 = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmM11N11ThreadClusterM1101 = 2; + constexpr index_t GemmM11N11ThreadClusterN1101 = 2; + constexpr index_t GemmM11N11ThreadClusterM1100 = 8; + constexpr index_t GemmM11N11ThreadClusterN1100 = 8; + + using GemmABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>; + using GemmABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>; + + using GemmABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>; + using GemmABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 1>; + + using GemmBBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 1>; + using GemmBBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>; + + using GemmBBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>; + using GemmBBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_BN1 = 1; +#elif 1 + // [8, 1, 128, 2] * [8, 4, 32, 2] = [1, 128, 4, 32] for fp16 + // cdata = 64, BlockSize = 256 + constexpr index_t BlockSize = 256; + + constexpr index_t GN0 = 4; + constexpr index_t GK1 = 2; + + constexpr index_t GemmGM1PerBlockGM11 = 128; + constexpr index_t GemmGN1PerBlockGN11 = 32; + constexpr index_t GemmKPerBlock = 8; + + constexpr index_t GemmM1PerThreadM111 = 4; + constexpr index_t GemmN1PerThreadN111 = 4; + constexpr index_t GemmKPerThread = 1; + + constexpr index_t GemmM11N11ThreadClusterM1101 = 2; + constexpr index_t GemmM11N11ThreadClusterN1101 = 2; + constexpr index_t GemmM11N11ThreadClusterM1100 = 8; + constexpr index_t GemmM11N11ThreadClusterN1100 = 8; + + using GemmABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 2>; + using GemmABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>; + + using GemmABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>; + using GemmABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 2>; + + using GemmBBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 2>; + using GemmBBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>; + + using GemmBBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>; + using GemmBBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 2>; + + constexpr index_t GemmCThreadTransferDstScalarPerVector_BN1 = 1; +#endif + + const auto descs = + transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}, + Number{}); + + const auto wei_gk0_gm0_gm1_gk1_grid_desc = descs[I0]; + const auto in_gk0_gn0_gn1_gk1_grid_desc = descs[I1]; + const auto out_gm0_gm1_gn0_gn1_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto wei_grid_iterator_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11 + Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11 + Sequence<0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1 + + constexpr auto in_grid_iterator_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1 + + constexpr auto out_grid_iterator_hacks = make_tuple( + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1 + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 5-: GN1 + + constexpr auto wei_grid_move_slice_window_iterator_hacks = Sequence<0, 0, 0, 0, 0, 0, 0>{}; + + constexpr auto in_grid_move_slice_window_iterator_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_dynamic_contraction_v1r2< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperation::Set, + decltype(wei_gk0_gm0_gm1_gk1_grid_desc), + decltype(in_gk0_gn0_gn1_gk1_grid_desc), + decltype(out_gm0_gm1_gn0_gn1_grid_desc), + GemmGM1PerBlockGM11, + GemmGN1PerBlockGN11, + GemmKPerBlock, + GemmM1PerThreadM111, + GemmN1PerThreadN111, + GemmKPerThread, + GemmM11N11ThreadClusterM1100, + GemmM11N11ThreadClusterN1100, + GemmM11N11ThreadClusterM1101, + GemmM11N11ThreadClusterN1101, + GemmABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1, + GemmABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1, + Sequence<1, 2, 3, 0, 4>, // ABlockTransferThreadClusterArrangeOrder + Sequence<3, 2, 1, 0, 4>, // ABlockTransferSrcAccessOrder + GemmABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + GemmABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, + Sequence<0, 1, 2, 3, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder + GemmBBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1, + GemmBBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1, + Sequence<0, 4, 1, 2, 3>, // BBlockTransferThreadClusterArrangeOrder + Sequence<4, 3, 2, 0, 1>, // BBlockTransferSrcAccessOrder + GemmBBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + GemmBBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, + Sequence<0, 1, 2, 3, 4>, // BBlockTransferSrcVectorTensorContiguousDimOrder + Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder + 5, // CThreadTransferSrcDstVectorDim + GemmCThreadTransferDstScalarPerVector_BN1, + decltype(wei_grid_iterator_hacks), + decltype(in_grid_iterator_hacks), + decltype(out_grid_iterator_hacks), + decltype(wei_grid_move_slice_window_iterator_hacks), + decltype(in_grid_move_slice_window_iterator_hacks)>( + static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + wei_gk0_gm0_gm1_gk1_grid_desc, + in_gk0_gn0_gn1_gk1_grid_desc, + out_gm0_gm1_gn0_gn1_grid_desc, + wei_grid_iterator_hacks, + in_grid_iterator_hacks, + out_grid_iterator_hacks, + wei_grid_move_slice_window_iterator_hacks, + in_grid_move_slice_window_iterator_hacks, + nrepeat); + + float perf = (float)calculate_convolution_flops( + in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); +}