From b9f7381f95a1694d7e1b66a6e036930eaa74954d Mon Sep 17 00:00:00 2001 From: msaffari-amd Date: Mon, 13 Oct 2025 12:30:28 +0200 Subject: [PATCH] [CK Tile] contraction multi d - kernel & example (#2901) * Initial commit. create batched_contraction_kernel file * initial problem definition * implement initial example to launch kernel * add universal gemm to contraction. initial phase * complete implementation for special case all Dims are 1 and no Ds * clean code * initial changes to support multi dimensional G * more progress in implementing multiple G * tmp commit * manage dynamic NumDimG in kernel * improving example for multi M,N,K,G handling. start generalizing kernel. it is a temporary commit * implement the example for general Multi dimension G M N K and test different reference calculation algorithms * 2 functions for reference using multi dimensional and flat indexing * clean the code for muti dimentional G, M, N, K contraction and add some logs * Add Make descriptor function in kernel for merging Ms, Ns, Ks for A, B, E * some cleaning on kernel * clean the code for calculating the offsets from flatten batch number * Start adding MultiD support to kernel and example * more changes to manage multi D in kernel and example * manage passing multi d to kernel and testing. * complete multi D support in kernel. modify example code to support it * Correct algorithm to calc the correct offset values for D tensor batches and some code cleaning * Minor fix * Generalize example code for variable NumD tensors and apply cleanup based on review feedback * Refactored code and addressed review feedback * refactoring, cleaning, add documents, in kernel side and example codes * Optimize batch offset calculation in kernel * Inline CalculateBatchOffset in batched contraction kernel, update CHANGELOG.md --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> [ROCm/composable_kernel commit: e9f0cc83a8f3f94ad8462e50a9d9a92d8dca3388] --- CHANGELOG.md | 1 + .../41_batched_contraction/CMakeLists.txt | 7 + .../batched_contraction.cpp | 245 ++++++++ .../contraction_utils.hpp | 146 +++++ .../run_batched_contraction_example.inc | 405 ++++++++++++++ example/ck_tile/CMakeLists.txt | 1 + .../reference_batched_contraction.hpp | 265 +++++++++ include/ck_tile/ops/batched_contraction.hpp | 9 + .../kernel/batched_contraction_kernel.hpp | 522 ++++++++++++++++++ .../pipeline/batched_contraction_problem.hpp | 32 ++ .../utils/tensor_descriptor_utils.hpp | 169 ++++++ 11 files changed, 1802 insertions(+) create mode 100644 example/ck_tile/41_batched_contraction/CMakeLists.txt create mode 100644 example/ck_tile/41_batched_contraction/batched_contraction.cpp create mode 100644 example/ck_tile/41_batched_contraction/contraction_utils.hpp create mode 100644 example/ck_tile/41_batched_contraction/run_batched_contraction_example.inc create mode 100644 include/ck_tile/host/reference/reference_batched_contraction.hpp create mode 100644 include/ck_tile/ops/batched_contraction.hpp create mode 100644 include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp create mode 100644 include/ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp create mode 100644 include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index a8fe7b4afb..9de78f3043 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added the row-wise column-wise quantization for CK_TILE GEMM & CK_TILE Grouped GEMM. * Added support for f32 to FMHA (fwd/bwd). * Added tensor-wise quantization for CK_TILE GEMM. +* Added support for batched contraction kernel. * Added pooling kernel in CK_TILE ### Optimized diff --git a/example/ck_tile/41_batched_contraction/CMakeLists.txt b/example/ck_tile/41_batched_contraction/CMakeLists.txt new file mode 100644 index 0000000000..10b2e48cbf --- /dev/null +++ b/example/ck_tile/41_batched_contraction/CMakeLists.txt @@ -0,0 +1,7 @@ +add_executable(tile_example_batched_contraction EXCLUDE_FROM_ALL batched_contraction.cpp) +set(EXAMPLE_CONTRACTION_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_CONTRACTION_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +target_compile_options(tile_example_batched_contraction PRIVATE ${EXAMPLE_CONTRACTION_COMPILE_OPTIONS}) diff --git a/example/ck_tile/41_batched_contraction/batched_contraction.cpp b/example/ck_tile/41_batched_contraction/batched_contraction.cpp new file mode 100644 index 0000000000..ea78f09dff --- /dev/null +++ b/example/ck_tile/41_batched_contraction/batched_contraction.cpp @@ -0,0 +1,245 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/host.hpp" + +#include "ck_tile/ops/batched_contraction.hpp" +#include "contraction_utils.hpp" + +template + +float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs& args, + const ck_tile::stream_config& s) +{ + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = false; + + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + + using Problem = ck_tile::BatchedContractionProblem; + + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE; + + ck_tile::index_t K_total = 1; + for(ck_tile::index_t i = NumDimG + NumDimM; i < NumDimG + NumDimM + NumDimK; ++i) + { + K_total *= args.A_dims[i]; + } + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_total); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + constexpr auto memory_operation = + ck_tile::memory_operation_enum::set; // Always set (no atomic_add) + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = GEMM_PIPELINE; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = + ck_tile::BatchedContractionKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::GetBlockSize(); + + if(!Kernel::IsSupportedArguments(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping contraction!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetKernelName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << GemmPipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + auto kernel = ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs); + + ave_time = ck_tile::launch_kernel(s, kernel); + + return ave_time; + }; + + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + + return ave_time; +} + +#define HANDLE_CASE(G, M, N, K) \ + if(num_g_dims == G && num_m_dims == M && num_n_dims == N && num_k_dims == K) \ + { \ + return batched_contraction_impl(args, s); \ + } + +template +float batched_contraction(const ck_tile::BatchedContractionHostArgs& args, + const ck_tile::stream_config& s, + ck_tile::index_t num_g_dims, + ck_tile::index_t num_m_dims, + ck_tile::index_t num_n_dims, + ck_tile::index_t num_k_dims) +{ + std::cout << "Dimensions: G=" << num_g_dims << ", M=" << num_m_dims << ", N=" << num_n_dims + << ", K=" << num_k_dims << std::endl; + + HANDLE_CASE(1, 1, 1, 1); + HANDLE_CASE(2, 1, 1, 1); + HANDLE_CASE(2, 2, 2, 1); + HANDLE_CASE(1, 2, 1, 1); + HANDLE_CASE(1, 1, 1, 2); + HANDLE_CASE(2, 2, 2, 2); + HANDLE_CASE(4, 4, 4, 4); + + throw std::runtime_error( + "Unsupported dimension combination: G=" + std::to_string(num_g_dims) + + ", M=" + std::to_string(num_m_dims) + ", N=" + std::to_string(num_n_dims) + + ", K=" + std::to_string(num_k_dims) + ". Please add this combination to the kernel."); +} + +#include "run_batched_contraction_example.inc" + +int main(int argc, char* argv[]) +{ + try + { + return !run_batched_contraction_example(argc, argv); + } + catch(const std::runtime_error& e) + { + std::cerr << "Runtime error: " << e.what() << '\n'; + return EXIT_FAILURE; + } +} diff --git a/example/ck_tile/41_batched_contraction/contraction_utils.hpp b/example/ck_tile/41_batched_contraction/contraction_utils.hpp new file mode 100644 index 0000000000..6a75f1c04e --- /dev/null +++ b/example/ck_tile/41_batched_contraction/contraction_utils.hpp @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" + +struct AddDs +{ + template + CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void + { + const float x0_f = + ck_tile::type_convert(c) + (ck_tile::type_convert(ds) + ...); + + e = ck_tile::type_convert(x0_f); + } +}; + +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave + +template +struct BatchedContractionTypeConfig +{ + using ADataType = DataType; + using BDataType = DataType; + using AccDataType = float; + using EDataType = DataType; + using DDataType = DataType; +}; + +using ContractionTypes = BatchedContractionTypeConfig; + +using ADataType = ContractionTypes::ADataType; +using BDataType = ContractionTypes::BDataType; +using AccDataType = ContractionTypes::AccDataType; +using EDataType = ContractionTypes::EDataType; +using DDataType = ContractionTypes::DDataType; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m_dims", "4,256", "M dimensions separated by comma (e.g., '16,32' for 2D M)") + .insert("n_dims", "16,128", "N dimensions separated by comma (e.g., '32,32' for 2D N)") + .insert("k_dims", "64", "K dimensions separated by comma (e.g., '64,32' for 2D K)") + .insert( + "g_dims", "1,2", "G dimensions separated by comma (e.g., '4,2' for 2D, '2,3,4' for 3D)") + .insert("stride_a", "0", "Custom A tensor leading dimension stride (0 = auto)") + .insert("stride_b", "0", "Custom B tensor leading dimension stride (0 = auto)") + .insert("stride_e", "0", "Custom E tensor leading dimension stride (0 = auto)") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Col by default") + .insert("e_layout", "R", "E tensor data layout - Row by default") + .insert("v", "1", "0. No validation, 1. Validation on CPU") + .insert("prec", "fp16", "data type. fp32/fp16/bf16") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "10", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value") + .insert("log", "1", "log level for debugging"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// Helper function to parse G, M, N, K dimensions from string +std::vector parse_dimensions(const std::string& dims_str) +{ + std::vector dims; + std::stringstream ss(dims_str); + std::string token; + + while(std::getline(ss, token, ',')) + { + dims.push_back(std::stoi(token)); + } + + if(dims.empty()) + { + throw std::invalid_argument("Dimensions cannot be empty"); + } + + return dims; +} + +// Helper function to Calculate total elements from multi-dimensional vector +ck_tile::index_t calculate_total_elements(const std::vector& dims) +{ + ck_tile::index_t total = 1; + for(auto dim : dims) + { + total *= dim; + } + return total; +} + +/** + * @brief Flattens a list of tensor dimension components into a single dimension vector. + * + * This function takes a list of dimension vectors (e.g., representing different components + * such as G, M, N, or K dimensions) and concatenates them into a single vector. + * + * Example: + * Input: {{G0, G1}, {M0, M1}, {K0}} + * Output: {G0, G1, M0, M1, K0} + * + * @param dim_components A vector of vectors, where each inner vector represents a set of tensor + * dimensions. + * @return A single vector containing all dimensions concatenated in order. + */ +std::vector +concatenate_dim_components(const std::vector>& dim_components) +{ + std::vector result; + + // Concatenate all dimension components into a single vector + for(const auto& component : dim_components) + { + result.insert(result.end(), component.begin(), component.end()); + } + + return result; +} + +// Helper function for printing dimensions +void print_dims(const std::string& name, + const std::vector& dims, + ck_tile::index_t total) +{ + std::cout << name << ": ["; + for(size_t i = 0; i < dims.size(); ++i) + { + std::cout << dims[i]; + if(i < dims.size() - 1) + std::cout << ","; + } + std::cout << "] "; + if(total != 0) + std::cout << "(total=" << total << ")"; + std::cout << std::endl; +} diff --git a/example/ck_tile/41_batched_contraction/run_batched_contraction_example.inc b/example/ck_tile/41_batched_contraction/run_batched_contraction_example.inc new file mode 100644 index 0000000000..9bc09a6c9c --- /dev/null +++ b/example/ck_tile/41_batched_contraction/run_batched_contraction_example.inc @@ -0,0 +1,405 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "contraction_utils.hpp" +#include "ck_tile/host/reference/reference_batched_contraction.hpp" + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +float invoke_batched_contraction_kernel( + const void* a_full_dims_dev_buf, + const void* b_full_dims_dev_buf, + const std::array& ds_dev_buf, + void* e_full_dims_dev_buf, + const std::vector& G_dims, + const std::vector& M_dims, + const std::vector& N_dims, + const std::vector& K_dims, + const std::vector& A_dims, // [G0,G1,..,M0,M1,..,K0,K1,..] + const std::vector& B_dims, // [G0,G1,..,N0,N1,..,K0,K1,..] + const std::array, DsDataType::size()>& + Ds_dims, // [G0, G1, ..., M0, M1, ... , N0, N1, ...][NumDTensor] + const std::vector& E_dims, // [G0,G1,..,M0,M1,..,N0,N1,..] + const std::vector& A_strides, // [G0,G1,..,M0,M1,..,K0,K1,..] + const std::vector& B_strides, // [G0,G1,..,N0,N1,..,K0,K1,..] + const std::array, DsDataType::size()>& Ds_strides, + const std::vector& E_strides, // [G0,G1,..,M0,M1,..,N0,N1,..] + ck_tile::index_t kbatch, + int n_warmup, + int n_repeat) +{ + std::cout << "Creating BatchedContractionHostArgs..." << std::endl; + + ck_tile::BatchedContractionHostArgs args(a_full_dims_dev_buf, // a_ptr + b_full_dims_dev_buf, // b_ptr + ds_dev_buf, // ds_ptr + e_full_dims_dev_buf, // e_ptr + kbatch, // k_batch + A_dims, // A_dims + B_dims, // B_dims + Ds_dims, // Ds_dims + E_dims, // E_dims + A_strides, // A_strides + B_strides, // B_strides + Ds_strides, // Ds_strides + E_strides // E_strides + ); + + std::cout << "Calling batched_contraction with dimensions: G=" << G_dims.size() + << ", M=" << M_dims.size() << ", N=" << N_dims.size() << ", K=" << K_dims.size() + << std::endl; + + float ave_time = batched_contraction( + args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, + G_dims.size(), // num_g_dims + M_dims.size(), // num_m_dims + N_dims.size(), // num_n_dims + K_dims.size() // num_k_dims + ); + + return ave_time; +} + +template +int run_batched_contraction_example_with_layouts( + int argc, + char* argv[], + [[maybe_unused]] const ALayout a_layout = ALayout{}, + [[maybe_unused]] const BLayout b_layout = BLayout{}, + [[maybe_unused]] const DLayout d_layout = DLayout{}, + [[maybe_unused]] const ELayout e_layout = ELayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + std::vector G_dims = parse_dimensions(arg_parser.get_str("g_dims")); + std::vector M_dims = parse_dimensions(arg_parser.get_str("m_dims")); + std::vector N_dims = parse_dimensions(arg_parser.get_str("n_dims")); + std::vector K_dims = parse_dimensions(arg_parser.get_str("k_dims")); + + constexpr ck_tile::index_t NumDTensor = 2; + + ck_tile::index_t G_total = calculate_total_elements(G_dims); + ck_tile::index_t M_total = calculate_total_elements(M_dims); + ck_tile::index_t N_total = calculate_total_elements(N_dims); + ck_tile::index_t K_total = calculate_total_elements(K_dims); + + std::vector A_dims = + concatenate_dim_components({G_dims, M_dims, K_dims}); // [G0,G1,..,M0,M1,..,K0,K1,..] + std::vector B_dims = + concatenate_dim_components({G_dims, N_dims, K_dims}); // [G0,G1,..,N0,N1,..,K0,K1,..] + std::vector E_dims = + concatenate_dim_components({G_dims, M_dims, N_dims}); // [G0,G1,..,M0,M1,..,N0,N1,..] + + std::array, NumDTensor> Ds_dims; + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + Ds_dims[d] = E_dims; + } + + auto convert_strides = [](const std::vector& strides) { + std::vector converted(strides.size()); + std::copy(strides.begin(), strides.end(), converted.begin()); + return converted; + }; + + ck_tile::HostTensorDescriptor a_desc(A_dims); + ck_tile::HostTensorDescriptor b_desc(B_dims); + ck_tile::HostTensorDescriptor e_desc(E_dims); + std::array ds_descs; + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + ds_descs[d] = ck_tile::HostTensorDescriptor(Ds_dims[d], e_desc.get_strides()); + } + + std::vector A_strides = convert_strides(a_desc.get_strides()); + std::vector B_strides = convert_strides(b_desc.get_strides()); + std::vector E_strides = convert_strides(e_desc.get_strides()); + + std::array, NumDTensor> Ds_strides; + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + Ds_strides[d] = convert_strides(ds_descs[d].get_strides()); + } + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + + print_dims("G_dims", G_dims, G_total); + print_dims("M_dims", M_dims, M_total); + print_dims("N_dims", N_dims, N_total); + print_dims("K_dims", K_dims, K_total); + + std::cout << "NumDTensor: " << NumDTensor << std::endl; + std::cout << "\n=== Tensor Shapes for Kernel ===" << std::endl; + print_dims("A_dims", A_dims, 0); + print_dims("B_dims", B_dims, 0); + print_dims("E_dims", E_dims, 0); + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + print_dims("Ds[" + std::to_string(d) + "]_dims", Ds_dims[d], 0); + } + + std::cout << "\n=== Tensor Strides ===" << std::endl; + print_dims("A_strides", A_strides, 0); + print_dims("B_strides", B_strides, 0); + print_dims("E_strides", E_strides, 0); + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + print_dims("Ds[" + std::to_string(d) + "]_strides", Ds_strides[d], 0); + } + + std::cout << "===============================================\n" << std::endl; + + ck_tile::HostTensor<::ADataType> a_full_dims_host(a_desc); + ck_tile::HostTensor<::BDataType> b_full_dims_host(b_desc); + ck_tile::HostTensor<::EDataType> e_full_dims_host(e_desc); + + std::vector> ds_full_dims_host; + for(int d = 0; d < NumDTensor; ++d) + { + ds_full_dims_host.emplace_back(ck_tile::HostTensor<::DDataType>(ds_descs[d])); + } + + ck_tile::FillUniformDistribution<::ADataType>{-5.f, 5.f, std::nullopt}(a_full_dims_host); + ck_tile::FillUniformDistribution<::BDataType>{-5.f, 5.f, std::nullopt}(b_full_dims_host); + + ck_tile::DeviceMem a_full_dims_dev_buf(a_full_dims_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_full_dims_dev_buf(b_full_dims_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem e_full_dims_dev_buf(e_full_dims_host.get_element_space_size_in_bytes()); + + a_full_dims_dev_buf.ToDevice(a_full_dims_host.data()); + b_full_dims_dev_buf.ToDevice(b_full_dims_host.data()); + + for(int d = 0; d < NumDTensor; ++d) + { + ck_tile::FillUniformDistribution<::DDataType>{-2.f, 2.f, std::nullopt}( + ds_full_dims_host[d]); + } + + std::vector> ds_full_dims_dev_buf; + for(int d = 0; d < NumDTensor; ++d) + { + ds_full_dims_dev_buf.push_back(std::make_unique( + ds_full_dims_host[d].get_element_space_size_in_bytes())); + ds_full_dims_dev_buf[d]->ToDevice(ds_full_dims_host[d].data()); + } + std::array ds_ptr_buf; + for(int d = 0; d < NumDTensor; ++d) + { + ds_ptr_buf[d] = ds_full_dims_dev_buf[d]->GetDeviceBuffer(); + } + + e_full_dims_dev_buf.SetZero(); + e_full_dims_host.SetZero(); + + std::cout << "\n=== Running GPU Kernel ===" << std::endl; + + using DsDataType = ck_tile::tuple_array<::DDataType, NumDTensor>; + using DsLayout = ck_tile::tuple_array; + using CDEElementWise = + std::conditional_t; + + float ave_time = + invoke_batched_contraction_kernel<::ADataType, + ::BDataType, + DsDataType, + ::AccDataType, + ::EDataType, + ALayout, + BLayout, + DsLayout, + ELayout, + CDEElementWise>(a_full_dims_dev_buf.GetDeviceBuffer(), + b_full_dims_dev_buf.GetDeviceBuffer(), + ds_ptr_buf, + e_full_dims_dev_buf.GetDeviceBuffer(), + G_dims, + M_dims, + N_dims, + K_dims, + A_dims, + B_dims, + Ds_dims, + E_dims, + A_strides, + B_strides, + Ds_strides, + E_strides, + kbatch, + n_warmup, + n_repeat); + + std::string op_name{ + "Multi-Dimensional Batched Contraction : G: " + std::to_string(G_dims.size()) + + "D, M: " + std::to_string(M_dims.size()) + "D, N: " + std::to_string(N_dims.size()) + + "D, K: " + std::to_string(K_dims.size()) + "D"}; + + std::size_t flop = std::size_t(2) * G_total * M_total * N_total * K_total + + NumDTensor * K_total * M_total * N_total; // Number of operations + std::size_t num_byte = + sizeof(::ADataType) * G_total * M_total * K_total + // A tensor size + sizeof(::BDataType) * G_total * N_total * K_total + // B tensor size + sizeof(::DDataType) * NumDTensor * G_total * M_total * N_total + // D tensors + sizeof(::EDataType) * G_total * M_total * N_total; // E tensor size + + float tflops = static_cast(flop) / 1.E9 / ave_time; // TFlops calculation + float gb_per_sec = num_byte / 1.E6 / ave_time; // GB/s calculation + print_dims("G_dims", G_dims, G_total); + print_dims("M_dims", M_dims, M_total); + print_dims("N_dims", N_dims, N_total); + print_dims("K_dims", K_dims, K_total); + + std::cout << " Performance: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + + std::cout << "===============================================" << std::endl; + + e_full_dims_dev_buf.FromDevice(e_full_dims_host.data()); + std::cout << "GPU results retrieved from device." << std::endl; + + bool pass = true; + if(arg_parser.get_int("v") == 1) + { + + std::cout << "Computing CPU reference..." << std::endl; + + ck_tile::HostTensor<::EDataType> e_full_dims_host_ref( + ck_tile::HostTensorDescriptor(E_dims, E_strides)); + e_full_dims_host_ref.SetZero(); + + auto start_time = std::chrono::high_resolution_clock::now(); + + calculate_reference_flat_indexing(a_full_dims_host, + b_full_dims_host, + ds_full_dims_host, + e_full_dims_host_ref, + G_total, + M_total, + N_total, + K_total, + CDEElementWise{}); + + auto end_time = std::chrono::high_resolution_clock::now(); + auto duration = + std::chrono::duration_cast(end_time - start_time); + + std::cout << "CPU reference completed in " << duration.count() << "ms" << std::endl; + + const float max_accumulated_value = + *std::max_element(e_full_dims_host_ref.mData.begin(), e_full_dims_host_ref.mData.end()); + + const auto rtol_atol = + calculate_rtol_atol<::ADataType, ::BDataType, ::EDataType, ::AccDataType>( + K_total, kbatch, max_accumulated_value); + + pass = ck_tile::check_err(e_full_dims_host, + e_full_dims_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "The CPU verification result is: " << (pass ? "correct" : "fail") << std::endl; + + std::cout << "===============================================" << std::endl; + + std::cout << "\n=== Random Samples of Reference and Result ===" << std::endl; + + // Generate 10 random indices + std::vector random_indices; + std::size_t total_elements = e_full_dims_host_ref.mData.size(); + std::mt19937 rng(std::random_device{}()); + std::uniform_int_distribution dist(0, total_elements - 1); + + for(int i = 0; i < 10; ++i) + { + random_indices.push_back(dist(rng)); + } + + // Print the values at the random indices + for(std::size_t idx : random_indices) + { + std::cout << "Index " << idx << ": " + << "ref=" << static_cast(e_full_dims_host_ref.mData[idx]) << ", " + << "GPU=" << static_cast(e_full_dims_host.mData[idx]) << std::endl; + } + + std::cout << "===============================================" << std::endl; + } + + return pass; +} + +int run_batched_contraction_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + if(a_layout == "R" && b_layout == "C") + { + return run_batched_contraction_example_with_layouts(argc, argv, Row{}, Col{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and E tensors! " + "Only R-C-R supported for now."); + } +} diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 7a8ae065db..5e178e3669 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -27,3 +27,4 @@ add_subdirectory(36_pooling) add_subdirectory(38_block_scale_gemm) add_subdirectory(39_copy) add_subdirectory(40_streamk_gemm) +add_subdirectory(41_batched_contraction) diff --git a/include/ck_tile/host/reference/reference_batched_contraction.hpp b/include/ck_tile/host/reference/reference_batched_contraction.hpp new file mode 100644 index 0000000000..1ce071969c --- /dev/null +++ b/include/ck_tile/host/reference/reference_batched_contraction.hpp @@ -0,0 +1,265 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace ck_tile { + +template + +void calculate_reference_flat_indexing( + const ck_tile::HostTensor& a_full_dims, + const ck_tile::HostTensor& b_full_dims, + const std::vector>& ds_full_dims_host, + ck_tile::HostTensor& e_full_dims_host_ref, + ck_tile::index_t G_total, + ck_tile::index_t M_total, + ck_tile::index_t N_total, + ck_tile::index_t K_total, + const CDEElementWise& cde_elementwise) +{ + std::cout << "Calculating reference using optimized flat indexing with parallel processing..." + << std::endl; + + // Parallel computation over G and M dimensions using pattern from reference_batched_gemm.hpp + auto f_gm = [&](auto g_flat, auto m_flat) { + for(ck_tile::index_t n_flat = 0; n_flat < N_total; ++n_flat) + { + AccDataType sum = 0; + + // Compute dot product over K dimension + for(ck_tile::index_t k_flat = 0; k_flat < K_total; ++k_flat) + { + auto a_val = + a_full_dims.mData[g_flat * M_total * K_total + m_flat * K_total + k_flat]; + auto b_val = + b_full_dims.mData[g_flat * N_total * K_total + n_flat * K_total + k_flat]; + sum += static_cast(a_val) * static_cast(b_val); + } + + // Apply elementwise operation with D tensors + EDataType result = static_cast(sum); + if(ds_full_dims_host.size() == 0) + { + ; + } + else if(ds_full_dims_host.size() == 1) + { + cde_elementwise(result, + ck_tile::type_convert(sum), + ck_tile::type_convert( + ds_full_dims_host[0].mData[g_flat * M_total * N_total + + m_flat * N_total + n_flat])); + } + else if(ds_full_dims_host.size() == 2) + { + cde_elementwise( + result, + ck_tile::type_convert(sum), + ck_tile::type_convert( + ds_full_dims_host[0] + .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), + ck_tile::type_convert( + ds_full_dims_host[1] + .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat])); + } + else if(ds_full_dims_host.size() == 3) + { + cde_elementwise( + result, + ck_tile::type_convert(sum), + ck_tile::type_convert( + ds_full_dims_host[0] + .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), + ck_tile::type_convert( + ds_full_dims_host[1] + .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), + ck_tile::type_convert( + ds_full_dims_host[2] + .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat])); + } + else if(ds_full_dims_host.size() == 4) + { + cde_elementwise( + result, + ck_tile::type_convert(sum), + ck_tile::type_convert( + ds_full_dims_host[0] + .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), + ck_tile::type_convert( + ds_full_dims_host[1] + .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), + ck_tile::type_convert( + ds_full_dims_host[2] + .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), + ck_tile::type_convert( + ds_full_dims_host[3] + .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat])); + } + else + { + throw std::runtime_error("Unsupported NumDTensor for reference calculation"); + } + + // Store result + e_full_dims_host_ref.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat] = + static_cast(result); + } + }; + + // Execute parallel computation using hardware concurrency + // Parallelize over G_total and M_total dimensions for optimal CPU utilization + make_ParallelTensorFunctor(f_gm, G_total, M_total)(std::thread::hardware_concurrency()); +} + +template +void calculate_reference_multi_dimensional( + const HostTensor& a_full_dims, + const HostTensor& b_full_dims, + const std::vector>& ds_full_dims_host, + HostTensor& e_full_dims_host_ref, + const std::vector& G_dims, + const std::vector& M_dims, + const std::vector& N_dims, + const std::vector& K_dims, + const std::vector& A_dims, + const std::vector& B_dims, + const std::vector& E_dims, + const CDEElementWise& cde_elementwise) +{ + std::cout << "Calculating reference using multi-dimensional indexing..." << std::endl; + + std::vector g_idx(G_dims.size()); + std::vector m_idx(M_dims.size()); + std::vector n_idx(N_dims.size()); + std::vector k_idx(K_dims.size()); + std::vector a_idx, b_idx, e_idx; + + a_idx.reserve(A_dims.size()); + b_idx.reserve(B_dims.size()); + e_idx.reserve(E_dims.size()); + + for(ck_tile::index_t g_flat = 0; g_flat < calculate_total_elements(G_dims); ++g_flat) + { + ck_tile::index_t temp = g_flat; + for(int i = G_dims.size() - 1; i >= 0; --i) + { + g_idx[i] = temp % G_dims[i]; + temp /= G_dims[i]; + } + + for(ck_tile::index_t m_flat = 0; m_flat < calculate_total_elements(M_dims); ++m_flat) + { + temp = m_flat; + for(int i = M_dims.size() - 1; i >= 0; --i) + { + m_idx[i] = temp % M_dims[i]; + temp /= M_dims[i]; + } + + for(ck_tile::index_t n_flat = 0; n_flat < calculate_total_elements(N_dims); ++n_flat) + { + temp = n_flat; + for(int i = N_dims.size() - 1; i >= 0; --i) + { + n_idx[i] = temp % N_dims[i]; + temp /= N_dims[i]; + } + + AccDataType sum = 0; + + for(ck_tile::index_t k_flat = 0; k_flat < calculate_total_elements(K_dims); + ++k_flat) + { + temp = k_flat; + for(int i = K_dims.size() - 1; i >= 0; --i) + { + k_idx[i] = temp % K_dims[i]; + temp /= K_dims[i]; + } + + a_idx.clear(); + b_idx.clear(); + + a_idx.insert(a_idx.end(), g_idx.begin(), g_idx.end()); + a_idx.insert(a_idx.end(), m_idx.begin(), m_idx.end()); + a_idx.insert(a_idx.end(), k_idx.begin(), k_idx.end()); + + b_idx.insert(b_idx.end(), g_idx.begin(), g_idx.end()); + b_idx.insert(b_idx.end(), n_idx.begin(), n_idx.end()); + b_idx.insert(b_idx.end(), k_idx.begin(), k_idx.end()); + + auto a_val = a_full_dims(a_idx); + auto b_val = b_full_dims(b_idx); + + sum += static_cast(a_val) * static_cast(b_val); + } + + e_idx.clear(); + e_idx.insert(e_idx.end(), g_idx.begin(), g_idx.end()); + e_idx.insert(e_idx.end(), m_idx.begin(), m_idx.end()); + e_idx.insert(e_idx.end(), n_idx.begin(), n_idx.end()); + + EDataType result = static_cast(sum); + if(ds_full_dims_host.size() == 0) + { + ; + } + else if(ds_full_dims_host.size() == 1) + { + cde_elementwise(result, + ck_tile::type_convert(sum), + ck_tile::type_convert(ds_full_dims_host[0](e_idx))); + } + else if(ds_full_dims_host.size() == 2) + { + cde_elementwise(result, + ck_tile::type_convert(sum), + ck_tile::type_convert(ds_full_dims_host[0](e_idx)), + ck_tile::type_convert(ds_full_dims_host[1](e_idx))); + } + else if(ds_full_dims_host.size() == 3) + { + cde_elementwise(result, + ck_tile::type_convert(sum), + ck_tile::type_convert(ds_full_dims_host[0](e_idx)), + ck_tile::type_convert(ds_full_dims_host[1](e_idx)), + ck_tile::type_convert(ds_full_dims_host[2](e_idx))); + } + else if(ds_full_dims_host.size() == 4) + { + cde_elementwise(result, + ck_tile::type_convert(sum), + ck_tile::type_convert(ds_full_dims_host[0](e_idx)), + ck_tile::type_convert(ds_full_dims_host[1](e_idx)), + ck_tile::type_convert(ds_full_dims_host[2](e_idx)), + ck_tile::type_convert(ds_full_dims_host[3](e_idx))); + } + else + { + throw std::runtime_error("Unsupported NumDTensor for reference calculation"); + } + + e_full_dims_host_ref(e_idx) = static_cast(result); + } + } + } +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/batched_contraction.hpp b/include/ck_tile/ops/batched_contraction.hpp new file mode 100644 index 0000000000..9162f421d1 --- /dev/null +++ b/include/ck_tile/ops/batched_contraction.hpp @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp" +#include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp new file mode 100644 index 0000000000..6d8f9f3f0e --- /dev/null +++ b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp @@ -0,0 +1,522 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp" +#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" + +/** + * @file batched_contraction_kernel.hpp + * @brief Batched Tensor Contraction Operations + * + * @section batched_contraction_overview What is Batched Tensor Contraction with Multiple D? + * + * Tensor contraction is a fundamental operation that generalizes matrix multiplication to + * multi-dimensional tensors. It performs element-wise multiplication and summation over + * shared dimensions + * + * **Beyond pure contraction, this kernel supports multiple auxiliary input tensors (D tensors)** + * that are fused with the contraction result through configurable epilogue operations, enabling + * efficient computation of complex tensor expressions in a single kernel launch. + * + * @subsection mathematical_formulation Mathematical Formulation + * + * For tensors A and B with arbitrary dimensionalities, the complete operation computes: + * + * **E[G₀,G₁,...,M₀,M₁,...,N₀,N₁,...] = epilogue_op(C, D₀, D₁, D₂, ...)** + * + * Where: + * **C[G₀,G₁,...,M₀,M₁,...,N₀,N₁,...] = Σ_{K₀,K₁,...} A[G₀,G₁,...,M₀,M₁,...,K₀,K₁,...] × + * B[G₀,G₁,...,N₀,N₁,...,K₀,K₁,...]** + * + * Where: + * - **G dimensions**: Batch dimensions (shared across A, B, and output E) + * - **M dimensions**: Row dimensions of the output matrix (from tensor A) + * - **N dimensions**: Column dimensions of the output matrix (from tensor B) + * - **K dimensions**: Contraction dimensions (summed over, present in both A and B) + * + * @subsection why_gemm_implementation Why Tensor Contraction Can Be Implemented Using GEMM + * + * **Mathematical Equivalence**: Tensor contraction is fundamentally equivalent to matrix + * multiplication when dimensions are appropriately flattened. The key insight is that the summation + * operation over shared dimensions (K dimensions) in tensor contraction is mathematically identical + * to the dot product computation in matrix multiplication. + * + * **Dimension Flattening Strategy**: + * - **M dimensions** (from tensor A) → Flattened into matrix rows (M_total) + * - **N dimensions** (from tensor B) → Flattened into matrix columns (N_total) + * - **K dimensions** (contraction dims) → Flattened into inner dimension (K_total) + * - **G dimensions** (batch dims) → Handled through batch processing + * + * **Mathematical Transformation**: + * ``` + * Original: E[g,m₀,m₁,n₀,n₁] = Σ_{k₀,k₁} A[g,m₀,m₁,k₀,k₁] × B[g,n₀,n₁,k₀,k₁] + * Flattened: E[g,M,N] = Σ_K A[g,M,K] × B[g,N,K] (where M=m₀×m₁, N=n₀×n₁, K=k₀×k₁) + * GEMM Form: E = A × Bᵀ + * + * **Why This Approach Is Optimal**: + * Rather than implementing tensor contraction from scratch, this kernel leverages the highly + * optimized `UniversalGemmKernel` as its computational backend. + * + * @subsection current_limitations Current Kernel Limitations + * + * **Layout Restrictions:** + * - **Row-Major Only**: All tensors must use row-major memory layout + * - **Packed Tensors**: Only contiguous/packed tensor layouts supported + * - **Hardcoded Strides**: stride_A = K_total, stride_B = K_total, stride_E = N_total + * - **D Tensor Layout**: All D tensors must match E tensor layout (stride_Ds = N_total) + * + * **Implementation Constraints:** + * - **Fixed Stride Calculation**: Strides are automatically calculated and cannot be customized + * - **No Column-Major**: Column-major or custom stride patterns not supported + * - **No Strided Access**: Non-contiguous tensor slicing not supported + * + * **Future Enhancements:** + * - Support for arbitrary stride patterns + * - Column-major and mixed layout support + * - Non-contiguous tensor operation support + */ + +namespace ck_tile { + +/// @brief Host arguments for batched tensor contraction operations. +/// +/// @par Overview +/// This structure encapsulates all host-side arguments required for batched tensor contraction. +/// It supports arbitrary number of batch dimensions (G), M dimensions, N dimensions, and K +/// dimensions. +/// +/// @par Tensor Layout Assumptions +/// - A tensor: [G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] +/// - B tensor: [G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] +/// - D tensors: [G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] (auxiliary input tensors) +/// - E tensor: [G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] (output tensor) +/// +/// @tparam NumDTensor Number of D (auxiliary input) tensors. Default is 0. +template +struct BatchedContractionHostArgs +{ + /// @brief Constructor for batched contraction host arguments. + /// + /// @param a_ptr_ Pointer to input tensor A + /// @param b_ptr_ Pointer to input tensor B + /// @param ds_ptr_ Array of pointers to auxiliary input tensors D + /// @param e_ptr_ Pointer to output tensor E + /// @param k_batch_ Number of k-splits for split-K batching + /// @param A_dims_ Dimension vector for tensor A: [G0, G1, ..., M0, M1, ..., K0, K1, ...] + /// @param B_dims_ Dimension vector for tensor B: [G0, G1, ..., N0, N1, ..., K0, K1, ...] + /// @param Ds_dims_ Dimension vectors for D tensors: [G0, G1, ..., M0, M1, ..., N0, N1, ...] + /// @param E_dims_ Dimension vector for tensor E: [G0, G1, ..., M0, M1, ..., N0, N1, ...] + /// @param A_strides_ Stride vector for tensor A: [G0, G1, ..., M0, M1, ..., K0, K1, ...] + /// @param B_strides_ Stride vector for tensor B: [G0, G1, ..., N0, N1, ..., K0, K1, ...] + /// @param Ds_strides_ Stride vectors for D tensors: [G0, G1, ..., M0, M1, ..., N0, N1, ...] + /// @param E_strides_ Stride vector for tensor E: [G0, G1, ..., M0, M1, ..., N0, N1, ...] + CK_TILE_HOST + BatchedContractionHostArgs( + const void* a_ptr_, + const void* b_ptr_, + const std::array& ds_ptr_, + void* e_ptr_, + ck_tile::index_t k_batch_, + const std::vector& A_dims_, // [G0, G1, ..., M0, M1, ... , K0, K1, ...] + const std::vector& B_dims_, // [G0, G1, ..., N0, N1, ... , K0, K1, ...] + const std::array, NumDTensor>& + Ds_dims_, // [G0, G1, ..., M0, M1, ... , N0, N1, ...][NumDTensor] + const std::vector& E_dims_, // [G0, G1, ..., M0, M1, ... , N0, N1, ...] + + const std::vector& A_strides_, // [G0, G1, ..., M0, M1, ...,K0, K1, ...] + const std::vector& B_strides_, // [G0, G1, ..., N0, N1, ...,K0, K1, ...] + const std::array, NumDTensor>& + Ds_strides_, // [G0, G1, ..., M0, M1, ...,N0, N1, ...] + const std::vector& + E_strides_) // [G0, G1, ..., M0, M1, ...,N0, N1, ...][NumDTensor] + + : a_ptr(a_ptr_), + b_ptr(b_ptr_), + ds_ptr(ds_ptr_), + e_ptr(e_ptr_), + k_batch(k_batch_), + A_dims(A_dims_), + B_dims(B_dims_), + Ds_dims(Ds_dims_), + E_dims(E_dims_), + A_strides(A_strides_), + B_strides(B_strides_), + Ds_strides(Ds_strides_), + E_strides(E_strides_) + { + } + + const void* a_ptr; ///< Pointer to input tensor A + const void* b_ptr; ///< Pointer to input tensor B + std::array ds_ptr; ///< Array of pointers to auxiliary input tensors D + void* e_ptr; ///< Pointer to output tensor E + ck_tile::index_t k_batch; ///< Number of k-splits for split-K batching + const std::vector + A_dims; ///< Dimension vector for tensor A: [G0, G1, ..., M0, M1, ..., K0, K1, ...] + const std::vector + B_dims; ///< Dimension vector for tensor B: [G0, G1, ..., N0, N1, ..., K0, K1, ...] + const std::array, NumDTensor> + Ds_dims; ///< Dimension vectors for D tensors: [G0, G1, ..., M0, M1, ..., N0, N1, ...] + const std::vector + E_dims; ///< Dimension vector for tensor E: [G0, G1, ..., M0, M1, ..., N0, N1, ...] + const std::vector + A_strides; ///< Stride vector for tensor A: [G0, G1, ..., M0, M1, ..., K0, K1, ...] + const std::vector + B_strides; ///< Stride vector for tensor B: [G0, G1, ..., N0, N1, ..., K0, K1, ...] + const std::array, NumDTensor> + Ds_strides; ///< Stride vectors for D tensors: [G0, G1, ..., M0, M1, ..., N0, N1, ...] + const std::vector + E_strides; ///< Stride vector for tensor E: [G0, G1, ..., M0, M1, ..., N0, N1, ...] +}; + +/// @brief Kernel arguments for batched tensor contraction operations. +/// +/// @tparam NumDimG Number of batch dimensions +/// @tparam NumDimM Number of M (output row) dimensions +/// @tparam NumDimN Number of N (output column) dimensions +/// @tparam NumDimK Number of K (contraction) dimensions +/// @tparam NumDTensor Number of auxiliary input D tensors. Default is 0. + +template +struct BatchedContractionKernelArgs +{ + const void* a_ptr; ///< Pointer to input tensor A + const void* b_ptr; ///< Pointer to input tensor B + std::array ds_ptr; ///< Array of pointers to auxiliary input tensors D + void* e_ptr; ///< Pointer to output tensor E + ck_tile::index_t k_batch; ///< Number of k-splits for split-K batching + + ck_tile::index_t M_dims[NumDimM]; ///< M dimension sizes: [M0, M1, M2, ..., M_{NumDimM-1}] + ck_tile::index_t N_dims[NumDimN]; ///< N dimension sizes: [N0, N1, N2, ..., N_{NumDimN-1}] + ck_tile::index_t K_dims[NumDimK]; ///< K dimension sizes: [K0, K1, K2, ..., K_{NumDimK-1}] + ck_tile::index_t + G_dims[NumDimG]; ///< G (batch) dimension sizes: [G0, G1, G2, ..., G_{NumDimG-1}] + + // Batch strides for efficient offset calculation + ck_tile::index_t batch_stride_A; ///< Batch stride for tensor A + ck_tile::index_t batch_stride_B; ///< Batch stride for tensor B + ck_tile::index_t batch_stride_E; ///< Batch stride for tensor E + std::array batch_stride_Ds; ///< Batch strides for D tensors + + ck_tile::index_t G_total; ///< Total batch size: G0 * G1 * ... * G_{NumDimG-1} + ck_tile::index_t M_total; ///< Total M dimension: M0 * M1 * ... * M_{NumDimM-1} + ck_tile::index_t N_total; ///< Total N dimension: N0 * N1 * ... * N_{NumDimN-1} + ck_tile::index_t K_total; ///< Total K dimension: K0 * K1 * ... * K_{NumDimK-1} + + ck_tile::index_t stride_A; ///< Leading dimension stride for tensor A (row-major: K_total) + ck_tile::index_t stride_B; ///< Leading dimension stride for tensor B (row-major: K_total) + std::array + stride_Ds; ///< Leading dimension strides for D tensors (row-major: N_total) + ck_tile::index_t stride_E; ///< Leading dimension stride for tensor E (row-major: N_total) +}; + +/// @brief GPU kernel for batched tensor contraction operations. +/// +/// @par Overview +/// This kernel performs batched tensor contraction operations using the underlying +/// UniversalGemmKernel. It supports arbitrary tensor dimensionalities (G, M, N, K) and +/// processes multiple batch instances in parallel. Each batch performs: E = +/// epilogue_op(contraction(A, B), D0, D1, ...). +/// +/// @tparam Problem_ Tensor contraction problem specification defining data types and dimensions +/// @tparam TilePartitioner_ Tile partitioning strategy for workload distribution +/// @tparam GemmPipeline_ GEMM computation pipeline for core matrix operations +/// @tparam EpiloguePipeline_ Epilogue pipeline for post-GEMM operations and tensor fusion + +template +struct BatchedContractionKernel +{ + // Type aliases for cleaner code and better readability + using Problem = ck_tile::remove_cvref_t; ///< Tensor contraction problem specification + using ADataType = + ck_tile::remove_cvref_t; ///< Data type for input tensor A + using BDataType = + ck_tile::remove_cvref_t; ///< Data type for input tensor B + using DsDataType = + ck_tile::remove_cvref_t; ///< Data types for auxiliary input + ///< tensors D + using EDataType = + ck_tile::remove_cvref_t; ///< Data type for output tensor E + + // Compile-time dimension constants extracted from problem specification + static constexpr ck_tile::index_t NumDimG = Problem::NumDimG; ///< Number of batch dimensions + static constexpr ck_tile::index_t NumDimM = + Problem::NumDimM; ///< Number of M (output row) dimensions + static constexpr ck_tile::index_t NumDimN = + Problem::NumDimN; ///< Number of N (output column) dimensions + static constexpr ck_tile::index_t NumDimK = + Problem::NumDimK; ///< Number of K (contraction) dimensions + static constexpr ck_tile::index_t NumDTensor = + Problem::NumDTensor; ///< Number of auxiliary input D tensors + + // Pipeline and partitioning strategy types + using TilePartitioner = + ck_tile::remove_cvref_t; ///< Tile partitioning strategy for workload + ///< distribution + using GemmPipeline = ck_tile::remove_cvref_t; ///< GEMM computation pipeline + using EpiloguePipeline = + ck_tile::remove_cvref_t; ///< Epilogue pipeline for post-GEMM operations + + // Underlying GEMM kernel that performs the actual computation + using UniversalGemmKernel = + ck_tile::UniversalGemmKernel; + + static constexpr ck_tile::index_t kBlockSize = + UniversalGemmKernel::kBlockSize; ///< GPU block size inherited from GEMM kernel + + using KernelArgs = + BatchedContractionKernelArgs; ///< Kernel + ///< argument + ///< structure + + /// @brief Returns the kernel name for debugging and profiling purposes. + /// @return Constant string identifier for this kernel + CK_TILE_HOST static constexpr auto GetKernelName() { return "batched_contraction_kernel"; } + + /// @brief Validates whether the given kernel arguments are supported. + /// @param kargs Kernel arguments to validate + /// @return True if arguments are supported, false otherwise + /// @details Checks underlying GEMM kernel support and ensures valid batch dimensions + CK_TILE_HOST static constexpr bool IsSupportedArguments(const KernelArgs& kargs) + { + typename UniversalGemmKernel::KernelArgs gemm_kargs{{kargs.a_ptr}, + {kargs.b_ptr}, + kargs.ds_ptr, + kargs.e_ptr, + kargs.M_total, + kargs.N_total, + kargs.K_total, + {kargs.stride_A}, + {kargs.stride_B}, + kargs.stride_Ds, + kargs.stride_E, + kargs.k_batch}; + + return UniversalGemmKernel::IsSupportedArgument(gemm_kargs) && kargs.G_total > 0; + } + + /// @brief Returns the shared memory size required by the kernel. + /// @return Shared memory size in bytes + /// @details Delegates to underlying GEMM kernel's shared memory requirements + CK_TILE_HOST static constexpr ck_tile::index_t GetSmemSize() + { + return UniversalGemmKernel::GetSmemSize(); + } + + /// @brief Returns the GPU block size for kernel launch. + /// @return 3D block dimensions for GPU kernel execution + CK_TILE_HOST static constexpr auto GetBlockSize() + { + return dim3(UniversalGemmKernel::kBlockSize); + } + + CK_TILE_HOST static constexpr auto GridSize(const KernelArgs& kargs) + { + return dim3( + TilePartitioner::GridSize(kargs.M_total, kargs.N_total), kargs.G_total, kargs.k_batch); + } + + CK_TILE_HOST static constexpr KernelArgs + MakeKernelArgs(const BatchedContractionHostArgs& host_args) + { + const auto expected_A_dims = NumDimG + NumDimM + NumDimK; + const auto expected_B_dims = NumDimG + NumDimN + NumDimK; + const auto expected_E_dims = NumDimG + NumDimM + NumDimN; + + if(host_args.A_dims.size() != expected_A_dims || + host_args.A_strides.size() != expected_A_dims) + { + throw std::invalid_argument("A dimension size mismatch"); + } + if(host_args.B_dims.size() != expected_B_dims || + host_args.B_strides.size() != expected_B_dims) + { + throw std::invalid_argument("B dimension size mismatch"); + } + if(host_args.E_dims.size() != expected_E_dims || + host_args.E_strides.size() != expected_E_dims) + { + throw std::invalid_argument("E dimension size mismatch"); + } + + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + if(host_args.Ds_dims[d].size() != expected_E_dims || + host_args.Ds_strides[d].size() != expected_E_dims) + { + throw std::invalid_argument("D dimension size mismatch"); + } + } + + KernelArgs kargs; + kargs.a_ptr = host_args.a_ptr; + kargs.b_ptr = host_args.b_ptr; + kargs.ds_ptr = host_args.ds_ptr; + kargs.e_ptr = host_args.e_ptr; + kargs.k_batch = host_args.k_batch; + + // Validate and set G dimensions (must be identical across all tensors) + for(ck_tile::index_t i = 0; i < NumDimG; ++i) + { + // All tensors must have same G dimensions for valid contraction + if(host_args.A_dims[i] != host_args.B_dims[i] || + host_args.A_dims[i] != host_args.E_dims[i]) + { + throw std::invalid_argument( + "All tensors must have identical G dimensions for valid contraction"); + } + + // Store G dimensions (same for all tensors) + kargs.G_dims[i] = host_args.A_dims[i]; + } + + // Set batch strides from the stride of last G dimension + kargs.batch_stride_A = host_args.A_strides[NumDimG - 1]; + kargs.batch_stride_B = host_args.B_strides[NumDimG - 1]; + kargs.batch_stride_E = host_args.E_strides[NumDimG - 1]; + + for(ck_tile::index_t i = 0; i < NumDimM; ++i) + { + kargs.M_dims[i] = host_args.A_dims[NumDimG + i]; + if(kargs.M_dims[i] != host_args.E_dims[NumDimG + i]) + { + throw std::invalid_argument("M dimension mismatch between A and E tensors"); + } + } + for(ck_tile::index_t i = 0; i < NumDimN; ++i) + { + kargs.N_dims[i] = host_args.B_dims[NumDimG + i]; + if(kargs.N_dims[i] != host_args.E_dims[NumDimG + NumDimM + i]) + { + throw std::invalid_argument("N dimension mismatch between B and E tensors"); + } + } + for(ck_tile::index_t i = 0; i < NumDimK; ++i) + { + kargs.K_dims[i] = host_args.A_dims[NumDimG + NumDimM + i]; + if(kargs.K_dims[i] != host_args.B_dims[NumDimG + NumDimN + i]) + { + throw std::invalid_argument("K dimension mismatch between A and B tensors"); + } + } + + // Calculate total dimensions from individual dimension arrays + kargs.G_total = 1; + for(ck_tile::index_t i = 0; i < NumDimG; ++i) + { + kargs.G_total *= kargs.G_dims[i]; + } + + kargs.M_total = 1; + for(ck_tile::index_t i = 0; i < NumDimM; ++i) + { + kargs.M_total *= kargs.M_dims[i]; + } + + kargs.N_total = 1; + for(ck_tile::index_t i = 0; i < NumDimN; ++i) + { + kargs.N_total *= kargs.N_dims[i]; + } + + kargs.K_total = 1; + for(ck_tile::index_t i = 0; i < NumDimK; ++i) + { + kargs.K_total *= kargs.K_dims[i]; + } + + kargs.stride_A = kargs.K_total; + kargs.stride_B = kargs.K_total; + kargs.stride_E = kargs.N_total; + + // Validate D tensors have same G dimensions and set their batch strides + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + for(ck_tile::index_t i = 0; i < NumDimG; ++i) + { + if(host_args.Ds_dims[d][i] != host_args.A_dims[i]) + { + throw std::invalid_argument( + "D tensor G dimensions must match A/B/E tensor G dimensions"); + } + } + // Set batch stride for D tensor + kargs.batch_stride_Ds[d] = host_args.Ds_strides[d][NumDimG - 1]; + kargs.stride_Ds[d] = kargs.N_total; // D tensors same shape as E + } + + return kargs; + } + + CK_TILE_DEVICE void operator()(const KernelArgs& kargs) const + { + + const auto [iM, iN] = + TilePartitioner{kargs.M_total, kargs.N_total}.GetOutputTileIndex(blockIdx.x); + const ck_tile::index_t i_m = + __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const ck_tile::index_t i_n = + __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y); + const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z); + + // Calculate batch offsets for each tensor + const auto batch_offset_A = i_batch_flat * kargs.batch_stride_A; + const auto batch_offset_B = i_batch_flat * kargs.batch_stride_B; + const auto batch_offset_E = i_batch_flat * kargs.batch_stride_E; + + const ADataType* a_ptr = static_cast(kargs.a_ptr) + batch_offset_A; + const BDataType* b_ptr = static_cast(kargs.b_ptr) + batch_offset_B; + EDataType* e_ptr = static_cast(kargs.e_ptr) + batch_offset_E; + + std::array ds_batch_ptr; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = typename std::tuple_element::type; + const auto batch_offset_D = i_batch_flat * kargs.batch_stride_Ds[i]; + ds_batch_ptr[i] = static_cast(kargs.ds_ptr[i]) + batch_offset_D; + }); + + typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr}, + {b_ptr}, + ds_batch_ptr, + e_ptr, + kargs.M_total, + kargs.N_total, + kargs.K_total, + {kargs.stride_A}, + {kargs.stride_B}, + kargs.stride_Ds, + kargs.stride_E, + kargs.k_batch}; + + const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs, + i_splitk); + + const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0]; + const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0]; + __shared__ char smem_ptr[GetSmemSize()]; + + UniversalGemmKernel::RunGemm({a_ptr_final}, + {b_ptr_final}, + ds_batch_ptr, + e_ptr, + smem_ptr, + gemm_kargs, + splitk_batch_offset, + i_m, + i_n); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp b/include/ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp new file mode 100644 index 0000000000..9ebaae3c97 --- /dev/null +++ b/include/ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct BatchedContractionProblem +{ + using ADataType = ck_tile::remove_cvref_t; + using BDataType = ck_tile::remove_cvref_t; + using DsDataType = ck_tile::remove_cvref_t; + using EDataType = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t NumDimG = NumDimG_; + static constexpr ck_tile::index_t NumDimM = NumDimM_; + static constexpr ck_tile::index_t NumDimN = NumDimN_; + static constexpr ck_tile::index_t NumDimK = NumDimK_; + static constexpr ck_tile::index_t NumDTensor = NumDTensor_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp b/include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp new file mode 100644 index 0000000000..6d3286ce09 --- /dev/null +++ b/include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +/** + * @file tensor_descriptor_utils.hpp + * @brief Utility functions for creating tensor descriptors in batched contraction operations + * + * @details This file contains utility functions for creating tensor descriptors with flattened + * dimensions for GEMM operations. These functions transform multi-dimensional tensors into + * 2D matrix descriptors by removing batch dimensions and flattening the remaining dimensions. + * + * These utilities are currently not used in the main batched contraction kernel but are preserved + * for future implementations that may require explicit tensor descriptor creation. + */ + +namespace ck_tile { + +/** + * @brief Utility class for creating tensor descriptors in batched contraction operations + * + * @tparam NumDimG Number of batch dimensions + * @tparam NumDimM Number of M (output row) dimensions + * @tparam NumDimN Number of N (output column) dimensions + * @tparam NumDimK Number of K (contraction) dimensions + */ +template +struct TensorDescriptorUtils +{ + /// @brief Creates a tensor descriptor for input tensor A with batch dimensions removed. + /// @param A_dims Dimension vector for tensor A: [G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] + /// @param A_strides Stride vector for tensor A: [G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] + /// @return Flattened tensor descriptor: [M_total, K_total] for GEMM computation + /// @details Removes batch dimensions and flattens M and K dimensions for efficient GEMM + /// execution + CK_TILE_HOST static constexpr auto + Make_A_GridDescriptor_M_K(const std::vector& A_dims = {}, + const std::vector& A_strides = {}) + { + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, number{}); + }; + + // Remove G Dimensions + const auto A_dims_M_K = + to_tuple(A_dims, number{}, number{}); + const auto A_strides_M_K = + to_tuple(A_strides, number{}, number{}); + + // dimension Ids for M and K + constexpr auto A_dims_M_ids = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + constexpr auto A_dims_K_ids = + typename arithmetic_sequence_gen::type{}; + + // Dimensions for M [M0, M1, ...] and K [K0, K1, ...] + const auto dims_M = get_container_subset(A_dims_M_K, A_dims_M_ids); + const auto dims_K = get_container_subset(A_dims_M_K, A_dims_K_ids); + + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] Discriptor + const auto A_grid_desc_Ms_Ks = + ck_tile::make_naive_tensor_descriptor(A_dims_M_K, A_strides_M_K); + + // transformed tensor to flatten M and K dimensions [M_total = M0 * M1 * M2 * ... , K_total + // = K0 * K1 * K2 * ...] + const auto A_grid_desc_Mflat_Kflat = ck_tile::transform_tensor_descriptor( + A_grid_desc_Ms_Ks, + make_tuple(make_merge_transform(dims_M), make_merge_transform(dims_K)), + make_tuple(A_dims_M_ids, A_dims_K_ids), + make_tuple(sequence<0>{}, sequence<1>{})); + + return A_grid_desc_Mflat_Kflat; + } + + /// @brief Creates a tensor descriptor for input tensor B with batch dimensions removed. + /// @param B_dims Dimension vector for tensor B: [G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] + /// @param B_strides Stride vector for tensor B: [G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] + /// @return Flattened tensor descriptor: [N_total, K_total] for GEMM computation + /// @details Removes batch dimensions and flattens N and K dimensions for efficient GEMM + /// execution + CK_TILE_HOST static constexpr auto + Make_B_GridDescriptor_N_K(const std::vector& B_dims = {}, + const std::vector& B_strides = {}) + { + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, number{}); + }; + + // Remove G Dimensions + const auto B_dims_N_K = + to_tuple(B_dims, number{}, number{}); + const auto B_strides_N_K = + to_tuple(B_strides, number{}, number{}); + + // dimension Ids for N and K + constexpr auto B_dims_N_ids = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{}; + constexpr auto B_dims_K_ids = + typename arithmetic_sequence_gen::type{}; + + // Dimensions for N [N0, N1, ...] and K [K0, K1, ...] + const auto dims_N = get_container_subset(B_dims_N_K, B_dims_N_ids); + const auto dims_K = get_container_subset(B_dims_N_K, B_dims_K_ids); + + // naive tensor B[N0, N1, N2, ..., K0, K1, K2...] Discriptor + const auto B_grid_desc_Ns_Ks = + ck_tile::make_naive_tensor_descriptor(B_dims_N_K, B_strides_N_K); + + // transformed tensor to flatten N and K dimensions [N_total = N0 * N1 * N2 * ... , K_total + // = K0 * K1 * K2 * ...] + const auto B_grid_desc_Nflat_Kflat = ck_tile::transform_tensor_descriptor( + B_grid_desc_Ns_Ks, + make_tuple(make_merge_transform(dims_N), make_merge_transform(dims_K)), + make_tuple(B_dims_N_ids, B_dims_K_ids), + make_tuple(sequence<0>{}, sequence<1>{})); + + return B_grid_desc_Nflat_Kflat; + } + + /// @brief Creates a tensor descriptor for output tensor E with batch dimensions removed. + /// @param E_dims Dimension vector for tensor E: [G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] + /// @param E_strides Stride vector for tensor E: [G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] + /// @return Flattened tensor descriptor: [M_total, N_total] for GEMM computation + /// @details Removes batch dimensions and flattens M and N dimensions for efficient GEMM + /// execution + CK_TILE_HOST static constexpr auto + Make_E_GridDescriptor_M_N(const std::vector& E_dims = {}, + const std::vector& E_strides = {}) + { + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, number{}); + }; + + // Remove G dimensions + const auto E_dims_M_N = + to_tuple(E_dims, number{}, number{}); + const auto E_strides_M_N = + to_tuple(E_strides, number{}, number{}); + + // dimension Ids for M and N + constexpr auto E_dims_M_ids = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + constexpr auto E_dims_N_ids = + typename arithmetic_sequence_gen::type{}; + + // Dimensions for M and N + const auto dims_M = get_container_subset(E_dims_M_N, E_dims_M_ids); + const auto dims_N = get_container_subset(E_dims_M_N, E_dims_N_ids); + + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] Discriptor + const auto E_grid_desc_Ms_Ns = + ck_tile::make_naive_tensor_descriptor(E_dims_M_N, E_strides_M_N); + + // transformed tensor to flatten M and N dimensions [M_total = M0 * M1 * M2 * ... , + // N_total = N0 * N1 * N2 * ...] + const auto E_grid_desc_Mflat_Nflat = ck_tile::transform_tensor_descriptor( + E_grid_desc_Ms_Ns, + make_tuple(make_merge_transform(dims_M), make_merge_transform(dims_N)), + make_tuple(E_dims_M_ids, E_dims_N_ids), + make_tuple(sequence<0>{}, sequence<1>{})); + + return E_grid_desc_Mflat_Nflat; + } +}; + +} // namespace ck_tile