mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Stream-K operator() Reboot (#3064)
* Persistent Stream-K Kernel Implementation This change implements an operator() function in the reboot::StreamKKernel class that is enabled when the Persistent flag is set to true. In this case, the data-parallel portion and the Stream-K portion of the kernel are fully persistent. The changes were made in the reboot namespace. A future PR will remove the old Stream-K kernel class and remove the reboot namespace. * Unit Tests for Persistent Stream-K Kernel This change contains the inital test suite for the Persitent Stream-K Kernel. The files contain "reboot" in the name; a future PR will remove tests for the old Stream-K Kernel and remove the "reboot" naming. A future commit will add tests for the non-persistent kernel. Also added estimate_num_wgs_per_tile to the StreamKTilePartitionerBase class. This allows us to estimate the number of accumulations done per macro tile in C to use during validation when computing relative and absolute tolerance. * Adding implementation for the Non-Persistent Stream-K kernel This code is adding the operator() function for the Non-Persistent Stream-K kernel. Persistency of the kernel is determined through a template argument. The Non-Persistent kernel will allocate additional workgroups for the data parallel section, leading to a different structure for processing the data parallel and Stream-K sections. There has been an addition to the TilePartitioner to get access to the whether Persistent has been set to true or false in the StreamKKernel. * Adding in the tests for the Non-Persistent Stream-K kernel * Refactor Stream-K Reboot Unit Tests This commit makes the following changes: - Update test cases to determine M, N, and K based on the number of CUs. This ensures that each test case is one of Edge Case, SK Only, DP Only, or DP + 2 Tile SK regardless of the architecture. - Since the DP + 2 Tile SK test case takes long to run, this change moves this case into a separate .inc file and labels it as an extended test. - Since the extended test takes > 30 seconds to run, this test is added to the list of regression tests. * Fix spelling errors in comments for test cases Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Changes based on review Removed const volatile for typenames Set up alias for is_tuple_t Naming changes for clarity: GemmCommon -> BaseGemm Moved std::enable_if_t out of template parameters and changed to a return type for operator() Added constructor for StreamKKernelArgs to clarify UniversalGemm inheritance --------- Co-authored-by: Emily Martins <emily.martins@amd.com> Co-authored-by: Christopher Millette <63608002+cgmillette@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -8,6 +8,478 @@
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace reboot {
|
||||
|
||||
/// @brief The Stream K GEMM kernel host arguments.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This structure is passed to @ref StreamKKernel "StreamKKernel" when creating the kernel
|
||||
/// arguments object. It contains all necessary information required to build proper kernel
|
||||
/// arguments and launch the kernel on GPU. This structure defines the GEMM problem
|
||||
/// configuration by stating all required information like M,N,K sizes and respective strides.
|
||||
struct StreamKHostArgs : public ck_tile::UniversalGemmHostArgs<>
|
||||
{
|
||||
CK_TILE_HOST explicit StreamKHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
void* c_ptr_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
index_t stride_C_,
|
||||
StreamKReductionStrategy reduction_strategy_)
|
||||
: UniversalGemmHostArgs<>({a_ptr_},
|
||||
{b_ptr_},
|
||||
{/*ds_ptr*/},
|
||||
c_ptr_,
|
||||
/*k_batch_ =*/1,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
{stride_A_},
|
||||
{stride_B_},
|
||||
{/*stride_Ds_*/},
|
||||
stride_C_),
|
||||
reduction_strategy{reduction_strategy_}
|
||||
{
|
||||
}
|
||||
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy;
|
||||
};
|
||||
|
||||
/// @brief The Stream K GEMM kernel class.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This class is responsible for the Stream-K kernel, making use of UniversalGemm.
|
||||
// The main kernel functions are the operator() functions. There is one for Persistent
|
||||
// and one for Non-Persistent data parallel sections of the Stream-K algorithm.
|
||||
//
|
||||
// Both the Non-Persistent and Persistent kernels make use of `BaseGemm()` and
|
||||
// `StreamKGemm()`. `BaseGemm()` computes offsets into the A,B,C tensors, then calls
|
||||
// `RunGemm()` which runs the GEMM pipeline and epilogue. `StreamKGemm()` performs the
|
||||
// main Stream-K algorithm. Each iteration of the Stream-K loop calls `BaseGemm()`.
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct StreamKKernel
|
||||
{
|
||||
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
|
||||
/// functions.
|
||||
using UniversalGemmKernel =
|
||||
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
||||
|
||||
static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize;
|
||||
static constexpr bool PersistentDP = UniversalGemmKernel::PersistentKernel;
|
||||
|
||||
using TilePartitioner = TilePartitioner_;
|
||||
using GemmPipeline = GemmPipeline_;
|
||||
using EpiloguePipeline = EpiloguePipeline_;
|
||||
|
||||
static_assert(
|
||||
TilePartitioner::PERSISTENT == PersistentDP,
|
||||
"Persistent flag from TilePartitioner must match Persistent flag from UniversalGemm.");
|
||||
|
||||
/// @brief Specify the layout configurations for A, B, and C
|
||||
using ALayout = typename GemmPipeline::ALayout;
|
||||
using BLayout = typename GemmPipeline::BLayout;
|
||||
using CLayout = typename GemmPipeline::CLayout;
|
||||
|
||||
/// @brief Specify the data type configurations for A, B, and C
|
||||
using ADataType = typename GemmPipeline::ADataType;
|
||||
using BDataType = typename GemmPipeline::BDataType;
|
||||
using CDataType = typename EpiloguePipeline::ODataType;
|
||||
|
||||
template <typename T>
|
||||
static constexpr bool is_tuple_v = is_detected<is_tuple, T>::value;
|
||||
|
||||
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_tuple_v<ALayout> && !is_tuple_v<ADataType>,
|
||||
"ALayout and ADataType must be scalars.");
|
||||
|
||||
/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_tuple_v<BLayout> && !is_tuple_v<BDataType>,
|
||||
"BLayout and BDataType must be scalars.");
|
||||
|
||||
/// @brief CLayout and CDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_tuple_v<CLayout> && !is_tuple_v<CDataType>,
|
||||
"CLayout and CDataType must be scalars.");
|
||||
|
||||
struct StreamKKernelArgs : ck_tile::UniversalGemmKernelArgs<>
|
||||
{
|
||||
StreamKKernelArgs(const StreamKHostArgs& host_args, index_t grid)
|
||||
: UniversalGemmKernelArgs{host_args.as_ptr,
|
||||
host_args.bs_ptr,
|
||||
host_args.ds_ptr,
|
||||
host_args.e_ptr,
|
||||
host_args.M,
|
||||
host_args.N,
|
||||
host_args.K,
|
||||
host_args.stride_As,
|
||||
host_args.stride_Bs,
|
||||
host_args.stride_Ds,
|
||||
host_args.stride_E,
|
||||
host_args.k_batch},
|
||||
reduction_strategy{host_args.reduction_strategy},
|
||||
// The workspace pointer is set to nullptr because we must first
|
||||
// instantiate the TilePartitioner to get the necessary size
|
||||
workspace_ptr{nullptr},
|
||||
tile_partitioner{TilePartitioner{host_args.M, host_args.N, host_args.K, grid}}
|
||||
|
||||
{
|
||||
}
|
||||
|
||||
/// @brief The strategy used by work groups to compute final results in C tensor.
|
||||
StreamKReductionStrategy reduction_strategy;
|
||||
/// @brief A pointer to a buffer in device memory for accumulating partial via reduction
|
||||
/// strategy.
|
||||
void* workspace_ptr;
|
||||
/// @brief An instance of the TilePartioner class for assisting with mapping workgroups to
|
||||
/// the C tensor.
|
||||
TilePartitioner tile_partitioner;
|
||||
};
|
||||
|
||||
using KernelArgs = StreamKKernelArgs;
|
||||
using Kernel = StreamKKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
using P_ = GemmPipeline;
|
||||
using WarpTile = typename P_::BlockGemmShape::WarpTile;
|
||||
|
||||
return concat('_', "streamk", gemm_prec_str<ADataType, BDataType>(),
|
||||
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
||||
concat('x', WarpTile::at(number<0>{}), WarpTile::at(number<1>{}), WarpTile::at(number<2>{})),
|
||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
/// @brief Compute the grid size for the Stream K kernel using the tile_partitioner.
|
||||
/// @return The grid size.
|
||||
CK_TILE_HOST static auto GridSize(const TilePartitioner& tile_partitioner) -> dim3
|
||||
{
|
||||
return tile_partitioner.grid_size();
|
||||
}
|
||||
|
||||
/// @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
|
||||
/// @return The maximum occupancy grid size.
|
||||
/// @note This function queries the maximum occupancy of the kernel using
|
||||
/// `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
|
||||
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
|
||||
{
|
||||
return UniversalGemmKernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
|
||||
{
|
||||
return UniversalGemmKernel::BlockSize();
|
||||
}
|
||||
|
||||
/// @brief Constructs kernel arguments for the Stream-K kernel.
|
||||
/// @param host_args Stream-K host arguments.
|
||||
/// @param num_cu Number of compute units (CUs). The default is the number of CUs on the device.
|
||||
/// The caller may select their own to assist with test reproducibility, etc.
|
||||
/// @param occupancy The maximum number of active blocks per CU for this kernel. The caller may
|
||||
/// select their own to assist with test reproducibility, etc.
|
||||
/// @return The kernel arguments for Stream-K.
|
||||
CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args,
|
||||
int num_cu = NumCU(),
|
||||
int occupancy = Occupancy())
|
||||
{
|
||||
const index_t grid = num_cu * occupancy;
|
||||
|
||||
return StreamKKernelArgs{host_args, grid};
|
||||
}
|
||||
|
||||
template <bool UseDefaultScheduler = true>
|
||||
CK_TILE_DEVICE static void
|
||||
RunGemm(const std::array<const ADataType*, UniversalGemmKernel::NumATensor>& as_ptr,
|
||||
const std::array<const BDataType*, UniversalGemmKernel::NumBTensor>& bs_ptr,
|
||||
const std::array<const void*, UniversalGemmKernel::NumDTensor>& ds_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const typename UniversalGemmKernel::KernelArgs& kargs,
|
||||
const index_t num_loop,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n,
|
||||
const index_t k_size)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
UniversalGemmKernel::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size);
|
||||
|
||||
const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
|
||||
const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
|
||||
|
||||
// Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute
|
||||
// has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this
|
||||
// case, we call the GemmPipeline's operator() function that takes both has_hot_loop and
|
||||
// tail_num.
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
|
||||
bs_block_window[UniversalGemmKernel::I0],
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const StreamKKernelArgs& kargs)
|
||||
{
|
||||
if(kargs.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("CK Tile Stream-K only supports the atomic reduction strategy.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return UniversalGemmKernel::IsSupportedArgument(kargs);
|
||||
}
|
||||
|
||||
/// @brief Computes the buffer size needed to store accumulation results for Stream K.
|
||||
/// @return The buffer size needed.
|
||||
CK_TILE_HOST static uint32_t GetWorkSpaceSize(const StreamKKernelArgs& kargs)
|
||||
{
|
||||
return kargs.tile_partitioner.GetWorkSpaceSize(sizeof(CDataType));
|
||||
}
|
||||
|
||||
/// @brief Sets the kargs' current workspace_ptr to the given workspace_ptr.
|
||||
/// @note Assumes that the given workspace_ptr points to allocated device memory.
|
||||
CK_TILE_HOST static void SetWorkSpacePointer(StreamKKernelArgs& kargs, void* workspace_ptr)
|
||||
{
|
||||
kargs.workspace_ptr = workspace_ptr;
|
||||
}
|
||||
|
||||
/// @brief Computes offsets into A, B, and C tensors then runs the GEMM pipeline and epilogue.
|
||||
/// @param kargs Stream-K kernel arguments.
|
||||
/// @param tile_idx The 1D tile index in the C tensor for this workgroup.
|
||||
/// @param num_loop The number of iterations (at the macro tile level) in the K dimension this
|
||||
/// workgroup will perform in the C tile.
|
||||
/// @param i_k_a The K offset in the A tensor.
|
||||
/// @param i_k_b The K offset in the B tensor.
|
||||
/// @param k_size The portion of the K dimension this workgroup processes in the assigned
|
||||
/// `tile_idx`.
|
||||
/// @param smem_ptr_0 Pointer to LDS.
|
||||
CK_TILE_DEVICE void BaseGemm(StreamKKernelArgs& kargs,
|
||||
index_t tile_idx,
|
||||
index_t num_loop,
|
||||
index_t i_k_a,
|
||||
index_t i_k_b,
|
||||
index_t k_size,
|
||||
void* smem_ptr_0) const
|
||||
{
|
||||
const auto c_macro_tile_idx = kargs.tile_partitioner.get_output_tile_index(tile_idx);
|
||||
index_t i_m = c_macro_tile_idx[UniversalGemmKernel::I0] * TilePartitioner::MPerBlock;
|
||||
index_t i_n = c_macro_tile_idx[UniversalGemmKernel::I1] * TilePartitioner::NPerBlock;
|
||||
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k_a;
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
|
||||
|
||||
// Run the GEMM pipeline and Epilogue.
|
||||
RunGemm(
|
||||
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, k_size);
|
||||
}
|
||||
|
||||
/// @brief Runs the main Stream-K algorithm.
|
||||
/// @param kargs Stream-K kernel arguments.
|
||||
/// @param cta_idx The current Stream-K workgroup's index.
|
||||
/// @param smem_ptr_0 Pointer to LDS.
|
||||
/// @note It is assumed that the first Stream-K workgroup has a `cta_idx` of zero. If a
|
||||
/// non-persistent data-parallel (DP) section is used, then a Stream-K workgroup's `cta_idx`
|
||||
/// should be something like `blockIdx.x` minus number of DP workgroups.
|
||||
CK_TILE_DEVICE void
|
||||
StreamKGemm(StreamKKernelArgs& kargs, index_t cta_idx, void* smem_ptr_0) const
|
||||
{
|
||||
index_t iter_start, iter_end;
|
||||
kargs.tile_partitioner.get_iter_boundaries(iter_start, iter_end, cta_idx);
|
||||
|
||||
while(iter_start < iter_end)
|
||||
{
|
||||
// Get the 1D tile index in the C tensor that this workgroup will work in for this
|
||||
// iteration of the loop.
|
||||
index_t tile_idx =
|
||||
amd_wave_read_first_lane(kargs.tile_partitioner.get_tile_index(iter_start));
|
||||
|
||||
// Get the start and end boundaries for the current tile.
|
||||
index_t tile_iter_start, tile_iter_end;
|
||||
kargs.tile_partitioner.get_tile_boundaries(tile_iter_start, tile_iter_end, tile_idx);
|
||||
|
||||
// Get the start and end iteration within the current tile for the workgroup.
|
||||
index_t local_iter_start = amd_wave_read_first_lane(
|
||||
kargs.tile_partitioner.get_local_iter(iter_start, tile_iter_start));
|
||||
index_t local_iter_end =
|
||||
amd_wave_read_first_lane(kargs.tile_partitioner.get_local_iter_end(
|
||||
tile_iter_start, iter_end, tile_iter_end));
|
||||
|
||||
// Get the iteration length.
|
||||
index_t num_loop_sk = local_iter_end - local_iter_start;
|
||||
|
||||
// Determine the total size along the K dimension the workgroup is using in this
|
||||
// iteration (used to construct tensor views).
|
||||
index_t k_size = num_loop_sk * TilePartitioner::KPerBlock;
|
||||
|
||||
// Get the K offsets for the A and B tensors
|
||||
auto [i_k_a, i_k_b] = GetKOffsets<ALayout, BLayout>(
|
||||
local_iter_start, kargs.stride_As[0], kargs.stride_Bs[0]);
|
||||
|
||||
if constexpr(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
BaseGemm(kargs, tile_idx, num_loop_sk, i_k_a, i_k_b, k_size, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
// TODO: Apply reduction logic.
|
||||
}
|
||||
|
||||
// Prepare for next Stream-K loop iteration.
|
||||
iter_start = tile_iter_end;
|
||||
block_sync_lds();
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Entry point for the Stream-K Kernel with non-persistent DP.
|
||||
///
|
||||
/// @par Overview
|
||||
/// For the Non-Persistent kernel, each data parallel workgroup will
|
||||
/// compute the results for their assigned macro-tile by calling `BaseGemm()`.
|
||||
/// The Stream-K workgroups will do their assigned work by calling
|
||||
/// `StreamKGemm()`, which calls `BaseGemm()` in the Stream-K loop.
|
||||
template <bool U = PersistentDP>
|
||||
CK_TILE_DEVICE typename std::enable_if_t<!U> operator()(StreamKKernelArgs kargs) const
|
||||
{
|
||||
// Allocate LDS
|
||||
__shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
|
||||
|
||||
index_t block_idx = ck_tile::get_block_1d_id();
|
||||
index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
|
||||
index_t dp_ctas = kargs.tile_partitioner.get_dp_ctas();
|
||||
bool is_dp_ctas = block_idx < kargs.tile_partitioner.get_dp_ctas();
|
||||
|
||||
// Check if at the data parallel section
|
||||
if(is_dp_ctas)
|
||||
{
|
||||
BaseGemm(kargs, block_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Stream-K
|
||||
StreamKGemm(kargs, block_idx - dp_ctas, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Entry point for the Stream-K Kernel with persistent DP.
|
||||
///
|
||||
/// @par Overview
|
||||
/// For the Persistent kernel, each workgroup will first compute their
|
||||
/// assigned data-parallel tiles. Each data parallel tile will be computed
|
||||
/// by calling `BaseGemm()`. Then the workgroups will proceed with the
|
||||
/// Stream-K portion by calling `StreamKGemm()`, which calls `BaseGemm()`
|
||||
/// in the Stream-K loop.
|
||||
template <bool U = PersistentDP>
|
||||
CK_TILE_DEVICE typename std::enable_if_t<U> operator()(StreamKKernelArgs kargs) const
|
||||
{
|
||||
// Allocate LDS
|
||||
__shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
|
||||
|
||||
index_t block_idx = ck_tile::get_block_1d_id();
|
||||
index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
|
||||
|
||||
// Data-parallel section
|
||||
for(index_t tile_idx = block_idx; tile_idx < kargs.tile_partitioner.get_dp_tiles();
|
||||
tile_idx += kargs.tile_partitioner.get_grid())
|
||||
{
|
||||
BaseGemm(kargs, tile_idx, dp_num_loop, 0, 0, kargs.K, smem_ptr_0);
|
||||
}
|
||||
|
||||
// Stream-K section
|
||||
StreamKGemm(kargs, block_idx, smem_ptr_0);
|
||||
}
|
||||
|
||||
private:
|
||||
/// @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset is
|
||||
/// the starting macro tile index in the K dimension for the workgroup.
|
||||
/// @return A tuple containing the offsets into the A and B tensors accounting for the layouts
|
||||
/// of A and B.
|
||||
/// @note The default case is that A is assumed to be row major and B is assumed to be column
|
||||
/// major.
|
||||
template <typename ALayout, typename BLayout>
|
||||
CK_TILE_DEVICE static tuple<index_t, index_t>
|
||||
GetKOffsets(index_t iter_offset, index_t stride_a, index_t stride_b)
|
||||
{
|
||||
index_t stride_offset_a;
|
||||
index_t stride_offset_b;
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
stride_offset_a = stride_a;
|
||||
}
|
||||
else
|
||||
{
|
||||
stride_offset_a = 1;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
stride_offset_b = stride_b;
|
||||
}
|
||||
else
|
||||
{
|
||||
stride_offset_b = 1;
|
||||
}
|
||||
|
||||
index_t base_offset = iter_offset * TilePartitioner::KPerBlock;
|
||||
|
||||
return make_tuple(base_offset * stride_offset_a, base_offset * stride_offset_b);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static int NumCU()
|
||||
{
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
hip_check_error(hipGetDevice(&dev));
|
||||
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
|
||||
int num_cu = dev_prop.multiProcessorCount;
|
||||
|
||||
return num_cu;
|
||||
}
|
||||
|
||||
/// @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the kernel
|
||||
/// @return The occupancy
|
||||
/// @note This function queries the maximum occupancy of the kernel using
|
||||
/// `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
|
||||
CK_TILE_HOST static int Occupancy()
|
||||
{
|
||||
int occupancy;
|
||||
|
||||
// Since occupancy of 1 is valid for stream k, we set min_num_block_per_cu to 1
|
||||
constexpr int min_block_per_cu = 1;
|
||||
const auto kernel = kentry<min_block_per_cu, Kernel, KernelArgs>;
|
||||
|
||||
hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
|
||||
|
||||
return occupancy;
|
||||
}
|
||||
};
|
||||
} // namespace reboot
|
||||
|
||||
/// @brief The Stream K GEMM kernel host arguments.
|
||||
///
|
||||
|
||||
@@ -186,6 +186,11 @@ struct StreamKTilePartitionerBase
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_n() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns an estimate of the number of workgroups writing to the same macro tile in C.
|
||||
*/
|
||||
CK_TILE_HOST index_t estimate_num_wgs_per_tile() const noexcept;
|
||||
|
||||
protected:
|
||||
index_t num_tiles_;
|
||||
index_t grid_;
|
||||
@@ -246,6 +251,7 @@ struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true
|
||||
ck_tile::index_t grid);
|
||||
|
||||
public:
|
||||
static constexpr bool PERSISTENT = true;
|
||||
/**
|
||||
* @brief Calculates the launching grid size for the Stream-K kernel. In the Persistent
|
||||
* case, no extra workgroups are allocated for the data parallel section, making the grid
|
||||
@@ -292,6 +298,7 @@ struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, fals
|
||||
ck_tile::index_t grid);
|
||||
|
||||
public:
|
||||
static constexpr bool PERSISTENT = false;
|
||||
/**
|
||||
* @brief Calculates the launching grid size for the Stream-K kernel. In the Non-Persistent
|
||||
* case, extra workgroups are allocated for the data parallel section, making the grid
|
||||
|
||||
@@ -214,6 +214,27 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_n() c
|
||||
return n_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::estimate_num_wgs_per_tile()
|
||||
const noexcept
|
||||
{
|
||||
// In the case of non-atomic reduction or data-parallel only, there will always be 1 workgroup
|
||||
// writing final results to a given macro tile in C.
|
||||
int num_wgs_per_tile = 1;
|
||||
|
||||
// Otherwise, for atomics, multiple workgroups may be writing to the same macro tile in C.
|
||||
if(sk_ctas_ > 0 && ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
ck_tile::index_t iters_per_sk_cta_non_zero = ck_tile::max(iters_per_sk_cta_, 1);
|
||||
// Estimate the number of workgroups per macro tile.
|
||||
num_wgs_per_tile = (iters_per_tile_ / iters_per_sk_cta_non_zero) +
|
||||
((iters_per_tile_ % iters_per_sk_cta_non_zero) != 0);
|
||||
}
|
||||
|
||||
return std::max(num_wgs_per_tile, 1);
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType,
|
||||
StreamKReductionStrategy ReductionStrategyType,
|
||||
bool Persistent>
|
||||
|
||||
Reference in New Issue
Block a user