[CK_TILE] Stream-K operator() Reboot (#3064)

* Persistent Stream-K Kernel Implementation

This change implements an operator() function in the
reboot::StreamKKernel class that is enabled when the Persistent flag is
set to true. In this case, the data-parallel portion and the Stream-K
portion of the kernel are fully persistent.

The changes were made in the reboot namespace. A future PR will remove
the old Stream-K kernel class and remove the reboot namespace.

* Unit Tests for Persistent Stream-K Kernel

This change contains the inital test suite for the Persitent Stream-K
Kernel. The files contain "reboot" in the name; a future PR will remove
tests for the old Stream-K Kernel and remove the "reboot" naming.

A future commit will add tests for the non-persistent kernel.

Also added estimate_num_wgs_per_tile to the StreamKTilePartitionerBase
class. This allows us to estimate the number of accumulations done per
macro tile in C to use during validation when computing relative and
absolute tolerance.

* Adding implementation for the Non-Persistent Stream-K kernel

This code is adding the operator() function for the Non-Persistent Stream-K
kernel. Persistency of the kernel is determined through a template argument.
The Non-Persistent kernel will allocate additional workgroups for the data
parallel section, leading to a different structure for processing the data
parallel and Stream-K sections.

There has been an addition to the TilePartitioner to get access to the whether
Persistent has been set to true or false in the StreamKKernel.

* Adding in the tests for the Non-Persistent Stream-K kernel

* Refactor Stream-K Reboot Unit Tests

This commit makes the following changes:
- Update test cases to determine M, N, and K based on the number of CUs.
  This ensures that each test case is one of Edge Case, SK Only, DP
Only, or DP + 2 Tile SK regardless of the architecture.
- Since the DP + 2 Tile SK test case takes long to run, this change
  moves this case into a separate .inc file and labels it as an extended
test.
- Since the extended test takes > 30 seconds to run, this test is added
  to the list of regression tests.

* Fix spelling errors in comments for test cases

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Changes based on review

Removed const volatile for typenames
Set up alias for is_tuple_t
Naming changes for clarity: GemmCommon -> BaseGemm
Moved std::enable_if_t out of template parameters and changed to a return type for operator()
Added constructor for StreamKKernelArgs to clarify UniversalGemm inheritance

---------

Co-authored-by: Emily Martins <emily.martins@amd.com>
Co-authored-by: Christopher Millette <63608002+cgmillette@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

[ROCm/composable_kernel commit: 054fdb765c]
This commit is contained in:
arai713
2025-10-27 09:14:17 -07:00
committed by GitHub
parent 3dd0779bf7
commit d06d23ab11
20 changed files with 1122 additions and 0 deletions

View File

@@ -8,6 +8,478 @@
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
namespace reboot {
/// @brief The Stream K GEMM kernel host arguments.
///
/// @par Overview
/// This structure is passed to @ref StreamKKernel "StreamKKernel" when creating the kernel
/// arguments object. It contains all necessary information required to build proper kernel
/// arguments and launch the kernel on GPU. This structure defines the GEMM problem
/// configuration by stating all required information like M,N,K sizes and respective strides.
struct StreamKHostArgs : public ck_tile::UniversalGemmHostArgs<>
{
CK_TILE_HOST explicit StreamKHostArgs(const void* a_ptr_,
const void* b_ptr_,
void* c_ptr_,
index_t M_,
index_t N_,
index_t K_,
index_t stride_A_,
index_t stride_B_,
index_t stride_C_,
StreamKReductionStrategy reduction_strategy_)
: UniversalGemmHostArgs<>({a_ptr_},
{b_ptr_},
{/*ds_ptr*/},
c_ptr_,
/*k_batch_ =*/1,
M_,
N_,
K_,
{stride_A_},
{stride_B_},
{/*stride_Ds_*/},
stride_C_),
reduction_strategy{reduction_strategy_}
{
}
ck_tile::StreamKReductionStrategy reduction_strategy;
};
/// @brief The Stream K GEMM kernel class.
///
/// @par Overview
/// This class is responsible for the Stream-K kernel, making use of UniversalGemm.
// The main kernel functions are the operator() functions. There is one for Persistent
// and one for Non-Persistent data parallel sections of the Stream-K algorithm.
//
// Both the Non-Persistent and Persistent kernels make use of `BaseGemm()` and
// `StreamKGemm()`. `BaseGemm()` computes offsets into the A,B,C tensors, then calls
// `RunGemm()` which runs the GEMM pipeline and epilogue. `StreamKGemm()` performs the
// main Stream-K algorithm. Each iteration of the Stream-K loop calls `BaseGemm()`.
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct StreamKKernel
{
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
/// functions.
using UniversalGemmKernel =
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize;
static constexpr bool PersistentDP = UniversalGemmKernel::PersistentKernel;
using TilePartitioner = TilePartitioner_;
using GemmPipeline = GemmPipeline_;
using EpiloguePipeline = EpiloguePipeline_;
static_assert(
TilePartitioner::PERSISTENT == PersistentDP,
"Persistent flag from TilePartitioner must match Persistent flag from UniversalGemm.");
/// @brief Specify the layout configurations for A, B, and C
using ALayout = typename GemmPipeline::ALayout;
using BLayout = typename GemmPipeline::BLayout;
using CLayout = typename GemmPipeline::CLayout;
/// @brief Specify the data type configurations for A, B, and C
using ADataType = typename GemmPipeline::ADataType;
using BDataType = typename GemmPipeline::BDataType;
using CDataType = typename EpiloguePipeline::ODataType;
template <typename T>
static constexpr bool is_tuple_v = is_detected<is_tuple, T>::value;
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
static_assert(!is_tuple_v<ALayout> && !is_tuple_v<ADataType>,
"ALayout and ADataType must be scalars.");
/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
static_assert(!is_tuple_v<BLayout> && !is_tuple_v<BDataType>,
"BLayout and BDataType must be scalars.");
/// @brief CLayout and CDataType are expected to be scalars, not a tuple.
static_assert(!is_tuple_v<CLayout> && !is_tuple_v<CDataType>,
"CLayout and CDataType must be scalars.");
struct StreamKKernelArgs : ck_tile::UniversalGemmKernelArgs<>
{
StreamKKernelArgs(const StreamKHostArgs& host_args, index_t grid)
: UniversalGemmKernelArgs{host_args.as_ptr,
host_args.bs_ptr,
host_args.ds_ptr,
host_args.e_ptr,
host_args.M,
host_args.N,
host_args.K,
host_args.stride_As,
host_args.stride_Bs,
host_args.stride_Ds,
host_args.stride_E,
host_args.k_batch},
reduction_strategy{host_args.reduction_strategy},
// The workspace pointer is set to nullptr because we must first
// instantiate the TilePartitioner to get the necessary size
workspace_ptr{nullptr},
tile_partitioner{TilePartitioner{host_args.M, host_args.N, host_args.K, grid}}
{
}
/// @brief The strategy used by work groups to compute final results in C tensor.
StreamKReductionStrategy reduction_strategy;
/// @brief A pointer to a buffer in device memory for accumulating partial via reduction
/// strategy.
void* workspace_ptr;
/// @brief An instance of the TilePartioner class for assisting with mapping workgroups to
/// the C tensor.
TilePartitioner tile_partitioner;
};
using KernelArgs = StreamKKernelArgs;
using Kernel = StreamKKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
using P_ = GemmPipeline;
using WarpTile = typename P_::BlockGemmShape::WarpTile;
return concat('_', "streamk", gemm_prec_str<ADataType, BDataType>(),
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
concat('x', WarpTile::at(number<0>{}), WarpTile::at(number<1>{}), WarpTile::at(number<2>{})),
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
// clang-format on
}
/// @brief Compute the grid size for the Stream K kernel using the tile_partitioner.
/// @return The grid size.
CK_TILE_HOST static auto GridSize(const TilePartitioner& tile_partitioner) -> dim3
{
return tile_partitioner.grid_size();
}
/// @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
/// @return The maximum occupancy grid size.
/// @note This function queries the maximum occupancy of the kernel using
/// `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
{
return UniversalGemmKernel::MaxOccupancyGridSize(s);
}
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
{
return UniversalGemmKernel::BlockSize();
}
/// @brief Constructs kernel arguments for the Stream-K kernel.
/// @param host_args Stream-K host arguments.
/// @param num_cu Number of compute units (CUs). The default is the number of CUs on the device.
/// The caller may select their own to assist with test reproducibility, etc.
/// @param occupancy The maximum number of active blocks per CU for this kernel. The caller may
/// select their own to assist with test reproducibility, etc.
/// @return The kernel arguments for Stream-K.
CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args,
int num_cu = NumCU(),
int occupancy = Occupancy())
{
const index_t grid = num_cu * occupancy;
return StreamKKernelArgs{host_args, grid};
}
template <bool UseDefaultScheduler = true>
CK_TILE_DEVICE static void
RunGemm(const std::array<const ADataType*, UniversalGemmKernel::NumATensor>& as_ptr,
const std::array<const BDataType*, UniversalGemmKernel::NumBTensor>& bs_ptr,
const std::array<const void*, UniversalGemmKernel::NumDTensor>& ds_ptr,
CDataType* c_ptr,
void* smem_ptr_0,
const typename UniversalGemmKernel::KernelArgs& kargs,
const index_t num_loop,
const index_t block_idx_m,
const index_t block_idx_n,
const index_t k_size)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
UniversalGemmKernel::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size);
const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows =
UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
// Run GEMM cooperatively by whole workgroup.
const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
// Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute
// has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this
// case, we call the GemmPipeline's operator() function that takes both has_hot_loop and
// tail_num.
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
bs_block_window[UniversalGemmKernel::I0],
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0);
if(UseDefaultScheduler || (get_warp_id() == 0))
{
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
}
}
CK_TILE_HOST static bool IsSupportedArgument(const StreamKKernelArgs& kargs)
{
if(kargs.reduction_strategy == StreamKReductionStrategy::Reduction)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("CK Tile Stream-K only supports the atomic reduction strategy.");
}
return false;
}
return UniversalGemmKernel::IsSupportedArgument(kargs);
}
/// @brief Computes the buffer size needed to store accumulation results for Stream K.
/// @return The buffer size needed.
CK_TILE_HOST static uint32_t GetWorkSpaceSize(const StreamKKernelArgs& kargs)
{
return kargs.tile_partitioner.GetWorkSpaceSize(sizeof(CDataType));
}
/// @brief Sets the kargs' current workspace_ptr to the given workspace_ptr.
/// @note Assumes that the given workspace_ptr points to allocated device memory.
CK_TILE_HOST static void SetWorkSpacePointer(StreamKKernelArgs& kargs, void* workspace_ptr)
{
kargs.workspace_ptr = workspace_ptr;
}
/// @brief Computes offsets into A, B, and C tensors then runs the GEMM pipeline and epilogue.
/// @param kargs Stream-K kernel arguments.
/// @param tile_idx The 1D tile index in the C tensor for this workgroup.
/// @param num_loop The number of iterations (at the macro tile level) in the K dimension this
/// workgroup will perform in the C tile.
/// @param i_k_a The K offset in the A tensor.
/// @param i_k_b The K offset in the B tensor.
/// @param k_size The portion of the K dimension this workgroup processes in the assigned
/// `tile_idx`.
/// @param smem_ptr_0 Pointer to LDS.
CK_TILE_DEVICE void BaseGemm(StreamKKernelArgs& kargs,
index_t tile_idx,
index_t num_loop,
index_t i_k_a,
index_t i_k_b,
index_t k_size,
void* smem_ptr_0) const
{
const auto c_macro_tile_idx = kargs.tile_partitioner.get_output_tile_index(tile_idx);
index_t i_m = c_macro_tile_idx[UniversalGemmKernel::I0] * TilePartitioner::MPerBlock;
index_t i_n = c_macro_tile_idx[UniversalGemmKernel::I1] * TilePartitioner::NPerBlock;
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k_a;
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
// Run the GEMM pipeline and Epilogue.
RunGemm(
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, k_size);
}
/// @brief Runs the main Stream-K algorithm.
/// @param kargs Stream-K kernel arguments.
/// @param cta_idx The current Stream-K workgroup's index.
/// @param smem_ptr_0 Pointer to LDS.
/// @note It is assumed that the first Stream-K workgroup has a `cta_idx` of zero. If a
/// non-persistent data-parallel (DP) section is used, then a Stream-K workgroup's `cta_idx`
/// should be something like `blockIdx.x` minus number of DP workgroups.
CK_TILE_DEVICE void
StreamKGemm(StreamKKernelArgs& kargs, index_t cta_idx, void* smem_ptr_0) const
{
index_t iter_start, iter_end;
kargs.tile_partitioner.get_iter_boundaries(iter_start, iter_end, cta_idx);
while(iter_start < iter_end)
{
// Get the 1D tile index in the C tensor that this workgroup will work in for this
// iteration of the loop.
index_t tile_idx =
amd_wave_read_first_lane(kargs.tile_partitioner.get_tile_index(iter_start));
// Get the start and end boundaries for the current tile.
index_t tile_iter_start, tile_iter_end;
kargs.tile_partitioner.get_tile_boundaries(tile_iter_start, tile_iter_end, tile_idx);
// Get the start and end iteration within the current tile for the workgroup.
index_t local_iter_start = amd_wave_read_first_lane(
kargs.tile_partitioner.get_local_iter(iter_start, tile_iter_start));
index_t local_iter_end =
amd_wave_read_first_lane(kargs.tile_partitioner.get_local_iter_end(
tile_iter_start, iter_end, tile_iter_end));
// Get the iteration length.
index_t num_loop_sk = local_iter_end - local_iter_start;
// Determine the total size along the K dimension the workgroup is using in this
// iteration (used to construct tensor views).
index_t k_size = num_loop_sk * TilePartitioner::KPerBlock;
// Get the K offsets for the A and B tensors
auto [i_k_a, i_k_b] = GetKOffsets<ALayout, BLayout>(
local_iter_start, kargs.stride_As[0], kargs.stride_Bs[0]);
if constexpr(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Atomic)
{
BaseGemm(kargs, tile_idx, num_loop_sk, i_k_a, i_k_b, k_size, smem_ptr_0);
}
else
{
// TODO: Apply reduction logic.
}
// Prepare for next Stream-K loop iteration.
iter_start = tile_iter_end;
block_sync_lds();
}
}
/// @brief Entry point for the Stream-K Kernel with non-persistent DP.
///
/// @par Overview
/// For the Non-Persistent kernel, each data parallel workgroup will
/// compute the results for their assigned macro-tile by calling `BaseGemm()`.
/// The Stream-K workgroups will do their assigned work by calling
/// `StreamKGemm()`, which calls `BaseGemm()` in the Stream-K loop.
template <bool U = PersistentDP>
CK_TILE_DEVICE typename std::enable_if_t<!U> operator()(StreamKKernelArgs kargs) const
{
// Allocate LDS
__shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
index_t block_idx = ck_tile::get_block_1d_id();
index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
index_t dp_ctas = kargs.tile_partitioner.get_dp_ctas();
bool is_dp_ctas = block_idx < kargs.tile_partitioner.get_dp_ctas();
// Check if at the data parallel section
if(is_dp_ctas)
{
BaseGemm(kargs, block_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
}
else
{
// Stream-K
StreamKGemm(kargs, block_idx - dp_ctas, smem_ptr_0);
}
}
/// @brief Entry point for the Stream-K Kernel with persistent DP.
///
/// @par Overview
/// For the Persistent kernel, each workgroup will first compute their
/// assigned data-parallel tiles. Each data parallel tile will be computed
/// by calling `BaseGemm()`. Then the workgroups will proceed with the
/// Stream-K portion by calling `StreamKGemm()`, which calls `BaseGemm()`
/// in the Stream-K loop.
template <bool U = PersistentDP>
CK_TILE_DEVICE typename std::enable_if_t<U> operator()(StreamKKernelArgs kargs) const
{
// Allocate LDS
__shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
index_t block_idx = ck_tile::get_block_1d_id();
index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
// Data-parallel section
for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles();
tile_idx += kargs.tile_partitioner.get_grid())
{
BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
}
// Stream-K section
StreamKGemm(kargs, block_idx, smem_ptr_0);
}
private:
/// @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset is
/// the starting macro tile index in the K dimension for the workgroup.
/// @return A tuple containing the offsets into the A and B tensors accounting for the layouts
/// of A and B.
/// @note The default case is that A is assumed to be row major and B is assumed to be column
/// major.
template <typename ALayout, typename BLayout>
CK_TILE_DEVICE static tuple<index_t, index_t>
GetKOffsets(index_t iter_offset, index_t stride_a, index_t stride_b)
{
index_t stride_offset_a;
index_t stride_offset_b;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
stride_offset_a = stride_a;
}
else
{
stride_offset_a = 1;
}
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
stride_offset_b = stride_b;
}
else
{
stride_offset_b = 1;
}
index_t base_offset = iter_offset * TilePartitioner::KPerBlock;
return make_tuple(base_offset * stride_offset_a, base_offset * stride_offset_b);
}
CK_TILE_HOST static int NumCU()
{
hipDeviceProp_t dev_prop;
hipDevice_t dev;
hip_check_error(hipGetDevice(&dev));
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
int num_cu = dev_prop.multiProcessorCount;
return num_cu;
}
/// @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the kernel
/// @return The occupancy
/// @note This function queries the maximum occupancy of the kernel using
/// `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
CK_TILE_HOST static int Occupancy()
{
int occupancy;
// Since occupancy of 1 is valid for stream k, we set min_num_block_per_cu to 1
constexpr int min_block_per_cu = 1;
const auto kernel = kentry<min_block_per_cu, Kernel, KernelArgs>;
hip_check_error(
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
return occupancy;
}
};
} // namespace reboot
/// @brief The Stream K GEMM kernel host arguments.
///

