From 72b77167447e008c02647812a5474ba222b1d4e0 Mon Sep 17 00:00:00 2001 From: Anthony Chang Date: Sat, 13 Aug 2022 13:16:14 +0800 Subject: [PATCH] Fused attention (#345) * initial stub for gemm_gemm_xdl_cshuffle * set up example code * compiles * prevent integer overflow * harmonize interface between ref_gemm and ref_batched_gemm * batched_gemm_gemm * fix example * host tensor gen: diagonal pattern in lowest two-dimensions only * make c descriptors containing only integral constants * clean up * add BlockwiseGemmXdlops_v2 while exploring an unified approach * implement proper interface * tidy up example * fix compilation warnings * coarsely controlled 2nd gemm padding * remove rocm-cmake's hard requirement for certain revision * clang-format * resolve merge conflict * fix compilation error on gfx10 * adds acc0 elementwise op to interface * attention host validation * add blockwsie softmax v1 * iteratively update softmax+gemm * transpose both gemm0 and gemm1 xdl output so as to avoid broadcasting softmax max/sum * add init method for easier debugging * do away with manual thread cluster calculation * generalize blockwise softmax interface * row-wise softmax sum & max * format * rename to DeviceBatchedGemmSoftmaxGemm * add gemm_softmax_gemm instances and tests * comment Co-authored-by: ltqin Co-authored-by: Chao Liu [ROCm/composable_kernel commit: cac014f17355d6504b618f5945c6326a285db7e9] --- CMakeLists.txt | 2 +- .../batched_gemm_reduce_xdl_fp16.cpp | 10 +- .../batched_gemm_e_permute_xdl_fp16.cpp | 9 +- example/32_batched_gemm_gemm/CMakeLists.txt | 2 + .../batched_gemm_softmax_gemm_xdl_fp16.cpp | 392 +++++++ example/CMakeLists.txt | 1 + .../tensor_description/tensor_descriptor.hpp | 7 + .../gpu/block/blockwise_gemm_xdlops.hpp | 373 ++++++ .../gpu/block/blockwise_softmax.hpp | 96 ++ .../block/reduction_functions_blockwise.hpp | 72 ++ .../gpu/device/device_batched_gemm_gemm.hpp | 86 ++ .../device_batched_gemm_softmax_gemm.hpp | 87 ++ ...batched_gemm_softmax_gemm_xdl_cshuffle.hpp | 916 +++++++++++++++ ...ched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp | 1021 +++++++++++++++++ .../threadwise_tensor_slice_transfer.hpp | 114 +- .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 62 +- include/ck/utility/static_buffer.hpp | 32 +- .../statically_indexed_array_multi_index.hpp | 66 +- .../cpu/reference_batched_gemm.hpp | 11 +- .../gpu/batched_gemm_softmax_gemm.hpp | 93 ++ .../library/utility/host_tensor_generator.hpp | 19 + .../gpu/CMakeLists.txt | 1 + .../batched_gemm_softmax_gemm/CMakeLists.txt | 8 + ...6_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp | 68 ++ .../include/profile_batched_gemm_impl.hpp | 1 + .../profile_batched_gemm_reduce_impl.hpp | 1 + ...profile_batched_gemm_softmax_gemm_impl.hpp | 325 ++++++ test/CMakeLists.txt | 1 + test/batched_gemm_softmax_gemm/CMakeLists.txt | 5 + .../test_batched_gemm_softmax_gemm_fp16.cpp | 39 + .../test_batched_gemm_softmax_gemm_util.hpp | 68 ++ 31 files changed, 3957 insertions(+), 31 deletions(-) create mode 100644 example/32_batched_gemm_gemm/CMakeLists.txt create mode 100644 example/32_batched_gemm_gemm/batched_gemm_softmax_gemm_xdl_fp16.cpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp create mode 100644 include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp create mode 100644 include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp create mode 100644 include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp create mode 100644 profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp create mode 100644 test/batched_gemm_softmax_gemm/CMakeLists.txt create mode 100644 test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp create mode 100644 test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 9f70620741..ef46d96f4d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,7 +8,7 @@ list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") enable_testing() set(ROCM_SYMLINK_LIBS OFF) -find_package(ROCM 0.8 REQUIRED PATHS /opt/rocm) +find_package(ROCM REQUIRED PATHS /opt/rocm) include(ROCMInstallTargets) include(ROCMPackageConfigHelpers) diff --git a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp index eaea725efa..fb019faa42 100644 --- a/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp +++ b/example/18_batched_gemm_reduce/batched_gemm_reduce_xdl_fp16.cpp @@ -66,8 +66,14 @@ using DeviceBatchedGemmReduceInstance = ck::tensor_operation::device::DeviceBatc < Row, Col, Row, F16, F16, F16, F32, F32, F32, ReducePtrsGlobal, AElementOp, BElementOp, CElementOp, ReduceOps, ReduceInElementOps, ReduceOutElementOps, ReduceGlobalMemOps, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; // clang-format on -using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: - ReferenceBatchedGemm; +using ReferenceBatchedGemmInstance = + ck::tensor_operation::host::ReferenceBatchedGemm; int main(int argc, char* argv[]) { diff --git a/example/24_batched_gemm_e_permute/batched_gemm_e_permute_xdl_fp16.cpp b/example/24_batched_gemm_e_permute/batched_gemm_e_permute_xdl_fp16.cpp index e377530584..5b7f988134 100644 --- a/example/24_batched_gemm_e_permute/batched_gemm_e_permute_xdl_fp16.cpp +++ b/example/24_batched_gemm_e_permute/batched_gemm_e_permute_xdl_fp16.cpp @@ -51,8 +51,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmEPermu < ALayout, BLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; // clang-format on -using ReferenceBatchedGemmInstance = ck::tensor_operation::host:: - ReferenceBatchedGemm; +using ReferenceBatchedGemmInstance = ck::tensor_operation::host::ReferenceBatchedGemm; int main(int argc, char* argv[]) { diff --git a/example/32_batched_gemm_gemm/CMakeLists.txt b/example/32_batched_gemm_gemm/CMakeLists.txt new file mode 100644 index 0000000000..ca4fb026cb --- /dev/null +++ b/example/32_batched_gemm_gemm/CMakeLists.txt @@ -0,0 +1,2 @@ +# TODO: add example batched_gemm_gemm_xdl_fp16 +add_example_executable(example_batched_gemm_softmax_gemm_xdl_fp16 batched_gemm_softmax_gemm_xdl_fp16.cpp) diff --git a/example/32_batched_gemm_gemm/batched_gemm_softmax_gemm_xdl_fp16.cpp b/example/32_batched_gemm_gemm/batched_gemm_softmax_gemm_xdl_fp16.cpp new file mode 100644 index 0000000000..18b0ea79a6 --- /dev/null +++ b/example/32_batched_gemm_gemm/batched_gemm_softmax_gemm_xdl_fp16.cpp @@ -0,0 +1,392 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Gemm fused operation. Computes C_m_o = A_m_k * B0_k_n * B1_n_o + |------------| + Gemm0 + |---------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; + +using ALayout = Row; +using B0Layout = Col; +using B1Layout = Row; +using CLayout = Row; + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = PassThrough; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< + ALayout, + B0Layout, + B1Layout, + CLayout, + ADataType, + B0DataType, + B1DataType, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + GemmDefault, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 64, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 2, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 2, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<16, 16, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 2, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock + +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 64; + ck::index_t O = 128; + ck::index_t BatchCount = 4; + ck::index_t StrideA = -1; + ck::index_t StrideB0 = -1; + ck::index_t StrideB1 = -1; + ck::index_t StrideC = -1; + ck::index_t BatchStrideA = -1; + ck::index_t BatchStrideB0 = -1; + ck::index_t BatchStrideB1 = -1; + ck::index_t BatchStrideC = -1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 9) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + + BatchCount = std::stoi(argv[8]); + } + else if(argc == 17) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + O = std::stoi(argv[7]); + + BatchCount = std::stoi(argv[8]); + + StrideA = std::stoi(argv[9]); + StrideB0 = std::stoi(argv[10]); + StrideB1 = std::stoi(argv[11]); + StrideC = std::stoi(argv[12]); + + BatchStrideA = std::stoi(argv[13]); + BatchStrideB0 = std::stoi(argv[14]); + BatchStrideB1 = std::stoi(argv[15]); + BatchStrideC = std::stoi(argv[16]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 17: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, " + "BatchStrideB0, BatchStrideB1, BatchStrideC\n"); + exit(0); + } + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB0 = ck::is_same_v ? N : K; + const int DefaultStrideB1 = ck::is_same_v ? O : N; + const int DefaultStrideC = ck::is_same_v ? O : M; + + StrideA = (StrideA < 0) ? DefaultStrideA : StrideA; + StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0; + StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1; + StrideC = (StrideC < 0) ? DefaultStrideC : StrideC; + + const int DefaultBatchStrideA = (ck::is_same_v ? K : M) * StrideA; + const int DefaultBatchStrideB0 = (ck::is_same_v ? N : K) * StrideB0; + const int DefaultBatchStrideB1 = (ck::is_same_v ? O : N) * StrideB1; + const int DefaultBatchStrideC = (ck::is_same_v ? O : M) * StrideC; + + BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA; + BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0; + BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1; + BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC; + + auto f_host_tensor_descriptor = [](std::size_t batch_count, + std::size_t row, + std::size_t col, + std::size_t stride, + std::size_t batch_stride, + auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({batch_stride, stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({batch_stride, 1, stride})); + } + }; + + // C_m_o = A_m_k * B0_k_n * B1_n_o + Tensor a_g_m_k( + f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{})); + Tensor b0_g_k_n( + f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{})); + Tensor b1_g_n_o( + f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{})); + Tensor c_g_m_o_host_result( + f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{})); + Tensor c_g_m_o_device_result( + f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{})); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl; + std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl; + std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSize()); + DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize()); + DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize()); + DeviceMem c_g_m_o_device_buf(sizeof(CDataType) * c_g_m_o_device_result.mDesc.GetElementSize()); + + a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data()); + b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); + b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = + gemm.MakeArgument(static_cast(a_g_m_k_device_buf.GetDeviceBuffer()), + static_cast(b0_g_k_n_device_buf.GetDeviceBuffer()), + static_cast(b1_g_n_o_device_buf.GetDeviceBuffer()), + static_cast(c_g_m_o_device_buf.GetDeviceBuffer()), + M, + N, + K, + O, + BatchCount, + StrideA, + StrideB0, + StrideB1, + StrideC, + BatchStrideA, + BatchStrideB0, + BatchStrideB1, + BatchStrideC, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data()); + + if(do_verification) + { + // Output of Gemm0 is input A of Gemm1 + Tensor acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); + + Tensor a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); + + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, PassThrough{}); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument( + a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + return ck::utils::check_err(c_g_m_o_device_result.mData, c_g_m_o_host_result.mData) ? 0 : 1; + } + + return 0; +} diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 08e52be54a..0838d5a19b 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -46,3 +46,4 @@ add_subdirectory(28_grouped_gemm_bias_e_permute) add_subdirectory(29_batched_gemm_bias_e_permute) add_subdirectory(30_grouped_convnd_fwd_bias_relu) add_subdirectory(31_grouped_convnd_fwd_bias_relu_add) +add_subdirectory(32_batched_gemm_gemm) diff --git a/include/ck/tensor_description/tensor_descriptor.hpp b/include/ck/tensor_description/tensor_descriptor.hpp index 1e69736ecc..f07d5b1733 100644 --- a/include/ck/tensor_description/tensor_descriptor.hpp +++ b/include/ck/tensor_description/tensor_descriptor.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck/utility/common_header.hpp" +#include "ck/utility/sequence_helper.hpp" #include "ck/tensor_description/multi_index_transform.hpp" namespace ck { @@ -159,6 +160,12 @@ struct TensorDescriptor return transforms_[Number{}].GetUpperLengths()[Number{}]; } + __host__ __device__ constexpr auto GetLengths() const + { + // FIXME: use Tuple of reference instead + return generate_sequence_v2([&](auto I) { return GetLength(I); }, Number{}); + } + __host__ __device__ constexpr auto GetElementSize() const { return element_size_; } __host__ __device__ constexpr auto GetElementSpaceSize() const { return element_space_size_; } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index 9720db4a95..69a00c8e54 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -25,6 +25,22 @@ constexpr LoopScheduler make_default_loop_scheduler() #endif // if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING } +template +__host__ __device__ static constexpr auto +MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&) +{ + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); +} + template {}.K0PerXdlops, + index_t BMmaKStride = + KPack* XdlopsGemm{}.K0PerXdlops> +struct BlockwiseGemmXdlops_v2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); + static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = XdlopsGemm{}; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto tmp = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + const auto blk_idx = + TransposeC ? make_multi_index(tmp[I1], tmp[I0]) : make_multi_index(tmp[I0], tmp[I1]); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); + + __host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), + Tuple4 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); + } + + __host__ __device__ BlockwiseGemmXdlops_v2(const BlockwiseGemmXdlops_v2& other) + : a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin) + { + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, N, M0, M1, M2)); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k; + static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k; + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_for<0, KPerThread / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(i) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + + protected: + // A[M0, M1, M2, KPerThread] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // B[N0, N1, N2, KPerThread] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp b/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp new file mode 100644 index 0000000000..505f3fa185 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" + +namespace ck { + +template +struct BlockwiseSoftmax +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0); + static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1); + + using ThreadSliceDesc_M = decltype( + make_naive_tensor_descriptor_packed(make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0)))); + + using ThreadwiseMaxReduce = ThreadwiseReduction; + + using ThreadClusterLengths_M_K = decltype(ThreadClusterDesc_M_K{}.GetLengths()); + + using BlockwiseMaxReduce = PartitionedBlockwiseReduction_v2; + + using BlockwiseSumReduce = PartitionedBlockwiseReduction_v2; + + using ThreadwiseSumReduce = ThreadwiseReduction; + + using BufferType = StaticBuffer; + + template + __host__ __device__ void Run(CThreadBuffer& in_thread_buf, WorkspaceBuffer& reduce_work_buf) + { + // find max value + static_for<0, MRepeat, 1>{}([&](auto I) { + max_value_buf(I) = reduce::Max::template GetIdentityValue(); + }); + ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf); + static_for<0, MRepeat, 1>{}([&](auto I) { + BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I)); + block_sync_lds(); + }); + + // calculate exp for elements, P=exp(s-max) + static_for<0, MRepeat, 1>{}([&](auto iM) { + static_for<0, KRepeat, 1>{}([&](auto iK) { + auto offset = Number{}; + in_thread_buf(offset) = math::exp(in_thread_buf[offset] - max_value_buf(iM)); + }); + }); + + // sum data + static_for<0, MRepeat, 1>{}([&](auto I) { + sum_value_buf(I) = reduce::Add::template GetIdentityValue(); + }); + ThreadwiseSumReduce::Reduce(in_thread_buf, sum_value_buf); + static_for<0, MRepeat, 1>{}([&](auto I) { + BlockwiseSumReduce::Reduce(reduce_work_buf, sum_value_buf(I)); + block_sync_lds(); + }); + } + + BufferType max_value_buf; + BufferType sum_value_buf; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp b/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp index cce560367f..2163ad3238 100644 --- a/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp +++ b/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp @@ -82,6 +82,78 @@ struct PartitionedBlockwiseReduction }; }; +// clang-format off +// Assume: +// 1) work_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data +// 2) work_buffer has AccDataType elements, and space size is no less than BlockSize +// 3) in_out_value is the input data in vgpr from each thread +// 4) in_out_value is the over-written reduced output in vgpr for each thread +// clang-format on +template > +struct PartitionedBlockwiseReduction_v2 +{ + static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1), + "The product of cluster lengths should be same as BlockSize!"); + + static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0); + static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1); + + static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements"); + + static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto thread_cluster_desc = ThreadClusterDesc{}; + + template + __device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value) + { + static_assert(is_same{}, + "Buffer data type should be consistent as AccDataType!"); + + constexpr auto cluster_len_shift = get_shift(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); + + const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}]; + const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}]; + + work_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value; + + __syncthreads(); + + static_for<0, cluster_len_shift, 1>{}([&](auto I) { + constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I()); + + if(thread_k_cluster_id < indOffset) + { + index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx); + index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + + make_tuple(0, indOffset)); + + AccDataType opData1 = work_buffer[offset1]; + AccDataType opData2 = work_buffer[offset2]; + Accumulation::Calculate(opData1, opData2); + work_buffer(offset1) = opData1; + } + + __syncthreads(); + }); + + index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0)); + + in_out_value = work_buffer[offset]; + }; +}; + // clang-format off // Assume: // 1) work_val_buffer/work_idx_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp new file mode 100644 index 0000000000..08fc161eb5 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemmGemm : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b0, + const void* p_b1, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t O, + ck::index_t Batch, + ck::index_t StrideA, + ck::index_t StrideB0, + ck::index_t StrideB1, + ck::index_t StrideC, + ck::index_t BatchStrideA, + ck::index_t BatchStrideB0, + ck::index_t BatchStrideB1, + ck::index_t BatchStrideC, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + Acc0ElementwiseOperation acc0_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceBatchedGemmGemmPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp new file mode 100644 index 0000000000..f75a61d9fd --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b0, + const void* p_b1, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t O, + ck::index_t Batch, + ck::index_t StrideA, + ck::index_t StrideB0, + ck::index_t StrideB1, + ck::index_t StrideC, + ck::index_t BatchStrideA, + ck::index_t BatchStrideB0, + ck::index_t BatchStrideB1, + ck::index_t BatchStrideC, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + Acc0ElementwiseOperation acc0_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceBatchedGemmSoftmaxGemmPtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp new file mode 100644 index 0000000000..45edf196cf --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -0,0 +1,916 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatAB* __restrict__ p_b1_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const B1ElementwiseOperation b1_element_op, + const CElementwiseOperation c_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const index_t batch_count, + const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_b1_grid; + ignore = p_c_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = acc_element_op; + ignore = b1_element_op; + ignore = c_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = b1_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_ctile_map; + ignore = batch_count; + ignore = compute_base_ptr_of_batch; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +// Computes C = A * B0 * B1 +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template +struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle + : public DeviceBatchedGemmSoftmaxGemm +{ + using DeviceOp = DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + // Args: Gemm1KRaw, Gemm1NRaw, StrideB1 + static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b1_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, Gemm1NPerBlock) * Gemm1NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, Gemm1KPerBlock) * Gemm1KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + // TODO: implement finer-grained padding + if constexpr(GemmSpec == GemmSpecialization::Default) + { + const auto B1K0 = KRaw / B1K1; + + const auto b1_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b1_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b1_grid_desc_bk0_n_bk1; + } + else + { + // pad both B1N and B1K + const auto B1K0 = K / B1K1; + + const auto b1_grid_desc_n_k = + transform_tensor_descriptor(b1_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b1_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b1_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b1_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, Gemm1NPerBlock) * Gemm1NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + struct ComputeBasePtrOfStridedBatch + { + ComputeBasePtrOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideB1, + index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideB1_(BatchStrideB1), + BatchStrideC_(BatchStrideC) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB1_); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideB1_; + index_t BatchStrideC_; + }; + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + B1GridDesc_BK0_N_BK1, + CGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + AK1, + BK1, + B1K1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + Gemm1NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + true, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + true, + BBlockLdsExtraN, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + false, + B1BlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Gemm1NRaw, // = ORaw + index_t Batch, + index_t StrideA, + index_t StrideB, + index_t StrideB1, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideB1, + index_t BatchStrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + b1_grid_desc_bk0_n_bk1_{ + DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(NRaw, Gemm1NRaw, StrideB1)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, Gemm1NRaw, StrideC)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + acc_element_op_{acc_element_op}, + b1_element_op_{b1_element_op}, + c_element_op_{c_element_op}, + batch_count_(Batch), + compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + b1_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + AccElementwiseOperation acc_element_op_; + B1ElementwiseOperation b1_element_op_; + CElementwiseOperation c_element_op_; + index_t batch_count_; + ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_; + + // Gemm0_K + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_) { + const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::B1GridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + ComputeBasePtrOfStridedBatch, + has_main_k_block_loop_>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_b1_grid_, + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.acc_element_op_, + arg.b1_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_, + arg.batch_count_, + arg.compute_base_ptr_of_batch_); + }; + + // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need + // to concern Gemm0's loop + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + const B1DataType* p_b1, + CDataType* p_c, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Gemm1NRaw, + index_t Batch, + index_t StrideA, + index_t StrideB, + index_t StrideB1, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideB1, + index_t BatchStrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, p_b, p_b1, p_c, MRaw, + NRaw, KRaw, Gemm1NRaw, Batch, StrideA, + StrideB, StrideB1, StrideC, BatchStrideA, BatchStrideB, + BatchStrideB1, BatchStrideC, a_element_op, b_element_op, acc_element_op, + b1_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_b1, + void* p_c, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Gemm1NRaw, + index_t Batch, + index_t StrideA, + index_t StrideB, + index_t StrideB1, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideB1, + index_t BatchStrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_b1), + static_cast(p_c), + MRaw, + NRaw, + KRaw, + Gemm1NRaw, + Batch, + StrideA, + StrideB, + StrideB1, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideB1, + BatchStrideC, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << Gemm1NPerBlock << ", " + << Gemm1KPerBlock << ", " + << B1K1 << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp new file mode 100644 index 0000000000..7e0fbb7989 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -0,0 +1,1021 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp" + +namespace ck { + +template +struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle +{ + static_assert(LoopSched == LoopScheduler::Default, + "Non-default loop scheduler is currently not supported"); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + // Gemm0 + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + // Gemm1 + static constexpr auto B1K0 = Number{}; + static constexpr auto B1K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + template + __host__ __device__ static constexpr auto + MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t Gemm1NWaves = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B1 matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(B1K0, Number{}, B1K1), + make_tuple(Number{} * B1K1, B1K1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + return math::max((SharedMemTrait::a_block_space_size_aligned + + SharedMemTrait::b_block_space_size_aligned) * + sizeof(FloatAB) + + SharedMemTrait::reduction_workspace * sizeof(FloatGemmAcc), + SharedMemTrait::c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1))) + { + return false; + } + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && + Gemm1N % Gemm1NPerBlock == 0)) + { + return false; + } + + // check gemm0 gridwise gemm pipeline + const auto num_gemm0_k_loop = K / KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop)) + { + return false; + } + + // check gemm1 gridwise gemm pipeline + if(!(NPerBlock % Gemm1KPerBlock == 0)) + { + return false; + } + + const auto num_gemm1_k_inner_loop = NPerBlock / Gemm1KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop)) + { + return false; + } + + assert(num_gemm1_k_outer_loop * num_gemm1_k_inner_loop == N / Gemm1KPerBlock); + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / Gemm1NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment + static constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + static constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + static constexpr auto b1_block_desc_bk0_n_bk1 = + GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1); + + static constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + static constexpr auto b0_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple( + b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // B1 can reuse B's LDS + static constexpr auto b_block_space_size_aligned = + math::max(b0_block_space_size_aligned.value, b1_block_space_size_aligned.value); + + // LDS allocation for reduction + static constexpr index_t reduction_workspace = BlockSize; + + // LDS allocation for C shuffle in LDS + static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + static constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + }; + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatAB* __restrict__ p_b1_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const AccElementwiseOperation& acc_element_op, + const B1ElementwiseOperation& b1_element_op, + const CElementwiseOperation& c_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + const auto b1_grid_buf = make_dynamic_buffer( + p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // + // set up Gemm0 + // + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + true, // SrcResetCoord + true, // DstResetCoord + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + true, // SrcResetCoord + true, // DstResetCoord + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), // will loop over GemmN dimension + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + // Fused Gemm+Gemm pipeline + // for n in N0: + // for k in K0: + // acc[m][n] += A[m][k] * B0[k][n] + // acc1[m][o] += acc[m][n] * B1[n][o] + + // sanity check + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_v2< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)), + decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)), + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + true>{}; // TransposeC + + auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + const auto a_block_reset_copy_step = + make_multi_index(-a_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0); + const auto b_block_reset_copy_step = + make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0); + + // gridwise GEMM pipeline + // Only supports LoopScheduler::Default + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // + // set up Gemm1 + // + + // Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type + constexpr auto acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); + constexpr auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); + constexpr auto m1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); + constexpr auto n1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3); + constexpr auto m2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4); + constexpr auto n2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5); + constexpr auto n3 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6); + constexpr auto n4 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7); + + constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / B1K1, 0, 0); + + // acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1 + // n0_n1_n2_n3 -> k0 + // m0_m1_m2 -> m + // n4 -> k1 + // NOTE: had to use merge_v3 or will spit out compilation errors + constexpr auto acc_thread_desc_k0_m_k1 = transform_tensor_descriptor( + acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)), + make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)), + make_pass_through_transform(n4)), + make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // A1 matrix in AccVGPR + // N2 num_groups_per_blk, N3 num_input_blks, N4 group_size + constexpr auto AccN3 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I6); + + constexpr auto A1ThreadSlice_K0_M_K1 = + make_tuple(Number{}, Number{}, Number{}); + + constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0]; + constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1]; + constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2]; + constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor( + A1ThreadSlice_K0_M_K1, + make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1)); + + // B1 matrix in LDS memory, dst of blockwise copy + constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A1 matrix blockwise copy + auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< + FloatGemmAcc, + FloatAB, + decltype(acc_thread_desc_k0_m_k1), + decltype(a1_thread_desc_k0_m_k1), + decltype(acc_element_op), + Sequence, + Sequence<1, 0, 2>, + 2, + n4>{acc_element_op}; + + // B1 matrix blockwise copy + auto b1_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b1_grid_desc_bk0_n_bk1), + decltype(b1_block_desc_bk0_n_bk1), + B1BlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + B1BlockTransferSrcVectorDim, + 2, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + 1, + 1, + B1ThreadTransferSrcResetCoordinateAfterRun, + true, // DstResetCoord + NumGemmKPrefetchStage>( + b1_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b1_element_op, + b1_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + auto a1_thread_buf = make_static_buffer( + a1_thread_desc_k0_m_k1.GetElementSpaceSize()); + + // reuse LDS space for gemm0's b_block_buf + auto b1_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::a_block_space_size_aligned, + b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr index_t Gemm1KPack = math::max( + math::lcm(MfmaSelector::selected_mfma.group_size, B1K1), + MfmaSelector::selected_mfma.k_per_blk); + + auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a1_thread_desc_k0_m_k1), + decltype(b1_block_desc_bk0_n_bk1), + decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a1_thread_desc_k0_m_k1)), + decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b1_block_desc_bk0_n_bk1)), + MPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + Gemm1NXdlPerWave, + Gemm1KPack, + true, // TransposeC + Gemm1KPack, // AMmaKStride + Gemm1KPack * XdlopsGemm{}.K0PerXdlops>{ + make_tuple(0, 0, 0, 0)}; // TransposeC + + auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer(); + + // + // Blockwise softmax + // + auto workspace_buf = make_dynamic_buffer( + static_cast(p_shared) + + SharedMemTrait::a_block_space_size_aligned * sizeof(FloatAB) / 4 + + SharedMemTrait::b_block_space_size_aligned * sizeof(FloatAB) / 4, + SharedMemTrait::reduction_workspace); + + // get acc0 8D thread cluster + constexpr auto thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths() / + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); + constexpr auto tm0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I0); + constexpr auto tn0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I1); + constexpr auto tm1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I2); + constexpr auto tn1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I3); + constexpr auto tm2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I4); + constexpr auto tn2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I5); + constexpr auto tn3 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I6); + constexpr auto tn4 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I7); + + // get acc0 thread map + constexpr auto m0_n_m1_to_m_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(tm0 * tm1, tm2)), + make_pass_through_transform(I1)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + constexpr auto threadid_to_m0_n_m1_adaptor = make_single_stage_tensor_adaptor( + make_tuple( + make_merge_transform(make_tuple(tm0 * tm1, tn0 * tn1 * tn2 * tn3 * tn4, tm2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + const auto threadid_to_m_n_thread_cluster_adaptor = + chain_tensor_adaptors(m0_n_m1_to_m_n_adaptor, threadid_to_m0_n_m1_adaptor); + + // get acc0 2D thread cluster & 2D thread slice + constexpr auto thread_cluster_desc_m_n = make_naive_tensor_descriptor_packed( + make_tuple(tm0 * tm1 * tm2, tn0 * tn1 * tn2 * tn3 * tn4)); + constexpr auto thread_slice_desc_m_n = + make_naive_tensor_descriptor_packed(make_tuple(m0 * m1 * m2, n0 * n1 * n2 * n3 * n4)); + + auto blockwise_softmax = BlockwiseSoftmax{}; + + const index_t num_gemm1_k_block_outer_loop = + b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock; + constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock; + + // Initialize C + StaticBuffer + c_thread_buf; + c_thread_buf.Clear(); + + // Initialize running sum and max of exponentiating row vectors + using SoftmaxBuf = typename decltype(blockwise_softmax)::BufferType; + SoftmaxBuf running_sum, running_sum_new, running_max, running_max_new; + running_sum = 0; + running_sum_new = 0; + running_max = NumericLimits::Lowest(); + running_max_new = NumericLimits::Lowest(); + + // gemm1 K loop + index_t gemm1_k_block_outer_index = 0; + do + { + // gemm0 + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + acc_thread_buf, + num_k_block_main_loop); + // softmax + SoftmaxBuf& max = blockwise_softmax.max_value_buf; + SoftmaxBuf& sum = blockwise_softmax.sum_value_buf; + + blockwise_softmax.Run(acc_thread_buf, workspace_buf); + + // TODO: may convert to log domain + running_max_new = mathext::max(max, running_max); + running_sum_new = mathext::exp(running_max - running_max_new) * running_sum + + mathext::exp(max - running_max_new) * sum; + + block_sync_lds(); + // gemm1 + { + // TODO: explore using dynamic buffer for a1 thread buffer + // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(), + // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that + // the A1 source buffer is static buffer holding the output of first GEMM and + // requires constexpr offset by design. Therefore, we pass tensor coordinate offset + // explicitly in Run() below. + + // Initialize acc1 + acc1_thread_buf.Clear(); + + // preload data into LDS + b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, + b1_block_slice_copy_step); + + b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); + + // main body + if constexpr(num_gemm1_k_block_inner_loop > 1) + { + + static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) { + a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1, + make_tuple(Number{}, I0, I0), + acc_thread_buf, + a1_thread_desc_k0_m_k1, + make_tuple(I0, I0, I0), + a1_thread_buf); + b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); + + block_sync_lds(); + + gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); + + block_sync_lds(); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, + b1_block_slice_copy_step); + + b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); + }); + } + // tail + { + a1_blockwise_copy.Run( + acc_thread_desc_k0_m_k1, + make_tuple( + Number<(num_gemm1_k_block_inner_loop - 1) * A1ThreadSliceK0>{}, I0, I0), + acc_thread_buf, + a1_thread_desc_k0_m_k1, + make_tuple(I0, I0, I0), + a1_thread_buf); + block_sync_lds(); + + gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); + } + } // end gemm1 + + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); + constexpr auto cn0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); + constexpr auto cm1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); + constexpr auto cn1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3); + constexpr auto cm2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4); + constexpr auto cn2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5); + constexpr auto cn3 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6); + constexpr auto cn4 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7); + constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed( + make_tuple(cm0 * cm1 * cm2, cn0 * cn1 * cn2 * cn3 * cn4)); + constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0); + constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1); + + static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) { + static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) { + auto I = Number{}; + FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V + FloatGemmAcc c = c_thread_buf[I]; // O + FloatGemmAcc c_new = + (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + + math::exp(max[iM] - running_max_new[iM]) * acc1) / + running_sum_new[iM]; // O_new + + c_thread_buf(I) = c_new; + }); + }); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1, + a_block_reset_copy_step); // rewind K + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_bk0_n_bk1, + b_block_reset_copy_step); // rewind K and step N + + // update before next j iteration + running_max = running_max_new; + running_sum = running_sum_new; + + } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = + gemm1_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); + constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2)), // M2 = MPerXdl + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2, // N2 * N3 * N4 = NPerXdl + N3, + N4))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + gemm1_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3], + n_thread_data_on_block_idx[I4]), + tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index a50bb851fe..1c49f270a1 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -1145,9 +1145,22 @@ struct ThreadwiseTensorSliceTransfer_v4 src_desc, src_data_coord); // copy data from src_buf into src_tmp_vector - src_tmp_vector.template AsType()(Number<0>{}) = - src_buf.template Get(src_data_coord.GetOffset(), is_src_valid); + if constexpr(SrcBuffer::IsDynamicBuffer()) + { + src_tmp_vector.template AsType()(Number<0>{}) = + src_buf.template Get(src_data_coord.GetOffset(), is_src_valid); + } + else if constexpr(SrcBuffer::IsStaticBuffer()) + { + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t src_offset = src_desc.CalculateOffset( + src_ref_to_origin_disp_idx + data_to_origin_disp_idx + + i * src_scalar_step_in_vector); + // apply type convert + src_tmp_vector.template AsType()(i) = src_buf[Number{}]; + }); + } // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to // DstData) vector_type_maker_t dst_tmp_vector; @@ -1184,4 +1197,101 @@ struct ThreadwiseTensorSliceTransfer_v4 SrcCoord src_ref_coord_; }; +// Do NOT involve any tensor coordinates with StaticBuffer +template ::type = false> +struct ThreadwiseTensorSliceTransfer_StaticToStatic +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic( + const ElementwiseOperation& element_op) + : element_op_{element_op} + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! Not divisible"); + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SliceOrigin need to known at compile-time"); + + static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(), + "wrong! Buffer need to be StaticBuffer"); + + // SrcDesc and src_slice_origin_idx are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{}); + + // scalar per access on each dim + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, + "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); + + // copy data from src_buf into dst_vector + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + SrcData v; + + // apply element-wise operation + element_op_(v, src_buf[Number{}]); + + // apply type convert + dst_buf(Number{}) = type_convert(v); + }); + }); + } + + ElementwiseOperation element_op_; +}; + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index eaf0f13275..b4885ad3fc 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -579,7 +579,11 @@ struct MfmaSelector static constexpr index_t GetK1PerXdlops() { return selected_mfma.k_per_blk; } }; -template +template struct XdlopsGemm { static constexpr auto I0 = Number<0>{}; @@ -612,6 +616,8 @@ struct XdlopsGemm static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk"); } + // XDL output supporting C = A * B + // M2_N2 -> M2_M3_M4_N2 template __host__ __device__ static constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) @@ -627,10 +633,10 @@ struct XdlopsGemm make_pass_through_transform(N0), make_pass_through_transform(M1), make_pass_through_transform(N1), - make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk, - mfma_instr.num_input_blks, - mfma_instr.group_size)), - make_pass_through_transform(mfma_instr.num_threads_per_blk)), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{})), + make_pass_through_transform(Number{})), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, @@ -645,6 +651,41 @@ struct XdlopsGemm Sequence<7>{})); } + // transposed XDL output supporting C' = B' * A' + // M2_N2 -> M2_N2_N3_N4 + template + __host__ __device__ static constexpr auto + MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) + { + const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + + return transform_tensor_descriptor( + c_desc_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6, 7>{})); + } + template __host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2) @@ -698,7 +739,16 @@ struct XdlopsGemm "base base_type must be double, float, half, bfloat16, and int8_t!"); static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { - mfma_instr.template run(p_a_wave[k], p_b_wave[k], p_c_thread); + if constexpr(!TransposeC) + { + mfma_instr.template run( + p_a_wave[k], p_b_wave[k], p_c_thread); + } + else + { + mfma_instr.template run( + p_b_wave[k], p_a_wave[k], p_c_thread); + } }); } diff --git a/include/ck/utility/static_buffer.hpp b/include/ck/utility/static_buffer.hpp index 638eefa374..5428f4c6c3 100644 --- a/include/ck/utility/static_buffer.hpp +++ b/include/ck/utility/static_buffer.hpp @@ -20,6 +20,29 @@ struct StaticBuffer : public StaticallyIndexedArray __host__ __device__ constexpr StaticBuffer() : base{} {} + __host__ __device__ constexpr StaticBuffer& operator=(StaticBuffer& y) + { + StaticBuffer& x = *this; + static_for<0, base::Size(), 1>{}([&](auto i) { x(i) = y[i]; }); + return x; + } + + template + __host__ __device__ constexpr StaticBuffer& operator=(const Tuple& y) + { + static_assert(base::Size() == sizeof...(Ys), "wrong! size not the same"); + StaticBuffer& x = *this; + static_for<0, base::Size(), 1>{}([&](auto i) { x(i) = y[i]; }); + return x; + } + + __host__ __device__ constexpr StaticBuffer& operator=(const T& y) + { + StaticBuffer& x = *this; + static_for<0, base::Size(), 1>{}([&](auto i) { x(i) = y; }); + return x; + } + __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() { return AddressSpace; } __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } @@ -40,10 +63,12 @@ struct StaticBuffer : public StaticallyIndexedArray return base::operator()(i); } - __host__ __device__ void Clear() + __host__ __device__ void Set(T x) { - static_for<0, N, 1>{}([&](auto i) { operator()(i) = T{0}; }); + static_for<0, N, 1>{}([&](auto i) { operator()(i) = T{x}; }); } + + __host__ __device__ void Clear() { Set(T{0}); } }; // static buffer for vector @@ -61,6 +86,7 @@ struct StaticBufferTupleOfVector static constexpr auto s_per_v = Number{}; static constexpr auto num_of_v_ = Number{}; + static constexpr auto s_per_buf = s_per_v * num_of_v_; __host__ __device__ constexpr StaticBufferTupleOfVector() : base{} {} @@ -70,6 +96,8 @@ struct StaticBufferTupleOfVector __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } + __host__ __device__ static constexpr index_t Size() { return s_per_buf; }; + // Get S // i is offset of S template diff --git a/include/ck/utility/statically_indexed_array_multi_index.hpp b/include/ck/utility/statically_indexed_array_multi_index.hpp index bab5aebff7..21b2941b21 100644 --- a/include/ck/utility/statically_indexed_array_multi_index.hpp +++ b/include/ck/utility/statically_indexed_array_multi_index.hpp @@ -34,7 +34,10 @@ __host__ __device__ constexpr auto to_multi_index(const T& x) // is the alias of the latter. This is because compiler cannot infer the NSize if // using MultiIndex // TODO: how to fix this? -template +template < + typename... Ys, + typename X, + enable_if_t::value && !std::is_floating_point::value, bool> = false> __host__ __device__ constexpr auto operator+=(Tuple& y, const X& x) { static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); @@ -43,7 +46,10 @@ __host__ __device__ constexpr auto operator+=(Tuple& y, const X& x) return y; } -template +template < + typename... Ys, + typename X, + enable_if_t::value && !std::is_floating_point::value, bool> = false> __host__ __device__ constexpr auto operator-=(Tuple& y, const X& x) { static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same"); @@ -52,7 +58,10 @@ __host__ __device__ constexpr auto operator-=(Tuple& y, const X& x) return y; } -template +template < + typename... Xs, + typename Y, + enable_if_t::value && !std::is_floating_point::value, bool> = false> __host__ __device__ constexpr auto operator+(const Tuple& x, const Y& y) { static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); @@ -63,7 +72,10 @@ __host__ __device__ constexpr auto operator+(const Tuple& x, const Y& y) return r; } -template +template < + typename... Xs, + typename Y, + enable_if_t::value && !std::is_floating_point::value, bool> = false> __host__ __device__ constexpr auto operator-(const Tuple& x, const Y& y) { static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); @@ -74,7 +86,10 @@ __host__ __device__ constexpr auto operator-(const Tuple& x, const Y& y) return r; } -template +template < + typename... Xs, + typename Y, + enable_if_t::value && !std::is_floating_point::value, bool> = false> __host__ __device__ constexpr auto operator*(const Tuple& x, const Y& y) { static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); @@ -85,9 +100,11 @@ __host__ __device__ constexpr auto operator*(const Tuple& x, const Y& y) return r; } -// MultiIndex = index_t * MultiIndex -template -__host__ __device__ constexpr auto operator*(index_t a, const Tuple& x) +// MultiIndex = scalar * MultiIndex +template ::value || std::is_floating_point::value, bool> = false> +__host__ __device__ constexpr auto operator*(Y a, const Tuple& x) { constexpr index_t NSize = sizeof...(Xs); @@ -96,13 +113,40 @@ __host__ __device__ constexpr auto operator*(index_t a, const Tuple& x) return r; } -// MultiIndex = MultiIndex * index_t -template -__host__ __device__ constexpr auto operator*(const Tuple& x, index_t a) +// MultiIndex = MultiIndex * scalar +template ::value || std::is_floating_point::value, bool> = false> +__host__ __device__ constexpr auto operator*(const Tuple& x, Y a) { return a * x; } +namespace mathext { + +template +__host__ __device__ constexpr auto exp(const Tuple& x) +{ + constexpr index_t NSize = sizeof...(Xs); + + Tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = math::exp(x[i]); }); + return r; +} + +template +__host__ __device__ constexpr auto max(const Tuple& x, const Y& y) +{ + static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same"); + constexpr index_t NSize = sizeof...(Xs); + + Tuple r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = math::max(x[i], y[i]); }); + return r; +} + +} // namespace mathext + template __host__ __device__ void print_multi_index(const Tuple& x) { diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp index 97ce3dcacd..269126432b 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp @@ -16,6 +16,7 @@ namespace host { template @@ -58,7 +59,7 @@ struct ReferenceBatchedGemm : public device::BaseOperator auto f_gmk_gkn_gmn = [&](auto g, auto m, auto n) { const int K = arg.a_g_m_k_.mDesc.GetLengths()[2]; - float v_acc = 0; + AccDataType v_acc = 0; for(int k = 0; k < K; ++k) { @@ -68,10 +69,11 @@ struct ReferenceBatchedGemm : public device::BaseOperator arg.a_element_op_(v_a, arg.a_g_m_k_(g, m, k)); arg.b_element_op_(v_b, arg.b_g_k_n_(g, k, n)); - v_acc += ck::type_convert(v_a) * ck::type_convert(v_b); + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); } - float v_c; + AccDataType v_c; arg.c_element_op_(v_c, v_acc); @@ -81,8 +83,7 @@ struct ReferenceBatchedGemm : public device::BaseOperator make_ParallelTensorFunctor(f_gmk_gkn_gmn, arg.c_g_m_n_.mDesc.GetLengths()[0], arg.c_g_m_n_.mDesc.GetLengths()[1], - arg.c_g_m_n_.mDesc.GetLengths()[2])( - std::thread::hardware_concurrency()); + arg.c_g_m_n_.mDesc.GetLengths()[2])(); return 0; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp new file mode 100644 index 0000000000..d553f981d1 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + std::vector>>& instances); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm> +{ + using DeviceOp = DeviceBatchedGemmSoftmaxGemm; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) + { + add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + op_ptrs); + } + } + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/utility/host_tensor_generator.hpp b/library/include/ck/library/utility/host_tensor_generator.hpp index e0bd4991ef..b2edaa0eb3 100644 --- a/library/include/ck/library/utility/host_tensor_generator.hpp +++ b/library/include/ck/library/utility/host_tensor_generator.hpp @@ -151,3 +151,22 @@ struct GeneratorTensor_Sequential return dims[Dim]; } }; + +template +struct GeneratorTensor_Diagonal +{ + T value{1}; + + template + T operator()(Ts... Xs) const + { + std::array dims = {{static_cast(Xs)...}}; + size_t start_dim = dims.size() - NumEffectiveDim; + bool pred = true; + for(size_t i = start_dim + 1; i < dims.size(); i++) + { + pred &= (dims[start_dim] == dims[i]); + } + return pred ? value : T{0}; + } +}; diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 0a50d37c8a..115040eef7 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -13,6 +13,7 @@ add_subdirectory(gemm_reduce) add_subdirectory(gemm_bias_add_reduce) add_subdirectory(batched_gemm) add_subdirectory(batched_gemm_reduce) +add_subdirectory(batched_gemm_softmax_gemm) add_subdirectory(grouped_gemm) add_subdirectory(contraction_scale) add_subdirectory(contraction_bilinear) diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt new file mode 100644 index 0000000000..5e14c5ebb2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/CMakeLists.txt @@ -0,0 +1,8 @@ +set(DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_INSTANCE_SOURCE + device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp +) + +add_instance_library(device_batched_gemm_softmax_gemm_instance OBJECT ${DEVICE_BATCHED_GEMM_SOFTMAX_GEMM_INSTANCE_SOURCE}) +target_compile_features(device_batched_gemm_softmax_gemm_instance PUBLIC) +set_target_properties(device_batched_gemm_softmax_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON) +clang_tidy_check(device_batched_gemm_softmax_gemm_instance) \ No newline at end of file diff --git a/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp new file mode 100644 index 0000000000..4de2428775 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm/device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// c[g, m, n] = a[g, m, k] * b[g, n, k] +using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances = + std::tuple< + // clang-format off + //#######################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 256, 128, 32, 128, 32, 8, 8, 2, 32, 32, 2, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 128, 32, 8, 8, 2, 32, 32, 1, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 128, 32, 64, 32, 8, 8, 2, 32, 32, 1, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>, + DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 128, 64, 32, 128, 32, 8, 8, 2, 32, 32, 1, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8> + // clang-format on + >; + +void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profile_batched_gemm_impl.hpp b/profiler/include/profile_batched_gemm_impl.hpp index d50710a3c1..3d9df4c81f 100644 --- a/profiler/include/profile_batched_gemm_impl.hpp +++ b/profiler/include/profile_batched_gemm_impl.hpp @@ -101,6 +101,7 @@ bool profile_batched_gemm_impl(int do_verification, ck::tensor_operation::host::ReferenceBatchedGemm; diff --git a/profiler/include/profile_batched_gemm_reduce_impl.hpp b/profiler/include/profile_batched_gemm_reduce_impl.hpp index 5f1aa0a980..9807e020f5 100644 --- a/profiler/include/profile_batched_gemm_reduce_impl.hpp +++ b/profiler/include/profile_batched_gemm_reduce_impl.hpp @@ -155,6 +155,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ck::tensor_operation::host::ReferenceBatchedGemm; diff --git a/profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp b/profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp new file mode 100644 index 0000000000..48f722830c --- /dev/null +++ b/profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp @@ -0,0 +1,325 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_batched_gemm_softmax_gemm_impl(bool do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int O, + int BatchCount = 1, + int StrideA = -1, + int StrideB0 = -1, + int StrideB1 = -1, + int StrideC = -1, + int BatchStrideA = -1, + int BatchStrideB0 = -1, + int BatchStrideB1 = -1, + int BatchStrideC = -1) + +{ + + using Row = tensor_layout::gemm::RowMajor; + using Col = tensor_layout::gemm::ColumnMajor; + using PassThrough = tensor_operation::element_wise::PassThrough; + using AElementOp = PassThrough; + using B0ElementOp = PassThrough; + using Acc0ElementOp = PassThrough; + using B1ElementOp = PassThrough; + using CElementOp = PassThrough; + using AccDataType = float; + + // Ref Gemm0: various type in, fp32 out + using ReferenceGemm0Instance = tensor_operation::host::ReferenceBatchedGemm; + + // Ref Softmax: fp32 in, various type out + using ReferenceSoftmaxInstance = + tensor_operation::host::ReferenceSoftmax; + + // Ref Gemm1: various type in, various type out + using ReferenceGemm1Instance = tensor_operation::host::ReferenceBatchedGemm; + + bool pass = true; + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB0 = ck::is_same_v ? N : K; + const int DefaultStrideB1 = ck::is_same_v ? O : N; + const int DefaultStrideC = ck::is_same_v ? O : M; + + StrideA = (StrideA < 0) ? DefaultStrideA : StrideA; + StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0; + StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1; + StrideC = (StrideC < 0) ? DefaultStrideC : StrideC; + + const int DefaultBatchStrideA = (ck::is_same_v ? K : M) * StrideA; + const int DefaultBatchStrideB0 = (ck::is_same_v ? N : K) * StrideB0; + const int DefaultBatchStrideB1 = (ck::is_same_v ? O : N) * StrideB1; + const int DefaultBatchStrideC = (ck::is_same_v ? O : M) * StrideC; + + BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA; + BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0; + BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1; + BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC; + + auto f_host_tensor_descriptor = [](std::size_t batch_count, + std::size_t row, + std::size_t col, + std::size_t stride, + std::size_t batch_stride, + auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({batch_stride, stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({batch_stride, 1, stride})); + } + }; + + // C_m_o = A_m_k * B0_k_n * B1_n_o + Tensor a_g_m_k( + f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{})); + Tensor b0_g_k_n( + f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{})); + Tensor b1_g_n_o( + f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{})); + Tensor c_g_m_o_host_result( + f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{})); + Tensor c_g_m_o_device_result( + f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{})); + // Host verification: Output of Gemm0 is input A of Gemm1 + Tensor acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); + Tensor a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); + + std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; + std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl; + std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl; + std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSize()); + DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSize()); + DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSize()); + DeviceMem c_g_m_o_device_buf(sizeof(CDataType) * c_g_m_o_device_result.mDesc.GetElementSize()); + + a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data()); + b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); + b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + using DeviceOp = tensor_operation::device::DeviceBatchedGemmSoftmaxGemm; + + // get device op instances + const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + if(do_verification) + { + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, PassThrough{}); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument( + a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + } + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + // profile device op instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer( + static_cast(a_g_m_k_device_buf.GetDeviceBuffer()), + static_cast(b0_g_k_n_device_buf.GetDeviceBuffer()), + static_cast(b1_g_n_o_device_buf.GetDeviceBuffer()), + static_cast(c_g_m_o_device_buf.GetDeviceBuffer()), + M, + N, + K, + O, + BatchCount, + StrideA, + StrideB0, + StrideB1, + StrideC, + BatchStrideA, + BatchStrideB0, + BatchStrideB1, + BatchStrideC, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + std::string op_name = op_ptr->GetTypeString(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; + std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * + BatchCount; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data()); + + pass = pass & + ck::utils::check_err(c_g_m_o_device_result.mData, c_g_m_o_host_result.mData); + + if(do_log) + { + LogRangeAsType(std::cout << "a_g_m_k: ", a_g_m_k.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b0_g_k_n : ", b0_g_k_n.mData, ",") + << std::endl; + LogRangeAsType(std::cout << "b1_g_n_o : ", b1_g_n_o.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_g_m_o_host_result : ", c_g_m_o_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_g_m_o_device_result : ", c_g_m_o_device_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index eca4df2c8f..172d1fa6e8 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -40,6 +40,7 @@ add_subdirectory(gemm_split_k) add_subdirectory(gemm_reduce) add_subdirectory(batched_gemm) add_subdirectory(batched_gemm_reduce) +add_subdirectory(batched_gemm_softmax_gemm) add_subdirectory(grouped_gemm) add_subdirectory(reduce) add_subdirectory(convnd_fwd) diff --git a/test/batched_gemm_softmax_gemm/CMakeLists.txt b/test/batched_gemm_softmax_gemm/CMakeLists.txt new file mode 100644 index 0000000000..1ceecefb5f --- /dev/null +++ b/test/batched_gemm_softmax_gemm/CMakeLists.txt @@ -0,0 +1,5 @@ +add_custom_target(test_batched_gemm_softmax_gemm) + +add_gtest_executable(test_batched_gemm_softmax_gemm_fp16 test_batched_gemm_softmax_gemm_fp16.cpp) +target_link_libraries(test_batched_gemm_softmax_gemm_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_instance) +add_dependencies(test_batched_gemm_softmax_gemm test_batched_gemm_softmax_gemm_fp16) \ No newline at end of file diff --git a/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp b/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp new file mode 100644 index 0000000000..7b79c975db --- /dev/null +++ b/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_fp16.cpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "test_batched_gemm_softmax_gemm_util.hpp" + +template +class TestBatchedGemmSoftmaxGemmFP16 : public TestBatchedGemmSoftmaxGemm +{ +}; + +// clang-format off +using KernelTypes = ::testing::Types< + std::tuple + >; +// clang-format on + +TYPED_TEST_SUITE(TestBatchedGemmSoftmaxGemmFP16, KernelTypes); + +TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, Test_FP16) { this->Run(); } + +TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16) +{ + this->lengths_ = std::vector>{ + {256, 256, 64, 64, 768}, + {256, 256, 128, 128, 768}, + {512, 512, 64, 64, 768}, + {512, 512, 128, 128, 768}, + {1024, 1024, 64, 64, 768}, + {1024, 1024, 128, 128, 768}, + {2048, 2048, 64, 64, 768}, + {2048, 2048, 128, 128, 768}, + {4096, 4096, 64, 64, 768}, + {4096, 4096, 128, 128, 768}, + }; + this->bench_ = true; + this->verify_ = false; + this->Run(); +} diff --git a/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp b/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp new file mode 100644 index 0000000000..d51b4feda6 --- /dev/null +++ b/test/batched_gemm_softmax_gemm/test_batched_gemm_softmax_gemm_util.hpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include "profiler/include/profile_batched_gemm_softmax_gemm_impl.hpp" + +template +using I = ck::Number; + +using F16 = ck::half_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +struct TestBatchedGemmSoftmaxGemm : public ::testing::Test +{ + using ADataType = std::tuple_element_t<0, Tuple>; + using B0DataType = std::tuple_element_t<1, Tuple>; + using B1DataType = std::tuple_element_t<2, Tuple>; + using CDataType = std::tuple_element_t<3, Tuple>; + using ALayout = std::tuple_element_t<4, Tuple>; + using B0Layout = std::tuple_element_t<5, Tuple>; + using B1Layout = std::tuple_element_t<6, Tuple>; + using CLayout = std::tuple_element_t<7, Tuple>; + + std::vector> lengths_ = { + {256, 256, 64, 64, 4}, + {256, 256, 128, 128, 4}, + {512, 512, 64, 64, 2}, + {512, 512, 128, 128, 2}, + {1024, 1024, 64, 64, 1}, + {1024, 1024, 128, 128, 1}, + }; + bool bench_ = false; + bool verify_ = true; + + void RunSingle(int M, int N, int K, int O, int BatchCount) + { + bool pass = ck::profiler::profile_batched_gemm_softmax_gemm_impl( + verify_, 1, false, bench_, M, N, K, O, BatchCount); + + EXPECT_TRUE(pass); + } + + void Run() + { + for(auto lengths : this->lengths_) + { + int M = lengths[0]; + int N = lengths[1]; + int K = lengths[2]; + int O = lengths[3]; + int BatchCount = lengths[4]; + + this->RunSingle(M, N, K, O, BatchCount); + } + } +};