diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp index 58bce4795f..915aebd1e6 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp @@ -8,6 +8,478 @@ #include "ck_tile/host/concat.hpp" namespace ck_tile { +namespace reboot { + +/// @brief The Stream K GEMM kernel host arguments. +/// +/// @par Overview +/// This structure is passed to @ref StreamKKernel "StreamKKernel" when creating the kernel +/// arguments object. It contains all necessary information required to build proper kernel +/// arguments and launch the kernel on GPU. This structure defines the GEMM problem +/// configuration by stating all required information like M,N,K sizes and respective strides. +struct StreamKHostArgs : public ck_tile::UniversalGemmHostArgs<> +{ + CK_TILE_HOST explicit StreamKHostArgs(const void* a_ptr_, + const void* b_ptr_, + void* c_ptr_, + index_t M_, + index_t N_, + index_t K_, + index_t stride_A_, + index_t stride_B_, + index_t stride_C_, + StreamKReductionStrategy reduction_strategy_) + : UniversalGemmHostArgs<>({a_ptr_}, + {b_ptr_}, + {/*ds_ptr*/}, + c_ptr_, + /*k_batch_ =*/1, + M_, + N_, + K_, + {stride_A_}, + {stride_B_}, + {/*stride_Ds_*/}, + stride_C_), + reduction_strategy{reduction_strategy_} + { + } + + ck_tile::StreamKReductionStrategy reduction_strategy; +}; + +/// @brief The Stream K GEMM kernel class. +/// +/// @par Overview +/// This class is responsible for the Stream-K kernel, making use of UniversalGemm. +// The main kernel functions are the operator() functions. There is one for Persistent +// and one for Non-Persistent data parallel sections of the Stream-K algorithm. +// +// Both the Non-Persistent and Persistent kernels make use of `BaseGemm()` and +// `StreamKGemm()`. `BaseGemm()` computes offsets into the A,B,C tensors, then calls +// `RunGemm()` which runs the GEMM pipeline and epilogue. `StreamKGemm()` performs the +// main Stream-K algorithm. Each iteration of the Stream-K loop calls `BaseGemm()`. +template +struct StreamKKernel +{ + /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary + /// functions. + using UniversalGemmKernel = + UniversalGemmKernel; + + static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize; + static constexpr bool PersistentDP = UniversalGemmKernel::PersistentKernel; + + using TilePartitioner = TilePartitioner_; + using GemmPipeline = GemmPipeline_; + using EpiloguePipeline = EpiloguePipeline_; + + static_assert( + TilePartitioner::PERSISTENT == PersistentDP, + "Persistent flag from TilePartitioner must match Persistent flag from UniversalGemm."); + + /// @brief Specify the layout configurations for A, B, and C + using ALayout = typename GemmPipeline::ALayout; + using BLayout = typename GemmPipeline::BLayout; + using CLayout = typename GemmPipeline::CLayout; + + /// @brief Specify the data type configurations for A, B, and C + using ADataType = typename GemmPipeline::ADataType; + using BDataType = typename GemmPipeline::BDataType; + using CDataType = typename EpiloguePipeline::ODataType; + + template + static constexpr bool is_tuple_v = is_detected::value; + + /// @brief ALayout and ADataType are expected to be scalars, not a tuple. + static_assert(!is_tuple_v && !is_tuple_v, + "ALayout and ADataType must be scalars."); + + /// @brief BLayout and BDataType are expected to be scalars, not a tuple. + static_assert(!is_tuple_v && !is_tuple_v, + "BLayout and BDataType must be scalars."); + + /// @brief CLayout and CDataType are expected to be scalars, not a tuple. + static_assert(!is_tuple_v && !is_tuple_v, + "CLayout and CDataType must be scalars."); + + struct StreamKKernelArgs : ck_tile::UniversalGemmKernelArgs<> + { + StreamKKernelArgs(const StreamKHostArgs& host_args, index_t grid) + : UniversalGemmKernelArgs{host_args.as_ptr, + host_args.bs_ptr, + host_args.ds_ptr, + host_args.e_ptr, + host_args.M, + host_args.N, + host_args.K, + host_args.stride_As, + host_args.stride_Bs, + host_args.stride_Ds, + host_args.stride_E, + host_args.k_batch}, + reduction_strategy{host_args.reduction_strategy}, + // The workspace pointer is set to nullptr because we must first + // instantiate the TilePartitioner to get the necessary size + workspace_ptr{nullptr}, + tile_partitioner{TilePartitioner{host_args.M, host_args.N, host_args.K, grid}} + + { + } + + /// @brief The strategy used by work groups to compute final results in C tensor. + StreamKReductionStrategy reduction_strategy; + /// @brief A pointer to a buffer in device memory for accumulating partial via reduction + /// strategy. + void* workspace_ptr; + /// @brief An instance of the TilePartioner class for assisting with mapping workgroups to + /// the C tensor. + TilePartitioner tile_partitioner; + }; + + using KernelArgs = StreamKKernelArgs; + using Kernel = StreamKKernel; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + using P_ = GemmPipeline; + using WarpTile = typename P_::BlockGemmShape::WarpTile; + + return concat('_', "streamk", gemm_prec_str(), + concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), + concat('x', WarpTile::at(number<0>{}), WarpTile::at(number<1>{}), WarpTile::at(number<2>{})), + concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()), + concat('x', P_::kPadM, P_::kPadN, P_::kPadK)); + // clang-format on + } + + /// @brief Compute the grid size for the Stream K kernel using the tile_partitioner. + /// @return The grid size. + CK_TILE_HOST static auto GridSize(const TilePartitioner& tile_partitioner) -> dim3 + { + return tile_partitioner.grid_size(); + } + + /// @brief Get the maximum occupancy grid size for the persistent kernel on the current device. + /// @return The maximum occupancy grid size. + /// @note This function queries the maximum occupancy of the kernel using + /// `hipOccupancyMaxActiveBlocksPerMultiprocessor`. + CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 + { + return UniversalGemmKernel::MaxOccupancyGridSize(s); + } + + CK_TILE_HOST static constexpr auto BlockSize() -> dim3 + { + return UniversalGemmKernel::BlockSize(); + } + + /// @brief Constructs kernel arguments for the Stream-K kernel. + /// @param host_args Stream-K host arguments. + /// @param num_cu Number of compute units (CUs). The default is the number of CUs on the device. + /// The caller may select their own to assist with test reproducibility, etc. + /// @param occupancy The maximum number of active blocks per CU for this kernel. The caller may + /// select their own to assist with test reproducibility, etc. + /// @return The kernel arguments for Stream-K. + CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args, + int num_cu = NumCU(), + int occupancy = Occupancy()) + { + const index_t grid = num_cu * occupancy; + + return StreamKKernelArgs{host_args, grid}; + } + + template + CK_TILE_DEVICE static void + RunGemm(const std::array& as_ptr, + const std::array& bs_ptr, + const std::array& ds_ptr, + CDataType* c_ptr, + void* smem_ptr_0, + const typename UniversalGemmKernel::KernelArgs& kargs, + const index_t num_loop, + const index_t block_idx_m, + const index_t block_idx_n, + const index_t k_size) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + UniversalGemmKernel::template MakeGemmTensorViews( + as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size); + + const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = + UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + // Run GEMM cooperatively by whole workgroup. + const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0); + const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1); + const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2); + + // Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute + // has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this + // case, we call the GemmPipeline's operator() function that takes both has_hot_loop and + // tail_num. + const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); + + const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0], + bs_block_window[UniversalGemmKernel::I0], + num_loop, + has_hot_loop, + tail_num, + smem_ptr_0); + + if(UseDefaultScheduler || (get_warp_id() == 0)) + { + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3); + + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + } + + CK_TILE_HOST static bool IsSupportedArgument(const StreamKKernelArgs& kargs) + { + if(kargs.reduction_strategy == StreamKReductionStrategy::Reduction) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("CK Tile Stream-K only supports the atomic reduction strategy."); + } + return false; + } + return UniversalGemmKernel::IsSupportedArgument(kargs); + } + + /// @brief Computes the buffer size needed to store accumulation results for Stream K. + /// @return The buffer size needed. + CK_TILE_HOST static uint32_t GetWorkSpaceSize(const StreamKKernelArgs& kargs) + { + return kargs.tile_partitioner.GetWorkSpaceSize(sizeof(CDataType)); + } + + /// @brief Sets the kargs' current workspace_ptr to the given workspace_ptr. + /// @note Assumes that the given workspace_ptr points to allocated device memory. + CK_TILE_HOST static void SetWorkSpacePointer(StreamKKernelArgs& kargs, void* workspace_ptr) + { + kargs.workspace_ptr = workspace_ptr; + } + + /// @brief Computes offsets into A, B, and C tensors then runs the GEMM pipeline and epilogue. + /// @param kargs Stream-K kernel arguments. + /// @param tile_idx The 1D tile index in the C tensor for this workgroup. + /// @param num_loop The number of iterations (at the macro tile level) in the K dimension this + /// workgroup will perform in the C tile. + /// @param i_k_a The K offset in the A tensor. + /// @param i_k_b The K offset in the B tensor. + /// @param k_size The portion of the K dimension this workgroup processes in the assigned + /// `tile_idx`. + /// @param smem_ptr_0 Pointer to LDS. + CK_TILE_DEVICE void BaseGemm(StreamKKernelArgs& kargs, + index_t tile_idx, + index_t num_loop, + index_t i_k_a, + index_t i_k_b, + index_t k_size, + void* smem_ptr_0) const + { + const auto c_macro_tile_idx = kargs.tile_partitioner.get_output_tile_index(tile_idx); + index_t i_m = c_macro_tile_idx[UniversalGemmKernel::I0] * TilePartitioner::MPerBlock; + index_t i_n = c_macro_tile_idx[UniversalGemmKernel::I1] * TilePartitioner::NPerBlock; + + const ADataType* a_ptr = static_cast(kargs.as_ptr[0]) + i_k_a; + const BDataType* b_ptr = static_cast(kargs.bs_ptr[0]) + i_k_b; + CDataType* c_ptr = static_cast(kargs.e_ptr); + + // Run the GEMM pipeline and Epilogue. + RunGemm( + {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, k_size); + } + + /// @brief Runs the main Stream-K algorithm. + /// @param kargs Stream-K kernel arguments. + /// @param cta_idx The current Stream-K workgroup's index. + /// @param smem_ptr_0 Pointer to LDS. + /// @note It is assumed that the first Stream-K workgroup has a `cta_idx` of zero. If a + /// non-persistent data-parallel (DP) section is used, then a Stream-K workgroup's `cta_idx` + /// should be something like `blockIdx.x` minus number of DP workgroups. + CK_TILE_DEVICE void + StreamKGemm(StreamKKernelArgs& kargs, index_t cta_idx, void* smem_ptr_0) const + { + index_t iter_start, iter_end; + kargs.tile_partitioner.get_iter_boundaries(iter_start, iter_end, cta_idx); + + while(iter_start < iter_end) + { + // Get the 1D tile index in the C tensor that this workgroup will work in for this + // iteration of the loop. + index_t tile_idx = + amd_wave_read_first_lane(kargs.tile_partitioner.get_tile_index(iter_start)); + + // Get the start and end boundaries for the current tile. + index_t tile_iter_start, tile_iter_end; + kargs.tile_partitioner.get_tile_boundaries(tile_iter_start, tile_iter_end, tile_idx); + + // Get the start and end iteration within the current tile for the workgroup. + index_t local_iter_start = amd_wave_read_first_lane( + kargs.tile_partitioner.get_local_iter(iter_start, tile_iter_start)); + index_t local_iter_end = + amd_wave_read_first_lane(kargs.tile_partitioner.get_local_iter_end( + tile_iter_start, iter_end, tile_iter_end)); + + // Get the iteration length. + index_t num_loop_sk = local_iter_end - local_iter_start; + + // Determine the total size along the K dimension the workgroup is using in this + // iteration (used to construct tensor views). + index_t k_size = num_loop_sk * TilePartitioner::KPerBlock; + + // Get the K offsets for the A and B tensors + auto [i_k_a, i_k_b] = GetKOffsets( + local_iter_start, kargs.stride_As[0], kargs.stride_Bs[0]); + + if constexpr(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Atomic) + { + BaseGemm(kargs, tile_idx, num_loop_sk, i_k_a, i_k_b, k_size, smem_ptr_0); + } + else + { + // TODO: Apply reduction logic. + } + + // Prepare for next Stream-K loop iteration. + iter_start = tile_iter_end; + block_sync_lds(); + } + } + + /// @brief Entry point for the Stream-K Kernel with non-persistent DP. + /// + /// @par Overview + /// For the Non-Persistent kernel, each data parallel workgroup will + /// compute the results for their assigned macro-tile by calling `BaseGemm()`. + /// The Stream-K workgroups will do their assigned work by calling + /// `StreamKGemm()`, which calls `BaseGemm()` in the Stream-K loop. + template + CK_TILE_DEVICE typename std::enable_if_t operator()(StreamKKernelArgs kargs) const + { + // Allocate LDS + __shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()]; + + index_t block_idx = ck_tile::get_block_1d_id(); + index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile(); + index_t dp_ctas = kargs.tile_partitioner.get_dp_ctas(); + bool is_dp_ctas = block_idx < kargs.tile_partitioner.get_dp_ctas(); + + // Check if at the data parallel section + if(is_dp_ctas) + { + BaseGemm(kargs, block_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0); + } + else + { + // Stream-K + StreamKGemm(kargs, block_idx - dp_ctas, smem_ptr_0); + } + } + + /// @brief Entry point for the Stream-K Kernel with persistent DP. + /// + /// @par Overview + /// For the Persistent kernel, each workgroup will first compute their + /// assigned data-parallel tiles. Each data parallel tile will be computed + /// by calling `BaseGemm()`. Then the workgroups will proceed with the + /// Stream-K portion by calling `StreamKGemm()`, which calls `BaseGemm()` + /// in the Stream-K loop. + template + CK_TILE_DEVICE typename std::enable_if_t operator()(StreamKKernelArgs kargs) const + { + // Allocate LDS + __shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()]; + + index_t block_idx = ck_tile::get_block_1d_id(); + index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile(); + + // Data-parallel section + for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles(); + tile_idx += kargs.tile_partitioner.get_grid()) + { + BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0); + } + + // Stream-K section + StreamKGemm(kargs, block_idx, smem_ptr_0); + } + + private: + /// @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset is + /// the starting macro tile index in the K dimension for the workgroup. + /// @return A tuple containing the offsets into the A and B tensors accounting for the layouts + /// of A and B. + /// @note The default case is that A is assumed to be row major and B is assumed to be column + /// major. + template + CK_TILE_DEVICE static tuple + GetKOffsets(index_t iter_offset, index_t stride_a, index_t stride_b) + { + index_t stride_offset_a; + index_t stride_offset_b; + if constexpr(std::is_same_v) + { + stride_offset_a = stride_a; + } + else + { + stride_offset_a = 1; + } + + if constexpr(std::is_same_v) + { + stride_offset_b = stride_b; + } + else + { + stride_offset_b = 1; + } + + index_t base_offset = iter_offset * TilePartitioner::KPerBlock; + + return make_tuple(base_offset * stride_offset_a, base_offset * stride_offset_b); + } + + CK_TILE_HOST static int NumCU() + { + hipDeviceProp_t dev_prop; + hipDevice_t dev; + hip_check_error(hipGetDevice(&dev)); + hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); + int num_cu = dev_prop.multiProcessorCount; + + return num_cu; + } + + /// @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the kernel + /// @return The occupancy + /// @note This function queries the maximum occupancy of the kernel using + /// `hipOccupancyMaxActiveBlocksPerMultiprocessor`. + CK_TILE_HOST static int Occupancy() + { + int occupancy; + + // Since occupancy of 1 is valid for stream k, we set min_num_block_per_cu to 1 + constexpr int min_block_per_cu = 1; + const auto kernel = kentry; + + hip_check_error( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0)); + + return occupancy; + } +}; +} // namespace reboot /// @brief The Stream K GEMM kernel host arguments. /// diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp index 1962f3518a..e98c60e5f0 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp @@ -186,6 +186,11 @@ struct StreamKTilePartitionerBase */ CK_TILE_HOST_DEVICE index_t get_n() const noexcept; + /** + * @brief Returns an estimate of the number of workgroups writing to the same macro tile in C. + */ + CK_TILE_HOST index_t estimate_num_wgs_per_tile() const noexcept; + protected: index_t num_tiles_; index_t grid_; @@ -246,6 +251,7 @@ struct StreamKTilePartitioner_v2::get_n() c return n_; } +template +CK_TILE_HOST index_t +StreamKTilePartitionerBase::estimate_num_wgs_per_tile() + const noexcept +{ + // In the case of non-atomic reduction or data-parallel only, there will always be 1 workgroup + // writing final results to a given macro tile in C. + int num_wgs_per_tile = 1; + + // Otherwise, for atomics, multiple workgroups may be writing to the same macro tile in C. + if(sk_ctas_ > 0 && ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic) + { + ck_tile::index_t iters_per_sk_cta_non_zero = ck_tile::max(iters_per_sk_cta_, 1); + // Estimate the number of workgroups per macro tile. + num_wgs_per_tile = (iters_per_tile_ / iters_per_sk_cta_non_zero) + + ((iters_per_tile_ % iters_per_sk_cta_non_zero) != 0); + } + + return std::max(num_wgs_per_tile, 1); +} + template diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 47b776f401..810ae8d231 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -46,6 +46,7 @@ set(REGRESSION_TESTS test_ck_tile_fmha_fwd_bf16 test_ck_tile_fmha_fwd_fp16 test_ck_tile_fmha_fwd_fp8 + test_ck_tile_streamk_reboot_extended ) function(add_test_executable TEST_NAME) diff --git a/test/ck_tile/gemm_streamk/CMakeLists.txt b/test/ck_tile/gemm_streamk/CMakeLists.txt index 331118da59..eba411e271 100644 --- a/test/ck_tile/gemm_streamk/CMakeLists.txt +++ b/test/ck_tile/gemm_streamk/CMakeLists.txt @@ -117,6 +117,18 @@ if(GPU_TARGETS MATCHES "gfx9") # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp # ) add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp) + add_gtest_executable(test_ck_tile_streamk_reboot_smoke + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_fp16_persistent.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_bf16_persistent.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp + test_gemm_streamk_reboot_util.cpp) + add_gtest_executable(test_ck_tile_streamk_reboot_extended + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_fp16_persistent.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_bf16_persistent.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp + test_gemm_streamk_reboot_util.cpp) else() message(DEBUG "Skipping test_ck_tile_streamk tests for current target") endif() diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp new file mode 100644 index 0000000000..eb4478f3d6 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp @@ -0,0 +1,19 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_reboot_types.hpp" +#include "test_gemm_streamk_reboot_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileStreamKRebootBf16NonPersistent : public TestCkTileStreamKReboot +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKRebootBf16NonPersistent + +TYPED_TEST_SUITE(TestCkTileStreamKRebootBf16NonPersistent, KernelTypesStreamKBf16NonPersistent); + +#include "test_gemm_streamk_reboot_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_reboot_bf16_persistent.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_reboot_bf16_persistent.cpp new file mode 100644 index 0000000000..c42ada1a98 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_reboot_bf16_persistent.cpp @@ -0,0 +1,19 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_reboot_types.hpp" +#include "test_gemm_streamk_reboot_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileStreamKRebootBf16Persistent : public TestCkTileStreamKReboot +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKRebootBf16Persistent + +TYPED_TEST_SUITE(TestCkTileStreamKRebootBf16Persistent, KernelTypesStreamKBf16Persistent); + +#include "test_gemm_streamk_reboot_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp new file mode 100644 index 0000000000..664c16a5e6 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp @@ -0,0 +1,19 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_reboot_types.hpp" +#include "test_gemm_streamk_reboot_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileStreamKRebootFp16NonPersistent : public TestCkTileStreamKReboot +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKRebootFp16NonPersistent + +TYPED_TEST_SUITE(TestCkTileStreamKRebootFp16NonPersistent, KernelTypesStreamKFp16NonPersistent); + +#include "test_gemm_streamk_reboot_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_reboot_fp16_persistent.cpp b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_reboot_fp16_persistent.cpp new file mode 100644 index 0000000000..39c79b4180 --- /dev/null +++ b/test/ck_tile/gemm_streamk/extended_tests/test_gemm_streamk_reboot_fp16_persistent.cpp @@ -0,0 +1,19 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_reboot_types.hpp" +#include "test_gemm_streamk_reboot_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileStreamKRebootFp16Persistent : public TestCkTileStreamKReboot +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKRebootFp16Persistent + +TYPED_TEST_SUITE(TestCkTileStreamKRebootFp16Persistent, KernelTypesStreamKFp16Persistent); + +#include "test_gemm_streamk_reboot_extended_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp new file mode 100644 index 0000000000..0c1813fb65 --- /dev/null +++ b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp @@ -0,0 +1,19 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_reboot_types.hpp" +#include "test_gemm_streamk_reboot_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileStreamKRebootBf16NonPersistent : public TestCkTileStreamKReboot +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKRebootBf16NonPersistent + +TYPED_TEST_SUITE(TestCkTileStreamKRebootBf16NonPersistent, KernelTypesStreamKBf16NonPersistent); + +#include "test_gemm_streamk_reboot_smoke_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_reboot_bf16_persistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_reboot_bf16_persistent.cpp new file mode 100644 index 0000000000..e78092c4ba --- /dev/null +++ b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_reboot_bf16_persistent.cpp @@ -0,0 +1,19 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_reboot_types.hpp" +#include "test_gemm_streamk_reboot_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileStreamKRebootBf16Persistent : public TestCkTileStreamKReboot +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKRebootBf16Persistent + +TYPED_TEST_SUITE(TestCkTileStreamKRebootBf16Persistent, KernelTypesStreamKBf16Persistent); + +#include "test_gemm_streamk_reboot_smoke_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp new file mode 100644 index 0000000000..5e6118bd0c --- /dev/null +++ b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp @@ -0,0 +1,19 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_reboot_types.hpp" +#include "test_gemm_streamk_reboot_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileStreamKRebootFp16NonPersistent : public TestCkTileStreamKReboot +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKRebootFp16NonPersistent + +TYPED_TEST_SUITE(TestCkTileStreamKRebootFp16NonPersistent, KernelTypesStreamKFp16NonPersistent); + +#include "test_gemm_streamk_reboot_smoke_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_reboot_fp16_persistent.cpp b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_reboot_fp16_persistent.cpp new file mode 100644 index 0000000000..9f9c8f8234 --- /dev/null +++ b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_reboot_fp16_persistent.cpp @@ -0,0 +1,19 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_reboot_types.hpp" +#include "test_gemm_streamk_reboot_util.hpp" +#include "gtest/gtest.h" + +template +class TestCkTileStreamKRebootFp16Persistent : public TestCkTileStreamKReboot +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKRebootFp16Persistent + +TYPED_TEST_SUITE(TestCkTileStreamKRebootFp16Persistent, KernelTypesStreamKFp16Persistent); + +#include "test_gemm_streamk_reboot_smoke_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_extended_cases.inc b/test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_extended_cases.inc new file mode 100644 index 0000000000..8b6522bd75 --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_extended_cases.inc @@ -0,0 +1,24 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +TYPED_TEST(TEST_SUITE_NAME, StreamK_DP2TSK) +{ + const ck_tile::index_t num_cu = get_cu_count(); + constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; + constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; + + // For DP 2-Tile SK, there are 2 important terms: + // Term 1: (M_Tile * num_cu * 2) - This ensures we have at least 2 cycles that will fully + // saturate all CUs. This assumes tile sizes are large enough such that occupancy is 1. + // Term 2: (M_Tile * 2) - This ensures we have 1 cycle that does not fully saturate all CUs + // (i.e., we will have remainder tiles). This guarantees we have 1 full tile cycle plus + // remainder tiles for the 2 Tile SK portion; the rest of the tiles will fully saturate all CUs + // for the DP portion. + ck_tile::index_t M = (M_Tile * num_cu * 2) + (M_Tile * 2); + ck_tile::index_t N = N_Tile; + ck_tile::index_t K = 2048; + + this->Run(M, N, K); +} diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_smoke_cases.inc b/test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_smoke_cases.inc new file mode 100644 index 0000000000..d714b3446c --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_smoke_cases.inc @@ -0,0 +1,47 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +TYPED_TEST(TEST_SUITE_NAME, StreamK_EdgeCase) +{ + ck_tile::index_t M = 256; + ck_tile::index_t N = 256; + ck_tile::index_t K = 256; + + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_DPOnly) +{ + const ck_tile::index_t num_cu = get_cu_count(); + constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; + constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; + constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; + + // For DP only, we ensure that the number of tiles is a multiple of the number of CUs. This + // assumes tile sizes are large enough such that occupancy is 1. + ck_tile::index_t M = M_Tile * num_cu; + ck_tile::index_t N = N_Tile; + ck_tile::index_t K = K_Tile; + + this->Run(M, N, K); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly) +{ + const ck_tile::index_t num_cu = get_cu_count(); + constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; + constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; + constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; + + // For SK only, we have 4 macro tiles in C. But, we need to make sure there is enough work along + // the K dimension to avoid falling into the edge case. Thus, we always have at least num_cu + // macro tiles in the K dimension. This assumes tile sizes are large enough such that occupancy + // is 1. + ck_tile::index_t M = M_Tile * 2; + ck_tile::index_t N = N_Tile * 2; + ck_tile::index_t K = K_Tile * num_cu; + + this->Run(M, N, K); +} diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_types.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_types.hpp new file mode 100644 index 0000000000..1db53ddd64 --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_types.hpp @@ -0,0 +1,56 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using BF16 = ck_tile::bf16_t; + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +using Persistent = std::true_type; +using NonPersistent = std::false_type; + +using I32 = ck_tile::number<32>; +using I256 = ck_tile::number<256>; + +// clang-format off +using KernelTypesStreamKFp16Persistent = ::testing::Types< +// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent + + std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent> +>; + +using KernelTypesStreamKBf16Persistent = ::testing::Types< + std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>, + std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>, + std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent> +>; + +using KernelTypesStreamKFp16NonPersistent = ::testing::Types< +// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent + + std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent> +>; + +using KernelTypesStreamKBf16NonPersistent = ::testing::Types< + std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>, + std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>, + std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent> +>; +// clang-format on diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_util.cpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_util.cpp new file mode 100644 index 0000000000..39a92d622d --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_util.cpp @@ -0,0 +1,10 @@ +#include "test_gemm_streamk_reboot_util.hpp" + +ck_tile::index_t get_cu_count() +{ + hipDeviceProp_t dev_prop; + hipDevice_t dev; + ck_tile::hip_check_error(hipGetDevice(&dev)); + ck_tile::hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); + return dev_prop.multiProcessorCount; +} diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_util.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_util.hpp new file mode 100644 index 0000000000..85863989b0 --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_util.hpp @@ -0,0 +1,283 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + + // The logic below may need to become more advanced once bugs in Stream-K Tile Partitioner are + // resolved. Because the number of WGs contributing to a macro tile in C may not be the same for + // all macro tiles in C. + + // Calculate error due to more than 1 WG contributing to the same macro tile in C + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +ck_tile::index_t get_cu_count(); + +template +class TestCkTileStreamKReboot : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = std::tuple_element_t<2, Tuple>; + using ADataType = std::tuple_element_t<3, Tuple>; + using BDataType = std::tuple_element_t<4, Tuple>; + using AccDataType = std::tuple_element_t<5, Tuple>; + using CDataType = std::tuple_element_t<6, Tuple>; + using DsLayout = ck_tile::tuple<>; + using DsDataType = ck_tile::tuple<>; + static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, Tuple>::value; + static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, Tuple>::value; + static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, Tuple>::value; + static constexpr bool Persistent = std::tuple_element_t<10, Tuple>::value; + + template + ck_tile::index_t invoke_streamk(const ck_tile::reboot::StreamKHostArgs& args, + const ck_tile::stream_config& s) + { + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool kPadM = PadM; + constexpr bool kPadN = PadN; + constexpr bool kPadK = PadK; + constexpr bool preshuffle = Preshuffle; + + constexpr bool DoubleSmemBuffer = false; + constexpr int kBlockPerCu = 1; + constexpr bool StructuredSparsity = false; + constexpr bool NumWaveGroup = 1; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = + ck_tile::StreamKTilePartitioner_v2; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + + // We create the GEMM pipeline without specifying has_hot_loop or tail_num. + // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K + // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K + // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + // For initial testing, we will just test with one pipeline. + // More extensive testing is coming later and will test other pipelines. + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + UniversalGemmProblem::TransposeC, + memory_operation>>; + + using Kernel = + ck_tile::reboot::StreamKKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + + if(!Kernel::IsSupportedArgument(kargs)) + { + EXPECT_TRUE(false); + } + + dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner); + dim3 block_dims = Kernel::BlockSize(); + + ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grid_dims, block_dims, 0, kargs)); + + return kargs.tile_partitioner.estimate_num_wgs_per_tile(); + }; + + return Run(ck_tile::integral_constant{}); + } + + public: + void Run(ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::StreamKReductionStrategy reduction_strategy = + ck_tile::StreamKReductionStrategy::Atomic, + ck_tile::index_t stride_A = 0, + ck_tile::index_t stride_B = 0, + ck_tile::index_t stride_C = 0) + { + // Since M, N, and K will vary depending on the number of CUs, we print it here to + // facilitate test output readability. + std::cout << "M: " << M << ", N: " << N << ", K: " << K << std::endl; + + using namespace ck_tile::literals; + + if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction) + { + throw std::runtime_error("Reduction Strategy is current unsupported!\n"); + } + + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if constexpr(std::is_same_v) + { + return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + stride_A = f_get_default_stride(M, K, stride_A, ALayout{}); + stride_B = f_get_default_stride(K, N, stride_B, BLayout{}); + stride_C = f_get_default_stride(M, N, stride_C, CLayout{}); + + ck_tile::HostTensor a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{})); + ck_tile::HostTensor b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{})); + ck_tile::HostTensor c_m_n_dev_result( + f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + + ck_tile::FillUniformDistributionIntegerValue{-5, 5, /*seed*/ 11939}(a_m_k); + ck_tile::FillUniformDistributionIntegerValue{-5, 5, /*seed*/ 11940}(b_k_n); + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + ck_tile::reboot::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + M, + N, + K, + stride_A, + stride_B, + stride_C, + reduction_strategy}; + + ck_tile::index_t num_accumulations_per_tile = + invoke_streamk( + args, ck_tile::stream_config{nullptr, false, 0, 0, 1}); + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + + ck_tile::HostTensor c_m_n_host_ref( + f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_ref); + + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, num_accumulations_per_tile, max_accumulated_value); + + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + EXPECT_TRUE(pass); + }; +}; diff --git a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp index 89d72d844b..9028f7bf10 100644 --- a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp +++ b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp @@ -77,6 +77,26 @@ TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, ReductionStrategy) expected_partials_size + expected_flags_size); } +TEST(StreamKTilePartitionerBaseEstimateNumWgsPerTile, EstimateNumWgsPerTileLowerValue) +{ + using Config = StreamKTilePartitionerBaseConfigDP2TileSK; + + ck_tile::StreamKTilePartitionerBase tile_partitioner{ + Config::M, Config::N, Config::K, Config::GRID}; + + EXPECT_EQ(tile_partitioner.estimate_num_wgs_per_tile(), 1); +} + +TEST(StreamKTilePartitionerBaseEstimateNumWgsPerTile, EstimateNumWgsPerTileEqualValue) +{ + using Config = StreamKTilePartitionerBaseConfigSKOnlyWith2WgsPerSKTile; + + ck_tile::StreamKTilePartitionerBase tile_partitioner{ + Config::M, Config::N, Config::K, Config::GRID}; + + EXPECT_EQ(tile_partitioner.estimate_num_wgs_per_tile(), 2); +} + TEST(StreamKTilePartitionerBaseGetLocalIter, GetLocalIter) { // Types diff --git a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp index 4fc654a7ea..eb62f4253b 100644 --- a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp +++ b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp @@ -194,6 +194,23 @@ struct StreamKTilePartitionerBaseConfigDP2TileSK : public StreamKTilePartitioner ck_tile::sequence>; }; +struct StreamKTilePartitionerBaseConfigSKOnlyWith2WgsPerSKTile + : public StreamKTilePartitionerBaseConfig +{ + static constexpr ck_tile::index_t M = 16; + static constexpr ck_tile::index_t N = 4; + static constexpr ck_tile::index_t K = 16; + static constexpr ck_tile::index_t GRID = 8; + + static constexpr ck_tile::index_t M_TILE = 4; + static constexpr ck_tile::index_t N_TILE = 4; + static constexpr ck_tile::index_t K_TILE = 8; + + using GemmShape = ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; +}; + struct StreamKTilePartitionerBaseConfigDPOnly : public StreamKTilePartitionerBaseConfig { static constexpr ck_tile::index_t M = 12;