mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Merge commit 'd7278cc664c20613e0b7c45f249f6e7613550ca2' into develop
This commit is contained in:
4
.github/workflows/therock-ci-linux.yml
vendored
4
.github/workflows/therock-ci-linux.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
permissions:
|
||||
id-token: write
|
||||
container:
|
||||
image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:044b113562629f4bd2ec5d2e64b32eee11562d48fb1a75d7493daec9dd8d8292
|
||||
image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:2f3ebd0beb04c449fdb36933e54bdc69483b914fb9005594d3fc9444c206b54b
|
||||
options: -v /runner/config:/home/awsconfig/
|
||||
env:
|
||||
AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }}
|
||||
@@ -44,7 +44,7 @@ jobs:
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
repository: "ROCm/TheRock"
|
||||
ref: dc05d637054ad197c84b00e24b6262af0ec797c6 # 10-03-2025 commit
|
||||
ref: c2921b151b8285a1d29942aceb33cfe0fea77ac9 # 10-15-2025 commit
|
||||
path: "TheRock"
|
||||
|
||||
- name: Setup ccache
|
||||
|
||||
@@ -608,7 +608,7 @@ class KernelComponentFactory:
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f'))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f'))
|
||||
if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and skip == "f":
|
||||
if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f":
|
||||
pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't'))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't'))
|
||||
if receipt == 1 and bias != "bias":
|
||||
|
||||
@@ -33,6 +33,7 @@
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
|
||||
|
||||
@@ -0,0 +1,327 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Stream-K tile partitioner base class.
|
||||
*
|
||||
* This partitioner is responsible for mapping workgroups to tiles in the C tensor
|
||||
* for the Stream-K algorithm.
|
||||
*
|
||||
* @tparam BlockGemmShapeType A class providing basic GEMM parameters.
|
||||
* @tparam ReductionStrategyType An enum that defines the reduction strategy for the results in
|
||||
* the C Tensor.
|
||||
*/
|
||||
template <typename BlockGemmShapeType,
|
||||
StreamKReductionStrategy ReductionStrategyType = StreamKReductionStrategy::Atomic>
|
||||
struct StreamKTilePartitionerBase
|
||||
{
|
||||
using BlockGemmShape = BlockGemmShapeType;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategyType;
|
||||
|
||||
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid);
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Calculates the total space needed for the partials buffer.
|
||||
*
|
||||
* @param acc_element_bytes The number of bytes for the accumulator data type used in the GEMM.
|
||||
* @return index_t The number of bytes needed for the partials buffer.
|
||||
*/
|
||||
CK_TILE_HOST index_t get_partials_buffer_size(index_t acc_element_bytes) const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the total space needed for the flags buffer.
|
||||
*
|
||||
* @return index_t The number of bytes needed for the flags buffer.
|
||||
*/
|
||||
CK_TILE_HOST index_t get_flags_buffer_size() const noexcept;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Calculates the start and end iteration given the cta_idx.
|
||||
*
|
||||
* @param iter_start Reference to an index_t; will be set to the starting iteration by the
|
||||
* function.
|
||||
* @param iter_end Reference to an index_t; will be set to the non-inclusive end iteration by
|
||||
* the function.
|
||||
* @param cta_idx The current Stream-K workgroup's index.
|
||||
* @note It is assumed that the first Stream-K workgroup has a `cta_idx` of zero. If a
|
||||
* non-persistent DP section is used, then a Stream-K workgroup's `cta_idx` should be something
|
||||
* like `blockIdx.x` minus number of DP workgroups.
|
||||
*/
|
||||
CK_TILE_DEVICE void
|
||||
get_iter_boundaries(index_t& iter_start, index_t& iter_end, index_t cta_idx) const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the 1D tile index in the C tensor for a workgroup.
|
||||
*
|
||||
* @param iter_start The starting iteration.
|
||||
* @return index_t The 1D tile index.
|
||||
*/
|
||||
CK_TILE_DEVICE index_t get_tile_index(index_t iter_start) const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the starting and ending tile boundaries for the given 1D tile index.
|
||||
*
|
||||
* @param tile_iter_start Reference to an index_t; will be set to the tile's start iteration by
|
||||
* the function.
|
||||
* @param tile_iter_end Reference to an index_t; will be set to the non-inclusive tile's end
|
||||
* iteration by the function.
|
||||
* @param tile_idx The 1D C tensor tile index for the workgroup.
|
||||
*/
|
||||
CK_TILE_DEVICE void get_tile_boundaries(index_t& tile_iter_start,
|
||||
index_t& tile_iter_end,
|
||||
index_t tile_idx) const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the workgroup's starting iteration that is local to a tile.
|
||||
*
|
||||
* @param iter_start The starting iteration.
|
||||
* @param tile_iter_start The starting iteration of the tile (i.e., the tile's starting
|
||||
* boundary).
|
||||
* @return index_t The local starting iteration. The value is in range [0, `iters_per_tile_`).
|
||||
* @note Assumes `iter_start` >= `tile_iter_start`.
|
||||
*/
|
||||
CK_TILE_DEVICE static index_t get_local_iter(index_t iter_start,
|
||||
index_t tile_iter_start) noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the workgroup's non-inclusive end iteration that is local to a tile.
|
||||
*
|
||||
* @param tile_iter_start The starting tile iteration.
|
||||
* @param iter_end The non-inclusive end iteration.
|
||||
* @param tile_iter_end The non-inclusive end iteration of the tile.
|
||||
* @return index_t The local non-inclusive end iteration.
|
||||
* @note Assumes `iter_end` >= `tile_iter_start` and `tile_iter_end` >= `tile_iter_start`.
|
||||
*/
|
||||
CK_TILE_DEVICE static index_t
|
||||
get_local_iter_end(index_t tile_iter_start, index_t iter_end, index_t tile_iter_end) noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the workgroups 2D tile index in the C tensor given the 1D tile index.
|
||||
*
|
||||
* @param tile_idx The 1D tile index in the C tensor for the workgroup.
|
||||
* @return index_t The corresponding 2D tile index in the C tensor for the workgroup.
|
||||
*/
|
||||
CK_TILE_DEVICE auto
|
||||
get_output_tile_index(index_t tile_idx) const noexcept -> tuple<index_t, index_t>;
|
||||
|
||||
/**
|
||||
* @brief Calculates the total space needed for the partials and flags buffers.
|
||||
*
|
||||
* @param acc_element_bytes The number of bytes for the accumulator data type used in the GEMM.
|
||||
* @return index_t The number of bytes needed for the partials and flags buffers.
|
||||
*/
|
||||
CK_TILE_HOST index_t get_workspace_size(index_t acc_element_bytes) const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the number of macro tiles in the C tensor.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_num_tiles() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the maximum number of active workgroups; this is assumed to be number of CUs *
|
||||
* occupancy.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_grid() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the number of tiles in the C tensor that will use the data-parallel (DP)
|
||||
* approach.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_dp_tiles() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the number of tiles in the C tensor that will use the Stream-K approach.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_sk_tiles() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the number of workgroups that will participate in Stream-K in the `sk_tiles_`.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_sk_ctas() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the total number of Stream-K iterations.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_total_sk_iters() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the total number of iterations per tile in the C tensor. In other words, this
|
||||
* is the total number of macro tiles along the K dimension of A and B.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_iters_per_tile() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the total number of Stream-K iterations for each `sk_cta`. This is the lower
|
||||
* bound (i.e., all `sk_ctas_` are guaranteed to perform at least this many iterations).
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_iters_per_sk_cta() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the remainder resulting from `total_sk_iters_` divided by `sk_ctas_`. When
|
||||
* this is non-zero, the first `extra_iters_` `sk_ctas_` will get one additional iteration
|
||||
* assigned to them; such work groups will perform (`iters_per_sk_cta_` + 1) iterations.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_extra_iters() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the total number of DP iterations.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_total_dp_iters() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the n dimension for the GEMM problem.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_n() const noexcept;
|
||||
|
||||
protected:
|
||||
index_t num_tiles_;
|
||||
index_t grid_;
|
||||
index_t dp_tiles_;
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief The number of full tiles assigned to each `sk_cta` when performing DP + 2 Tile
|
||||
* Stream-K.
|
||||
*/
|
||||
index_t full_tiles_ = 1;
|
||||
index_t sk_tiles_;
|
||||
index_t sk_ctas_;
|
||||
index_t total_sk_iters_;
|
||||
index_t iters_per_tile_;
|
||||
index_t iters_per_sk_cta_;
|
||||
index_t extra_iters_;
|
||||
index_t total_dp_iters_;
|
||||
index_t n_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Template for the Stream-K tile partitioner derived struct.
|
||||
*
|
||||
* This partitioner is responsible for mapping workgroups to tiles in the C tensor
|
||||
* for the Stream-K algorithm. This struct is derived from
|
||||
* StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>. Behavior of the
|
||||
* StreamKTilePartitioner based on persistency will be in the template specializations.
|
||||
*
|
||||
* @tparam BlockGemmShapeType A class providing basic GEMM parameters.
|
||||
* @tparam ReductionStrategyType An enum that defines the reduction strategy for the results in
|
||||
* the C Tensor.
|
||||
* @tparam Persistent A bool that indicates whether to use a Persistent approach
|
||||
*/
|
||||
template <typename BlockGemmShapeType,
|
||||
StreamKReductionStrategy ReductionStrategyType,
|
||||
bool Persistent>
|
||||
struct StreamKTilePartitioner_v2;
|
||||
|
||||
/**
|
||||
* @brief Persistent Stream-K tile partitioner derived struct.
|
||||
*
|
||||
* This partitioner is responsible for mapping workgroups to tiles in the C tensor
|
||||
* for the Stream-K algorithm when using a Persistent approach where no extra workgroups
|
||||
* are allocated for data parallel.
|
||||
*
|
||||
* @tparam BlockGemmShapeType A class providing basic GEMM parameters.
|
||||
* @tparam ReductionStrategyType An enum that defines the reduction strategy for the results in
|
||||
* the C Tensor.
|
||||
*/
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true>
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>
|
||||
{
|
||||
StreamKTilePartitioner_v2(ck_tile::index_t m,
|
||||
ck_tile::index_t n,
|
||||
ck_tile::index_t k,
|
||||
ck_tile::index_t grid);
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Calculates the launching grid size for the Stream-K kernel. In the Persistent
|
||||
* case, no extra workgroups are allocated for the data parallel section, making the grid
|
||||
* size num_cu * occupancy.
|
||||
*
|
||||
* @return dim_3 The launching grid size for the kernel.
|
||||
*/
|
||||
CK_TILE_HOST auto grid_size() const noexcept -> dim3;
|
||||
|
||||
/**
|
||||
* @brief Returns the total number of DP tiles per workgroup.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_dp_tiles_per_cta() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the total number of DP tiles left over when `dp_tiles_` is not evenly
|
||||
* divisible by `grid_`.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_extra_dp_tiles() const noexcept;
|
||||
|
||||
protected:
|
||||
index_t dp_tiles_per_cta_;
|
||||
index_t extra_dp_tiles_;
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Non-Persistent Stream-K tile partitioner derived struct.
|
||||
*
|
||||
* This partitioner is responsible for mapping workgroups to tiles in the C tensor
|
||||
* for the Stream-K algorithm when using a Non-Persistent approach where extra workgroups
|
||||
* are allocated for the data parallel section.
|
||||
*
|
||||
* @tparam BlockGemmShapeType A class providing basic GEMM parameters.
|
||||
* @tparam ReductionStrategyType An enum that defines the reduction strategy for the results in
|
||||
* the C Tensor.
|
||||
*/
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, false>
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>
|
||||
{
|
||||
StreamKTilePartitioner_v2(ck_tile::index_t m,
|
||||
ck_tile::index_t n,
|
||||
ck_tile::index_t k,
|
||||
ck_tile::index_t grid);
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Calculates the launching grid size for the Stream-K kernel. In the Non-Persistent
|
||||
* case, extra workgroups are allocated for the data parallel section, making the grid
|
||||
* size the total number of Stream-K and data parallel workgroups.
|
||||
*
|
||||
* @return dim_3 The launching grid size for the kernel.
|
||||
*/
|
||||
CK_TILE_HOST auto grid_size() const noexcept -> dim3;
|
||||
|
||||
/**
|
||||
* @brief Returns the total number of DP workgroups.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_dp_ctas() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns starting DP workgroup index. It is always zero.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_dp_start_block_idx() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief The index that starts the Stream-K workgroups. It is set to the number of `dp_tiles_`.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_sk_start_block_idx() const noexcept;
|
||||
|
||||
protected:
|
||||
index_t dp_ctas_;
|
||||
index_t dp_start_block_idx_;
|
||||
index_t sk_start_block_idx_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#include "streamk_gemm_tile_partitioner_impl.hpp"
|
||||
@@ -0,0 +1,312 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
#pragma once
|
||||
#include "streamk_gemm_tile_partitioner.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::StreamKTilePartitionerBase(
|
||||
index_t m, index_t n, index_t k, index_t grid)
|
||||
: grid_{grid}, n_{n}
|
||||
{
|
||||
iters_per_tile_ = integer_divide_ceil(k, KPerBlock);
|
||||
num_tiles_ = integer_divide_ceil(m, MPerBlock) * integer_divide_ceil(n_, NPerBlock);
|
||||
|
||||
bool big_enough = num_tiles_ > grid_;
|
||||
index_t remainder_tiles = num_tiles_ % grid_;
|
||||
|
||||
if(remainder_tiles)
|
||||
{
|
||||
sk_tiles_ = big_enough ? full_tiles_ * grid_ + (num_tiles_ % grid_) : num_tiles_;
|
||||
sk_tiles_ = min(num_tiles_, sk_tiles_);
|
||||
sk_ctas_ = grid_;
|
||||
total_sk_iters_ = sk_tiles_ * iters_per_tile_;
|
||||
|
||||
// If there still isn't enough work to saturate all CUs, then just revert to DP only.
|
||||
if(total_sk_iters_ < grid_)
|
||||
{
|
||||
sk_tiles_ = 0;
|
||||
sk_ctas_ = 0;
|
||||
total_sk_iters_ = 0;
|
||||
}
|
||||
}
|
||||
else // Full DP (i.e., no Stream-K)
|
||||
{
|
||||
sk_tiles_ = 0;
|
||||
sk_ctas_ = 0;
|
||||
total_sk_iters_ = 0;
|
||||
}
|
||||
|
||||
iters_per_sk_cta_ = sk_ctas_ ? total_sk_iters_ / sk_ctas_ : 0;
|
||||
extra_iters_ = sk_ctas_ ? total_sk_iters_ % sk_ctas_ : 0;
|
||||
|
||||
dp_tiles_ = num_tiles_ - sk_tiles_;
|
||||
total_dp_iters_ = dp_tiles_ * iters_per_tile_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_partials_buffer_size(
|
||||
index_t acc_element_bytes) const noexcept
|
||||
{
|
||||
return MPerBlock * NPerBlock * acc_element_bytes * sk_ctas_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_flags_buffer_size()
|
||||
const noexcept
|
||||
{
|
||||
return sizeof(index_t) * sk_ctas_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE void
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_iter_boundaries(
|
||||
index_t& iter, index_t& iter_end, index_t cta_idx) const noexcept
|
||||
{
|
||||
index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
|
||||
iter = total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
|
||||
iter_end = iter + iters_per_sk_cta_ + (cta_idx < extra_iters_);
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_tile_index(
|
||||
index_t iter) const noexcept
|
||||
{
|
||||
return iter / iters_per_tile_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE void
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_tile_boundaries(
|
||||
index_t& tile_iter, index_t& tile_iter_end, index_t tile_idx) const noexcept
|
||||
{
|
||||
tile_iter = tile_idx * iters_per_tile_;
|
||||
tile_iter_end = tile_iter + iters_per_tile_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE /* static */ index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_local_iter(
|
||||
index_t iter, index_t tile_iter) noexcept
|
||||
{
|
||||
return iter - tile_iter;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE /* static */ index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_local_iter_end(
|
||||
index_t tile_iter, index_t iter_end, index_t tile_iter_end) noexcept
|
||||
{
|
||||
return ck_tile::min(iter_end, tile_iter_end) - tile_iter;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE auto
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_output_tile_index(
|
||||
index_t tile_idx) const noexcept -> tuple<index_t, index_t>
|
||||
{
|
||||
const index_t n_macro_tiles = integer_divide_ceil(n_, NPerBlock);
|
||||
|
||||
const index_t im = amd_wave_read_first_lane(tile_idx / n_macro_tiles);
|
||||
const index_t in = amd_wave_read_first_lane(tile_idx - im * n_macro_tiles);
|
||||
return make_tuple(im, in);
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_workspace_size(
|
||||
index_t acc_element_bytes) const noexcept
|
||||
{
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
|
||||
return get_partials_buffer_size(acc_element_bytes) + get_flags_buffer_size();
|
||||
}
|
||||
else // ReductionStrategy is Atomics
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_num_tiles()
|
||||
const noexcept
|
||||
{
|
||||
return num_tiles_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_grid() const noexcept
|
||||
{
|
||||
return grid_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_dp_tiles() const noexcept
|
||||
{
|
||||
return dp_tiles_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_sk_tiles() const noexcept
|
||||
{
|
||||
return sk_tiles_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_sk_ctas() const noexcept
|
||||
{
|
||||
return sk_ctas_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_total_sk_iters()
|
||||
const noexcept
|
||||
{
|
||||
return total_sk_iters_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_iters_per_tile()
|
||||
const noexcept
|
||||
{
|
||||
return iters_per_tile_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_iters_per_sk_cta()
|
||||
const noexcept
|
||||
{
|
||||
return iters_per_sk_cta_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_extra_iters()
|
||||
const noexcept
|
||||
{
|
||||
return extra_iters_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_total_dp_iters()
|
||||
const noexcept
|
||||
{
|
||||
return total_dp_iters_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_n() const noexcept
|
||||
{
|
||||
return n_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType,
|
||||
StreamKReductionStrategy ReductionStrategyType,
|
||||
bool Persistent>
|
||||
struct StreamKTilePartitioner_v2;
|
||||
|
||||
// child class for Persistent Tile Partitioner
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true>::
|
||||
StreamKTilePartitioner_v2(ck_tile::index_t m,
|
||||
ck_tile::index_t n,
|
||||
ck_tile::index_t k,
|
||||
ck_tile::index_t grid)
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
|
||||
{ // inherit from base constructor
|
||||
dp_tiles_per_cta_ = this->dp_tiles_ / this->grid_;
|
||||
extra_dp_tiles_ = this->dp_tiles_ % this->grid_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST auto
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true>::grid_size()
|
||||
const noexcept -> dim3
|
||||
{
|
||||
if(extra_dp_tiles_ == 0)
|
||||
{
|
||||
return dim3(this->grid_, 1, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(this->num_tiles_, 1, 1);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true>::get_dp_tiles_per_cta()
|
||||
const noexcept
|
||||
{
|
||||
return dp_tiles_per_cta_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true>::get_extra_dp_tiles()
|
||||
const noexcept
|
||||
{
|
||||
return extra_dp_tiles_;
|
||||
}
|
||||
|
||||
// child class for Non-Persistent Tile Partitioner
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, false>::
|
||||
StreamKTilePartitioner_v2(ck_tile::index_t m,
|
||||
ck_tile::index_t n,
|
||||
ck_tile::index_t k,
|
||||
ck_tile::index_t grid)
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
|
||||
{ // inherit from base constructor
|
||||
dp_ctas_ = this->dp_tiles_;
|
||||
dp_start_block_idx_ = 0;
|
||||
sk_start_block_idx_ = this->dp_tiles_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST auto
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, false>::grid_size()
|
||||
const noexcept -> dim3
|
||||
{
|
||||
return dim3(dp_ctas_ + this->get_sk_ctas(), 1, 1);
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, false>::get_dp_ctas()
|
||||
const noexcept
|
||||
{
|
||||
return dp_ctas_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, false>::
|
||||
get_dp_start_block_idx() const noexcept
|
||||
{
|
||||
return dp_start_block_idx_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, false>::
|
||||
get_sk_start_block_idx() const noexcept
|
||||
{
|
||||
return sk_start_block_idx_;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -4,7 +4,7 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR})
|
||||
|
||||
#TODO: support all arches
|
||||
#TODO: current stream-k c-shuffle only supports C layout as R
|
||||
#TODO: current c-shuffle only supports C layout as R
|
||||
add_gtest_executable(test_ck_tile_streamk_smoke
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rrr_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
#${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/f16_rrc_compv3_256x256x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
@@ -116,6 +116,7 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
# ${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccr_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# #${CMAKE_CURRENT_SOURCE_DIR}/extended_tests/mem/bf16_ccc_mem_128x128x32_2x2x1_32x32x16_NonPersistent.cpp
|
||||
# )
|
||||
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
|
||||
else()
|
||||
message(DEBUG "Skipping test_ck_tile_streamk tests for current target")
|
||||
endif()
|
||||
|
||||
498
test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp
Normal file
498
test/ck_tile/gemm_streamk/test_streamk_tile_partitioner.cpp
Normal file
@@ -0,0 +1,498 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_streamk_tile_partitioner_common.hpp"
|
||||
|
||||
TEST(StreamKTilePartitionerBaseConstructor, SKOnly)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigSKOnly;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerBaseExpected expected_values{
|
||||
2, 0, 3, 4, 1, 2, 1, 0, 2, Config::GRID, Config::N};
|
||||
validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseConstructor, DPOnly)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigDPOnly;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerBaseExpected expected_values{
|
||||
0, 6, 0, 0, 0, 2, 0, 12, 6, Config::GRID, Config::N};
|
||||
validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseConstructor, DP2TileSK)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerBaseExpected expected_values{
|
||||
4, 3, 3, 8, 2, 2, 2, 6, 7, Config::GRID, Config::N};
|
||||
validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseConstructor, EdgeCase)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigEdgeCase;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerBaseExpected expected_values{
|
||||
0, 1, 0, 0, 0, 2, 0, 2, 1, Config::GRID, Config::N};
|
||||
validate_streamk_base_constructor<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, AtomicStrategy)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
EXPECT_EQ(tile_partitioner.get_workspace_size(sizeof(float)), 0);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetWorkSpaceSize, ReductionStrategy)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Reduction>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
ck_tile::index_t expected_partials_size =
|
||||
sizeof(float) * Config::M_TILE * Config::N_TILE * Config::GRID;
|
||||
ck_tile::index_t expected_flags_size = sizeof(ck_tile::index_t) * Config::GRID;
|
||||
|
||||
EXPECT_EQ(tile_partitioner.get_workspace_size(sizeof(float)),
|
||||
expected_partials_size + expected_flags_size);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetLocalIter, GetLocalIter)
|
||||
{
|
||||
// Types
|
||||
using Config = StreamKTilePartitionerBaseConfigSKOnly;
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<Config::GemmShape>;
|
||||
using Kernel = KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_LOCAL_ITER>;
|
||||
|
||||
// Test parameters
|
||||
ck_tile::DeviceMem local_iter_start_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::index_t iter_start = 3;
|
||||
ck_tile::index_t tile_iter_start = 2;
|
||||
|
||||
// Launch kernel
|
||||
auto kargs = Kernel::MakeKernelArgs(iter_start,
|
||||
tile_iter_start,
|
||||
Config::UNUSED,
|
||||
local_iter_start_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
Config::UNUSED);
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
|
||||
|
||||
// Validate result
|
||||
ck_tile::index_t local_iter_start;
|
||||
local_iter_start_dev.FromDevice(&local_iter_start);
|
||||
EXPECT_EQ(local_iter_start, iter_start - tile_iter_start);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetLocalIterEnd, MinIsTileIterEnd)
|
||||
{
|
||||
// Types
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<Config::GemmShape>;
|
||||
using Kernel = KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_LOCAL_ITER_END>;
|
||||
// Test parameters
|
||||
ck_tile::DeviceMem local_iter_end_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::index_t tile_iter_start = 6;
|
||||
ck_tile::index_t iter_end = 9;
|
||||
ck_tile::index_t tile_iter_end = 8;
|
||||
|
||||
// Launch kernel
|
||||
auto kargs = Kernel::MakeKernelArgs(tile_iter_start,
|
||||
iter_end,
|
||||
tile_iter_end,
|
||||
local_iter_end_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
Config::UNUSED);
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
|
||||
|
||||
// Validate results
|
||||
ck_tile::index_t local_iter_end;
|
||||
local_iter_end_dev.FromDevice(&local_iter_end);
|
||||
EXPECT_EQ(local_iter_end, tile_iter_end - tile_iter_start);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetLocalIterEnd, MinIsIterEnd)
|
||||
{
|
||||
// Types
|
||||
// Note: For this test, the Config is used for types only, the function get_local_iter_end is
|
||||
// static; thus, the test parameters are independent of the Config in this case.
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<Config::GemmShape>;
|
||||
using Kernel = KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_LOCAL_ITER_END>;
|
||||
// Test parameters
|
||||
ck_tile::DeviceMem local_iter_end_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::index_t tile_iter_start = 12;
|
||||
ck_tile::index_t iter_end = 13;
|
||||
ck_tile::index_t tile_iter_end = 14;
|
||||
|
||||
// Launch kernel
|
||||
auto kargs = Kernel::MakeKernelArgs(tile_iter_start,
|
||||
iter_end,
|
||||
tile_iter_end,
|
||||
local_iter_end_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
Config::UNUSED);
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
|
||||
|
||||
// Validate results
|
||||
ck_tile::index_t local_iter_end;
|
||||
local_iter_end_dev.FromDevice(&local_iter_end);
|
||||
EXPECT_EQ(local_iter_end, iter_end - tile_iter_start);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetTileBoundaries, GetTileBoundaries)
|
||||
{
|
||||
// Types
|
||||
using Config = StreamKTilePartitionerBaseConfigSKOnly;
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<Config::GemmShape>;
|
||||
using Kernel =
|
||||
KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_TILE_BOUNDARIES>;
|
||||
|
||||
// Test parameters
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
ck_tile::DeviceMem tile_iter_start_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem tile_iter_end_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::index_t tile_idx = 1;
|
||||
|
||||
// Launch kernel
|
||||
auto kargs = Kernel::MakeKernelArgs(Config::PLACEHOLDER,
|
||||
Config::PLACEHOLDER,
|
||||
tile_idx,
|
||||
tile_iter_start_dev.GetDeviceBuffer(),
|
||||
tile_iter_end_dev.GetDeviceBuffer(),
|
||||
tile_partitioner);
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
|
||||
|
||||
// Validate results
|
||||
ck_tile::index_t tile_iter_start, tile_iter_end;
|
||||
tile_iter_start_dev.FromDevice(&tile_iter_start);
|
||||
tile_iter_end_dev.FromDevice(&tile_iter_end);
|
||||
// There are 2 iters per tile. Thus, for tile_idx 1, we expect 2 and 4 to be the start and end,
|
||||
// respectively.
|
||||
EXPECT_EQ(tile_iter_start, 2);
|
||||
EXPECT_EQ(tile_iter_end, 4);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetTileIndex, GetTileIndex)
|
||||
{
|
||||
// Types
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<Config::GemmShape>;
|
||||
using Kernel = KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_TILE_INDEX>;
|
||||
|
||||
// Test parameters
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
ck_tile::DeviceMem tile_idx_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::index_t iter_start = 8;
|
||||
|
||||
// Launch kernel
|
||||
auto kargs = Kernel::MakeKernelArgs(iter_start,
|
||||
Config::UNUSED,
|
||||
Config::UNUSED,
|
||||
tile_idx_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
tile_partitioner);
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
|
||||
|
||||
// Validate results
|
||||
ck_tile::index_t tile_idx;
|
||||
tile_idx_dev.FromDevice(&tile_idx);
|
||||
// Since there are 2 iters per tile, iter 8 maps to tile_idx 4.
|
||||
EXPECT_EQ(tile_idx, 4);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetIterBoundaries, ZeroExtraItersBeforeMe)
|
||||
{
|
||||
// Types
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<Config::GemmShape>;
|
||||
using Kernel =
|
||||
KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_ITER_BOUNDARIES>;
|
||||
|
||||
// Test parameters
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
ck_tile::DeviceMem iter_start_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::index_t cta_idx = 0;
|
||||
|
||||
// Launch kernel
|
||||
auto kargs = Kernel::MakeKernelArgs(Config::PLACEHOLDER,
|
||||
Config::PLACEHOLDER,
|
||||
cta_idx,
|
||||
iter_start_dev.GetDeviceBuffer(),
|
||||
iter_end_dev.GetDeviceBuffer(),
|
||||
tile_partitioner);
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
|
||||
|
||||
// Validate results
|
||||
ck_tile::index_t iter_start, iter_end;
|
||||
iter_start_dev.FromDevice(&iter_start);
|
||||
iter_end_dev.FromDevice(&iter_end);
|
||||
EXPECT_EQ(iter_start, 6);
|
||||
EXPECT_EQ(iter_end, 9);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetIterBoundaries, NonZeroExtraItersBeforeMe)
|
||||
{
|
||||
// Types
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<Config::GemmShape>;
|
||||
using Kernel =
|
||||
KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_ITER_BOUNDARIES>;
|
||||
|
||||
// Test parameters
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
ck_tile::DeviceMem iter_start_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::index_t cta_idx = 1;
|
||||
|
||||
// Launch kernel
|
||||
auto kargs = Kernel::MakeKernelArgs(Config::PLACEHOLDER,
|
||||
Config::PLACEHOLDER,
|
||||
cta_idx,
|
||||
iter_start_dev.GetDeviceBuffer(),
|
||||
iter_end_dev.GetDeviceBuffer(),
|
||||
tile_partitioner);
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
|
||||
|
||||
// Validate results
|
||||
ck_tile::index_t iter_start, iter_end;
|
||||
iter_start_dev.FromDevice(&iter_start);
|
||||
iter_end_dev.FromDevice(&iter_end);
|
||||
EXPECT_EQ(iter_start, 9);
|
||||
EXPECT_EQ(iter_end, 12);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetIterBoundaries, MinIsExtraIters)
|
||||
{
|
||||
// Types
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<Config::GemmShape>;
|
||||
using Kernel =
|
||||
KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_ITER_BOUNDARIES>;
|
||||
|
||||
// Test parameters
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
ck_tile::DeviceMem iter_start_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem iter_end_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::index_t cta_idx = 2;
|
||||
|
||||
// Launch kernel
|
||||
auto kargs = Kernel::MakeKernelArgs(Config::PLACEHOLDER,
|
||||
Config::PLACEHOLDER,
|
||||
cta_idx,
|
||||
iter_start_dev.GetDeviceBuffer(),
|
||||
iter_end_dev.GetDeviceBuffer(),
|
||||
tile_partitioner);
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
|
||||
|
||||
// Validate results
|
||||
ck_tile::index_t iter_start, iter_end;
|
||||
iter_start_dev.FromDevice(&iter_start);
|
||||
iter_end_dev.FromDevice(&iter_end);
|
||||
EXPECT_EQ(iter_start, 12);
|
||||
EXPECT_EQ(iter_end, 14);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetOutputTileIndex, TestAllMappings)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigLargerCTensor;
|
||||
ck_tile::index_t m_macro_tiles = Config::M / Config::M_TILE;
|
||||
ck_tile::index_t n_macro_tiles = Config::N / Config::N_TILE;
|
||||
ck_tile::index_t tile_idx = 0;
|
||||
|
||||
for(ck_tile::index_t row = 0; row < m_macro_tiles; ++row)
|
||||
{
|
||||
for(ck_tile::index_t col = 0; col < n_macro_tiles; ++col)
|
||||
{
|
||||
test_get_output_tile_index(tile_idx, ck_tile::make_tuple(row, col));
|
||||
++tile_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Persistent
|
||||
TEST(StreamKTilePartitioner_v2_PersistentConstructor, SKOnly)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigSKOnly;
|
||||
|
||||
ck_tile::StreamKTilePartitioner_v2<Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
true>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerV2PersistentExpected expected_values{0, 0, 3};
|
||||
validate_streamk_v2_persistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitioner_v2_PersistentConstructor, DPOnly)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigDPOnly;
|
||||
|
||||
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
true>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerV2PersistentExpected expected_values{2, 0, 3};
|
||||
validate_streamk_v2_persistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitioner_v2_PersistentConstructor, DP2TileSK)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
|
||||
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
true>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerV2PersistentExpected expected_values{1, 0, 3};
|
||||
validate_streamk_v2_persistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitioner_v2_PersistentConstructor, EdgeCase)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigEdgeCase;
|
||||
|
||||
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
true>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerV2PersistentExpected expected_values{0, 1, 4};
|
||||
validate_streamk_v2_persistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitioner_v2_GridSize_Persistent, SKOnly)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigSKOnly;
|
||||
|
||||
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
true>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
const auto g = tile_partitioner.grid_size();
|
||||
EXPECT_EQ(g.x, Config::GRID);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitioner_v2_GridSize_Persistent, EdgeCase)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigEdgeCase;
|
||||
|
||||
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
true>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
const auto g = tile_partitioner.grid_size();
|
||||
EXPECT_EQ(g.x, 1);
|
||||
}
|
||||
|
||||
// Non-Persistent Tests
|
||||
TEST(StreamKTilePartitioner_v2_NonPersistentConstructor, SKOnly)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigSKOnly;
|
||||
|
||||
ck_tile::StreamKTilePartitioner_v2<Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
false>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerV2NonPersistentExpected expected_values{0, 0, 0, 3};
|
||||
validate_streamk_v2_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitioner_v2_NonPersistentConstructor, DPOnly)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigDPOnly;
|
||||
|
||||
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
false>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerV2NonPersistentExpected expected_values{6, 0, 6, 3};
|
||||
validate_streamk_v2_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitioner_v2_NonPersistentConstructor, DP2TileSK)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
|
||||
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
false>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerV2NonPersistentExpected expected_values{3, 0, 3, 3};
|
||||
validate_streamk_v2_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitioner_v2_NonPersistentConstructor, EdgeCase)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigEdgeCase;
|
||||
|
||||
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
false>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
StreamKTilePartitionerV2NonPersistentExpected expected_values{1, 0, 1, 4};
|
||||
validate_streamk_v2_nonpersistent<Config::GemmShape>(expected_values, tile_partitioner);
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitioner_v2_GridSize_NonPersistent, DP2TileSK)
|
||||
{
|
||||
using Config = StreamKTilePartitionerBaseConfigDP2TileSK;
|
||||
|
||||
ck_tile::StreamKTilePartitioner_v2<typename Config::GemmShape,
|
||||
ck_tile::StreamKReductionStrategy::Atomic,
|
||||
false>
|
||||
tile_partitioner{Config::M, Config::N, Config::K, Config::GRID};
|
||||
|
||||
const auto g = tile_partitioner.grid_size();
|
||||
EXPECT_EQ(g.x, 6);
|
||||
}
|
||||
@@ -0,0 +1,339 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
enum StreamKTilePartitionerBaseMethodId
|
||||
{
|
||||
GET_LOCAL_ITER,
|
||||
GET_LOCAL_ITER_END,
|
||||
GET_TILE_BOUNDARIES,
|
||||
GET_TILE_INDEX,
|
||||
GET_ITER_BOUNDARIES,
|
||||
GET_OUTPUT_TILE_INDEX
|
||||
};
|
||||
|
||||
// Base kernel wrapper class to facilitate testing class device functions.
|
||||
template <typename T = ck_tile::index_t>
|
||||
struct KernelWrapper
|
||||
{
|
||||
static constexpr ck_tile::index_t kBlockSize = 1;
|
||||
|
||||
struct KernelArgs
|
||||
{
|
||||
ck_tile::index_t arg1;
|
||||
ck_tile::index_t arg2;
|
||||
ck_tile::index_t arg3;
|
||||
void* result1;
|
||||
void* result2;
|
||||
T tile_partitioner;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static KernelArgs MakeKernelArgs(ck_tile::index_t arg1,
|
||||
ck_tile::index_t arg2,
|
||||
ck_tile::index_t arg3,
|
||||
void* result1,
|
||||
void* result2,
|
||||
T tile_partitioner)
|
||||
{
|
||||
return KernelArgs{arg1, arg2, arg3, result1, result2, tile_partitioner};
|
||||
}
|
||||
};
|
||||
|
||||
// Specialized derived class to support unique operator() functions. There is one template
|
||||
// specialization per member in the StreamKTilePartitionerBaseMethodId enum.
|
||||
template <typename TilePartitioner, StreamKTilePartitionerBaseMethodId Id>
|
||||
struct KernelWrapperSpecialized;
|
||||
|
||||
template <typename TilePartitioner>
|
||||
struct KernelWrapperSpecialized<TilePartitioner, StreamKTilePartitionerBaseMethodId::GET_LOCAL_ITER>
|
||||
: public KernelWrapper<>
|
||||
{
|
||||
using Base = KernelWrapper<>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(Base::KernelArgs kargs)
|
||||
{
|
||||
*(static_cast<ck_tile::index_t*>(kargs.result1)) =
|
||||
TilePartitioner::get_local_iter(kargs.arg1, kargs.arg2);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TilePartitioner>
|
||||
struct KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_TILE_BOUNDARIES>
|
||||
: public KernelWrapper<TilePartitioner>
|
||||
{
|
||||
|
||||
using Base = KernelWrapper<TilePartitioner>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(typename Base::KernelArgs kargs)
|
||||
{
|
||||
kargs.tile_partitioner.get_tile_boundaries(kargs.arg1, kargs.arg2, kargs.arg3);
|
||||
*(static_cast<ck_tile::index_t*>(kargs.result1)) = kargs.arg1;
|
||||
*(static_cast<ck_tile::index_t*>(kargs.result2)) = kargs.arg2;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TilePartitioner>
|
||||
struct KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_ITER_BOUNDARIES>
|
||||
: public KernelWrapper<TilePartitioner>
|
||||
{
|
||||
|
||||
using Base = KernelWrapper<TilePartitioner>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(typename Base::KernelArgs kargs)
|
||||
{
|
||||
kargs.tile_partitioner.get_iter_boundaries(kargs.arg1, kargs.arg2, kargs.arg3);
|
||||
*(static_cast<ck_tile::index_t*>(kargs.result1)) = kargs.arg1;
|
||||
*(static_cast<ck_tile::index_t*>(kargs.result2)) = kargs.arg2;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TilePartitioner>
|
||||
struct KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_LOCAL_ITER_END>
|
||||
: public KernelWrapper<>
|
||||
{
|
||||
|
||||
using Base = KernelWrapper<>;
|
||||
CK_TILE_DEVICE void operator()(Base::KernelArgs kargs)
|
||||
{
|
||||
*(static_cast<ck_tile::index_t*>(kargs.result1)) =
|
||||
TilePartitioner::get_local_iter_end(kargs.arg1, kargs.arg2, kargs.arg3);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TilePartitioner>
|
||||
struct KernelWrapperSpecialized<TilePartitioner, StreamKTilePartitionerBaseMethodId::GET_TILE_INDEX>
|
||||
: public KernelWrapper<TilePartitioner>
|
||||
{
|
||||
|
||||
using Base = KernelWrapper<TilePartitioner>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(typename Base::KernelArgs kargs)
|
||||
{
|
||||
*(static_cast<ck_tile::index_t*>(kargs.result1)) =
|
||||
kargs.tile_partitioner.get_tile_index(kargs.arg1);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TilePartitioner>
|
||||
struct KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_OUTPUT_TILE_INDEX>
|
||||
: public KernelWrapper<TilePartitioner>
|
||||
{
|
||||
|
||||
using Base = KernelWrapper<TilePartitioner>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(typename Base::KernelArgs kargs)
|
||||
{
|
||||
auto [im, in] = kargs.tile_partitioner.get_output_tile_index(kargs.arg1);
|
||||
*(static_cast<ck_tile::index_t*>(kargs.result1)) = im;
|
||||
*(static_cast<ck_tile::index_t*>(kargs.result2)) = in;
|
||||
}
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseExpected
|
||||
{
|
||||
ck_tile::index_t sk_tiles_;
|
||||
ck_tile::index_t dp_tiles_;
|
||||
ck_tile::index_t sk_ctas_;
|
||||
ck_tile::index_t total_sk_iters_;
|
||||
ck_tile::index_t iters_per_sk_cta_;
|
||||
ck_tile::index_t iters_per_tile_;
|
||||
ck_tile::index_t extra_iters_;
|
||||
ck_tile::index_t total_dp_iters_;
|
||||
ck_tile::index_t num_tiles_;
|
||||
ck_tile::index_t grid_;
|
||||
ck_tile::index_t n_;
|
||||
};
|
||||
|
||||
template <typename GemmShape>
|
||||
void validate_streamk_base_constructor(
|
||||
StreamKTilePartitionerBaseExpected& expected_values,
|
||||
ck_tile::StreamKTilePartitionerBase<GemmShape>& tile_partitioner)
|
||||
{
|
||||
EXPECT_EQ(tile_partitioner.get_sk_tiles(), expected_values.sk_tiles_);
|
||||
EXPECT_EQ(tile_partitioner.get_dp_tiles(), expected_values.dp_tiles_);
|
||||
EXPECT_EQ(tile_partitioner.get_sk_ctas(), expected_values.sk_ctas_);
|
||||
EXPECT_EQ(tile_partitioner.get_total_sk_iters(), expected_values.total_sk_iters_);
|
||||
EXPECT_EQ(tile_partitioner.get_iters_per_sk_cta(), expected_values.iters_per_sk_cta_);
|
||||
EXPECT_EQ(tile_partitioner.get_extra_iters(), expected_values.extra_iters_);
|
||||
EXPECT_EQ(tile_partitioner.get_iters_per_tile(), expected_values.iters_per_tile_);
|
||||
EXPECT_EQ(tile_partitioner.get_total_dp_iters(), expected_values.total_dp_iters_);
|
||||
EXPECT_EQ(tile_partitioner.get_num_tiles(), expected_values.num_tiles_);
|
||||
EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_);
|
||||
EXPECT_EQ(tile_partitioner.get_n(), expected_values.n_);
|
||||
}
|
||||
|
||||
struct StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t PLACEHOLDER = -1;
|
||||
static constexpr ck_tile::index_t UNUSED = -1;
|
||||
};
|
||||
|
||||
// Note: for the configs below, we only use BlockTiles in the TileGemmShape. We do not use
|
||||
// BlockWarps or WarpTile.
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigDP2TileSK : public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M = 28;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
static constexpr ck_tile::index_t GRID = 3;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 4;
|
||||
static constexpr ck_tile::index_t K_TILE = 8;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigDPOnly : public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M = 12;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
static constexpr ck_tile::index_t GRID = 3;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 2;
|
||||
static constexpr ck_tile::index_t K_TILE = 8;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigSKOnly : public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M = 4;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
static constexpr ck_tile::index_t GRID = 3;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 2;
|
||||
static constexpr ck_tile::index_t K_TILE = 8;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigEdgeCase : public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
|
||||
static constexpr ck_tile::index_t M = 4;
|
||||
static constexpr ck_tile::index_t N = 4;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
static constexpr ck_tile::index_t GRID = 4;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 4;
|
||||
static constexpr ck_tile::index_t K_TILE = 8;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigLargerCTensor : public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
// This config has 3 macro tiles in the M dimension and 4 macro tiles in the N dimension.
|
||||
// This facilitates testing the get_output_tile_index method.
|
||||
|
||||
static constexpr ck_tile::index_t M = 12;
|
||||
static constexpr ck_tile::index_t N = 16;
|
||||
static constexpr ck_tile::index_t K = 16;
|
||||
static constexpr ck_tile::index_t GRID = 4;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 4;
|
||||
static constexpr ck_tile::index_t K_TILE = 8;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
|
||||
};
|
||||
|
||||
void test_get_output_tile_index(ck_tile::index_t tile_idx,
|
||||
ck_tile::tuple<ck_tile::index_t, ck_tile::index_t> expected_2d_idx)
|
||||
{
|
||||
// Types
|
||||
using Config = StreamKTilePartitionerBaseConfigLargerCTensor;
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<Config::GemmShape>;
|
||||
using Kernel =
|
||||
KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_OUTPUT_TILE_INDEX>;
|
||||
|
||||
// Test parameters
|
||||
ck_tile::StreamKTilePartitionerBase<Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
ck_tile::DeviceMem im_dev(sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem in_dev(sizeof(ck_tile::index_t));
|
||||
|
||||
// Launch kernel
|
||||
auto kargs = Kernel::MakeKernelArgs(tile_idx,
|
||||
Config::UNUSED,
|
||||
Config::UNUSED,
|
||||
im_dev.GetDeviceBuffer(),
|
||||
in_dev.GetDeviceBuffer(),
|
||||
tile_partitioner);
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
|
||||
|
||||
// Validate results
|
||||
const auto [im_expected, in_expected] = expected_2d_idx;
|
||||
ck_tile::index_t im, in;
|
||||
im_dev.FromDevice(&im);
|
||||
in_dev.FromDevice(&in);
|
||||
EXPECT_EQ(im, im_expected);
|
||||
EXPECT_EQ(in, in_expected);
|
||||
};
|
||||
|
||||
// Configs for TilePartitioner Child structs
|
||||
struct StreamKTilePartitionerV2PersistentExpected
|
||||
{
|
||||
ck_tile::index_t dp_tiles_per_cta_;
|
||||
ck_tile::index_t extra_dp_tiles_;
|
||||
ck_tile::index_t grid_;
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerV2NonPersistentExpected
|
||||
{
|
||||
ck_tile::index_t dp_ctas_;
|
||||
ck_tile::index_t dp_start_block_idx_;
|
||||
ck_tile::index_t sk_start_block_idx_;
|
||||
ck_tile::index_t grid_;
|
||||
};
|
||||
|
||||
// Persistent
|
||||
template <typename GemmShape>
|
||||
void validate_streamk_v2_persistent(
|
||||
StreamKTilePartitionerV2PersistentExpected& expected_values,
|
||||
ck_tile::StreamKTilePartitioner_v2<GemmShape, ck_tile::StreamKReductionStrategy::Atomic, true>&
|
||||
tile_partitioner)
|
||||
{
|
||||
EXPECT_EQ(tile_partitioner.get_dp_tiles_per_cta(), expected_values.dp_tiles_per_cta_);
|
||||
EXPECT_EQ(tile_partitioner.get_extra_dp_tiles(), expected_values.extra_dp_tiles_);
|
||||
EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_);
|
||||
}
|
||||
|
||||
// Non-Persistent
|
||||
template <typename GemmShape>
|
||||
void validate_streamk_v2_nonpersistent(
|
||||
StreamKTilePartitionerV2NonPersistentExpected& expected_values,
|
||||
ck_tile::StreamKTilePartitioner_v2<GemmShape, ck_tile::StreamKReductionStrategy::Atomic, false>&
|
||||
tile_partitioner)
|
||||
{
|
||||
EXPECT_EQ(tile_partitioner.get_dp_ctas(), expected_values.dp_ctas_);
|
||||
EXPECT_EQ(tile_partitioner.get_dp_start_block_idx(), expected_values.dp_start_block_idx_);
|
||||
EXPECT_EQ(tile_partitioner.get_sk_start_block_idx(), expected_values.sk_start_block_idx_);
|
||||
EXPECT_EQ(tile_partitioner.get_grid(), expected_values.grid_);
|
||||
}
|
||||
Reference in New Issue
Block a user