View File

@@ -186,6 +186,11 @@ struct StreamKTilePartitionerBase
*/
CK_TILE_HOST_DEVICE index_t get_n() const noexcept;
/**
* @brief Returns an estimate of the number of workgroups writing to the same macro tile in C.
*/
CK_TILE_HOST index_t estimate_num_wgs_per_tile() const noexcept;
protected:
index_t num_tiles_;
index_t grid_;
@@ -246,6 +251,7 @@ struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true
ck_tile::index_t grid);
public:
static constexpr bool PERSISTENT = true;
/**
* @brief Calculates the launching grid size for the Stream-K kernel. In the Persistent
* case, no extra workgroups are allocated for the data parallel section, making the grid
@@ -292,6 +298,7 @@ struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, fals
ck_tile::index_t grid);
public:
static constexpr bool PERSISTENT = false;
/**
* @brief Calculates the launching grid size for the Stream-K kernel. In the Non-Persistent
* case, extra workgroups are allocated for the data parallel section, making the grid

View File

@@ -214,6 +214,27 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_n() c
return n_;
}
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
CK_TILE_HOST index_t
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::estimate_num_wgs_per_tile()
const noexcept
{
// In the case of non-atomic reduction or data-parallel only, there will always be 1 workgroup
// writing final results to a given macro tile in C.
int num_wgs_per_tile = 1;
// Otherwise, for atomics, multiple workgroups may be writing to the same macro tile in C.
if(sk_ctas_ > 0 && ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
{
ck_tile::index_t iters_per_sk_cta_non_zero = ck_tile::max(iters_per_sk_cta_, 1);
// Estimate the number of workgroups per macro tile.
num_wgs_per_tile = (iters_per_tile_ / iters_per_sk_cta_non_zero) +
((iters_per_tile_ % iters_per_sk_cta_non_zero) != 0);
}
return std::max(num_wgs_per_tile, 1);
}
template <typename BlockGemmShapeType,
StreamKReductionStrategy ReductionStrategyType,
bool Persistent>

View File

@@ -46,6 +46,7 @@ set(REGRESSION_TESTS
test_ck_tile_fmha_fwd_bf16
test_ck_tile_fmha_fwd_fp16
test_ck_tile_fmha_fwd_fp8
test_ck_tile_streamk_reboot_extended
)
function(add_test_executable TEST_NAME)

View File

@@ -117,6 +117,18 @@ if(GPU_TARGETS MATCHES "gfx9")
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
# )
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
add_gtest_executable(test_ck_tile_streamk_reboot_smoke
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_fp16_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_bf16_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp
test_gemm_streamk_reboot_util.cpp)
add_gtest_executable(test_ck_tile_streamk_reboot_extended
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_fp16_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_bf16_persistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp
test_gemm_streamk_reboot_util.cpp)
else()
message(DEBUG "Skipping test_ck_tile_streamk tests for current target")
endif()

View File

@@ -0,0 +1,19 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_reboot_types.hpp"
#include "test_gemm_streamk_reboot_util.hpp"
#include "gtest/gtest.h"
template <typename Tuple>
class TestCkTileStreamKRebootBf16NonPersistent : public TestCkTileStreamKReboot<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKRebootBf16NonPersistent
TYPED_TEST_SUITE(TestCkTileStreamKRebootBf16NonPersistent, KernelTypesStreamKBf16NonPersistent);
#include "test_gemm_streamk_reboot_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,19 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_reboot_types.hpp"
#include "test_gemm_streamk_reboot_util.hpp"
#include "gtest/gtest.h"
template <typename Tuple>
class TestCkTileStreamKRebootBf16Persistent : public TestCkTileStreamKReboot<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKRebootBf16Persistent
TYPED_TEST_SUITE(TestCkTileStreamKRebootBf16Persistent, KernelTypesStreamKBf16Persistent);
#include "test_gemm_streamk_reboot_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,19 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_reboot_types.hpp"
#include "test_gemm_streamk_reboot_util.hpp"
#include "gtest/gtest.h"
template <typename Tuple>
class TestCkTileStreamKRebootFp16NonPersistent : public TestCkTileStreamKReboot<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKRebootFp16NonPersistent
TYPED_TEST_SUITE(TestCkTileStreamKRebootFp16NonPersistent, KernelTypesStreamKFp16NonPersistent);
#include "test_gemm_streamk_reboot_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,19 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_reboot_types.hpp"
#include "test_gemm_streamk_reboot_util.hpp"
#include "gtest/gtest.h"
template <typename Tuple>
class TestCkTileStreamKRebootFp16Persistent : public TestCkTileStreamKReboot<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKRebootFp16Persistent
TYPED_TEST_SUITE(TestCkTileStreamKRebootFp16Persistent, KernelTypesStreamKFp16Persistent);
#include "test_gemm_streamk_reboot_extended_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,19 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_reboot_types.hpp"
#include "test_gemm_streamk_reboot_util.hpp"
#include "gtest/gtest.h"
template <typename Tuple>
class TestCkTileStreamKRebootBf16NonPersistent : public TestCkTileStreamKReboot<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKRebootBf16NonPersistent
TYPED_TEST_SUITE(TestCkTileStreamKRebootBf16NonPersistent, KernelTypesStreamKBf16NonPersistent);
#include "test_gemm_streamk_reboot_smoke_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,19 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_reboot_types.hpp"
#include "test_gemm_streamk_reboot_util.hpp"
#include "gtest/gtest.h"
template <typename Tuple>
class TestCkTileStreamKRebootBf16Persistent : public TestCkTileStreamKReboot<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKRebootBf16Persistent
TYPED_TEST_SUITE(TestCkTileStreamKRebootBf16Persistent, KernelTypesStreamKBf16Persistent);
#include "test_gemm_streamk_reboot_smoke_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,19 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_reboot_types.hpp"
#include "test_gemm_streamk_reboot_util.hpp"
#include "gtest/gtest.h"
template <typename Tuple>
class TestCkTileStreamKRebootFp16NonPersistent : public TestCkTileStreamKReboot<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKRebootFp16NonPersistent
TYPED_TEST_SUITE(TestCkTileStreamKRebootFp16NonPersistent, KernelTypesStreamKFp16NonPersistent);
#include "test_gemm_streamk_reboot_smoke_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,19 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_streamk_reboot_types.hpp"
#include "test_gemm_streamk_reboot_util.hpp"
#include "gtest/gtest.h"
template <typename Tuple>
class TestCkTileStreamKRebootFp16Persistent : public TestCkTileStreamKReboot<Tuple>
{
};
#define TEST_SUITE_NAME TestCkTileStreamKRebootFp16Persistent
TYPED_TEST_SUITE(TestCkTileStreamKRebootFp16Persistent, KernelTypesStreamKFp16Persistent);
#include "test_gemm_streamk_reboot_smoke_cases.inc"
#undef TEST_SUITE_NAME

View File

@@ -0,0 +1,24 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
TYPED_TEST(TEST_SUITE_NAME, StreamK_DP2TSK)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
// For DP 2-Tile SK, there are 2 important terms:
// Term 1: (M_Tile * num_cu * 2) - This ensures we have at least 2 cycles that will fully
// saturate all CUs. This assumes tile sizes are large enough such that occupancy is 1.
// Term 2: (M_Tile * 2) - This ensures we have 1 cycle that does not fully saturate all CUs
// (i.e., we will have remainder tiles). This guarantees we have 1 full tile cycle plus
// remainder tiles for the 2 Tile SK portion; the rest of the tiles will fully saturate all CUs
// for the DP portion.
ck_tile::index_t M = (M_Tile * num_cu * 2) + (M_Tile * 2);
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = 2048;
this->Run(M, N, K);
}

View File

@@ -0,0 +1,47 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
TYPED_TEST(TEST_SUITE_NAME, StreamK_EdgeCase)
{
ck_tile::index_t M = 256;
ck_tile::index_t N = 256;
ck_tile::index_t K = 256;
this->Run(M, N, K);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_DPOnly)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
// For DP only, we ensure that the number of tiles is a multiple of the number of CUs. This
// assumes tile sizes are large enough such that occupancy is 1.
ck_tile::index_t M = M_Tile * num_cu;
ck_tile::index_t N = N_Tile;
ck_tile::index_t K = K_Tile;
this->Run(M, N, K);
}
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly)
{
const ck_tile::index_t num_cu = get_cu_count();
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
// For SK only, we have 4 macro tiles in C. But, we need to make sure there is enough work along
// the K dimension to avoid falling into the edge case. Thus, we always have at least num_cu
// macro tiles in the K dimension. This assumes tile sizes are large enough such that occupancy
// is 1.
ck_tile::index_t M = M_Tile * 2;
ck_tile::index_t N = N_Tile * 2;
ck_tile::index_t K = K_Tile * num_cu;
this->Run(M, N, K);
}

View File

@@ -0,0 +1,56 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <tuple>
#include <type_traits>
#include "gtest/gtest.h"
#include "ck_tile/host.hpp"
using F16 = ck_tile::half_t;
using F32 = float;
using BF16 = ck_tile::bf16_t;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Persistent = std::true_type;
using NonPersistent = std::false_type;
using I32 = ck_tile::number<32>;
using I256 = ck_tile::number<256>;
// clang-format off
using KernelTypesStreamKFp16Persistent = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>
>;
using KernelTypesStreamKBf16Persistent = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>
>;
using KernelTypesStreamKFp16NonPersistent = ::testing::Types<
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>
>;
using KernelTypesStreamKBf16NonPersistent = ::testing::Types<
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>,
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>,
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>,
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>
>;
// clang-format on

View File

@@ -0,0 +1,10 @@
#include "test_gemm_streamk_reboot_util.hpp"
ck_tile::index_t get_cu_count()
{
hipDeviceProp_t dev_prop;
hipDevice_t dev;
ck_tile::hip_check_error(hipGetDevice(&dev));
ck_tile::hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
return dev_prop.multiProcessorCount;
}

View File

@@ -0,0 +1,283 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include <hip/hip_runtime.h>
#include <iostream>
#include <string>
#include <tuple>
#include "ck_tile/host.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// The logic below may need to become more advanced once bugs in Stream-K Tile Partitioner are
// resolved. Because the number of WGs contributing to a macro tile in C may not be the same for
// all macro tiles in C.
// Calculate error due to more than 1 WG contributing to the same macro tile in C
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
ck_tile::index_t get_cu_count();
template <typename Tuple>
class TestCkTileStreamKReboot : public ::testing::Test
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using AccDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
using DsLayout = ck_tile::tuple<>;
using DsDataType = ck_tile::tuple<>;
static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, Tuple>::value;
static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, Tuple>::value;
static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, Tuple>::value;
static constexpr bool Persistent = std::tuple_element_t<10, Tuple>::value;
template <ck_tile::StreamKReductionStrategy ReductionStrategy,
bool PadM = true,
bool PadN = true,
bool PadK = true,
bool Preshuffle = false,
bool TransposeC = false>
ck_tile::index_t invoke_streamk(const ck_tile::reboot::StreamKHostArgs& args,
const ck_tile::stream_config& s)
{
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr bool kPadM = PadM;
constexpr bool kPadN = PadN;
constexpr bool kPadK = PadK;
constexpr bool preshuffle = Preshuffle;
constexpr bool DoubleSmemBuffer = false;
constexpr int kBlockPerCu = 1;
constexpr bool StructuredSparsity = false;
constexpr bool NumWaveGroup = 1;
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner =
ck_tile::StreamKTilePartitioner_v2<GemmShape, ReductionStrategy, Persistent>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
kPadN,
kPadK,
DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
TransposeC,
StructuredSparsity,
Persistent,
NumWaveGroup,
preshuffle>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
// Kernel's RunGemm function. This is a similar pattern used by grouped GEMM.
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
// For initial testing, we will just test with one pipeline.
// More extensive testing is coming later and will test other pipelines.
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel =
ck_tile::reboot::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
if(!Kernel::IsSupportedArgument(kargs))
{
EXPECT_TRUE(false);
}
dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner);
dim3 block_dims = Kernel::BlockSize();
ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grid_dims, block_dims, 0, kargs));
return kargs.tile_partitioner.estimate_num_wgs_per_tile();
};
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
// Since we are doing stream K, in the case of
// atomics, multiple workgroups may write to the same
// output tile in the C tensor, so we must atomic add
// the results (not set)
ck_tile::memory_operation_enum::atomic_add>{});
}
public:
void Run(ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::StreamKReductionStrategy reduction_strategy =
ck_tile::StreamKReductionStrategy::Atomic,
ck_tile::index_t stride_A = 0,
ck_tile::index_t stride_B = 0,
ck_tile::index_t stride_C = 0)
{
// Since M, N, and K will vary depending on the number of CUs, we print it here to
// facilitate test output readability.
std::cout << "M: " << M << ", N: " << N << ", K: " << K << std::endl;
using namespace ck_tile::literals;
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction)
{
throw std::runtime_error("Reduction Strategy is current unsupported!\n");
}
auto f_host_tensor_descriptor = [](std::size_t row,
std::size_t col,
std::size_t stride,
auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(stride == 0)
{
if constexpr(std::is_same_v<decltype(layout),
ck_tile::tensor_layout::gemm::RowMajor>)
{
return col;
}
else
{
return row;
}
}
else
return stride;
};
stride_A = f_get_default_stride(M, K, stride_A, ALayout{});
stride_B = f_get_default_stride(K, N, stride_B, BLayout{});
stride_C = f_get_default_stride(M, N, stride_C, CLayout{});
ck_tile::HostTensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
ck_tile::HostTensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5, /*seed*/ 11939}(a_m_k);
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5, /*seed*/ 11940}(b_k_n);
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
ck_tile::reboot::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
b_k_n_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
M,
N,
K,
stride_A,
stride_B,
stride_C,
reduction_strategy};
ck_tile::index_t num_accumulations_per_tile =
invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, num_accumulations_per_tile, max_accumulated_value);
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
EXPECT_TRUE(pass);
};
};

