diff --git a/example/ck_tile/03_gemm/CMakeLists.txt b/example/ck_tile/03_gemm/CMakeLists.txt index b1aede42c7..825cd6e522 100644 --- a/example/ck_tile/03_gemm/CMakeLists.txt +++ b/example/ck_tile/03_gemm/CMakeLists.txt @@ -1,6 +1,7 @@ add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp) add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp) +add_executable(tile_example_gemm_reduce EXCLUDE_FROM_ALL gemm_splitk_two_stage_reduce.cpp) set(EXAMPLE_GEMM_COMPILE_OPTIONS) set(EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS) if(CK_USE_OCP_FP8) @@ -14,3 +15,4 @@ list(APPEND EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS "SHELL: -mllvm -greedy-rev target_compile_options(tile_example_gemm_basic PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_gemm_universal PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(tile_example_gemm_weight_preshuffle PRIVATE ${EXAMPLE_WEIGHT_PRESHUFFLE_COMPILE_OPTIONS}) +target_compile_options(tile_example_gemm_reduce PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp new file mode 100644 index 0000000000..a4a8039288 --- /dev/null +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp @@ -0,0 +1,1009 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" +#include "gemm_utils.hpp" +#include "run_gemm_example.inc" + +/** + * @brief Tile partitioner with output offset support. + * + * This partitioner extends the spatially local tile partitioner to support + * split-K reduction by providing workspace output offset calculation. Each K-split + * writes to a separate slice of the workspace: workspace[k_id * M * N]. + */ +template +struct GemmSplitKTilePartitioner + : public ck_tile::GemmSpatiallyLocalTilePartitioner +{ + using Base = ck_tile::GemmSpatiallyLocalTilePartitioner; + + // Inherit constructors and methods + using Base::Base; + using Base::GetLoopNum; + + /** + * @brief Calculate output pointer offset for split-K reduction. + * + * @param kargs Kernel arguments. + * @param k_id Current K-split ID (from blockIdx.z or calculated k_batch). + * @return ck_tile::index_t The offset for this K-split. + */ + template + CK_TILE_HOST_DEVICE static ck_tile::index_t GetOutputOffset(const KernelArgs& kargs, + ck_tile::index_t k_id) noexcept + { + // Each K-split gets its own M*N workspace slice + return (kargs.k_batch > 1) ? (k_id * kargs.M * kargs.N) : 0; + } +}; + +/** + * @brief Extended GEMM host arguments for two-stage split-K implementation + * + * This structure supports the two-stage split-K approach where: + * 1. Stage 1: GEMM writes partial results to workspace memory + * 2. Stage 2: Reduction kernel sums workspace results to final output + * + * The base class e_ptr points to workspace, while final_output_ptr points to the actual output + */ +struct GemmSplitKHostArgs : public ck_tile::GemmHostArgs +{ + using BaseArgs = ck_tile::GemmHostArgs; + + CK_TILE_HOST GemmSplitKHostArgs() = default; + CK_TILE_HOST GemmSplitKHostArgs(const void* a_ptr_, + const void* b_ptr_, + void* workspace_ptr_, // Workspace for partial results + void* e_ptr_, // Final output destination + ck_tile::index_t k_batch_, + ck_tile::index_t M_, + ck_tile::index_t N_, + ck_tile::index_t K_, + ck_tile::index_t stride_A_, + ck_tile::index_t stride_B_, + ck_tile::index_t workspace_stride_, + ck_tile::index_t stride_E_) + : BaseArgs(a_ptr_, + b_ptr_, + workspace_ptr_, // Base e_ptr = workspace_ptr + k_batch_, + M_, + N_, + K_, + stride_A_, + stride_B_, + workspace_stride_), + final_output_ptr(e_ptr_), + final_stride_E(stride_E_) + { + } + + void* final_output_ptr; // Pointer to final output tensor + ck_tile::index_t final_stride_E; // Stride for final output tensor +}; + +/** + * @brief Stage 1: GEMM kernel that writes partial split-K results to workspace + * + * This function performs the matrix multiplication with split-K, where each + * K-split writes its partial result to a separate section of the workspace. + * + * Workspace layout: [k_batch, M, N] where each [M, N] slice contains + * partial results for one K-split. + * + * @param args Extended arguments containing workspace and final output pointers + * @param s Stream configuration for kernel execution + * @return Execution time in milliseconds + */ +template +float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config& s) +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence, + GemmConfig::PermuteA, + GemmConfig::PermuteB>; + + using TilePartitioner = GemmSplitKTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; + + const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + float ave_time{0}; + + // Create base GEMM arguments pointing to workspace instead of final output + // The workspace will store partial results from each K-split + ck_tile::GemmHostArgs base_args(args.a_ptr, + args.b_ptr, + args.e_ptr, + args.k_batch, + args.M, + args.N, + args.K, + args.stride_A, + args.stride_B, + args.stride_E); + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(base_args); + + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Stage 1 - Launching GEMM kernel: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + ck_tile::RotatingMemWrapper rotating_mem( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + // For workspace mode, always use SET operation since each K-split writes to separate memory + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + return ave_time; +} + +/** + * @brief Stage 2: Reduction kernel that sums partial split-K results to final output + * + * This function reduces the partial results stored in workspace memory by stage 1. + * It sums across the k_batch dimension to produce the final GEMM result. + * + * Workspace layout: [k_batch, M, N] -> Final output: [M, N] + * + * @tparam CDataType Output data type + * @tparam ComputeDataType Computation precision for reduction + * @tparam ELayout Memory layout of output tensor + * @param args Extended arguments containing workspace and output information + * @param s Stream configuration for kernel execution + * @return Execution time in milliseconds + */ +template +float reduce_stage2(const GemmSplitKHostArgs& args, const ck_tile::stream_config& s) +{ + const ck_tile::index_t reduce_dim_size = args.k_batch; // Number of partial results to reduce + // Calculate output size based on the final output tensor dimensions + const ck_tile::index_t output_size = args.M * args.N; + + // Workspace layout: [k_batch, M, N] where each [M, N] slice has the same layout as final output + // The workspace strides need to account for the layout of the final output tensor + auto workspace_shape = ck_tile::make_tuple(args.k_batch, args.M, args.N); + auto workspace_strides = + ck_tile::make_tuple(args.M * args.N, // k_batch stride: jump to next K split + args.final_stride_E, // stride same as final output stride + 1); + + // Define kept and reduced dimensions + constexpr auto kept_dim = ck_tile::sequence<1, 2>{}; // Keep M, N dimensions + constexpr auto reduce_dims = ck_tile::sequence<0>{}; // Reduce k_batch dimension + + using ReduceOp = ck_tile::ReduceOp::Add; + using BlockWarps = ck_tile::sequence<4, 1>; + using BlockTile = ck_tile::sequence<128, 128>; + using WarpTile = ck_tile::sequence<32, 128>; + using ThreadTile = ck_tile::sequence<8, 8>; + + constexpr ck_tile::index_t kBlockSize = 256; + constexpr ck_tile::index_t kBlockPerCu = 1; + + ck_tile::index_t kGridSize = (output_size + BlockTile::at(ck_tile::number<0>{}) - 1) / + BlockTile::at(ck_tile::number<0>{}); + + using Shape = ck_tile::Reduce2dShape; + using Problem = + ck_tile::Reduce2dProblem; + using Kernel = ck_tile::Reduce; + + if(!Kernel::IsSupportedArgument(reduce_dim_size, workspace_strides)) + { + throw std::runtime_error("Wrong! Reduction arguments not supported!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Stage 2 - Launching Reduction kernel" << '\n' + << "workspace shape: [" << args.k_batch << ", " << args.M << ", " << args.N << "]" + << '\n' + << "output shape: [" << args.M << ", " << args.N << "]" << '\n' + << "grid size: " << kGridSize << std::endl; + } + + float ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + kGridSize, + kBlockSize, + 0, // LDS size + static_cast(args.e_ptr), // workspace input + static_cast(args.final_output_ptr), // final output + workspace_shape, + workspace_strides, + kept_dim, + reduce_dims)); + + return ave_time; +} + +/** + * @brief Orchestrator for two-stage split-K GEMM implementation + * + * This function coordinates the two-stage approach: + * 1. Stage 1: Execute GEMM with each K-split writing to workspace + * 2. Stage 2: Reduce workspace results to final output (if k_batch > 1) + * + * @param args Extended arguments for two-stage execution + * @param s Stream configuration + * @return Total execution time (GEMM + Reduction) + */ +template +float gemm_splitk_two_stage(const GemmSplitKHostArgs& args, const ck_tile::stream_config& s) +{ + float gemm_time = 0.0f; + float reduce_time = 0.0f; + + if(s.log_level_ > 0) + { + std::cout << "Starting Two-Stage GEMM+SplitK with k_batch=" << args.k_batch << std::endl; + std::cout << "Workspace size: " << args.k_batch << " x " << args.M << " x " << args.N + << " = " << args.k_batch * args.M * args.N * sizeof(CDataType) << " bytes" + << std::endl; + } + + // Stage 1: GEMM to workspace + gemm_time = gemm_stage1(args, s); + + // Synchronize before stage 2 + auto sync_result = hipStreamSynchronize(s.stream_id_); + if(sync_result != hipSuccess) + { + throw std::runtime_error("Stream synchronization failed"); + } + + // Stage 2: Reduction from workspace to final output (if needed) + if(args.k_batch > 1) + { + // Use appropriate precision for reduction computations + using ComputeDataType = std::conditional_t< + std::is_same_v, + float, + std::conditional_t, float, CDataType>>; + reduce_time = reduce_stage2(args, s); + } + else + { + // Single K-split: simple copy from workspace to final output + auto copy_result = hipMemcpyAsync(args.final_output_ptr, + args.e_ptr, + args.M * args.N * sizeof(CDataType), + hipMemcpyDeviceToDevice, + s.stream_id_); + if(copy_result != hipSuccess) + { + throw std::runtime_error("Memory copy failed"); + } + } + + if(s.log_level_ > 0) + { + std::cout << "GEMM stage time: " << gemm_time << " ms" << std::endl; + if(args.k_batch > 1) + { + std::cout << "Reduction stage time: " << reduce_time << " ms" << std::endl; + } + std::cout << "Total time: " << gemm_time + reduce_time << " ms" << std::endl; + } + + return gemm_time + reduce_time; +} + +/** + * @brief High-level interface for two-stage split-K GEMM execution + * + * @param a_m_k_dev_buf Input matrix A device buffer + * @param b_k_n_dev_buf Input matrix B device buffer + * @param c_m_n_dev_buf Output matrix C device buffer + * @param M Matrix M dimension + * @param N Matrix N dimension + * @param K Matrix K dimension + * @param stride_A Memory stride for matrix A + * @param stride_B Memory stride for matrix B + * @param stride_C Memory stride for matrix C + * @param kbatch Number of K-splits for split-K execution + * @param n_warmup Number of warmup iterations + * @param n_repeat Number of repeat iterations for benchmarking + * @param persistent Whether to use persistent kernel execution + * @return Average execution time in milliseconds + */ +template +float invoke_gemm_splitk_two_stage(ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch, + int n_warmup, + int n_repeat, + bool persistent) +{ + // Calculate workspace size: kbatch * M * N elements + const ck_tile::index_t workspace_size = kbatch * M * N * sizeof(CDataType); + const ck_tile::index_t workspace_stride = stride_C; // Stride for k_batch dimension + + // Allocate workspace memory + ck_tile::DeviceMem workspace_buf(workspace_size); + workspace_buf.SetZero(); + + // Create extended args for two-stage approach + GemmSplitKHostArgs args{ + a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr + b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr + workspace_buf.GetDeviceBuffer(), // workspace_ptr (used as e_ptr for stage 1) + c_m_n_dev_buf.GetDeviceBuffer(), // final_output_ptr + kbatch, // k_batch + M, + N, + K, // dimensions + stride_A, + stride_B, // input strides + workspace_stride, // workspace stride + stride_C // final output stride + }; + + float ave_time; + ck_tile::stream_config config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}; + + if(persistent) + { + ave_time = gemm_splitk_two_stage(args, config); + } + else + { + ave_time = gemm_splitk_two_stage(args, config); + } + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run Two-Stage GEMM+SplitK with M=" << M << " N=" << N << " K=" << K + << " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C + << " kbatch=" << kbatch << " WorkspaceSize=" << workspace_size << " bytes" + << " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name + << " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits::name + << " B_Type=" << DataTypeTraits::name + << " C_Type=" << DataTypeTraits::name + << " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off") + << " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, " + << tflops << " TFlops, " << gb_per_sec << " GB/s" << std::endl; + + return ave_time; +} + +// Two-stage implementation of run_gemm_example_with_layouts +template +int run_gemm_example_with_layouts_two_stage(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + using AccDataType = typename GemmTypeConfig::AccDataType; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t init_method = arg_parser.get_int("init"); + bool persistent = arg_parser.get_int("persistent"); + + const bool preshuffle = GemmConfig::Preshuffle; + + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + if(init_method == 0) + { + if constexpr(preshuffle) + { + ck_tile::FillUniformDistribution{-.5f, .5f}(a_m_k); + ck_tile::FillUniformDistribution{-.5f, .5f}(b_k_n); + } + else + { + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + } + } + else if(init_method == 1) + { + ck_tile::FillMonotonicSeq{}(a_m_k); + ck_tile::FillMonotonicSeq{}(b_k_n); + } + else if(init_method == 2) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n); + } + else + { + a_m_k.SetZero(); + b_k_n.SetZero(); + } + + if(!preshuffle && GemmConfig::UseStructuredSparsity) + { + ck_tile::AdjustToStructuredSparsity{}(a_m_k); + } + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + static_assert(!GemmConfig::PermuteA, "Not implemented"); + + if constexpr(preshuffle) + { + ck_tile::HostTensor b_shuffle_host = shuffle_b(b_k_n); + // shuffled buffer B for device implementation + b_k_n_dev_buf.ToDevice(b_shuffle_host.data()); + } + else + { + if constexpr(std::is_same_v) + { + // Permute vector pk_i4x4 data for device implementation + ck_tile::HostTensor b_k_n_dev = b_k_n; + if constexpr(GemmConfig::PermuteB) + { + permute_tensor_b(b_k_n_dev); + } + permute_vectors_i4x4_b(b_k_n_dev); + b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); + } + else + { + if constexpr(GemmConfig::PermuteB) + { + std::cout << "Permute for this DataType is not implemented." << std::endl; + return false; + } + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + } + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + std::cout << "Using Workspace Split-K Mode (Two-Stage with Reduction)" << std::endl; + // Use the new two-stage approach + invoke_gemm_splitk_two_stage, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout>(a_m_k_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat, + persistent); + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + if constexpr(std::is_same_v) + { + // Restore input for B for gpu reference + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + if constexpr(GemmConfig::Preshuffle) + { + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + + // memory on host to store gpu reference result + ck_tile::HostTensor c_m_n_gpu_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + // memory on device to store gpu reference result + ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); + + c_m_n_gpu_ref.SetZero(); + c_m_n_gpu_buf_ref.SetZero(); + + ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); + BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); + CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); + + ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); + + c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); + + const float max_accumulated_value = + *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_gpu_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl; + } + + return pass; +} + +template +int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + auto [result, arg_parser] = create_args(argc, argv); + bool preshuffle = GemmConfig::Preshuffle; + + if(preshuffle && std::is_same_v) + { + throw std::runtime_error("Preshuffle is not supported for this int4 datatype!"); + } + + if(preshuffle && a_layout != "R" && b_layout != "C") + { + throw std::runtime_error( + "Preshuffle is supported only for A(Row major), B(column major) input matrices!"); + } + + // Use new two-stage approach for both int4 and other data types + if constexpr(std::is_same_v) + { + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts_two_stage(argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_example_with_layouts_two_stage(argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices when " + "BPrecType is ck_tile::pk_int4_t!"); + } + } + else + { + if(a_layout == "R" && b_layout == "R") + { + return run_gemm_example_with_layouts_two_stage( + argc, argv, Row{}, Row{}, Row{}); + } + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts_two_stage( + argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_gemm_example_with_layouts_two_stage( + argc, argv, Col{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_example_with_layouts_two_stage( + argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } + } + return 0; +} + +template