From cb97ce68d8df6335d2a2128ee91a973548a1a4cc Mon Sep 17 00:00:00 2001 From: Jianfeng Yan Date: Wed, 30 Mar 2022 11:21:18 -0500 Subject: [PATCH] Batched gemm and reduction (#156) * adding batched_gemm_and_reduction * batched_gemm_reduce works with bactch_count=1 * fix a bug in grid_size; batched_gemm_reduce works for batch_count > 1 * adding profiler for batched_gemm_fp16 * fixed a bug in declaration of d1 and d0; both example and profiler work * clang-format * cleanup * batched_gemm_reduce: add test * minor change * fixed some typo in function names [ROCm/composable_kernel commit: 34c661e71cc8cf4753843a58786c8f6211ec5e22] --- example/01_gemm/gemm_xdl_bf16.cpp | 2 - example/01_gemm/gemm_xdl_fp16.cpp | 2 - example/01_gemm/gemm_xdl_int8.cpp | 2 - .../16_gemm_reduce/gemm_reduce_xdl_fp16.cpp | 3 - example/18_batched_gemm_reduce/CMakeLists.txt | 2 + .../batched_gemm_reduce_xdl_fp16.cpp | 281 ++++++ example/CMakeLists.txt | 1 + ...evice_batched_gemm_reduce_xdl_cshuffle.hpp | 940 ++++++++++++++++++ .../gpu/device/device_batched_gemm_xdl.hpp | 74 +- .../gpu/device/device_gemm_reduce.hpp | 3 +- .../device_gemm_reduce_xdl_cshuffle.hpp | 3 +- .../ck/library/host_tensor/host_tensor.hpp | 9 +- .../gpu/CMakeLists.txt | 1 + .../gpu/batched_gemm_reduce/CMakeLists.txt | 11 + ...6_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp | 70 ++ ...6_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp | 70 ++ ...6_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp | 70 ++ ...6_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp | 67 ++ profiler/CMakeLists.txt | 2 + .../include/profile_batched_gemm_impl.hpp | 1 + .../profile_batched_gemm_reduce_impl.hpp | 354 +++++++ profiler/src/profile_batched_gemm_reduce.cpp | 154 +++ profiler/src/profiler.cpp | 5 + test/CMakeLists.txt | 1 + test/batched_gemm_reduce/CMakeLists.txt | 9 + .../batched_gemm_reduce_fp16.cpp | 64 ++ test/gemm_reduce/gemm_reduce_fp16.cpp | 6 - 27 files changed, 2145 insertions(+), 62 deletions(-) create mode 100644 example/18_batched_gemm_reduce/CMakeLists.txt create mode 100644 example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp create mode 100644 include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp create mode 100644 profiler/include/profile_batched_gemm_reduce_impl.hpp create mode 100644 profiler/src/profile_batched_gemm_reduce.cpp create mode 100644 test/batched_gemm_reduce/CMakeLists.txt create mode 100644 test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp index 5a9091a236..9be781454b 100644 --- a/example/01_gemm/gemm_xdl_bf16.cpp +++ b/example/01_gemm/gemm_xdl_bf16.cpp @@ -5,11 +5,9 @@ #include #include #include "config.hpp" -#include "print.hpp" #include "device.hpp" #include "host_tensor.hpp" #include "host_tensor_generator.hpp" -#include "host_gemm.hpp" #include "device_tensor.hpp" #include "device_gemm_xdl.hpp" #include "device_gemm_xdl_c_shuffle.hpp" diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index 8db97a5b25..5be6deb850 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -5,11 +5,9 @@ #include #include #include "config.hpp" -#include "print.hpp" #include "device.hpp" #include "host_tensor.hpp" #include "host_tensor_generator.hpp" -#include "host_gemm.hpp" #include "device_tensor.hpp" #include "device_gemm_xdl.hpp" #include "device_gemm_xdl_c_shuffle.hpp" diff --git a/example/01_gemm/gemm_xdl_int8.cpp b/example/01_gemm/gemm_xdl_int8.cpp index dfe1eec77f..aaad1397f7 100644 --- a/example/01_gemm/gemm_xdl_int8.cpp +++ b/example/01_gemm/gemm_xdl_int8.cpp @@ -5,11 +5,9 @@ #include #include #include "config.hpp" -#include "print.hpp" #include "device.hpp" #include "host_tensor.hpp" #include "host_tensor_generator.hpp" -#include "host_gemm.hpp" #include "device_tensor.hpp" #include "device_gemm_xdl.hpp" #include "device_gemm_xdl_c_shuffle.hpp" diff --git a/example/16_gemm_reduce/gemm_reduce_xdl_fp16.cpp b/example/16_gemm_reduce/gemm_reduce_xdl_fp16.cpp index 6f173ae1de..673dce82db 100644 --- a/example/16_gemm_reduce/gemm_reduce_xdl_fp16.cpp +++ b/example/16_gemm_reduce/gemm_reduce_xdl_fp16.cpp @@ -5,13 +5,10 @@ #include #include #include "config.hpp" -#include "print.hpp" #include "device.hpp" #include "host_tensor.hpp" #include "host_tensor_generator.hpp" -#include "host_gemm.hpp" #include "device_tensor.hpp" -#include "device_gemm_xdl.hpp" #include "device_gemm_reduce_xdl_cshuffle.hpp" #include "element_wise_operation.hpp" #include "reference_gemm.hpp" diff --git a/example/18_batched_gemm_reduce/CMakeLists.txt b/example/18_batched_gemm_reduce/CMakeLists.txt new file mode 100644 index 0000000000..99fc0043d2 --- /dev/null +++ b/example/18_batched_gemm_reduce/CMakeLists.txt @@ -0,0 +1,2 @@ +add_example_executable(example_batched_gemm_reduce_xdl_fp16 batched_gemm_reduce_xdl_fp16.cpp) + diff --git a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp new file mode 100644 index 0000000000..8e30ef0c79 --- /dev/null +++ b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -0,0 +1,281 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "device_tensor.hpp" +#include "device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "reference_batched_gemm.hpp" +#include "gemm_specialization.hpp" +#include "element_wise_reduce_operation.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using ADataType = F16; +using BDataType = F16; +using CDataType = F16; +using DDataType = F32; + +using ALayout = ck::tensor_layout::gemm::RowMajor; +using BLayout = ck::tensor_layout::gemm::ColumnMajor; +using CLayout = ck::tensor_layout::gemm::RowMajor; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CElementOp = ck::tensor_operation::element_wise::PassThrough; +using D0ReduceOp = ck::tensor_operation::element_wise::ReduceSum; +using D1ReduceOp = ck::tensor_operation::element_wise::ReduceSquareSum; + +static constexpr auto GemmSpecialization = + ck::tensor_operation::device::GemmSpecialization_t::Default; + +// clang-format off +using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatchedGemmReduce_Xdl_CShuffle +//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| +//######| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, AElementOp, BElementOp, CElementOp, D0ReduceOp, D1ReduceOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; +// clang-format on + +using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: + ReferenceBatchedGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = 1; + int init_method = 1; + int nrepeat = 5; + + // GEMM shape + ck::index_t M = 3840; + ck::index_t N = 4096; + ck::index_t K = 4096; + + ck::index_t StrideA = 4096; + ck::index_t StrideB = 4096; + ck::index_t StrideC = 4096; + + ck::index_t BatchCount = 4; + + if(argc == 1) + { + // do nothing + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + } + else if(argc == 11) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + nrepeat = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + StrideA = std::stoi(argv[7]); + StrideB = std::stoi(argv[8]); + StrideC = std::stoi(argv[9]); + + BatchCount = std::stoi(argv[9]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: run kernel # of times (>1)\n"); + printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC, BatchCount\n"); + exit(0); + } + + auto f_host_tensor_descriptor = [](std::size_t batch_count, + std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({row * stride, stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({col * stride, 1, stride})); + } + }; + + Tensor a_g_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{})); + Tensor b_g_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{})); + + Tensor c_g_m_n_host_result( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + Tensor d0_g_m_host_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + Tensor d1_g_m_host_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + + Tensor c_g_m_n_device_result( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + Tensor d0_g_m_device_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + Tensor d1_g_m_device_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; + std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl; + std::cout << "d0_g_m: " << d0_g_m_host_result.mDesc << std::endl; + std::cout << "d1_g_m: " << d1_g_m_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem d0_device_buf(sizeof(DDataType) * d0_g_m_device_result.mDesc.GetElementSpace()); + DeviceMem d1_device_buf(sizeof(DDataType) * d1_g_m_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_g_m_k.mData.data()); + b_device_buf.ToDevice(b_g_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + auto d0_reduce_op = D0ReduceOp{}; + auto d1_reduce_op = D1ReduceOp{}; + + // do GEMM + auto batched_gemm = DeviceBatchedGemmReduceInstance{}; + auto invoker = batched_gemm.MakeInvoker(); + auto argument = + batched_gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + static_cast(d0_device_buf.GetDeviceBuffer()), + static_cast(d1_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + d0_reduce_op, + d1_reduce_op, + BatchCount); + + if(!batched_gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + // warm up + invoker.Run(argument); + + // timing + float total_time = 0; + + for(int i = 0; i < nrepeat; ++i) + { + // init DO, D1 to 0 + d0_device_buf.SetZero(); + d1_device_buf.SetZero(); + + KernelTimer timer; + + timer.Start(); + + invoker.Run(argument); + + timer.End(); + + total_time += timer.GetElapsedTime(); + } + + float ave_time = total_time / nrepeat; + + std::size_t flop = std::size_t(2) * BatchCount * M * N * K; + std::size_t num_btype = sizeof(ADataType) * BatchCount * M * K + + sizeof(BDataType) * BatchCount * K * N + + sizeof(CDataType) * BatchCount * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << batched_gemm.GetTypeString() << std::endl; + + if(do_verification) + { + c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); + d0_device_buf.FromDevice(d0_g_m_device_result.mData.data()); + d1_device_buf.FromDevice(d1_g_m_device_result.mData.data()); + + auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_batched_gemm.MakeInvoker(); + + auto ref_argument = ref_batched_gemm.MakeArgument( + a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + for(int batch = 0; batch < BatchCount; ++batch) + { + for(int m = 0; m < M; ++m) + { + float d0_acc = d0_reduce_op.GetReduceZeroValue(); + float d1_acc = d1_reduce_op.GetReduceZeroValue(); + + for(int n = 0; n < N; ++n) + { + d0_reduce_op.Reduce(d0_acc, c_g_m_n_host_result(batch, m, n)); + d1_reduce_op.Reduce(d1_acc, c_g_m_n_host_result(batch, m, n)); + } + + d0_g_m_host_result(batch, m) = d0_acc; + d1_g_m_host_result(batch, m) = d1_acc; + } + } + + check_error(c_g_m_n_host_result, c_g_m_n_device_result); + check_error(d0_g_m_host_result, d0_g_m_device_result); + check_error(d1_g_m_host_result, d1_g_m_device_result); + } + + return 0; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 0a8051c3e2..830d1189de 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -42,3 +42,4 @@ add_subdirectory(14_gemm_xdl_requant_relu_requant) add_subdirectory(17_convnd_bwd_data_xdl) add_subdirectory(15_grouped_gemm) add_subdirectory(16_gemm_reduce) +add_subdirectory(18_batched_gemm_reduce) diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp new file mode 100644 index 0000000000..46ae7ab2ac --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp @@ -0,0 +1,940 @@ +#pragma once +#include +#include +#include "device.hpp" +#include "device_gemm_reduce.hpp" +#include "common_header.hpp" +#include "tensor_layout.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_reduce_xdl_cshuffle_v1.hpp" +#include "gemm_specialization.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_reduce_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatD* __restrict__ p_d0_grid, + FloatD* __restrict__ p_d1_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const D0ReduceOperation d0_reduce_op, + const D1ReduceOperation d1_reduce_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock, + const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, + const Block2CTileMap block_2_ctile_map) +{ + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetBBasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); + + const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetD0BasePtr(g_idx))); + const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetD1BasePtr(g_idx))); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_d0_grid + d0_batch_offset, + p_d1_grid + d1_batch_offset, + p_shared, + a_element_op, + b_element_op, + c_element_op, + d0_reduce_op, + d1_reduce_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + d_grid_desc_mblock_mperblock, + block_2_ctile_map); +} + +template +struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce +{ + using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpecialization == GemmSpecialization_t::MKPadding || + GemmSpecialization == GemmSpecialization_t::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || + GemmSpecialization == GemmSpecialization_t::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding || + GemmSpecialization == GemmSpecialization_t::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpecialization == GemmSpecialization_t::NKPadding || + GemmSpecialization == GemmSpecialization_t::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding || + GemmSpecialization == GemmSpecialization_t::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpecialization == GemmSpecialization_t::KPadding || + GemmSpecialization == GemmSpecialization_t::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpecialization == GemmSpecialization_t::MNPadding || + GemmSpecialization == GemmSpecialization_t::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || + GemmSpecialization == GemmSpecialization_t::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpecialization == GemmSpecialization_t::NPadding || + GemmSpecialization == GemmSpecialization_t::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + // assume D is packed tensor + static auto MakeDGridDescriptor_M(index_t MRaw) + { + const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto MPad = M - MRaw; + + if constexpr(GemmSpecialization == GemmSpecialization_t::MPadding || + GemmSpecialization == GemmSpecialization_t::MNPadding || + GemmSpecialization == GemmSpecialization_t::MKPadding || + GemmSpecialization == GemmSpecialization_t::MNKPadding) + { + // pad M + return transform_tensor_descriptor(d_grid_desc_mraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad M + return d_grid_desc_mraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using DGridDesc_M = decltype(MakeDGridDescriptor_M(1)); + + static constexpr auto MakeBlock2CTileMap(index_t batch_count, + const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01, + index_t N01) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + const auto M00 = M0 / M01; + const auto N00 = N0 / N01; + + const auto g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_insert_transform(batch_count), + make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); + + const auto globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(batch_count, M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto globalblockid_to_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor); + + return globalblockid_to_m0_n0_block_cluster_adaptor; + } + + struct ComputeBasePtrOfStridedBatch + { + ComputeBasePtrOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t BatchStrideD0, + index_t BatchStrideD1) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideC_(BatchStrideC), + BatchStrideD0_(BatchStrideD0), + BatchStrideD1_(BatchStrideD1) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + + __host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideD0_); + } + + __host__ __device__ constexpr long_index_t GetD1BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideD1_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideC_; + index_t BatchStrideD0_; + index_t BatchStrideD1_; + }; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + ReduceAccDataType, + DDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + D0ReduceOperation, + D1ReduceOperation, + InMemoryDataOperationEnum_t::Set, + InMemoryDataOperationEnum_t::AtomicAdd, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + DGridDesc_M, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + CReduceThreadClusterLengths_MPerBlock_NPerBlock, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock>; + + using Block2CTileMap = decltype(MakeBlock2CTileMap(1, CGridDesc_M_N{}, 1, 1)); + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + DDataType* p_d0_grid, + DDataType* p_d1_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + D0ReduceOperation d0_reduce_op, + D1ReduceOperation d1_reduce_op, + index_t BatchCount) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_d0_grid_{p_d0_grid}, + p_d1_grid_{p_d1_grid}, + BatchCount_(BatchCount), + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + d_grid_desc_m_{DeviceOp::MakeDGridDescriptor_M(MRaw)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + d_grid_desc_mblock_mperblock_{}, + compute_base_ptr_of_batch_{a_grid_desc_ak0_m_ak1_.GetElementSpaceSize(), + b_grid_desc_bk0_n_bk1_.GetElementSpaceSize(), + c_grid_desc_m_n_.GetElementSpaceSize(), + d_grid_desc_m_.GetElementSpaceSize(), + d_grid_desc_m_.GetElementSpaceSize()}, + block_2_ctile_map_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + d0_reduce_op_{d0_reduce_op}, + d1_reduce_op_{d1_reduce_op} + { + if(GridwiseGemm::CheckValidity( + a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + + d_grid_desc_mblock_mperblock_ = + GridwiseGemm::MakeDGridDescriptor_MBlock_MPerBlock(d_grid_desc_m_); + + block_2_ctile_map_ = MakeBlock2CTileMap(BatchCount, c_grid_desc_m_n_, 1, 1); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + DDataType* p_d0_grid_; + DDataType* p_d1_grid_; + index_t BatchCount_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + DGridDesc_M d_grid_desc_m_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock d_grid_desc_mblock_mperblock_; + ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; + Block2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + D0ReduceOperation d0_reduce_op_; + D1ReduceOperation d1_reduce_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, int /* nrepeat */ = 1) + { +#if 0 + { + std::cout << "arg.BatchCount_ = " << arg.BatchCount_ << std::endl; + + std::cout << "arg.a_grid_desc_ak0_m_ak1_{" + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_bk0_n_bk1_{" + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.d_grid_desc_m_{ " << arg.d_grid_desc_m_.GetLength(I0) << "}" + << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity( + arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_) * arg.BatchCount_; + + const auto K0 = arg.a_grid_desc_ak0_m_ak1_.GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + if(has_main_k0_block_loop) + { + const auto kernel = kernel_batched_gemm_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + DDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + D0ReduceOperation, + D1ReduceOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, + ComputeBasePtrOfStridedBatch, + remove_reference_t, + true>; + + launch_kernel(kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_d0_grid_, + arg.p_d1_grid_, + arg.BatchCount_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.d0_reduce_op_, + arg.d1_reduce_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.d_grid_desc_mblock_mperblock_, + arg.compute_base_ptr_of_batch_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_batched_gemm_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + DDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + D0ReduceOperation, + D1ReduceOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DGridDescriptor_MBlock_MPerBlock, + ComputeBasePtrOfStridedBatch, + remove_reference_t, + false>; + + launch_kernel(kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_d0_grid_, + arg.p_d1_grid_, + arg.BatchCount_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.d0_reduce_op_, + arg.d1_reduce_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.d_grid_desc_mblock_mperblock_, + arg.compute_base_ptr_of_batch_, + arg.block_2_ctile_map_); + } + + return 0; + } + + // polymorphic + float Run(const BaseArgument* p_arg, int nrepeat = 1) override + { + return Run(*dynamic_cast(p_arg), nrepeat); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + auto casted_p_arg = dynamic_cast(p_arg); + if(casted_p_arg == nullptr) + { + return false; + } + else + { + return IsSupportedArgument(*casted_p_arg); + } + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + DDataType* p_d0, + DDataType* p_d1, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + D0ReduceOperation d0_reduce_op, + D1ReduceOperation d1_reduce_op, + index_t BatchCount) + { + return Argument{p_a, + p_b, + p_c, + p_d0, + p_d1, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + d0_reduce_op, + d1_reduce_op, + BatchCount}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + void* p_d0, + void* p_d1, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + D0ReduceOperation d0_reduce_op, + D1ReduceOperation d1_reduce_op, + index_t BatchCount) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_d0), + static_cast(p_d1), + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + d0_reduce_op, + d1_reduce_op, + BatchCount); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmReduce_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp index 6daa5af5f2..e21a5cb335 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp @@ -36,7 +36,7 @@ __global__ void const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, - const index_t num_batches, + const index_t batch_count, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, @@ -47,7 +47,7 @@ __global__ void const Block2CTileMap block_2_ctile_map) { const index_t num_blocks_per_batch = - __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( @@ -203,49 +203,43 @@ struct DeviceBatchedGemmXdl using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); - struct Block2CTileMapMaker + static constexpr auto MakeBlock2CTileMap(index_t batch_count, + const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01, + index_t N01) { - Block2CTileMapMaker(index_t num_batches) : num_batches_(num_batches) {} + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); - __host__ __device__ constexpr auto - MakeBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) - { - const auto M = c_grid_desc_m_n.GetLength(I0); - const auto N = c_grid_desc_m_n.GetLength(I1); + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; - constexpr auto M1 = Number{}; - constexpr auto N1 = Number{}; + const auto M0 = M / M1; + const auto N0 = N / N1; - const auto M0 = M / M1; - const auto N0 = N / N1; + const auto M00 = M0 / M01; + const auto N00 = N0 / N01; - const auto M00 = M0 / M01; - const auto N00 = N0 / N01; + const auto g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_insert_transform(batch_count), + make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); - const auto g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_insert_transform(num_batches_), - make_unmerge_transform(make_tuple(M00, M01)), - make_unmerge_transform(make_tuple(N00, N01))), - make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); + const auto globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(batch_count, M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); - const auto globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(num_batches_, M00, N00, M01, N01))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); + const auto globalblockid_to_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor); - const auto globalblockid_to_m0_n0_block_cluster_adaptor = - chain_tensor_adaptors(g_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, - globalblockid_to_m00_m01_n00_n01_block_cluster_adaptor); - - return globalblockid_to_m0_n0_block_cluster_adaptor; - } - - private: - index_t num_batches_; - }; + return globalblockid_to_m0_n0_block_cluster_adaptor; + } struct ComputeBasePtrOfStridedBatch { @@ -320,8 +314,7 @@ struct DeviceBatchedGemmXdl using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); - using Block2CTileMap = - decltype(Block2CTileMapMaker{1}.MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); + using Block2CTileMap = decltype(MakeBlock2CTileMap(1, CGridDesc_M_N{}, 1, 1)); // Argument struct Argument : public BaseArgument @@ -367,8 +360,7 @@ struct DeviceBatchedGemmXdl c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); - block_2_ctile_map_ = - Block2CTileMapMaker{BatchCount}.MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + block_2_ctile_map_ = MakeBlock2CTileMap(BatchCount, c_grid_desc_m_n_, M01, N01); } } diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp index 76ea2fc864..eddc570088 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp @@ -28,7 +28,8 @@ struct DeviceGemmReduce : public BaseOperator BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, D0ReduceOperation d0_reduce_op, - D1ReduceOperation d1_reduce_op) = 0; + D1ReduceOperation d1_reduce_op, + ck::index_t BatchCount = 1) = 0; virtual std::unique_ptr MakeInvokerPointer() = 0; }; diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp index 01ea388f33..7b31bf457d 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp @@ -694,7 +694,8 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce(static_cast(p_a), static_cast(p_b), diff --git a/library/include/ck/library/host_tensor/host_tensor.hpp b/library/include/ck/library/host_tensor/host_tensor.hpp index c70c0e5532..443e0f9e4c 100644 --- a/library/include/ck/library/host_tensor/host_tensor.hpp +++ b/library/include/ck/library/host_tensor/host_tensor.hpp @@ -73,10 +73,10 @@ struct HostTensorDescriptor HostTensorDescriptor() = delete; template - HostTensorDescriptor(std::vector lens); + HostTensorDescriptor(const std::vector& lens); template - HostTensorDescriptor(std::vector lens, std::vector strides); + HostTensorDescriptor(const std::vector& lens, const std::vector& strides); void CalculateStrides(); @@ -285,13 +285,14 @@ struct Tensor }; template -HostTensorDescriptor::HostTensorDescriptor(std::vector lens) : mLens(lens) +HostTensorDescriptor::HostTensorDescriptor(const std::vector& lens) : mLens(lens) { this->CalculateStrides(); } template -HostTensorDescriptor::HostTensorDescriptor(std::vector lens, std::vector strides) +HostTensorDescriptor::HostTensorDescriptor(const std::vector& lens, + const std::vector& strides) : mLens(lens), mStrides(strides) { } diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index f8650c445b..f232c41b5c 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -39,3 +39,4 @@ add_subdirectory(conv2d_bwd_data) add_subdirectory(reduce) add_subdirectory(convnd_bwd_data) add_subdirectory(grouped_gemm) +add_subdirectory(batched_gemm_reduce) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt new file mode 100644 index 0000000000..59eb6cb1cc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt @@ -0,0 +1,11 @@ +set(DEVICE_BATCHED_GEMM_REDUCE_INSTANCE_SOURCE + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp +) + +add_instance_library(device_batched_gemm_reduce_instance ${DEVICE_BATCHED_GEMM_REDUCE_INSTANCE_SOURCE}) +install(TARGETS device_batched_gemm_reduce_instance LIBRARY DESTINATION lib) +clang_tidy_check(device_batched_gemm_reduce_instance) + diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp new file mode 100644 index 0000000000..0144081160 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp @@ -0,0 +1,70 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "element_wise_reduce_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::tensor_operation::element_wise::ReduceSum; +using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +// d0[g, m] = reduce0(c[g, m, n]) +// d1[g, m] = reduce1(c[g, m, n]) +using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 256, 32, 4, 4, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( + std::vector< + DeviceGemmReducePtr>& + instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp new file mode 100644 index 0000000000..873bd1c847 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instance.cpp @@ -0,0 +1,70 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "element_wise_reduce_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::tensor_operation::element_wise::ReduceSum; +using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +// d0[g, m] = reduce0(c[g, m, n]) +// d1[g, m] = reduce1(c[g, m, n]) +using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Col, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( + std::vector< + DeviceGemmReducePtr>& + instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp new file mode 100644 index 0000000000..ec94ed2ace --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp @@ -0,0 +1,70 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "element_wise_reduce_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::tensor_operation::element_wise::ReduceSum; +using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +// d0[g, m] = reduce0(c[g, m, n]) +// d1[g, m] = reduce1(c[g, m, n]) +using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout| AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 256, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 128, 32, 8, 2, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Row, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1> + // clang-format on + >; + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( + std::vector< + DeviceGemmReducePtr>& + instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp new file mode 100644 index 0000000000..ad7e70b31b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_reduce/device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp @@ -0,0 +1,67 @@ +#include +#include "config.hpp" +#include "device_batched_gemm_reduce_xdl_cshuffle.hpp" +#include "element_wise_operation.hpp" +#include "element_wise_reduce_operation.hpp" +#include "device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ReduceSum = ck::tensor_operation::element_wise::ReduceSum; +using ReduceSquareSum = ck::tensor_operation::element_wise::ReduceSquareSum; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +// d0[g, m] = reduce0(c[g, m, n]) +// d1[g, m] = reduce1(c[g, m, n]) +using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances = + std::tuple< + // clang-format off + //##################################| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| ReduceAcc| DData| A| B| C| D0| D1| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| CReduce| CReduceThreadLds2VGprCopy| CReduceThreadVgpr2GlobalCopy| + //##################################| | | | Type| Type| Type| DataType| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Reduce| Reduce| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| + //##################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| + //##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 4>, 8, S<64, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, S<32, 4>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1>, + DeviceBatchedGemmReduce_Xdl_CShuffle< Row, Col, Row, F16, F16, F16, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ReduceSum, ReduceSquareSum, GemmDefault, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, S<32, 2>, 4, 1> + // clang-format on + >; + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( + std::vector< + DeviceGemmReducePtr>& + instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances{}); +} + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/CMakeLists.txt b/profiler/CMakeLists.txt index e3123e1ef6..ae1bcfa52f 100644 --- a/profiler/CMakeLists.txt +++ b/profiler/CMakeLists.txt @@ -35,6 +35,7 @@ set(PROFILER_SOURCE src/profile_convnd_bwd_data.cpp src/profile_reduce.cpp src/profile_grouped_gemm.cpp + src/profile_batched_gemm_reduce.cpp ) add_executable(ckProfiler ${PROFILER_SOURCE}) @@ -54,3 +55,4 @@ target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance) target_link_libraries(ckProfiler PRIVATE device_reduce_instance) target_link_libraries(ckProfiler PRIVATE device_reduce_instance) target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance) +target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance) diff --git a/profiler/include/profile_batched_gemm_impl.hpp b/profiler/include/profile_batched_gemm_impl.hpp index ae17f32591..d57bfd7c09 100644 --- a/profiler/include/profile_batched_gemm_impl.hpp +++ b/profiler/include/profile_batched_gemm_impl.hpp @@ -1,4 +1,5 @@ #pragma once + #include #include "reference_batched_gemm.hpp" diff --git a/profiler/include/profile_batched_gemm_reduce_impl.hpp b/profiler/include/profile_batched_gemm_reduce_impl.hpp new file mode 100644 index 0000000000..75befce848 --- /dev/null +++ b/profiler/include/profile_batched_gemm_reduce_impl.hpp @@ -0,0 +1,354 @@ +#pragma once + +#include "config.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "host_conv.hpp" +#include "tensor_layout.hpp" +#include "device_tensor.hpp" +#include "element_wise_operation.hpp" +#include "element_wise_reduce_operation.hpp" +#include "device_gemm_reduce.hpp" +#include "reference_batched_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace device_gemm_instance { + +using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr< + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::ReduceSum, + ck::tensor_operation::element_wise::ReduceSquareSum>; + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( + std::vector&); + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( + std::vector&); + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( + std::vector&); + +void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( + std::vector&); + +} // namespace device_gemm_instance +} // namespace device +} // namespace tensor_operation +} // namespace ck + +namespace ck { +namespace profiler { + +template +bool profile_batched_gemm_reduce_impl(int do_verification, + int init_method, + bool do_log, + int nrepeat, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int BatchCount) +{ + bool pass = true; + + auto f_host_tensor_descriptor = [](std::size_t batch_count, + std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({row * stride, stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({col * stride, 1, stride})); + } + }; + + Tensor a_g_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{})); + Tensor b_g_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{})); + + Tensor c_g_m_n_host_result( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + Tensor d0_g_m_host_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + Tensor d1_g_m_host_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + + Tensor c_g_m_n_device_result( + f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); + Tensor d0_g_m_device_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + Tensor d1_g_m_device_result(HostTensorDescriptor(std::vector( + {static_cast(BatchCount), static_cast(M)}))); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; + std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl; + std::cout << "d0_g_m: " << d0_g_m_host_result.mDesc << std::endl; + std::cout << "d1_g_m: " << d1_g_m_host_result.mDesc << std::endl; + + std::size_t num_thread = std::thread::hardware_concurrency(); + switch(init_method) + { + case 0: break; + case 1: + std::srand(0); + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + std::srand(0); + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b_g_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + using D0ReduceOp = ck::tensor_operation::element_wise::ReduceSum; + using D1ReduceOp = ck::tensor_operation::element_wise::ReduceSquareSum; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + const auto d0_reduce_op = D0ReduceOp{}; + const auto d1_reduce_op = D1ReduceOp{}; + + if(do_verification) + { + using ReferenceBatchedGemmInstance = + ck::tensor_operation::host::ReferenceBatchedGemm; + + auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; + auto ref_invoker = ref_batched_gemm.MakeInvoker(); + + auto ref_argument = ref_batched_gemm.MakeArgument( + a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + + for(int batch = 0; batch < BatchCount; ++batch) + { + for(int m = 0; m < M; ++m) + { + float d0_acc = d0_reduce_op.GetReduceZeroValue(); + float d1_acc = d1_reduce_op.GetReduceZeroValue(); + + for(int n = 0; n < N; ++n) + { + d0_reduce_op.Reduce(d0_acc, c_g_m_n_host_result(batch, m, n)); + d1_reduce_op.Reduce(d1_acc, c_g_m_n_host_result(batch, m, n)); + } + + d0_g_m_host_result(batch, m) = d0_acc; + d1_g_m_host_result(batch, m) = d1_acc; + } + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); + DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace()); + DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace()); + DeviceMem d0_device_buf(sizeof(DDataType) * d0_g_m_device_result.mDesc.GetElementSpace()); + DeviceMem d1_device_buf(sizeof(DDataType) * d1_g_m_device_result.mDesc.GetElementSpace()); + + a_device_buf.ToDevice(a_g_m_k.mData.data()); + b_device_buf.ToDevice(b_g_k_n.mData.data()); + + // add device GEMM instances + std::vector + gemm_ptrs; + + if constexpr(is_same::value && is_same::value && + is_same::value) + { + if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( + gemm_ptrs); + } + else if constexpr(is_same::value && + is_same::value && + is_same::value) + { + ck::tensor_operation::device::device_gemm_instance:: + add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( + gemm_ptrs); + } + } + + if(gemm_ptrs.size() <= 0) + { + throw std::runtime_error("wrong! no device GEMM instance found"); + } + + std::string best_gemm_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device GEMM instances + for(auto& gemm_ptr : gemm_ptrs) + { + auto argument_ptr = + gemm_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + static_cast(d0_device_buf.GetDeviceBuffer()), + static_cast(d1_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + d0_reduce_op, + d1_reduce_op, + BatchCount); + + auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); + + if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) + { + // warm up + invoker_ptr->Run(argument_ptr.get()); + + // timing + float total_time = 0; + + for(int i = 0; i < nrepeat; ++i) + { + // init DO, D1 to 0 + d0_device_buf.SetZero(); + d1_device_buf.SetZero(); + + KernelTimer timer; + + timer.Start(); + + invoker_ptr->Run(argument_ptr.get()); + + timer.End(); + + total_time += timer.GetElapsedTime(); + } + + float ave_time = total_time / nrepeat; + + std::string gemm_name = gemm_ptr->GetTypeString(); + + std::size_t flop = std::size_t(2) * BatchCount * M * N * K; + std::size_t num_btype = sizeof(ADataType) * BatchCount * M * K + + sizeof(BDataType) * BatchCount * K * N + + sizeof(CDataType) * BatchCount * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm_name << std::endl; + + if(tflops > best_tflops) + { + best_gemm_name = gemm_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); + d0_device_buf.FromDevice(d0_g_m_device_result.mData.data()); + d1_device_buf.FromDevice(d1_g_m_device_result.mData.data()); + + float c_error = check_error(c_g_m_n_host_result, c_g_m_n_device_result); + float d0_error = check_error(d0_g_m_host_result, d0_g_m_device_result); + float d1_error = check_error(d1_g_m_host_result, d1_g_m_device_result); + + pass = pass && (c_error < 1E-6); + pass = pass && (d0_error < 1E-6); + pass = pass && (d1_error < 1E-6); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_g_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_g_k_n.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host: ", c_g_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_g_m_n_device_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "d0_host: ", d0_g_m_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d0_device: ", d0_g_m_device_result.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "d1_host: ", d1_g_m_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "d1_device: ", d1_g_m_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << "does not support this GEMM problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/profile_batched_gemm_reduce.cpp b/profiler/src/profile_batched_gemm_reduce.cpp new file mode 100644 index 0000000000..61f22ba003 --- /dev/null +++ b/profiler/src/profile_batched_gemm_reduce.cpp @@ -0,0 +1,154 @@ +#include +#include +#include +#include +#include +#include + +#include "profile_batched_gemm_reduce_impl.hpp" + +int profile_batched_gemm_reduce(int argc, char* argv[]) +{ + enum struct GemmMatrixLayout_t + { + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 + }; + + enum struct GemmReduceDataType_t + { + F32_F32_F32_F32_F32, // 0 + F16_F16_F16_F32_F32, // 1 + }; + + if(!(argc == 15 || argc == 16)) + { + printf("arg1: tensor operation (batched_gemm: BatchedGEMM+Reduce)\n"); + printf("arg2: data type (0: fp32; 1: fp16)\n"); + printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg8: print tensor value (0: no; 1: yes)\n"); + printf("arg7: run kernel # of times (>1)\n"); + printf("arg8 to 14: M, N, K, StrideA, StrideB, StrideC, BatchCount\n"); + printf("arg15: split k into mulitiple batch\n"); + exit(1); + } + + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const int nrepeat = std::stoi(argv[7]); + + const int M = std::stoi(argv[8]); + const int N = std::stoi(argv[9]); + const int K = std::stoi(argv[10]); + + const int StrideA = std::stoi(argv[11]); + const int StrideB = std::stoi(argv[12]); + const int StrideC = std::stoi(argv[13]); + + const int BatchCount = std::stoi(argv[14]); + + if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout_t::MK_KN_MN) + { + ck::profiler::profile_batched_gemm_reduce_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout_t::MK_NK_MN) + { + ck::profiler::profile_batched_gemm_reduce_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? K : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout_t::KM_KN_MN) + { + ck::profiler::profile_batched_gemm_reduce_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? N : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else if(data_type == GemmReduceDataType_t::F16_F16_F16_F32_F32 && + layout == GemmMatrixLayout_t::KM_NK_MN) + { + ck::profiler::profile_batched_gemm_reduce_impl( + do_verification, + init_method, + do_log, + nrepeat, + M, + N, + K, + (StrideA < 0) ? M : StrideA, + (StrideB < 0) ? K : StrideB, + (StrideC < 0) ? N : StrideC, + BatchCount); + } + else + { + throw std::runtime_error("wrong! this data_type & layout is not implemented"); + } + + return 1; +} diff --git a/profiler/src/profiler.cpp b/profiler/src/profiler.cpp index a4cd23ee22..24e5ae7e3e 100644 --- a/profiler/src/profiler.cpp +++ b/profiler/src/profiler.cpp @@ -17,6 +17,7 @@ int profile_conv_fwd_bias_relu_add(int, char*[]); int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); int profile_convnd_bwd_data(int, char*[], int); int profile_reduce(int, char*[]); +int profile_batched_gemm_reduce(int, char*[]); int main(int argc, char* argv[]) { @@ -44,6 +45,10 @@ int main(int argc, char* argv[]) { return profile_batched_gemm(argc, argv); } + else if(strcmp(argv[1], "batched_gemm_reduce") == 0) + { + return profile_batched_gemm_reduce(argc, argv); + } else if(strcmp(argv[1], "grouped_gemm") == 0) { profile_grouped_gemm(argc, argv); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index c9fe83f040..b1a397122b 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -39,6 +39,7 @@ add_subdirectory(gemm) add_subdirectory(gemm_split_k) add_subdirectory(gemm_reduce) add_subdirectory(batched_gemm) +add_subdirectory(batched_gemm_reduce) add_subdirectory(grouped_gemm) add_subdirectory(convnd_fwd) add_subdirectory(reduce) diff --git a/test/batched_gemm_reduce/CMakeLists.txt b/test/batched_gemm_reduce/CMakeLists.txt new file mode 100644 index 0000000000..3ecf19491b --- /dev/null +++ b/test/batched_gemm_reduce/CMakeLists.txt @@ -0,0 +1,9 @@ +include_directories(BEFORE + ${PROJECT_SOURCE_DIR}/profiler/include + ${PROJECT_SOURCE_DIR}/test/include + ${PROJECT_SOURCE_DIR}/external/include/half +) + +add_test_executable(test_batched_gemm_reduce_fp16 batched_gemm_reduce_fp16.cpp) +target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE host_tensor) +target_link_libraries(test_batched_gemm_reduce_fp16 PRIVATE device_batched_gemm_reduce_instance) diff --git a/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp b/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp new file mode 100644 index 0000000000..ce061c644b --- /dev/null +++ b/test/batched_gemm_reduce/batched_gemm_reduce_fp16.cpp @@ -0,0 +1,64 @@ +#include + +#include "profile_batched_gemm_reduce_impl.hpp" + +int main() +{ + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + int M = 512; + int N = 256; + int K = 128; + + int BatchCount = 3; + + bool pass = true; + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, 1, M, N, K, K, N, N, BatchCount); + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, 1, M, N, K, K, K, N, BatchCount); + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, 1, M, N, K, M, N, N, BatchCount); + + pass = pass && ck::profiler::profile_batched_gemm_reduce_impl( + true, 1, false, 1, M, N, K, M, K, N, BatchCount); + + if(pass) + { + std::cout << "test BatchedGEMM+Reduce fp16: Pass" << std::endl; + return 0; + } + else + { + std::cout << "test BatchedGEMM+Reduce fp16: Fail" << std::endl; + return -1; + } +} diff --git a/test/gemm_reduce/gemm_reduce_fp16.cpp b/test/gemm_reduce/gemm_reduce_fp16.cpp index 0b3421a667..8deb66b2b0 100644 --- a/test/gemm_reduce/gemm_reduce_fp16.cpp +++ b/test/gemm_reduce/gemm_reduce_fp16.cpp @@ -1,10 +1,4 @@ -#include -#include -#include #include -#include -#include -#include #include "profile_gemm_reduce_impl.hpp"