// SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include #include #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/core/numeric/math.hpp" template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) { using ComputeType = std::conditional_t; // Calculate thresholds const auto rtol = ck_tile::get_relative_threshold( ck_tile::integer_divide_ceil(K, kbatch)); const auto atol = ck_tile::get_absolute_threshold( max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); // Calculate error due to split_k accumulation const auto rtol_split_k = ck_tile::get_relative_threshold(kbatch); const auto atol_split_k = ck_tile::get_absolute_threshold( max_accumulated_value, kbatch); // Use higher threshold return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } enum struct GemmPipelineType { Mem, CompV3, CompV4, CompV6, CompAsync }; template struct GemmPipelineTypeSelector; template struct GemmPipelineTypeSelector { using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrMem; using pipeline = ck_tile::GemmPipelineAgBgCrMem; static constexpr auto GetName() { return "GemmPipelineAgBgCrMem"; } }; template struct GemmPipelineTypeSelector { using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; using pipeline = ck_tile::GemmPipelineAgBgCrCompV3; static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV3"; } }; template struct GemmPipelineTypeSelector { using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; using pipeline = ck_tile::GemmPipelineAgBgCrCompV4; static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV4"; } }; template struct GemmPipelineTypeSelector { using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompV6; using pipeline = ck_tile::GemmPipelineAgBgCrCompV6; static constexpr auto GetName() { return "GemmPipelineAgBgCrCompV6"; } }; template struct GemmPipelineTypeSelector { using base_pipeline = ck_tile::BaseGemmPipelineAgBgCrCompAsync; using pipeline = ck_tile::GemmPipelineAgBgCrCompAsync; static constexpr auto GetName() { return "GemmPipelineAgBgCrCompAsync"; } }; template class TestCkTileGemmPipeline : public ::testing::Test { protected: using ALayout = std::tuple_element_t<0, Tuple>; using BLayout = std::tuple_element_t<1, Tuple>; using CLayout = std::tuple_element_t<2, Tuple>; using ADataType = std::tuple_element_t<3, Tuple>; using BDataType = std::tuple_element_t<4, Tuple>; using AccDataType = std::tuple_element_t<5, Tuple>; using CDataType = std::tuple_element_t<6, Tuple>; static constexpr auto Scheduler = std::tuple_element_t<13, Tuple>::value; static constexpr auto PipelineType = std::tuple_element_t<14, Tuple>::value; static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, Tuple>{}; static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, Tuple>{}; static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, Tuple>{}; static constexpr ck_tile::index_t M_Warp_Tile = std::tuple_element_t<10, Tuple>{}; static constexpr ck_tile::index_t N_Warp_Tile = std::tuple_element_t<11, Tuple>{}; static constexpr ck_tile::index_t K_Warp_Tile = std::tuple_element_t<12, Tuple>{}; using DsLayout = ck_tile::tuple<>; using DsDataType = ck_tile::tuple<>; static constexpr bool Persistent = ck_tile::tuple_element_or_default_t::value; template void invoke_gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) { constexpr ck_tile::index_t M_Warp = 2; constexpr ck_tile::index_t N_Warp = 2; constexpr ck_tile::index_t K_Warp = 1; constexpr bool kPadM = PadM; constexpr bool kPadN = PadN; constexpr bool kPadK = PadK; constexpr bool preshuffle = Preshuffle; constexpr bool DoubleSmemBuffer = (PipelineType == GemmPipelineType::CompV4 || PipelineType == GemmPipelineType::CompAsync); constexpr bool TransposeC = false; static constexpr bool StructuredSparsity = false; static constexpr bool NumWaveGroup = 1; // TODO: For now - but this should also be a test parameter constexpr int kBlockPerCu = 1; constexpr ck_tile::index_t TileParitionerGroupNum = 8; constexpr ck_tile::index_t TileParitionerM01 = 4; // =============================================== using GemmShape = ck_tile::TileGemmShape, ck_tile::sequence, ck_tile::sequence>; using TilePartitioner = ck_tile:: GemmSpatiallyLocalTilePartitioner; using Traits = ck_tile::TileGemmTraits; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; using GemmPipelineProblem = ck_tile::GemmPipelineProblem; using BaseGemmPipeline = typename GemmPipelineTypeSelector::base_pipeline; const ck_tile::index_t k_grain = args.k_batch * K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); const auto Run = [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr auto tail_number_v = tail_number_.value; constexpr auto memory_operation = memory_operation_.value; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; using GemmPipeline = typename GemmPipelineTypeSelector::pipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem>; using Kernel = ck_tile::GemmKernel; auto kargs = Kernel::MakeKernelArgs(args); dim3 grids; if constexpr(Persistent) { grids = Kernel::MaxOccupancyGridSize(s); } else { grids = Kernel::GridSize(args.M, args.N, args.k_batch); } dim3 blocks = Kernel::BlockSize(); if(!Kernel::IsSupportedArgument(kargs)) { throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } if(s.log_level_ > 0) { std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } ck_tile::launch_kernel( s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); }; const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) { Run(has_hot_loop_, tail_number_, ck_tile::integral_constant{}); } else { Run(has_hot_loop_, tail_number_, ck_tile::integral_constant{}); } }; BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); } template bool check_data_type() { return static_cast(this) ->template check_data_type_impl(); } template bool check_data_type_impl() { return true; } public: std::vector k_batches_; void SetUp() override { if(!check_data_type()) { GTEST_SKIP() << "Unsupported data type combination for gemm pipeline test."; } if constexpr(PipelineType == GemmPipelineType::CompV4) { // Only do k_batch = 1 when pipeline is CompV4 k_batches_ = {1}; } else { // Otherwise, use k_batch = 1 and 2 k_batches_ = {1, 2}; } } template void Run(const int M, const int N, const int K, const int StrideA = 0, const int StrideB = 0, const int StrideC = 0) { for(auto kb : k_batches_) { RunSingle(M, N, K, StrideA, StrideB, StrideC, kb); } } template void RunSingle(const int M, const int N, const int K, const int StrideA, const int StrideB, const int StrideC, int kbatch = 1) { using namespace ck_tile::literals; auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if constexpr(std::is_same_v) { return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); } else { return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); } }; auto f_get_default_stride = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if(stride == 0) { // give a chance if stride is zero, return a default packed stride if constexpr(std::is_same_v) { return col; } else { return row; } } else return stride; }; ck_tile::index_t stride_A = f_get_default_stride(M, K, StrideA, ALayout{}); ck_tile::index_t stride_B = f_get_default_stride(K, N, StrideB, BLayout{}); ck_tile::index_t stride_C = f_get_default_stride(M, N, StrideC, CLayout{}); ck_tile::HostTensor a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{})); ck_tile::HostTensor b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{})); ck_tile::HostTensor c_m_n_dev_result( f_host_tensor_descriptor(M, N, stride_C, CLayout{})); ck_tile::FillUniformDistributionIntegerValue{-5, 5, 11939}(a_m_k); ck_tile::FillUniformDistributionIntegerValue{-5, 5, 11940}(b_k_n); ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); a_m_k_dev_buf.ToDevice(a_m_k.data()); b_k_n_dev_buf.ToDevice(b_k_n.data()); c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(), b_k_n_dev_buf.GetDeviceBuffer(), c_m_n_dev_buf.GetDeviceBuffer(), kbatch, M, N, K, stride_A, stride_B, stride_C}; invoke_gemm(args, ck_tile::stream_config{nullptr, false}); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool pass = true; ck_tile::HostTensor c_m_n_host_ref( f_host_tensor_descriptor(M, N, stride_C, CLayout{})); c_m_n_host_ref.SetZero(); ck_tile::reference_gemm( a_m_k, b_k_n, c_m_n_host_ref); const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); const auto rtol_atol = calculate_rtol_atol( K, kbatch, max_accumulated_value); pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref, "Error: Incorrect results!", rtol_atol.at(ck_tile::number<0>{}), rtol_atol.at(ck_tile::number<1>{})); EXPECT_TRUE(pass); } };