[CK_TILE] Stream-K Tree Reduction and Cache Skipping Integration (#3371)

* CK Tile Stream-K Tree Reduction

This change adds the first implementation of the Stream-K tree reduction
strategy into CK Tile. The tree reduction reduces the the number of
steps for accumulating results for a tile from O(N) to O(logN) where N
is the number of workgroups contributing to a C tile.

Additionally, in the original non-atomic reduction strategy, atomics
were used to set the flags buffer and to read from the flags buffer.
Howeover, through investigation with the tree reduciton, atomics with
default (relaxed) semantics were not enough to guarantee workgroups
would not read stale data, leading to incorrect results. Stronger
acquire/release memory orderings are too expensive. So, this change
also eliminates the use of atomics for setting the flags. Instead, we
leverage cache modifiers (e.g., GLC) to avoid writing to cache, thereby
avoiding the use of atomics.

Prelimiary tests were also added for the normal reduction and tree
reduction. More will be added in a future PR via tile engine.

* Move Stream-K kernel files to a subdirectory

* Cleanup Code Style & Handle Unsupported Reductions

This change makes the following small changes:
- Add an explicit else block for unimplemented reduction strategies
- Clarify type of sk_flags_ptr via auto*
- Add description for extra_iters_before_me variable

* Run new copyright script on new files
This commit is contained in:
Emily Martins
2025-12-14 14:49:49 -07:00
committed by GitHub
parent 9ac51aa0f4
commit 22b945e06e
13 changed files with 524 additions and 70 deletions

View File

@@ -8,7 +8,8 @@
namespace ck_tile {
enum StreamKReductionStrategy : uint32_t
{
Atomic = 0u,
Reduction = 1u
Atomic = 0u,
Reduction = 1u,
TreeReduction = 2u
};
} // namespace ck_tile

View File

@@ -33,9 +33,10 @@
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_coherency.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp"
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"

View File

@@ -0,0 +1,35 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
namespace ck_tile {
template <typename CompilerTarget, typename Enabler = void>
struct StreamKCoherency
{
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
amd_buffer_coherence_enum::coherence_default;
};
template <typename CompilerTarget>
struct StreamKCoherency<CompilerTarget,
core::arch::enable_if_target_id_t<CompilerTarget,
core::arch::amdgcn_target_id::GFX942,
core::arch::amdgcn_target_id::GFX950>>
{
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
amd_buffer_coherence_enum::SYSTEM_NT0;
};
template <typename CompilerTarget>
struct StreamKCoherency<CompilerTarget,
core::arch::enable_if_target_id_t<CompilerTarget,
core::arch::amdgcn_target_id::GFX908,
core::arch::amdgcn_target_id::GFX90A>>
{
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
amd_buffer_coherence_enum::glc_slc;
};
} // namespace ck_tile

View File

@@ -6,6 +6,7 @@
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
#include "streamk_gemm_coherency.hpp"
namespace ck_tile {
@@ -318,37 +319,58 @@ struct StreamKKernel
* results.
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the current thread block (CTA).
* @note This function utilizes a workgroup barrier to set a synchronization flag for the given
* CTA index.
* @note This function utilizes a scalar store to write to the flags buffer.
*/
CK_TILE_DEVICE void SignalStorePartialDone(const StreamKKernelArgs& kargs,
index_t cta_idx) const
{
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
workgroup_barrier sk_flags(sk_flags_ptr);
sk_flags.wait_set(0, 1, cta_idx);
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
index_t offset = cta_idx * sizeof(index_t);
asm volatile("s_mov_b32 m0, %2\n\t"
// Depending on the architecture, the GLC flag will bypass the approproriate
// cache level(s) to ensure the write is visible to other workgroups. See the
// appropriate ISA for details about the GLC modifier.
"s_store_dword %0, %1, %2 glc\n\t"
"s_waitcnt lgkmcnt(0)" // Wait for the store to complete
:
: "s"(1), "s"(sk_flags_ptr), "s"(offset)
: "memory");
}
/**
* @brief Waits for the thread block (cta_idx) to complete storing its partial results.
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the thread block (CTA).
* @note This function utilizes a workgroup barrier to wait for the synchronization flag to be
* set by the given CTA index.
* @note This function utilizes a scalar load to read from the flags
* buffer.
*/
CK_TILE_DEVICE void WaitStorePartialDone(const StreamKKernelArgs& kargs, index_t cta_idx) const
{
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
workgroup_barrier sk_flags(sk_flags_ptr);
sk_flags.wait_eq(1, cta_idx);
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
index_t result;
index_t offset = cta_idx * sizeof(index_t);
do
{
asm volatile("s_mov_b32 m0, %2\n\t"
// Depending on the architecture, the GLC flag will bypass the
// approproriate cache level(s) to avoid reading stale flags. See the
// appropriate ISA for details about the GLC modifier.
"s_load_dword %0, %1, %2 glc\n\t"
"s_waitcnt lgkmcnt(0)" // Wait for the load to complete
: "=s"(result)
: "s"(sk_flags_ptr), "s"(offset)
: "memory");
} while(result != 1);
}
/**
* @brief Adds the values of a block tile to an output block tile.
* @param in_out_block_tile The output block tile to which values are added.
* @param in_block_tile The input block tile whose values are added.
* @note This function iterates over the distributed spans of the block tiles and updates the
* output block tile with accumulated values.
* @note This function iterates over the distributed spans of the block tiles and updates
* the output block tile with accumulated values.
*/
template <typename OAccTile>
CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile,
@@ -370,8 +392,8 @@ struct StreamKKernel
* @param cta_idx The index of the thread block (CTA).
* @param c_block_tile_dist The tile distribution for the block.
* @return The loaded partial block tile.
* @note This function calculates the buffer pointer and uses the tile distribution for loading
* the partial block tile.
* @note This function calculates the buffer pointer and uses the tile distribution for
* loading the partial block tile.
*/
template <typename DataType, typename OAccTileDist>
CK_TILE_DEVICE auto LoadPartial(const StreamKKernelArgs& kargs,
@@ -405,8 +427,8 @@ struct StreamKKernel
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the thread block (CTA).
* @param c_block_tile The block tile to be stored.
* @note This function calculates the buffer pointer and uses the tile window for storing the
* partial block tile.
* @note This function calculates the buffer pointer and uses the tile window for storing
* the partial block tile.
*/
template <typename OAccTile>
CK_TILE_DEVICE void StorePartial(const StreamKKernelArgs& kargs,
@@ -420,7 +442,10 @@ struct StreamKKernel
kargs.tile_partitioner.get_flags_buffer_size() +
cta_idx * c_block_tile_buffer_size;
const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
const auto& partial_tensor_view = make_naive_tensor_view<
address_space_enum::global,
memory_operation_enum::set,
StreamKCoherency<decltype(core::arch::get_compiler_target())>::BUFFER_COHERENCE>(
static_cast<typename OAccTile::DataType*>(partial_buffer_ptr),
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
make_tuple(TilePartitioner::NPerBlock, 1),
@@ -431,8 +456,11 @@ struct StreamKKernel
partial_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{0, 0});
store_tile(partial_tile_window, c_block_tile);
// Wait for all vector stores for this wavefront to complete
s_waitcnt</*vmcnt*/ 0, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
// Wait for all wavefronts in this workgroup to arrive here before continuing
__builtin_amdgcn_s_barrier();
}
/**
@@ -483,7 +511,8 @@ struct StreamKKernel
{
BaseGemm(kargs, tile_idx, num_loop_sk, i_k_a, i_k_b, k_size, smem_ptr_0);
}
else
else if(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Reduction ||
TilePartitioner::ReductionStrategy == StreamKReductionStrategy::TreeReduction)
{
const auto c_macro_tile_idx =
kargs.tile_partitioner.get_output_tile_index(tile_idx);
@@ -528,46 +557,107 @@ struct StreamKKernel
auto tile_started = iter_start == tile_iter_start;
auto tile_ended = iter_end >= tile_iter_end;
if(!tile_started)
if constexpr(TilePartitioner::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
StorePartial(kargs, cta_idx, c_block_tile);
// Ensure device-wide visibility of partial results stored in global memory
// before signaling completion. __threadfence() guarantees that all global
// memory writes by this thread are visible to other threads on the device.
__threadfence(); // send signal when the store is done
SignalStorePartialDone(kargs, cta_idx);
if(!tile_started)
{
StorePartial(kargs, cta_idx, c_block_tile);
SignalStorePartialDone(kargs, cta_idx);
}
else
{
auto accum_block_tile = c_block_tile;
if(!tile_ended)
{
const index_t iter_per_tile =
kargs.tile_partitioner.get_iters_per_tile();
const index_t iter_per_cta =
kargs.tile_partitioner.get_iters_per_sk_cta();
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
int accum_iters = local_iter_end - local_iter_start;
int next_cta = cta_idx + 1;
while(accum_iters < iter_per_tile)
{
WaitStorePartialDone(kargs, next_cta);
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
AddBlockTile(
accum_block_tile,
LoadPartial<typename BlockType::DataType>(
kargs, next_cta, c_block_tile.get_tile_distribution()));
accum_iters += iter_per_cta + (next_cta < extra_iters);
++next_cta;
}
}
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
EpiloguePipeline{}(
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
}
}
else
else // Tree Reduction
{
auto accum_block_tile = c_block_tile;
if(!tile_ended)
index_t tile_local_cta_idx =
kargs.tile_partitioner.get_tile_local_cta_index(tile_iter_start, cta_idx);
for(index_t stride = 1;; stride <<= 1)
{
const index_t iter_per_tile = kargs.tile_partitioner.get_iters_per_tile();
const index_t iter_per_cta = kargs.tile_partitioner.get_iters_per_sk_cta();
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
int accum_iters = local_iter_end - local_iter_start;
int next_cta = cta_idx + 1;
const index_t partner_cta_idx = cta_idx + stride;
const index_t partner_start_iter =
kargs.tile_partitioner.get_start_iter(partner_cta_idx);
bool partner_in_tile = partner_start_iter < tile_iter_end;
while(accum_iters < iter_per_tile)
// If the partner of the workgroup who started the tile is not in this tile,
// then the work for this tile is done and results can be stored in the C
// tensor.
if(tile_started && !partner_in_tile)
{
WaitStorePartialDone(kargs, next_cta);
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
EpiloguePipeline{}(
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
break;
}
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
AddBlockTile(
accum_block_tile,
LoadPartial<typename BlockType::DataType>(
kargs, next_cta, c_block_tile.get_tile_distribution()));
accum_iters += iter_per_cta + (next_cta < extra_iters);
++next_cta;
// It's this workgroup's turn to read from partials.
if(tile_local_cta_idx % (stride << 1) == 0)
{
// If this workgroup's partner is in the tile then it can read from
// partials and accumulate results.
if(partner_in_tile)
{
WaitStorePartialDone(kargs, partner_cta_idx);
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
AddBlockTile(accum_block_tile,
LoadPartial<typename BlockType::DataType>(
kargs,
partner_cta_idx,
c_block_tile.get_tile_distribution()));
}
}
// Otherwise, it's this workgroup's turn to write to partials. All
// workgroups, except the workgroup who starts the tile, will write to
// partials.
else
{
StorePartial(kargs, cta_idx, accum_block_tile);
SignalStorePartialDone(kargs, cta_idx);
// Once the workgroup writes to partials, it has no more work to do for
// this tile.
break;
}
}
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
EpiloguePipeline{}(
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
}
}
else
{
static_assert(
"An implementation does not exist for the chosen reduction strategy.");
}
// Prepare for next Stream-K loop iteration.
iter_start = tile_iter_end;
@@ -640,10 +730,10 @@ struct StreamKKernel
private:
/**
* @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset is
* the starting macro tile index in the K dimension for the workgroup.
* @return A tuple containing the offsets into the A and B tensors accounting for the layouts
* of A and B.
* @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset
* is the starting macro tile index in the K dimension for the workgroup.
* @return A tuple containing the offsets into the A and B tensors accounting for the
* layouts of A and B.
* @note The default case is that A is assumed to be row major and B is assumed to be column
* major.
*/
@@ -688,7 +778,8 @@ struct StreamKKernel
}
/**
* @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the kernel
* @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the
* kernel
* @return The occupancy
* @note This function queries the maximum occupancy of the kernel using
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.

View File

@@ -46,6 +46,16 @@ struct StreamKTilePartitionerBase
CK_TILE_HOST_DEVICE index_t get_flags_buffer_size() const noexcept;
public:
/**
* @brief Calculates the start iteration for the given the cta_idx.
* @param cta_idx The current Stream-K workgroup's index.
* @return index_t The start iteration.
* @note It is assumed that the first Stream-K workgroup has a `cta_idx` of zero. If a
* non-persistent DP section is used, then a Stream-K workgroup's `cta_idx` should be something
* like `blockIdx.x` minus number of DP workgroups.
*/
CK_TILE_DEVICE index_t get_start_iter(index_t cta_idx) const noexcept;
/**
* @brief Calculates the start and end iteration given the cta_idx.
*
@@ -107,7 +117,17 @@ struct StreamKTilePartitionerBase
get_local_iter_end(index_t tile_iter_start, index_t iter_end, index_t tile_iter_end) noexcept;
/**
* @brief Calculates the workgroups 2D tile index in the C tensor given the 1D tile index.
* @brief Calculates the workgroup's local CTA idx within the given tile.
*
* @param tile_iter_start The starting tile iteration.
* @param cta_idx The Stream-K workgroup index.
* @return index_t The tile local workgroup index in the tile.
*/
CK_TILE_DEVICE index_t get_tile_local_cta_index(index_t tile_iter_start,
index_t cta_idx) const noexcept;
/**
* @brief Calculates the workgroup's 2D tile index in the C tensor given the 1D tile index.
*
* @param tile_idx The 1D tile index in the C tensor for the workgroup.
* @return index_t The corresponding 2D tile index in the C tensor for the workgroup.

View File

@@ -61,13 +61,24 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_flags
return sizeof(index_t) * sk_ctas_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_start_iter(
index_t cta_idx) const noexcept
{
// Compute the number of extra iterations done before this CTA. If the cta_idx is less than
// extra_iters, the number of extra iterations before the CTA is exactly the cta_idx. Otherwise,
// it is extra_iters.
index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
return total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_DEVICE void
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_iter_boundaries(
index_t& iter, index_t& iter_end, index_t cta_idx) const noexcept
{
index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
iter = total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
iter = get_start_iter(cta_idx);
iter_end = iter + iters_per_sk_cta_ + (cta_idx < extra_iters_);
}
@@ -104,6 +115,24 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_local
return ck_tile::min(iter_end, tile_iter_end) - tile_iter;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_tile_local_cta_index(
index_t tile_iter_start, index_t cta_idx) const noexcept
{
tile_iter_start = tile_iter_start - (dp_tiles_ * iters_per_tile_);
// Compute how many WGs fit before this tile starts assuming each WG does an
// extra_iter
const index_t num_extra_iter_ctas = tile_iter_start / (iters_per_sk_cta_ + 1);
// Compute how many WGs fit before this tile starts excluding extra iters
const index_t num_non_extra_iter_ctas = (tile_iter_start - extra_iters_) / iters_per_sk_cta_;
// Compute the CTA idx for the CTA that starts this tile
const index_t coop_group_start =
num_extra_iter_ctas < extra_iters_ ? num_extra_iter_ctas : num_non_extra_iter_ctas;
return cta_idx - coop_group_start;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_DEVICE auto
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_output_tile_index(
@@ -121,7 +150,8 @@ CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_workspace_size(
index_t acc_element_bytes) const noexcept
{
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction ||
ReductionStrategy == StreamKReductionStrategy::TreeReduction)
{
return get_partials_buffer_size(acc_element_bytes) + get_flags_buffer_size();