[CK_TILE] Stream-K Tree Reduction and Cache Skipping Integration (#3371)

* CK Tile Stream-K Tree Reduction

This change adds the first implementation of the Stream-K tree reduction
strategy into CK Tile. The tree reduction reduces the the number of
steps for accumulating results for a tile from O(N) to O(logN) where N
is the number of workgroups contributing to a C tile.

Additionally, in the original non-atomic reduction strategy, atomics
were used to set the flags buffer and to read from the flags buffer.
Howeover, through investigation with the tree reduciton, atomics with
default (relaxed) semantics were not enough to guarantee workgroups
would not read stale data, leading to incorrect results. Stronger
acquire/release memory orderings are too expensive. So, this change
also eliminates the use of atomics for setting the flags. Instead, we
leverage cache modifiers (e.g., GLC) to avoid writing to cache, thereby
avoiding the use of atomics.

Prelimiary tests were also added for the normal reduction and tree
reduction. More will be added in a future PR via tile engine.

* Move Stream-K kernel files to a subdirectory

* Cleanup Code Style & Handle Unsupported Reductions

This change makes the following small changes:
- Add an explicit else block for unimplemented reduction strategies
- Clarify type of sk_flags_ptr via auto*
- Add description for extra_iters_before_me variable

* Run new copyright script on new files

[ROCm/composable_kernel commit: 22b945e06e]
This commit is contained in:
Emily Martins
2025-12-14 14:49:49 -07:00
committed by GitHub
parent a3270d2eb0
commit eeb78c46a4
13 changed files with 524 additions and 70 deletions

View File

@@ -8,7 +8,8 @@
namespace ck_tile {
enum StreamKReductionStrategy : uint32_t
{
Atomic = 0u,
Reduction = 1u
Atomic = 0u,
Reduction = 1u,
TreeReduction = 2u
};
} // namespace ck_tile

View File

@@ -33,9 +33,10 @@
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_coherency.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp"
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"

View File

@@ -0,0 +1,35 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
namespace ck_tile {
template <typename CompilerTarget, typename Enabler = void>
struct StreamKCoherency
{
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
amd_buffer_coherence_enum::coherence_default;
};
template <typename CompilerTarget>
struct StreamKCoherency<CompilerTarget,
core::arch::enable_if_target_id_t<CompilerTarget,
core::arch::amdgcn_target_id::GFX942,
core::arch::amdgcn_target_id::GFX950>>
{
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
amd_buffer_coherence_enum::SYSTEM_NT0;
};
template <typename CompilerTarget>
struct StreamKCoherency<CompilerTarget,
core::arch::enable_if_target_id_t<CompilerTarget,
core::arch::amdgcn_target_id::GFX908,
core::arch::amdgcn_target_id::GFX90A>>
{
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
amd_buffer_coherence_enum::glc_slc;
};
} // namespace ck_tile

View File

@@ -6,6 +6,7 @@
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
#include "streamk_gemm_coherency.hpp"
namespace ck_tile {
@@ -318,37 +319,58 @@ struct StreamKKernel
* results.
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the current thread block (CTA).
* @note This function utilizes a workgroup barrier to set a synchronization flag for the given
* CTA index.
* @note This function utilizes a scalar store to write to the flags buffer.
*/
CK_TILE_DEVICE void SignalStorePartialDone(const StreamKKernelArgs& kargs,
index_t cta_idx) const
{
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
workgroup_barrier sk_flags(sk_flags_ptr);
sk_flags.wait_set(0, 1, cta_idx);
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
index_t offset = cta_idx * sizeof(index_t);
asm volatile("s_mov_b32 m0, %2\n\t"
// Depending on the architecture, the GLC flag will bypass the approproriate
// cache level(s) to ensure the write is visible to other workgroups. See the
// appropriate ISA for details about the GLC modifier.
"s_store_dword %0, %1, %2 glc\n\t"
"s_waitcnt lgkmcnt(0)" // Wait for the store to complete
:
: "s"(1), "s"(sk_flags_ptr), "s"(offset)
: "memory");
}
/**
* @brief Waits for the thread block (cta_idx) to complete storing its partial results.
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the thread block (CTA).
* @note This function utilizes a workgroup barrier to wait for the synchronization flag to be
* set by the given CTA index.
* @note This function utilizes a scalar load to read from the flags
* buffer.
*/
CK_TILE_DEVICE void WaitStorePartialDone(const StreamKKernelArgs& kargs, index_t cta_idx) const
{
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
workgroup_barrier sk_flags(sk_flags_ptr);
sk_flags.wait_eq(1, cta_idx);
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
index_t result;
index_t offset = cta_idx * sizeof(index_t);
do
{
asm volatile("s_mov_b32 m0, %2\n\t"
// Depending on the architecture, the GLC flag will bypass the
// approproriate cache level(s) to avoid reading stale flags. See the
// appropriate ISA for details about the GLC modifier.
"s_load_dword %0, %1, %2 glc\n\t"
"s_waitcnt lgkmcnt(0)" // Wait for the load to complete
: "=s"(result)
: "s"(sk_flags_ptr), "s"(offset)
: "memory");
} while(result != 1);
}
/**
* @brief Adds the values of a block tile to an output block tile.
* @param in_out_block_tile The output block tile to which values are added.
* @param in_block_tile The input block tile whose values are added.
* @note This function iterates over the distributed spans of the block tiles and updates the
* output block tile with accumulated values.
* @note This function iterates over the distributed spans of the block tiles and updates
* the output block tile with accumulated values.
*/
template <typename OAccTile>
CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile,
@@ -370,8 +392,8 @@ struct StreamKKernel
* @param cta_idx The index of the thread block (CTA).
* @param c_block_tile_dist The tile distribution for the block.
* @return The loaded partial block tile.
* @note This function calculates the buffer pointer and uses the tile distribution for loading
* the partial block tile.
* @note This function calculates the buffer pointer and uses the tile distribution for
* loading the partial block tile.
*/
template <typename DataType, typename OAccTileDist>
CK_TILE_DEVICE auto LoadPartial(const StreamKKernelArgs& kargs,
@@ -405,8 +427,8 @@ struct StreamKKernel
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the thread block (CTA).
* @param c_block_tile The block tile to be stored.
* @note This function calculates the buffer pointer and uses the tile window for storing the
* partial block tile.
* @note This function calculates the buffer pointer and uses the tile window for storing
* the partial block tile.
*/
template <typename OAccTile>
CK_TILE_DEVICE void StorePartial(const StreamKKernelArgs& kargs,
@@ -420,7 +442,10 @@ struct StreamKKernel
kargs.tile_partitioner.get_flags_buffer_size() +
cta_idx * c_block_tile_buffer_size;
const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
const auto& partial_tensor_view = make_naive_tensor_view<
address_space_enum::global,
memory_operation_enum::set,
StreamKCoherency<decltype(core::arch::get_compiler_target())>::BUFFER_COHERENCE>(
static_cast<typename OAccTile::DataType*>(partial_buffer_ptr),
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
make_tuple(TilePartitioner::NPerBlock, 1),
@@ -431,8 +456,11 @@ struct StreamKKernel
partial_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{0, 0});
store_tile(partial_tile_window, c_block_tile);
// Wait for all vector stores for this wavefront to complete
s_waitcnt</*vmcnt*/ 0, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
// Wait for all wavefronts in this workgroup to arrive here before continuing
__builtin_amdgcn_s_barrier();
}
/**
@@ -483,7 +511,8 @@ struct StreamKKernel
{
BaseGemm(kargs, tile_idx, num_loop_sk, i_k_a, i_k_b, k_size, smem_ptr_0);
}
else
else if(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Reduction ||
TilePartitioner::ReductionStrategy == StreamKReductionStrategy::TreeReduction)
{
const auto c_macro_tile_idx =
kargs.tile_partitioner.get_output_tile_index(tile_idx);
@@ -528,46 +557,107 @@ struct StreamKKernel
auto tile_started = iter_start == tile_iter_start;
auto tile_ended = iter_end >= tile_iter_end;
if(!tile_started)
if constexpr(TilePartitioner::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
StorePartial(kargs, cta_idx, c_block_tile);
// Ensure device-wide visibility of partial results stored in global memory
// before signaling completion. __threadfence() guarantees that all global
// memory writes by this thread are visible to other threads on the device.
__threadfence(); // send signal when the store is done
SignalStorePartialDone(kargs, cta_idx);
if(!tile_started)
{
StorePartial(kargs, cta_idx, c_block_tile);
SignalStorePartialDone(kargs, cta_idx);
}
else
{
auto accum_block_tile = c_block_tile;
if(!tile_ended)
{
const index_t iter_per_tile =
kargs.tile_partitioner.get_iters_per_tile();
const index_t iter_per_cta =
kargs.tile_partitioner.get_iters_per_sk_cta();
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
int accum_iters = local_iter_end - local_iter_start;
int next_cta = cta_idx + 1;
while(accum_iters < iter_per_tile)
{
WaitStorePartialDone(kargs, next_cta);
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
AddBlockTile(
accum_block_tile,
LoadPartial<typename BlockType::DataType>(
kargs, next_cta, c_block_tile.get_tile_distribution()));
accum_iters += iter_per_cta + (next_cta < extra_iters);
++next_cta;
}
}
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
EpiloguePipeline{}(
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
}
}
else
else // Tree Reduction
{
auto accum_block_tile = c_block_tile;
if(!tile_ended)
index_t tile_local_cta_idx =
kargs.tile_partitioner.get_tile_local_cta_index(tile_iter_start, cta_idx);
for(index_t stride = 1;; stride <<= 1)
{
const index_t iter_per_tile = kargs.tile_partitioner.get_iters_per_tile();
const index_t iter_per_cta = kargs.tile_partitioner.get_iters_per_sk_cta();
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
int accum_iters = local_iter_end - local_iter_start;
int next_cta = cta_idx + 1;
const index_t partner_cta_idx = cta_idx + stride;
const index_t partner_start_iter =
kargs.tile_partitioner.get_start_iter(partner_cta_idx);
bool partner_in_tile = partner_start_iter < tile_iter_end;
while(accum_iters < iter_per_tile)
// If the partner of the workgroup who started the tile is not in this tile,
// then the work for this tile is done and results can be stored in the C
// tensor.
if(tile_started && !partner_in_tile)
{
WaitStorePartialDone(kargs, next_cta);
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
EpiloguePipeline{}(
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
break;
}
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
AddBlockTile(
accum_block_tile,
LoadPartial<typename BlockType::DataType>(
kargs, next_cta, c_block_tile.get_tile_distribution()));
accum_iters += iter_per_cta + (next_cta < extra_iters);
++next_cta;
// It's this workgroup's turn to read from partials.
if(tile_local_cta_idx % (stride << 1) == 0)
{
// If this workgroup's partner is in the tile then it can read from
// partials and accumulate results.
if(partner_in_tile)
{
WaitStorePartialDone(kargs, partner_cta_idx);
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
AddBlockTile(accum_block_tile,
LoadPartial<typename BlockType::DataType>(
kargs,
partner_cta_idx,
c_block_tile.get_tile_distribution()));
}
}
// Otherwise, it's this workgroup's turn to write to partials. All
// workgroups, except the workgroup who starts the tile, will write to
// partials.
else
{
StorePartial(kargs, cta_idx, accum_block_tile);
SignalStorePartialDone(kargs, cta_idx);
// Once the workgroup writes to partials, it has no more work to do for
// this tile.
break;
}
}
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
EpiloguePipeline{}(
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
}
}
else
{
static_assert(
"An implementation does not exist for the chosen reduction strategy.");
}
// Prepare for next Stream-K loop iteration.
iter_start = tile_iter_end;
@@ -640,10 +730,10 @@ struct StreamKKernel
private:
/**
* @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset is
* the starting macro tile index in the K dimension for the workgroup.
* @return A tuple containing the offsets into the A and B tensors accounting for the layouts
* of A and B.
* @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset
* is the starting macro tile index in the K dimension for the workgroup.
* @return A tuple containing the offsets into the A and B tensors accounting for the
* layouts of A and B.
* @note The default case is that A is assumed to be row major and B is assumed to be column
* major.
*/
@@ -688,7 +778,8 @@ struct StreamKKernel
}
/**
* @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the kernel
* @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the
* kernel
* @return The occupancy
* @note This function queries the maximum occupancy of the kernel using
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.

View File

@@ -46,6 +46,16 @@ struct StreamKTilePartitionerBase
CK_TILE_HOST_DEVICE index_t get_flags_buffer_size() const noexcept;
public:
/**
* @brief Calculates the start iteration for the given the cta_idx.
* @param cta_idx The current Stream-K workgroup's index.
* @return index_t The start iteration.
* @note It is assumed that the first Stream-K workgroup has a `cta_idx` of zero. If a
* non-persistent DP section is used, then a Stream-K workgroup's `cta_idx` should be something
* like `blockIdx.x` minus number of DP workgroups.
*/
CK_TILE_DEVICE index_t get_start_iter(index_t cta_idx) const noexcept;
/**
* @brief Calculates the start and end iteration given the cta_idx.
*
@@ -107,7 +117,17 @@ struct StreamKTilePartitionerBase
get_local_iter_end(index_t tile_iter_start, index_t iter_end, index_t tile_iter_end) noexcept;
/**
* @brief Calculates the workgroups 2D tile index in the C tensor given the 1D tile index.
* @brief Calculates the workgroup's local CTA idx within the given tile.
*
* @param tile_iter_start The starting tile iteration.
* @param cta_idx The Stream-K workgroup index.
* @return index_t The tile local workgroup index in the tile.
*/
CK_TILE_DEVICE index_t get_tile_local_cta_index(index_t tile_iter_start,
index_t cta_idx) const noexcept;
/**
* @brief Calculates the workgroup's 2D tile index in the C tensor given the 1D tile index.
*
* @param tile_idx The 1D tile index in the C tensor for the workgroup.
* @return index_t The corresponding 2D tile index in the C tensor for the workgroup.

View File

@@ -61,13 +61,24 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_flags
return sizeof(index_t) * sk_ctas_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_start_iter(
index_t cta_idx) const noexcept
{
// Compute the number of extra iterations done before this CTA. If the cta_idx is less than
// extra_iters, the number of extra iterations before the CTA is exactly the cta_idx. Otherwise,
// it is extra_iters.
index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
return total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_DEVICE void
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_iter_boundaries(
index_t& iter, index_t& iter_end, index_t cta_idx) const noexcept
{
index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
iter = total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
iter = get_start_iter(cta_idx);
iter_end = iter + iters_per_sk_cta_ + (cta_idx < extra_iters_);
}
@@ -104,6 +115,24 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_local
return ck_tile::min(iter_end, tile_iter_end) - tile_iter;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_tile_local_cta_index(
index_t tile_iter_start, index_t cta_idx) const noexcept
{
tile_iter_start = tile_iter_start - (dp_tiles_ * iters_per_tile_);
// Compute how many WGs fit before this tile starts assuming each WG does an
// extra_iter
const index_t num_extra_iter_ctas = tile_iter_start / (iters_per_sk_cta_ + 1);
// Compute how many WGs fit before this tile starts excluding extra iters
const index_t num_non_extra_iter_ctas = (tile_iter_start - extra_iters_) / iters_per_sk_cta_;
// Compute the CTA idx for the CTA that starts this tile
const index_t coop_group_start =
num_extra_iter_ctas < extra_iters_ ? num_extra_iter_ctas : num_non_extra_iter_ctas;
return cta_idx - coop_group_start;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_DEVICE auto
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_output_tile_index(
@@ -121,7 +150,8 @@ CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_workspace_size(
index_t acc_element_bytes) const noexcept
{
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction ||
ReductionStrategy == StreamKReductionStrategy::TreeReduction)
{
return get_partials_buffer_size(acc_element_bytes) + get_flags_buffer_size();

View File

@@ -23,6 +23,9 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
#TODO: support all arches
#TODO: current c-shuffle only supports C layout as R
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
add_gtest_executable(test_ck_tile_streamk_reduction
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp
test_gemm_streamk_util.cpp)
add_gtest_executable(test_ck_tile_streamk_smoke
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf16_persistent.cpp

View File

@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp16Reduction : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp16Reduction
TYPED_TEST_SUITE(TestCkTileStreamKFp16Reduction, KernelTypesStreamKFp16Reduction);
#include "test_gemm_streamk_reduction_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,88 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile_Tree)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile * num_cu;
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile * num_cu;
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Tree)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile * 4;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile);
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Reduction)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile * 4;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile);
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles_Tree)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile * 3;
ck_tile::index_t N = N_Tile * 7;
ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile);
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile * 3;
ck_tile::index_t N = N_Tile * 7;
ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile);
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction);
}

View File

@@ -33,6 +33,14 @@ using KernelTypesStreamKFp16Persistent = ::testing::Types<
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>
>;
using KernelTypesStreamKFp16Reduction = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>>;
using KernelTypesStreamKBf16Persistent = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,

View File

@@ -144,7 +144,11 @@ class TestCkTileStreamK : public ::testing::Test
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
auto kargs = Kernel::MakeKernelArgs(args);
const auto workspace_size = Kernel::GetWorkSpaceSize(kargs);
ck_tile::DeviceMem workspace_data(workspace_size);
workspace_data.SetZero();
kargs.workspace_ptr = workspace_data.GetDeviceBuffer();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -184,11 +188,6 @@ class TestCkTileStreamK : public ::testing::Test
using namespace ck_tile::literals;
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction)
{
throw std::runtime_error("Reduction Strategy is current unsupported!\n");
}
auto f_host_tensor_descriptor = [](std::size_t row,
std::size_t col,
std::size_t stride,
@@ -252,9 +251,25 @@ class TestCkTileStreamK : public ::testing::Test
stride_B,
stride_C};
ck_tile::index_t num_accumulations_per_tile =
invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
ck_tile::index_t num_accumulations_per_tile;
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
{
num_accumulations_per_tile = invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
}
else if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction)
{
num_accumulations_per_tile =
invoke_streamk<ck_tile::StreamKReductionStrategy::Reduction>(
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
}
else
{
num_accumulations_per_tile =
invoke_streamk<ck_tile::StreamKReductionStrategy::TreeReduction>(
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
}
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());

View File

@@ -372,6 +372,85 @@ TEST(StreamKTilePartitionerBaseGetOutputTileIndex, TestAllMappings)
}
}
TEST(StreamKTilePartitionerBaseGetTileLocalCtaIndex, SKOnlyLargeK)
{
/*
The StreamKTilePartitionerBaseConfigSKOnlyLargeK has the following form:
- tiles in the C tensor: 2
- iters_per_tile: 5
- grid: 5
- dp_tiles: 0
- sk_tiles: 2
- iters_per_sk_cta: 2
- extra_iters: 0
The tiles with iters are as follows:
tile_idx: __________0_________|_________1_________|
tile_iter:| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
| | | | | | | | | | |
<---------------SK Tiles--------------->|
From the above configuration, we get the following:
- SK CTA 0: tile_iter_start is 0 with local CTA index of 0 in tile 0
- SK CTA 1: tile_iter_start is 0 with local CTA index of 1 in tile 0
- SK CTA 2: tile_iter_start is 0 with local CTA index of 2 in tile 0
- SK CTA 2: tile_iter_start is 5 with local CTA index of 0 in tile 1
- SK CTA 3: tile_iter_start is 5 with local CTA index of 1 in tile 1
- SK CTA 4: tile_iter_start is 5 with local CTA index of 2 in tile 1
*/
// Now we create a vector of triplets (tile_iter_start, cta_idx, tile_local_cta_idx) to test
std::vector<std::array<ck_tile::index_t, 3>> sk_triplets{
{0, 0, 0}, {0, 1, 1}, {0, 2, 2}, {5, 2, 0}, {5, 3, 1}, {5, 4, 2}};
for(const auto& triplet : sk_triplets)
{
const auto& [tile_iter_start, cta_idx, tile_local_cta_idx] = triplet;
test_get_tile_local_cta_idx<StreamKTilePartitionerBaseConfigSKOnlyLargeK>(
tile_iter_start, cta_idx, tile_local_cta_idx);
}
}
TEST(StreamKTilePartitionerBaseGetTileLocalCtaIndex, DP2TileSK)
{
/*
The StreamKTilePartitionerBaseConfigDP2TileSK has the following form:
- tiles in the C tensor: 7
- iters_per_tile: 3
- grid: 3
- dp_tiles: 3
- sk_tiles: 4
- iters_per_sk_cta: 2
- extra_iters: 2
The tiles with iters are as follows:
tile_idx: ____0___|___1___|___2___|___3___|___4___|____5____|____6____|
tile_iter:| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 |
| | | | | | | | | | | | | | |
|<-------DP Tiles------>|<------------SK Tiles------------->|
From the above configuration, we get the following:
- SK CTA 0: tile_iter_start is 6 with local CTA index of 0 in tile 3
- SK CTA 0: tile_iter_start is 8 with local CTA index of 0 in tile 4
- SK CTA 1: tile_iter_start is 8 with local CTA index of 1 in tile 4
- SK CTA 1: tile_iter_start is 10 with local CTA index of 0 in tile 5
- SK CTA 2: tile_iter_start is 12 with local CTA index of 0 in tile 6
*/
// Now we create a vector of triplets (tile_iter_start, cta_idx, tile_local_cta_idx) to test
std::vector<std::array<ck_tile::index_t, 3>> sk_triplets{
{6, 0, 0}, {8, 0, 0}, {8, 1, 1}, {10, 1, 0}, {12, 2, 0}};
for(const auto& triplet : sk_triplets)
{
const auto& [tile_iter_start, cta_idx, tile_local_cta_idx] = triplet;
test_get_tile_local_cta_idx<StreamKTilePartitionerBaseConfigDP2TileSK>(
tile_iter_start, cta_idx, tile_local_cta_idx);
}
}
// Persistent
TEST(StreamKTilePartitioner_PersistentConstructor, SKOnly)
{

View File

@@ -4,6 +4,7 @@
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "gtest/gtest.h"
#include <array>
enum StreamKTilePartitionerBaseMethodId
{
@@ -12,7 +13,8 @@ enum StreamKTilePartitionerBaseMethodId
GET_TILE_BOUNDARIES,
GET_TILE_INDEX,
GET_ITER_BOUNDARIES,
GET_OUTPUT_TILE_INDEX
GET_OUTPUT_TILE_INDEX,
GET_TILE_LOCAL_CTA_INDEX
};
// Base kernel wrapper class to facilitate testing class device functions.
@@ -136,6 +138,22 @@ struct KernelWrapperSpecialized<TilePartitioner,
}
};
template <typename TilePartitioner>
struct KernelWrapperSpecialized<TilePartitioner,
StreamKTilePartitionerBaseMethodId::GET_TILE_LOCAL_CTA_INDEX>
: public KernelWrapper<TilePartitioner>
{
using Base = KernelWrapper<TilePartitioner>;
CK_TILE_DEVICE void operator()(typename Base::KernelArgs kargs)
{
ck_tile::index_t tile_local_cta_index =
kargs.tile_partitioner.get_tile_local_cta_index(kargs.arg1, kargs.arg2);
*(static_cast<ck_tile::index_t*>(kargs.result1)) = tile_local_cta_index;
}
};
struct StreamKTilePartitionerBaseExpected
{
ck_tile::index_t sk_tiles_;
@@ -243,6 +261,22 @@ struct StreamKTilePartitionerBaseConfigSKOnly : public StreamKTilePartitionerBas
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
};
struct StreamKTilePartitionerBaseConfigSKOnlyLargeK : public StreamKTilePartitionerBaseConfig
{
static constexpr ck_tile::index_t M = 8;
static constexpr ck_tile::index_t N = 2;
static constexpr ck_tile::index_t K = 10;
static constexpr ck_tile::index_t GRID = 5;
static constexpr ck_tile::index_t M_TILE = 4;
static constexpr ck_tile::index_t N_TILE = 2;
static constexpr ck_tile::index_t K_TILE = 2;
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
};
struct StreamKTilePartitionerBaseConfigEdgeCase : public StreamKTilePartitionerBaseConfig
{
@@ -314,6 +348,38 @@ void test_get_output_tile_index(ck_tile::index_t tile_idx,
EXPECT_EQ(in, in_expected);
};
template <typename Config>
void test_get_tile_local_cta_idx(ck_tile::index_t tile_iter_start,
ck_tile::index_t cta_idx,
ck_tile::index_t expected_tile_local_cta_idx)
{
// Types
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<typename Config::GemmShape>;
using Kernel =
KernelWrapperSpecialized<TilePartitioner,
StreamKTilePartitionerBaseMethodId::GET_TILE_LOCAL_CTA_INDEX>;
// Test parameters
ck_tile::StreamKTilePartitionerBase<typename Config::GemmShape> tile_partitioner{
Config::M, Config::N, Config::K, Config::GRID};
ck_tile::DeviceMem tile_local_cta_idx_dev(sizeof(ck_tile::index_t));
// Launch kernel
auto kargs = Kernel::MakeKernelArgs(tile_iter_start,
cta_idx,
Config::UNUSED,
tile_local_cta_idx_dev.GetDeviceBuffer(),
nullptr,
tile_partitioner);
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
// Validate results
ck_tile::index_t tile_local_cta_idx;
tile_local_cta_idx_dev.FromDevice(&tile_local_cta_idx);
EXPECT_EQ(tile_local_cta_idx, expected_tile_local_cta_idx);
}
// Configs for TilePartitioner Child structs
struct StreamKTilePartitionerV2PersistentExpected
{