From e9f0cc83a8f3f94ad8462e50a9d9a92d8dca3388 Mon Sep 17 00:00:00 2001 From: msaffari-amd Date: Mon, 13 Oct 2025 12:30:28 +0200 Subject: [PATCH 01/75] [CK Tile] contraction multi d - kernel & example (#2901) * Initial commit. create batched_contraction_kernel file * initial problem definition * implement initial example to launch kernel * add universal gemm to contraction. initial phase * complete implementation for special case all Dims are 1 and no Ds * clean code * initial changes to support multi dimensional G * more progress in implementing multiple G * tmp commit * manage dynamic NumDimG in kernel * improving example for multi M,N,K,G handling. start generalizing kernel. it is a temporary commit * implement the example for general Multi dimension G M N K and test different reference calculation algorithms * 2 functions for reference using multi dimensional and flat indexing * clean the code for muti dimentional G, M, N, K contraction and add some logs * Add Make descriptor function in kernel for merging Ms, Ns, Ks for A, B, E * some cleaning on kernel * clean the code for calculating the offsets from flatten batch number * Start adding MultiD support to kernel and example * more changes to manage multi D in kernel and example * manage passing multi d to kernel and testing. * complete multi D support in kernel. modify example code to support it * Correct algorithm to calc the correct offset values for D tensor batches and some code cleaning * Minor fix * Generalize example code for variable NumD tensors and apply cleanup based on review feedback * Refactored code and addressed review feedback * refactoring, cleaning, add documents, in kernel side and example codes * Optimize batch offset calculation in kernel * Inline CalculateBatchOffset in batched contraction kernel, update CHANGELOG.md --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> --- CHANGELOG.md | 1 + .../41_batched_contraction/CMakeLists.txt | 7 + .../batched_contraction.cpp | 245 ++++++++ .../contraction_utils.hpp | 146 +++++ .../run_batched_contraction_example.inc | 405 ++++++++++++++ example/ck_tile/CMakeLists.txt | 1 + .../reference_batched_contraction.hpp | 265 +++++++++ include/ck_tile/ops/batched_contraction.hpp | 9 + .../kernel/batched_contraction_kernel.hpp | 522 ++++++++++++++++++ .../pipeline/batched_contraction_problem.hpp | 32 ++ .../utils/tensor_descriptor_utils.hpp | 169 ++++++ 11 files changed, 1802 insertions(+) create mode 100644 example/ck_tile/41_batched_contraction/CMakeLists.txt create mode 100644 example/ck_tile/41_batched_contraction/batched_contraction.cpp create mode 100644 example/ck_tile/41_batched_contraction/contraction_utils.hpp create mode 100644 example/ck_tile/41_batched_contraction/run_batched_contraction_example.inc create mode 100644 include/ck_tile/host/reference/reference_batched_contraction.hpp create mode 100644 include/ck_tile/ops/batched_contraction.hpp create mode 100644 include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp create mode 100644 include/ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp create mode 100644 include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index a8fe7b4afb..9de78f3043 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added the row-wise column-wise quantization for CK_TILE GEMM & CK_TILE Grouped GEMM. * Added support for f32 to FMHA (fwd/bwd). * Added tensor-wise quantization for CK_TILE GEMM. +* Added support for batched contraction kernel. * Added pooling kernel in CK_TILE ### Optimized diff --git a/example/ck_tile/41_batched_contraction/CMakeLists.txt b/example/ck_tile/41_batched_contraction/CMakeLists.txt new file mode 100644 index 0000000000..10b2e48cbf --- /dev/null +++ b/example/ck_tile/41_batched_contraction/CMakeLists.txt @@ -0,0 +1,7 @@ +add_executable(tile_example_batched_contraction EXCLUDE_FROM_ALL batched_contraction.cpp) +set(EXAMPLE_CONTRACTION_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_CONTRACTION_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +target_compile_options(tile_example_batched_contraction PRIVATE ${EXAMPLE_CONTRACTION_COMPILE_OPTIONS}) diff --git a/example/ck_tile/41_batched_contraction/batched_contraction.cpp b/example/ck_tile/41_batched_contraction/batched_contraction.cpp new file mode 100644 index 0000000000..ea78f09dff --- /dev/null +++ b/example/ck_tile/41_batched_contraction/batched_contraction.cpp @@ -0,0 +1,245 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/host.hpp" + +#include "ck_tile/ops/batched_contraction.hpp" +#include "contraction_utils.hpp" + +template + +float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs& args, + const ck_tile::stream_config& s) +{ + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = false; + + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + + using Problem = ck_tile::BatchedContractionProblem; + + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = UNIVERSAL_GEMM_PIPELINE; + + ck_tile::index_t K_total = 1; + for(ck_tile::index_t i = NumDimG + NumDimM; i < NumDimG + NumDimM + NumDimK; ++i) + { + K_total *= args.A_dims[i]; + } + + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_total); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GEMM_PIPELINE_SCHEDULER; + constexpr auto memory_operation = + ck_tile::memory_operation_enum::set; // Always set (no atomic_add) + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = GEMM_PIPELINE; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = + ck_tile::BatchedContractionKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::GetBlockSize(); + + if(!Kernel::IsSupportedArguments(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping contraction!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetKernelName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << GemmPipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + auto kernel = ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs); + + ave_time = ck_tile::launch_kernel(s, kernel); + + return ave_time; + }; + + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + + return ave_time; +} + +#define HANDLE_CASE(G, M, N, K) \ + if(num_g_dims == G && num_m_dims == M && num_n_dims == N && num_k_dims == K) \ + { \ + return batched_contraction_impl(args, s); \ + } + +template +float batched_contraction(const ck_tile::BatchedContractionHostArgs& args, + const ck_tile::stream_config& s, + ck_tile::index_t num_g_dims, + ck_tile::index_t num_m_dims, + ck_tile::index_t num_n_dims, + ck_tile::index_t num_k_dims) +{ + std::cout << "Dimensions: G=" << num_g_dims << ", M=" << num_m_dims << ", N=" << num_n_dims + << ", K=" << num_k_dims << std::endl; + + HANDLE_CASE(1, 1, 1, 1); + HANDLE_CASE(2, 1, 1, 1); + HANDLE_CASE(2, 2, 2, 1); + HANDLE_CASE(1, 2, 1, 1); + HANDLE_CASE(1, 1, 1, 2); + HANDLE_CASE(2, 2, 2, 2); + HANDLE_CASE(4, 4, 4, 4); + + throw std::runtime_error( + "Unsupported dimension combination: G=" + std::to_string(num_g_dims) + + ", M=" + std::to_string(num_m_dims) + ", N=" + std::to_string(num_n_dims) + + ", K=" + std::to_string(num_k_dims) + ". Please add this combination to the kernel."); +} + +#include "run_batched_contraction_example.inc" + +int main(int argc, char* argv[]) +{ + try + { + return !run_batched_contraction_example(argc, argv); + } + catch(const std::runtime_error& e) + { + std::cerr << "Runtime error: " << e.what() << '\n'; + return EXIT_FAILURE; + } +} diff --git a/example/ck_tile/41_batched_contraction/contraction_utils.hpp b/example/ck_tile/41_batched_contraction/contraction_utils.hpp new file mode 100644 index 0000000000..6a75f1c04e --- /dev/null +++ b/example/ck_tile/41_batched_contraction/contraction_utils.hpp @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" + +struct AddDs +{ + template + CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void + { + const float x0_f = + ck_tile::type_convert(c) + (ck_tile::type_convert(ds) + ...); + + e = ck_tile::type_convert(x0_f); + } +}; + +#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3 +#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3 +#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave + +template +struct BatchedContractionTypeConfig +{ + using ADataType = DataType; + using BDataType = DataType; + using AccDataType = float; + using EDataType = DataType; + using DDataType = DataType; +}; + +using ContractionTypes = BatchedContractionTypeConfig; + +using ADataType = ContractionTypes::ADataType; +using BDataType = ContractionTypes::BDataType; +using AccDataType = ContractionTypes::AccDataType; +using EDataType = ContractionTypes::EDataType; +using DDataType = ContractionTypes::DDataType; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m_dims", "4,256", "M dimensions separated by comma (e.g., '16,32' for 2D M)") + .insert("n_dims", "16,128", "N dimensions separated by comma (e.g., '32,32' for 2D N)") + .insert("k_dims", "64", "K dimensions separated by comma (e.g., '64,32' for 2D K)") + .insert( + "g_dims", "1,2", "G dimensions separated by comma (e.g., '4,2' for 2D, '2,3,4' for 3D)") + .insert("stride_a", "0", "Custom A tensor leading dimension stride (0 = auto)") + .insert("stride_b", "0", "Custom B tensor leading dimension stride (0 = auto)") + .insert("stride_e", "0", "Custom E tensor leading dimension stride (0 = auto)") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Col by default") + .insert("e_layout", "R", "E tensor data layout - Row by default") + .insert("v", "1", "0. No validation, 1. Validation on CPU") + .insert("prec", "fp16", "data type. fp32/fp16/bf16") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "10", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value") + .insert("log", "1", "log level for debugging"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// Helper function to parse G, M, N, K dimensions from string +std::vector parse_dimensions(const std::string& dims_str) +{ + std::vector dims; + std::stringstream ss(dims_str); + std::string token; + + while(std::getline(ss, token, ',')) + { + dims.push_back(std::stoi(token)); + } + + if(dims.empty()) + { + throw std::invalid_argument("Dimensions cannot be empty"); + } + + return dims; +} + +// Helper function to Calculate total elements from multi-dimensional vector +ck_tile::index_t calculate_total_elements(const std::vector& dims) +{ + ck_tile::index_t total = 1; + for(auto dim : dims) + { + total *= dim; + } + return total; +} + +/** + * @brief Flattens a list of tensor dimension components into a single dimension vector. + * + * This function takes a list of dimension vectors (e.g., representing different components + * such as G, M, N, or K dimensions) and concatenates them into a single vector. + * + * Example: + * Input: {{G0, G1}, {M0, M1}, {K0}} + * Output: {G0, G1, M0, M1, K0} + * + * @param dim_components A vector of vectors, where each inner vector represents a set of tensor + * dimensions. + * @return A single vector containing all dimensions concatenated in order. + */ +std::vector +concatenate_dim_components(const std::vector>& dim_components) +{ + std::vector result; + + // Concatenate all dimension components into a single vector + for(const auto& component : dim_components) + { + result.insert(result.end(), component.begin(), component.end()); + } + + return result; +} + +// Helper function for printing dimensions +void print_dims(const std::string& name, + const std::vector& dims, + ck_tile::index_t total) +{ + std::cout << name << ": ["; + for(size_t i = 0; i < dims.size(); ++i) + { + std::cout << dims[i]; + if(i < dims.size() - 1) + std::cout << ","; + } + std::cout << "] "; + if(total != 0) + std::cout << "(total=" << total << ")"; + std::cout << std::endl; +} diff --git a/example/ck_tile/41_batched_contraction/run_batched_contraction_example.inc b/example/ck_tile/41_batched_contraction/run_batched_contraction_example.inc new file mode 100644 index 0000000000..9bc09a6c9c --- /dev/null +++ b/example/ck_tile/41_batched_contraction/run_batched_contraction_example.inc @@ -0,0 +1,405 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include "contraction_utils.hpp" +#include "ck_tile/host/reference/reference_batched_contraction.hpp" + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +float invoke_batched_contraction_kernel( + const void* a_full_dims_dev_buf, + const void* b_full_dims_dev_buf, + const std::array& ds_dev_buf, + void* e_full_dims_dev_buf, + const std::vector& G_dims, + const std::vector& M_dims, + const std::vector& N_dims, + const std::vector& K_dims, + const std::vector& A_dims, // [G0,G1,..,M0,M1,..,K0,K1,..] + const std::vector& B_dims, // [G0,G1,..,N0,N1,..,K0,K1,..] + const std::array, DsDataType::size()>& + Ds_dims, // [G0, G1, ..., M0, M1, ... , N0, N1, ...][NumDTensor] + const std::vector& E_dims, // [G0,G1,..,M0,M1,..,N0,N1,..] + const std::vector& A_strides, // [G0,G1,..,M0,M1,..,K0,K1,..] + const std::vector& B_strides, // [G0,G1,..,N0,N1,..,K0,K1,..] + const std::array, DsDataType::size()>& Ds_strides, + const std::vector& E_strides, // [G0,G1,..,M0,M1,..,N0,N1,..] + ck_tile::index_t kbatch, + int n_warmup, + int n_repeat) +{ + std::cout << "Creating BatchedContractionHostArgs..." << std::endl; + + ck_tile::BatchedContractionHostArgs args(a_full_dims_dev_buf, // a_ptr + b_full_dims_dev_buf, // b_ptr + ds_dev_buf, // ds_ptr + e_full_dims_dev_buf, // e_ptr + kbatch, // k_batch + A_dims, // A_dims + B_dims, // B_dims + Ds_dims, // Ds_dims + E_dims, // E_dims + A_strides, // A_strides + B_strides, // B_strides + Ds_strides, // Ds_strides + E_strides // E_strides + ); + + std::cout << "Calling batched_contraction with dimensions: G=" << G_dims.size() + << ", M=" << M_dims.size() << ", N=" << N_dims.size() << ", K=" << K_dims.size() + << std::endl; + + float ave_time = batched_contraction( + args, + ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}, + G_dims.size(), // num_g_dims + M_dims.size(), // num_m_dims + N_dims.size(), // num_n_dims + K_dims.size() // num_k_dims + ); + + return ave_time; +} + +template +int run_batched_contraction_example_with_layouts( + int argc, + char* argv[], + [[maybe_unused]] const ALayout a_layout = ALayout{}, + [[maybe_unused]] const BLayout b_layout = BLayout{}, + [[maybe_unused]] const DLayout d_layout = DLayout{}, + [[maybe_unused]] const ELayout e_layout = ELayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + std::vector G_dims = parse_dimensions(arg_parser.get_str("g_dims")); + std::vector M_dims = parse_dimensions(arg_parser.get_str("m_dims")); + std::vector N_dims = parse_dimensions(arg_parser.get_str("n_dims")); + std::vector K_dims = parse_dimensions(arg_parser.get_str("k_dims")); + + constexpr ck_tile::index_t NumDTensor = 2; + + ck_tile::index_t G_total = calculate_total_elements(G_dims); + ck_tile::index_t M_total = calculate_total_elements(M_dims); + ck_tile::index_t N_total = calculate_total_elements(N_dims); + ck_tile::index_t K_total = calculate_total_elements(K_dims); + + std::vector A_dims = + concatenate_dim_components({G_dims, M_dims, K_dims}); // [G0,G1,..,M0,M1,..,K0,K1,..] + std::vector B_dims = + concatenate_dim_components({G_dims, N_dims, K_dims}); // [G0,G1,..,N0,N1,..,K0,K1,..] + std::vector E_dims = + concatenate_dim_components({G_dims, M_dims, N_dims}); // [G0,G1,..,M0,M1,..,N0,N1,..] + + std::array, NumDTensor> Ds_dims; + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + Ds_dims[d] = E_dims; + } + + auto convert_strides = [](const std::vector& strides) { + std::vector converted(strides.size()); + std::copy(strides.begin(), strides.end(), converted.begin()); + return converted; + }; + + ck_tile::HostTensorDescriptor a_desc(A_dims); + ck_tile::HostTensorDescriptor b_desc(B_dims); + ck_tile::HostTensorDescriptor e_desc(E_dims); + std::array ds_descs; + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + ds_descs[d] = ck_tile::HostTensorDescriptor(Ds_dims[d], e_desc.get_strides()); + } + + std::vector A_strides = convert_strides(a_desc.get_strides()); + std::vector B_strides = convert_strides(b_desc.get_strides()); + std::vector E_strides = convert_strides(e_desc.get_strides()); + + std::array, NumDTensor> Ds_strides; + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + Ds_strides[d] = convert_strides(ds_descs[d].get_strides()); + } + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + + print_dims("G_dims", G_dims, G_total); + print_dims("M_dims", M_dims, M_total); + print_dims("N_dims", N_dims, N_total); + print_dims("K_dims", K_dims, K_total); + + std::cout << "NumDTensor: " << NumDTensor << std::endl; + std::cout << "\n=== Tensor Shapes for Kernel ===" << std::endl; + print_dims("A_dims", A_dims, 0); + print_dims("B_dims", B_dims, 0); + print_dims("E_dims", E_dims, 0); + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + print_dims("Ds[" + std::to_string(d) + "]_dims", Ds_dims[d], 0); + } + + std::cout << "\n=== Tensor Strides ===" << std::endl; + print_dims("A_strides", A_strides, 0); + print_dims("B_strides", B_strides, 0); + print_dims("E_strides", E_strides, 0); + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + print_dims("Ds[" + std::to_string(d) + "]_strides", Ds_strides[d], 0); + } + + std::cout << "===============================================\n" << std::endl; + + ck_tile::HostTensor<::ADataType> a_full_dims_host(a_desc); + ck_tile::HostTensor<::BDataType> b_full_dims_host(b_desc); + ck_tile::HostTensor<::EDataType> e_full_dims_host(e_desc); + + std::vector> ds_full_dims_host; + for(int d = 0; d < NumDTensor; ++d) + { + ds_full_dims_host.emplace_back(ck_tile::HostTensor<::DDataType>(ds_descs[d])); + } + + ck_tile::FillUniformDistribution<::ADataType>{-5.f, 5.f, std::nullopt}(a_full_dims_host); + ck_tile::FillUniformDistribution<::BDataType>{-5.f, 5.f, std::nullopt}(b_full_dims_host); + + ck_tile::DeviceMem a_full_dims_dev_buf(a_full_dims_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_full_dims_dev_buf(b_full_dims_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem e_full_dims_dev_buf(e_full_dims_host.get_element_space_size_in_bytes()); + + a_full_dims_dev_buf.ToDevice(a_full_dims_host.data()); + b_full_dims_dev_buf.ToDevice(b_full_dims_host.data()); + + for(int d = 0; d < NumDTensor; ++d) + { + ck_tile::FillUniformDistribution<::DDataType>{-2.f, 2.f, std::nullopt}( + ds_full_dims_host[d]); + } + + std::vector> ds_full_dims_dev_buf; + for(int d = 0; d < NumDTensor; ++d) + { + ds_full_dims_dev_buf.push_back(std::make_unique( + ds_full_dims_host[d].get_element_space_size_in_bytes())); + ds_full_dims_dev_buf[d]->ToDevice(ds_full_dims_host[d].data()); + } + std::array ds_ptr_buf; + for(int d = 0; d < NumDTensor; ++d) + { + ds_ptr_buf[d] = ds_full_dims_dev_buf[d]->GetDeviceBuffer(); + } + + e_full_dims_dev_buf.SetZero(); + e_full_dims_host.SetZero(); + + std::cout << "\n=== Running GPU Kernel ===" << std::endl; + + using DsDataType = ck_tile::tuple_array<::DDataType, NumDTensor>; + using DsLayout = ck_tile::tuple_array; + using CDEElementWise = + std::conditional_t; + + float ave_time = + invoke_batched_contraction_kernel<::ADataType, + ::BDataType, + DsDataType, + ::AccDataType, + ::EDataType, + ALayout, + BLayout, + DsLayout, + ELayout, + CDEElementWise>(a_full_dims_dev_buf.GetDeviceBuffer(), + b_full_dims_dev_buf.GetDeviceBuffer(), + ds_ptr_buf, + e_full_dims_dev_buf.GetDeviceBuffer(), + G_dims, + M_dims, + N_dims, + K_dims, + A_dims, + B_dims, + Ds_dims, + E_dims, + A_strides, + B_strides, + Ds_strides, + E_strides, + kbatch, + n_warmup, + n_repeat); + + std::string op_name{ + "Multi-Dimensional Batched Contraction : G: " + std::to_string(G_dims.size()) + + "D, M: " + std::to_string(M_dims.size()) + "D, N: " + std::to_string(N_dims.size()) + + "D, K: " + std::to_string(K_dims.size()) + "D"}; + + std::size_t flop = std::size_t(2) * G_total * M_total * N_total * K_total + + NumDTensor * K_total * M_total * N_total; // Number of operations + std::size_t num_byte = + sizeof(::ADataType) * G_total * M_total * K_total + // A tensor size + sizeof(::BDataType) * G_total * N_total * K_total + // B tensor size + sizeof(::DDataType) * NumDTensor * G_total * M_total * N_total + // D tensors + sizeof(::EDataType) * G_total * M_total * N_total; // E tensor size + + float tflops = static_cast(flop) / 1.E9 / ave_time; // TFlops calculation + float gb_per_sec = num_byte / 1.E6 / ave_time; // GB/s calculation + print_dims("G_dims", G_dims, G_total); + print_dims("M_dims", M_dims, M_total); + print_dims("N_dims", N_dims, N_total); + print_dims("K_dims", K_dims, K_total); + + std::cout << " Performance: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s" << std::endl; + + std::cout << "===============================================" << std::endl; + + e_full_dims_dev_buf.FromDevice(e_full_dims_host.data()); + std::cout << "GPU results retrieved from device." << std::endl; + + bool pass = true; + if(arg_parser.get_int("v") == 1) + { + + std::cout << "Computing CPU reference..." << std::endl; + + ck_tile::HostTensor<::EDataType> e_full_dims_host_ref( + ck_tile::HostTensorDescriptor(E_dims, E_strides)); + e_full_dims_host_ref.SetZero(); + + auto start_time = std::chrono::high_resolution_clock::now(); + + calculate_reference_flat_indexing(a_full_dims_host, + b_full_dims_host, + ds_full_dims_host, + e_full_dims_host_ref, + G_total, + M_total, + N_total, + K_total, + CDEElementWise{}); + + auto end_time = std::chrono::high_resolution_clock::now(); + auto duration = + std::chrono::duration_cast(end_time - start_time); + + std::cout << "CPU reference completed in " << duration.count() << "ms" << std::endl; + + const float max_accumulated_value = + *std::max_element(e_full_dims_host_ref.mData.begin(), e_full_dims_host_ref.mData.end()); + + const auto rtol_atol = + calculate_rtol_atol<::ADataType, ::BDataType, ::EDataType, ::AccDataType>( + K_total, kbatch, max_accumulated_value); + + pass = ck_tile::check_err(e_full_dims_host, + e_full_dims_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "The CPU verification result is: " << (pass ? "correct" : "fail") << std::endl; + + std::cout << "===============================================" << std::endl; + + std::cout << "\n=== Random Samples of Reference and Result ===" << std::endl; + + // Generate 10 random indices + std::vector random_indices; + std::size_t total_elements = e_full_dims_host_ref.mData.size(); + std::mt19937 rng(std::random_device{}()); + std::uniform_int_distribution dist(0, total_elements - 1); + + for(int i = 0; i < 10; ++i) + { + random_indices.push_back(dist(rng)); + } + + // Print the values at the random indices + for(std::size_t idx : random_indices) + { + std::cout << "Index " << idx << ": " + << "ref=" << static_cast(e_full_dims_host_ref.mData[idx]) << ", " + << "GPU=" << static_cast(e_full_dims_host.mData[idx]) << std::endl; + } + + std::cout << "===============================================" << std::endl; + } + + return pass; +} + +int run_batched_contraction_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + if(a_layout == "R" && b_layout == "C") + { + return run_batched_contraction_example_with_layouts(argc, argv, Row{}, Col{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and E tensors! " + "Only R-C-R supported for now."); + } +} diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 7a8ae065db..5e178e3669 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -27,3 +27,4 @@ add_subdirectory(36_pooling) add_subdirectory(38_block_scale_gemm) add_subdirectory(39_copy) add_subdirectory(40_streamk_gemm) +add_subdirectory(41_batched_contraction) diff --git a/include/ck_tile/host/reference/reference_batched_contraction.hpp b/include/ck_tile/host/reference/reference_batched_contraction.hpp new file mode 100644 index 0000000000..1ce071969c --- /dev/null +++ b/include/ck_tile/host/reference/reference_batched_contraction.hpp @@ -0,0 +1,265 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace ck_tile { + +template + +void calculate_reference_flat_indexing( + const ck_tile::HostTensor& a_full_dims, + const ck_tile::HostTensor& b_full_dims, + const std::vector>& ds_full_dims_host, + ck_tile::HostTensor& e_full_dims_host_ref, + ck_tile::index_t G_total, + ck_tile::index_t M_total, + ck_tile::index_t N_total, + ck_tile::index_t K_total, + const CDEElementWise& cde_elementwise) +{ + std::cout << "Calculating reference using optimized flat indexing with parallel processing..." + << std::endl; + + // Parallel computation over G and M dimensions using pattern from reference_batched_gemm.hpp + auto f_gm = [&](auto g_flat, auto m_flat) { + for(ck_tile::index_t n_flat = 0; n_flat < N_total; ++n_flat) + { + AccDataType sum = 0; + + // Compute dot product over K dimension + for(ck_tile::index_t k_flat = 0; k_flat < K_total; ++k_flat) + { + auto a_val = + a_full_dims.mData[g_flat * M_total * K_total + m_flat * K_total + k_flat]; + auto b_val = + b_full_dims.mData[g_flat * N_total * K_total + n_flat * K_total + k_flat]; + sum += static_cast(a_val) * static_cast(b_val); + } + + // Apply elementwise operation with D tensors + EDataType result = static_cast(sum); + if(ds_full_dims_host.size() == 0) + { + ; + } + else if(ds_full_dims_host.size() == 1) + { + cde_elementwise(result, + ck_tile::type_convert(sum), + ck_tile::type_convert( + ds_full_dims_host[0].mData[g_flat * M_total * N_total + + m_flat * N_total + n_flat])); + } + else if(ds_full_dims_host.size() == 2) + { + cde_elementwise( + result, + ck_tile::type_convert(sum), + ck_tile::type_convert( + ds_full_dims_host[0] + .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), + ck_tile::type_convert( + ds_full_dims_host[1] + .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat])); + } + else if(ds_full_dims_host.size() == 3) + { + cde_elementwise( + result, + ck_tile::type_convert(sum), + ck_tile::type_convert( + ds_full_dims_host[0] + .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), + ck_tile::type_convert( + ds_full_dims_host[1] + .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), + ck_tile::type_convert( + ds_full_dims_host[2] + .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat])); + } + else if(ds_full_dims_host.size() == 4) + { + cde_elementwise( + result, + ck_tile::type_convert(sum), + ck_tile::type_convert( + ds_full_dims_host[0] + .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), + ck_tile::type_convert( + ds_full_dims_host[1] + .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), + ck_tile::type_convert( + ds_full_dims_host[2] + .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), + ck_tile::type_convert( + ds_full_dims_host[3] + .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat])); + } + else + { + throw std::runtime_error("Unsupported NumDTensor for reference calculation"); + } + + // Store result + e_full_dims_host_ref.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat] = + static_cast(result); + } + }; + + // Execute parallel computation using hardware concurrency + // Parallelize over G_total and M_total dimensions for optimal CPU utilization + make_ParallelTensorFunctor(f_gm, G_total, M_total)(std::thread::hardware_concurrency()); +} + +template +void calculate_reference_multi_dimensional( + const HostTensor& a_full_dims, + const HostTensor& b_full_dims, + const std::vector>& ds_full_dims_host, + HostTensor& e_full_dims_host_ref, + const std::vector& G_dims, + const std::vector& M_dims, + const std::vector& N_dims, + const std::vector& K_dims, + const std::vector& A_dims, + const std::vector& B_dims, + const std::vector& E_dims, + const CDEElementWise& cde_elementwise) +{ + std::cout << "Calculating reference using multi-dimensional indexing..." << std::endl; + + std::vector g_idx(G_dims.size()); + std::vector m_idx(M_dims.size()); + std::vector n_idx(N_dims.size()); + std::vector k_idx(K_dims.size()); + std::vector a_idx, b_idx, e_idx; + + a_idx.reserve(A_dims.size()); + b_idx.reserve(B_dims.size()); + e_idx.reserve(E_dims.size()); + + for(ck_tile::index_t g_flat = 0; g_flat < calculate_total_elements(G_dims); ++g_flat) + { + ck_tile::index_t temp = g_flat; + for(int i = G_dims.size() - 1; i >= 0; --i) + { + g_idx[i] = temp % G_dims[i]; + temp /= G_dims[i]; + } + + for(ck_tile::index_t m_flat = 0; m_flat < calculate_total_elements(M_dims); ++m_flat) + { + temp = m_flat; + for(int i = M_dims.size() - 1; i >= 0; --i) + { + m_idx[i] = temp % M_dims[i]; + temp /= M_dims[i]; + } + + for(ck_tile::index_t n_flat = 0; n_flat < calculate_total_elements(N_dims); ++n_flat) + { + temp = n_flat; + for(int i = N_dims.size() - 1; i >= 0; --i) + { + n_idx[i] = temp % N_dims[i]; + temp /= N_dims[i]; + } + + AccDataType sum = 0; + + for(ck_tile::index_t k_flat = 0; k_flat < calculate_total_elements(K_dims); + ++k_flat) + { + temp = k_flat; + for(int i = K_dims.size() - 1; i >= 0; --i) + { + k_idx[i] = temp % K_dims[i]; + temp /= K_dims[i]; + } + + a_idx.clear(); + b_idx.clear(); + + a_idx.insert(a_idx.end(), g_idx.begin(), g_idx.end()); + a_idx.insert(a_idx.end(), m_idx.begin(), m_idx.end()); + a_idx.insert(a_idx.end(), k_idx.begin(), k_idx.end()); + + b_idx.insert(b_idx.end(), g_idx.begin(), g_idx.end()); + b_idx.insert(b_idx.end(), n_idx.begin(), n_idx.end()); + b_idx.insert(b_idx.end(), k_idx.begin(), k_idx.end()); + + auto a_val = a_full_dims(a_idx); + auto b_val = b_full_dims(b_idx); + + sum += static_cast(a_val) * static_cast(b_val); + } + + e_idx.clear(); + e_idx.insert(e_idx.end(), g_idx.begin(), g_idx.end()); + e_idx.insert(e_idx.end(), m_idx.begin(), m_idx.end()); + e_idx.insert(e_idx.end(), n_idx.begin(), n_idx.end()); + + EDataType result = static_cast(sum); + if(ds_full_dims_host.size() == 0) + { + ; + } + else if(ds_full_dims_host.size() == 1) + { + cde_elementwise(result, + ck_tile::type_convert(sum), + ck_tile::type_convert(ds_full_dims_host[0](e_idx))); + } + else if(ds_full_dims_host.size() == 2) + { + cde_elementwise(result, + ck_tile::type_convert(sum), + ck_tile::type_convert(ds_full_dims_host[0](e_idx)), + ck_tile::type_convert(ds_full_dims_host[1](e_idx))); + } + else if(ds_full_dims_host.size() == 3) + { + cde_elementwise(result, + ck_tile::type_convert(sum), + ck_tile::type_convert(ds_full_dims_host[0](e_idx)), + ck_tile::type_convert(ds_full_dims_host[1](e_idx)), + ck_tile::type_convert(ds_full_dims_host[2](e_idx))); + } + else if(ds_full_dims_host.size() == 4) + { + cde_elementwise(result, + ck_tile::type_convert(sum), + ck_tile::type_convert(ds_full_dims_host[0](e_idx)), + ck_tile::type_convert(ds_full_dims_host[1](e_idx)), + ck_tile::type_convert(ds_full_dims_host[2](e_idx)), + ck_tile::type_convert(ds_full_dims_host[3](e_idx))); + } + else + { + throw std::runtime_error("Unsupported NumDTensor for reference calculation"); + } + + e_full_dims_host_ref(e_idx) = static_cast(result); + } + } + } +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/batched_contraction.hpp b/include/ck_tile/ops/batched_contraction.hpp new file mode 100644 index 0000000000..9162f421d1 --- /dev/null +++ b/include/ck_tile/ops/batched_contraction.hpp @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp" +#include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp new file mode 100644 index 0000000000..6d8f9f3f0e --- /dev/null +++ b/include/ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp @@ -0,0 +1,522 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp" +#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" + +/** + * @file batched_contraction_kernel.hpp + * @brief Batched Tensor Contraction Operations + * + * @section batched_contraction_overview What is Batched Tensor Contraction with Multiple D? + * + * Tensor contraction is a fundamental operation that generalizes matrix multiplication to + * multi-dimensional tensors. It performs element-wise multiplication and summation over + * shared dimensions + * + * **Beyond pure contraction, this kernel supports multiple auxiliary input tensors (D tensors)** + * that are fused with the contraction result through configurable epilogue operations, enabling + * efficient computation of complex tensor expressions in a single kernel launch. + * + * @subsection mathematical_formulation Mathematical Formulation + * + * For tensors A and B with arbitrary dimensionalities, the complete operation computes: + * + * **E[G₀,G₁,...,M₀,M₁,...,N₀,N₁,...] = epilogue_op(C, D₀, D₁, D₂, ...)** + * + * Where: + * **C[G₀,G₁,...,M₀,M₁,...,N₀,N₁,...] = Σ_{K₀,K₁,...} A[G₀,G₁,...,M₀,M₁,...,K₀,K₁,...] × + * B[G₀,G₁,...,N₀,N₁,...,K₀,K₁,...]** + * + * Where: + * - **G dimensions**: Batch dimensions (shared across A, B, and output E) + * - **M dimensions**: Row dimensions of the output matrix (from tensor A) + * - **N dimensions**: Column dimensions of the output matrix (from tensor B) + * - **K dimensions**: Contraction dimensions (summed over, present in both A and B) + * + * @subsection why_gemm_implementation Why Tensor Contraction Can Be Implemented Using GEMM + * + * **Mathematical Equivalence**: Tensor contraction is fundamentally equivalent to matrix + * multiplication when dimensions are appropriately flattened. The key insight is that the summation + * operation over shared dimensions (K dimensions) in tensor contraction is mathematically identical + * to the dot product computation in matrix multiplication. + * + * **Dimension Flattening Strategy**: + * - **M dimensions** (from tensor A) → Flattened into matrix rows (M_total) + * - **N dimensions** (from tensor B) → Flattened into matrix columns (N_total) + * - **K dimensions** (contraction dims) → Flattened into inner dimension (K_total) + * - **G dimensions** (batch dims) → Handled through batch processing + * + * **Mathematical Transformation**: + * ``` + * Original: E[g,m₀,m₁,n₀,n₁] = Σ_{k₀,k₁} A[g,m₀,m₁,k₀,k₁] × B[g,n₀,n₁,k₀,k₁] + * Flattened: E[g,M,N] = Σ_K A[g,M,K] × B[g,N,K] (where M=m₀×m₁, N=n₀×n₁, K=k₀×k₁) + * GEMM Form: E = A × Bᵀ + * + * **Why This Approach Is Optimal**: + * Rather than implementing tensor contraction from scratch, this kernel leverages the highly + * optimized `UniversalGemmKernel` as its computational backend. + * + * @subsection current_limitations Current Kernel Limitations + * + * **Layout Restrictions:** + * - **Row-Major Only**: All tensors must use row-major memory layout + * - **Packed Tensors**: Only contiguous/packed tensor layouts supported + * - **Hardcoded Strides**: stride_A = K_total, stride_B = K_total, stride_E = N_total + * - **D Tensor Layout**: All D tensors must match E tensor layout (stride_Ds = N_total) + * + * **Implementation Constraints:** + * - **Fixed Stride Calculation**: Strides are automatically calculated and cannot be customized + * - **No Column-Major**: Column-major or custom stride patterns not supported + * - **No Strided Access**: Non-contiguous tensor slicing not supported + * + * **Future Enhancements:** + * - Support for arbitrary stride patterns + * - Column-major and mixed layout support + * - Non-contiguous tensor operation support + */ + +namespace ck_tile { + +/// @brief Host arguments for batched tensor contraction operations. +/// +/// @par Overview +/// This structure encapsulates all host-side arguments required for batched tensor contraction. +/// It supports arbitrary number of batch dimensions (G), M dimensions, N dimensions, and K +/// dimensions. +/// +/// @par Tensor Layout Assumptions +/// - A tensor: [G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] +/// - B tensor: [G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] +/// - D tensors: [G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] (auxiliary input tensors) +/// - E tensor: [G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] (output tensor) +/// +/// @tparam NumDTensor Number of D (auxiliary input) tensors. Default is 0. +template +struct BatchedContractionHostArgs +{ + /// @brief Constructor for batched contraction host arguments. + /// + /// @param a_ptr_ Pointer to input tensor A + /// @param b_ptr_ Pointer to input tensor B + /// @param ds_ptr_ Array of pointers to auxiliary input tensors D + /// @param e_ptr_ Pointer to output tensor E + /// @param k_batch_ Number of k-splits for split-K batching + /// @param A_dims_ Dimension vector for tensor A: [G0, G1, ..., M0, M1, ..., K0, K1, ...] + /// @param B_dims_ Dimension vector for tensor B: [G0, G1, ..., N0, N1, ..., K0, K1, ...] + /// @param Ds_dims_ Dimension vectors for D tensors: [G0, G1, ..., M0, M1, ..., N0, N1, ...] + /// @param E_dims_ Dimension vector for tensor E: [G0, G1, ..., M0, M1, ..., N0, N1, ...] + /// @param A_strides_ Stride vector for tensor A: [G0, G1, ..., M0, M1, ..., K0, K1, ...] + /// @param B_strides_ Stride vector for tensor B: [G0, G1, ..., N0, N1, ..., K0, K1, ...] + /// @param Ds_strides_ Stride vectors for D tensors: [G0, G1, ..., M0, M1, ..., N0, N1, ...] + /// @param E_strides_ Stride vector for tensor E: [G0, G1, ..., M0, M1, ..., N0, N1, ...] + CK_TILE_HOST + BatchedContractionHostArgs( + const void* a_ptr_, + const void* b_ptr_, + const std::array& ds_ptr_, + void* e_ptr_, + ck_tile::index_t k_batch_, + const std::vector& A_dims_, // [G0, G1, ..., M0, M1, ... , K0, K1, ...] + const std::vector& B_dims_, // [G0, G1, ..., N0, N1, ... , K0, K1, ...] + const std::array, NumDTensor>& + Ds_dims_, // [G0, G1, ..., M0, M1, ... , N0, N1, ...][NumDTensor] + const std::vector& E_dims_, // [G0, G1, ..., M0, M1, ... , N0, N1, ...] + + const std::vector& A_strides_, // [G0, G1, ..., M0, M1, ...,K0, K1, ...] + const std::vector& B_strides_, // [G0, G1, ..., N0, N1, ...,K0, K1, ...] + const std::array, NumDTensor>& + Ds_strides_, // [G0, G1, ..., M0, M1, ...,N0, N1, ...] + const std::vector& + E_strides_) // [G0, G1, ..., M0, M1, ...,N0, N1, ...][NumDTensor] + + : a_ptr(a_ptr_), + b_ptr(b_ptr_), + ds_ptr(ds_ptr_), + e_ptr(e_ptr_), + k_batch(k_batch_), + A_dims(A_dims_), + B_dims(B_dims_), + Ds_dims(Ds_dims_), + E_dims(E_dims_), + A_strides(A_strides_), + B_strides(B_strides_), + Ds_strides(Ds_strides_), + E_strides(E_strides_) + { + } + + const void* a_ptr; ///< Pointer to input tensor A + const void* b_ptr; ///< Pointer to input tensor B + std::array ds_ptr; ///< Array of pointers to auxiliary input tensors D + void* e_ptr; ///< Pointer to output tensor E + ck_tile::index_t k_batch; ///< Number of k-splits for split-K batching + const std::vector + A_dims; ///< Dimension vector for tensor A: [G0, G1, ..., M0, M1, ..., K0, K1, ...] + const std::vector + B_dims; ///< Dimension vector for tensor B: [G0, G1, ..., N0, N1, ..., K0, K1, ...] + const std::array, NumDTensor> + Ds_dims; ///< Dimension vectors for D tensors: [G0, G1, ..., M0, M1, ..., N0, N1, ...] + const std::vector + E_dims; ///< Dimension vector for tensor E: [G0, G1, ..., M0, M1, ..., N0, N1, ...] + const std::vector + A_strides; ///< Stride vector for tensor A: [G0, G1, ..., M0, M1, ..., K0, K1, ...] + const std::vector + B_strides; ///< Stride vector for tensor B: [G0, G1, ..., N0, N1, ..., K0, K1, ...] + const std::array, NumDTensor> + Ds_strides; ///< Stride vectors for D tensors: [G0, G1, ..., M0, M1, ..., N0, N1, ...] + const std::vector + E_strides; ///< Stride vector for tensor E: [G0, G1, ..., M0, M1, ..., N0, N1, ...] +}; + +/// @brief Kernel arguments for batched tensor contraction operations. +/// +/// @tparam NumDimG Number of batch dimensions +/// @tparam NumDimM Number of M (output row) dimensions +/// @tparam NumDimN Number of N (output column) dimensions +/// @tparam NumDimK Number of K (contraction) dimensions +/// @tparam NumDTensor Number of auxiliary input D tensors. Default is 0. + +template +struct BatchedContractionKernelArgs +{ + const void* a_ptr; ///< Pointer to input tensor A + const void* b_ptr; ///< Pointer to input tensor B + std::array ds_ptr; ///< Array of pointers to auxiliary input tensors D + void* e_ptr; ///< Pointer to output tensor E + ck_tile::index_t k_batch; ///< Number of k-splits for split-K batching + + ck_tile::index_t M_dims[NumDimM]; ///< M dimension sizes: [M0, M1, M2, ..., M_{NumDimM-1}] + ck_tile::index_t N_dims[NumDimN]; ///< N dimension sizes: [N0, N1, N2, ..., N_{NumDimN-1}] + ck_tile::index_t K_dims[NumDimK]; ///< K dimension sizes: [K0, K1, K2, ..., K_{NumDimK-1}] + ck_tile::index_t + G_dims[NumDimG]; ///< G (batch) dimension sizes: [G0, G1, G2, ..., G_{NumDimG-1}] + + // Batch strides for efficient offset calculation + ck_tile::index_t batch_stride_A; ///< Batch stride for tensor A + ck_tile::index_t batch_stride_B; ///< Batch stride for tensor B + ck_tile::index_t batch_stride_E; ///< Batch stride for tensor E + std::array batch_stride_Ds; ///< Batch strides for D tensors + + ck_tile::index_t G_total; ///< Total batch size: G0 * G1 * ... * G_{NumDimG-1} + ck_tile::index_t M_total; ///< Total M dimension: M0 * M1 * ... * M_{NumDimM-1} + ck_tile::index_t N_total; ///< Total N dimension: N0 * N1 * ... * N_{NumDimN-1} + ck_tile::index_t K_total; ///< Total K dimension: K0 * K1 * ... * K_{NumDimK-1} + + ck_tile::index_t stride_A; ///< Leading dimension stride for tensor A (row-major: K_total) + ck_tile::index_t stride_B; ///< Leading dimension stride for tensor B (row-major: K_total) + std::array + stride_Ds; ///< Leading dimension strides for D tensors (row-major: N_total) + ck_tile::index_t stride_E; ///< Leading dimension stride for tensor E (row-major: N_total) +}; + +/// @brief GPU kernel for batched tensor contraction operations. +/// +/// @par Overview +/// This kernel performs batched tensor contraction operations using the underlying +/// UniversalGemmKernel. It supports arbitrary tensor dimensionalities (G, M, N, K) and +/// processes multiple batch instances in parallel. Each batch performs: E = +/// epilogue_op(contraction(A, B), D0, D1, ...). +/// +/// @tparam Problem_ Tensor contraction problem specification defining data types and dimensions +/// @tparam TilePartitioner_ Tile partitioning strategy for workload distribution +/// @tparam GemmPipeline_ GEMM computation pipeline for core matrix operations +/// @tparam EpiloguePipeline_ Epilogue pipeline for post-GEMM operations and tensor fusion + +template +struct BatchedContractionKernel +{ + // Type aliases for cleaner code and better readability + using Problem = ck_tile::remove_cvref_t; ///< Tensor contraction problem specification + using ADataType = + ck_tile::remove_cvref_t; ///< Data type for input tensor A + using BDataType = + ck_tile::remove_cvref_t; ///< Data type for input tensor B + using DsDataType = + ck_tile::remove_cvref_t; ///< Data types for auxiliary input + ///< tensors D + using EDataType = + ck_tile::remove_cvref_t; ///< Data type for output tensor E + + // Compile-time dimension constants extracted from problem specification + static constexpr ck_tile::index_t NumDimG = Problem::NumDimG; ///< Number of batch dimensions + static constexpr ck_tile::index_t NumDimM = + Problem::NumDimM; ///< Number of M (output row) dimensions + static constexpr ck_tile::index_t NumDimN = + Problem::NumDimN; ///< Number of N (output column) dimensions + static constexpr ck_tile::index_t NumDimK = + Problem::NumDimK; ///< Number of K (contraction) dimensions + static constexpr ck_tile::index_t NumDTensor = + Problem::NumDTensor; ///< Number of auxiliary input D tensors + + // Pipeline and partitioning strategy types + using TilePartitioner = + ck_tile::remove_cvref_t; ///< Tile partitioning strategy for workload + ///< distribution + using GemmPipeline = ck_tile::remove_cvref_t; ///< GEMM computation pipeline + using EpiloguePipeline = + ck_tile::remove_cvref_t; ///< Epilogue pipeline for post-GEMM operations + + // Underlying GEMM kernel that performs the actual computation + using UniversalGemmKernel = + ck_tile::UniversalGemmKernel; + + static constexpr ck_tile::index_t kBlockSize = + UniversalGemmKernel::kBlockSize; ///< GPU block size inherited from GEMM kernel + + using KernelArgs = + BatchedContractionKernelArgs; ///< Kernel + ///< argument + ///< structure + + /// @brief Returns the kernel name for debugging and profiling purposes. + /// @return Constant string identifier for this kernel + CK_TILE_HOST static constexpr auto GetKernelName() { return "batched_contraction_kernel"; } + + /// @brief Validates whether the given kernel arguments are supported. + /// @param kargs Kernel arguments to validate + /// @return True if arguments are supported, false otherwise + /// @details Checks underlying GEMM kernel support and ensures valid batch dimensions + CK_TILE_HOST static constexpr bool IsSupportedArguments(const KernelArgs& kargs) + { + typename UniversalGemmKernel::KernelArgs gemm_kargs{{kargs.a_ptr}, + {kargs.b_ptr}, + kargs.ds_ptr, + kargs.e_ptr, + kargs.M_total, + kargs.N_total, + kargs.K_total, + {kargs.stride_A}, + {kargs.stride_B}, + kargs.stride_Ds, + kargs.stride_E, + kargs.k_batch}; + + return UniversalGemmKernel::IsSupportedArgument(gemm_kargs) && kargs.G_total > 0; + } + + /// @brief Returns the shared memory size required by the kernel. + /// @return Shared memory size in bytes + /// @details Delegates to underlying GEMM kernel's shared memory requirements + CK_TILE_HOST static constexpr ck_tile::index_t GetSmemSize() + { + return UniversalGemmKernel::GetSmemSize(); + } + + /// @brief Returns the GPU block size for kernel launch. + /// @return 3D block dimensions for GPU kernel execution + CK_TILE_HOST static constexpr auto GetBlockSize() + { + return dim3(UniversalGemmKernel::kBlockSize); + } + + CK_TILE_HOST static constexpr auto GridSize(const KernelArgs& kargs) + { + return dim3( + TilePartitioner::GridSize(kargs.M_total, kargs.N_total), kargs.G_total, kargs.k_batch); + } + + CK_TILE_HOST static constexpr KernelArgs + MakeKernelArgs(const BatchedContractionHostArgs& host_args) + { + const auto expected_A_dims = NumDimG + NumDimM + NumDimK; + const auto expected_B_dims = NumDimG + NumDimN + NumDimK; + const auto expected_E_dims = NumDimG + NumDimM + NumDimN; + + if(host_args.A_dims.size() != expected_A_dims || + host_args.A_strides.size() != expected_A_dims) + { + throw std::invalid_argument("A dimension size mismatch"); + } + if(host_args.B_dims.size() != expected_B_dims || + host_args.B_strides.size() != expected_B_dims) + { + throw std::invalid_argument("B dimension size mismatch"); + } + if(host_args.E_dims.size() != expected_E_dims || + host_args.E_strides.size() != expected_E_dims) + { + throw std::invalid_argument("E dimension size mismatch"); + } + + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + if(host_args.Ds_dims[d].size() != expected_E_dims || + host_args.Ds_strides[d].size() != expected_E_dims) + { + throw std::invalid_argument("D dimension size mismatch"); + } + } + + KernelArgs kargs; + kargs.a_ptr = host_args.a_ptr; + kargs.b_ptr = host_args.b_ptr; + kargs.ds_ptr = host_args.ds_ptr; + kargs.e_ptr = host_args.e_ptr; + kargs.k_batch = host_args.k_batch; + + // Validate and set G dimensions (must be identical across all tensors) + for(ck_tile::index_t i = 0; i < NumDimG; ++i) + { + // All tensors must have same G dimensions for valid contraction + if(host_args.A_dims[i] != host_args.B_dims[i] || + host_args.A_dims[i] != host_args.E_dims[i]) + { + throw std::invalid_argument( + "All tensors must have identical G dimensions for valid contraction"); + } + + // Store G dimensions (same for all tensors) + kargs.G_dims[i] = host_args.A_dims[i]; + } + + // Set batch strides from the stride of last G dimension + kargs.batch_stride_A = host_args.A_strides[NumDimG - 1]; + kargs.batch_stride_B = host_args.B_strides[NumDimG - 1]; + kargs.batch_stride_E = host_args.E_strides[NumDimG - 1]; + + for(ck_tile::index_t i = 0; i < NumDimM; ++i) + { + kargs.M_dims[i] = host_args.A_dims[NumDimG + i]; + if(kargs.M_dims[i] != host_args.E_dims[NumDimG + i]) + { + throw std::invalid_argument("M dimension mismatch between A and E tensors"); + } + } + for(ck_tile::index_t i = 0; i < NumDimN; ++i) + { + kargs.N_dims[i] = host_args.B_dims[NumDimG + i]; + if(kargs.N_dims[i] != host_args.E_dims[NumDimG + NumDimM + i]) + { + throw std::invalid_argument("N dimension mismatch between B and E tensors"); + } + } + for(ck_tile::index_t i = 0; i < NumDimK; ++i) + { + kargs.K_dims[i] = host_args.A_dims[NumDimG + NumDimM + i]; + if(kargs.K_dims[i] != host_args.B_dims[NumDimG + NumDimN + i]) + { + throw std::invalid_argument("K dimension mismatch between A and B tensors"); + } + } + + // Calculate total dimensions from individual dimension arrays + kargs.G_total = 1; + for(ck_tile::index_t i = 0; i < NumDimG; ++i) + { + kargs.G_total *= kargs.G_dims[i]; + } + + kargs.M_total = 1; + for(ck_tile::index_t i = 0; i < NumDimM; ++i) + { + kargs.M_total *= kargs.M_dims[i]; + } + + kargs.N_total = 1; + for(ck_tile::index_t i = 0; i < NumDimN; ++i) + { + kargs.N_total *= kargs.N_dims[i]; + } + + kargs.K_total = 1; + for(ck_tile::index_t i = 0; i < NumDimK; ++i) + { + kargs.K_total *= kargs.K_dims[i]; + } + + kargs.stride_A = kargs.K_total; + kargs.stride_B = kargs.K_total; + kargs.stride_E = kargs.N_total; + + // Validate D tensors have same G dimensions and set their batch strides + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + for(ck_tile::index_t i = 0; i < NumDimG; ++i) + { + if(host_args.Ds_dims[d][i] != host_args.A_dims[i]) + { + throw std::invalid_argument( + "D tensor G dimensions must match A/B/E tensor G dimensions"); + } + } + // Set batch stride for D tensor + kargs.batch_stride_Ds[d] = host_args.Ds_strides[d][NumDimG - 1]; + kargs.stride_Ds[d] = kargs.N_total; // D tensors same shape as E + } + + return kargs; + } + + CK_TILE_DEVICE void operator()(const KernelArgs& kargs) const + { + + const auto [iM, iN] = + TilePartitioner{kargs.M_total, kargs.N_total}.GetOutputTileIndex(blockIdx.x); + const ck_tile::index_t i_m = + __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const ck_tile::index_t i_n = + __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + const auto i_batch_flat = __builtin_amdgcn_readfirstlane(blockIdx.y); + const auto i_splitk = __builtin_amdgcn_readfirstlane(blockIdx.z); + + // Calculate batch offsets for each tensor + const auto batch_offset_A = i_batch_flat * kargs.batch_stride_A; + const auto batch_offset_B = i_batch_flat * kargs.batch_stride_B; + const auto batch_offset_E = i_batch_flat * kargs.batch_stride_E; + + const ADataType* a_ptr = static_cast(kargs.a_ptr) + batch_offset_A; + const BDataType* b_ptr = static_cast(kargs.b_ptr) + batch_offset_B; + EDataType* e_ptr = static_cast(kargs.e_ptr) + batch_offset_E; + + std::array ds_batch_ptr; + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = typename std::tuple_element::type; + const auto batch_offset_D = i_batch_flat * kargs.batch_stride_Ds[i]; + ds_batch_ptr[i] = static_cast(kargs.ds_ptr[i]) + batch_offset_D; + }); + + typename UniversalGemmKernel::KernelArgs gemm_kargs{{a_ptr}, + {b_ptr}, + ds_batch_ptr, + e_ptr, + kargs.M_total, + kargs.N_total, + kargs.K_total, + {kargs.stride_A}, + {kargs.stride_B}, + kargs.stride_Ds, + kargs.stride_E, + kargs.k_batch}; + + const typename UniversalGemmKernel::SplitKBatchOffset splitk_batch_offset(gemm_kargs, + i_splitk); + + const ADataType* a_ptr_final = a_ptr + splitk_batch_offset.as_k_split_offset[0]; + const BDataType* b_ptr_final = b_ptr + splitk_batch_offset.bs_k_split_offset[0]; + __shared__ char smem_ptr[GetSmemSize()]; + + UniversalGemmKernel::RunGemm({a_ptr_final}, + {b_ptr_final}, + ds_batch_ptr, + e_ptr, + smem_ptr, + gemm_kargs, + splitk_batch_offset, + i_m, + i_n); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp b/include/ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp new file mode 100644 index 0000000000..9ebaae3c97 --- /dev/null +++ b/include/ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct BatchedContractionProblem +{ + using ADataType = ck_tile::remove_cvref_t; + using BDataType = ck_tile::remove_cvref_t; + using DsDataType = ck_tile::remove_cvref_t; + using EDataType = ck_tile::remove_cvref_t; + + static constexpr ck_tile::index_t NumDimG = NumDimG_; + static constexpr ck_tile::index_t NumDimM = NumDimM_; + static constexpr ck_tile::index_t NumDimN = NumDimN_; + static constexpr ck_tile::index_t NumDimK = NumDimK_; + static constexpr ck_tile::index_t NumDTensor = NumDTensor_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp b/include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp new file mode 100644 index 0000000000..6d3286ce09 --- /dev/null +++ b/include/ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp @@ -0,0 +1,169 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +/** + * @file tensor_descriptor_utils.hpp + * @brief Utility functions for creating tensor descriptors in batched contraction operations + * + * @details This file contains utility functions for creating tensor descriptors with flattened + * dimensions for GEMM operations. These functions transform multi-dimensional tensors into + * 2D matrix descriptors by removing batch dimensions and flattening the remaining dimensions. + * + * These utilities are currently not used in the main batched contraction kernel but are preserved + * for future implementations that may require explicit tensor descriptor creation. + */ + +namespace ck_tile { + +/** + * @brief Utility class for creating tensor descriptors in batched contraction operations + * + * @tparam NumDimG Number of batch dimensions + * @tparam NumDimM Number of M (output row) dimensions + * @tparam NumDimN Number of N (output column) dimensions + * @tparam NumDimK Number of K (contraction) dimensions + */ +template +struct TensorDescriptorUtils +{ + /// @brief Creates a tensor descriptor for input tensor A with batch dimensions removed. + /// @param A_dims Dimension vector for tensor A: [G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] + /// @param A_strides Stride vector for tensor A: [G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] + /// @return Flattened tensor descriptor: [M_total, K_total] for GEMM computation + /// @details Removes batch dimensions and flattens M and K dimensions for efficient GEMM + /// execution + CK_TILE_HOST static constexpr auto + Make_A_GridDescriptor_M_K(const std::vector& A_dims = {}, + const std::vector& A_strides = {}) + { + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, number{}); + }; + + // Remove G Dimensions + const auto A_dims_M_K = + to_tuple(A_dims, number{}, number{}); + const auto A_strides_M_K = + to_tuple(A_strides, number{}, number{}); + + // dimension Ids for M and K + constexpr auto A_dims_M_ids = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + constexpr auto A_dims_K_ids = + typename arithmetic_sequence_gen::type{}; + + // Dimensions for M [M0, M1, ...] and K [K0, K1, ...] + const auto dims_M = get_container_subset(A_dims_M_K, A_dims_M_ids); + const auto dims_K = get_container_subset(A_dims_M_K, A_dims_K_ids); + + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] Discriptor + const auto A_grid_desc_Ms_Ks = + ck_tile::make_naive_tensor_descriptor(A_dims_M_K, A_strides_M_K); + + // transformed tensor to flatten M and K dimensions [M_total = M0 * M1 * M2 * ... , K_total + // = K0 * K1 * K2 * ...] + const auto A_grid_desc_Mflat_Kflat = ck_tile::transform_tensor_descriptor( + A_grid_desc_Ms_Ks, + make_tuple(make_merge_transform(dims_M), make_merge_transform(dims_K)), + make_tuple(A_dims_M_ids, A_dims_K_ids), + make_tuple(sequence<0>{}, sequence<1>{})); + + return A_grid_desc_Mflat_Kflat; + } + + /// @brief Creates a tensor descriptor for input tensor B with batch dimensions removed. + /// @param B_dims Dimension vector for tensor B: [G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] + /// @param B_strides Stride vector for tensor B: [G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] + /// @return Flattened tensor descriptor: [N_total, K_total] for GEMM computation + /// @details Removes batch dimensions and flattens N and K dimensions for efficient GEMM + /// execution + CK_TILE_HOST static constexpr auto + Make_B_GridDescriptor_N_K(const std::vector& B_dims = {}, + const std::vector& B_strides = {}) + { + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, number{}); + }; + + // Remove G Dimensions + const auto B_dims_N_K = + to_tuple(B_dims, number{}, number{}); + const auto B_strides_N_K = + to_tuple(B_strides, number{}, number{}); + + // dimension Ids for N and K + constexpr auto B_dims_N_ids = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{}; + constexpr auto B_dims_K_ids = + typename arithmetic_sequence_gen::type{}; + + // Dimensions for N [N0, N1, ...] and K [K0, K1, ...] + const auto dims_N = get_container_subset(B_dims_N_K, B_dims_N_ids); + const auto dims_K = get_container_subset(B_dims_N_K, B_dims_K_ids); + + // naive tensor B[N0, N1, N2, ..., K0, K1, K2...] Discriptor + const auto B_grid_desc_Ns_Ks = + ck_tile::make_naive_tensor_descriptor(B_dims_N_K, B_strides_N_K); + + // transformed tensor to flatten N and K dimensions [N_total = N0 * N1 * N2 * ... , K_total + // = K0 * K1 * K2 * ...] + const auto B_grid_desc_Nflat_Kflat = ck_tile::transform_tensor_descriptor( + B_grid_desc_Ns_Ks, + make_tuple(make_merge_transform(dims_N), make_merge_transform(dims_K)), + make_tuple(B_dims_N_ids, B_dims_K_ids), + make_tuple(sequence<0>{}, sequence<1>{})); + + return B_grid_desc_Nflat_Kflat; + } + + /// @brief Creates a tensor descriptor for output tensor E with batch dimensions removed. + /// @param E_dims Dimension vector for tensor E: [G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] + /// @param E_strides Stride vector for tensor E: [G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] + /// @return Flattened tensor descriptor: [M_total, N_total] for GEMM computation + /// @details Removes batch dimensions and flattens M and N dimensions for efficient GEMM + /// execution + CK_TILE_HOST static constexpr auto + Make_E_GridDescriptor_M_N(const std::vector& E_dims = {}, + const std::vector& E_strides = {}) + { + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, number{}); + }; + + // Remove G dimensions + const auto E_dims_M_N = + to_tuple(E_dims, number{}, number{}); + const auto E_strides_M_N = + to_tuple(E_strides, number{}, number{}); + + // dimension Ids for M and N + constexpr auto E_dims_M_ids = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + constexpr auto E_dims_N_ids = + typename arithmetic_sequence_gen::type{}; + + // Dimensions for M and N + const auto dims_M = get_container_subset(E_dims_M_N, E_dims_M_ids); + const auto dims_N = get_container_subset(E_dims_M_N, E_dims_N_ids); + + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] Discriptor + const auto E_grid_desc_Ms_Ns = + ck_tile::make_naive_tensor_descriptor(E_dims_M_N, E_strides_M_N); + + // transformed tensor to flatten M and N dimensions [M_total = M0 * M1 * M2 * ... , + // N_total = N0 * N1 * N2 * ...] + const auto E_grid_desc_Mflat_Nflat = ck_tile::transform_tensor_descriptor( + E_grid_desc_Ms_Ns, + make_tuple(make_merge_transform(dims_M), make_merge_transform(dims_N)), + make_tuple(E_dims_M_ids, E_dims_N_ids), + make_tuple(sequence<0>{}, sequence<1>{})); + + return E_grid_desc_Mflat_Nflat; + } +}; + +} // namespace ck_tile From 46c10c316db0b4e987ff69b2804d97a98bb01c1a Mon Sep 17 00:00:00 2001 From: damien-lejeune <31985270+damien-lejeune@users.noreply.github.com> Date: Mon, 13 Oct 2025 13:24:47 +0200 Subject: [PATCH 02/75] Update include path to break the remod's cyclic dep issue (#2978) * Update include path to break the cyclic dep issue * Use ck_tile::permute_vectors_i4x4_b in tile engine --------- Co-authored-by: Damien Lejeune Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> --- include/ck_tile/host.hpp | 1 + include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp | 2 + include/ck_tile/ops/batched_transpose.hpp | 2 + include/ck_tile/ops/common.hpp | 3 +- .../ops/common/load_interleaved_pk_type.hpp | 2 +- include/ck_tile/ops/elementwise.hpp | 2 + include/ck_tile/ops/epilogue.hpp | 2 + include/ck_tile/ops/flatmm.hpp | 2 + include/ck_tile/ops/fmha.hpp | 2 + include/ck_tile/ops/fused_moe.hpp | 2 + include/ck_tile/ops/gemm.hpp | 8 +-- include/ck_tile/ops/gemm_quant.hpp | 8 +-- include/ck_tile/ops/grouped_convolution.hpp | 2 + include/ck_tile/ops/image_to_column.hpp | 2 + include/ck_tile/ops/layernorm2d.hpp | 2 + include/ck_tile/ops/norm_reduce.hpp | 2 + include/ck_tile/ops/permute.hpp | 2 + include/ck_tile/ops/reduce.hpp | 2 + include/ck_tile/ops/rmsnorm2d.hpp | 2 + include/ck_tile/ops/smoothquant.hpp | 2 + include/ck_tile/ops/softmax.hpp | 2 + include/ck_tile/ops/topk.hpp | 2 + include/ck_tile/ops/topk_softmax.hpp | 2 + tile_engine/ops/gemm/gemm_common.hpp | 52 ------------------- tile_engine/ops/gemm/gemm_profiler.hpp | 2 +- 25 files changed, 51 insertions(+), 61 deletions(-) diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 86110d57ec..d815b1db40 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -16,6 +16,7 @@ #include "ck_tile/host/host_tensor.hpp" #include "ck_tile/host/joinable_thread.hpp" #include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/permute_pk_int4.hpp" #include "ck_tile/host/ranges.hpp" #include "ck_tile/host/reference/reference_batched_dropout.hpp" #include "ck_tile/host/reference/reference_batched_dropout_randval.hpp" diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp index 1768c802d5..6c0972e10a 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp @@ -9,5 +9,7 @@ #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/batched_transpose.hpp b/include/ck_tile/ops/batched_transpose.hpp index ca0088c812..5822d7b91b 100644 --- a/include/ck_tile/ops/batched_transpose.hpp +++ b/include/ck_tile/ops/batched_transpose.hpp @@ -12,5 +12,7 @@ #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/common.hpp b/include/ck_tile/ops/common.hpp index 7c6adc3ec2..eff2d625b3 100644 --- a/include/ck_tile/ops/common.hpp +++ b/include/ck_tile/ops/common.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" -#include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp index f8432b9da0..fb7a05044f 100644 --- a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp +++ b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp @@ -4,7 +4,7 @@ #pragma once #include "ck_tile/core/config.hpp" -#include "ck_tile/ops/elementwise.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" namespace ck_tile { diff --git a/include/ck_tile/ops/elementwise.hpp b/include/ck_tile/ops/elementwise.hpp index 4858245ec4..7f2303932e 100644 --- a/include/ck_tile/ops/elementwise.hpp +++ b/include/ck_tile/ops/elementwise.hpp @@ -10,5 +10,7 @@ #include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index 6cc0fa8540..ec5a8ef445 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -8,5 +8,7 @@ #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/flatmm.hpp b/include/ck_tile/ops/flatmm.hpp index 1714789e63..41463e6a2d 100644 --- a/include/ck_tile/ops/flatmm.hpp +++ b/include/ck_tile/ops/flatmm.hpp @@ -14,5 +14,7 @@ #include "ck_tile/ops/flatmm/pipeline/flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" #include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 31de21a726..6b25c089bd 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -60,5 +60,7 @@ #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/fused_moe.hpp b/include/ck_tile/ops/fused_moe.hpp index ddb64a2189..71721f3408 100644 --- a/include/ck_tile/ops/fused_moe.hpp +++ b/include/ck_tile/ops/fused_moe.hpp @@ -16,5 +16,7 @@ #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 5edde31cd9..204d67a0ff 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -30,18 +30,18 @@ #include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" -#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp" -#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" @@ -72,5 +72,7 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index cde0b6833f..61cb96c8f4 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -3,8 +3,8 @@ #pragma once -#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp" +#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp" #include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp" #include "ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp" #include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp" @@ -15,11 +15,13 @@ #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp" #include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/grouped_convolution.hpp b/include/ck_tile/ops/grouped_convolution.hpp index 09b50f26b0..1dd13b6246 100644 --- a/include/ck_tile/ops/grouped_convolution.hpp +++ b/include/ck_tile/ops/grouped_convolution.hpp @@ -12,5 +12,7 @@ #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/image_to_column.hpp b/include/ck_tile/ops/image_to_column.hpp index 93664ea138..2307b05190 100644 --- a/include/ck_tile/ops/image_to_column.hpp +++ b/include/ck_tile/ops/image_to_column.hpp @@ -7,5 +7,7 @@ #include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp" #include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/layernorm2d.hpp b/include/ck_tile/ops/layernorm2d.hpp index afbb817db1..9ce22137bf 100644 --- a/include/ck_tile/ops/layernorm2d.hpp +++ b/include/ck_tile/ops/layernorm2d.hpp @@ -10,5 +10,7 @@ #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/norm_reduce.hpp b/include/ck_tile/ops/norm_reduce.hpp index 7dc3e8b7e7..aa074b7f9f 100644 --- a/include/ck_tile/ops/norm_reduce.hpp +++ b/include/ck_tile/ops/norm_reduce.hpp @@ -7,5 +7,7 @@ #include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp" #include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/permute.hpp b/include/ck_tile/ops/permute.hpp index 1cc3d9cbc3..46512c57fe 100644 --- a/include/ck_tile/ops/permute.hpp +++ b/include/ck_tile/ops/permute.hpp @@ -6,5 +6,7 @@ #include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp" #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/reduce.hpp b/include/ck_tile/ops/reduce.hpp index a6721c9305..d628e9c945 100644 --- a/include/ck_tile/ops/reduce.hpp +++ b/include/ck_tile/ops/reduce.hpp @@ -11,5 +11,7 @@ #include "ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp" #include "ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/rmsnorm2d.hpp b/include/ck_tile/ops/rmsnorm2d.hpp index 610541b2e4..00afcf4aed 100644 --- a/include/ck_tile/ops/rmsnorm2d.hpp +++ b/include/ck_tile/ops/rmsnorm2d.hpp @@ -11,5 +11,7 @@ #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/smoothquant.hpp b/include/ck_tile/ops/smoothquant.hpp index dc164dc1a0..1aa14c69e1 100644 --- a/include/ck_tile/ops/smoothquant.hpp +++ b/include/ck_tile/ops/smoothquant.hpp @@ -10,5 +10,7 @@ #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp" #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/softmax.hpp b/include/ck_tile/ops/softmax.hpp index b23e869d81..d559dc15e2 100644 --- a/include/ck_tile/ops/softmax.hpp +++ b/include/ck_tile/ops/softmax.hpp @@ -6,5 +6,7 @@ #include "ck_tile/ops/softmax/block/block_softmax_2d.hpp" #include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/topk.hpp b/include/ck_tile/ops/topk.hpp index 1dc563f757..040c6b8ddc 100644 --- a/include/ck_tile/ops/topk.hpp +++ b/include/ck_tile/ops/topk.hpp @@ -6,5 +6,7 @@ #include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp" #include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/topk_softmax.hpp b/include/ck_tile/ops/topk_softmax.hpp index d0a810de4f..d9657a9764 100644 --- a/include/ck_tile/ops/topk_softmax.hpp +++ b/include/ck_tile/ops/topk_softmax.hpp @@ -8,5 +8,7 @@ #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp" #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/tile_engine/ops/gemm/gemm_common.hpp b/tile_engine/ops/gemm/gemm_common.hpp index 5188915f1a..179aeb7307 100644 --- a/tile_engine/ops/gemm/gemm_common.hpp +++ b/tile_engine/ops/gemm/gemm_common.hpp @@ -74,58 +74,6 @@ constexpr auto is_row_major(Layout) return ck_tile::bool_constant>{}; } -// Permutation function for pk_int4_t -template -void permute_vectors_i4x4_b(Tensor& tensor) -{ - const ck_tile::index_t K = tensor.get_length(0); - const ck_tile::index_t N = tensor.get_length(1); - // vector pk_i4x4 permute - for(int i = 0; i < N; i++) - { - for(int j = 0; j < K; j += 8) - { - int8_t input[8]; - - for(int k = 0; k < 4; k++) - { - int8_t i4x2 = tensor(j + k * 2, i).data; - input[k * 2 + 0] = (i4x2 >> 4) & 0xf; - input[k * 2 + 1] = (i4x2 >> 0) & 0xf; - } - - // permute 01234567->20643175 - { - int8_t hi = input[2]; - int8_t lo = input[0]; - int8_t i4x2 = (hi << 4) | lo; - tensor(j + 0, i) = i4x2; - } - - { - int8_t hi = input[6]; - int8_t lo = input[4]; - int8_t i4x2 = (hi << 4) | lo; - tensor(j + 2, i) = i4x2; - } - - { - int8_t hi = input[3]; - int8_t lo = input[1]; - int8_t i4x2 = (hi << 4) | lo; - tensor(j + 4, i) = i4x2; - } - - { - int8_t hi = input[7]; - int8_t lo = input[5]; - int8_t i4x2 = (hi << 4) | lo; - tensor(j + 6, i) = i4x2; - } - } - } -} - // Structure to hold kernel traits for dispatcher struct KernelTraits { diff --git a/tile_engine/ops/gemm/gemm_profiler.hpp b/tile_engine/ops/gemm/gemm_profiler.hpp index bbf0c92e67..1298c78d18 100644 --- a/tile_engine/ops/gemm/gemm_profiler.hpp +++ b/tile_engine/ops/gemm/gemm_profiler.hpp @@ -96,7 +96,7 @@ class GemmProfiler // Permute vector pk_i4x4 data for device implementation ck_tile::HostTensor b_k_n_dev = b_k_n; // permute_tensor_b(b_k_n_dev); - permute_vectors_i4x4_b(b_k_n_dev); + ck_tile::permute_vectors_i4x4_b(b_k_n_dev); b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); } else From 3021604213750fc5acb02dad50e60ea8b0176b91 Mon Sep 17 00:00:00 2001 From: aledudek Date: Mon, 13 Oct 2025 13:55:23 +0200 Subject: [PATCH 03/75] [CK_TILE] Batched Gemm Kernel IsSupported function checks (#2860) * Add valid check batched gemm part1 * [CK_TILE] Add batched gemm kernel IsSupported func checks * revert broken pre-commit hook changes * revert broken pre-commit hook changes v2 * Clarify error messages --- .../ops/gemm/kernel/batched_gemm_kernel.hpp | 37 ++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index 6f9d53467f..806a471397 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -161,8 +161,43 @@ struct BatchedGemmKernel } CK_TILE_HOST static auto - IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool + IsSupportedArgument(const typename BatchedGemmKernel::KernelArgs& kargs) -> bool { + if(kargs.batch_count < 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met: batch_count must be at least 1 !"); + } + return false; + } + if(kargs.batch_stride_A < 0 || kargs.batch_stride_A < kargs.M * kargs.K) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Conditions not met: batch_stride_A must be non-negative and at least K * M!"); + } + return false; + } + if(kargs.batch_stride_B < 0 || kargs.batch_stride_B < kargs.K * kargs.N) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Conditions not met: batch_stride_B must be non-negative and at least K * N!"); + } + return false; + } + if(kargs.batch_stride_E < 0 || kargs.batch_stride_E < kargs.M * kargs.N) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Conditions not met: batch_stride_E must be non-negative and at least M * N!"); + } + return false; + } return UniversalGemmKernel::IsSupportedArgument(kargs); } From 634634f5c09a3b42f5f838a5af9c948602e246db Mon Sep 17 00:00:00 2001 From: aledudek Date: Mon, 13 Oct 2025 13:57:37 +0200 Subject: [PATCH 04/75] [CK_TILE] Blockwise GEMM pipeline v6 - port of v5 from old CK (#2955) * First checkpoint * Second checkpoint - hot loop scheduler * Third checkpoint - init main operator * Fourth checkpoint - main loop ready * Fifth checkpoint - main loop fix * Sixth checkpoint - ReadWritecompFunc * Seventh checkpoint - Tail finished * [CK_TILE] Blockwise gemm pipeline v5 complete * Working * Working fixes 2 * Rename v5 to v77 temporarily * Data type adjustment * Data type adjustment 2 * [CK_TILE] Blockwise Gemm pipeline v5 add tests * [CK_TILE] Fix calculation error * TEMP: check pipeline * Fix name to V6 * naming and documentation changes * WIP dump * Try fixing v1 * Failing tests v5 * Debugging * Changes v2 * F16 tests working great * Working BlockwiseGemmPipelineV5 as V6 * Cleanup and format * Merging changes part1 * [CK_TILE] Blockwise Gemm Pipeline Comp V5/V6 * Remove commented code * Fix gfx950 build issues * Fix file formatting * Review changes, more concat info, add bf16 bf8 tests * Fix formatting * Add bf16 and bf8 tests --------- Co-authored-by: Adam Osewski --- example/ck_tile/03_gemm/gemm_utils.hpp | 40 +- include/ck_tile/ops/gemm.hpp | 2 + .../gemm_pipeline_ag_bg_cr_comp_v6.hpp | 770 ++++++++++++++++++ ...peline_ag_bg_cr_comp_v6_default_policy.hpp | 56 ++ test/ck_tile/gemm/CMakeLists.txt | 10 +- .../gemm/test_gemm_pipeline_compv6.cpp | 17 + .../gemm/test_gemm_pipeline_kernel_types.hpp | 23 + test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 18 +- 8 files changed, 924 insertions(+), 12 deletions(-) create mode 100644 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp create mode 100644 include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_compv6.cpp diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 07b925d0eb..a831a4f26c 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -16,8 +16,9 @@ #define CK_TILE_PIPELINE_MEMORY 2 #define CK_TILE_PIPELINE_COMPUTE_V4 3 #define CK_TILE_PIPELINE_COMPUTE_V5 4 -#define CK_TILE_PIPELINE_PRESHUFFLE_V1 5 -#define CK_TILE_PIPELINE_PRESHUFFLE_V2 6 +#define CK_TILE_PIPELINE_COMPUTE_V6 5 +#define CK_TILE_PIPELINE_PRESHUFFLE_V1 6 +#define CK_TILE_PIPELINE_PRESHUFFLE_V2 7 template constexpr ck_tile::index_t get_k_warp_tile() @@ -251,9 +252,29 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; - static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::index_t NumWaveGroups = 2; +}; + +template +struct GemmConfigComputeV6 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 32; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V6; + static constexpr ck_tile::index_t NumWaveGroups = 1; }; template @@ -484,6 +505,15 @@ struct PipelineTypeTraits using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; }; +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV6; +}; + template <> struct PipelineTypeTraits { diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 204d67a0ff..2a4f9d21e3 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -44,6 +44,8 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp new file mode 100644 index 0000000000..2ae9001098 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6.hpp @@ -0,0 +1,770 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct BaseGemmPipelineAgBgCrCompV6 +{ + static constexpr index_t PrefetchStages = 3; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + static constexpr index_t HotloopUnroll = 2; + + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + + CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) + { + if(num_loop % HotloopUnroll == 1) + { + return TailNumber::Odd; + } + else + { + return TailNumber::Even; + } + } + + template + CK_TILE_HOST_DEVICE static auto + TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) + { + // Handle all the valid cases. + if(has_hot_loop) + { + if(tail_number == TailNumber::Odd) + { + return run_func(bool_constant{}, + integral_constant{}); + } + else if(tail_number == TailNumber::Even) + { + return run_func(bool_constant{}, + integral_constant{}); + } + } + else + { + if(tail_number == TailNumber::Odd) + { + return run_func(bool_constant{}, + integral_constant{}); + } + else if(tail_number == TailNumber::Even) + { + return run_func(bool_constant{}, + integral_constant{}); + } + } + // If execution reaches here, it's an invalid tail_number because it wasn't handled above. +#if defined(__HIP_DEVICE_COMPILE__) + __builtin_unreachable(); +#else + throw std::logic_error("Invalid TailNumber: Only TailNumber::Odd and TailNumber::Even are " + "supported in this pipeline context."); +#endif + } +}; + +// Compute optimized pipeline +// GlobalPrefetchStages: 3 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 2 +template +struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6 +{ + using Base = BaseGemmPipelineAgBgCrCompV6; + using BasePImpl = GemmPipelineAgBgCrImplBase; + + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; + + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; + + static_assert(!std::is_same_v, "Not implemented"); + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + using BlockGemm = remove_cvref_t())>; + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + + static constexpr index_t BlockSize = Problem::kBlockSize; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + template + static constexpr index_t GetVectorSizeA() + { + return Policy::template GetVectorSizeA(); + } + template + static constexpr index_t GetVectorSizeB() + { + return Policy::template GetVectorSizeB(); + } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + + static constexpr index_t KRepeat = BlockGemm::WarpGemm::kKPerThread / GetSmemPackA(); + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + static constexpr index_t Preshuffle = Problem::Preshuffle; + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + static constexpr auto Scheduler = Problem::Scheduler; + + static constexpr auto is_a_load_tr_v = bool_constant{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "pipeline_AgBgCrCompV6", BlockSize, + concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()), + concat('x', kPadM, kPadN, kPadK), + concat('x', TailNum), + concat('_', KRepeat), + concat('_', DoubleSmemBuffer), + concat('_', Preshuffle), + concat('_', HasHotLoop)); + // clang-format on + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() + { + return Policy::template IsTransposeC(); + } + + template + struct PipelineImpl : public BasePImpl + { + }; + + template <> + struct PipelineImpl : public BasePImpl + { + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0); + constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1); + constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2); + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1); + + constexpr index_t A_LDS_Read_Width = KPerXDL; + constexpr index_t B_LDS_Read_Width = KPerXDL; + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + + constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL); + constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * KPerXDL); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / + (MPerXDL * NPerXDL * KPerXDL); + + constexpr auto num_ds_read_inst_a = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num + : A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + + constexpr auto ds_read_a_issue_cycle = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4; + + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_stage1_a = num_ds_read_inst_a / KRepeat * (KRepeat - 1); + constexpr auto num_dsread_stage1_b = num_ds_read_inst_b / KRepeat * (KRepeat - 1); + constexpr auto num_dsread_stage3_a = num_ds_read_inst_a / KRepeat; + constexpr auto num_dsread_stage3_b = num_ds_read_inst_b / KRepeat; + + constexpr auto num_dsread_stage1_a_mfma = + (num_dsread_stage1_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_stage1_b_mfma = + (num_dsread_stage1_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + constexpr auto num_dsread_stage3_a_mfma = + (num_dsread_stage3_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_stage3_b_mfma = + (num_dsread_stage3_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + constexpr auto num_mfma_stage2 = C_MFMA_Inst_Num - + num_ds_read_inst_a / ds_read_a_mfma_rate - + num_ds_read_inst_b / ds_read_b_mfma_rate; + constexpr auto num_mfma_per_issue = + num_mfma_stage2 / (A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num); + constexpr auto num_dswrite_per_issue_a = A_LDS_Write_Inst_Num / A_Buffer_Load_Inst_Num; + constexpr auto num_dswrite_per_issue_b = B_LDS_Write_Inst_Num / B_Buffer_Load_Inst_Num; + + // stage 1 + static_for<0, num_dsread_stage1_a_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage1_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage1_a - (num_dsread_stage1_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, num_dsread_stage1_b_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage1_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage1_b - (num_dsread_stage1_b_mfma - 1) * ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + // stage 2 + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 3 + static_for<0, num_dsread_stage3_a_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage3_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage3_a - (num_dsread_stage3_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, num_dsread_stage3_b_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage3_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage3_b - (num_dsread_stage3_b_mfma - 1) * ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_barrier(0); + } + + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* __restrict__ p_smem) const + { + // TODO: Add Multi A/B support + static_assert(std::tuple_size>::value == 1, + "Multi A/B is not yet supported for this pipeline."); + static_assert(std::tuple_size>::value == 1, + "Multi A/B is not yet supported for this pipeline."); + + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( + std::is_same_v> && + std::is_same_v>, + "Data Type conflict on A and B matrix input data type."); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0] && + KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1]), + "B block window has incorrect lengths for defined BLayout!"); + + ////////////// LDS desc, window & register ///////////////// + using ALdsType = + remove_cvref_t; + using BLdsType = + remove_cvref_t; + auto&& ABLdsTensorViews = BasePImpl::GetABLdsTensorViews(p_smem); + ALdsType& a_lds_block = ABLdsTensorViews.at(I0); + BLdsType& b_lds_block = ABLdsTensorViews.at(I1); + + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + using acopy_dram_type = + remove_cvref_t; + using bcopy_dram_type = + remove_cvref_t; + + using a_copy_lds_window_type = + remove_cvref_t; + using b_copy_lds_window_type = + remove_cvref_t; + + using a_lds_load_tile_distr_type = + remove_cvref_t; + using b_lds_load_tile_distr_type = + remove_cvref_t; + + auto&& aWindows = + BasePImpl::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + auto&& bWindows = + BasePImpl::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + + // A DRAM tile window for load + // A LDS tile window for store + // A LDS tile for block GEMM + acopy_dram_type& a_copy_dram_window = aWindows.at(I0); + a_copy_lds_window_type& a_copy_lds_window = aWindows.at(I1); + a_lds_load_tile_distr_type& a_lds_gemm_window = aWindows.at(I2); + + // B DRAM tile window for load + // B LDS tile window for store + // B LDS tile for block GEMM + bcopy_dram_type& b_copy_dram_window = bWindows.at(I0); + b_copy_lds_window_type& b_copy_lds_window = bWindows.at(I1); + b_lds_load_tile_distr_type& b_lds_gemm_window = bWindows.at(I2); + + // Block GEMM + auto block_gemm = BlockGemm(); + auto c_block_tile = block_gemm.MakeCBlockTile(); + + using ABlockTileDistr = + decltype(a_copy_dram_window[number<0>{}].get_tile_distribution()); + using BBlockTileDistr = + decltype(b_copy_dram_window[number<0>{}].get_tile_distribution()); + + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = + decltype(make_static_distributed_tensor(BBlockTileDistr{})); + + ABlockTile a_block_tile[Base::GlobalBufferNum]; + BBlockTile b_block_tile[Base::GlobalBufferNum]; + + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + + constexpr auto ALdsTileDistr = decltype(make_static_tile_distribution( + BlockGemm::MakeABlockDistributionEncode())){}; + constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution( + BlockGemm::MakeBBlockDistributionEncode())){}; + + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + + ALdsTile a_lds_tile; + BLdsTile b_lds_tile; + // ----------------------------------------------------------------------------------------- + // Gemm pipeline start + + // Global prefetch 1 + a_block_tile[I0] = load_tile_with_elementwise(a_copy_dram_window, a_element_func); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + b_block_tile[I0] = load_tile_with_elementwise(b_copy_dram_window, b_element_func); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // Local prefill 1 + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile[I0]); + BasePImpl::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); + } + else + { + BasePImpl::LocalPrefill(a_copy_lds_window, a_block_tile[I0]); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile[I0]); + BasePImpl::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); + } + else + { + BasePImpl::LocalPrefill(b_copy_lds_window, b_block_tile[I0]); + } + + // Global prefetch 2 + a_block_tile[I0] = load_tile_with_elementwise(a_copy_dram_window, a_element_func); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + b_block_tile[I0] = load_tile_with_elementwise(b_copy_dram_window, b_element_func); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + + // Global prefetch 3 + a_block_tile[I1] = load_tile_with_elementwise(a_copy_dram_window, a_element_func); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + b_block_tile[I1] = load_tile_with_elementwise(b_copy_dram_window, b_element_func); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + + block_sync_lds(); + + // Local prefetch 1 + BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v); + BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v); + + if(HasHotLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto vmem_buf_idx) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + if constexpr(k0 == (KRepeat - 1)) + { + block_sync_lds(); + + // Local prefill 2 + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution< + Problem>()); + transpose_tile2d(a_shuffle_tmp, a_block_tile[vmem_buf_idx]); + BasePImpl::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); + } + else + { + BasePImpl::LocalPrefill(a_copy_lds_window, + a_block_tile[vmem_buf_idx]); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution< + Problem>()); + transpose_tile2d(b_shuffle_tmp, b_block_tile[vmem_buf_idx]); + BasePImpl::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); + } + else + { + BasePImpl::LocalPrefill(b_copy_lds_window, + b_block_tile[vmem_buf_idx]); + } + + // Global prefetch 4 + a_block_tile[vmem_buf_idx] = + load_tile_with_elementwise(a_copy_dram_window, a_element_func); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + b_block_tile[vmem_buf_idx] = + load_tile_with_elementwise(b_copy_dram_window, b_element_func); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + + block_sync_lds(); + } + block_gemm(c_block_tile, a_lds_tile, b_lds_tile); + + // Local prefetch 2 + BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v); + BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v); + }); + + HotLoopScheduler(); + }; + + LoopFunc(I0); + LoopFunc(I1); + + i += Base::HotloopUnroll; + } while(i < (num_loop - Base::PrefetchStages)); + } + + auto ReadWriteCompFunc = [&](auto vmem_buf_idx) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + if constexpr(k0 == (KRepeat - 1)) + { + block_sync_lds(); + + // Local prefill 3 + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile[vmem_buf_idx]); + BasePImpl::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); + } + else + { + BasePImpl::LocalPrefill(a_copy_lds_window, a_block_tile[vmem_buf_idx]); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile[vmem_buf_idx]); + BasePImpl::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); + } + else + { + BasePImpl::LocalPrefill(b_copy_lds_window, b_block_tile[vmem_buf_idx]); + } + + block_sync_lds(); + } + + block_gemm(c_block_tile, a_lds_tile, b_lds_tile); + + BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v); + BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v); + }); + + HotLoopScheduler(); + }; + + auto ReadCompFunc = [&]() { + static_for<0, KRepeat - 1, 1>{}([&]() { + __syncthreads(); + block_gemm(c_block_tile, a_lds_tile, b_lds_tile); + + // Local prefetch 4 + BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v); + BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v); + + __syncthreads(); + }); + + block_gemm(c_block_tile, a_lds_tile, b_lds_tile); + + HotLoopScheduler(); + }; + + if constexpr(TailNum == TailNumber::Odd) + { + ReadWriteCompFunc(I0); + ReadWriteCompFunc(I1); + ReadCompFunc(); + } + else if constexpr(TailNum == TailNumber::Even) + { + ReadWriteCompFunc(I0); + ReadCompFunc(); + } + + return c_block_tile; + } + }; + + public: + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* __restrict__ p_smem) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem); + } + + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const index_t num_loop, + void* __restrict__ p_smem) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](auto& e, const ADataType& a) { e = a; }, + b_dram_block_window_tmp, + [](auto& e, const BDataType& b) { e = b; }, + num_loop, + p_smem); + } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* __restrict__ p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + a_element_func, + ck_tile::make_tuple(b_dram_block_window_tmp), + b_element_func, + num_loop, + p_smem); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp new file mode 100644 index 0000000000..6ac702d38b --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp @@ -0,0 +1,56 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" + +namespace ck_tile { +// Default policy for GemmPipelineAGmemBGmemCregComputeV6, except the block gemm method, it shares +// the same vector size implementation, SmemSize, Global memory tile distiribution as the +// UniversalGemm Pipeline Policy. +// Default policy class should not be templated, put template on +// member functions instead. +struct GemmPipelineAgBgCrCompV6DefaultPolicy + : public UniversalGemmBasePolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + + constexpr index_t vector_size = + DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType); + constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size(); + constexpr auto wg_attr_num_access = + !(is_a_load_tr || is_b_load_tr) ? WGAttrNumAccessEnum::Single + : vector_size == thread_elements ? WGAttrNumAccessEnum::Single + : vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double + : vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad + : WGAttrNumAccessEnum::Invalid; + + using WarpGemm = WarpGemmDispatcher; + + using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy; + + return BlockGemmARegBRegCRegV1{}; + } +}; +} // namespace ck_tile diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index 1ca7f4fc7d..24cc1bc5ab 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -24,12 +24,13 @@ endif() if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_test_executable(test_ck_tile_gemm_pipeline_universal_fp8 test_gemm_pipeline_universal_fp8.cpp) - target_compile_options(test_ck_tile_gemm_pipeline_universal_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) add_test_executable(test_ck_tile_gemm_pipeline_universal_bf8 test_gemm_pipeline_universal_bf8.cpp) - target_compile_options(test_ck_tile_gemm_pipeline_universal_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) add_test_executable(test_ck_tile_gemm_pipeline_basic_fp8 test_gemm_pipeline_basic_fp8.cpp) - target_compile_options(test_ck_tile_gemm_pipeline_basic_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) add_test_executable(test_ck_tile_gemm_pipeline_basic_bf8 test_gemm_pipeline_basic_bf8.cpp) + + target_compile_options(test_ck_tile_gemm_pipeline_basic_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_universal_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_universal_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(test_ck_tile_gemm_pipeline_basic_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile_gemm tests for current target") @@ -55,10 +56,13 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx90a|gfx11|gfx12") add_gtest_executable(test_ck_tile_gemm_pipeline_compv3 test_gemm_pipeline_compv3.cpp) add_gtest_executable(test_ck_tile_gemm_pipeline_compv4 test_gemm_pipeline_compv4.cpp) add_gtest_executable(test_ck_tile_gemm_pipeline_persistent test_gemm_pipeline_persistent.cpp) + add_gtest_executable(test_ck_tile_gemm_pipeline_compv6 test_gemm_pipeline_compv6.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_mem PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(test_ck_tile_gemm_pipeline_compv3 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(test_ck_tile_gemm_pipeline_compv4 PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS}) target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + target_compile_options(test_ck_tile_gemm_pipeline_compv6 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() if(GPU_TARGETS MATCHES "gfx95") diff --git a/test/ck_tile/gemm/test_gemm_pipeline_compv6.cpp b/test/ck_tile/gemm/test_gemm_pipeline_compv6.cpp new file mode 100644 index 0000000000..a72ff98055 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_compv6.cpp @@ -0,0 +1,17 @@ +#include "test_gemm_pipeline_kernel_types.hpp" +#include "test_gemm_pipeline_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileGemmPipelineCompV6 + : public TestCkTileGemmPipeline> +{ +}; + +#define TEST_SUITE_NAME TestCkTileGemmPipelineCompV6 + +TYPED_TEST_SUITE(TestCkTileGemmPipelineCompV6, KernelTypesCompV6); + +#include "test_gemm_pipeline_ut_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index bba106174c..aa1f610022 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -29,6 +29,7 @@ using Interwave = ck_tile::integral_constant; using CompV3 = ck_tile::integral_constant; using CompV4 = ck_tile::integral_constant; +using CompV6 = ck_tile::integral_constant; using CompAsync = ck_tile::integral_constant; using Persistent = std::true_type; @@ -130,6 +131,28 @@ using KernelTypesCompV4 = ::testing::Types< std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompV4> >; +using KernelTypesCompV6 = ::testing::Types< + std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Row, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Row, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Col, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Col, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Col, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Row, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6>, + std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV6> +>; using KernelTypesCompAsync = ::testing::Types< std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>, std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, I32, I32, I16, Intrawave, CompAsync>, diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 01bc3d7522..994510c060 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -38,6 +38,7 @@ enum struct GemmPipelineType Mem, CompV3, CompV4, + CompV6, CompAsync }; @@ -71,6 +72,15 @@ struct GemmPipelineTypeSelector static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV4"; } }; +template +struct GemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV6; + using pipeline = ck_tile::GemmPipelineAgBgCrCompV6; + + static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV6"; } +}; + template struct GemmPipelineTypeSelector { @@ -120,11 +130,13 @@ class TestCkTileGemmPipeline : public ::testing::Test constexpr bool kPadK = PadK; constexpr bool preshuffle = Preshuffle; - constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4 || + constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4 || PipelineType == GemmPipelineType::CompAsync); + constexpr bool TransposeC = false; + static constexpr bool StructuredSparsity = false; + static constexpr bool NumWaveGroup = 1; // TODO: For now - but this should also be a test parameter - constexpr bool TransposeC = false; constexpr int kBlockPerCu = 1; constexpr ck_tile::index_t TileParitionerGroupNum = 8; @@ -140,8 +152,6 @@ class TestCkTileGemmPipeline : public ::testing::Test GemmSpatiallyLocalTilePartitioner; using Traits = ck_tile::TileGemmTraits; - static constexpr bool StructuredSparsity = false; - static constexpr bool NumWaveGroup = 1; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits Date: Mon, 13 Oct 2025 13:27:02 +0100 Subject: [PATCH 05/75] [CK_TILE] Non-K Major from old CK to CK-Tile (#2442) * Enable the adapted LDS B layout for Row-Major * fix formatting * Implement specialized col-major A LDS block descriptor * Fix formatting * Use VecLoadSize for AK1/BK1 * Fix some thread access pattern values * Use GetVectorSizeA for A * Fix formatting * Add extra condition to avoid division by zero * disable layout for wave32 * remove extra else * fix formatting * Fix formatting * Rename one remaining TileDistributionEncodingPattern2D * Use integer ceil division * revert remod.py changes * also revert utility.hpp * use getA/BTileAccessPattern everywhere * use integer_divide_ceil for AK0 too --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> Co-authored-by: Adam Osewski --- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 520 +++++++++++------- 1 file changed, 318 insertions(+), 202 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 4030783ecc..89e0346961 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -73,10 +73,14 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() { + using ALayout = remove_cvref_t; + using ADataType = remove_cvref_t; using ADataType = remove_cvref_t; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetSmemPackA(); + constexpr auto DataTypeSize = sizeof(ADataType); if constexpr(is_a_load_tr) { @@ -90,47 +94,168 @@ struct UniversalGemmBasePolicy } else { - constexpr index_t KPack = GetSmemPackA(); + // Only use this ColumnMajor layout for Wave64 mode (gfx9) + constexpr auto Wave64 = get_warp_size() == 64; + if constexpr(Wave64 && + std::is_same_v) + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg + // offset for compiler. + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t VecLoadSize = GetVectorSizeA(); + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; + // AK1 + constexpr auto AK1 = number{}; + constexpr auto AK0 = number{}; + // How the M dimension is split across threads + constexpr auto M0 = TileEncodingPattern::X0; // # of threads in M dim + constexpr auto M1 = number{}; - constexpr auto DataTypeSize = sizeof(ADataType); - constexpr auto MLdsLayer = - (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); + // Get the warp tile size + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + constexpr auto MPerXdl = number{}; - constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); + // How many elements we can write by single thread to LDS, + // the transposed / shuffled tile dstr has size: + constexpr auto KThreadWrite = TileEncodingPattern::Y2; + constexpr auto K0PerThreadWrite = integer_divide_ceil(AK0, KThreadWrite); + constexpr auto KThreadRead = get_warp_size() / MPerXdl; + constexpr auto K0PerThreadRead = integer_divide_ceil(AK0, KThreadRead); - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); + constexpr auto LdsBanksWidth = 128; + constexpr auto kfold = (AK1 * M0 * sizeof(ADataType) > LdsBanksWidth) + ? 1 + : LdsBanksWidth / (AK1 * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + ((kfold * K0PerThreadWrite / K0PerThreadRead) > 1 && + (kfold * K0PerThreadWrite / K0PerThreadRead) < KThreadRead) + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; - constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + // 1<=mpair<=n0 + constexpr auto mpair = + (AK1 * MPerXdl * sizeof(ADataType) > LdsBanksWidth) + ? 1 + : ((LdsBanksWidth / (AK1 * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : LdsBanksWidth / (AK1 * MPerXdl * sizeof(ADataType))); - constexpr auto a_lds_block_desc = transform_tensor_descriptor( - a_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, number{})), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + AK1), + AK1); - return a_lds_block_desc; + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_pass_through_transform( + number{}), + make_pass_through_transform(number{}), + make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(AK1)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2, 3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2, 3>{}, + sequence<4>{}, + sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform( + number{}), + make_pass_through_transform(number{}), + make_unmerge_transform(make_tuple(number{}, number{})), + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(AK1)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<1>{}, + sequence<2>{}, + sequence<0, 3>{}, + sequence<4, 5>{}, + sequence<6>{}, + sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}, + number{}, + number{}, + AK1)), + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // A is in RowMajor + { + constexpr auto MLdsLayer = (32 * 4 / KPerBlock / DataTypeSize) < 1 + ? 1 + : (32 * 4 / KPerBlock / DataTypeSize); + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple( + make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform_v3_division_mod(make_tuple( + number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc; + } } } @@ -143,12 +268,12 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { + using BLayout = remove_cvref_t; using BDataType = remove_cvref_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; -#if 1 if constexpr(is_b_load_tr) { // TODO: better lds descriptor for performance @@ -160,178 +285,169 @@ struct UniversalGemmBasePolicy return b_lds_block_desc_0; } else - // else if constexpr(std::is_same_v) { - constexpr index_t KPack = GetSmemPackB(); - constexpr auto BK0 = number{}; - constexpr auto DataTypeSize = sizeof(BDataType); - constexpr auto NLdsLayer = - (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); + // Only use this RowMajor layout for Wave64 mode (gfx9) + constexpr auto Wave64 = get_warp_size() == 64; + if constexpr(Wave64 && std::is_same_v) + { + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t VecLoadSize = GetVectorSizeB(); + using TileEncodingPattern = + tile_distribution_encoding_pattern_2d; + // BK1 + constexpr auto BK1 = number{}; + constexpr auto BK0 = number{}; - constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple( - BK0 * number{}, number{}, number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); + // How threads access data on N dim + constexpr auto N0 = TileEncodingPattern::X0; // # of threads in N dim + constexpr auto N1 = number{}; - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - BK0 * number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); + // Get NPerXdl, the warp tile size + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + constexpr auto NPerXdl = number{}; - constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(number{}, BK0)), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + // How many elements we can write by single thread to LDS, + // the transposed / shuffled tile dstr has size: + constexpr auto KThreadWrite = TileEncodingPattern::Y2; + constexpr auto K0PerThreadWrite = integer_divide_ceil(BK0, KThreadWrite); + constexpr auto KThreadRead = get_warp_size() / NPerXdl; + constexpr auto K0PerThreadRead = integer_divide_ceil(BK0, KThreadRead); - constexpr auto b_lds_block_desc = transform_tensor_descriptor( - b_lds_block_desc_bk0_nldslayer_n_bk1, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, number{})), - make_merge_transform_v3_division_mod(make_tuple(BK0, number{}))), - make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - return b_lds_block_desc; + // check if we exceed all 32banks width - (32x4B) + constexpr auto LdsBanksWidth = 128; + constexpr auto kfold = (BK1 * N0 * sizeof(BDataType) > LdsBanksWidth) + ? 1 + : LdsBanksWidth / (BK1 * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + ((kfold * K0PerThreadWrite / K0PerThreadRead) > 1 && + (kfold * K0PerThreadWrite / K0PerThreadRead) < KThreadRead) + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = + (BK1 * NPerXdl * sizeof(BDataType) > LdsBanksWidth) + ? 1 + : ((LdsBanksWidth / (BK1 * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : LdsBanksWidth / (BK1 * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + BK1), + BK1); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_pass_through_transform( + number{}), + make_pass_through_transform(number{}), + make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(BK1)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2, 3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2, 3>{}, + sequence<4>{}, + sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform( + number{}), + make_pass_through_transform(number{}), + make_unmerge_transform(make_tuple(number{}, number{})), + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(BK1)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple( + sequence<1>{}, // 0: K0PerThreadWrite + sequence<2>{}, // 1: KThreadReadPerm + sequence<0, 3>{}, // 2: KThreadWrite / kfold / KThreadReadPerm, 3: N1 + sequence<4, 5>{}, // 4: kfold, 5: N0 / npair + sequence<6>{}, // 6: npair + sequence<7>{})); // 7: BK1 + + constexpr auto b_lds_block_desc_nk = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}, + number{}, + number{}, + BK1)), + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{}, number{}))), + make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return b_lds_block_desc_nk; + } + else // B is Column Major + { + constexpr index_t KPack = GetSmemPackB(); + constexpr auto BK0 = number{}; + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr auto NLdsLayer = (32 * 4 / KPerBlock / DataTypeSize) < 1 + ? 1 + : (32 * 4 / KPerBlock / DataTypeSize); + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(BK0 * number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + BK0 * number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(number{}, BK0)), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple( + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod(make_tuple(BK0, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return b_lds_block_desc; + } } -#else - else // B is Row Major - { - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t VecLoadSize = GetVectorSizeB(); - using TileEncodingPattern = - tile_distribution_encoding_pattern_2d; - - constexpr auto BK0 = number{}; - constexpr auto BK1 = number{}; - // constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N0 = TileEncodingPattern::X0; - constexpr auto N1 = NPerBlock / N0; - - using WarpTile = typename Problem::BlockGemmShape::WarpTile; - constexpr auto NPerXdl = number{}; - - // constexpr auto KThreadWrite = - // BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); - constexpr auto KThreadWrite = TileEncodingPattern::Y2; - constexpr auto K0PerThreadWrite = BK0 / KThreadWrite; - constexpr auto KThreadRead = 64 / NPerXdl; - constexpr auto K0PerThreadRead = BK0 / KThreadRead; - - constexpr auto kfold = - (BK1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (BK1 * N0 * sizeof(BDataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=npair<=n0 - constexpr auto npair = (BK1 * NPerXdl * sizeof(BDataType) > 128) - ? 1 - : ((128 / (BK1 * NPerXdl * sizeof(BDataType))) > N0 - ? N0 - : 128 / (BK1 * NPerXdl * sizeof(BDataType))); - - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}, - BK1)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple( - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_xor_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(BK1)), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); - - constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_unmerge_transform(make_tuple(number{}, number{})), - make_unmerge_transform(make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(BK1)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<1>{}, - sequence<2>{}, - sequence<0, 3>{}, - sequence<4, 5>{}, - sequence<6>{}, - sequence<7>{})); - - // constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - // b_lds_block_desc_unmerged, - // make_tuple(make_merge_transform_v3_division_mod( - // make_tuple(number{}, - // number{}, - // number{}, - // number{})), - // make_merge_transform_v3_division_mod( - // make_tuple(number{}, number{}, number{})), - // make_pass_through_transform(BK1)), - // make_tuple(sequence<0, 1, 4, 2>{}, sequence<5, 6, 3>{}, sequence<7>{}), - // make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); - - constexpr auto b_lds_block_desc_kn = transform_tensor_descriptor( - b_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, - number{}, - number{}, - number{}, - BK1)), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}, number{}))), - make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - - // return b_lds_block_desc_bk0_n_bk1; - return b_lds_block_desc_kn; - - // constexpr auto b_lds_block_desc_bk0_n_bk1 = make_naive_tensor_descriptor( - // make_tuple(BK0, number{}, number{}), - // make_tuple(number{}, number{}, number<1>{}), - // number{}, - // number<1>{}); - - // constexpr auto b_lds_block_desc = transform_tensor_descriptor( - // b_lds_block_desc_bk0_n_bk1, - // make_tuple(make_pass_through_transform(number{}), - // make_merge_transform_v3_division_mod(make_tuple(BK0, - // number{}))), - // make_tuple(sequence<1>{}, sequence<0, 2>{}), - // make_tuple(sequence<0>{}, sequence<1>{})); - - // return b_lds_block_desc; - } -#endif } /** From fc2a121c4446b4ca939e977563528019b30e6114 Mon Sep 17 00:00:00 2001 From: John Shumway Date: Mon, 13 Oct 2025 08:11:51 -0700 Subject: [PATCH 06/75] Enable GMock and improve gtest configuration (#2976) Our current cmake/gtest.cmake file does not enable gmock. Gmock is needed for matchers that are needed for more readable unit tests. This PR enables gmock and does a little cleanup in gtest.cmake: * Enable BUILD_GMOCK by default (was previously disabled) * Patch gtest-src/googlemock/CMakeLists.txt for broken include path. * Add configuration to gmock if the target is used. No other changes in this PR, but I've verified I can use gmock matchers correctly once I include these changes in other code. --- cmake/gtest.cmake | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/cmake/gtest.cmake b/cmake/gtest.cmake index 6587f4c4be..41e2fa2cc0 100644 --- a/cmake/gtest.cmake +++ b/cmake/gtest.cmake @@ -12,6 +12,17 @@ FetchContent_Declare( GIT_TAG f8d7d77c06936315286eb55f8de22cd23c188571 ) +FetchContent_Populate(GTest) + +# Patch googlemock/CMakeLists.txt to fix invalid include path +set(GMOCK_CMAKE "${gtest_SOURCE_DIR}/googlemock/CMakeLists.txt") +file(READ "${GMOCK_CMAKE}" GMOCK_CMAKE_CONTENT) +string(REPLACE [[gtest_SOURCE_DIR}/include]] + [[gtest_SOURCE_DIR}/googletest/include]] + GMOCK_CMAKE_CONTENT + "${GMOCK_CMAKE_CONTENT}") +file(WRITE "${GMOCK_CMAKE}" "${GMOCK_CMAKE_CONTENT}") + # Suppress ROCMChecks WARNING on GoogleTests set(ROCM_DISABLE_CHECKS FALSE) macro(rocm_check_toolchain_var var access value list_file) @@ -24,7 +35,7 @@ if(WIN32) set(gtest_force_shared_crt ON CACHE_INTERNAL "") endif() -set(BUILD_GMOCK OFF CACHE INTERNAL "") +set(BUILD_GMOCK ON CACHE INTERNAL "") set(INSTALL_GTEST OFF CACHE INTERNAL "") # Store the current value of BUILD_SHARED_LIBS @@ -32,15 +43,12 @@ set(__build_shared_libs ${BUILD_SHARED_LIBS}) set(BUILD_SHARED_LIBS OFF CACHE INTERNAL "") set(ROCM_DISABLE_CHECKS TRUE) -FetchContent_MakeAvailable(GTest) +add_subdirectory(${gtest_SOURCE_DIR} ${gtest_BINARY_DIR}) set(ROCM_DISABLE_CHECKS FALSE) # Restore the old value of BUILD_SHARED_LIBS set(BUILD_SHARED_LIBS ${__build_shared_libs} CACHE BOOL "Type of libraries to build" FORCE) -set(BUILD_GMOCK OFF CACHE INTERNAL "") -set(INSTALL_GTEST OFF CACHE INTERNAL "") - set(GTEST_CXX_FLAGS -Wno-undef -Wno-reserved-identifier @@ -71,3 +79,12 @@ target_compile_options(gtest_main PRIVATE ${GTEST_CXX_FLAGS}) target_compile_definitions(gtest PRIVATE GTEST_HAS_SEH=0) target_compile_definitions(gtest_main PRIVATE GTEST_HAS_SEH=0) +if(TARGET gmock) + target_compile_options(gmock PRIVATE ${GTEST_CXX_FLAGS}) + target_compile_definitions(gmock PRIVATE GTEST_HAS_SEH=0) +endif() + +if(TARGET gmock_main) + target_compile_options(gmock_main PRIVATE ${GTEST_CXX_FLAGS}) + target_compile_definitions(gmock_main PRIVATE GTEST_HAS_SEH=0) +endif() From e1b0bdfbfa92f47006fdbced627c7470eacdea2b Mon Sep 17 00:00:00 2001 From: ClementLinCF <162283536+ClementLinCF@users.noreply.github.com> Date: Tue, 14 Oct 2025 02:52:37 +0800 Subject: [PATCH 07/75] [CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm (#2540) * [CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm * Update rmsnorm host reference * Update tree reduction of rmsnorm for reference host * Fix cross warp for m > 1 cases * Add RMSNorm model selectable option for host reference * Fix save_unquant cases * Update reference rmsnorm forward function to use enum for model sensitivity * Update reference rmsnorm calculation for model sensitivity * Fix m warp for layernorm * Adjust parameter of reference for twoPass * Fix clang format * Run clang-format-overwrite.sh to fix formating issue * fix clang format --------- Co-authored-by: MHYang Co-authored-by: illsilin_amdeng Co-authored-by: ThomasNing --- example/ck_tile/02_layernorm2d/generate.py | 33 +++++ example/ck_tile/10_rmsnorm2d/generate.py | 39 +++++- .../ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp | 47 +++++-- .../ck_tile/10_rmsnorm2d/script/smoke_test.sh | 124 +++++++++++------- .../reference/reference_rmsnorm2d_fwd.hpp | 31 ++++- .../ops/reduce/block/block_reduce2d.hpp | 4 +- ...rm2d_fwd_pipeline_model_sensitive_pass.hpp | 6 +- 7 files changed, 217 insertions(+), 67 deletions(-) diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index b7512b2999..5f589db8d0 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -75,6 +75,39 @@ struct layernorm2d_fwd_traits_ using SmoothScaleDataType = ck_tile::remove_cvref_t; using YScaleDataType = ck_tile::remove_cvref_t; + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return total_warps; + } + else + { + // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); + return ThreadPerBlock_N_ / ck_tile::get_warp_size(); + } + }(); + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; diff --git a/example/ck_tile/10_rmsnorm2d/generate.py b/example/ck_tile/10_rmsnorm2d/generate.py index 0e948322a2..75d7abd0ad 100644 --- a/example/ck_tile/10_rmsnorm2d/generate.py +++ b/example/ck_tile/10_rmsnorm2d/generate.py @@ -75,6 +75,39 @@ struct rmsnorm2d_fwd_traits_ using YScaleDataType = ck_tile::remove_cvref_t; using UnquantYDataType = ck_tile::remove_cvref_t; + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return total_warps; + } + else + { + // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); + return ThreadPerBlock_N_ / ck_tile::get_warp_size(); + } + }(); + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; @@ -605,15 +638,15 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 1)] - } + } } - + total_blob = list() for model_sensitive_flag in [0, 1]: # 0: default; 1: model sensitive current_trait_dict = h_trait_dicts[model_sensitive_flag] for hs_key in current_trait_dict: - hs = current_trait_dict[hs_key] + hs = current_trait_dict[hs_key] current_n = hs_key for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list): prec_i, prec_o = dtype.split(',') diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp index 6e2664e9ba..8518b5ddc7 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp @@ -70,16 +70,16 @@ template bool run(const ck_tile::ArgParser& arg_parser) { - ck_tile::index_t m = arg_parser.get_int("m"); - ck_tile::index_t n = arg_parser.get_int("n"); - float epsilon = arg_parser.get_float("e"); - int kname = arg_parser.get_int("kname"); - int do_validation = arg_parser.get_int("v"); - int fused_add = arg_parser.get_int("fadd"); - int fused_quant = arg_parser.get_int("fquant"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); - const int use_model_sensitive_rmsnorm = arg_parser.get_int("s"); + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + float epsilon = arg_parser.get_float("e"); + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int fused_add = arg_parser.get_int("fadd"); + int fused_quant = arg_parser.get_int("fquant"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int use_model_sensitive_rmsnorm = arg_parser.get_int("s"); ck_tile::index_t x_stride = arg_parser.get_int("x_stride"); if(x_stride < 0) @@ -196,6 +196,11 @@ bool run(const ck_tile::ArgParser& arg_parser) return base_str; }(); + if(n > 8192) + { + use_model_sensitive_rmsnorm = 0; + } + std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride << ", yr_stride:" << yr_stride << ", s:" << use_model_sensitive_rmsnorm << std::flush; @@ -297,7 +302,7 @@ bool run(const ck_tile::ArgParser& arg_parser) const int N = acc_.mDesc.get_lengths()[1]; for(int n_ = 0; n_ < N; ++n_) { - o_unquant_(m_, n_) = ck_tile::type_convert(acc_(m_, n_)); + o_unquant_(m_, n_) = ck_tile::type_convert(acc_(m_, n_)); } dquant_functor(m_, o_, acc_); @@ -316,7 +321,8 @@ bool run(const ck_tile::ArgParser& arg_parser) invRms_host_ref, unquant_y_host_ref, epsilon, - default_and_dquant_functor); + default_and_dquant_functor, + use_model_sensitive_rmsnorm); } else { @@ -331,7 +337,8 @@ bool run(const ck_tile::ArgParser& arg_parser) invRms_host_ref, unquant_y_host_ref, epsilon, - dquant_functor); + dquant_functor, + use_model_sensitive_rmsnorm); } } else @@ -343,7 +350,14 @@ bool run(const ck_tile::ArgParser& arg_parser) YDataType, InvRmsDataType, ck_tile::null_type>( - x_host, gamma_host, y_host_ref, invRms_host_ref, unquant_y_null, epsilon); + x_host, + gamma_host, + y_host_ref, + invRms_host_ref, + unquant_y_null, + epsilon, + ck_tile::reference_rmsnorm2d_default_epilogue{}, + use_model_sensitive_rmsnorm); } y_buf.FromDevice(y_host_dev.data()); @@ -354,6 +368,11 @@ bool run(const ck_tile::ArgParser& arg_parser) y_residual_buf.FromDevice(y_residual_host_dev.data()); } + if constexpr(SaveUnquant) + { + unquant_y_buf.FromDevice(unquant_y_host_dev.data()); + } + auto [rtol, atol] = get_elimit(); if(x_stride == n) { diff --git a/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh b/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh index 1c79dafadd..3a0f7dbb66 100755 --- a/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh +++ b/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh @@ -1,49 +1,85 @@ -#!/bin/sh +#!/bin/bash + EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)" -for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -prec_o=fp8" "-fquant=2 -prec_o=fp8"\ - "-fquant=1 -prec_o=int8 -save_unquant=1" "-fquant=2 -prec_o=int8 -save_unquant=1" "-fquant=1 -prec_o=fp8 -save_unquant=1" "-fquant=2 -prec_o=fp8 -save_unquant=1"; do -for pr_i in "fp16" "bf16" ; do -for fadd in "0" "1"; do -# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm -for s in "0" "1"; do -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=99 -n=13 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=17 -n=16 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=100 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=4 -n=128 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=80 -n=127 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=22 -n=255 -stride=256 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=599 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=19 -n=512 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=33 -n=313 -stride=1000 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=11 -n=510 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=171 -n=676 -stride=818 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=91 -n=636 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=12 -n=768 -stride=800 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=100 -n=766 -stride=812 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=31 -n=1024 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=64 -n=1000 -stride=1004 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=8 -n=1501 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=1826 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=5 -n=2040 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=2734 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=3182 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=9 -n=4096 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=8192 -done -done -done +total=0 +valid=0 + +run_case() { + cmd="$EXE -prec_i=$1 -fadd=$2 -s=$3 $4 -m=$5 -n=$6 $7" + echo "[CMD] $cmd" + output=$($cmd 2>&1) + echo "$output" + if echo "$output" | grep -q "valid:y"; then + valid=$((valid + 1)) + fi + total=$((total + 1)) +} + +fquant_list=( + "" + "-fquant=1 -prec_o=int8" + "-fquant=2 -prec_o=int8" + "-fquant=1 -prec_o=fp8" + "-fquant=2 -prec_o=fp8" + "-fquant=1 -prec_o=int8 -save_unquant=1" + "-fquant=2 -prec_o=int8 -save_unquant=1" + "-fquant=1 -prec_o=fp8 -save_unquant=1" + "-fquant=2 -prec_o=fp8 -save_unquant=1" +) + +m_n_list=( + "99 13" "17 16" "1 100" "4 128" "80 127" + "7 599" "19 512" "11 510" "91 636" + "31 1024" "8 1501" "3 1826" "5 2040" + "7 2734" "1 3182" "9 4096" "3 8192" +) + +### Add special stride test ### +m_n_stride_list=( + "22 255 -x_stride=256 -xr_stride=256 -y_stride=256 -yr_stride=256" + "33 313 -x_stride=1000 -xr_stride=1000 -y_stride=1000 -yr_stride=1000" + "171 676 -x_stride=818 -xr_stride=818 -y_stride=818 -yr_stride=818" + "12 768 -x_stride=800 -xr_stride=800 -y_stride=800 -yr_stride=800" + "100 766 -x_stride=812 -xr_stride=812 -y_stride=812 -yr_stride=812" + "64 1000 -x_stride=1004 -xr_stride=1004 -y_stride=1004 -yr_stride=1004" +) + +for fquant in "${fquant_list[@]}"; do + for pr_i in "fp16" "bf16"; do + for fadd in "0" "1"; do + for s in "0" "1"; do + for pair in "${m_n_list[@]}"; do + m=$(echo $pair | cut -d ' ' -f1) + n=$(echo $pair | cut -d ' ' -f2) + run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" "" + done + + ### Running tests with stride ### + for triple in "${m_n_stride_list[@]}"; do + m=$(echo $triple | cut -d ' ' -f1) + n=$(echo $triple | cut -d ' ' -f2) + stride_args=$(echo $triple | cut -d ' ' -f3-) + run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" "$stride_args" + done + done + done + done done -# The following cases uses two pass pipeline which doesn't support quant epilogue. -for fquant in "" -for pr_i in "fp16" "bf16" ; do -for fadd in "0" "1"; do -# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm -for s in "0" "1"; do -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=10547 -#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134 -done -done -done +# Special two-pass only +for pr_i in "fp16" "bf16"; do + for fadd in "0" "1"; do + for s in "0" "1"; do + run_case "$pr_i" "$fadd" "$s" "" "1" "10547" "" + done + done done + +# Summary +echo "==============================" +echo "Total cases: $total" +echo "Valid cases: $valid" +accuracy=$(awk "BEGIN {printf \"%.2f\", ($valid / $total) * 100}") +echo "Accuracy: $accuracy%" +echo "==============================" diff --git a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp index 070168b51d..424fff4470 100644 --- a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp +++ b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/host_tensor.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp" namespace ck_tile { @@ -43,7 +44,9 @@ void reference_rmsnorm2d_fwd(const HostTensor& x_m_n, HostTensor& invRms_m, HostTensor& unquant_y_m_n, ComputeDataType epsilon, - Epilogue epilogue_functor = {}) + Epilogue epilogue_functor = {}, + const int use_model_sensitive_rmsnorm = + static_cast(Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL)) { auto rmsnorm2d_fwd_func = [&](auto m) { const int N = x_m_n.mDesc.get_lengths()[1]; @@ -68,7 +71,30 @@ void reference_rmsnorm2d_fwd(const HostTensor& x_m_n, { ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); ComputeDataType gamma = ck_tile::type_convert(gamma_n(n)); - acc(m, n) = x * divisor * gamma; + if(use_model_sensitive_rmsnorm == + static_cast( + Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL)) // 0: for no specific model + { + acc(m, n) = x * divisor * gamma; + } + else if(use_model_sensitive_rmsnorm == + static_cast(Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE)) // 1: for T5-like model + { + if constexpr(std::is_same_v) + { + const auto tmp0 = float_to_bf16(x * divisor); + const auto tmp1 = float_to_bf16( + type_convert(tmp0) * gamma); + const auto rmsn_ = type_convert(tmp1); + acc(m, n) = rmsn_; + } + else + { + const auto tmp = type_convert(x * divisor); + const auto rmsn_ = type_convert(tmp) * gamma; + acc(m, n) = rmsn_; + } + } } if constexpr(!std::is_same_v) @@ -84,4 +110,5 @@ void reference_rmsnorm2d_fwd(const HostTensor& x_m_n, make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])( std::thread::hardware_concurrency()); } + } // namespace ck_tile diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index b72657b785..b97a66a3ec 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -400,11 +400,13 @@ struct BlockReduce2dTreeCrossWarpSync block_sync_lds(); // We let each warp holds a duplication to do reduction. + const index_t local_warp_id = warp_id / num_reduce_warps; + const index_t local_smem_os = local_warp_id * num_reduce_warps; static_for<0, thread_buf_size, 1>{}([&](auto i) { DataType v = 0; if(lane_id < num_reduce_warps) { - v = smem_ptr[lane_id + i * num_warps]; + v = smem_ptr[i * num_warps + local_smem_os + lane_id]; } // cross-lane reduce for replication diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp index c5923ba10d..1d5467b459 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp @@ -146,7 +146,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass // compute mean square each-thread->cross-lane->cross-warp auto square_sum = block_reduce2d.template MakeYBlockTile(); set_tile(square_sum, 0); - if constexpr(Problem::BlockShape::Vector_N % 2 == 0) + if constexpr((Problem::BlockShape::Repeat_N * Problem::BlockShape::Vector_N) % 2 == 0) { sweep_tile( acc, @@ -179,7 +179,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass const auto gamma_ = type_convert(gamma[j_idx]); - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { const auto tmp0 = float_to_bf16(acc[idx] * inv_rms_[i_idx]); @@ -190,7 +190,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass } else { - const auto tmp = type_convert(acc[idx] * inv_rms_[i_idx]); + const auto tmp = type_convert(acc[idx] * inv_rms_[i_idx]); const auto rmsn_ = type_convert(tmp) * gamma_; rmsn(idx) = rmsn_; } From 589e242eda730958b36c4f78bfad1991c499b0d2 Mon Sep 17 00:00:00 2001 From: msaffari-amd Date: Tue, 14 Oct 2025 13:20:25 +0200 Subject: [PATCH 08/75] Fix: Handle JSON boolean values (pad_m, pad_n, pad_k and persistent) in gemm_instance_builder (#3008) --- tile_engine/ops/gemm/gemm_instance_builder.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index c2214da613..0dc9fffedb 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -450,11 +450,11 @@ struct SelectedKernel {{ static constexpr ck_tile::index_t WarpTileK = {tile_config["warp_tile_k"]}; // Traits - static constexpr bool kPadM = {"true" if pad_m == "true" else "false"}; - static constexpr bool kPadN = {"true" if pad_n == "true" else "false"}; - static constexpr bool kPadK = {"true" if pad_k == "true" else "false"}; + static constexpr bool kPadM = {"true" if pad_m in [True, "true"] else "false"}; + static constexpr bool kPadN = {"true" if pad_n in [True, "true"] else "false"}; + static constexpr bool kPadK = {"true" if pad_k in [True, "true"] else "false"}; static constexpr bool TransposeC = false; - static constexpr bool UsePersistentKernel = {"true" if persistent == "true" else "false"}; + static constexpr bool UsePersistentKernel = {"true" if persistent in [True, "true"] else "false"}; static constexpr bool DoubleSmemBuffer = {"true" if pipeline == "compv4" else "false"}; static constexpr bool UseStructuredSparsity = false; static constexpr bool Preshuffle = false; @@ -576,7 +576,7 @@ struct SelectedKernel {{ }} // Get grid and block sizes - const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent == "true" else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"}; + const dim3 grids = {"GemmKernel::MaxOccupancyGridSize(stream)" if persistent in [True, "true"] else "GemmKernel::GridSize(args.M, args.N, args.k_batch)"}; const dim3 blocks = GemmKernel::BlockSize(); if(stream.log_level_ > 0) {{ From 6deaaa92cc561f5bc29d956d6f6de903db19a079 Mon Sep 17 00:00:00 2001 From: jakpiase Date: Tue, 14 Oct 2025 16:09:16 +0200 Subject: [PATCH 09/75] [CK_TILE] Switch into universal gemms for conv bwds (#2981) * switch into universal gemms for conv bwds * some fixes and support universal gemm in conv fwd * add reviewer comments --- .../20_grouped_convolution/gemm_configs.hpp | 303 ++++++++++++++++++ .../grouped_convolution_backward_data.cpp | 10 +- ...uped_convolution_backward_data_invoker.hpp | 239 ++++++++------ .../grouped_convolution_backward_weight.cpp | 10 +- ...ed_convolution_backward_weight_invoker.hpp | 182 +++++++---- ..._convolution_backward_weight_two_stage.cpp | 11 +- ...tion_backward_weight_two_stage_invoker.hpp | 180 +++++++---- .../grouped_convolution_forward.cpp | 14 +- .../grouped_convolution_forward_invoker.hpp | 228 ++++++++----- .../grouped_convolution_utils.hpp | 6 +- ...n_grouped_convolution_bwd_data_example.inc | 16 +- ...grouped_convolution_bwd_weight_example.inc | 16 +- .../run_grouped_convolution_fwd_example.inc | 16 +- ...ouped_convolution_backward_data_kernel.hpp | 91 +++--- ...ped_convolution_backward_weight_kernel.hpp | 144 ++++----- .../grouped_convolution_forward_kernel.hpp | 84 ++--- .../utils/grouped_convolution_utils.hpp | 10 +- .../utils/transform_conv_bwd_data_to_gemm.hpp | 14 +- .../transform_conv_bwd_weight_to_gemm.hpp | 19 +- 19 files changed, 1043 insertions(+), 550 deletions(-) create mode 100644 example/ck_tile/20_grouped_convolution/gemm_configs.hpp diff --git a/example/ck_tile/20_grouped_convolution/gemm_configs.hpp b/example/ck_tile/20_grouped_convolution/gemm_configs.hpp new file mode 100644 index 0000000000..37a63cd65c --- /dev/null +++ b/example/ck_tile/20_grouped_convolution/gemm_configs.hpp @@ -0,0 +1,303 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/utility/json_dump.hpp" + +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 +#define CK_TILE_PIPELINE_COMPUTE_V5 4 + +struct GemmConfigBase +{ + static constexpr bool kPadM = true; + static constexpr bool kPadN = true; + static constexpr bool kPadK = true; + + static constexpr bool PermuteA = false; + static constexpr bool PermuteB = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool Preshuffle = false; + static constexpr bool TiledMMAPermuteN = false; +}; + +template +struct GemmConfigMemoryInterwave : public GemmConfigBase +{ + // Memory friendly for Interwave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +template +struct GemmConfigMemoryIntrawave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; +}; + +template +struct GemmConfigComputeV3 : public GemmConfigBase +{ + // Compute V3 only support Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + + static constexpr int kBlockPerCu = 2; +}; + +template +struct GemmConfigComputeV3_WMMA : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + + static constexpr int kBlockPerCu = 2; +}; + +template +struct GemmConfigComputeV4 : public GemmConfigBase +{ + // Compute V4 only support Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV4_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV5 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 2; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; +}; + +template +struct ConvTypeConfig; + +template <> +struct ConvTypeConfig +{ + using InDataType = ck_tile::half_t; + using WeiDataType = ck_tile::half_t; + using AccDataType = float; + using OutDataType = ck_tile::half_t; + // ToDo: Add more bias config to support different categories of GEMM. +}; + +template <> +struct ConvTypeConfig +{ + using InDataType = ck_tile::bf16_t; + using WeiDataType = ck_tile::bf16_t; + using AccDataType = float; + using OutDataType = ck_tile::bf16_t; +}; + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +template +struct PipelineTypeTraits; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; +}; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp index fa914a7119..6f3bedc32a 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data.cpp @@ -14,7 +14,7 @@ #include "grouped_convolution_backward_data_invoker.hpp" #include "run_grouped_convolution_bwd_data_example.inc" -template +template