mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Merge commit '054fdb765cd74c0f7bbb6561ea58713df82ed85f' into develop
This commit is contained in:
@@ -8,6 +8,478 @@
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace reboot {
|
||||
|
||||
/// @brief The Stream K GEMM kernel host arguments.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This structure is passed to @ref StreamKKernel "StreamKKernel" when creating the kernel
|
||||
/// arguments object. It contains all necessary information required to build proper kernel
|
||||
/// arguments and launch the kernel on GPU. This structure defines the GEMM problem
|
||||
/// configuration by stating all required information like M,N,K sizes and respective strides.
|
||||
struct StreamKHostArgs : public ck_tile::UniversalGemmHostArgs<>
|
||||
{
|
||||
CK_TILE_HOST explicit StreamKHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
void* c_ptr_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
index_t stride_C_,
|
||||
StreamKReductionStrategy reduction_strategy_)
|
||||
: UniversalGemmHostArgs<>({a_ptr_},
|
||||
{b_ptr_},
|
||||
{/*ds_ptr*/},
|
||||
c_ptr_,
|
||||
/*k_batch_ =*/1,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
{stride_A_},
|
||||
{stride_B_},
|
||||
{/*stride_Ds_*/},
|
||||
stride_C_),
|
||||
reduction_strategy{reduction_strategy_}
|
||||
{
|
||||
}
|
||||
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy;
|
||||
};
|
||||
|
||||
/// @brief The Stream K GEMM kernel class.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This class is responsible for the Stream-K kernel, making use of UniversalGemm.
|
||||
// The main kernel functions are the operator() functions. There is one for Persistent
|
||||
// and one for Non-Persistent data parallel sections of the Stream-K algorithm.
|
||||
//
|
||||
// Both the Non-Persistent and Persistent kernels make use of `BaseGemm()` and
|
||||
// `StreamKGemm()`. `BaseGemm()` computes offsets into the A,B,C tensors, then calls
|
||||
// `RunGemm()` which runs the GEMM pipeline and epilogue. `StreamKGemm()` performs the
|
||||
// main Stream-K algorithm. Each iteration of the Stream-K loop calls `BaseGemm()`.
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct StreamKKernel
|
||||
{
|
||||
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
|
||||
/// functions.
|
||||
using UniversalGemmKernel =
|
||||
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
||||
|
||||
static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize;
|
||||
static constexpr bool PersistentDP = UniversalGemmKernel::PersistentKernel;
|
||||
|
||||
using TilePartitioner = TilePartitioner_;
|
||||
using GemmPipeline = GemmPipeline_;
|
||||
using EpiloguePipeline = EpiloguePipeline_;
|
||||
|
||||
static_assert(
|
||||
TilePartitioner::PERSISTENT == PersistentDP,
|
||||
"Persistent flag from TilePartitioner must match Persistent flag from UniversalGemm.");
|
||||
|
||||
/// @brief Specify the layout configurations for A, B, and C
|
||||
using ALayout = typename GemmPipeline::ALayout;
|
||||
using BLayout = typename GemmPipeline::BLayout;
|
||||
using CLayout = typename GemmPipeline::CLayout;
|
||||
|
||||
/// @brief Specify the data type configurations for A, B, and C
|
||||
using ADataType = typename GemmPipeline::ADataType;
|
||||
using BDataType = typename GemmPipeline::BDataType;
|
||||
using CDataType = typename EpiloguePipeline::ODataType;
|
||||
|
||||
template <typename T>
|
||||
static constexpr bool is_tuple_v = is_detected<is_tuple, T>::value;
|
||||
|
||||
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_tuple_v<ALayout> && !is_tuple_v<ADataType>,
|
||||
"ALayout and ADataType must be scalars.");
|
||||
|
||||
/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_tuple_v<BLayout> && !is_tuple_v<BDataType>,
|
||||
"BLayout and BDataType must be scalars.");
|
||||
|
||||
/// @brief CLayout and CDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_tuple_v<CLayout> && !is_tuple_v<CDataType>,
|
||||
"CLayout and CDataType must be scalars.");
|
||||
|
||||
struct StreamKKernelArgs : ck_tile::UniversalGemmKernelArgs<>
|
||||
{
|
||||
StreamKKernelArgs(const StreamKHostArgs& host_args, index_t grid)
|
||||
: UniversalGemmKernelArgs{host_args.as_ptr,
|
||||
host_args.bs_ptr,
|
||||
host_args.ds_ptr,
|
||||
host_args.e_ptr,
|
||||
host_args.M,
|
||||
host_args.N,
|
||||
host_args.K,
|
||||
host_args.stride_As,
|
||||
host_args.stride_Bs,
|
||||
host_args.stride_Ds,
|
||||
host_args.stride_E,
|
||||
host_args.k_batch},
|
||||
reduction_strategy{host_args.reduction_strategy},
|
||||
// The workspace pointer is set to nullptr because we must first
|
||||
// instantiate the TilePartitioner to get the necessary size
|
||||
workspace_ptr{nullptr},
|
||||
tile_partitioner{TilePartitioner{host_args.M, host_args.N, host_args.K, grid}}
|
||||
|
||||
{
|
||||
}
|
||||
|
||||
/// @brief The strategy used by work groups to compute final results in C tensor.
|
||||
StreamKReductionStrategy reduction_strategy;
|
||||
/// @brief A pointer to a buffer in device memory for accumulating partial via reduction
|
||||
/// strategy.
|
||||
void* workspace_ptr;
|
||||
/// @brief An instance of the TilePartioner class for assisting with mapping workgroups to
|
||||
/// the C tensor.
|
||||
TilePartitioner tile_partitioner;
|
||||
};
|
||||
|
||||
using KernelArgs = StreamKKernelArgs;
|
||||
using Kernel = StreamKKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
using P_ = GemmPipeline;
|
||||
using WarpTile = typename P_::BlockGemmShape::WarpTile;
|
||||
|
||||
return concat('_', "streamk", gemm_prec_str<ADataType, BDataType>(),
|
||||
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
||||
concat('x', WarpTile::at(number<0>{}), WarpTile::at(number<1>{}), WarpTile::at(number<2>{})),
|
||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
/// @brief Compute the grid size for the Stream K kernel using the tile_partitioner.
|
||||
/// @return The grid size.
|
||||
CK_TILE_HOST static auto GridSize(const TilePartitioner& tile_partitioner) -> dim3
|
||||
{
|
||||
return tile_partitioner.grid_size();
|
||||
}
|
||||
|
||||
/// @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
|
||||
/// @return The maximum occupancy grid size.
|
||||
/// @note This function queries the maximum occupancy of the kernel using
|
||||
/// `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
|
||||
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
|
||||
{
|
||||
return UniversalGemmKernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
|
||||
{
|
||||
return UniversalGemmKernel::BlockSize();
|
||||
}
|
||||
|
||||
/// @brief Constructs kernel arguments for the Stream-K kernel.
|
||||
/// @param host_args Stream-K host arguments.
|
||||
/// @param num_cu Number of compute units (CUs). The default is the number of CUs on the device.
|
||||
/// The caller may select their own to assist with test reproducibility, etc.
|
||||
/// @param occupancy The maximum number of active blocks per CU for this kernel. The caller may
|
||||
/// select their own to assist with test reproducibility, etc.
|
||||
/// @return The kernel arguments for Stream-K.
|
||||
CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args,
|
||||
int num_cu = NumCU(),
|
||||
int occupancy = Occupancy())
|
||||
{
|
||||
const index_t grid = num_cu * occupancy;
|
||||
|
||||
return StreamKKernelArgs{host_args, grid};
|
||||
}
|
||||
|
||||
template <bool UseDefaultScheduler = true>
|
||||
CK_TILE_DEVICE static void
|
||||
RunGemm(const std::array<const ADataType*, UniversalGemmKernel::NumATensor>& as_ptr,
|
||||
const std::array<const BDataType*, UniversalGemmKernel::NumBTensor>& bs_ptr,
|
||||
const std::array<const void*, UniversalGemmKernel::NumDTensor>& ds_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const typename UniversalGemmKernel::KernelArgs& kargs,
|
||||
const index_t num_loop,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n,
|
||||
const index_t k_size)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
UniversalGemmKernel::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size);
|
||||
|
||||
const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
|
||||
const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
|
||||
|
||||
// Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute
|
||||
// has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this
|
||||
// case, we call the GemmPipeline's operator() function that takes both has_hot_loop and
|
||||
// tail_num.
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
|
||||
bs_block_window[UniversalGemmKernel::I0],
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const StreamKKernelArgs& kargs)
|
||||
{
|
||||
if(kargs.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("CK Tile Stream-K only supports the atomic reduction strategy.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return UniversalGemmKernel::IsSupportedArgument(kargs);
|
||||
}
|
||||
|
||||
/// @brief Computes the buffer size needed to store accumulation results for Stream K.
|
||||
/// @return The buffer size needed.
|
||||
CK_TILE_HOST static uint32_t GetWorkSpaceSize(const StreamKKernelArgs& kargs)
|
||||
{
|
||||
return kargs.tile_partitioner.GetWorkSpaceSize(sizeof(CDataType));
|
||||
}
|
||||
|
||||
/// @brief Sets the kargs' current workspace_ptr to the given workspace_ptr.
|
||||
/// @note Assumes that the given workspace_ptr points to allocated device memory.
|
||||
CK_TILE_HOST static void SetWorkSpacePointer(StreamKKernelArgs& kargs, void* workspace_ptr)
|
||||
{
|
||||
kargs.workspace_ptr = workspace_ptr;
|
||||
}
|
||||
|
||||
/// @brief Computes offsets into A, B, and C tensors then runs the GEMM pipeline and epilogue.
|
||||
/// @param kargs Stream-K kernel arguments.
|
||||
/// @param tile_idx The 1D tile index in the C tensor for this workgroup.
|
||||
/// @param num_loop The number of iterations (at the macro tile level) in the K dimension this
|
||||
/// workgroup will perform in the C tile.
|
||||
/// @param i_k_a The K offset in the A tensor.
|
||||
/// @param i_k_b The K offset in the B tensor.
|
||||
/// @param k_size The portion of the K dimension this workgroup processes in the assigned
|
||||
/// `tile_idx`.
|
||||
/// @param smem_ptr_0 Pointer to LDS.
|
||||
CK_TILE_DEVICE void BaseGemm(StreamKKernelArgs& kargs,
|
||||
index_t tile_idx,
|
||||
index_t num_loop,
|
||||
index_t i_k_a,
|
||||
index_t i_k_b,
|
||||
index_t k_size,
|
||||
void* smem_ptr_0) const
|
||||
{
|
||||
const auto c_macro_tile_idx = kargs.tile_partitioner.get_output_tile_index(tile_idx);
|
||||
index_t i_m = c_macro_tile_idx[UniversalGemmKernel::I0] * TilePartitioner::MPerBlock;
|
||||
index_t i_n = c_macro_tile_idx[UniversalGemmKernel::I1] * TilePartitioner::NPerBlock;
|
||||
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k_a;
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
|
||||
|
||||
// Run the GEMM pipeline and Epilogue.
|
||||
RunGemm(
|
||||
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, k_size);
|
||||
}
|
||||
|
||||
/// @brief Runs the main Stream-K algorithm.
|
||||
/// @param kargs Stream-K kernel arguments.
|
||||
/// @param cta_idx The current Stream-K workgroup's index.
|
||||
/// @param smem_ptr_0 Pointer to LDS.
|
||||
/// @note It is assumed that the first Stream-K workgroup has a `cta_idx` of zero. If a
|
||||
/// non-persistent data-parallel (DP) section is used, then a Stream-K workgroup's `cta_idx`
|
||||
/// should be something like `blockIdx.x` minus number of DP workgroups.
|
||||
CK_TILE_DEVICE void
|
||||
StreamKGemm(StreamKKernelArgs& kargs, index_t cta_idx, void* smem_ptr_0) const
|
||||
{
|
||||
index_t iter_start, iter_end;
|
||||
kargs.tile_partitioner.get_iter_boundaries(iter_start, iter_end, cta_idx);
|
||||
|
||||
while(iter_start < iter_end)
|
||||
{
|
||||
// Get the 1D tile index in the C tensor that this workgroup will work in for this
|
||||
// iteration of the loop.
|
||||
index_t tile_idx =
|
||||
amd_wave_read_first_lane(kargs.tile_partitioner.get_tile_index(iter_start));
|
||||
|
||||
// Get the start and end boundaries for the current tile.
|
||||
index_t tile_iter_start, tile_iter_end;
|
||||
kargs.tile_partitioner.get_tile_boundaries(tile_iter_start, tile_iter_end, tile_idx);
|
||||
|
||||
// Get the start and end iteration within the current tile for the workgroup.
|
||||
index_t local_iter_start = amd_wave_read_first_lane(
|
||||
kargs.tile_partitioner.get_local_iter(iter_start, tile_iter_start));
|
||||
index_t local_iter_end =
|
||||
amd_wave_read_first_lane(kargs.tile_partitioner.get_local_iter_end(
|
||||
tile_iter_start, iter_end, tile_iter_end));
|
||||
|
||||
// Get the iteration length.
|
||||
index_t num_loop_sk = local_iter_end - local_iter_start;
|
||||
|
||||
// Determine the total size along the K dimension the workgroup is using in this
|
||||
// iteration (used to construct tensor views).
|
||||
index_t k_size = num_loop_sk * TilePartitioner::KPerBlock;
|
||||
|
||||
// Get the K offsets for the A and B tensors
|
||||
auto [i_k_a, i_k_b] = GetKOffsets<ALayout, BLayout>(
|
||||
local_iter_start, kargs.stride_As[0], kargs.stride_Bs[0]);
|
||||
|
||||
if constexpr(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
BaseGemm(kargs, tile_idx, num_loop_sk, i_k_a, i_k_b, k_size, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
// TODO: Apply reduction logic.
|
||||
}
|
||||
|
||||
// Prepare for next Stream-K loop iteration.
|
||||
iter_start = tile_iter_end;
|
||||
block_sync_lds();
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Entry point for the Stream-K Kernel with non-persistent DP.
|
||||
///
|
||||
/// @par Overview
|
||||
/// For the Non-Persistent kernel, each data parallel workgroup will
|
||||
/// compute the results for their assigned macro-tile by calling `BaseGemm()`.
|
||||
/// The Stream-K workgroups will do their assigned work by calling
|
||||
/// `StreamKGemm()`, which calls `BaseGemm()` in the Stream-K loop.
|
||||
template <bool U = PersistentDP>
|
||||
CK_TILE_DEVICE typename std::enable_if_t<!U> operator()(StreamKKernelArgs kargs) const
|
||||
{
|
||||
// Allocate LDS
|
||||
__shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
|
||||
|
||||
index_t block_idx = ck_tile::get_block_1d_id();
|
||||
index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
|
||||
index_t dp_ctas = kargs.tile_partitioner.get_dp_ctas();
|
||||
bool is_dp_ctas = block_idx < kargs.tile_partitioner.get_dp_ctas();
|
||||
|
||||
// Check if at the data parallel section
|
||||
if(is_dp_ctas)
|
||||
{
|
||||
BaseGemm(kargs, block_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Stream-K
|
||||
StreamKGemm(kargs, block_idx - dp_ctas, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Entry point for the Stream-K Kernel with persistent DP.
|
||||
///
|
||||
/// @par Overview
|
||||
/// For the Persistent kernel, each workgroup will first compute their
|
||||
/// assigned data-parallel tiles. Each data parallel tile will be computed
|
||||
/// by calling `BaseGemm()`. Then the workgroups will proceed with the
|
||||
/// Stream-K portion by calling `StreamKGemm()`, which calls `BaseGemm()`
|
||||
/// in the Stream-K loop.
|
||||
template <bool U = PersistentDP>
|
||||
CK_TILE_DEVICE typename std::enable_if_t<U> operator()(StreamKKernelArgs kargs) const
|
||||
{
|
||||
// Allocate LDS
|
||||
__shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
|
||||
|
||||
index_t block_idx = ck_tile::get_block_1d_id();
|
||||
index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
|
||||
|
||||
// Data-parallel section
|
||||
for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles();
|
||||
tile_idx += kargs.tile_partitioner.get_grid())
|
||||
{
|
||||
BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
|
||||
}
|
||||
|
||||
// Stream-K section
|
||||
StreamKGemm(kargs, block_idx, smem_ptr_0);
|
||||
}
|
||||
|
||||
private:
|
||||
/// @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset is
|
||||
/// the starting macro tile index in the K dimension for the workgroup.
|
||||
/// @return A tuple containing the offsets into the A and B tensors accounting for the layouts
|
||||
/// of A and B.
|
||||
/// @note The default case is that A is assumed to be row major and B is assumed to be column
|
||||
/// major.
|
||||
template <typename ALayout, typename BLayout>
|
||||
CK_TILE_DEVICE static tuple<index_t, index_t>
|
||||
GetKOffsets(index_t iter_offset, index_t stride_a, index_t stride_b)
|
||||
{
|
||||
index_t stride_offset_a;
|
||||
index_t stride_offset_b;
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
stride_offset_a = stride_a;
|
||||
}
|
||||
else
|
||||
{
|
||||
stride_offset_a = 1;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
stride_offset_b = stride_b;
|
||||
}
|
||||
else
|
||||
{
|
||||
stride_offset_b = 1;
|
||||
}
|
||||
|
||||
index_t base_offset = iter_offset * TilePartitioner::KPerBlock;
|
||||
|
||||
return make_tuple(base_offset * stride_offset_a, base_offset * stride_offset_b);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static int NumCU()
|
||||
{
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
hip_check_error(hipGetDevice(&dev));
|
||||
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
|
||||
int num_cu = dev_prop.multiProcessorCount;
|
||||
|
||||
return num_cu;
|
||||
}
|
||||
|
||||
/// @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the kernel
|
||||
/// @return The occupancy
|
||||
/// @note This function queries the maximum occupancy of the kernel using
|
||||
/// `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
|
||||
CK_TILE_HOST static int Occupancy()
|
||||
{
|
||||
int occupancy;
|
||||
|
||||
// Since occupancy of 1 is valid for stream k, we set min_num_block_per_cu to 1
|
||||
constexpr int min_block_per_cu = 1;
|
||||
const auto kernel = kentry<min_block_per_cu, Kernel, KernelArgs>;
|
||||
|
||||
hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
|
||||
|
||||
return occupancy;
|
||||
}
|
||||
};
|
||||
} // namespace reboot
|
||||
|
||||
/// @brief The Stream K GEMM kernel host arguments.
|
||||
///
|
||||
|
||||
@@ -186,6 +186,11 @@ struct StreamKTilePartitionerBase
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_n() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns an estimate of the number of workgroups writing to the same macro tile in C.
|
||||
*/
|
||||
CK_TILE_HOST index_t estimate_num_wgs_per_tile() const noexcept;
|
||||
|
||||
protected:
|
||||
index_t num_tiles_;
|
||||
index_t grid_;
|
||||
@@ -246,6 +251,7 @@ struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true
|
||||
ck_tile::index_t grid);
|
||||
|
||||
public:
|
||||
static constexpr bool PERSISTENT = true;
|
||||
/**
|
||||
* @brief Calculates the launching grid size for the Stream-K kernel. In the Persistent
|
||||
* case, no extra workgroups are allocated for the data parallel section, making the grid
|
||||
@@ -292,6 +298,7 @@ struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, fals
|
||||
ck_tile::index_t grid);
|
||||
|
||||
public:
|
||||
static constexpr bool PERSISTENT = false;
|
||||
/**
|
||||
* @brief Calculates the launching grid size for the Stream-K kernel. In the Non-Persistent
|
||||
* case, extra workgroups are allocated for the data parallel section, making the grid
|
||||
|
||||
@@ -214,6 +214,27 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_n() c
|
||||
return n_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::estimate_num_wgs_per_tile()
|
||||
const noexcept
|
||||
{
|
||||
// In the case of non-atomic reduction or data-parallel only, there will always be 1 workgroup
|
||||
// writing final results to a given macro tile in C.
|
||||
int num_wgs_per_tile = 1;
|
||||
|
||||
// Otherwise, for atomics, multiple workgroups may be writing to the same macro tile in C.
|
||||
if(sk_ctas_ > 0 && ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
ck_tile::index_t iters_per_sk_cta_non_zero = ck_tile::max(iters_per_sk_cta_, 1);
|
||||
// Estimate the number of workgroups per macro tile.
|
||||
num_wgs_per_tile = (iters_per_tile_ / iters_per_sk_cta_non_zero) +
|
||||
((iters_per_tile_ % iters_per_sk_cta_non_zero) != 0);
|
||||
}
|
||||
|
||||
return std::max(num_wgs_per_tile, 1);
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType,
|
||||
StreamKReductionStrategy ReductionStrategyType,
|
||||
bool Persistent>
|
||||
|
||||
@@ -46,6 +46,7 @@ set(REGRESSION_TESTS
|
||||
test_ck_tile_fmha_fwd_bf16
|
||||
test_ck_tile_fmha_fwd_fp16
|
||||
test_ck_tile_fmha_fwd_fp8
|
||||
test_ck_tile_streamk_reboot_extended
|
||||
)
|
||||
|
||||
function(add_test_executable TEST_NAME)
|
||||
|
||||
@@ -117,6 +117,18 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# )
|
||||
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
|
||||
add_gtest_executable(test_ck_tile_streamk_reboot_smoke
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_fp16_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_bf16_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp
|
||||
test_gemm_streamk_reboot_util.cpp)
|
||||
add_gtest_executable(test_ck_tile_streamk_reboot_extended
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_fp16_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_bf16_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp
|
||||
test_gemm_streamk_reboot_util.cpp)
|
||||
else()
|
||||
message(DEBUG "Skipping test_ck_tile_streamk tests for current target")
|
||||
endif()
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_reboot_types.hpp"
|
||||
#include "test_gemm_streamk_reboot_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKRebootBf16NonPersistent : public TestCkTileStreamKReboot<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKRebootBf16NonPersistent
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKRebootBf16NonPersistent, KernelTypesStreamKBf16NonPersistent);
|
||||
|
||||
#include "test_gemm_streamk_reboot_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,19 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_reboot_types.hpp"
|
||||
#include "test_gemm_streamk_reboot_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKRebootBf16Persistent : public TestCkTileStreamKReboot<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKRebootBf16Persistent
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKRebootBf16Persistent, KernelTypesStreamKBf16Persistent);
|
||||
|
||||
#include "test_gemm_streamk_reboot_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,19 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_reboot_types.hpp"
|
||||
#include "test_gemm_streamk_reboot_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKRebootFp16NonPersistent : public TestCkTileStreamKReboot<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKRebootFp16NonPersistent
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKRebootFp16NonPersistent, KernelTypesStreamKFp16NonPersistent);
|
||||
|
||||
#include "test_gemm_streamk_reboot_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,19 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_reboot_types.hpp"
|
||||
#include "test_gemm_streamk_reboot_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKRebootFp16Persistent : public TestCkTileStreamKReboot<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKRebootFp16Persistent
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKRebootFp16Persistent, KernelTypesStreamKFp16Persistent);
|
||||
|
||||
#include "test_gemm_streamk_reboot_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,19 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_reboot_types.hpp"
|
||||
#include "test_gemm_streamk_reboot_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKRebootBf16NonPersistent : public TestCkTileStreamKReboot<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKRebootBf16NonPersistent
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKRebootBf16NonPersistent, KernelTypesStreamKBf16NonPersistent);
|
||||
|
||||
#include "test_gemm_streamk_reboot_smoke_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,19 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_reboot_types.hpp"
|
||||
#include "test_gemm_streamk_reboot_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKRebootBf16Persistent : public TestCkTileStreamKReboot<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKRebootBf16Persistent
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKRebootBf16Persistent, KernelTypesStreamKBf16Persistent);
|
||||
|
||||
#include "test_gemm_streamk_reboot_smoke_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,19 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_reboot_types.hpp"
|
||||
#include "test_gemm_streamk_reboot_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKRebootFp16NonPersistent : public TestCkTileStreamKReboot<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKRebootFp16NonPersistent
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKRebootFp16NonPersistent, KernelTypesStreamKFp16NonPersistent);
|
||||
|
||||
#include "test_gemm_streamk_reboot_smoke_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,19 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_reboot_types.hpp"
|
||||
#include "test_gemm_streamk_reboot_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKRebootFp16Persistent : public TestCkTileStreamKReboot<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKRebootFp16Persistent
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKRebootFp16Persistent, KernelTypesStreamKFp16Persistent);
|
||||
|
||||
#include "test_gemm_streamk_reboot_smoke_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,24 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_DP2TSK)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
|
||||
// For DP 2-Tile SK, there are 2 important terms:
|
||||
// Term 1: (M_Tile * num_cu * 2) - This ensures we have at least 2 cycles that will fully
|
||||
// saturate all CUs. This assumes tile sizes are large enough such that occupancy is 1.
|
||||
// Term 2: (M_Tile * 2) - This ensures we have 1 cycle that does not fully saturate all CUs
|
||||
// (i.e., we will have remainder tiles). This guarantees we have 1 full tile cycle plus
|
||||
// remainder tiles for the 2 Tile SK portion; the rest of the tiles will fully saturate all CUs
|
||||
// for the DP portion.
|
||||
ck_tile::index_t M = (M_Tile * num_cu * 2) + (M_Tile * 2);
|
||||
ck_tile::index_t N = N_Tile;
|
||||
ck_tile::index_t K = 2048;
|
||||
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_EdgeCase)
|
||||
{
|
||||
ck_tile::index_t M = 256;
|
||||
ck_tile::index_t N = 256;
|
||||
ck_tile::index_t K = 256;
|
||||
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_DPOnly)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
// For DP only, we ensure that the number of tiles is a multiple of the number of CUs. This
|
||||
// assumes tile sizes are large enough such that occupancy is 1.
|
||||
ck_tile::index_t M = M_Tile * num_cu;
|
||||
ck_tile::index_t N = N_Tile;
|
||||
ck_tile::index_t K = K_Tile;
|
||||
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
// For SK only, we have 4 macro tiles in C. But, we need to make sure there is enough work along
|
||||
// the K dimension to avoid falling into the edge case. Thus, we always have at least num_cu
|
||||
// macro tiles in the K dimension. This assumes tile sizes are large enough such that occupancy
|
||||
// is 1.
|
||||
ck_tile::index_t M = M_Tile * 2;
|
||||
ck_tile::index_t N = N_Tile * 2;
|
||||
ck_tile::index_t K = K_Tile * num_cu;
|
||||
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
56
test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_types.hpp
Normal file
56
test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_types.hpp
Normal file
@@ -0,0 +1,56 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using BF16 = ck_tile::bf16_t;
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using Persistent = std::true_type;
|
||||
using NonPersistent = std::false_type;
|
||||
|
||||
using I32 = ck_tile::number<32>;
|
||||
using I256 = ck_tile::number<256>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypesStreamKFp16Persistent = ::testing::Types<
|
||||
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent
|
||||
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf16Persistent = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp16NonPersistent = ::testing::Types<
|
||||
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent
|
||||
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKBf16NonPersistent = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>,
|
||||
std::tuple< Col, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>,
|
||||
std::tuple< Col, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, NonPersistent>
|
||||
>;
|
||||
// clang-format on
|
||||
10
test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_util.cpp
Normal file
10
test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_util.cpp
Normal file
@@ -0,0 +1,10 @@
|
||||
#include "test_gemm_streamk_reboot_util.hpp"
|
||||
|
||||
ck_tile::index_t get_cu_count()
|
||||
{
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
ck_tile::hip_check_error(hipGetDevice(&dev));
|
||||
ck_tile::hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
|
||||
return dev_prop.multiProcessorCount;
|
||||
}
|
||||
283
test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_util.hpp
Normal file
283
test/ck_tile/gemm_streamk/test_gemm_streamk_reboot_util.hpp
Normal file
@@ -0,0 +1,283 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
|
||||
// The logic below may need to become more advanced once bugs in Stream-K Tile Partitioner are
|
||||
// resolved. Because the number of WGs contributing to a macro tile in C may not be the same for
|
||||
// all macro tiles in C.
|
||||
|
||||
// Calculate error due to more than 1 WG contributing to the same macro tile in C
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
ck_tile::index_t get_cu_count();
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKReboot : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
static constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, Tuple>::value;
|
||||
static constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, Tuple>::value;
|
||||
static constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, Tuple>::value;
|
||||
static constexpr bool Persistent = std::tuple_element_t<10, Tuple>::value;
|
||||
|
||||
template <ck_tile::StreamKReductionStrategy ReductionStrategy,
|
||||
bool PadM = true,
|
||||
bool PadN = true,
|
||||
bool PadK = true,
|
||||
bool Preshuffle = false,
|
||||
bool TransposeC = false>
|
||||
ck_tile::index_t invoke_streamk(const ck_tile::reboot::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
constexpr bool kPadM = PadM;
|
||||
constexpr bool kPadN = PadN;
|
||||
constexpr bool kPadK = PadK;
|
||||
constexpr bool preshuffle = Preshuffle;
|
||||
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr bool StructuredSparsity = false;
|
||||
constexpr bool NumWaveGroup = 1;
|
||||
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::StreamKTilePartitioner_v2<GemmShape, ReductionStrategy, Persistent>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
TransposeC,
|
||||
StructuredSparsity,
|
||||
Persistent,
|
||||
NumWaveGroup,
|
||||
preshuffle>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
|
||||
// We create the GEMM pipeline without specifying has_hot_loop or tail_num.
|
||||
// This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K
|
||||
// while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K
|
||||
// Kernel's RunGemm function. This is a similar pattern used by grouped GEMM.
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
// For initial testing, we will just test with one pipeline.
|
||||
// More extensive testing is coming later and will test other pipelines.
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
using Kernel =
|
||||
ck_tile::reboot::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
EXPECT_TRUE(false);
|
||||
}
|
||||
|
||||
dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner);
|
||||
dim3 block_dims = Kernel::BlockSize();
|
||||
|
||||
ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grid_dims, block_dims, 0, kargs));
|
||||
|
||||
return kargs.tile_partitioner.estimate_num_wgs_per_tile();
|
||||
};
|
||||
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
// Since we are doing stream K, in the case of
|
||||
// atomics, multiple workgroups may write to the same
|
||||
// output tile in the C tensor, so we must atomic add
|
||||
// the results (not set)
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
|
||||
public:
|
||||
void Run(ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy =
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
ck_tile::index_t stride_A = 0,
|
||||
ck_tile::index_t stride_B = 0,
|
||||
ck_tile::index_t stride_C = 0)
|
||||
{
|
||||
// Since M, N, and K will vary depending on the number of CUs, we print it here to
|
||||
// facilitate test output readability.
|
||||
std::cout << "M: " << M << ", N: " << N << ", K: " << K << std::endl;
|
||||
|
||||
using namespace ck_tile::literals;
|
||||
|
||||
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
throw std::runtime_error("Reduction Strategy is current unsupported!\n");
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride =
|
||||
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
|
||||
if(stride == 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<decltype(layout),
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return row;
|
||||
}
|
||||
}
|
||||
else
|
||||
return stride;
|
||||
};
|
||||
|
||||
stride_A = f_get_default_stride(M, K, stride_A, ALayout{});
|
||||
stride_B = f_get_default_stride(K, N, stride_B, BLayout{});
|
||||
stride_C = f_get_default_stride(M, N, stride_C, CLayout{});
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{}));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{}));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
|
||||
|
||||
ck_tile::FillUniformDistributionIntegerValue<ADataType>{-5, 5, /*seed*/ 11939}(a_m_k);
|
||||
ck_tile::FillUniformDistributionIntegerValue<BDataType>{-5, 5, /*seed*/ 11940}(b_k_n);
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
ck_tile::reboot::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
reduction_strategy};
|
||||
|
||||
ck_tile::index_t num_accumulations_per_tile =
|
||||
invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
f_host_tensor_descriptor(M, N, stride_C, CLayout{}));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_host_ref);
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, num_accumulations_per_tile, max_accumulated_value);
|
||||
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
EXPECT_TRUE(pass);
|
||||
};
|
||||
};
|
||||
@@ -77,6 +77,26 @@ TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, ReductionStrategy)
|
||||
expected_partials_size + expected_flags_size);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseEstimateNumWgsPerTile, EstimateNumWgsPerTileLowerValue)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
EXPECT_EQ(tile_partitioner.estimate_num_wgs_per_tile(), 1);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseEstimateNumWgsPerTile, EstimateNumWgsPerTileEqualValue)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigSKOnlyWith2WgsPerSKTile;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
EXPECT_EQ(tile_partitioner.estimate_num_wgs_per_tile(), 2);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetLocalIter, GetLocalIter)
|
||||
{
|
||||
// Types
|
||||
|
||||
@@ -194,6 +194,23 @@ struct StreamKTilePartitionerBaseConfigDP2TileSK : public StreamKTilePartitioner
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigSKOnlyWith2WgsPerSKTile
|
||||
: public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M = 16;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
static constexpr ck_tile::index_t GRID = 8;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 4;
|
||||
static constexpr ck_tile::index_t K_TILE = 8;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigDPOnly : public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M = 12;
|
||||
|
||||
Reference in New Issue
Block a user