diff --git a/include/ck_tile/ops/common/streamk_common.hpp b/include/ck_tile/ops/common/streamk_common.hpp index c97282a8be..c723251112 100644 --- a/include/ck_tile/ops/common/streamk_common.hpp +++ b/include/ck_tile/ops/common/streamk_common.hpp @@ -8,7 +8,8 @@ namespace ck_tile { enum StreamKReductionStrategy : uint32_t { - Atomic = 0u, - Reduction = 1u + Atomic = 0u, + Reduction = 1u, + TreeReduction = 2u }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index d518a15b7e..0eaedbfb3a 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -33,9 +33,10 @@ #include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" -#include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp" -#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp" -#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp" +#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_coherency.hpp" +#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp" +#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp" +#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp" #include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp" diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_coherency.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_coherency.hpp new file mode 100644 index 0000000000..65e29c7fd5 --- /dev/null +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_coherency.hpp @@ -0,0 +1,35 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once +#include "ck_tile/core/arch/arch.hpp" +namespace ck_tile { + +template +struct StreamKCoherency +{ + static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE = + amd_buffer_coherence_enum::coherence_default; +}; + +template +struct StreamKCoherency> +{ + static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE = + amd_buffer_coherence_enum::SYSTEM_NT0; +}; + +template +struct StreamKCoherency> +{ + static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE = + amd_buffer_coherence_enum::glc_slc; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp similarity index 79% rename from include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp rename to include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp index 6130107cfe..d1fd32dc1b 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp @@ -6,6 +6,7 @@ #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/common.hpp" #include "ck_tile/host/concat.hpp" +#include "streamk_gemm_coherency.hpp" namespace ck_tile { @@ -318,37 +319,58 @@ struct StreamKKernel * results. * @param kargs Kernel arguments, including the workspace pointer. * @param cta_idx The index of the current thread block (CTA). - * @note This function utilizes a workgroup barrier to set a synchronization flag for the given - * CTA index. + * @note This function utilizes a scalar store to write to the flags buffer. */ CK_TILE_DEVICE void SignalStorePartialDone(const StreamKKernelArgs& kargs, index_t cta_idx) const { - auto sk_flags_ptr = static_cast(kargs.workspace_ptr); - workgroup_barrier sk_flags(sk_flags_ptr); - sk_flags.wait_set(0, 1, cta_idx); + auto* sk_flags_ptr = static_cast(kargs.workspace_ptr); + index_t offset = cta_idx * sizeof(index_t); + + asm volatile("s_mov_b32 m0, %2\n\t" + // Depending on the architecture, the GLC flag will bypass the approproriate + // cache level(s) to ensure the write is visible to other workgroups. See the + // appropriate ISA for details about the GLC modifier. + "s_store_dword %0, %1, %2 glc\n\t" + "s_waitcnt lgkmcnt(0)" // Wait for the store to complete + : + : "s"(1), "s"(sk_flags_ptr), "s"(offset) + : "memory"); } /** * @brief Waits for the thread block (cta_idx) to complete storing its partial results. * @param kargs Kernel arguments, including the workspace pointer. * @param cta_idx The index of the thread block (CTA). - * @note This function utilizes a workgroup barrier to wait for the synchronization flag to be - * set by the given CTA index. + * @note This function utilizes a scalar load to read from the flags + * buffer. */ CK_TILE_DEVICE void WaitStorePartialDone(const StreamKKernelArgs& kargs, index_t cta_idx) const { - auto sk_flags_ptr = static_cast(kargs.workspace_ptr); - workgroup_barrier sk_flags(sk_flags_ptr); - sk_flags.wait_eq(1, cta_idx); + auto* sk_flags_ptr = static_cast(kargs.workspace_ptr); + index_t result; + index_t offset = cta_idx * sizeof(index_t); + + do + { + asm volatile("s_mov_b32 m0, %2\n\t" + // Depending on the architecture, the GLC flag will bypass the + // approproriate cache level(s) to avoid reading stale flags. See the + // appropriate ISA for details about the GLC modifier. + "s_load_dword %0, %1, %2 glc\n\t" + "s_waitcnt lgkmcnt(0)" // Wait for the load to complete + : "=s"(result) + : "s"(sk_flags_ptr), "s"(offset) + : "memory"); + } while(result != 1); } /** * @brief Adds the values of a block tile to an output block tile. * @param in_out_block_tile The output block tile to which values are added. * @param in_block_tile The input block tile whose values are added. - * @note This function iterates over the distributed spans of the block tiles and updates the - * output block tile with accumulated values. + * @note This function iterates over the distributed spans of the block tiles and updates + * the output block tile with accumulated values. */ template CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile, @@ -370,8 +392,8 @@ struct StreamKKernel * @param cta_idx The index of the thread block (CTA). * @param c_block_tile_dist The tile distribution for the block. * @return The loaded partial block tile. - * @note This function calculates the buffer pointer and uses the tile distribution for loading - * the partial block tile. + * @note This function calculates the buffer pointer and uses the tile distribution for + * loading the partial block tile. */ template CK_TILE_DEVICE auto LoadPartial(const StreamKKernelArgs& kargs, @@ -405,8 +427,8 @@ struct StreamKKernel * @param kargs Kernel arguments, including the workspace pointer. * @param cta_idx The index of the thread block (CTA). * @param c_block_tile The block tile to be stored. - * @note This function calculates the buffer pointer and uses the tile window for storing the - * partial block tile. + * @note This function calculates the buffer pointer and uses the tile window for storing + * the partial block tile. */ template CK_TILE_DEVICE void StorePartial(const StreamKKernelArgs& kargs, @@ -420,7 +442,10 @@ struct StreamKKernel kargs.tile_partitioner.get_flags_buffer_size() + cta_idx * c_block_tile_buffer_size; - const auto& partial_tensor_view = make_naive_tensor_view( + const auto& partial_tensor_view = make_naive_tensor_view< + address_space_enum::global, + memory_operation_enum::set, + StreamKCoherency::BUFFER_COHERENCE>( static_cast(partial_buffer_ptr), make_tuple(number{}, number{}), make_tuple(TilePartitioner::NPerBlock, 1), @@ -431,8 +456,11 @@ struct StreamKKernel partial_tensor_view, make_tuple(number{}, number{}), {0, 0}); - store_tile(partial_tile_window, c_block_tile); + // Wait for all vector stores for this wavefront to complete + s_waitcnt(); + // Wait for all wavefronts in this workgroup to arrive here before continuing + __builtin_amdgcn_s_barrier(); } /** @@ -483,7 +511,8 @@ struct StreamKKernel { BaseGemm(kargs, tile_idx, num_loop_sk, i_k_a, i_k_b, k_size, smem_ptr_0); } - else + else if(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Reduction || + TilePartitioner::ReductionStrategy == StreamKReductionStrategy::TreeReduction) { const auto c_macro_tile_idx = kargs.tile_partitioner.get_output_tile_index(tile_idx); @@ -528,46 +557,107 @@ struct StreamKKernel auto tile_started = iter_start == tile_iter_start; auto tile_ended = iter_end >= tile_iter_end; - if(!tile_started) + + if constexpr(TilePartitioner::ReductionStrategy == + StreamKReductionStrategy::Reduction) { - StorePartial(kargs, cta_idx, c_block_tile); - // Ensure device-wide visibility of partial results stored in global memory - // before signaling completion. __threadfence() guarantees that all global - // memory writes by this thread are visible to other threads on the device. - __threadfence(); // send signal when the store is done - SignalStorePartialDone(kargs, cta_idx); + if(!tile_started) + { + StorePartial(kargs, cta_idx, c_block_tile); + SignalStorePartialDone(kargs, cta_idx); + } + else + { + auto accum_block_tile = c_block_tile; + if(!tile_ended) + { + const index_t iter_per_tile = + kargs.tile_partitioner.get_iters_per_tile(); + const index_t iter_per_cta = + kargs.tile_partitioner.get_iters_per_sk_cta(); + const index_t extra_iters = kargs.tile_partitioner.get_extra_iters(); + int accum_iters = local_iter_end - local_iter_start; + int next_cta = cta_idx + 1; + + while(accum_iters < iter_per_tile) + { + WaitStorePartialDone(kargs, next_cta); + + using BlockType = remove_cvref_t; + AddBlockTile( + accum_block_tile, + LoadPartial( + kargs, next_cta, c_block_tile.get_tile_distribution())); + + accum_iters += iter_per_cta + (next_cta < extra_iters); + ++next_cta; + } + } + + auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3); + EpiloguePipeline{}( + c_block_window, accum_block_tile, ds_block_window, smem_ptr_0); + } } - else + else // Tree Reduction { auto accum_block_tile = c_block_tile; - if(!tile_ended) + index_t tile_local_cta_idx = + kargs.tile_partitioner.get_tile_local_cta_index(tile_iter_start, cta_idx); + + for(index_t stride = 1;; stride <<= 1) { - const index_t iter_per_tile = kargs.tile_partitioner.get_iters_per_tile(); - const index_t iter_per_cta = kargs.tile_partitioner.get_iters_per_sk_cta(); - const index_t extra_iters = kargs.tile_partitioner.get_extra_iters(); - int accum_iters = local_iter_end - local_iter_start; - int next_cta = cta_idx + 1; + const index_t partner_cta_idx = cta_idx + stride; + const index_t partner_start_iter = + kargs.tile_partitioner.get_start_iter(partner_cta_idx); + bool partner_in_tile = partner_start_iter < tile_iter_end; - while(accum_iters < iter_per_tile) + // If the partner of the workgroup who started the tile is not in this tile, + // then the work for this tile is done and results can be stored in the C + // tensor. + if(tile_started && !partner_in_tile) { - WaitStorePartialDone(kargs, next_cta); + auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3); + EpiloguePipeline{}( + c_block_window, accum_block_tile, ds_block_window, smem_ptr_0); + break; + } - using BlockType = remove_cvref_t; - AddBlockTile( - accum_block_tile, - LoadPartial( - kargs, next_cta, c_block_tile.get_tile_distribution())); - - accum_iters += iter_per_cta + (next_cta < extra_iters); - ++next_cta; + // It's this workgroup's turn to read from partials. + if(tile_local_cta_idx % (stride << 1) == 0) + { + // If this workgroup's partner is in the tile then it can read from + // partials and accumulate results. + if(partner_in_tile) + { + WaitStorePartialDone(kargs, partner_cta_idx); + using BlockType = remove_cvref_t; + AddBlockTile(accum_block_tile, + LoadPartial( + kargs, + partner_cta_idx, + c_block_tile.get_tile_distribution())); + } + } + // Otherwise, it's this workgroup's turn to write to partials. All + // workgroups, except the workgroup who starts the tile, will write to + // partials. + else + { + StorePartial(kargs, cta_idx, accum_block_tile); + SignalStorePartialDone(kargs, cta_idx); + // Once the workgroup writes to partials, it has no more work to do for + // this tile. + break; } } - - auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3); - EpiloguePipeline{}( - c_block_window, accum_block_tile, ds_block_window, smem_ptr_0); } } + else + { + static_assert( + "An implementation does not exist for the chosen reduction strategy."); + } // Prepare for next Stream-K loop iteration. iter_start = tile_iter_end; @@ -640,10 +730,10 @@ struct StreamKKernel private: /** - * @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset is - * the starting macro tile index in the K dimension for the workgroup. - * @return A tuple containing the offsets into the A and B tensors accounting for the layouts - * of A and B. + * @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset + * is the starting macro tile index in the K dimension for the workgroup. + * @return A tuple containing the offsets into the A and B tensors accounting for the + * layouts of A and B. * @note The default case is that A is assumed to be row major and B is assumed to be column * major. */ @@ -688,7 +778,8 @@ struct StreamKKernel } /** - * @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the kernel + * @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the + * kernel * @return The occupancy * @note This function queries the maximum occupancy of the kernel using * `hipOccupancyMaxActiveBlocksPerMultiprocessor`. diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp similarity index 92% rename from include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp rename to include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp index 9ab75fbdbf..a6022e8b8e 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp @@ -46,6 +46,16 @@ struct StreamKTilePartitionerBase CK_TILE_HOST_DEVICE index_t get_flags_buffer_size() const noexcept; public: + /** + * @brief Calculates the start iteration for the given the cta_idx. + * @param cta_idx The current Stream-K workgroup's index. + * @return index_t The start iteration. + * @note It is assumed that the first Stream-K workgroup has a `cta_idx` of zero. If a + * non-persistent DP section is used, then a Stream-K workgroup's `cta_idx` should be something + * like `blockIdx.x` minus number of DP workgroups. + */ + CK_TILE_DEVICE index_t get_start_iter(index_t cta_idx) const noexcept; + /** * @brief Calculates the start and end iteration given the cta_idx. * @@ -107,7 +117,17 @@ struct StreamKTilePartitionerBase get_local_iter_end(index_t tile_iter_start, index_t iter_end, index_t tile_iter_end) noexcept; /** - * @brief Calculates the workgroups 2D tile index in the C tensor given the 1D tile index. + * @brief Calculates the workgroup's local CTA idx within the given tile. + * + * @param tile_iter_start The starting tile iteration. + * @param cta_idx The Stream-K workgroup index. + * @return index_t The tile local workgroup index in the tile. + */ + CK_TILE_DEVICE index_t get_tile_local_cta_index(index_t tile_iter_start, + index_t cta_idx) const noexcept; + + /** + * @brief Calculates the workgroup's 2D tile index in the C tensor given the 1D tile index. * * @param tile_idx The 1D tile index in the C tensor for the workgroup. * @return index_t The corresponding 2D tile index in the C tensor for the workgroup. diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp similarity index 88% rename from include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp rename to include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp index acc1860f1f..1764a1ce83 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp @@ -61,13 +61,24 @@ StreamKTilePartitionerBase::get_flags return sizeof(index_t) * sk_ctas_; } +template +CK_TILE_DEVICE index_t +StreamKTilePartitionerBase::get_start_iter( + index_t cta_idx) const noexcept +{ + // Compute the number of extra iterations done before this CTA. If the cta_idx is less than + // extra_iters, the number of extra iterations before the CTA is exactly the cta_idx. Otherwise, + // it is extra_iters. + index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_); + return total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me; +} + template CK_TILE_DEVICE void StreamKTilePartitionerBase::get_iter_boundaries( index_t& iter, index_t& iter_end, index_t cta_idx) const noexcept { - index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_); - iter = total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me; + iter = get_start_iter(cta_idx); iter_end = iter + iters_per_sk_cta_ + (cta_idx < extra_iters_); } @@ -104,6 +115,24 @@ StreamKTilePartitionerBase::get_local return ck_tile::min(iter_end, tile_iter_end) - tile_iter; } +template +CK_TILE_DEVICE index_t +StreamKTilePartitionerBase::get_tile_local_cta_index( + index_t tile_iter_start, index_t cta_idx) const noexcept +{ + tile_iter_start = tile_iter_start - (dp_tiles_ * iters_per_tile_); + + // Compute how many WGs fit before this tile starts assuming each WG does an + // extra_iter + const index_t num_extra_iter_ctas = tile_iter_start / (iters_per_sk_cta_ + 1); + // Compute how many WGs fit before this tile starts excluding extra iters + const index_t num_non_extra_iter_ctas = (tile_iter_start - extra_iters_) / iters_per_sk_cta_; + // Compute the CTA idx for the CTA that starts this tile + const index_t coop_group_start = + num_extra_iter_ctas < extra_iters_ ? num_extra_iter_ctas : num_non_extra_iter_ctas; + return cta_idx - coop_group_start; +} + template CK_TILE_DEVICE auto StreamKTilePartitionerBase::get_output_tile_index( @@ -121,7 +150,8 @@ CK_TILE_HOST_DEVICE index_t StreamKTilePartitionerBase::get_workspace_size( index_t acc_element_bytes) const noexcept { - if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) + if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction || + ReductionStrategy == StreamKReductionStrategy::TreeReduction) { return get_partials_buffer_size(acc_element_bytes) + get_flags_buffer_size(); diff --git a/test/ck_tile/gemm_streamk/CMakeLists.txt b/test/ck_tile/gemm_streamk/CMakeLists.txt index d8b4ff945f..1390e5ee07 100644 --- a/test/ck_tile/gemm_streamk/CMakeLists.txt +++ b/test/ck_tile/gemm_streamk/CMakeLists.txt @@ -23,6 +23,9 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950") #TODO: support all arches #TODO: current c-shuffle only supports C layout as R add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp) + add_gtest_executable(test_ck_tile_streamk_reduction + ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp + test_gemm_streamk_util.cpp) add_gtest_executable(test_ck_tile_streamk_smoke ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_persistent.cpp ${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf16_persistent.cpp diff --git a/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_reduction.cpp b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_reduction.cpp new file mode 100644 index 0000000000..bcd4583da2 --- /dev/null +++ b/test/ck_tile/gemm_streamk/smoke_tests/test_gemm_streamk_fp16_reduction.cpp @@ -0,0 +1,17 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_common_includes.hpp" + +template +class TestCkTileStreamKFp16Reduction : public TestCkTileStreamK +{ +}; + +#define TEST_SUITE_NAME TestCkTileStreamKFp16Reduction + +TYPED_TEST_SUITE(TestCkTileStreamKFp16Reduction, KernelTypesStreamKFp16Reduction); + +#include "test_gemm_streamk_reduction_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_reduction_cases.inc b/test/ck_tile/gemm_streamk/test_gemm_streamk_reduction_cases.inc new file mode 100644 index 0000000000..66c3e3b5e9 --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_reduction_cases.inc @@ -0,0 +1,88 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile_Tree) +{ + const ck_tile::index_t num_cu = get_cu_count(); + constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; + constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; + constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; + + ck_tile::index_t M = M_Tile; + ck_tile::index_t N = N_Tile; + ck_tile::index_t K = K_Tile * num_cu; + + this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile) +{ + const ck_tile::index_t num_cu = get_cu_count(); + constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; + constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; + constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; + + ck_tile::index_t M = M_Tile; + ck_tile::index_t N = N_Tile; + ck_tile::index_t K = K_Tile * num_cu; + + this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Tree) +{ + const ck_tile::index_t num_cu = get_cu_count(); + constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; + constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; + constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; + + ck_tile::index_t M = M_Tile * 4; + ck_tile::index_t N = N_Tile; + ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile); + + this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Reduction) +{ + const ck_tile::index_t num_cu = get_cu_count(); + constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; + constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; + constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; + + ck_tile::index_t M = M_Tile * 4; + ck_tile::index_t N = N_Tile; + ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile); + + this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles_Tree) +{ + const ck_tile::index_t num_cu = get_cu_count(); + constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; + constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; + constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; + + ck_tile::index_t M = M_Tile * 3; + ck_tile::index_t N = N_Tile * 7; + ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile); + + this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles) +{ + const ck_tile::index_t num_cu = get_cu_count(); + constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value; + constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value; + constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value; + + ck_tile::index_t M = M_Tile * 3; + ck_tile::index_t N = N_Tile * 7; + ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile); + + this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction); +} diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp index efb7416580..ece313b8aa 100644 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp @@ -33,6 +33,14 @@ using KernelTypesStreamKFp16Persistent = ::testing::Types< std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent> >; +using KernelTypesStreamKFp16Reduction = ::testing::Types< +// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent + std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, + std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, + std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>>; + using KernelTypesStreamKBf16Persistent = ::testing::Types< std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>, std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>, diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp index 213702551a..540109a999 100644 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp @@ -144,7 +144,11 @@ class TestCkTileStreamK : public ::testing::Test using Kernel = ck_tile::StreamKKernel; - auto kargs = Kernel::MakeKernelArgs(args); + auto kargs = Kernel::MakeKernelArgs(args); + const auto workspace_size = Kernel::GetWorkSpaceSize(kargs); + ck_tile::DeviceMem workspace_data(workspace_size); + workspace_data.SetZero(); + kargs.workspace_ptr = workspace_data.GetDeviceBuffer(); if(!Kernel::IsSupportedArgument(kargs)) { @@ -184,11 +188,6 @@ class TestCkTileStreamK : public ::testing::Test using namespace ck_tile::literals; - if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction) - { - throw std::runtime_error("Reduction Strategy is current unsupported!\n"); - } - auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, @@ -252,9 +251,25 @@ class TestCkTileStreamK : public ::testing::Test stride_B, stride_C}; - ck_tile::index_t num_accumulations_per_tile = - invoke_streamk( + ck_tile::index_t num_accumulations_per_tile; + + if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic) + { + num_accumulations_per_tile = invoke_streamk( args, ck_tile::stream_config{nullptr, false, 0, 0, 1}); + } + else if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction) + { + num_accumulations_per_tile = + invoke_streamk( + args, ck_tile::stream_config{nullptr, false, 0, 0, 1}); + } + else + { + num_accumulations_per_tile = + invoke_streamk( + args, ck_tile::stream_config{nullptr, false, 0, 0, 1}); + } c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); diff --git a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp index dd74efc27a..637f71c04f 100644 --- a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp +++ b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp @@ -372,6 +372,85 @@ TEST(StreamKTilePartitionerBaseGetOutputTileIndex, TestAllMappings) } } +TEST(StreamKTilePartitionerBaseGetTileLocalCtaIndex, SKOnlyLargeK) +{ + /* + The StreamKTilePartitionerBaseConfigSKOnlyLargeK has the following form: + - tiles in the C tensor: 2 + - iters_per_tile: 5 + - grid: 5 + - dp_tiles: 0 + - sk_tiles: 2 + - iters_per_sk_cta: 2 + - extra_iters: 0 + + The tiles with iters are as follows: + + tile_idx: __________0_________|_________1_________| + tile_iter:| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | + | | | | | | | | | | | + <---------------SK Tiles--------------->| + + From the above configuration, we get the following: + - SK CTA 0: tile_iter_start is 0 with local CTA index of 0 in tile 0 + - SK CTA 1: tile_iter_start is 0 with local CTA index of 1 in tile 0 + - SK CTA 2: tile_iter_start is 0 with local CTA index of 2 in tile 0 + - SK CTA 2: tile_iter_start is 5 with local CTA index of 0 in tile 1 + - SK CTA 3: tile_iter_start is 5 with local CTA index of 1 in tile 1 + - SK CTA 4: tile_iter_start is 5 with local CTA index of 2 in tile 1 + */ + + // Now we create a vector of triplets (tile_iter_start, cta_idx, tile_local_cta_idx) to test + std::vector> sk_triplets{ + {0, 0, 0}, {0, 1, 1}, {0, 2, 2}, {5, 2, 0}, {5, 3, 1}, {5, 4, 2}}; + + for(const auto& triplet : sk_triplets) + { + const auto& [tile_iter_start, cta_idx, tile_local_cta_idx] = triplet; + test_get_tile_local_cta_idx( + tile_iter_start, cta_idx, tile_local_cta_idx); + } +} + +TEST(StreamKTilePartitionerBaseGetTileLocalCtaIndex, DP2TileSK) +{ + /* + The StreamKTilePartitionerBaseConfigDP2TileSK has the following form: + - tiles in the C tensor: 7 + - iters_per_tile: 3 + - grid: 3 + - dp_tiles: 3 + - sk_tiles: 4 + - iters_per_sk_cta: 2 + - extra_iters: 2 + + The tiles with iters are as follows: + + tile_idx: ____0___|___1___|___2___|___3___|___4___|____5____|____6____| + tile_iter:| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | + | | | | | | | | | | | | | | | + |<-------DP Tiles------>|<------------SK Tiles------------->| + + From the above configuration, we get the following: + - SK CTA 0: tile_iter_start is 6 with local CTA index of 0 in tile 3 + - SK CTA 0: tile_iter_start is 8 with local CTA index of 0 in tile 4 + - SK CTA 1: tile_iter_start is 8 with local CTA index of 1 in tile 4 + - SK CTA 1: tile_iter_start is 10 with local CTA index of 0 in tile 5 + - SK CTA 2: tile_iter_start is 12 with local CTA index of 0 in tile 6 + */ + + // Now we create a vector of triplets (tile_iter_start, cta_idx, tile_local_cta_idx) to test + std::vector> sk_triplets{ + {6, 0, 0}, {8, 0, 0}, {8, 1, 1}, {10, 1, 0}, {12, 2, 0}}; + + for(const auto& triplet : sk_triplets) + { + const auto& [tile_iter_start, cta_idx, tile_local_cta_idx] = triplet; + test_get_tile_local_cta_idx( + tile_iter_start, cta_idx, tile_local_cta_idx); + } +} + // Persistent TEST(StreamKTilePartitioner_PersistentConstructor, SKOnly) { diff --git a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp index 0bb0940651..3daec049a7 100644 --- a/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp +++ b/test/ck_tile/gemm_streamk/test_streamk_tile_partitioner_common.hpp @@ -4,6 +4,7 @@ #include "ck_tile/host.hpp" #include "ck_tile/ops/gemm.hpp" #include "gtest/gtest.h" +#include enum StreamKTilePartitionerBaseMethodId { @@ -12,7 +13,8 @@ enum StreamKTilePartitionerBaseMethodId GET_TILE_BOUNDARIES, GET_TILE_INDEX, GET_ITER_BOUNDARIES, - GET_OUTPUT_TILE_INDEX + GET_OUTPUT_TILE_INDEX, + GET_TILE_LOCAL_CTA_INDEX }; // Base kernel wrapper class to facilitate testing class device functions. @@ -136,6 +138,22 @@ struct KernelWrapperSpecialized +struct KernelWrapperSpecialized + : public KernelWrapper +{ + + using Base = KernelWrapper; + + CK_TILE_DEVICE void operator()(typename Base::KernelArgs kargs) + { + ck_tile::index_t tile_local_cta_index = + kargs.tile_partitioner.get_tile_local_cta_index(kargs.arg1, kargs.arg2); + *(static_cast(kargs.result1)) = tile_local_cta_index; + } +}; + struct StreamKTilePartitionerBaseExpected { ck_tile::index_t sk_tiles_; @@ -243,6 +261,22 @@ struct StreamKTilePartitionerBaseConfigSKOnly : public StreamKTilePartitionerBas ck_tile::sequence>; }; +struct StreamKTilePartitionerBaseConfigSKOnlyLargeK : public StreamKTilePartitionerBaseConfig +{ + static constexpr ck_tile::index_t M = 8; + static constexpr ck_tile::index_t N = 2; + static constexpr ck_tile::index_t K = 10; + static constexpr ck_tile::index_t GRID = 5; + + static constexpr ck_tile::index_t M_TILE = 4; + static constexpr ck_tile::index_t N_TILE = 2; + static constexpr ck_tile::index_t K_TILE = 2; + + using GemmShape = ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; +}; + struct StreamKTilePartitionerBaseConfigEdgeCase : public StreamKTilePartitionerBaseConfig { @@ -314,6 +348,38 @@ void test_get_output_tile_index(ck_tile::index_t tile_idx, EXPECT_EQ(in, in_expected); }; +template +void test_get_tile_local_cta_idx(ck_tile::index_t tile_iter_start, + ck_tile::index_t cta_idx, + ck_tile::index_t expected_tile_local_cta_idx) +{ + // Types + using TilePartitioner = ck_tile::StreamKTilePartitionerBase; + using Kernel = + KernelWrapperSpecialized; + + // Test parameters + ck_tile::StreamKTilePartitionerBase tile_partitioner{ + Config::M, Config::N, Config::K, Config::GRID}; + ck_tile::DeviceMem tile_local_cta_idx_dev(sizeof(ck_tile::index_t)); + + // Launch kernel + auto kargs = Kernel::MakeKernelArgs(tile_iter_start, + cta_idx, + Config::UNUSED, + tile_local_cta_idx_dev.GetDeviceBuffer(), + nullptr, + tile_partitioner); + ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1}, + ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs)); + + // Validate results + ck_tile::index_t tile_local_cta_idx; + tile_local_cta_idx_dev.FromDevice(&tile_local_cta_idx); + EXPECT_EQ(tile_local_cta_idx, expected_tile_local_cta_idx); +} + // Configs for TilePartitioner Child structs struct StreamKTilePartitionerV2PersistentExpected {