From 78f2779870915c3e71bcd4e2e8684ea5c0c873be Mon Sep 17 00:00:00 2001 From: "assistant-librarian[bot]" Date: Mon, 29 Sep 2025 22:11:28 +0000 Subject: [PATCH] Merge commit 'bebf0e9d158c13d34c9f263a9551f60fa463bc66' into develop --- example/ck_tile/17_grouped_gemm/README.md | 10 +- .../17_grouped_gemm/grouped_gemm_multi_d.cpp | 106 ++++++ .../17_grouped_gemm/grouped_gemm_multi_d.hpp | 5 +- .../run_grouped_gemm_multi_d_example.inc | 46 ++- .../ck_tile/40_streamk_gemm/CMakeLists.txt | 5 + example/ck_tile/40_streamk_gemm/README.md | 37 ++ .../ck_tile/40_streamk_gemm/gemm_utils.hpp | 106 ++++++ .../40_streamk_gemm/run_gemm_example.inc | 351 ++++++++++++++++++ .../40_streamk_gemm/streamk_gemm_basic.cpp | 193 ++++++++++ example/ck_tile/CMakeLists.txt | 1 + .../ops/gemm/kernel/grouped_gemm_kernel.hpp | 15 +- 11 files changed, 856 insertions(+), 19 deletions(-) create mode 100644 example/ck_tile/40_streamk_gemm/CMakeLists.txt create mode 100644 example/ck_tile/40_streamk_gemm/README.md create mode 100644 example/ck_tile/40_streamk_gemm/gemm_utils.hpp create mode 100644 example/ck_tile/40_streamk_gemm/run_gemm_example.inc create mode 100644 example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp diff --git a/example/ck_tile/17_grouped_gemm/README.md b/example/ck_tile/17_grouped_gemm/README.md index 0821065098..09bf3e167a 100644 --- a/example/ck_tile/17_grouped_gemm/README.md +++ b/example/ck_tile/17_grouped_gemm/README.md @@ -10,16 +10,15 @@ The grouped GEMM examples include two advanced optimization features: Weight preshuffle is an optimization technique that reorganizes the B matrix (weights) in memory to improve data access patterns and reduce memory bandwidth requirements. This is particularly beneficial for inference workloads where the same weights are reused across multiple batches. - **Implementation**: Available in `grouped_gemm_preshuffle.cpp` -- **Configuration**: Uses `GemmConfigPreshuffleDecode` template configuration +- **Configuration**: Uses `GemmConfigPreshuffleDecode` and `GemmConfigPreshufflePrefill` template configuration - **Constraints**: Currently supports only A(Row major) + B(Column major) → C(Row major) layouts -- **Benefits**: Improved memory efficiency and reduced data movement + #### Persistence Mode Persistence mode is a GPU optimization where thread blocks remain active on the compute units to process multiple work items sequentially, reducing kernel launch overhead and improving occupancy. - **Template Parameter**: Controlled by the `Persistent` boolean template parameter in `invoke_gemm` - **Usage**: `invoke_gemm` enables persistence -- **Benefits**: Reduced kernel launch overhead, better resource utilization for small matrix sizes #### Multi-D Operations Multi-D operations extend the standard GEMM operation by supporting additional element-wise operations on the result tensor. This feature is particularly useful for workloads that require post-processing of the GEMM output. @@ -31,7 +30,8 @@ Multi-D operations extend the standard GEMM operation by supporting additional e - **Benefits**: Enables complex operations like scaling, activation functions, or other element-wise transformations in a single kernel call - **Build Target**: `make tile_example_grouped_gemm_multi_d -j` -Both features can be combined with different data types (fp16, fp8) and layout configurations to optimize performance for specific workloads. +Multi-D operations supports both persistence and non-persistence modes. +Weight preshuffle supports only on non-persistence mode. ## Build ``` @@ -48,7 +48,7 @@ make tile_example_grouped_gemm_multi_d -j # The quant grouped gemm fp8 example make tile_example_quant_grouped_gemm -j ``` -This will result in an executable `build/bin/tile_example_grouped_gemm`, `build/bin/tile_example_grouped_gemm_preshuffle`, `build/bin/tile_example_grouped_gemm_multi_d`, and `build/bin/tile_example_quant_grouped_gemm`. +Each example will result in an corresponding executable `build/bin/tile_example_grouped_gemm`, `build/bin/tile_example_grouped_gemm_preshuffle`, `build/bin/tile_example_grouped_gemm_multi_d`, and `build/bin/tile_example_quant_grouped_gemm`. ## example diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp index 409eda8de4..98b0428d39 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp @@ -166,6 +166,112 @@ float grouped_gemm_multi_d(const std::vector& gemm_d return ave_time; } +template +float grouped_gemm_multi_d_tileloop(const ck_tile::stream_config& s, + const ck_tile::index_t num_groups, + void* kargs_ptr, + bool splitk) +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence>; + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + using GemmUniversalTraits = + ck_tile::PersistentTileGemmUniversalTraits; + + float ave_time{0}; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); + + return ave_time; + }; + if(!splitk) + { + Run(ck_tile::integral_constant{}); + } + else + { + Run(ck_tile::integral_constant{}); + } + + return ave_time; +} + #include "run_grouped_gemm_multi_d_example.inc" int main(int argc, char* argv[]) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp index f7727d854c..d5203a799c 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp @@ -95,6 +95,7 @@ struct GemmConfigV3 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; + static constexpr bool Persistent = true; static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; @@ -170,7 +171,7 @@ struct PipelineTypeTraits using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; }; -using grouped_gemm_multi_d_kargs = ck_tile::GroupedGemmHostArgs<2>; +using grouped_gemm_multi_d_kargs = ck_tile::GroupedGemmHostArgs; std::pair create_args(int argc, char* argv[]) { @@ -201,7 +202,7 @@ std::pair create_args(int argc, char* argv[]) inline std::size_t get_workspace_size(const std::vector& gemm_descs) { - return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<2>); + return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); } template > kargs; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); + const bool splitk = args[0].k_batch > 1; + for(const auto& arg : args) + { + kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<1, 1, 2>{{arg.a_ptr}, + {arg.b_ptr}, + arg.ds_ptr, + arg.e_ptr, + arg.M, + arg.N, + arg.K, + {arg.stride_A}, + {arg.stride_B}, + arg.stride_Ds, + arg.stride_E, + arg.k_batch}); + } + const auto stream = ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}; + HIP_CHECK_ERROR(hipMemcpyWithStream( + kargs_ptr, + kargs.data(), + kargs.size() * sizeof(ck_tile::GemmTransKernelArg), + hipMemcpyHostToDevice, + stream.stream_id_)); + ave_time = + grouped_gemm_multi_d_tileloop(stream, group_count, kargs_ptr, splitk); } return ave_time; } @@ -322,12 +356,6 @@ int run_grouped_gemm_multi_d_example_with_layouts(int argc, b_k_n_tensors[i], {d0_m_n_tensors[i], d1_m_n_tensors[i]}, e_m_n_host_refs[i]); - std::cout << "e_m_n_host_refs[i]: " << std::endl; - e_m_n_host_refs[i].print_first_n(std::cout, 10); - std::cout << std::endl; - std::cout << "e_m_n_tensors[i]: " << std::endl; - e_m_n_tensors[i].print_first_n(std::cout, 10); - std::cout << std::endl; const float max_accumulated_value = *std::max_element(e_m_n_host_refs[i].mData.begin(), e_m_n_host_refs[i].mData.end()); diff --git a/example/ck_tile/40_streamk_gemm/CMakeLists.txt b/example/ck_tile/40_streamk_gemm/CMakeLists.txt new file mode 100644 index 0000000000..3539dee05b --- /dev/null +++ b/example/ck_tile/40_streamk_gemm/CMakeLists.txt @@ -0,0 +1,5 @@ +if(GPU_TARGETS MATCHES "gfx9") + add_executable(tile_example_streamk_gemm_basic EXCLUDE_FROM_ALL streamk_gemm_basic.cpp) +else() + message(DEBUG "Skipping ck_tile streamk gemm tests for current target") +endif() diff --git a/example/ck_tile/40_streamk_gemm/README.md b/example/ck_tile/40_streamk_gemm/README.md new file mode 100644 index 0000000000..d2ff7eabc0 --- /dev/null +++ b/example/ck_tile/40_streamk_gemm/README.md @@ -0,0 +1,37 @@ +# Stream-K GEMM + +This folder contains examples of Stream-K GEMMs using the ck_tile tile-programming implementation. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +# you can replace with the appropriate architecture (for example gfx942) or leave it blank +../script/cmake-ck-dev.sh ../ +# Compile the Stream-K kernels +make tile_example_streamk_gemm_basic -j +``` +This will result in an executable `build/bin/tile_example_streamk_gemm_basic` + +## example +``` +args: + -m m dimension (default:512) + -n n dimension (default:512) + -k k dimension (default:512) + -a_layout tensor A data layout (default: R) + -b_layout tensor B data layout (default: C) + -c_layout tensor C data layout (default: R) + -num_sk_blocks number of Stream-K blocks. -1: chosen by algorithm, or user selected (default:-1) +-reduction_strategy strategy for storing results in C tensor. atomic/reduction (default:atomic) + -stride_a tensor A stride (default:0) + -stride_b tensor B stride (default:0) + -stride_c tensor C stride (default:0) + -v validation strategy. 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:1) + -prec data type. fp16/bf16 (default:fp16) + -warmup number of iterations before benchmarking the kernel (default:50) + -repeat number of iterations to benchmark the kernel (default:100) + -timer timing mode. gpu:gpu timer, cpu:cpu timer (default:gpu) + -init data initialization strategy. 0:random, 1:linear, 2:constant(1) (default:0) + -flush_cache flush the cache before running the kernel (default:true) +``` \ No newline at end of file diff --git a/example/ck_tile/40_streamk_gemm/gemm_utils.hpp b/example/ck_tile/40_streamk_gemm/gemm_utils.hpp new file mode 100644 index 0000000000..e698539eea --- /dev/null +++ b/example/ck_tile/40_streamk_gemm/gemm_utils.hpp @@ -0,0 +1,106 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +struct GemmConfigBase +{ + static constexpr bool kPadM = true; + static constexpr bool kPadN = true; + static constexpr bool kPadK = true; + + static constexpr bool PermuteA = false; + static constexpr bool PermuteB = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + static constexpr bool Persistent = false; + + static constexpr int kBlockPerCu = 1; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool Preshuffle = false; + static constexpr bool DoubleSmemBuffer = false; +}; + +template +struct GemmConfigMemoryInterwave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 32; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; + +template +struct StreamKGemmTypeConfig +{ + using ADataType = ADataType_; + using BDataType = BDataType_; + using AccDataType = float; + using CDataType = CDataType_; +}; + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "512", "m dimension") + .insert("n", "512", "n dimension") + .insert("k", "512", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Column by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("num_sk_blocks", + "-1", + "number of Stream-K blocks. -1: chosen by algorithm, or user selected") + .insert("reduction_strategy", + "atomic", + "strategy for storing results in C tensor - atomic/reduction") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp16", "data type. fp16/bf16") + .insert("warmup", "50", "number of iterations before benchmarking the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("flush_cache", "true", "flush cache before running the kernel, defaults to true"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} diff --git a/example/ck_tile/40_streamk_gemm/run_gemm_example.inc b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc new file mode 100644 index 0000000000..5fdf6b29ef --- /dev/null +++ b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc @@ -0,0 +1,351 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#pragma once + +// Estimate the number of WGs contributing to the same macro tile in C +template +int estimate_num_wgs_per_tile(const TilePartitioner& tile_partitioner) +{ + // In the case of non-atomic reduction or DP only, there will always be 1 WG contributing to a + // macro time in C + int num_wgs_per_tile = 1; + + // Otherwise, for atomics, multiple WGs may be contributing to the same macro tile in C + if(tile_partitioner.sk_num_blocks > 0 && + ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic) + { + // Determine the number of iterations per WG for a given macro tile in C + uint32_t k_iters_per_block = tile_partitioner.k_iters_per_big_block - 1; + + // Estimate the number of WGs per macro tile + num_wgs_per_tile = (tile_partitioner.k_iters_per_tile.get() / (k_iters_per_block)) + + ((tile_partitioner.k_iters_per_tile.get() % k_iters_per_block) != 0); + } + + return std::max(num_wgs_per_tile, 1); +} + +template +static constexpr inline auto is_row_major(Layout) +{ + return ck_tile::bool_constant< + std::is_same_v, ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to multiple WGs working in the same C macro tile + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +std::tuple gemm(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& s); + +template +std::tuple invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + int n_warmup, + int n_repeat, + bool flush_cache, + ck_tile::StreamKReductionStrategy reduction_strategy, + uint32_t num_sk_blocks) +{ + ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + M, + N, + K, + stride_A, + stride_B, + stride_C, + reduction_strategy, + num_sk_blocks}; + + std::tuple ave_time_and_batch; + + if(args.reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic) + { + ave_time_and_batch = gemm( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache}); + } + else /*Reduction*/ + { + ave_time_and_batch = gemm( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, flush_cache}); + } + + return ave_time_and_batch; +} + +template +bool do_verify(const ck_tile::HostTensor& c_m_n_dev_result, + const ck_tile::HostTensor& c_m_n_ref, + const ck_tile::tuple& rtol_atol, + const char* variant) +{ + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; + std::cout << "The " << variant << " verification result is:" << (pass ? "correct" : "fail") + << std::endl; + return pass; +} + +ck_tile::StreamKReductionStrategy get_reduction_strategy_value(const std::string& strategy) +{ + if(strategy == "atomic") + { + return ck_tile::StreamKReductionStrategy::Atomic; + } + else if(strategy == "reduction") + { + return ck_tile::StreamKReductionStrategy::Reduction; + } + else + { + throw std::runtime_error("Unsupported Stream-K reduction strategy !!!"); + } +} + +template +int run_gemm_example_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + static_assert(!GemmConfig::Preshuffle, "Not implemented"); + static_assert(!GemmConfig::UseStructuredSparsity, "Not implemented"); + static_assert(!GemmConfig::PermuteA, "Not implemented"); + static_assert(!GemmConfig::PermuteB, "Not implemented"); + + using ADataType = typename TypeConfig::ADataType; + using BDataType = typename TypeConfig::BDataType; + using AccDataType = typename TypeConfig::AccDataType; + using CDataType = typename TypeConfig::CDataType; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t init_method = arg_parser.get_int("init"); + bool flush_cache = arg_parser.get_bool("flush_cache"); + + ck_tile::StreamKReductionStrategy reduction_strategy = + get_reduction_strategy_value(arg_parser.get_str("reduction_strategy")); + uint32_t num_sk_blocks = static_cast(arg_parser.get_int("num_sk_blocks")); + + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + if(init_method == 0) + { + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + } + else if(init_method == 1) + { + ck_tile::FillMonotonicSeq{}(a_m_k); + ck_tile::FillMonotonicSeq{}(b_k_n); + } + else if(init_method == 2) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n); + } + else + { + a_m_k.SetZero(); + b_k_n.SetZero(); + } + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + auto [ave_time, num_wgs_per_tile] = invoke_gemm, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout>(a_m_k_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + n_warmup, + n_repeat, + flush_cache, + reduction_strategy, + num_sk_blocks); + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K + << " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C + << " A_Layout=" << ALayout::name << " B_Layout=" << BLayout::name + << " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits::name + << " B_Type=" << DataTypeTraits::name + << " C_Type=" << DataTypeTraits::name + << " reduction_strategy=" << arg_parser.get_str("reduction_strategy") << " " + << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << std::endl; + + bool pass = true; + + // Memory on host to store gpu reference result + ck_tile::HostTensor c_m_n_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + c_m_n_ref.SetZero(); + + if(arg_parser.get_int("v") == 1) // Validate on the CPU + { + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, num_wgs_per_tile, max_accumulated_value); + pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU"); + } + else if(arg_parser.get_int("v") == 2) // Validate on the GPU + { + // Memory on device to store gpu reference result + ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_ref.get_element_space_size_in_bytes()); + c_m_n_gpu_buf_ref.SetZero(); + + ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); + BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); + CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); + + ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); + + c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data()); + + const float max_accumulated_value = + *std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, num_wgs_per_tile, max_accumulated_value); + pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "GPU"); + } + + return pass; +} diff --git a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp new file mode 100644 index 0000000000..bb6b1eb413 --- /dev/null +++ b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp @@ -0,0 +1,193 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "gemm_utils.hpp" +#include "run_gemm_example.inc" + +template +std::tuple gemm(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& s) + +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence, + GemmConfig::PermuteA, + GemmConfig::PermuteB>; + + using TilePartitioner = ck_tile::StreamKTilePartitioner; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + + const auto Run = [&](const auto memory_operation) -> std::tuple { + // We create the GEMM pipeline without specifying has_hot_loop or tail_num. + // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K + // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K + // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = ck_tile::StreamKKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + + dim3 grids = Kernel::GridSize(kargs.tile_partitioner); + dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + // Function to clear the output C tensor results after each repetition of the kernel + auto clear_gemm_output = [&]() { + if(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + + std::function preprocess = clear_gemm_output; + + float ave_time = ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + int num_wgs_per_tile = estimate_num_wgs_per_tile(kargs.tile_partitioner); + + return std::tuple{ave_time, num_wgs_per_tile}; + }; + + if constexpr(ck_tile::StreamKReductionStrategy::Atomic == ReductionStrategy) + { + return Run(ck_tile::integral_constant{}); + } + else // We are using ck_tile::StreamKReductionStrategy::Reduction + { + return Run(ck_tile::integral_constant{}); + } +} + +template +int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_example_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported layouts."); + } + + return 0; +} + +template