diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp index b986f2cb37..ac83babeb6 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp @@ -75,6 +75,8 @@ struct StreamKKernel using TilePartitioner = TilePartitioner_; using GemmPipeline = GemmPipeline_; using EpiloguePipeline = EpiloguePipeline_; + using WarpGemm = typename GemmPipeline::BlockGemm::WarpGemm; + using BlockGemmShape = typename GemmPipeline::BlockGemmShape; static_assert( TilePartitioner::PERSISTENT == PersistentDP, @@ -156,7 +158,7 @@ struct StreamKKernel { // clang-format off using P_ = GemmPipeline; - using WarpTile = typename P_::BlockGemmShape::WarpTile; + using WarpTile = typename BlockGemmShape::WarpTile; return concat('_', "streamk", gemm_prec_str(), concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock), @@ -407,16 +409,84 @@ struct StreamKKernel static_cast(partial_buffer_ptr), make_tuple(number{}, number{}), make_tuple(TilePartitioner::NPerBlock, 1), - number{}, + number{}, number<1>{}); auto partial_tile_window = make_tile_window( partial_tensor_view, make_tuple(number{}, number{}), {0, 0}, - c_block_tile_dist); + MakePartialsDistribution()); - return load_tile(partial_tile_window); + auto partials_tile = load_tile(partial_tile_window); + + // Since the partials distribution is not the same as the C block distribution, we must + // describe the contents in the partials tile with the C block distribution. + // Note: The data assigned to threads does not change between distributions. + auto partials_tile_with_c_distr = make_static_distributed_tensor( + c_block_tile_dist, partials_tile.get_thread_buffer()); + + return partials_tile_with_c_distr; + } + + /** + * @brief Returns the vector size to be used for reading from and writing to partials. + * @return The vector size + */ + CK_TILE_DEVICE static constexpr index_t GetVectorSizePartials() + { + // We use kCM1PerLane from the C register layout of the warp GEMM which corresponds to the + // maximum vector width + return WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane; + } + + /** + * @brief Returns distribution used for reading from and writing to partials. + * @return The distribution. + * @note This will result in optimized reads from and writes to partials when C is row major. + * Additional functionality should be added to ensure optimized accesses to partials when C is + * column major. Since the C-Shuffle epilogue only supports C as row major, this is not a + * current limitation. + */ + CK_TILE_DEVICE static constexpr auto MakePartialsDistribution() + { + // Create the encoding to describe waves within a block + constexpr index_t m_warp = BlockGemmShape::BlockWarps::at(number<0>{}); + constexpr index_t n_warp = BlockGemmShape::BlockWarps::at(number<1>{}); + + constexpr index_t m_iter_per_warp = TilePartitioner::MPerBlock / (m_warp * WarpGemm::kM); + constexpr index_t n_iter_per_warp = TilePartitioner::NPerBlock / (n_warp * WarpGemm::kN); + + constexpr auto partials_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + // Create the encoding to describe threads within a wave + constexpr index_t vector_size = GetVectorSizePartials(); + constexpr index_t m_warp_repeat = WarpGemm::WarpGemmAttribute::Impl::kCM0PerLane; + constexpr index_t warp_tile_n_threads = WarpGemm::kN / vector_size; + constexpr index_t warp_tile_m_threads = get_warp_size() / warp_tile_n_threads; + + // This inner encoding ensures that contiguous threads perform vectorized writes along the + // same row in C. + constexpr auto partials_inner_dstr_encoding = + tile_distribution_encoding, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 1>>{}; + + // Combine the outer and inner encoding + constexpr auto partials_dstr_encode = detail::make_embed_tile_distribution_encoding( + partials_outer_dstr_encoding, partials_inner_dstr_encoding); + + return make_static_tile_distribution(partials_dstr_encode); } /** @@ -446,14 +516,22 @@ struct StreamKKernel static_cast(partial_buffer_ptr), make_tuple(number{}, number{}), make_tuple(TilePartitioner::NPerBlock, 1), - number{}, + number{}, number<1>{}); auto partial_tile_window = make_tile_window( partial_tensor_view, make_tuple(number{}, number{}), - {0, 0}); - store_tile(partial_tile_window, c_block_tile); + {0, 0}, + MakePartialsDistribution()); + + // Since the C block distribution is not the same as the partials distribution, we must + // describe the contents in the c_block_tile with the partials distribution. + // Note: The data assigned to threads does not change between distributions. + auto c_with_partials_dist = make_static_distributed_tensor( + MakePartialsDistribution(), c_block_tile.get_thread_buffer()); + + store_tile(partial_tile_window, c_with_partials_dist); // Wait for all vector stores for this wavefront to complete s_waitcnt(); // Wait for all wavefronts in this workgroup to arrive here before continuing @@ -591,16 +669,19 @@ struct StreamKKernel } else // Tree Reduction { - auto accum_block_tile = c_block_tile; - index_t tile_local_cta_idx = - kargs.tile_partitioner.get_tile_local_cta_index(tile_iter_start, cta_idx); + auto accum_block_tile = c_block_tile; + index_t tile_local_cta_idx = amd_wave_read_first_lane( + kargs.tile_partitioner.get_tile_local_cta_index(tile_iter_start, cta_idx)); - for(index_t stride = 1;; stride <<= 1) + index_t stride = amd_wave_read_first_lane(1); + + for(;; stride <<= 1) { - const index_t partner_cta_idx = cta_idx + stride; - const index_t partner_start_iter = - kargs.tile_partitioner.get_start_iter(partner_cta_idx); - bool partner_in_tile = partner_start_iter < tile_iter_end; + const index_t partner_cta_idx = amd_wave_read_first_lane(cta_idx + stride); + const index_t partner_start_iter = amd_wave_read_first_lane( + kargs.tile_partitioner.get_start_iter(partner_cta_idx)); + bool partner_in_tile = + amd_wave_read_first_lane(partner_start_iter < tile_iter_end); // If the partner of the workgroup who started the tile is not in this tile, // then the work for this tile is done and results can be stored in the C diff --git a/test/ck_tile/gemm_streamk_tile_engine/generate_configs.py b/test/ck_tile/gemm_streamk_tile_engine/generate_configs.py index a485f64ade..166ea940c3 100644 --- a/test/ck_tile/gemm_streamk_tile_engine/generate_configs.py +++ b/test/ck_tile/gemm_streamk_tile_engine/generate_configs.py @@ -21,9 +21,9 @@ class TileConfig: warp_m: List[int] = field(default_factory=lambda: [2]) warp_n: List[int] = field(default_factory=lambda: [2]) warp_k: List[int] = field(default_factory=lambda: [1]) - warp_tile_m: List[int] = field(default_factory=lambda: [32]) - warp_tile_n: List[int] = field(default_factory=lambda: [32]) - warp_tile_k: List[int] = field(default_factory=lambda: [16]) + warp_tile_m: List[int] = field(default_factory=lambda: [16, 32]) + warp_tile_n: List[int] = field(default_factory=lambda: [16, 32]) + warp_tile_k: List[int] = field(default_factory=lambda: [8, 16, 32]) def to_dict(self) -> Dict: return {k: {"values": v} for k, v in asdict(self).items()}