mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
[CK_TILE] Remove Old CK Tile Stream-K Artifacts (#3202)
* Remove old CK Tile Stream-K implementation The original CK Stream-K implementation was based on old CK's Stream-K block to C tile map. However, this implementation did not align with the original Stream-K paper. Thus, we implemented a new tile partitioner and associated Stream-K kernel, which was placed in the reboot namespace. Now that the new Stream-K implementation is ready, this change removes all artifacts of the old implementation. Specifically, the following changes were made: - Removes old Stream-K tile partitioner from CK Tile - Removes the reboot namespace such that the new implementation resides in the ck_tile namespace only. - Adds tests for bf8 and fp8 using the new implementation - Removes tests for the old implementation - Remove the v2 suffix from the new CK Tile Tile Partitioner derived classes. - Updates Stream-K Kernel ops file to use /** commenting style. * Remove v2 from tile partitioner validation function names
This commit is contained in:
@@ -71,16 +71,16 @@ invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
bool flush_cache,
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy)
|
||||
{
|
||||
ck_tile::reboot::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
reduction_strategy};
|
||||
ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
reduction_strategy};
|
||||
|
||||
std::tuple<float, ck_tile::index_t> ave_time_and_batch;
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ template <typename GemmConfig,
|
||||
typename ELayout,
|
||||
typename CDEElementWise,
|
||||
ck_tile::StreamKReductionStrategy ReductionStrategy>
|
||||
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::reboot::StreamKHostArgs& args,
|
||||
std::tuple<float, ck_tile::index_t> gemm(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
@@ -28,7 +28,7 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::reboot::StreamKHostArgs&
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::StreamKTilePartitioner_v2<GemmShape, ReductionStrategy, GemmConfig::Persistent>;
|
||||
ck_tile::StreamKTilePartitioner<GemmShape, ReductionStrategy, GemmConfig::Persistent>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
@@ -77,7 +77,7 @@ std::tuple<float, ck_tile::index_t> gemm(const ck_tile::reboot::StreamKHostArgs&
|
||||
memory_operation.value,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
|
||||
using Kernel = ck_tile::reboot::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const auto workspace_size = Kernel::GetWorkSpaceSize(kargs);
|
||||
|
||||
@@ -11,33 +11,4 @@ enum StreamKReductionStrategy : uint32_t
|
||||
Atomic = 0u,
|
||||
Reduction = 1u
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Estimates the number of Stream-K workgroups per macro tile in the C tensor.
|
||||
*
|
||||
* @param sk_ctas Number of Stream-K workgroups.
|
||||
* @param iters_per_sk_cta Number of iterations per Stream-K workgroup.
|
||||
* @param iters_per_tile Number of iterations per tile (i.e., the number of macro tiles in the K
|
||||
* dimension).
|
||||
* @return ck_tile::index_t An estimate of the number of workgroups per macro tile in the C tensor.
|
||||
* @note It is assumed that `iters_per_sk_cta` > 0.
|
||||
*/
|
||||
template <ck_tile::StreamKReductionStrategy ReductionStrategy>
|
||||
ck_tile::index_t
|
||||
estimate_num_wgs_per_tile(index_t sk_ctas, index_t iters_per_sk_cta, index_t iters_per_tile)
|
||||
{
|
||||
// In the case of non-atomic reduction or data-parallel only, there will always be 1 workgroup
|
||||
// writing final results to a given macro tile in C.
|
||||
int num_wgs_per_tile = 1;
|
||||
|
||||
// Otherwise, for atomics, multiple workgroups may be writing to the same macro tile in C.
|
||||
if(sk_ctas > 0 && ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
// Estimate the number of workgroups per macro tile.
|
||||
num_wgs_per_tile =
|
||||
(iters_per_tile / iters_per_sk_cta) + ((iters_per_tile % iters_per_sk_cta) != 0);
|
||||
}
|
||||
|
||||
return std::max(num_wgs_per_tile, 1);
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -364,448 +364,4 @@ struct GemmSpatiallyLocalTilePartitioner
|
||||
index_t M;
|
||||
index_t N;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Stream-K tile partitioner that dynamically balances work across workgroups
|
||||
*
|
||||
* This partitioner is responsible for mapping workgroups to tiles in the C tensor
|
||||
* for the Stream-K algorithm which decomposes the GEMM problem
|
||||
* into smaller work units and distributes them more evenly across available blocks,
|
||||
* improving load balancing especially for cases where the K dimension is large.
|
||||
*
|
||||
* @tparam BlockGemmShapeType A class providing basic GEMM parameters.
|
||||
* @tparam ReductionStrategy A class that defines the reduction strategy for the results in
|
||||
* the C Tensor.
|
||||
* @tparam TileSwizzleSubM A value that defines the size of the swizzle group along the m
|
||||
* dimension, where the swizzle group denotes consecutive tiles down a column. For instance a
|
||||
* swizzle group of 8 denotes tiles 0, 1, ..., 7, map to tiles [0,0], [1,0], ..., [7,0] in the C
|
||||
* tensor.
|
||||
*/
|
||||
template <typename BlockGemmShapeType,
|
||||
StreamKReductionStrategy ReductionStrategy = ck_tile::StreamKReductionStrategy::Atomic,
|
||||
uint32_t TileSwizzleSubM = 8>
|
||||
struct StreamKTilePartitioner
|
||||
{
|
||||
static constexpr uint32_t MPerBlock = BlockGemmShapeType::kM;
|
||||
static constexpr uint32_t NPerBlock = BlockGemmShapeType::kN;
|
||||
static constexpr uint32_t KPerBlock = BlockGemmShapeType::kK;
|
||||
|
||||
CK_TILE_HOST_DEVICE StreamKTilePartitioner() noexcept = delete;
|
||||
|
||||
/**
|
||||
* @brief Construct Stream-K tile partitioner with problem dimensions
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE StreamKTilePartitioner(uint32_t M,
|
||||
uint32_t N,
|
||||
uint32_t K,
|
||||
uint32_t num_cu,
|
||||
uint32_t occupancy,
|
||||
uint32_t sk_blocks = 0xffffffff) noexcept
|
||||
: M_(M), N_(N), K_(K)
|
||||
{
|
||||
num_tile_m_ = integer_divide_ceil(M, MPerBlock);
|
||||
num_tile_n_ = integer_divide_ceil(N, NPerBlock);
|
||||
num_tile_k_ = integer_divide_ceil(K, KPerBlock);
|
||||
|
||||
constexpr uint32_t min_k_iters_per_sk_block = 2;
|
||||
uint32_t num_tiles = num_tile_m_ * num_tile_n_;
|
||||
k_iters_per_tile = mdiv(num_tile_k_);
|
||||
|
||||
// one cu can hold one wg at one time, from the whole cZ's point of view
|
||||
// if number of wg is same as num_cu, we call it 1 dispatch
|
||||
// if number of wg is 2x num_cu, we call it 2 dispatches.
|
||||
// one dispatch can deliver wg same as num_cu (full dispatch), or less than num_cu (partial
|
||||
// dispatch)
|
||||
//
|
||||
const uint32_t full_dispatches = num_tiles / num_cu;
|
||||
const uint32_t full_dispatch_tiles = full_dispatches * num_cu;
|
||||
const uint32_t partial_dispatch_tiles = num_tiles - full_dispatch_tiles;
|
||||
|
||||
uint32_t sk_occupancy = occupancy;
|
||||
uint32_t dp_tiles = full_dispatch_tiles;
|
||||
uint32_t sk_tiles = partial_dispatch_tiles;
|
||||
|
||||
if(full_dispatches < occupancy)
|
||||
{
|
||||
// in this case, we allocate all blocks as sk blocks
|
||||
// sk_occupancy = occupancy - full_dispatches;
|
||||
sk_occupancy = 1;
|
||||
dp_tiles = full_dispatch_tiles;
|
||||
sk_tiles = partial_dispatch_tiles;
|
||||
}
|
||||
else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1))
|
||||
{
|
||||
// e.g. occupancy = 2, full_dispatches = 3, 5, 7 ...
|
||||
// occupancy = 3, full_dispatches = 5, 8, 11 ...
|
||||
// occupancy = 4, full_dispatches = 7, 11 ...
|
||||
sk_occupancy = 1; // left 1 slot for sk occupancy
|
||||
dp_tiles = full_dispatch_tiles;
|
||||
sk_tiles = partial_dispatch_tiles;
|
||||
}
|
||||
else
|
||||
{
|
||||
// otherwise, we reduce 1 dispatch from dp, together with partial dispatch,
|
||||
// to construct sk dispatch
|
||||
sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy);
|
||||
dp_tiles = full_dispatch_tiles - num_cu;
|
||||
sk_tiles = partial_dispatch_tiles + num_cu;
|
||||
}
|
||||
|
||||
// uint32_t dp_iters_per_block = k_iters_per_tile.get();
|
||||
uint32_t sk_total_iters = k_iters_per_tile.get() * sk_tiles;
|
||||
uint32_t dp_num_blocks = 0;
|
||||
|
||||
{
|
||||
const uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1);
|
||||
const uint32_t max_sk_tiles =
|
||||
(sk_tiles >= num_cu) ? num_cu * sk_occupancy
|
||||
: min(num_cu, sk_total_iters / min_k_iters_per_sk_block);
|
||||
|
||||
// if use dp for sk-block, how many iters do we need
|
||||
const uint32_t dp_for_sk_iters = k_iters_per_tile.get();
|
||||
|
||||
uint32_t best_sk_score =
|
||||
std::numeric_limits<int>::max(); // we need to find the smallest sk iters
|
||||
for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles;
|
||||
tentative_sk_blocks++)
|
||||
{
|
||||
const uint32_t tentative_sk_iters_per_block =
|
||||
(sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks;
|
||||
const uint32_t tentative_sk_iters = tentative_sk_iters_per_block;
|
||||
const uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles;
|
||||
|
||||
// the more sk_blocks_per_tile, the worse the overhead
|
||||
uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile;
|
||||
if(tentative_sk_blocks % sk_tiles != 0)
|
||||
{
|
||||
// penalty for uneven divide
|
||||
cross_sk_blocks_overhead +=
|
||||
sk_blocks_per_tile * tentative_sk_iters_per_block / 50;
|
||||
}
|
||||
|
||||
const uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead;
|
||||
|
||||
if(tentative_sk_score < best_sk_score)
|
||||
{
|
||||
best_sk_score = tentative_sk_score;
|
||||
sk_num_blocks = tentative_sk_blocks;
|
||||
}
|
||||
}
|
||||
|
||||
if(best_sk_score >= dp_for_sk_iters)
|
||||
{
|
||||
sk_num_blocks = 0;
|
||||
}
|
||||
|
||||
// give a chance to control num of sk blocks
|
||||
sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks;
|
||||
|
||||
if(sk_num_blocks == 0)
|
||||
{
|
||||
sk_num_big_blocks = 0;
|
||||
k_iters_per_big_block = 0;
|
||||
|
||||
dp_num_blocks = num_tiles; // all tile to be dp block
|
||||
dp_start_block_idx = 0;
|
||||
sk_total_iters = 0; // clear this tiles
|
||||
}
|
||||
else
|
||||
{
|
||||
// k_iters_per_sk_block is the floor of avg each ck block loop over tiles.
|
||||
// we need to decide how many iters for each sk block
|
||||
// let m = k_iters_per_sk_block
|
||||
// some of the sk block (little) will cover m iters, some (big) will cover m+1
|
||||
// we have
|
||||
// 1) l + b = sk_blocks
|
||||
// 2) l * m + b * (m + 1) = sk_total_iters
|
||||
// => (l + b) * m + b = sk_total_iters
|
||||
// => sk_blocks * m + b = sk_total_iters
|
||||
// => b = sk_total_iters - m * sk_blocks
|
||||
// NOTE: big could be zero
|
||||
const uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks;
|
||||
sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks;
|
||||
k_iters_per_big_block = k_iters_per_sk_block + 1;
|
||||
|
||||
dp_num_blocks = dp_tiles;
|
||||
dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu;
|
||||
}
|
||||
}
|
||||
n_tiles = mdiv2(num_tile_n_);
|
||||
reduction_start_block_idx = dp_start_block_idx + dp_num_blocks;
|
||||
|
||||
if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
const uint32_t upper_big = lcm(k_iters_per_big_block, k_iters_per_tile.get());
|
||||
const uint32_t upper_little = lcm(k_iters_per_big_block - 1, k_iters_per_tile.get());
|
||||
equiv_tiles_big = mdiv(upper_big / k_iters_per_tile.get());
|
||||
equiv_tiles_little = mdiv(upper_little / k_iters_per_tile.get());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculate optimal grid size for Stream-K
|
||||
*/
|
||||
CK_TILE_HOST auto GridSize() const noexcept -> dim3
|
||||
{
|
||||
if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
return dim3(reduction_start_block_idx + GetSkTiles(), 1, 1);
|
||||
}
|
||||
else
|
||||
return dim3(reduction_start_block_idx, 1, 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculate number of loop iterations over K dimension for given work unit
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE static auto GetLoopNum(uint32_t K) noexcept -> uint32_t
|
||||
{
|
||||
return integer_divide_ceil(K, KPerBlock); // Stream-K processes one K-slice at a time
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get output tile index for standard 2D mapping (compatibility)
|
||||
*/
|
||||
CK_TILE_DEVICE auto
|
||||
GetOutputTileIndex(uint32_t tile_idx) const noexcept -> tuple<uint32_t, uint32_t>
|
||||
{
|
||||
uint32_t m_tile_idx, n_tile_idx;
|
||||
n_tiles.divmod(tile_idx, num_tile_n_, m_tile_idx, n_tile_idx);
|
||||
|
||||
// swizzle tile
|
||||
|
||||
uint32_t tile_swizzle_sub_m_rem = num_tile_m_ % TileSwizzleSubM;
|
||||
|
||||
const auto sub_m_adapt = (m_tile_idx < (num_tile_m_ - tile_swizzle_sub_m_rem))
|
||||
? TileSwizzleSubM
|
||||
: tile_swizzle_sub_m_rem;
|
||||
|
||||
uint32_t m_tile_idx_sub0, m_tile_idx_sub1;
|
||||
m_tile_idx_sub0 = m_tile_idx / TileSwizzleSubM;
|
||||
m_tile_idx_sub1 = m_tile_idx % TileSwizzleSubM;
|
||||
|
||||
uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * num_tile_n_;
|
||||
|
||||
uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt;
|
||||
|
||||
n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt;
|
||||
m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt;
|
||||
return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * TileSwizzleSubM,
|
||||
n_tile_idx_with_adapt);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get work range for a given block ID
|
||||
*/
|
||||
CK_TILE_DEVICE void
|
||||
GetBlockItr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const noexcept
|
||||
{
|
||||
if(block_idx < sk_num_big_blocks)
|
||||
{
|
||||
iter_start = block_idx * k_iters_per_big_block;
|
||||
iter_end = iter_start + k_iters_per_big_block;
|
||||
}
|
||||
else if(block_idx < sk_num_blocks)
|
||||
{
|
||||
iter_start = (sk_num_big_blocks * k_iters_per_big_block) +
|
||||
(block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1);
|
||||
iter_end = iter_start + (k_iters_per_big_block - 1);
|
||||
}
|
||||
else if(block_idx >= dp_start_block_idx)
|
||||
{
|
||||
uint32_t sk_total_iters = GetSkTotalIters();
|
||||
uint32_t dp_iters_per_block = k_iters_per_tile.get();
|
||||
iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block;
|
||||
iter_end = iter_start + dp_iters_per_block;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get total number of iterations for sk tiles
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE uint32_t GetSkTotalIters() const noexcept
|
||||
{
|
||||
uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block +
|
||||
(sk_num_blocks - sk_num_big_blocks) * (k_iters_per_big_block - 1);
|
||||
return sk_total_iters;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get total number of sk tiles
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE uint32_t GetSkTiles() const noexcept
|
||||
{
|
||||
// tiles for sk
|
||||
uint32_t sk_total_iters = GetSkTotalIters();
|
||||
return k_iters_per_tile.div(sk_total_iters);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get length of loop iterations for stream-k loop
|
||||
*/
|
||||
CK_TILE_DEVICE uint32_t GetCurrentIterLength(uint32_t iter_start,
|
||||
uint32_t iter_end) const noexcept
|
||||
{
|
||||
// A WG's iter_end is either in the current C macro tile or not.
|
||||
// If it is not, then the macro tile boundary is where the WG must stop.
|
||||
uint32_t distance_to_tile_boundary =
|
||||
k_iters_per_tile.get() - (iter_start % k_iters_per_tile.get());
|
||||
return min(iter_start + distance_to_tile_boundary, iter_end) - iter_start;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get index of tile during a specified iteration
|
||||
*/
|
||||
CK_TILE_DEVICE uint32_t GetTileIdx(uint32_t iter) const noexcept
|
||||
{
|
||||
return k_iters_per_tile.div(iter);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get index of tile during a specified iteration
|
||||
*/
|
||||
CK_TILE_DEVICE void
|
||||
GetTileIdxWithOffset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const noexcept
|
||||
{
|
||||
k_iters_per_tile.divmod(iter, tile_idx, iter_offset);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates the buffer space needed for accumulation
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE uint32_t GetWorkSpaceSizeForAcc(uint32_t acc_element_bytes) const noexcept
|
||||
{
|
||||
static constexpr uint32_t alignment = 128;
|
||||
uint32_t acc_buffer_bytes =
|
||||
MPerBlock * NPerBlock * GetTotalAccBuffers() * acc_element_bytes;
|
||||
return (acc_buffer_bytes + alignment - 1) / alignment * alignment;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates the buffer space needed for the semaphore
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE uint32_t GetWorkSpaceSizeForSemaphore() const noexcept
|
||||
{
|
||||
return GetSkTiles() * sizeof(uint32_t);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates the total buffer space needed for accumulation and the semaphore
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE uint32_t GetWorkSpaceSize(uint32_t acc_element_bytes) const noexcept
|
||||
{
|
||||
return GetWorkSpaceSizeForAcc(acc_element_bytes) + GetWorkSpaceSizeForSemaphore();
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get location of intersection of tiles for reduction
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE uint32_t GetTileIntersections(uint32_t tiles_,
|
||||
const mdiv& equiv_tiles_) const noexcept
|
||||
{
|
||||
uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1);
|
||||
uint32_t max_equiv_tiles_ = equiv_tiles_.get() - 1;
|
||||
uint32_t quo_, rem_;
|
||||
equiv_tiles_.divmod(tile_idx_, quo_, rem_);
|
||||
return quo_ * max_equiv_tiles_ + rem_;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculate the number of tiles needed for the number of sk blocks
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE uint32_t GetTilesCoverSkBlock(uint32_t num_sk_blocks_,
|
||||
uint32_t iters_per_sk_block_) const noexcept
|
||||
{
|
||||
return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() -
|
||||
1);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculate the amount of total accumulation buffers required for stream-k
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE uint32_t GetTotalAccBuffers() const noexcept
|
||||
{
|
||||
uint32_t tiles_cover_big_blocks =
|
||||
GetTilesCoverSkBlock(sk_num_big_blocks, k_iters_per_big_block);
|
||||
uint32_t tiles_cover_little_blocks =
|
||||
GetTilesCoverSkBlock(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1);
|
||||
|
||||
uint32_t total_intersec_big = GetTileIntersections(tiles_cover_big_blocks, equiv_tiles_big);
|
||||
uint32_t total_intersec_little =
|
||||
GetTileIntersections(tiles_cover_little_blocks, equiv_tiles_little);
|
||||
|
||||
return sk_num_blocks + total_intersec_big + total_intersec_little;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculate offset based on tile index for big/little tiles
|
||||
*/
|
||||
CK_TILE_DEVICE uint32_t GetAccBufferOffsetFromTile(uint32_t tile_idx_) const noexcept
|
||||
{
|
||||
uint32_t tiles_cover_big_blocks =
|
||||
GetTilesCoverSkBlock(sk_num_big_blocks, k_iters_per_big_block);
|
||||
if(tile_idx_ < tiles_cover_big_blocks)
|
||||
{
|
||||
uint32_t touched_sk_blocks =
|
||||
(tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) /
|
||||
k_iters_per_big_block;
|
||||
uint32_t current_intersec = GetTileIntersections(tile_idx_, equiv_tiles_big);
|
||||
return touched_sk_blocks + current_intersec;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
|
||||
uint32_t tile_idx_little_reverse = GetSkTiles() - tile_idx_;
|
||||
uint32_t touched_sk_blocks =
|
||||
(tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) /
|
||||
iters_per_little_sk_block;
|
||||
uint32_t current_intersec =
|
||||
GetTileIntersections(tile_idx_little_reverse, equiv_tiles_little);
|
||||
return GetTotalAccBuffers() - (touched_sk_blocks + current_intersec);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculate offset based on block_idx index for big/little streamk blocks
|
||||
*/
|
||||
CK_TILE_DEVICE uint32_t GetAccBufferOffsetFromBlock(uint32_t block_idx_) const noexcept
|
||||
{
|
||||
uint32_t iters_per_big_sk_block = k_iters_per_big_block;
|
||||
uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1;
|
||||
if(block_idx_ < sk_num_big_blocks)
|
||||
{
|
||||
uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block +
|
||||
k_iters_per_tile.get() - 1);
|
||||
uint32_t current_intersec = GetTileIntersections(touched_tiles, equiv_tiles_big);
|
||||
return block_idx_ + current_intersec;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_;
|
||||
uint32_t touched_tiles = k_iters_per_tile.div(
|
||||
block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1);
|
||||
uint32_t current_intersec = GetTileIntersections(touched_tiles, equiv_tiles_little);
|
||||
return GetTotalAccBuffers() - (block_idx_little_reverse + current_intersec);
|
||||
}
|
||||
}
|
||||
|
||||
// Getters for problem dimensions
|
||||
CK_TILE_HOST_DEVICE uint32_t GetNumTileM() const noexcept { return num_tile_m_; }
|
||||
CK_TILE_HOST_DEVICE uint32_t GetNumTileN() const noexcept { return num_tile_n_; }
|
||||
CK_TILE_HOST_DEVICE uint32_t GetNumTileK() const noexcept { return num_tile_k_; }
|
||||
|
||||
uint32_t sk_num_blocks;
|
||||
uint32_t sk_num_big_blocks;
|
||||
uint32_t dp_start_block_idx;
|
||||
uint32_t reduction_start_block_idx;
|
||||
uint32_t k_iters_per_big_block;
|
||||
mdiv2 n_tiles;
|
||||
mdiv k_iters_per_tile;
|
||||
mdiv equiv_tiles_big; // for reduction
|
||||
mdiv equiv_tiles_little; // for reduction
|
||||
|
||||
private:
|
||||
uint32_t M_, N_, K_;
|
||||
uint32_t num_tile_m_, num_tile_n_, num_tile_k_;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -8,15 +8,16 @@
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace reboot {
|
||||
|
||||
/// @brief The Stream K GEMM kernel host arguments.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This structure is passed to @ref StreamKKernel "StreamKKernel" when creating the kernel
|
||||
/// arguments object. It contains all necessary information required to build proper kernel
|
||||
/// arguments and launch the kernel on GPU. This structure defines the GEMM problem
|
||||
/// configuration by stating all required information like M,N,K sizes and respective strides.
|
||||
/**
|
||||
* @brief The Stream K GEMM kernel host arguments.
|
||||
*
|
||||
* @par Overview
|
||||
* This structure is passed to @ref StreamKKernel "StreamKKernel" when creating the kernel
|
||||
* arguments object. It contains all necessary information required to build proper kernel
|
||||
* arguments and launch the kernel on GPU. This structure defines the GEMM problem
|
||||
* configuration by stating all required information like M,N,K sizes and respective strides.
|
||||
*/
|
||||
struct StreamKHostArgs : public ck_tile::UniversalGemmHostArgs<>
|
||||
{
|
||||
CK_TILE_HOST explicit StreamKHostArgs(const void* a_ptr_,
|
||||
@@ -48,22 +49,26 @@ struct StreamKHostArgs : public ck_tile::UniversalGemmHostArgs<>
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy;
|
||||
};
|
||||
|
||||
/// @brief The Stream K GEMM kernel class.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This class is responsible for the Stream-K kernel, making use of UniversalGemm.
|
||||
// The main kernel functions are the operator() functions. There is one for Persistent
|
||||
// and one for Non-Persistent data parallel sections of the Stream-K algorithm.
|
||||
//
|
||||
// Both the Non-Persistent and Persistent kernels make use of `BaseGemm()` and
|
||||
// `StreamKGemm()`. `BaseGemm()` computes offsets into the A,B,C tensors, then calls
|
||||
// `RunGemm()` which runs the GEMM pipeline and epilogue. `StreamKGemm()` performs the
|
||||
// main Stream-K algorithm. Each iteration of the Stream-K loop calls `BaseGemm()`.
|
||||
/**
|
||||
* @brief The Stream K GEMM kernel class.
|
||||
*
|
||||
* @par Overview
|
||||
* This class is responsible for the Stream-K kernel, making use of UniversalGemm.
|
||||
* The main kernel functions are the operator() functions. There is one for Persistent
|
||||
* and one for Non-Persistent data parallel sections of the Stream-K algorithm.
|
||||
*
|
||||
* Both the Non-Persistent and Persistent kernels make use of `BaseGemm()` and
|
||||
* `StreamKGemm()`. `BaseGemm()` computes offsets into the A,B,C tensors, then calls
|
||||
* `RunGemm()` which runs the GEMM pipeline and epilogue. `StreamKGemm()` performs the
|
||||
* main Stream-K algorithm. Each iteration of the Stream-K loop calls `BaseGemm()`.
|
||||
*/
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct StreamKKernel
|
||||
{
|
||||
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
|
||||
/// functions.
|
||||
/**
|
||||
*@brief Inject the UniversalGemmKernel base class to support execution of all necessary
|
||||
*functions.
|
||||
*/
|
||||
using UniversalGemmKernel =
|
||||
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
||||
|
||||
@@ -78,12 +83,16 @@ struct StreamKKernel
|
||||
TilePartitioner::PERSISTENT == PersistentDP,
|
||||
"Persistent flag from TilePartitioner must match Persistent flag from UniversalGemm.");
|
||||
|
||||
/// @brief Specify the layout configurations for A, B, and C
|
||||
/**
|
||||
* @brief Specify the layout configurations for A, B, and C
|
||||
*/
|
||||
using ALayout = typename GemmPipeline::ALayout;
|
||||
using BLayout = typename GemmPipeline::BLayout;
|
||||
using CLayout = typename GemmPipeline::CLayout;
|
||||
|
||||
/// @brief Specify the data type configurations for A, B, and C
|
||||
/**
|
||||
* @brief Specify the data type configurations for A, B, and C
|
||||
*/
|
||||
using ADataType = typename GemmPipeline::ADataType;
|
||||
using BDataType = typename GemmPipeline::BDataType;
|
||||
using CDataType = typename EpiloguePipeline::ODataType;
|
||||
@@ -91,16 +100,21 @@ struct StreamKKernel
|
||||
|
||||
template <typename T>
|
||||
static constexpr bool is_tuple_v = is_detected<is_tuple, T>::value;
|
||||
|
||||
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
|
||||
/**
|
||||
*@brief ALayout and ADataType are expected to be scalars, not a tuple.
|
||||
*/
|
||||
static_assert(!is_tuple_v<ALayout> && !is_tuple_v<ADataType>,
|
||||
"ALayout and ADataType must be scalars.");
|
||||
|
||||
/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
|
||||
/**
|
||||
*@brief BLayout and BDataType are expected to be scalars, not a tuple.
|
||||
*/
|
||||
static_assert(!is_tuple_v<BLayout> && !is_tuple_v<BDataType>,
|
||||
"BLayout and BDataType must be scalars.");
|
||||
|
||||
/// @brief CLayout and CDataType are expected to be scalars, not a tuple.
|
||||
/**
|
||||
*@brief CLayout and CDataType are expected to be scalars, not a tuple.
|
||||
*/
|
||||
static_assert(!is_tuple_v<CLayout> && !is_tuple_v<CDataType>,
|
||||
"CLayout and CDataType must be scalars.");
|
||||
|
||||
@@ -127,14 +141,19 @@ struct StreamKKernel
|
||||
|
||||
{
|
||||
}
|
||||
|
||||
/// @brief The strategy used by work groups to compute final results in C tensor.
|
||||
/**
|
||||
* @brief The strategy used by work groups to compute final results in C tensor.
|
||||
*/
|
||||
StreamKReductionStrategy reduction_strategy;
|
||||
/// @brief A pointer to a buffer in device memory for accumulating partial via reduction
|
||||
/// strategy.
|
||||
/**
|
||||
* @brief A pointer to a buffer in device memory for accumulating partial via reduction
|
||||
* strategy.
|
||||
*/
|
||||
void* workspace_ptr;
|
||||
/// @brief An instance of the TilePartioner class for assisting with mapping workgroups to
|
||||
/// the C tensor.
|
||||
/**
|
||||
* @brief An instance of the TilePartioner class for assisting with mapping workgroups to
|
||||
* the C tensor.
|
||||
*/
|
||||
TilePartitioner tile_partitioner;
|
||||
};
|
||||
|
||||
@@ -155,17 +174,21 @@ struct StreamKKernel
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
/// @brief Compute the grid size for the Stream K kernel using the tile_partitioner.
|
||||
/// @return The grid size.
|
||||
/**
|
||||
* @brief Compute the grid size for the Stream K kernel using the tile_partitioner.
|
||||
* @return The grid size.
|
||||
*/
|
||||
CK_TILE_HOST static auto GridSize(const TilePartitioner& tile_partitioner) -> dim3
|
||||
{
|
||||
return tile_partitioner.grid_size();
|
||||
}
|
||||
|
||||
/// @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
|
||||
/// @return The maximum occupancy grid size.
|
||||
/// @note This function queries the maximum occupancy of the kernel using
|
||||
/// `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
|
||||
/**
|
||||
* @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
|
||||
* @return The maximum occupancy grid size.
|
||||
* @note This function queries the maximum occupancy of the kernel using
|
||||
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
|
||||
*/
|
||||
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
|
||||
{
|
||||
return UniversalGemmKernel::MaxOccupancyGridSize(s);
|
||||
@@ -176,13 +199,15 @@ struct StreamKKernel
|
||||
return UniversalGemmKernel::BlockSize();
|
||||
}
|
||||
|
||||
/// @brief Constructs kernel arguments for the Stream-K kernel.
|
||||
/// @param host_args Stream-K host arguments.
|
||||
/// @param num_cu Number of compute units (CUs). The default is the number of CUs on the device.
|
||||
/// The caller may select their own to assist with test reproducibility, etc.
|
||||
/// @param occupancy The maximum number of active blocks per CU for this kernel. The caller may
|
||||
/// select their own to assist with test reproducibility, etc.
|
||||
/// @return The kernel arguments for Stream-K.
|
||||
/**
|
||||
* @brief Constructs kernel arguments for the Stream-K kernel.
|
||||
* @param host_args Stream-K host arguments.
|
||||
* @param num_cu Number of compute units (CUs). The default is the number of CUs on the device.
|
||||
* The caller may select their own to assist with test reproducibility, etc.
|
||||
* @param occupancy The maximum number of active blocks per CU for this kernel. The caller may
|
||||
* select their own to assist with test reproducibility, etc.
|
||||
* @return The kernel arguments for Stream-K.
|
||||
*/
|
||||
CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args,
|
||||
int num_cu = NumCU(),
|
||||
int occupancy = Occupancy())
|
||||
@@ -247,30 +272,35 @@ struct StreamKKernel
|
||||
return UniversalGemmKernel::IsSupportedArgument(kargs);
|
||||
}
|
||||
|
||||
/// @brief Computes the buffer size needed to store accumulation results for Stream K.
|
||||
/// @return The buffer size needed.
|
||||
/**
|
||||
* @brief Computes the buffer size needed to store accumulation results for Stream K.
|
||||
* @return The buffer size needed.
|
||||
*/
|
||||
CK_TILE_HOST static uint32_t GetWorkSpaceSize(const StreamKKernelArgs& kargs)
|
||||
{
|
||||
return kargs.tile_partitioner.get_workspace_size(sizeof(AccDataType));
|
||||
}
|
||||
|
||||
/// @brief Sets the kargs' current workspace_ptr to the given workspace_ptr.
|
||||
/// @note Assumes that the given workspace_ptr points to allocated device memory.
|
||||
/**
|
||||
*@brief Sets the kargs' current workspace_ptr to the given workspace_ptr.
|
||||
* @note Assumes that the given workspace_ptr points to allocated device memory.
|
||||
*/
|
||||
CK_TILE_HOST static void SetWorkSpacePointer(StreamKKernelArgs& kargs, void* workspace_ptr)
|
||||
{
|
||||
kargs.workspace_ptr = workspace_ptr;
|
||||
}
|
||||
|
||||
/// @brief Computes offsets into A, B, and C tensors then runs the GEMM pipeline and epilogue.
|
||||
/// @param kargs Stream-K kernel arguments.
|
||||
/// @param tile_idx The 1D tile index in the C tensor for this workgroup.
|
||||
/// @param num_loop The number of iterations (at the macro tile level) in the K dimension this
|
||||
/// workgroup will perform in the C tile.
|
||||
/// @param i_k_a The K offset in the A tensor.
|
||||
/// @param i_k_b The K offset in the B tensor.
|
||||
/// @param k_size The portion of the K dimension this workgroup processes in the assigned
|
||||
/// `tile_idx`.
|
||||
/// @param smem_ptr_0 Pointer to LDS.
|
||||
/**
|
||||
* @brief Computes offsets into A, B, and C tensors then runs the GEMM pipeline and epilogue.
|
||||
* @param kargs Stream-K kernel arguments.
|
||||
* @param tile_idx The 1D tile index in the C tensor for this workgroup.
|
||||
* @param num_loop The number of iterations (at the macro tile level) in the K dimension this
|
||||
* workgroup will perform in the C tile.
|
||||
* @param i_k_a The K offset in the A tensor.
|
||||
* @param i_k_b The K offset in the B tensor.
|
||||
* @param k_size The portion of the K dimension this workgroup processes in the assigned
|
||||
* `tile_idx`.
|
||||
* @param smem_ptr_0 Pointer to LDS.
|
||||
*/
|
||||
CK_TILE_DEVICE void BaseGemm(StreamKKernelArgs& kargs,
|
||||
index_t tile_idx,
|
||||
index_t num_loop,
|
||||
@@ -292,12 +322,14 @@ struct StreamKKernel
|
||||
{a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, k_size);
|
||||
}
|
||||
|
||||
/// @brief Signals that the current thread block (CTA) has completed storing its partial
|
||||
/// results.
|
||||
/// @param kargs Kernel arguments, including the workspace pointer.
|
||||
/// @param cta_idx The index of the current thread block (CTA).
|
||||
/// @note This function utilizes a workgroup barrier to set a synchronization flag for the given
|
||||
/// CTA index.
|
||||
/**
|
||||
*@brief Signals that the current thread block(CTA) has completed storing its partial
|
||||
* results.
|
||||
* @param kargs Kernel arguments, including the workspace pointer.
|
||||
* @param cta_idx The index of the current thread block (CTA).
|
||||
* @note This function utilizes a workgroup barrier to set a synchronization flag for the given
|
||||
* CTA index.
|
||||
*/
|
||||
CK_TILE_DEVICE void SignalStorePartialDone(const StreamKKernelArgs& kargs,
|
||||
index_t cta_idx) const
|
||||
{
|
||||
@@ -306,11 +338,13 @@ struct StreamKKernel
|
||||
sk_flags.wait_set(0, 1, cta_idx);
|
||||
}
|
||||
|
||||
/// @brief Waits for the thread block (cta_idx) to complete storing its partial results.
|
||||
/// @param kargs Kernel arguments, including the workspace pointer.
|
||||
/// @param cta_idx The index of the thread block (CTA).
|
||||
/// @note This function utilizes a workgroup barrier to wait for the synchronization flag to be
|
||||
/// set by the given CTA index.
|
||||
/**
|
||||
* @brief Waits for the thread block (cta_idx) to complete storing its partial results.
|
||||
* @param kargs Kernel arguments, including the workspace pointer.
|
||||
* @param cta_idx The index of the thread block (CTA).
|
||||
* @note This function utilizes a workgroup barrier to wait for the synchronization flag to be
|
||||
* set by the given CTA index.
|
||||
*/
|
||||
CK_TILE_DEVICE void WaitStorePartialDone(const StreamKKernelArgs& kargs, index_t cta_idx) const
|
||||
{
|
||||
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
|
||||
@@ -318,11 +352,13 @@ struct StreamKKernel
|
||||
sk_flags.wait_eq(1, cta_idx);
|
||||
}
|
||||
|
||||
/// @brief Adds the values of a block tile to an output block tile.
|
||||
/// @param in_out_block_tile The output block tile to which values are added.
|
||||
/// @param in_block_tile The input block tile whose values are added.
|
||||
/// @note This function iterates over the distributed spans of the block tiles and updates the
|
||||
/// output block tile with accumulated values.
|
||||
/**
|
||||
* @brief Adds the values of a block tile to an output block tile.
|
||||
* @param in_out_block_tile The output block tile to which values are added.
|
||||
* @param in_block_tile The input block tile whose values are added.
|
||||
* @note This function iterates over the distributed spans of the block tiles and updates the
|
||||
* output block tile with accumulated values.
|
||||
*/
|
||||
template <typename OAccTile>
|
||||
CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile,
|
||||
const OAccTile& in_block_tile) const
|
||||
@@ -337,13 +373,15 @@ struct StreamKKernel
|
||||
});
|
||||
}
|
||||
|
||||
/// @brief Loads a partial block tile from the workspace buffer.
|
||||
/// @param kargs Kernel arguments, including the workspace pointer.
|
||||
/// @param cta_idx The index of the thread block (CTA).
|
||||
/// @param c_block_tile_dist The tile distribution for the block.
|
||||
/// @return The loaded partial block tile.
|
||||
/// @note This function calculates the buffer pointer and uses the tile distribution for loading
|
||||
/// the partial block tile.
|
||||
/**
|
||||
* @brief Loads a partial block tile from the workspace buffer.
|
||||
* @param kargs Kernel arguments, including the workspace pointer.
|
||||
* @param cta_idx The index of the thread block (CTA).
|
||||
* @param c_block_tile_dist The tile distribution for the block.
|
||||
* @return The loaded partial block tile.
|
||||
* @note This function calculates the buffer pointer and uses the tile distribution for loading
|
||||
* the partial block tile.
|
||||
*/
|
||||
template <typename DataType, typename OAccTileDist>
|
||||
CK_TILE_DEVICE auto LoadPartial(const StreamKKernelArgs& kargs,
|
||||
index_t cta_idx,
|
||||
@@ -371,12 +409,14 @@ struct StreamKKernel
|
||||
return load_tile(partial_tile_window);
|
||||
}
|
||||
|
||||
/// @brief Stores a partial block tile to the workspace buffer.
|
||||
/// @param kargs Kernel arguments, including the workspace pointer.
|
||||
/// @param cta_idx The index of the thread block (CTA).
|
||||
/// @param c_block_tile The block tile to be stored.
|
||||
/// @note This function calculates the buffer pointer and uses the tile window for storing the
|
||||
/// partial block tile.
|
||||
/**
|
||||
* @brief Stores a partial block tile to the workspace buffer.
|
||||
* @param kargs Kernel arguments, including the workspace pointer.
|
||||
* @param cta_idx The index of the thread block (CTA).
|
||||
* @param c_block_tile The block tile to be stored.
|
||||
* @note This function calculates the buffer pointer and uses the tile window for storing the
|
||||
* partial block tile.
|
||||
*/
|
||||
template <typename OAccTile>
|
||||
CK_TILE_DEVICE void StorePartial(const StreamKKernelArgs& kargs,
|
||||
index_t cta_idx,
|
||||
@@ -404,15 +444,17 @@ struct StreamKKernel
|
||||
store_tile(partial_tile_window, c_block_tile);
|
||||
}
|
||||
|
||||
/// @brief Runs the main Stream-K algorithm.
|
||||
/// @param kargs Stream-K kernel arguments.
|
||||
/// @param cta_idx The current Stream-K workgroup's index.
|
||||
/// @param smem_ptr_0 Pointer to LDS.
|
||||
/// @note It is assumed that the first Stream-K workgroup has a `cta_idx` of zero. If a
|
||||
/// non-persistent data-parallel (DP) section is used, then a Stream-K workgroup's `cta_idx`
|
||||
/// should be something like `blockIdx.x` minus number of DP workgroups.
|
||||
CK_TILE_DEVICE void
|
||||
StreamKGemm(StreamKKernelArgs& kargs, index_t cta_idx, void* smem_ptr_0) const
|
||||
/**
|
||||
* @brief Runs the main Stream - K algorithm.
|
||||
* @param kargs Stream - K kernel arguments.
|
||||
* @param cta_idx The current Stream - K workgroup's index.
|
||||
* @param smem_ptr_0 Pointer to LDS.
|
||||
* @note It is assumed that the first Stream - K workgroup has a `cta_idx` of zero. If a
|
||||
* non-persistent data-parallel (DP) section is used, then a Stream-K workgroup's `cta_idx`
|
||||
* *should be something like `blockIdx.x` minus number of DP workgroups.
|
||||
*/
|
||||
CK_TILE_DEVICE
|
||||
void StreamKGemm(StreamKKernelArgs& kargs, index_t cta_idx, void* smem_ptr_0) const
|
||||
{
|
||||
index_t iter_start, iter_end;
|
||||
kargs.tile_partitioner.get_iter_boundaries(iter_start, iter_end, cta_idx);
|
||||
@@ -542,13 +584,15 @@ struct StreamKKernel
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Entry point for the Stream-K Kernel with non-persistent DP.
|
||||
///
|
||||
/// @par Overview
|
||||
/// For the Non-Persistent kernel, each data parallel workgroup will
|
||||
/// compute the results for their assigned macro-tile by calling `BaseGemm()`.
|
||||
/// The Stream-K workgroups will do their assigned work by calling
|
||||
/// `StreamKGemm()`, which calls `BaseGemm()` in the Stream-K loop.
|
||||
/**
|
||||
* @brief Entry point for the Stream-K Kernel with non-persistent DP.
|
||||
*
|
||||
* @par Overview
|
||||
* For the Non-Persistent kernel, each data parallel workgroup will
|
||||
* compute the results for their assigned macro-tile by calling `BaseGemm()`.
|
||||
* The Stream-K workgroups will do their assigned work by calling
|
||||
* `StreamKGemm()`, which calls `BaseGemm()` in the Stream-K loop.
|
||||
*/
|
||||
template <bool U = PersistentDP>
|
||||
CK_TILE_DEVICE typename std::enable_if_t<!U> operator()(StreamKKernelArgs kargs) const
|
||||
{
|
||||
@@ -572,14 +616,16 @@ struct StreamKKernel
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Entry point for the Stream-K Kernel with persistent DP.
|
||||
///
|
||||
/// @par Overview
|
||||
/// For the Persistent kernel, each workgroup will first compute their
|
||||
/// assigned data-parallel tiles. Each data parallel tile will be computed
|
||||
/// by calling `BaseGemm()`. Then the workgroups will proceed with the
|
||||
/// Stream-K portion by calling `StreamKGemm()`, which calls `BaseGemm()`
|
||||
/// in the Stream-K loop.
|
||||
/**
|
||||
* @brief Entry point for the Stream-K Kernel with persistent DP.
|
||||
*
|
||||
* @par Overview
|
||||
* For the Persistent kernel, each workgroup will first compute their
|
||||
* assigned data-parallel tiles. Each data parallel tile will be computed
|
||||
* by calling `BaseGemm()`. Then the workgroups will proceed with the
|
||||
* Stream-K portion by calling `StreamKGemm()`, which calls `BaseGemm()`
|
||||
* in the Stream-K loop.
|
||||
*/
|
||||
template <bool U = PersistentDP>
|
||||
CK_TILE_DEVICE typename std::enable_if_t<U> operator()(StreamKKernelArgs kargs) const
|
||||
{
|
||||
@@ -601,12 +647,14 @@ struct StreamKKernel
|
||||
}
|
||||
|
||||
private:
|
||||
/// @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset is
|
||||
/// the starting macro tile index in the K dimension for the workgroup.
|
||||
/// @return A tuple containing the offsets into the A and B tensors accounting for the layouts
|
||||
/// of A and B.
|
||||
/// @note The default case is that A is assumed to be row major and B is assumed to be column
|
||||
/// major.
|
||||
/**
|
||||
* @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset is
|
||||
* the starting macro tile index in the K dimension for the workgroup.
|
||||
* @return A tuple containing the offsets into the A and B tensors accounting for the layouts
|
||||
* of A and B.
|
||||
* @note The default case is that A is assumed to be row major and B is assumed to be column
|
||||
* major.
|
||||
*/
|
||||
template <typename ALayout, typename BLayout>
|
||||
CK_TILE_DEVICE static tuple<index_t, index_t>
|
||||
GetKOffsets(index_t iter_offset, index_t stride_a, index_t stride_b)
|
||||
@@ -647,10 +695,12 @@ struct StreamKKernel
|
||||
return num_cu;
|
||||
}
|
||||
|
||||
/// @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the kernel
|
||||
/// @return The occupancy
|
||||
/// @note This function queries the maximum occupancy of the kernel using
|
||||
/// `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
|
||||
/**
|
||||
* @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the kernel
|
||||
* @return The occupancy
|
||||
* @note This function queries the maximum occupancy of the kernel using
|
||||
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
|
||||
*/
|
||||
CK_TILE_HOST static int Occupancy()
|
||||
{
|
||||
int occupancy;
|
||||
@@ -665,402 +715,4 @@ struct StreamKKernel
|
||||
return max(occupancy, 1);
|
||||
}
|
||||
};
|
||||
} // namespace reboot
|
||||
|
||||
/// @brief The Stream K GEMM kernel host arguments.
|
||||
///
|
||||
/// @par Overview
|
||||
/// This structure is passed to @ref StreamKKernel "StreamKKernel" when creating the kernel
|
||||
/// arguments object. It contains all necessary information required to build proper kernel
|
||||
/// arguments and launch the kernel on GPU. This structure defines the GEMM problem
|
||||
/// configuration by stating all required information like M,N,K sizes and respective strides.
|
||||
struct StreamKHostArgs : public ck_tile::UniversalGemmHostArgs<>
|
||||
{
|
||||
CK_TILE_HOST explicit StreamKHostArgs(const void* a_ptr_,
|
||||
const void* b_ptr_,
|
||||
void* c_ptr_,
|
||||
index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
index_t stride_A_,
|
||||
index_t stride_B_,
|
||||
index_t stride_C_,
|
||||
StreamKReductionStrategy reduction_strategy_,
|
||||
uint32_t num_sk_blocks_ = 0xffffffff)
|
||||
: UniversalGemmHostArgs<>({a_ptr_},
|
||||
{b_ptr_},
|
||||
{/*ds_ptr*/},
|
||||
c_ptr_,
|
||||
/*k_batch_ =*/1,
|
||||
M_,
|
||||
N_,
|
||||
K_,
|
||||
{stride_A_},
|
||||
{stride_B_},
|
||||
{/*stride_Ds_*/},
|
||||
stride_C_),
|
||||
reduction_strategy{reduction_strategy_},
|
||||
num_sk_blocks{num_sk_blocks_}
|
||||
{
|
||||
}
|
||||
|
||||
ck_tile::StreamKReductionStrategy reduction_strategy;
|
||||
uint32_t num_sk_blocks;
|
||||
};
|
||||
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct StreamKKernel
|
||||
{
|
||||
/// @brief Inject the UniversalGemmKernel base class to support execution of all necessary
|
||||
/// functions.
|
||||
using UniversalGemmKernel =
|
||||
UniversalGemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>;
|
||||
|
||||
static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize;
|
||||
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
|
||||
/// @brief Specify the layout configurations for A, B, and C
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
|
||||
/// @brief Specify the data type configurations for A, B, and C
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
/// @brief ALayout and ADataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, ALayout>::value &&
|
||||
!is_detected<is_tuple, ADataType>::value,
|
||||
"ALayout and ADataType must be scalars.");
|
||||
|
||||
/// @brief BLayout and BDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, BLayout>::value &&
|
||||
!is_detected<is_tuple, BDataType>::value,
|
||||
"BLayout and BDataType must be scalars.");
|
||||
|
||||
/// @brief CLayout and CDataType are expected to be scalars, not a tuple.
|
||||
static_assert(!is_detected<is_tuple, CLayout>::value &&
|
||||
!is_detected<is_tuple, CDataType>::value,
|
||||
"CLayout and CDataType must be scalars.");
|
||||
|
||||
struct StreamKKernelArgs : ck_tile::UniversalGemmKernelArgs<>
|
||||
{
|
||||
/// @brief The strategy used by work groups to compute final results in C tensor.
|
||||
StreamKReductionStrategy reduction_strategy;
|
||||
/// @brief The number of stream k blocks.
|
||||
uint32_t num_sk_blocks;
|
||||
/// @brief A pointer to a buffer in device memory for accumulating partial via reduction
|
||||
/// strategy.
|
||||
void* workspace_ptr;
|
||||
/// @brief An instance of the TilePartioner class for assisting with mapping workgroups to
|
||||
/// the C tensor.
|
||||
TilePartitioner tile_partitioner;
|
||||
};
|
||||
|
||||
using KernelArgs = StreamKKernelArgs;
|
||||
using Kernel = StreamKKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
using P_ = GemmPipeline;
|
||||
using WarpTile = typename P_::BlockGemmShape::WarpTile;
|
||||
|
||||
return concat('_', "streamk", gemm_prec_str<ADataType, BDataType>(),
|
||||
concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
|
||||
concat('x', WarpTile::at(number<0>{}), WarpTile::at(number<1>{}), WarpTile::at(number<2>{})),
|
||||
concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
|
||||
concat('x', P_::kPadM, P_::kPadN, P_::kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
/// @brief Compute the grid size for the Stream K kernel using the tile_partitioner.
|
||||
/// @return The grid size.
|
||||
CK_TILE_HOST static auto GridSize(const TilePartitioner& tile_partitioner) -> dim3
|
||||
{
|
||||
return tile_partitioner.GridSize();
|
||||
}
|
||||
|
||||
/// @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
|
||||
/// @return The maximum occupancy grid size.
|
||||
/// @note This function queries the maximum occupancy of the kernel using
|
||||
/// `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
|
||||
CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
|
||||
{
|
||||
return UniversalGemmKernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() -> dim3
|
||||
{
|
||||
return UniversalGemmKernel::BlockSize();
|
||||
}
|
||||
|
||||
/// @brief Constructs kernel arguments for the Stream-K kernel.
|
||||
/// @param host_args Stream-K host arguments.
|
||||
/// @param num_cu Number of compute units (CUs). The default is the number of CUs on the device.
|
||||
/// The caller may select their own to assist with test reproducibility, etc.
|
||||
/// @param occupancy The maximum number of active blocks per CU for this kernel. The caller may
|
||||
/// select their own to assist with test reproducibility, etc.
|
||||
/// @return The kernel arguments for Stream-K.
|
||||
CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args,
|
||||
int num_cu = NumCU(),
|
||||
int occupancy = Occupancy())
|
||||
{
|
||||
return StreamKKernelArgs{{host_args.as_ptr,
|
||||
host_args.bs_ptr,
|
||||
host_args.ds_ptr,
|
||||
host_args.e_ptr,
|
||||
host_args.M,
|
||||
host_args.N,
|
||||
host_args.K,
|
||||
host_args.stride_As,
|
||||
host_args.stride_Bs,
|
||||
host_args.stride_Ds,
|
||||
host_args.stride_E,
|
||||
host_args.k_batch},
|
||||
host_args.reduction_strategy,
|
||||
host_args.num_sk_blocks,
|
||||
// The workspace pointer is set to nullptr because we must first
|
||||
// instantiate the TilePartitioner to get the necessary size
|
||||
/*workspace_ptr =*/nullptr,
|
||||
TilePartitioner{static_cast<uint32_t>(host_args.M),
|
||||
static_cast<uint32_t>(host_args.N),
|
||||
static_cast<uint32_t>(host_args.K),
|
||||
static_cast<uint32_t>(num_cu),
|
||||
static_cast<uint32_t>(occupancy),
|
||||
host_args.num_sk_blocks}};
|
||||
}
|
||||
|
||||
template <bool UseDefaultScheduler = true>
|
||||
CK_TILE_DEVICE static void
|
||||
RunGemm(const std::array<const ADataType*, UniversalGemmKernel::NumATensor>& as_ptr,
|
||||
const std::array<const BDataType*, UniversalGemmKernel::NumBTensor>& bs_ptr,
|
||||
const std::array<const void*, UniversalGemmKernel::NumDTensor>& ds_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const typename UniversalGemmKernel::KernelArgs& kargs,
|
||||
const index_t num_loop,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n,
|
||||
const index_t k_size)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
UniversalGemmKernel::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size);
|
||||
|
||||
const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0);
|
||||
const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2);
|
||||
|
||||
// Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute
|
||||
// has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this
|
||||
// case, we call the GemmPipeline's operator() function that takes both has_hot_loop and
|
||||
// tail_num.
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0],
|
||||
bs_block_window[UniversalGemmKernel::I0],
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_num,
|
||||
smem_ptr_0);
|
||||
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const StreamKKernelArgs& kargs)
|
||||
{
|
||||
if(kargs.reduction_strategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("CK Tile Stream-K only supports the atomic reduction strategy.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return UniversalGemmKernel::IsSupportedArgument(kargs);
|
||||
}
|
||||
|
||||
/// @brief Computes the buffer size needed to store accumulation results for Stream K.
|
||||
/// @return The buffer size needed.
|
||||
CK_TILE_HOST static uint32_t GetWorkSpaceSize(const StreamKKernelArgs& kargs)
|
||||
{
|
||||
// For reduction, we need to determine the amount of device space for acculumation
|
||||
// results and semaphores.
|
||||
if(kargs.reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
return kargs.tile_partitioner.GetWorkSpaceSize(sizeof(CDataType));
|
||||
}
|
||||
|
||||
// Otherwise, no additional space is needed since blocks atomically store their results.
|
||||
return 0;
|
||||
}
|
||||
|
||||
/// @brief Sets the kargs' current workspace_ptr to the given workspace_ptr.
|
||||
/// @note Assumes that the given workspace_ptr points to allocated device memory.
|
||||
CK_TILE_HOST static void SetWorkSpacePointer(StreamKKernelArgs& kargs, void* workspace_ptr)
|
||||
{
|
||||
kargs.workspace_ptr = workspace_ptr;
|
||||
}
|
||||
|
||||
/// @brief Entry point for the Stream-K Kernel, performing the main Stream-K loop.
|
||||
CK_TILE_DEVICE void operator()(StreamKKernelArgs kargs) const
|
||||
{
|
||||
// Allocate LDS
|
||||
__shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()];
|
||||
|
||||
uint32_t block_idx = ck_tile::get_block_1d_id();
|
||||
|
||||
bool is_padding_block =
|
||||
amd_wave_read_first_lane(block_idx >= kargs.tile_partitioner.sk_num_blocks &&
|
||||
block_idx < kargs.tile_partitioner.dp_start_block_idx);
|
||||
|
||||
// Padding blocks make it such that the DP blocks are aligned with the number of CUs; they
|
||||
// should not partake in the GEMM
|
||||
if(is_padding_block)
|
||||
return;
|
||||
|
||||
// Determine the K offset of the first and final macro tile in the A and B tensors along the
|
||||
// K dimension.
|
||||
uint32_t iter_start, iter_end;
|
||||
kargs.tile_partitioner.GetBlockItr(block_idx, iter_start, iter_end);
|
||||
|
||||
// Main Stream-K loop
|
||||
while(true)
|
||||
{
|
||||
// Determine the number of macro tiles in A and B this WG is resposible for in the
|
||||
// current C macro tile.
|
||||
uint32_t current_iter_length = amd_wave_read_first_lane(
|
||||
kargs.tile_partitioner.GetCurrentIterLength(iter_start, iter_end));
|
||||
|
||||
// Determine the 1D tile_idx and the iter_offset for this WG.
|
||||
// The tile_idx is the 1D macro tile index in the C tensor.
|
||||
// The iter_offset is the starting macro tile index in the K dimension for the WG in the
|
||||
// current iteration of the while loop.
|
||||
uint32_t tile_idx, iter_offset;
|
||||
kargs.tile_partitioner.GetTileIdxWithOffset(iter_start, tile_idx, iter_offset);
|
||||
|
||||
// Get the 2D tile index in the C tensor for this WG using the 1D index (i.e. tile_idx)
|
||||
auto spatial_idx = kargs.tile_partitioner.GetOutputTileIndex(tile_idx);
|
||||
|
||||
// Get the offsets in A, B, C tensors.
|
||||
index_t i_m = static_cast<index_t>(spatial_idx[UniversalGemmKernel::I0] *
|
||||
TilePartitioner::MPerBlock);
|
||||
index_t i_n = static_cast<index_t>(spatial_idx[UniversalGemmKernel::I1] *
|
||||
TilePartitioner::NPerBlock);
|
||||
auto [i_k_a, i_k_b] = GetKOffsets<ALayout, BLayout>(
|
||||
static_cast<index_t>(iter_offset), kargs.stride_As[0], kargs.stride_Bs[0]);
|
||||
|
||||
// Determine the total size along the K dimension the WG is using in this iteration
|
||||
// (used to construct tensor views).
|
||||
index_t k_size = static_cast<index_t>(current_iter_length * TilePartitioner::KPerBlock);
|
||||
|
||||
// Update pointer offsets for A, B, and C.
|
||||
const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) + i_k_a;
|
||||
const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) + i_k_b;
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
|
||||
|
||||
// Run the GEMM pipeline and Epilogue.
|
||||
RunGemm({a_ptr},
|
||||
{b_ptr},
|
||||
{/*ds_ptr*/},
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
current_iter_length,
|
||||
i_m,
|
||||
i_n,
|
||||
k_size);
|
||||
|
||||
// Prepare for next Stream-K loop iteration.
|
||||
iter_start += current_iter_length;
|
||||
if(iter_end <= iter_start)
|
||||
break;
|
||||
block_sync_lds();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
/// @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset is
|
||||
/// the starting macro tile index in the K dimension for the workgroup.
|
||||
/// @return A tuple containing the offsets into the A and B tensors accounting for the layouts
|
||||
/// of A and B.
|
||||
/// @note The default case is that A is assumed to be row major and B is assumed to be column
|
||||
/// major.
|
||||
template <typename ALayout, typename BLayout>
|
||||
CK_TILE_DEVICE static tuple<index_t, index_t>
|
||||
GetKOffsets(index_t iter_offset, index_t stride_a, index_t stride_b)
|
||||
{
|
||||
index_t stride_offset_a;
|
||||
index_t stride_offset_b;
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
stride_offset_a = stride_a;
|
||||
}
|
||||
else
|
||||
{
|
||||
stride_offset_a = 1;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
stride_offset_b = stride_b;
|
||||
}
|
||||
else
|
||||
{
|
||||
stride_offset_b = 1;
|
||||
}
|
||||
|
||||
index_t base_offset = iter_offset * TilePartitioner::KPerBlock;
|
||||
|
||||
return make_tuple(base_offset * stride_offset_a, base_offset * stride_offset_b);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static int NumCU()
|
||||
{
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
hip_check_error(hipGetDevice(&dev));
|
||||
hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
|
||||
int num_cu = dev_prop.multiProcessorCount;
|
||||
|
||||
return num_cu;
|
||||
}
|
||||
|
||||
/// @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the kernel
|
||||
/// @return The occupancy
|
||||
/// @note This function queries the maximum occupancy of the kernel using
|
||||
/// `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
|
||||
CK_TILE_HOST static int Occupancy()
|
||||
{
|
||||
int occupancy;
|
||||
|
||||
// Since occupancy of 1 is valid for stream k, we set min_num_block_per_cu to 1
|
||||
constexpr int min_block_per_cu = 1;
|
||||
const auto kernel = kentry<min_block_per_cu, Kernel, KernelArgs>;
|
||||
|
||||
hip_check_error(
|
||||
hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
|
||||
|
||||
return occupancy;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -226,7 +226,7 @@ struct StreamKTilePartitionerBase
|
||||
template <typename BlockGemmShapeType,
|
||||
StreamKReductionStrategy ReductionStrategyType,
|
||||
bool Persistent>
|
||||
struct StreamKTilePartitioner_v2;
|
||||
struct StreamKTilePartitioner;
|
||||
|
||||
/**
|
||||
* @brief Persistent Stream-K tile partitioner derived struct.
|
||||
@@ -240,13 +240,13 @@ struct StreamKTilePartitioner_v2;
|
||||
* the C Tensor.
|
||||
*/
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true>
|
||||
struct StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>
|
||||
{
|
||||
StreamKTilePartitioner_v2(ck_tile::index_t m,
|
||||
ck_tile::index_t n,
|
||||
ck_tile::index_t k,
|
||||
ck_tile::index_t grid);
|
||||
StreamKTilePartitioner(ck_tile::index_t m,
|
||||
ck_tile::index_t n,
|
||||
ck_tile::index_t k,
|
||||
ck_tile::index_t grid);
|
||||
|
||||
public:
|
||||
static constexpr bool PERSISTENT = true;
|
||||
@@ -287,13 +287,13 @@ struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true
|
||||
* the C Tensor.
|
||||
*/
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, false>
|
||||
struct StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, false>
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>
|
||||
{
|
||||
StreamKTilePartitioner_v2(ck_tile::index_t m,
|
||||
ck_tile::index_t n,
|
||||
ck_tile::index_t k,
|
||||
ck_tile::index_t grid);
|
||||
StreamKTilePartitioner(ck_tile::index_t m,
|
||||
ck_tile::index_t n,
|
||||
ck_tile::index_t k,
|
||||
ck_tile::index_t grid);
|
||||
|
||||
public:
|
||||
static constexpr bool PERSISTENT = false;
|
||||
|
||||
@@ -238,15 +238,12 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::estimate_
|
||||
template <typename BlockGemmShapeType,
|
||||
StreamKReductionStrategy ReductionStrategyType,
|
||||
bool Persistent>
|
||||
struct StreamKTilePartitioner_v2;
|
||||
struct StreamKTilePartitioner;
|
||||
|
||||
// child class for Persistent Tile Partitioner
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true>::
|
||||
StreamKTilePartitioner_v2(ck_tile::index_t m,
|
||||
ck_tile::index_t n,
|
||||
ck_tile::index_t k,
|
||||
ck_tile::index_t grid)
|
||||
StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::StreamKTilePartitioner(
|
||||
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid)
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
|
||||
{ // inherit from base constructor
|
||||
dp_tiles_per_cta_ = this->dp_tiles_ / this->grid_;
|
||||
@@ -255,8 +252,8 @@ StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true>::
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST auto
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true>::grid_size()
|
||||
const noexcept -> dim3
|
||||
StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::grid_size() const noexcept
|
||||
-> dim3
|
||||
{
|
||||
if(extra_dp_tiles_ == 0)
|
||||
{
|
||||
@@ -270,7 +267,7 @@ StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true>::grid
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true>::get_dp_tiles_per_cta()
|
||||
StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::get_dp_tiles_per_cta()
|
||||
const noexcept
|
||||
{
|
||||
return dp_tiles_per_cta_;
|
||||
@@ -278,7 +275,7 @@ StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true>::get_
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true>::get_extra_dp_tiles()
|
||||
StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, true>::get_extra_dp_tiles()
|
||||
const noexcept
|
||||
{
|
||||
return extra_dp_tiles_;
|
||||
@@ -286,11 +283,8 @@ StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true>::get_
|
||||
|
||||
// child class for Non-Persistent Tile Partitioner
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, false>::
|
||||
StreamKTilePartitioner_v2(ck_tile::index_t m,
|
||||
ck_tile::index_t n,
|
||||
ck_tile::index_t k,
|
||||
ck_tile::index_t grid)
|
||||
StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, false>::StreamKTilePartitioner(
|
||||
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid)
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
|
||||
{ // inherit from base constructor
|
||||
dp_ctas_ = this->dp_tiles_;
|
||||
@@ -300,15 +294,15 @@ StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, false>::
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST auto
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, false>::grid_size()
|
||||
const noexcept -> dim3
|
||||
StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, false>::grid_size() const noexcept
|
||||
-> dim3
|
||||
{
|
||||
return dim3(dp_ctas_ + this->get_sk_ctas(), 1, 1);
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, false>::get_dp_ctas()
|
||||
StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, false>::get_dp_ctas()
|
||||
const noexcept
|
||||
{
|
||||
return dp_ctas_;
|
||||
@@ -316,16 +310,16 @@ StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, false>::get
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, false>::
|
||||
get_dp_start_block_idx() const noexcept
|
||||
StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, false>::get_dp_start_block_idx()
|
||||
const noexcept
|
||||
{
|
||||
return dp_start_block_idx_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, false>::
|
||||
get_sk_start_block_idx() const noexcept
|
||||
StreamKTilePartitioner<BlockGemmShapeType, ReductionStrategyType, false>::get_sk_start_block_idx()
|
||||
const noexcept
|
||||
{
|
||||
return sk_start_block_idx_;
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ set(REGRESSION_TESTS
|
||||
test_ck_tile_fmha_fwd_bf16
|
||||
test_ck_tile_fmha_fwd_fp16
|
||||
test_ck_tile_fmha_fwd_fp8
|
||||
test_ck_tile_streamk_reboot_extended
|
||||
test_ck_tile_streamk_extended
|
||||
)
|
||||
|
||||
function(add_test_executable TEST_NAME)
|
||||
|
||||
@@ -19,147 +19,29 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
|
||||
#TODO: support all arches
|
||||
#TODO: current c-shuffle only supports C layout as R
|
||||
add_gtest_executable(test_ck_tile_streamk_smoke
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rrc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rcc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_crc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_ccc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rrc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_rcc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_crc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf16_ccc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f8_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/bf8_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
)
|
||||
# TODO: enable extended tests after tolerances for atomic reductions are addressed.
|
||||
# add_gtest_executable(test_ck_tile_streamk_extended
|
||||
# # compv3 pipeline
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rrc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rcc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_crc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_ccc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rrc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_rcc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_crc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/f16_ccc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rrc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rcr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rcc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_crr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_crc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_ccr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_ccc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rrr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rrc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rcr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_rcc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_crr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_crc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_ccr_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv3/bf16_ccc_compv3_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
|
||||
# # TODO: add compv4 pipeline
|
||||
# # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rrr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rrc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rcr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rcc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_crr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_crc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_ccr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_ccc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rrr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rrc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rcr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_rcc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_crr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_crc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_ccr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/f16_ccc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
|
||||
# # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rrr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rrc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rcr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rcc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_crr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_crc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_ccr_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_ccc_compv4_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rrr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rrc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rcr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_rcc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_crr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_crc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_ccr_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# # #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/compv4/bf16_ccc_compv4_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
|
||||
|
||||
# # mem pipeline
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_rrr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_rrc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_rcr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_rcc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_crr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_crc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_ccr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/f16_ccc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_rrr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_rrc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_rcr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_rcc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_crr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_crc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# )
|
||||
target_compile_options(test_ck_tile_streamk_smoke PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping test_ck_tile_streamk_smoke for current target")
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
|
||||
add_gtest_executable(test_ck_tile_streamk_reboot_smoke
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_fp16_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_bf16_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp
|
||||
test_gemm_streamk_reboot_util.cpp)
|
||||
add_gtest_executable(test_ck_tile_streamk_reboot_extended
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_fp16_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_bf16_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_fp16_nonpersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_reboot_bf16_nonpersistent.cpp
|
||||
test_gemm_streamk_reboot_util.cpp)
|
||||
add_gtest_executable(test_ck_tile_streamk_smoke
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf16_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp8_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf8_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_nonpersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf16_nonpersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp8_nonpersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf8_nonpersistent.cpp
|
||||
test_gemm_streamk_util.cpp)
|
||||
add_gtest_executable(test_ck_tile_streamk_extended
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp16_nonpersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf16_nonpersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_fp8_nonpersistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/test_gemm_streamk_bf8_nonpersistent.cpp
|
||||
test_gemm_streamk_util.cpp)
|
||||
target_compile_options(test_ck_tile_streamk_smoke PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_streamk_extended PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping test_ck_tile_streamk unit tests for current target")
|
||||
endif()
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_CCC_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_CCC_CompV3_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_CCR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_CCR_CompV3_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_CRC_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_CRC_CompV3_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_CRR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_CRR_CompV3_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_RCC_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_RCC_CompV3_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_RCR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_RCR_CompV3_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_RRC_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_RRC_CompV3_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_RRR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_RRR_CompV3_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_CCC_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_CCC_CompV3_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_CCR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_CCR_CompV3_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_CRC_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_CRC_CompV3_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_CRR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_CRR_CompV3_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_RCC_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_RCC_CompV3_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_RCR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_RCR_CompV3_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_RRC_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_RRC_CompV3_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_RRR_CompV3_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_RRR_CompV3_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_CCC_CompV4_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_CCC_CompV4_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_CCR_CompV4_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_CCR_CompV4_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_CRC_CompV4_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_CRC_CompV4_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_CRR_CompV4_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_CRR_CompV4_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_RCC_CompV4_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_RCC_CompV4_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_RCR_CompV4_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_RCR_CompV4_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_RRC_CompV4_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_RRC_CompV4_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_RRR_CompV4_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_RRR_CompV4_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_CCC_CompV4_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_CCC_CompV4_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_CCR_CompV4_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_CCR_CompV4_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_CRC_CompV4_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_CRC_CompV4_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_CRR_CompV4_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_CRR_CompV4_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_RCC_CompV4_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_RCC_CompV4_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_RCR_CompV4_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_RCR_CompV4_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_RRC_CompV4_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_RRC_CompV4_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_RRR_CompV4_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_RRR_CompV4_256x256x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_CCC_Mem_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_CCR_Mem_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_CRC_Mem_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_CRR_Mem_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_RCC_Mem_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_RCR_Mem_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_RRC_Mem_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS BF16_RRR_Mem_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_CCC_Mem_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_CCR_Mem_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_CRC_Mem_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_CRR_Mem_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_RCC_Mem_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_RCR_Mem_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_RRC_Mem_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -1,11 +0,0 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
#define TEST_SUITE_PARAMS F16_RRR_Mem_128x128x32_2x2x1_32x32x16_NonPersistent
|
||||
#define TEST_SUITE_NAME MAKE_TEST_SUITE_NAME(TEST_SUITE_PARAMS)
|
||||
|
||||
DECLARE_STREAM_K_TEST(TEST_SUITE_NAME, TEST_SUITE_PARAMS);
|
||||
|
||||
#include "test_gemm_streamk_cases.inc"
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf16NonPersistent : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf16NonPersistent
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf16NonPersistent, KernelTypesStreamKBf16NonPersistent);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf16Persistent : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf16Persistent
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf16Persistent, KernelTypesStreamKBf16Persistent);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf8NonPersistent : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf8NonPersistent
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf8NonPersistent, KernelTypesStreamKBf8NonPersistent);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKBf8Persistent : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKBf8Persistent
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKBf8Persistent, KernelTypesStreamKBf8Persistent);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp16NonPersistent : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp16NonPersistent
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp16NonPersistent, KernelTypesStreamKFp16NonPersistent);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp16Persistent : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp16Persistent
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp16Persistent, KernelTypesStreamKFp16Persistent);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp8NonPersistent : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp8NonPersistent
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp8NonPersistent, KernelTypesStreamKFp8NonPersistent);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp8Persistent : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp8Persistent
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp8Persistent, KernelTypesStreamKFp8Persistent);
|
||||
|
||||
#include "test_gemm_streamk_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,19 +0,0 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_reboot_types.hpp"
|
||||
#include "test_gemm_streamk_reboot_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKRebootBf16NonPersistent : public TestCkTileStreamKReboot<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKRebootBf16NonPersistent
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKRebootBf16NonPersistent, KernelTypesStreamKBf16NonPersistent);
|
||||
|
||||
#include "test_gemm_streamk_reboot_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,19 +0,0 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_reboot_types.hpp"
|
||||
#include "test_gemm_streamk_reboot_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKRebootBf16Persistent : public TestCkTileStreamKReboot<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKRebootBf16Persistent
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKRebootBf16Persistent, KernelTypesStreamKBf16Persistent);
|
||||
|
||||
#include "test_gemm_streamk_reboot_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -1,19 +0,0 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_reboot_types.hpp"
|
||||
#include "test_gemm_streamk_reboot_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKRebootFp16NonPersistent : public TestCkTileStreamKReboot<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKRebootFp16NonPersistent
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKRebootFp16NonPersistent, KernelTypesStreamKFp16NonPersistent);
|
||||
|
||||
#include "test_gemm_streamk_reboot_extended_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user