View File

@@ -77,6 +77,26 @@ TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, ReductionStrategy)
expected_partials_size + expected_flags_size);
}
TEST(StreamKTilePartitionerBaseEstimateNumWgsPerTile, EstimateNumWgsPerTileLowerValue)
{
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
Config::M, Config::N, Config::K, Config::GRID};
EXPECT_EQ(tile_partitioner.estimate_num_wgs_per_tile(), 1);
}
TEST(StreamKTilePartitionerBaseEstimateNumWgsPerTile, EstimateNumWgsPerTileEqualValue)
{
using Config = StreamKTilePartitionerBaseConfigSKOnlyWith2WgsPerSKTile;
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
Config::M, Config::N, Config::K, Config::GRID};
EXPECT_EQ(tile_partitioner.estimate_num_wgs_per_tile(), 2);
}
TEST(StreamKTilePartitionerBaseGetLocalIter, GetLocalIter)
{
// Types

View File

@@ -194,6 +194,23 @@ struct StreamKTilePartitionerBaseConfigDP2TileSK : public StreamKTilePartitioner
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
};
struct StreamKTilePartitionerBaseConfigSKOnlyWith2WgsPerSKTile
: public StreamKTilePartitionerBaseConfig
{
static constexpr ck_tile::index_t M = 16;
static constexpr ck_tile::index_t N = 4;
static constexpr ck_tile::index_t K = 16;
static constexpr ck_tile::index_t GRID = 8;
static constexpr ck_tile::index_t M_TILE = 4;
static constexpr ck_tile::index_t N_TILE = 4;
static constexpr ck_tile::index_t K_TILE = 8;
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
};
struct StreamKTilePartitionerBaseConfigDPOnly : public StreamKTilePartitionerBaseConfig
{
static constexpr ck_tile::index_t M = 12;