From 7c598d5d41abb1a8f3e541c5de1b7c0dc2e13c52 Mon Sep 17 00:00:00 2001 From: Anthony Chang Date: Tue, 20 Sep 2022 05:09:44 +0800 Subject: [PATCH] Grouped batched attention + permute (#412) * grouped attn without batch validates; now move toward grouped batched attn * grouped batched attention * working * remove debug logging clean up clean up * reintroduce g_ prefix back to host tensor variables * format * rename file * restore old file * rename * consolidate padded/non-padded attention example * harmonize padding specialization in attn examples [ROCm/composable_kernel commit: 9287b7c6b3756f7aae37aeee3e772672e7add404] --- .../CMakeLists.txt | 10 +- ...mm_scale_softmax_gemm_permute_xdl_fp16.cpp | 6 +- ...tched_gemm_scale_softmax_gemm_xdl_fp16.cpp | 8 +- ...mm_scale_softmax_gemm_permute_xdl_fp16.cpp | 443 +++++++++ ...tched_gemm_scale_softmax_gemm_xdl_fp16.cpp | 397 -------- .../device_batched_gemm_gemm_xdl_cshuffle.hpp | 8 +- ...vice_grouped_gemm_softmax_gemm_permute.hpp | 69 ++ ...gemm_softmax_gemm_permute_xdl_cshuffle.hpp | 929 ++++++++++++++++++ .../gpu/grid/block_to_ctile_map.hpp | 44 + 9 files changed, 1499 insertions(+), 415 deletions(-) create mode 100644 example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp delete mode 100644 example/32_batched_gemm_scale_softmax_gemm/padded_batched_gemm_scale_softmax_gemm_xdl_fp16.cpp create mode 100644 include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp create mode 100644 include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp diff --git a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt index 3eda09bf5c..df0566c214 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt +++ b/example/32_batched_gemm_scale_softmax_gemm/CMakeLists.txt @@ -1,8 +1,8 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) -add_example_executable(example_padded_batched_gemm_scale_softmax_gemm_xdl_fp16 padded_batched_gemm_scale_softmax_gemm_xdl_fp16.cpp) +add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) -add_custom_target(example_batched_gemm_scale_softmax_gemm) -add_dependencies(example_batched_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) -add_dependencies(example_batched_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16) -add_dependencies(example_batched_gemm_scale_softmax_gemm example_padded_batched_gemm_scale_softmax_gemm_xdl_fp16) +add_custom_target(example_gemm_scale_softmax_gemm) +add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_xdl_fp16) +add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_fp16) +add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16) diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp index 12f9bcb5d3..55a8820116 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -58,7 +58,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; using B1ElementOp = PassThrough; using CElementOp = PassThrough; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNOPadding; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding; using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle< @@ -149,8 +149,8 @@ int main(int argc, char* argv[]) // GEMM shape for A/B0/B1/C // C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o - ck::index_t M = 128; - ck::index_t N = 1024; + ck::index_t M = 120; + ck::index_t N = 1000; ck::index_t K = 64; ck::index_t O = 128; ck::index_t StrideA = -1; diff --git a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp index bb0af9caa9..de18f58ecd 100644 --- a/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp +++ b/example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_xdl_fp16.cpp @@ -55,7 +55,7 @@ using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; using B1ElementOp = PassThrough; using CElementOp = PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding; using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< ALayout, @@ -73,7 +73,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftma Acc0ElementOp, B1ElementOp, CElementOp, - GemmDefault, + GemmSpec, 1, 256, 128, // MPerBlock @@ -144,8 +144,8 @@ int main(int argc, char* argv[]) bool time_kernel = false; // GEMM shape - ck::index_t M = 1024; - ck::index_t N = 1024; + ck::index_t M = 1020; + ck::index_t N = 1020; ck::index_t K = 64; ck::index_t O = 128; ck::index_t BatchCount = 4; diff --git a/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp new file mode 100644 index 0000000000..273afdad6a --- /dev/null +++ b/example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp @@ -0,0 +1,443 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +/* +Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o + |-----------------| + Gemm0 + |-------------------------------------| + Gemm1 +*/ + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using B0DataType = F16; +using B1DataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F16; + +using ALayout = Row; +using B0Layout = Col; +using B1Layout = Row; + +using CPermuteNumDims_G_M_O = + S<1, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_M_O + +using AElementOp = PassThrough; +using B0ElementOp = PassThrough; +using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; +using B1ElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle< + ALayout, + B0Layout, + B1Layout, + CPermuteNumDims_G_M_O, + ADataType, + B0DataType, + B1DataType, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + B0ElementOp, + Acc0ElementOp, + B1ElementOp, + CElementOp, + GemmSpec, + 1, + 256, + 128, // MPerBlock + 128, // NPerBlock + 32, // KPerBlock + 64, // Gemm1NPerBlock + 32, // Gemm1KPerBlock + 8, // AK1 + 8, // BK1 + 2, // B1K1 + 32, // MPerXDL + 32, // NPerXDL + 1, // MXdlPerWave + 4, // NXdlPerWave + 2, // Gemm1NXdlPerWave + S<4, 64, 1>, // ABlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<4, 64, 1>, // BBlockTransfer + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 8, + 8, + true, + S<16, 16, 1>, // B1BlockTransfer + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 4, + 2, + false, + 1, // CShuffleMXdlPerWavePerShuffle + 2, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock + +// Ref Gemm0: fp16 in, fp32 out +using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +// Ref Softmax: fp32 in, fp16 out +using ReferenceSoftmaxInstance = + ck::tensor_operation::host::ReferenceSoftmax; + +// Ref Gemm1: fp16 in, fp16 out +using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + exit(0); + } + + float alpha = 1; // scaling after 1st gemm + + std::size_t group_count = 13; + + // Problem descs + std::vector problem_descs; + std::vector p_a; + std::vector p_b0; + std::vector p_b1; + std::vector p_c; + + for(std::size_t i = 0; i < group_count; i++) + { + int M = 128 * (rand() % 8 + 1); + int N = 128 * (rand() % 8 + 1); + int K = 64; + int O = 64 * (rand() % 2 + 1); + int Batch = rand() % 8 + 1; + + const int StrideA = ck::is_same_v ? K : M; + const int StrideB0 = ck::is_same_v ? N : K; + const int StrideB1 = ck::is_same_v ? O : N; + + const int BatchStrideA = (ck::is_same_v ? K : M) * StrideA; + const int BatchStrideB0 = (ck::is_same_v ? N : K) * StrideB0; + const int BatchStrideB1 = (ck::is_same_v ? O : N) * StrideB1; + + std::vector c_gs_ms_os_lengths{Batch, M, O}; + std::vector c_gs_ms_os_strides{O, Batch * O, 1}; + + problem_descs.push_back({M, + N, + K, + O, + Batch, + StrideA, + StrideB0, + StrideB1, + BatchStrideA, + BatchStrideB0, + BatchStrideB1, + c_gs_ms_os_lengths, + c_gs_ms_os_strides}); + } + + auto f_host_tensor_descriptor = [](std::size_t batch_count, + std::size_t row, + std::size_t col, + std::size_t stride, + std::size_t batch_stride, + auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({batch_stride, stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({batch_count, row, col}), + std::vector({batch_stride, 1, stride})); + } + }; + + std::vector> a_tensors; + std::vector> b0_tensors; + std::vector> b1_tensors; + std::vector> c_tensors; + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device; + std::vector b0_tensors_device; + std::vector b1_tensors_device; + std::vector c_tensors_device; + + std::size_t flop = 0, num_byte = 0; + + std::cout << "group count " << group_count << ". printing first 4 groups\n"; + for(std::size_t i = 0; i < group_count; i++) + { + const auto& M = problem_descs[i].M; + const auto& N = problem_descs[i].N; + const auto& K = problem_descs[i].K; + const auto& O = problem_descs[i].O; + const auto& Batch = problem_descs[i].Batch; + const auto& StrideA = problem_descs[i].StrideA; + const auto& StrideB0 = problem_descs[i].StrideB0; + const auto& StrideB1 = problem_descs[i].StrideB1; + const auto& BatchStrideA = problem_descs[i].BatchStrideA; + const auto& BatchStrideB0 = problem_descs[i].BatchStrideB0; + const auto& BatchStrideB1 = problem_descs[i].BatchStrideB1; + const auto& c_gs_ms_os_lengths = problem_descs[i].c_gs_ms_os_lengths; + const auto& c_gs_ms_os_strides = problem_descs[i].c_gs_ms_os_strides; + + // C_m_o = A_m_k * B0_k_n * B1_n_o + Tensor a_g_m_k( + f_host_tensor_descriptor(Batch, M, K, StrideA, BatchStrideA, ALayout{})); + Tensor b0_g_k_n( + f_host_tensor_descriptor(Batch, K, N, StrideB0, BatchStrideB0, B0Layout{})); + Tensor b1_g_n_o( + f_host_tensor_descriptor(Batch, N, O, StrideB1, BatchStrideB1, B1Layout{})); + Tensor c_gs_ms_os_device_result( + std::vector(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()), + std::vector(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end())); + + flop += (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * Batch; + num_byte += (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + + sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * + Batch; + + if(i < 4) + { + std::cout << "a_g_m_k[" << i << "]: " << a_g_m_k.mDesc << ", " + << "b0_g_k_n[" << i << "]: " << b0_g_k_n.mDesc << ", " + << "b1_g_n_o[" << i << "]: " << b1_g_n_o.mDesc << ", " + << "c_gs_ms_os[" << i << "]: " << c_gs_ms_os_device_result.mDesc << std::endl; + } + + switch(init_method) + { + case 0: break; + case 1: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + break; + case 2: + a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + case 3: + a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Diagonal{}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); + break; + default: + a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); + b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); + } + + a_tensors.push_back(a_g_m_k); + b0_tensors.push_back(b0_g_k_n); + b1_tensors.push_back(b1_g_n_o); + c_tensors.push_back(c_gs_ms_os_device_result); + + a_tensors_device.emplace_back( + std::make_unique(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize())); + b0_tensors_device.emplace_back( + std::make_unique(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize())); + b1_tensors_device.emplace_back( + std::make_unique(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize())); + c_tensors_device.emplace_back(std::make_unique( + sizeof(CDataType) * c_gs_ms_os_device_result.mDesc.GetElementSpaceSize())); + + a_tensors_device[i]->ToDevice(a_g_m_k.mData.data()); + b0_tensors_device[i]->ToDevice(b0_g_k_n.mData.data()); + b1_tensors_device[i]->ToDevice(b1_g_n_o.mData.data()); + + p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); + p_b0.push_back(b0_tensors_device[i]->GetDeviceBuffer()); + p_b1.push_back(b1_tensors_device[i]->GetDeviceBuffer()); + p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); + } + + auto a_element_op = AElementOp{}; + auto b0_element_op = B0ElementOp{}; + auto acc0_element_op = Acc0ElementOp{alpha}; + auto b1_element_op = B1ElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + auto argument = gemm.MakeArgument(p_a, + p_b0, + p_b1, + p_c, + problem_descs, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op); + + // specify workspace for problem_desc + DeviceMem problem_desc_workspace(gemm.GetWorkSpaceSize(&argument)); + + gemm.SetWorkSpacePointer(&argument, problem_desc_workspace.GetDeviceBuffer()); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + bool pass = true; + if(do_verification) + { + for(std::size_t i = 0; i < group_count; i++) + { + const auto& M = problem_descs[i].M; + const auto& N = problem_descs[i].N; + const auto& O = problem_descs[i].O; + const auto& Batch = problem_descs[i].Batch; + const auto& c_gs_ms_os_lengths = problem_descs[i].c_gs_ms_os_lengths; + const auto& c_gs_ms_os_strides = problem_descs[i].c_gs_ms_os_strides; + + const auto& a_g_m_k = a_tensors[i]; + const auto& b0_g_k_n = b0_tensors[i]; + const auto& b1_g_n_o = b1_tensors[i]; + auto& c_gs_ms_os_device_result = c_tensors[i]; + auto& c_gs_ms_os_device_buf = *c_tensors_device[i]; + + Tensor c_gs_ms_os_host_result( + std::vector(c_gs_ms_os_lengths.begin(), c_gs_ms_os_lengths.end()), + std::vector(c_gs_ms_os_strides.begin(), c_gs_ms_os_strides.end())); + + c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); + + // Output of Gemm0 is input A of Gemm1 + Tensor acc0_m_n(f_host_tensor_descriptor(Batch, M, N, N, M * N, Row{})); + + Tensor a1_g_m_n(f_host_tensor_descriptor(Batch, M, N, N, M * N, Row{})); + + Tensor c_g_m_o_host_result(std::vector{Batch, M, O}, + std::vector{M * O, O, 1}); + + auto ref_gemm0 = ReferenceGemm0Instance{}; + auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); + auto ref_gemm0_argument = ref_gemm0.MakeArgument( + a_g_m_k, b0_g_k_n, acc0_m_n, a_element_op, b0_element_op, acc0_element_op); + + ref_gemm0_invoker.Run(ref_gemm0_argument); + + auto ref_softmax = ReferenceSoftmaxInstance{}; + auto ref_softmax_invoker = ref_softmax.MakeInvoker(); + auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_m_n, a1_g_m_n, 1, 0, {2}); + + ref_softmax_invoker.Run(ref_softmax_argument); + + auto ref_gemm1 = ReferenceGemm1Instance{}; + auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); + auto ref_gemm1_argument = ref_gemm1.MakeArgument(a1_g_m_n, + b1_g_n_o, + c_g_m_o_host_result, + PassThrough{}, + b1_element_op, + c_element_op); + + ref_gemm1_invoker.Run(ref_gemm1_argument); + + // Note: in this example, we merely permute the dimensions by changing underlying + // strides so we simply access data as-is + c_gs_ms_os_host_result.ForEach( + [&](auto& self, auto idx) { self(idx) = c_g_m_o_host_result(idx); }); + + bool pass_ = + ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData); + pass &= pass_; + } + } + + return pass ? 0 : 1; +} diff --git a/example/32_batched_gemm_scale_softmax_gemm/padded_batched_gemm_scale_softmax_gemm_xdl_fp16.cpp b/example/32_batched_gemm_scale_softmax_gemm/padded_batched_gemm_scale_softmax_gemm_xdl_fp16.cpp deleted file mode 100644 index 70a22335ac..0000000000 --- a/example/32_batched_gemm_scale_softmax_gemm/padded_batched_gemm_scale_softmax_gemm_xdl_fp16.cpp +++ /dev/null @@ -1,397 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. - -/* -Gemm + Softmax + Gemm fused operation. Computes C_g_m_o = Softmax(A_g_m_k * B0_g_k_n) * B1_g_n_o - |-----------------| - Gemm0 - |-------------------------------------| - Gemm1 -*/ - -#include -#include -#include -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/host_tensor.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp" - -template -using S = ck::Sequence; - -using F16 = ck::half_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using ADataType = F16; -using B0DataType = F16; -using B1DataType = F16; -using AccDataType = F32; -using CShuffleDataType = F32; -using CDataType = F16; - -using ALayout = Row; -using B0Layout = Col; -using B1Layout = Row; -using CLayout = Row; - -using AElementOp = PassThrough; -using B0ElementOp = PassThrough; -using Acc0ElementOp = ck::tensor_operation::element_wise::Scale; -using B1ElementOp = PassThrough; -using CElementOp = PassThrough; - -static constexpr auto MNPadding = ck::tensor_operation::device::GemmSpecialization::MNPadding; - -using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< - ALayout, - B0Layout, - B1Layout, - CLayout, - ADataType, - B0DataType, - B1DataType, - CDataType, - AccDataType, - CShuffleDataType, - AElementOp, - B0ElementOp, - Acc0ElementOp, - B1ElementOp, - CElementOp, - MNPadding, - 1, - 256, - 128, // MPerBlock - 128, // NPerBlock - 32, // KPerBlock - 64, // Gemm1NPerBlock - 32, // Gemm1KPerBlock - 8, // AK1 - 8, // BK1 - 2, // B1K1 - 32, // MPerXDL - 32, // NPerXDL - 1, // MXdlPerWave - 4, // NXdlPerWave - 2, // Gemm1NXdlPerWave - S<4, 64, 1>, // ABlockTransfer - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - S<4, 64, 1>, // BBlockTransfer - S<1, 0, 2>, - S<1, 0, 2>, - 2, - 8, - 8, - true, - S<16, 16, 1>, // B1BlockTransfer - S<0, 2, 1>, - S<0, 2, 1>, - 1, - 4, - 2, - false, - 1, // CShuffleMXdlPerWavePerShuffle - 2, // CShuffleNXdlPerWavePerShuffle - S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8>; // CShuffleBlockTransferScalarPerVector_NPerBlock - -// Ref Gemm0: fp16 in, fp32 out -using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceBatchedGemm; - -// Ref Softmax: fp32 in, fp16 out -using ReferenceSoftmaxInstance = - ck::tensor_operation::host::ReferenceSoftmax; - -// Ref Gemm1: fp16 in, fp16 out -using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm; - -int main(int argc, char* argv[]) -{ - bool do_verification = true; - int init_method = 1; - bool time_kernel = false; - - // GEMM shape - ck::index_t M = 1020; - ck::index_t N = 1020; - ck::index_t K = 64; - ck::index_t O = 128; - ck::index_t BatchCount = 4; - ck::index_t StrideA = -1; - ck::index_t StrideB0 = -1; - ck::index_t StrideB1 = -1; - ck::index_t StrideC = -1; - ck::index_t BatchStrideA = -1; - ck::index_t BatchStrideB0 = -1; - ck::index_t BatchStrideB1 = -1; - ck::index_t BatchStrideC = -1; - float alpha = 1; - - if(argc == 1) - { - // use default case - } - else if(argc == 4) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - } - else if(argc == 9) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - O = std::stoi(argv[7]); - - BatchCount = std::stoi(argv[8]); - } - else if(argc == 18) - { - do_verification = std::stoi(argv[1]); - init_method = std::stoi(argv[2]); - time_kernel = std::stoi(argv[3]); - - M = std::stoi(argv[4]); - N = std::stoi(argv[5]); - K = std::stoi(argv[6]); - O = std::stoi(argv[7]); - - BatchCount = std::stoi(argv[8]); - - StrideA = std::stoi(argv[9]); - StrideB0 = std::stoi(argv[10]); - StrideB1 = std::stoi(argv[11]); - StrideC = std::stoi(argv[12]); - - BatchStrideA = std::stoi(argv[13]); - BatchStrideB0 = std::stoi(argv[14]); - BatchStrideB1 = std::stoi(argv[15]); - BatchStrideC = std::stoi(argv[16]); - - alpha = std::stof(argv[17]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=no, 1=yes)\n"); - printf("arg4 to 16: M, N, K, O, Batch, StrideA, StrideB0, StrideB1, StrideC, BatchStrideA, " - "BatchStrideB0, BatchStrideB1, BatchStrideC\n"); - printf("arg17: scale (alpha)\n"); - exit(0); - } - - const int DefaultStrideA = ck::is_same_v ? K : M; - const int DefaultStrideB0 = ck::is_same_v ? N : K; - const int DefaultStrideB1 = ck::is_same_v ? O : N; - const int DefaultStrideC = ck::is_same_v ? O : M; - - StrideA = (StrideA < 0) ? DefaultStrideA : StrideA; - StrideB0 = (StrideB0 < 0) ? DefaultStrideB0 : StrideB0; - StrideB1 = (StrideB1 < 0) ? DefaultStrideB1 : StrideB1; - StrideC = (StrideC < 0) ? DefaultStrideC : StrideC; - - const int DefaultBatchStrideA = (ck::is_same_v ? K : M) * StrideA; - const int DefaultBatchStrideB0 = (ck::is_same_v ? N : K) * StrideB0; - const int DefaultBatchStrideB1 = (ck::is_same_v ? O : N) * StrideB1; - const int DefaultBatchStrideC = (ck::is_same_v ? O : M) * StrideC; - - BatchStrideA = BatchStrideA < 0 ? DefaultBatchStrideA : BatchStrideA; - BatchStrideB0 = BatchStrideB0 < 0 ? DefaultBatchStrideB0 : BatchStrideB0; - BatchStrideB1 = BatchStrideB1 < 0 ? DefaultBatchStrideB1 : BatchStrideB1; - BatchStrideC = BatchStrideC < 0 ? DefaultBatchStrideC : BatchStrideC; - - auto f_host_tensor_descriptor = [](std::size_t batch_count, - std::size_t row, - std::size_t col, - std::size_t stride, - std::size_t batch_stride, - auto layout) { - if(std::is_same::value) - { - return HostTensorDescriptor(std::vector({batch_count, row, col}), - std::vector({batch_stride, stride, 1})); - } - else - { - return HostTensorDescriptor(std::vector({batch_count, row, col}), - std::vector({batch_stride, 1, stride})); - } - }; - - // C_m_o = A_m_k * B0_k_n * B1_n_o - Tensor a_g_m_k( - f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{})); - Tensor b0_g_k_n( - f_host_tensor_descriptor(BatchCount, K, N, StrideB0, BatchStrideB0, B0Layout{})); - Tensor b1_g_n_o( - f_host_tensor_descriptor(BatchCount, N, O, StrideB1, BatchStrideB1, B1Layout{})); - Tensor c_g_m_o_host_result( - f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{})); - Tensor c_g_m_o_device_result( - f_host_tensor_descriptor(BatchCount, M, O, StrideC, BatchStrideC, CLayout{})); - - std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; - std::cout << "b0_g_k_n: " << b0_g_k_n.mDesc << std::endl; - std::cout << "b1_g_n_o: " << b1_g_n_o.mDesc << std::endl; - std::cout << "c_g_m_o: " << c_g_m_o_host_result.mDesc << std::endl; - - switch(init_method) - { - case 0: break; - case 1: - a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - b1_g_n_o.GenerateTensorValue(GeneratorTensor_2{-5, 5}); - break; - case 2: - a_g_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); - b1_g_n_o.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); - break; - case 3: - a_g_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_Diagonal{}); - b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); - break; - default: - a_g_m_k.GenerateTensorValue(GeneratorTensor_1{1}); - b0_g_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{}); - b1_g_n_o.GenerateTensorValue(GeneratorTensor_Diagonal{}); - } - - DeviceMem a_g_m_k_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpaceSize()); - DeviceMem b0_g_k_n_device_buf(sizeof(B0DataType) * b0_g_k_n.mDesc.GetElementSpaceSize()); - DeviceMem b1_g_n_o_device_buf(sizeof(B1DataType) * b1_g_n_o.mDesc.GetElementSpaceSize()); - DeviceMem c_g_m_o_device_buf(sizeof(CDataType) * - c_g_m_o_device_result.mDesc.GetElementSpaceSize()); - - a_g_m_k_device_buf.ToDevice(a_g_m_k.mData.data()); - b0_g_k_n_device_buf.ToDevice(b0_g_k_n.mData.data()); - b1_g_n_o_device_buf.ToDevice(b1_g_n_o.mData.data()); - - auto a_element_op = AElementOp{}; - auto b0_element_op = B0ElementOp{}; - auto acc0_element_op = Acc0ElementOp{alpha}; - auto b1_element_op = B1ElementOp{}; - auto c_element_op = CElementOp{}; - - // do GEMM - auto gemm = DeviceGemmInstance{}; - auto invoker = gemm.MakeInvoker(); - auto argument = - gemm.MakeArgument(static_cast(a_g_m_k_device_buf.GetDeviceBuffer()), - static_cast(b0_g_k_n_device_buf.GetDeviceBuffer()), - static_cast(b1_g_n_o_device_buf.GetDeviceBuffer()), - static_cast(c_g_m_o_device_buf.GetDeviceBuffer()), - M, - N, - K, - O, - BatchCount, - StrideA, - StrideB0, - StrideB1, - StrideC, - BatchStrideA, - BatchStrideB0, - BatchStrideB1, - BatchStrideC, - a_element_op, - b0_element_op, - acc0_element_op, - b1_element_op, - c_element_op); - - if(!gemm.IsSupportedArgument(argument)) - { - std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; - - return 0; - } - - float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); - - std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * BatchCount; - std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N + - sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O) * - BatchCount; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << gemm.GetTypeString() << std::endl; - - c_g_m_o_device_buf.FromDevice(c_g_m_o_device_result.mData.data()); - - if(do_verification) - { - // Output of Gemm0 is input A of Gemm1 - Tensor acc0_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); - - Tensor a1_g_m_n(f_host_tensor_descriptor(BatchCount, M, N, N, M * N, Row{})); - - auto ref_gemm0 = ReferenceGemm0Instance{}; - auto ref_gemm0_invoker = ref_gemm0.MakeInvoker(); - auto ref_gemm0_argument = ref_gemm0.MakeArgument( - a_g_m_k, b0_g_k_n, acc0_g_m_n, a_element_op, b0_element_op, acc0_element_op); - - ref_gemm0_invoker.Run(ref_gemm0_argument); - - auto ref_softmax = ReferenceSoftmaxInstance{}; - auto ref_softmax_invoker = ref_softmax.MakeInvoker(); - auto ref_softmax_argument = ref_softmax.MakeArgument(acc0_g_m_n, a1_g_m_n, 1, 0, {2}); - - ref_softmax_invoker.Run(ref_softmax_argument); - - auto ref_gemm1 = ReferenceGemm1Instance{}; - auto ref_gemm1_invoker = ref_gemm1.MakeInvoker(); - auto ref_gemm1_argument = ref_gemm1.MakeArgument( - a1_g_m_n, b1_g_n_o, c_g_m_o_host_result, PassThrough{}, b1_element_op, c_element_op); - - ref_gemm1_invoker.Run(ref_gemm1_argument); - - return ck::utils::check_err(c_g_m_o_device_result.mData, c_g_m_o_host_result.mData) ? 0 : 1; - } - - return 0; -} diff --git a/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp index 2f245ccfd0..3b87e56337 100644 --- a/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp @@ -503,13 +503,9 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template + typename ADataType, + typename B0DataType, + typename B1DataType, + typename CDataType, + typename AElementwiseOperation, + typename B0ElementwiseOperation, + typename Acc0ElementwiseOperation, + typename B1ElementwiseOperation, + typename CElementwiseOperation> +struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator +{ + struct ProblemDesc + { + // Overall problem shape + index_t M; + index_t N; + index_t K; + index_t O; + index_t Batch; + + // Stride for A/B0/B1; layout determined by template args + index_t StrideA; + index_t StrideB0; + index_t StrideB1; + index_t BatchStrideA; + index_t BatchStrideB0; + index_t BatchStrideB1; + + // Lengths and strides for output C + std::vector c_gs_ms_os_lengths; + std::vector c_gs_ms_os_strides; + }; + + virtual std::unique_ptr + MakeArgumentPointer(std::vector p_a_vec, + std::vector p_b0_vec, + std::vector p_b1_vec, + std::vector p_c_vec, + std::vector problem_desc_vec, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + Acc0ElementwiseOperation acc0_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp new file mode 100644 index 0000000000..6aa6e3d8cf --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -0,0 +1,929 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1( + const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const B1ElementwiseOperation b1_element_op, + const CElementwiseOperation c_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id = get_block_1d_id(); + + const auto arg_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(group_kernel_args)); + + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + + while((!(block_id >= arg_ptr[group_id].block_start_ && + block_id < arg_ptr[group_id].block_end_)) && + left <= right) + { + if(block_id < arg_ptr[group_id].block_start_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + // per-group batch offset + const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_; + const index_t g_idx = __builtin_amdgcn_readfirstlane( + (block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast( + arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); + + GridwiseGemm::template Run( + arg_ptr[group_id].p_a_grid_ + a_batch_offset, + arg_ptr[group_id].p_b_grid_ + b_batch_offset, + arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, + arg_ptr[group_id].p_c_grid_ + c_batch_offset, + p_shared, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op, + arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, + arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, + arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, + arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg_ptr[group_id].block_2_ctile_map_); +#else + ignore = group_kernel_args; + ignore = group_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = acc_element_op; + ignore = b1_element_op; + ignore = c_element_op; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +// Computes C = A * B0 * B1 +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template + typename ADataType, + typename BDataType, + typename B1DataType, + typename CDataType, + typename GemmAccDataType, + typename CShuffleDataType, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename AccElementwiseOperation, + typename B1ElementwiseOperation, + typename CElementwiseOperation, + GemmSpecialization GemmSpec, + index_t NumGemmKPrefetchStage, + index_t BlockSize, + index_t MPerBlock, + index_t NPerBlock, // Gemm0NPerBlock + index_t KPerBlock, // Gemm0KPerBlock + index_t Gemm1NPerBlock, + index_t Gemm1KPerBlock, + index_t AK1, + index_t BK1, + index_t B1K1, + index_t MPerXDL, + index_t NPerXDL, + index_t MXdlPerWave, + index_t NXdlPerWave, + index_t Gemm1NXdlPerWave, + typename ABlockTransferThreadClusterLengths_AK0_M_AK1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_AK1, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_BK0_N_BK1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_BK1, + bool BBlockLdsExtraN, + typename B1BlockTransferThreadClusterLengths_BK0_N_BK1, + typename B1BlockTransferThreadClusterArrangeOrder, + typename B1BlockTransferSrcAccessOrder, + index_t B1BlockTransferSrcVectorDim, + index_t B1BlockTransferSrcScalarPerVector, + index_t B1BlockTransferDstScalarPerVector_BK1, + bool B1BlockLdsExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + index_t CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopScheduler LoopSched = LoopScheduler::Default> +struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle + : public DeviceGroupedGemmSoftmaxGemmPermute +{ + using DeviceOp = DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle; + using ProblemDesc = + typename DeviceGroupedGemmSoftmaxGemmPermute::ProblemDesc; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto matrix_padder = + GemmGemmPadder{ + MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock}; + + // FIXME: pad K + static_assert(!matrix_padder.PadK, "KPadding is currently not supported"); + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // Args: Gemm1KRaw, Gemm1NRaw, StrideB1 + static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b1_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw); + + const auto N = b1_grid_desc_n_k.GetLength(I0); + const auto K = b1_grid_desc_n_k.GetLength(I1); + + const auto B1K0 = K / B1K1; + + return transform_tensor_descriptor( + b1_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + static auto MakeCGridDescriptor_M_N(const std::vector& c_gs_ms_ns_lengths_vec, + const std::vector& c_gs_ms_ns_strides_vec) + { + constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0); + constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1); + constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N + + assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto c_ms_ns_lengths = to_tuple( + c_gs_ms_ns_lengths_vec, Number{}, Number{}); + const auto c_ms_ns_strides = to_tuple( + c_gs_ms_ns_strides_vec, Number{}, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(c_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(c_ms_ns_lengths, nDimIds); + + // naive tensor C[M0, M1, M2, ..., N0, N1, N2...] + const auto c_grid_desc_ms_ns = + make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides); + + // transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] + const auto c_grid_desc_mraw_nraw = transform_tensor_descriptor( + c_grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); + } + + // assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + static auto MakeCGridDescriptor_G_M_N(const std::vector& c_gs_ms_ns_lengths_vec, + const std::vector& c_gs_ms_ns_strides_vec) + { + constexpr index_t NumDimG = CPermuteNumDims_G_M_Gemm1N::At(I0); + constexpr index_t NumDimM = CPermuteNumDims_G_M_Gemm1N::At(I1); + constexpr index_t NumDimN = CPermuteNumDims_G_M_Gemm1N::At(I2); // NumDimGemm1N + + assert(c_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + c_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto c_gs_ms_ns_lengths = + to_tuple(c_gs_ms_ns_lengths_vec, Number<0>{}, Number{}); + const auto c_gs_ms_ns_strides = + to_tuple(c_gs_ms_ns_strides_vec, Number<0>{}, Number{}); + + // dimension Ids for G0, G1, ... + constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{}; + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = + typename arithmetic_sequence_gen::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen::type{}; + + // lengths for G0, G1, ... + const auto gLengths = get_container_subset(c_gs_ms_ns_lengths, gDimIds); + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(c_gs_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(c_gs_ms_ns_lengths, nDimIds); + + // naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + const auto c_grid_desc_gs_ms_ns = + make_naive_tensor_descriptor(c_gs_ms_ns_lengths, c_gs_ms_ns_strides); + + // transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + const auto c_grid_desc_g_mraw_nraw = + transform_tensor_descriptor(c_grid_desc_gs_ms_ns, + make_tuple(make_merge_transform(gLengths), + make_merge_transform(mLengths), + make_merge_transform(nLengths)), + make_tuple(gDimIds, mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // this desc is only for calculating batch offset so no padding needed + return c_grid_desc_g_mraw_nraw; + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N({}, {})); + using CGridDesc_G_M_N = decltype(MakeCGridDescriptor_G_M_N({}, {})); + + struct ComputeBasePtrOfStridedBatch + { + ComputeBasePtrOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideB1, + CGridDesc_G_M_N c_grid_desc_g_m_n) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideB1_(BatchStrideB1), + c_grid_desc_g_m_n_(c_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB1_); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideB1_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + }; + + // GridwiseGemm + using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + B1GridDesc_BK0_N_BK1, + CGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + AK1, + BK1, + B1K1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + Gemm1NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + true, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + true, + BBlockLdsExtraN, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + false, + B1BlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched, + matrix_padder.PadN>; + + using Block2CTileMap = OffsettedBlockToCTileMap; + + struct GroupKernelArg + { + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + + // batch & stride + index_t num_blocks_per_batch_; + ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; + + // block-to-c-tile map + Block2CTileMap block_2_ctile_map_; + + index_t block_start_, block_end_; + }; + + struct GroupDeviceArg + { + // problem definiton + index_t M; + index_t N; + index_t K; + index_t O; + + // Strides for the last dimensions of C for sanity check of vector load/store + index_t c_extent_lowest_; + index_t c_stride_lowest_; + + CGridDesc_M_N c_grid_desc_m_n_; + }; + + // Argument + // FIXME: constness + struct Argument : public BaseArgument + { + Argument(std::vector p_a_vec, + std::vector p_b_vec, + std::vector p_b1_vec, + std::vector p_c_vec, + std::vector problem_desc_vec, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + : a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + acc_element_op_{acc_element_op}, + b1_element_op_{b1_element_op}, + c_element_op_{c_element_op} + { + group_count_ = problem_desc_vec.size(); + + if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() && + group_count_ == p_b1_vec.size() && group_count_ == p_c_vec.size())) + { + throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size"); + } + + grid_size_ = 0; + + for(std::size_t i = 0; i < group_count_; i++) + { + const auto p_a_grid = static_cast(p_a_vec[i]); + const auto p_b_grid = static_cast(p_b_vec[i]); + const auto p_b1_grid = static_cast(p_b1_vec[i]); + const auto p_c_grid = static_cast(p_c_vec[i]); + + const auto a_grid_desc_ak0_m_ak1 = DeviceOp::MakeAGridDescriptor_AK0_M_AK1( + problem_desc_vec[i].M, problem_desc_vec[i].K, problem_desc_vec[i].StrideA); + const auto b_grid_desc_bk0_n_bk1 = DeviceOp::MakeBGridDescriptor_BK0_N_BK1( + problem_desc_vec[i].K, problem_desc_vec[i].N, problem_desc_vec[i].StrideB0); + const auto b1_grid_desc_bk0_n_bk1 = DeviceOp::MakeB1GridDescriptor_BK0_N_BK1( + problem_desc_vec[i].N, problem_desc_vec[i].O, problem_desc_vec[i].StrideB1); + const auto c_grid_desc_m_n = DeviceOp::MakeCGridDescriptor_M_N( + problem_desc_vec[i].c_gs_ms_os_lengths, problem_desc_vec[i].c_gs_ms_os_strides); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n); + + const index_t BlockStart = grid_size_; + const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart); + const index_t grid_size_grp = block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * + problem_desc_vec[i].Batch; + const index_t BlockEnd = grid_size_ + grid_size_grp; + + // batch stride + // TODO ANT: only keep batch stride in tensor desc to reduce scalar cache pressure + const auto c_grid_desc_g_m_n = DeviceOp::MakeCGridDescriptor_G_M_N( + problem_desc_vec[i].c_gs_ms_os_lengths, problem_desc_vec[i].c_gs_ms_os_strides); + const auto compute_base_ptr_of_batch = + ComputeBasePtrOfStridedBatch(problem_desc_vec[i].BatchStrideA, + problem_desc_vec[i].BatchStrideB0, + problem_desc_vec[i].BatchStrideB1, + c_grid_desc_g_m_n); + + grid_size_ += grid_size_grp; + + group_kernel_args_.push_back({p_a_grid, + p_b_grid, + p_b1_grid, + p_c_grid, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n), + compute_base_ptr_of_batch, + block_2_ctile_map, + BlockStart, + BlockEnd}); + + group_device_args_.push_back({problem_desc_vec[i].M, + problem_desc_vec[i].N, + problem_desc_vec[i].K, + problem_desc_vec[i].O, + problem_desc_vec[i].c_gs_ms_os_lengths.back(), + problem_desc_vec[i].c_gs_ms_os_strides.back(), + c_grid_desc_m_n}); + } + } + + std::vector group_kernel_args_; + std::vector group_device_args_; + + std::size_t group_count_; + index_t grid_size_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + AccElementwiseOperation acc_element_op_; + B1ElementwiseOperation b1_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!DeviceOp::IsSupportedArgument(arg)) + { + throw std::runtime_error("wrong! unsupported argument"); + } + + bool all_has_main_k_block_loop = true; + bool some_has_main_k_block_loop = false; + for(std::size_t i = 0; i < arg.group_count_; i++) + { + const auto K = arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) * + arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2); + const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K); + all_has_main_k_block_loop &= y; + some_has_main_k_block_loop |= y; + } + + hipGetErrorString(hipMemcpy(arg.p_workspace_, + arg.group_kernel_args_.data(), + arg.group_kernel_args_.size() * sizeof(GroupKernelArg), + hipMemcpyHostToDevice)); + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_) { + const auto kernel = + kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.group_count_, + arg.a_element_op_, + arg.b_element_op_, + arg.acc_element_op_, + arg.b1_element_op_, + arg.c_element_op_); + }; + + // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need + // to concern Gemm0's loop + if(all_has_main_k_block_loop) + { + ave_time = launch_kernel(integral_constant{}); + } + else if(!some_has_main_k_block_loop) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + throw std::runtime_error("wrong! all gemm problems have to simultaneously meet " + "has_main_k_block_loop or no_main_k_block_loop"); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + { + return false; + } + + bool all_has_main_k_block_loop = true; + bool some_has_main_k_block_loop = false; + + for(std::size_t i = 0; i < arg.group_count_; i++) + { + const auto& kernel_arg = arg.group_kernel_args_[i]; + const auto& device_arg = arg.group_device_args_[i]; + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_m = device_arg.c_grid_desc_m_n_.GetLength(I0); + const index_t c_gemm1n = device_arg.c_grid_desc_m_n_.GetLength(I1); + const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); + const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); + if(!(c_m == a_m && c_gemm1n == b1_gemm1n)) + { + return false; + } + + // Check if having main loop + const auto K = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * + kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K); + all_has_main_k_block_loop &= y; + some_has_main_k_block_loop |= y; + + // Note: we need raw lengths since threadwise copy can not handle vector load when + // part of vector is out of bounds + const auto MRaw = device_arg.M; + const auto NRaw = device_arg.N; + const auto KRaw = device_arg.K; + const auto Gemm1NRaw = device_arg.O; + + // Check scalar per vector requirement + const auto a_extent_lowest = + is_same_v ? KRaw : MRaw; + const auto b_extent_lowest = + is_same_v ? NRaw : KRaw; + const auto b1_extent_lowest = + is_same_v ? Gemm1NRaw : NRaw; + const auto c_extent_lowest = device_arg.c_extent_lowest_; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + + // Check vector store requirement; assumes last dimension in N to be contiguous + if(device_arg.c_stride_lowest_ != 1) + { + return false; + } + + if(!GridwiseGemm::CheckValidity(kernel_arg.a_grid_desc_ak0_m_ak1_, + kernel_arg.b_grid_desc_bk0_n_bk1_, + kernel_arg.b1_grid_desc_bk0_n_bk1_, + device_arg.c_grid_desc_m_n_, + kernel_arg.block_2_ctile_map_)) + { + return false; + } + } + + // all gemm problems have to simultaneously meet has_main_k_block_loop or + // no_main_k_block_loop + if(!(all_has_main_k_block_loop || !some_has_main_k_block_loop)) + { + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector p_a_vec, + std::vector p_b_vec, + std::vector p_b1_vec, + std::vector p_c_vec, + std::vector problem_desc_vec, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a_vec, + p_b_vec, + p_b1_vec, + p_c_vec, + problem_desc_vec, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(std::vector p_a_vec, + std::vector p_b_vec, + std::vector p_b1_vec, + std::vector p_c_vec, + std::vector problem_desc_vec, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(p_a_vec, + p_b_vec, + p_b1_vec, + p_c_vec, + problem_desc_vec, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << Gemm1NPerBlock << ", " + << Gemm1KPerBlock << ", " + << B1K1 << ", " + << getGemmSpecializationString(GemmSpec) << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + return dynamic_cast(p_arg)->group_count_ * sizeof(GroupKernelArg); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp index 498a88afe0..3591845095 100644 --- a/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp +++ b/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -486,4 +486,48 @@ __host__ __device__ bool DefaultValidCTileIndex(const CTileIdx& c_tile_idx, return is_valid; } +// This wrapper class is for grouped gemm where it subtracts blockIdx by a value so that the +// workgroups assigned to a given gemm problem have top index offsetted to range [0, +// grid_size_per_gemm] +template +struct OffsettedBlockToCTileMap +{ + using underlying_type = UnderlyingBlockToCTileMap; + + OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start) + { + block_to_ctile_map_ = block_to_ctile_map; + block_start_ = block_start; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] - block_start_)); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + template + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); + } + + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t block_start_; +}; + } // namespace ck