diff --git a/example/ck_tile/40_streamk_gemm/run_gemm_example.inc b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc index 5fdf6b29ef..6dd054ee11 100644 --- a/example/ck_tile/40_streamk_gemm/run_gemm_example.inc +++ b/example/ck_tile/40_streamk_gemm/run_gemm_example.inc @@ -2,29 +2,6 @@ // SPDX-License-Identifier: MIT #pragma once -// Estimate the number of WGs contributing to the same macro tile in C -template -int estimate_num_wgs_per_tile(const TilePartitioner& tile_partitioner) -{ - // In the case of non-atomic reduction or DP only, there will always be 1 WG contributing to a - // macro time in C - int num_wgs_per_tile = 1; - - // Otherwise, for atomics, multiple WGs may be contributing to the same macro tile in C - if(tile_partitioner.sk_num_blocks > 0 && - ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic) - { - // Determine the number of iterations per WG for a given macro tile in C - uint32_t k_iters_per_block = tile_partitioner.k_iters_per_big_block - 1; - - // Estimate the number of WGs per macro tile - num_wgs_per_tile = (tile_partitioner.k_iters_per_tile.get() / (k_iters_per_block)) + - ((tile_partitioner.k_iters_per_tile.get() % k_iters_per_block) != 0); - } - - return std::max(num_wgs_per_tile, 1); -} - template static constexpr inline auto is_row_major(Layout) { @@ -65,7 +42,8 @@ template -std::tuple gemm(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& s); +std::tuple gemm(const ck_tile::StreamKHostArgs& args, + const ck_tile::stream_config& s); template -std::tuple invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, - ck_tile::DeviceMem& b_k_n_dev_buf, - ck_tile::DeviceMem& c_m_n_dev_buf, - ck_tile::index_t M, - ck_tile::index_t N, - ck_tile::index_t K, - ck_tile::index_t stride_A, - ck_tile::index_t stride_B, - ck_tile::index_t stride_C, - int n_warmup, - int n_repeat, - bool flush_cache, - ck_tile::StreamKReductionStrategy reduction_strategy, - uint32_t num_sk_blocks) +std::tuple +invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + int n_warmup, + int n_repeat, + bool flush_cache, + ck_tile::StreamKReductionStrategy reduction_strategy, + uint32_t num_sk_blocks) { ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), b_k_n_dev_buf.GetDeviceBuffer(), @@ -105,7 +84,7 @@ std::tuple invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, reduction_strategy, num_sk_blocks}; - std::tuple ave_time_and_batch; + std::tuple ave_time_and_batch; if(args.reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic) { diff --git a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp index bb6b1eb413..40709e38e2 100644 --- a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp +++ b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp @@ -3,6 +3,7 @@ #include "gemm_utils.hpp" #include "run_gemm_example.inc" +#include "ck_tile/ops/common.hpp" template -std::tuple gemm(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& s) +std::tuple gemm(const ck_tile::StreamKHostArgs& args, + const ck_tile::stream_config& s) { using GemmShape = ck_tile::TileGemmShape< @@ -42,7 +44,7 @@ std::tuple gemm(const ck_tile::StreamKHostArgs& args, const ck_tile: GemmConfig::NumWaveGroups, GemmConfig::Preshuffle>; - const auto Run = [&](const auto memory_operation) -> std::tuple { + const auto Run = [&](const auto memory_operation) -> std::tuple { // We create the GEMM pipeline without specifying has_hot_loop or tail_num. // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K @@ -113,7 +115,13 @@ std::tuple gemm(const ck_tile::StreamKHostArgs& args, const ck_tile: preprocess, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - int num_wgs_per_tile = estimate_num_wgs_per_tile(kargs.tile_partitioner); + ck_tile::index_t num_wgs_per_tile = ck_tile::estimate_num_wgs_per_tile( + kargs.tile_partitioner.sk_num_blocks, + // k_iters_per_big_block could be 1, which indicates that all Stream-K workgroups are + // big and each does one iteration. Thus, we ensure the value passed in is at least 1 to + // avoid division by zero errors. + ck_tile::max(kargs.tile_partitioner.k_iters_per_big_block - 1, 1u), + kargs.tile_partitioner.k_iters_per_tile.get()); return std::tuple{ave_time, num_wgs_per_tile}; }; diff --git a/include/ck_tile/ops/common/streamk_common.hpp b/include/ck_tile/ops/common/streamk_common.hpp index 5dbe6223c4..c01e967dcd 100644 --- a/include/ck_tile/ops/common/streamk_common.hpp +++ b/include/ck_tile/ops/common/streamk_common.hpp @@ -11,4 +11,33 @@ enum StreamKReductionStrategy : uint32_t Atomic = 0u, Reduction = 1u }; + +/** + * @brief Estimates the number of Stream-K workgroups per macro tile in the C tensor. + * + * @param sk_ctas Number of Stream-K workgroups. + * @param iters_per_sk_cta Number of iterations per Stream-K workgroup. + * @param iters_per_tile Number of iterations per tile (i.e., the number of macro tiles in the K + * dimension). + * @return ck_tile::index_t An estimate of the number of workgroups per macro tile in the C tensor. + * @note It is assumed that `iters_per_sk_cta` > 0. + */ +template +ck_tile::index_t +estimate_num_wgs_per_tile(index_t sk_ctas, index_t iters_per_sk_cta, index_t iters_per_tile) +{ + // In the case of non-atomic reduction or data-parallel only, there will always be 1 workgroup + // writing final results to a given macro tile in C. + int num_wgs_per_tile = 1; + + // Otherwise, for atomics, multiple workgroups may be writing to the same macro tile in C. + if(sk_ctas > 0 && ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic) + { + // Estimate the number of workgroups per macro tile. + num_wgs_per_tile = + (iters_per_tile / iters_per_sk_cta) + ((iters_per_tile % iters_per_sk_cta) != 0); + } + + return std::max(num_wgs_per_tile, 1); +} } // namespace ck_tile diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk.hpp index da0b8d153d..c341789435 100644 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk.hpp +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk.hpp @@ -10,6 +10,7 @@ #include #include "ck_tile/host.hpp" +#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" @@ -50,10 +51,10 @@ class TestCkTileStreamK : public ::testing::Test bool PadK = true, bool Preshuffle = false, bool TransposeC = false> - bool invoke_streamk(const ck_tile::StreamKHostArgs& args, - const ck_tile::stream_config& s, - int num_cu, - int occupancy) + std::tuple invoke_streamk(const ck_tile::StreamKHostArgs& args, + const ck_tile::stream_config& s, + int num_cu, + int occupancy) { constexpr bool kPadM = PadM; constexpr bool kPadN = PadN; @@ -129,7 +130,7 @@ class TestCkTileStreamK : public ::testing::Test if(!Kernel::IsSupportedArgument(kargs)) { - return false; + return std::tuple{false, -1}; } dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner); @@ -138,7 +139,16 @@ class TestCkTileStreamK : public ::testing::Test ck_tile::launch_kernel( s, ck_tile::make_kernel(Kernel{}, grid_dims, block_dims, 0, kargs)); - return true; + ck_tile::index_t num_accumulations_per_tile = + ck_tile::estimate_num_wgs_per_tile( + kargs.tile_partitioner.sk_num_blocks, + // k_iters_per_big_block could be 1, which indicates that all blocks are + // big and each does one iteration. Thus, we ensure the value passed in is at + // least 1 to avoid division by zero errors. + ck_tile::max(kargs.tile_partitioner.k_iters_per_big_block - 1, 1u), + kargs.tile_partitioner.k_iters_per_tile.get()); + + return std::tuple{true, num_accumulations_per_tile}; }; return Run(ck_tile::integral_constant( - args, ck_tile::stream_config{nullptr, false, 0, 0, 1}, num_cu, occupancy)) + const auto [is_valid_instance, num_accumulations_per_tile] = + invoke_streamk( + args, ck_tile::stream_config{nullptr, false, 0, 0, 1}, num_cu, occupancy); + + if(!is_valid_instance) { GTEST_SKIP() << "Skipping this test: The kernel cannot solve the problem\n"; } @@ -256,7 +269,7 @@ class TestCkTileStreamK : public ::testing::Test const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); const auto rtol_atol = calculate_rtol_atol( - K, /*kbatch*/ 1, max_accumulated_value); + K, num_accumulations_per_tile, max_accumulated_value); bool pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref,