Merge commit '22b945e06ea4b4de188d7ff4ec7ae4bf127be9f9' into develop

This commit is contained in:
assistant-librarian[bot]
2025-12-14 22:12:40 +00:00
parent ca5fb0a3b7
commit ea731b5f29
13 changed files with 524 additions and 70 deletions

View File

@@ -8,7 +8,8 @@
namespace ck_tile {
enum StreamKReductionStrategy : uint32_t
{
Atomic = 0u,
Reduction = 1u
Atomic = 0u,
Reduction = 1u,
TreeReduction = 2u
};
} // namespace ck_tile

View File

@@ -33,9 +33,10 @@
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_coherency.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp"
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"

View File

@@ -0,0 +1,35 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core/arch/arch.hpp"
namespace ck_tile {
template <typename CompilerTarget, typename Enabler = void>
struct StreamKCoherency
{
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
amd_buffer_coherence_enum::coherence_default;
};
template <typename CompilerTarget>
struct StreamKCoherency<CompilerTarget,
core::arch::enable_if_target_id_t<CompilerTarget,
core::arch::amdgcn_target_id::GFX942,
core::arch::amdgcn_target_id::GFX950>>
{
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
amd_buffer_coherence_enum::SYSTEM_NT0;
};
template <typename CompilerTarget>
struct StreamKCoherency<CompilerTarget,
core::arch::enable_if_target_id_t<CompilerTarget,
core::arch::amdgcn_target_id::GFX908,
core::arch::amdgcn_target_id::GFX90A>>
{
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
amd_buffer_coherence_enum::glc_slc;
};
} // namespace ck_tile

View File

@@ -6,6 +6,7 @@
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
#include "streamk_gemm_coherency.hpp"
namespace ck_tile {
@@ -318,37 +319,58 @@ struct StreamKKernel
* results.
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the current thread block (CTA).
* @note This function utilizes a workgroup barrier to set a synchronization flag for the given
* CTA index.
* @note This function utilizes a scalar store to write to the flags buffer.
*/
CK_TILE_DEVICE void SignalStorePartialDone(const StreamKKernelArgs& kargs,
index_t cta_idx) const
{
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
workgroup_barrier sk_flags(sk_flags_ptr);
sk_flags.wait_set(0, 1, cta_idx);
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
index_t offset = cta_idx * sizeof(index_t);
asm volatile("s_mov_b32 m0, %2\n\t"
// Depending on the architecture, the GLC flag will bypass the approproriate
// cache level(s) to ensure the write is visible to other workgroups. See the
// appropriate ISA for details about the GLC modifier.
"s_store_dword %0, %1, %2 glc\n\t"
"s_waitcnt lgkmcnt(0)" // Wait for the store to complete
:
: "s"(1), "s"(sk_flags_ptr), "s"(offset)
: "memory");
}
/**
* @brief Waits for the thread block (cta_idx) to complete storing its partial results.
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the thread block (CTA).
* @note This function utilizes a workgroup barrier to wait for the synchronization flag to be
* set by the given CTA index.
* @note This function utilizes a scalar load to read from the flags
* buffer.
*/
CK_TILE_DEVICE void WaitStorePartialDone(const StreamKKernelArgs& kargs, index_t cta_idx) const
{
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
workgroup_barrier sk_flags(sk_flags_ptr);
sk_flags.wait_eq(1, cta_idx);
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
index_t result;
index_t offset = cta_idx * sizeof(index_t);
do
{
asm volatile("s_mov_b32 m0, %2\n\t"
// Depending on the architecture, the GLC flag will bypass the
// approproriate cache level(s) to avoid reading stale flags. See the
// appropriate ISA for details about the GLC modifier.
"s_load_dword %0, %1, %2 glc\n\t"
"s_waitcnt lgkmcnt(0)" // Wait for the load to complete
: "=s"(result)
: "s"(sk_flags_ptr), "s"(offset)
: "memory");
} while(result != 1);
}
/**
* @brief Adds the values of a block tile to an output block tile.
* @param in_out_block_tile The output block tile to which values are added.
* @param in_block_tile The input block tile whose values are added.
* @note This function iterates over the distributed spans of the block tiles and updates the
* output block tile with accumulated values.
* @note This function iterates over the distributed spans of the block tiles and updates
* the output block tile with accumulated values.
*/
template <typename OAccTile>
CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile,
@@ -370,8 +392,8 @@ struct StreamKKernel
* @param cta_idx The index of the thread block (CTA).
* @param c_block_tile_dist The tile distribution for the block.
* @return The loaded partial block tile.
* @note This function calculates the buffer pointer and uses the tile distribution for loading
* the partial block tile.
* @note This function calculates the buffer pointer and uses the tile distribution for
* loading the partial block tile.
*/
template <typename DataType, typename OAccTileDist>
CK_TILE_DEVICE auto LoadPartial(const StreamKKernelArgs& kargs,
@@ -405,8 +427,8 @@ struct StreamKKernel
* @param kargs Kernel arguments, including the workspace pointer.
* @param cta_idx The index of the thread block (CTA).
* @param c_block_tile The block tile to be stored.
* @note This function calculates the buffer pointer and uses the tile window for storing the
* partial block tile.
* @note This function calculates the buffer pointer and uses the tile window for storing
* the partial block tile.
*/
template <typename OAccTile>
CK_TILE_DEVICE void StorePartial(const StreamKKernelArgs& kargs,
@@ -420,7 +442,10 @@ struct StreamKKernel
kargs.tile_partitioner.get_flags_buffer_size() +
cta_idx * c_block_tile_buffer_size;
const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
const auto& partial_tensor_view = make_naive_tensor_view<
address_space_enum::global,
memory_operation_enum::set,
StreamKCoherency<decltype(core::arch::get_compiler_target())>::BUFFER_COHERENCE>(
static_cast<typename OAccTile::DataType*>(partial_buffer_ptr),
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
make_tuple(TilePartitioner::NPerBlock, 1),
@@ -431,8 +456,11 @@ struct StreamKKernel
partial_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{0, 0});
store_tile(partial_tile_window, c_block_tile);
// Wait for all vector stores for this wavefront to complete
s_waitcnt</*vmcnt*/ 0, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
// Wait for all wavefronts in this workgroup to arrive here before continuing
__builtin_amdgcn_s_barrier();
}
/**
@@ -483,7 +511,8 @@ struct StreamKKernel
{
BaseGemm(kargs, tile_idx, num_loop_sk, i_k_a, i_k_b, k_size, smem_ptr_0);
}
else
else if(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Reduction ||
TilePartitioner::ReductionStrategy == StreamKReductionStrategy::TreeReduction)
{
const auto c_macro_tile_idx =
kargs.tile_partitioner.get_output_tile_index(tile_idx);
@@ -528,46 +557,107 @@ struct StreamKKernel
auto tile_started = iter_start == tile_iter_start;
auto tile_ended = iter_end >= tile_iter_end;
if(!tile_started)
if constexpr(TilePartitioner::ReductionStrategy ==
StreamKReductionStrategy::Reduction)
{
StorePartial(kargs, cta_idx, c_block_tile);
// Ensure device-wide visibility of partial results stored in global memory
// before signaling completion. __threadfence() guarantees that all global
// memory writes by this thread are visible to other threads on the device.
__threadfence(); // send signal when the store is done
SignalStorePartialDone(kargs, cta_idx);
if(!tile_started)
{
StorePartial(kargs, cta_idx, c_block_tile);
SignalStorePartialDone(kargs, cta_idx);
}
else
{
auto accum_block_tile = c_block_tile;
if(!tile_ended)
{
const index_t iter_per_tile =
kargs.tile_partitioner.get_iters_per_tile();
const index_t iter_per_cta =
kargs.tile_partitioner.get_iters_per_sk_cta();
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
int accum_iters = local_iter_end - local_iter_start;
int next_cta = cta_idx + 1;
while(accum_iters < iter_per_tile)
{
WaitStorePartialDone(kargs, next_cta);
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
AddBlockTile(
accum_block_tile,
LoadPartial<typename BlockType::DataType>(
kargs, next_cta, c_block_tile.get_tile_distribution()));
accum_iters += iter_per_cta + (next_cta < extra_iters);
++next_cta;
}
}
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
EpiloguePipeline{}(
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
}
}
else
else // Tree Reduction
{
auto accum_block_tile = c_block_tile;
if(!tile_ended)
index_t tile_local_cta_idx =
kargs.tile_partitioner.get_tile_local_cta_index(tile_iter_start, cta_idx);
for(index_t stride = 1;; stride <<= 1)
{
const index_t iter_per_tile = kargs.tile_partitioner.get_iters_per_tile();
const index_t iter_per_cta = kargs.tile_partitioner.get_iters_per_sk_cta();
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
int accum_iters = local_iter_end - local_iter_start;
int next_cta = cta_idx + 1;
const index_t partner_cta_idx = cta_idx + stride;
const index_t partner_start_iter =
kargs.tile_partitioner.get_start_iter(partner_cta_idx);
bool partner_in_tile = partner_start_iter < tile_iter_end;
while(accum_iters < iter_per_tile)
// If the partner of the workgroup who started the tile is not in this tile,
// then the work for this tile is done and results can be stored in the C
// tensor.
if(tile_started && !partner_in_tile)
{
WaitStorePartialDone(kargs, next_cta);
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
EpiloguePipeline{}(
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
break;
}
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
AddBlockTile(
accum_block_tile,
LoadPartial<typename BlockType::DataType>(
kargs, next_cta, c_block_tile.get_tile_distribution()));
accum_iters += iter_per_cta + (next_cta < extra_iters);
++next_cta;
// It's this workgroup's turn to read from partials.
if(tile_local_cta_idx % (stride << 1) == 0)
{
// If this workgroup's partner is in the tile then it can read from
// partials and accumulate results.
if(partner_in_tile)
{
WaitStorePartialDone(kargs, partner_cta_idx);
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
AddBlockTile(accum_block_tile,
LoadPartial<typename BlockType::DataType>(
kargs,
partner_cta_idx,
c_block_tile.get_tile_distribution()));
}
}
// Otherwise, it's this workgroup's turn to write to partials. All
// workgroups, except the workgroup who starts the tile, will write to
// partials.
else
{
StorePartial(kargs, cta_idx, accum_block_tile);
SignalStorePartialDone(kargs, cta_idx);
// Once the workgroup writes to partials, it has no more work to do for
// this tile.
break;
}
}
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
EpiloguePipeline{}(
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
}
}
else
{
static_assert(
"An implementation does not exist for the chosen reduction strategy.");
}
// Prepare for next Stream-K loop iteration.
iter_start = tile_iter_end;
@@ -640,10 +730,10 @@ struct StreamKKernel
private:
/**
* @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset is
* the starting macro tile index in the K dimension for the workgroup.
* @return A tuple containing the offsets into the A and B tensors accounting for the layouts
* of A and B.
* @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset
* is the starting macro tile index in the K dimension for the workgroup.
* @return A tuple containing the offsets into the A and B tensors accounting for the
* layouts of A and B.
* @note The default case is that A is assumed to be row major and B is assumed to be column
* major.
*/
@@ -688,7 +778,8 @@ struct StreamKKernel
}
/**
* @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the kernel
* @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the
* kernel
* @return The occupancy
* @note This function queries the maximum occupancy of the kernel using
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.

View File

@@ -46,6 +46,16 @@ struct StreamKTilePartitionerBase
CK_TILE_HOST_DEVICE index_t get_flags_buffer_size() const noexcept;
public:
/**
* @brief Calculates the start iteration for the given the cta_idx.
* @param cta_idx The current Stream-K workgroup's index.
* @return index_t The start iteration.
* @note It is assumed that the first Stream-K workgroup has a `cta_idx` of zero. If a
* non-persistent DP section is used, then a Stream-K workgroup's `cta_idx` should be something
* like `blockIdx.x` minus number of DP workgroups.
*/
CK_TILE_DEVICE index_t get_start_iter(index_t cta_idx) const noexcept;
/**
* @brief Calculates the start and end iteration given the cta_idx.
*
@@ -107,7 +117,17 @@ struct StreamKTilePartitionerBase
get_local_iter_end(index_t tile_iter_start, index_t iter_end, index_t tile_iter_end) noexcept;
/**
* @brief Calculates the workgroups 2D tile index in the C tensor given the 1D tile index.
* @brief Calculates the workgroup's local CTA idx within the given tile.
*
* @param tile_iter_start The starting tile iteration.
* @param cta_idx The Stream-K workgroup index.
* @return index_t The tile local workgroup index in the tile.
*/
CK_TILE_DEVICE index_t get_tile_local_cta_index(index_t tile_iter_start,
index_t cta_idx) const noexcept;
/**
* @brief Calculates the workgroup's 2D tile index in the C tensor given the 1D tile index.
*
* @param tile_idx The 1D tile index in the C tensor for the workgroup.
* @return index_t The corresponding 2D tile index in the C tensor for the workgroup.

View File

@@ -61,13 +61,24 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_flags
return sizeof(index_t) * sk_ctas_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_start_iter(
index_t cta_idx) const noexcept
{
// Compute the number of extra iterations done before this CTA. If the cta_idx is less than
// extra_iters, the number of extra iterations before the CTA is exactly the cta_idx. Otherwise,
// it is extra_iters.
index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
return total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_DEVICE void
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_iter_boundaries(
index_t& iter, index_t& iter_end, index_t cta_idx) const noexcept
{
index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
iter = total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
iter = get_start_iter(cta_idx);
iter_end = iter + iters_per_sk_cta_ + (cta_idx < extra_iters_);
}
@@ -104,6 +115,24 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_local
return ck_tile::min(iter_end, tile_iter_end) - tile_iter;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_tile_local_cta_index(
index_t tile_iter_start, index_t cta_idx) const noexcept
{
tile_iter_start = tile_iter_start - (dp_tiles_ * iters_per_tile_);
// Compute how many WGs fit before this tile starts assuming each WG does an
// extra_iter
const index_t num_extra_iter_ctas = tile_iter_start / (iters_per_sk_cta_ + 1);
// Compute how many WGs fit before this tile starts excluding extra iters
const index_t num_non_extra_iter_ctas = (tile_iter_start - extra_iters_) / iters_per_sk_cta_;
// Compute the CTA idx for the CTA that starts this tile
const index_t coop_group_start =
num_extra_iter_ctas < extra_iters_ ? num_extra_iter_ctas : num_non_extra_iter_ctas;
return cta_idx - coop_group_start;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_DEVICE auto
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_output_tile_index(
@@ -121,7 +150,8 @@ CK_TILE_HOST_DEVICE index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_workspace_size(
index_t acc_element_bytes) const noexcept
{
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction ||
ReductionStrategy == StreamKReductionStrategy::TreeReduction)
{
return get_partials_buffer_size(acc_element_bytes) + get_flags_buffer_size();

View File

@@ -23,6 +23,9 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
#TODO: support all arches
#TODO: current c-shuffle only supports C layout as R
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
add_gtest_executable(test_ck_tile_streamk_reduction
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp
test_gemm_streamk_util.cpp)
add_gtest_executable(test_ck_tile_streamk_smoke
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf16_persistent.cpp

View File

@@ -0,0 +1,17 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_common_includes.hpp"
template <typename Tuple>
class TestCkTileStreamKFp16Reduction : public TestCkTileStreamK<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKFp16Reduction
TYPED_TEST_SUITE(TestCkTileStreamKFp16Reduction, KernelTypesStreamKFp16Reduction);
#include "test_gemm_streamk_reduction_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,88 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile_Tree)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile * num_cu;
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile * num_cu;
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Tree)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile * 4;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile);
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Reduction)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile * 4;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile);
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles_Tree)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile * 3;
ck_tile::index_t N = N_Tile * 7;
ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile);
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
ck_tile::index_t M = M_Tile * 3;
ck_tile::index_t N = N_Tile * 7;
ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile);
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction);
}

View File

@@ -33,6 +33,14 @@ using KernelTypesStreamKFp16Persistent = ::testing::Types<
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>
>;
using KernelTypesStreamKFp16Reduction = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>>;
using KernelTypesStreamKBf16Persistent = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,

View File

@@ -144,7 +144,11 @@ class TestCkTileStreamK : public ::testing::Test
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
auto kargs = Kernel::MakeKernelArgs(args);
const auto workspace_size = Kernel::GetWorkSpaceSize(kargs);
ck_tile::DeviceMem workspace_data(workspace_size);
workspace_data.SetZero();
kargs.workspace_ptr = workspace_data.GetDeviceBuffer();
if(!Kernel::IsSupportedArgument(kargs))
{
@@ -184,11 +188,6 @@ class TestCkTileStreamK : public ::testing::Test
using namespace ck_tile::literals;
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction)
{
throw std::runtime_error("Reduction Strategy is current unsupported!\n");
}
auto f_host_tensor_descriptor = [](std::size_t row,
std::size_t col,
std::size_t stride,
@@ -252,9 +251,25 @@ class TestCkTileStreamK : public ::testing::Test
stride_B,
stride_C};
ck_tile::index_t num_accumulations_per_tile =
invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
ck_tile::index_t num_accumulations_per_tile;
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
{
num_accumulations_per_tile = invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
}
else if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction)
{
num_accumulations_per_tile =
invoke_streamk<ck_tile::StreamKReductionStrategy::Reduction>(
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
}
else
{
num_accumulations_per_tile =
invoke_streamk<ck_tile::StreamKReductionStrategy::TreeReduction>(
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
}
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());

View File

@@ -372,6 +372,85 @@ TEST(StreamKTilePartitionerBaseGetOutputTileIndex, TestAllMappings)
}
}
TEST(StreamKTilePartitionerBaseGetTileLocalCtaIndex, SKOnlyLargeK)
{
/*
The StreamKTilePartitionerBaseConfigSKOnlyLargeK has the following form:
- tiles in the C tensor: 2
- iters_per_tile: 5
- grid: 5
- dp_tiles: 0
- sk_tiles: 2
- iters_per_sk_cta: 2
- extra_iters: 0
The tiles with iters are as follows:
tile_idx: __________0_________|_________1_________|
tile_iter:| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
| | | | | | | | | | |
<---------------SK Tiles--------------->|
From the above configuration, we get the following:
- SK CTA 0: tile_iter_start is 0 with local CTA index of 0 in tile 0
- SK CTA 1: tile_iter_start is 0 with local CTA index of 1 in tile 0
- SK CTA 2: tile_iter_start is 0 with local CTA index of 2 in tile 0
- SK CTA 2: tile_iter_start is 5 with local CTA index of 0 in tile 1
- SK CTA 3: tile_iter_start is 5 with local CTA index of 1 in tile 1
- SK CTA 4: tile_iter_start is 5 with local CTA index of 2 in tile 1
*/
// Now we create a vector of triplets (tile_iter_start, cta_idx, tile_local_cta_idx) to test
std::vector<std::array<ck_tile::index_t, 3>> sk_triplets{
{0, 0, 0}, {0, 1, 1}, {0, 2, 2}, {5, 2, 0}, {5, 3, 1}, {5, 4, 2}};
for(const auto& triplet : sk_triplets)
{
const auto& [tile_iter_start, cta_idx, tile_local_cta_idx] = triplet;
test_get_tile_local_cta_idx<StreamKTilePartitionerBaseConfigSKOnlyLargeK>(
tile_iter_start, cta_idx, tile_local_cta_idx);
}
}
TEST(StreamKTilePartitionerBaseGetTileLocalCtaIndex, DP2TileSK)
{
/*
The StreamKTilePartitionerBaseConfigDP2TileSK has the following form:
- tiles in the C tensor: 7
- iters_per_tile: 3
- grid: 3
- dp_tiles: 3
- sk_tiles: 4
- iters_per_sk_cta: 2
- extra_iters: 2
The tiles with iters are as follows:
tile_idx: ____0___|___1___|___2___|___3___|___4___|____5____|____6____|
tile_iter:| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 |
| | | | | | | | | | | | | | |
|<-------DP Tiles------>|<------------SK Tiles------------->|
From the above configuration, we get the following:
- SK CTA 0: tile_iter_start is 6 with local CTA index of 0 in tile 3
- SK CTA 0: tile_iter_start is 8 with local CTA index of 0 in tile 4
- SK CTA 1: tile_iter_start is 8 with local CTA index of 1 in tile 4
- SK CTA 1: tile_iter_start is 10 with local CTA index of 0 in tile 5
- SK CTA 2: tile_iter_start is 12 with local CTA index of 0 in tile 6
*/
// Now we create a vector of triplets (tile_iter_start, cta_idx, tile_local_cta_idx) to test
std::vector<std::array<ck_tile::index_t, 3>> sk_triplets{
{6, 0, 0}, {8, 0, 0}, {8, 1, 1}, {10, 1, 0}, {12, 2, 0}};
for(const auto& triplet : sk_triplets)
{
const auto& [tile_iter_start, cta_idx, tile_local_cta_idx] = triplet;
test_get_tile_local_cta_idx<StreamKTilePartitionerBaseConfigDP2TileSK>(
tile_iter_start, cta_idx, tile_local_cta_idx);
}
}
// Persistent
TEST(StreamKTilePartitioner_PersistentConstructor, SKOnly)
{

View File

@@ -4,6 +4,7 @@
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "gtest/gtest.h"
#include <array>
enum StreamKTilePartitionerBaseMethodId
{
@@ -12,7 +13,8 @@ enum StreamKTilePartitionerBaseMethodId
GET_TILE_BOUNDARIES,
GET_TILE_INDEX,
GET_ITER_BOUNDARIES,
GET_OUTPUT_TILE_INDEX
GET_OUTPUT_TILE_INDEX,
GET_TILE_LOCAL_CTA_INDEX
};
// Base kernel wrapper class to facilitate testing class device functions.
@@ -136,6 +138,22 @@ struct KernelWrapperSpecialized<TilePartitioner,
}
};
template <typename TilePartitioner>
struct KernelWrapperSpecialized<TilePartitioner,
StreamKTilePartitionerBaseMethodId::GET_TILE_LOCAL_CTA_INDEX>
: public KernelWrapper<TilePartitioner>
{
using Base = KernelWrapper<TilePartitioner>;
CK_TILE_DEVICE void operator()(typename Base::KernelArgs kargs)
{
ck_tile::index_t tile_local_cta_index =
kargs.tile_partitioner.get_tile_local_cta_index(kargs.arg1, kargs.arg2);
*(static_cast<ck_tile::index_t*>(kargs.result1)) = tile_local_cta_index;
}
};
struct StreamKTilePartitionerBaseExpected
{
ck_tile::index_t sk_tiles_;
@@ -243,6 +261,22 @@ struct StreamKTilePartitionerBaseConfigSKOnly : public StreamKTilePartitionerBas
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
};
struct StreamKTilePartitionerBaseConfigSKOnlyLargeK : public StreamKTilePartitionerBaseConfig
{
static constexpr ck_tile::index_t M = 8;
static constexpr ck_tile::index_t N = 2;
static constexpr ck_tile::index_t K = 10;
static constexpr ck_tile::index_t GRID = 5;
static constexpr ck_tile::index_t M_TILE = 4;
static constexpr ck_tile::index_t N_TILE = 2;
static constexpr ck_tile::index_t K_TILE = 2;
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
};
struct StreamKTilePartitionerBaseConfigEdgeCase : public StreamKTilePartitionerBaseConfig
{
@@ -314,6 +348,38 @@ void test_get_output_tile_index(ck_tile::index_t tile_idx,
EXPECT_EQ(in, in_expected);
};
template <typename Config>
void test_get_tile_local_cta_idx(ck_tile::index_t tile_iter_start,
ck_tile::index_t cta_idx,
ck_tile::index_t expected_tile_local_cta_idx)
{
// Types
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<typename Config::GemmShape>;
using Kernel =
KernelWrapperSpecialized<TilePartitioner,
StreamKTilePartitionerBaseMethodId::GET_TILE_LOCAL_CTA_INDEX>;
// Test parameters
ck_tile::StreamKTilePartitionerBase<typename Config::GemmShape> tile_partitioner{
Config::M, Config::N, Config::K, Config::GRID};
ck_tile::DeviceMem tile_local_cta_idx_dev(sizeof(ck_tile::index_t));
// Launch kernel
auto kargs = Kernel::MakeKernelArgs(tile_iter_start,
cta_idx,
Config::UNUSED,
tile_local_cta_idx_dev.GetDeviceBuffer(),
nullptr,
tile_partitioner);
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
// Validate results
ck_tile::index_t tile_local_cta_idx;
tile_local_cta_idx_dev.FromDevice(&tile_local_cta_idx);
EXPECT_EQ(tile_local_cta_idx, expected_tile_local_cta_idx);
}
// Configs for TilePartitioner Child structs
struct StreamKTilePartitionerV2PersistentExpected
{