mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Style updates and cleanup
The following changes were made - Renamed iter to iter_start - Renamed tile_iter to tile_iter_start - Moved documentation from member variables to getters - Removed double underscore from extra_iters_before_me variable - Defined parent header in impl file - Removed unused inlcudes
This commit is contained in:
committed by
Emily Martins
parent
8f75d7cea6
commit
cb83d52301
@@ -10,8 +10,6 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include <format>
|
||||
#include <iostream>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -812,5 +810,4 @@ struct StreamKTilePartitioner
|
||||
uint32_t M_, N_, K_;
|
||||
uint32_t num_tile_m_, num_tile_n_, num_tile_k_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -14,20 +14,20 @@ namespace ck_tile {
|
||||
* This partitioner is responsible for mapping workgroups to tiles in the C tensor
|
||||
* for the Stream-K algorithm.
|
||||
*
|
||||
* @tparam BlockGemmShapeType A class providing basic GEMM parameters.
|
||||
* @tparam ReductionStrategy An enum that defines the reduction strategy for the results in the C
|
||||
* Tensor.
|
||||
* @tparam BlockGemmShapeType A class providing basic GEMM parameters.
|
||||
* @tparam ReductionStrategyType An enum that defines the reduction strategy for the results in
|
||||
* the C Tensor.
|
||||
*/
|
||||
template <typename BlockGemmShapeType,
|
||||
StreamKReductionStrategy ReductionStrategy = StreamKReductionStrategy::Atomic>
|
||||
StreamKReductionStrategy ReductionStrategyType = StreamKReductionStrategy::Atomic>
|
||||
struct StreamKTilePartitionerBase
|
||||
{
|
||||
using BlockGemmShape = BlockGemmShapeType;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr StreamKReductionStrategy StreamKReductionStrategy = ReductionStrategy;
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategyType;
|
||||
|
||||
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid);
|
||||
|
||||
@@ -51,59 +51,62 @@ struct StreamKTilePartitionerBase
|
||||
/**
|
||||
* @brief Calculates the start and end iteration given the cta_idx.
|
||||
*
|
||||
* @param iter Reference to an index_t; will be set to the starting iteration by the
|
||||
* @param iter_start Reference to an index_t; will be set to the starting iteration by the
|
||||
* function.
|
||||
* @param iter_end Reference to an index_t; will be set to the non-inclusive end iteration by
|
||||
* @param iter_end Reference to an index_t; will be set to the non-inclusive end iteration by
|
||||
* the function.
|
||||
* @param cta_idx The current Stream-K workgroup's index.
|
||||
* @param cta_idx The current Stream-K workgroup's index.
|
||||
* @note It is assumed that the first Stream-K workgroup has a `cta_idx` of zero. If a
|
||||
* non-persistent DP section is used, then a Stream-K workgroup's `cta_idx` should be something
|
||||
* like `blockIdx.x` minus number of DP workgroups.
|
||||
*/
|
||||
CK_TILE_DEVICE void
|
||||
get_iter_boundaries(index_t& iter, index_t& iter_end, index_t cta_idx) const noexcept;
|
||||
get_iter_boundaries(index_t& iter_start, index_t& iter_end, index_t cta_idx) const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the 1D tile index in the C tensor for a workgroup.
|
||||
*
|
||||
* @param iter The starting iteration.
|
||||
* @return index_t The 1D tile index.
|
||||
* @param iter_start The starting iteration.
|
||||
* @return index_t The 1D tile index.
|
||||
*/
|
||||
CK_TILE_DEVICE index_t get_tile_index(index_t iter) const noexcept;
|
||||
CK_TILE_DEVICE index_t get_tile_index(index_t iter_start) const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the starting and ending tile boundaries for the given 1D tile index.
|
||||
*
|
||||
* @param tile_iter Reference to an index_t; will be set to the tile's start iteration by
|
||||
* @param tile_iter_start Reference to an index_t; will be set to the tile's start iteration by
|
||||
* the function.
|
||||
* @param tile_iter_end Reference to an index_t; will be set to the non-inclusive tile's end
|
||||
* @param tile_iter_end Reference to an index_t; will be set to the non-inclusive tile's end
|
||||
* iteration by the function.
|
||||
* @param tile_idx The 1D C tensor tile index for the workgroup.
|
||||
*/
|
||||
CK_TILE_DEVICE void get_tile_boundaries(index_t& tile_iter,
|
||||
CK_TILE_DEVICE void get_tile_boundaries(index_t& tile_iter_start,
|
||||
index_t& tile_iter_end,
|
||||
index_t tile_idx) const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the workgroup's starting iteration that is local to a tile.
|
||||
*
|
||||
* @param iter The starting iteration.
|
||||
* @param iter_start The starting iteration.
|
||||
* @param tile_iter_start The starting iteration of the tile (i.e., the tile's starting
|
||||
* boundary).
|
||||
* @return index_t The local starting iteration. The value is in range [0, `iters_per_tile_`).
|
||||
* @note Assumes `iter` >= `tile_iter`.
|
||||
* @note Assumes `iter_start` >= `tile_iter_start`.
|
||||
*/
|
||||
CK_TILE_DEVICE static index_t get_local_iter(index_t iter, index_t tile_iter) noexcept;
|
||||
CK_TILE_DEVICE static index_t get_local_iter(index_t iter_start,
|
||||
index_t tile_iter_start) noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the workgroup's non-inclusive end iteration that is local to a tile.
|
||||
*
|
||||
* @param tile_iter The starting tile iteration.
|
||||
* @param iter_end The non-inclusive end iteration.
|
||||
* @param tile_iter_end The non-inclusive end iteration of the tile.
|
||||
* @return index_t The local non-inclusive end iteration.
|
||||
* @note Assumes `iter_end` >= `tile_iter` and `tile_iter_end` >= `tile_iter`.
|
||||
* @param tile_iter_start The starting tile iteration.
|
||||
* @param iter_end The non-inclusive end iteration.
|
||||
* @param tile_iter_end The non-inclusive end iteration of the tile.
|
||||
* @return index_t The local non-inclusive end iteration.
|
||||
* @note Assumes `iter_end` >= `tile_iter_start` and `tile_iter_end` >= `tile_iter_start`.
|
||||
*/
|
||||
CK_TILE_DEVICE static index_t
|
||||
get_local_iter_end(index_t tile_iter, index_t iter_end, index_t tile_iter_end) noexcept;
|
||||
get_local_iter_end(index_t tile_iter_start, index_t iter_end, index_t tile_iter_end) noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the workgroups 2D tile index in the C tensor given the 1D tile index.
|
||||
@@ -122,83 +125,85 @@ struct StreamKTilePartitionerBase
|
||||
*/
|
||||
CK_TILE_HOST index_t get_workspace_size(index_t acc_element_bytes) const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the number of macro tiles in the C tensor.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_num_tiles() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the maximum number of active workgroups; this is assumed to be number of CUs *
|
||||
* occupancy.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_grid() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the number of tiles in the C tensor that will use the data-parallel (DP)
|
||||
* approach.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_dp_tiles() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the number of tiles in the C tensor that will use the Stream-K approach.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_sk_tiles() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the number of workgroups that will participate in Stream-K in the `sk_tiles_`.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_sk_ctas() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the total number of Stream-K iterations.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_total_sk_iters() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the total number of iterations per tile in the C tensor. In other words, this
|
||||
* is the total number of macro tiles along the K dimension of A and B.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_iters_per_tile() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the total number of Stream-K iterations for each `sk_cta`. This is the lower
|
||||
* bound (i.e., all `sk_ctas_` are guaranteed to perform at least this many iterations).
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_iters_per_sk_cta() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the remainder resulting from `total_sk_iters_` divided by `sk_ctas_`. When
|
||||
* this is non-zero, the first `extra_iters_` `sk_ctas_` will get one additional iteration
|
||||
* assigned to them; such work groups will perform (`iters_per_sk_cta_` + 1) iterations.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_extra_iters() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the total number of DP iterations.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_total_dp_iters() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the n dimension for the GEMM problem.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_n() const noexcept;
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief The number of macro tiles in the C tensor.
|
||||
*/
|
||||
index_t num_tiles_;
|
||||
/**
|
||||
* @brief The maximum number of active workgroups; this is assumed to be number of CUs *
|
||||
* occupancy.
|
||||
*/
|
||||
index_t grid_;
|
||||
/**
|
||||
* @brief The number of tiles in the C tensor that will use the data-parallel (DP) approach.
|
||||
*/
|
||||
index_t dp_tiles_;
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief The number of full tiles assigned to each `sk_cta` when performing DP + 2 Tile SK.
|
||||
* @brief The number of full tiles assigned to each `sk_cta` when performing DP + 2 Tile
|
||||
* Stream-K.
|
||||
*/
|
||||
index_t full_tiles_ = 1;
|
||||
/**
|
||||
* @brief The number of tiles in the C tensor that will use the Stream-K approach.
|
||||
*/
|
||||
index_t sk_tiles_;
|
||||
/**
|
||||
* @brief The number of workgroups that will participate in Stream-K in the `sk_tiles_`.
|
||||
*/
|
||||
index_t sk_ctas_;
|
||||
/**
|
||||
* @brief The total number of Stream-K iterations.
|
||||
*/
|
||||
index_t total_sk_iters_;
|
||||
/**
|
||||
* @brief The total number of iterations per tile in the C tensor. In other words, this is the
|
||||
* total number of macro tiles along the K dimension of A and B.
|
||||
*/
|
||||
index_t iters_per_tile_;
|
||||
/**
|
||||
* @brief The total number of Stream-K iterations for each `sk_cta`. This is the lower bound
|
||||
* (i.e., all `sk_ctas_` are guaranteed to perform at least this many iterations).
|
||||
*/
|
||||
index_t iters_per_sk_cta_;
|
||||
/**
|
||||
* @brief The remainder resulting from `total_sk_iters_` divided by `sk_ctas_`. When this is
|
||||
* non-zero, the first `extra_iters_` `sk_ctas_` will get one additional iteration assigned to
|
||||
* them; such work groups will perform (`iters_per_sk_cta_` + 1) iterations.
|
||||
*/
|
||||
index_t extra_iters_;
|
||||
/**
|
||||
* @brief The total number of DP iterations.
|
||||
*/
|
||||
index_t total_dp_iters_;
|
||||
/**
|
||||
* @brief The n dimension for the GEMM problem.
|
||||
*/
|
||||
index_t n_;
|
||||
};
|
||||
|
||||
@@ -207,15 +212,17 @@ struct StreamKTilePartitionerBase
|
||||
*
|
||||
* This partitioner is responsible for mapping workgroups to tiles in the C tensor
|
||||
* for the Stream-K algorithm. This struct is derived from
|
||||
* StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>. Behavior of the
|
||||
* StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>. Behavior of the
|
||||
* StreamKTilePartitioner based on persistency will be in the template specializations.
|
||||
*
|
||||
* @tparam BlockGemmShapeType A class providing basic GEMM parameters.
|
||||
* @tparam ReductionStrategy An enum that defines the reduction strategy for the results in the C
|
||||
* Tensor.
|
||||
* @tparam BlockGemmShapeType A class providing basic GEMM parameters.
|
||||
* @tparam ReductionStrategyType An enum that defines the reduction strategy for the results in
|
||||
* the C Tensor.
|
||||
* @tparam Persistent A bool that indicates whether to use a Persistent approach
|
||||
*/
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy, bool Persistent>
|
||||
template <typename BlockGemmShapeType,
|
||||
StreamKReductionStrategy ReductionStrategyType,
|
||||
bool Persistent>
|
||||
struct StreamKTilePartitioner_v2;
|
||||
|
||||
/**
|
||||
@@ -225,13 +232,13 @@ struct StreamKTilePartitioner_v2;
|
||||
* for the Stream-K algorithm when using a Persistent approach where no extra workgroups
|
||||
* are allocated for data parallel.
|
||||
*
|
||||
* @tparam BlockGemmShapeType A class providing basic GEMM parameters.
|
||||
* @tparam ReductionStrategy An enum that defines the reduction strategy for the results in the C
|
||||
* Tensor.
|
||||
* @tparam BlockGemmShapeType A class providing basic GEMM parameters.
|
||||
* @tparam ReductionStrategyType An enum that defines the reduction strategy for the results in
|
||||
* the C Tensor.
|
||||
*/
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, true>
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true>
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>
|
||||
{
|
||||
StreamKTilePartitioner_v2(ck_tile::index_t m,
|
||||
ck_tile::index_t n,
|
||||
@@ -248,19 +255,20 @@ struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, true>
|
||||
*/
|
||||
CK_TILE_HOST auto grid_size() const noexcept -> dim3;
|
||||
|
||||
/**
|
||||
* @brief Returns the total number of DP tiles per workgroup.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_dp_tiles_per_cta() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns the total number of DP tiles left over when `dp_tiles_` is not evenly
|
||||
* divisible by `grid_`.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_extra_dp_tiles() const noexcept;
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief The total number of DP tiles per workgroup.
|
||||
*/
|
||||
int dp_tiles_per_cta_;
|
||||
|
||||
/**
|
||||
* @brief The total number of DP tiles left over when dp_tiles is not evenly divisible by grid.
|
||||
*/
|
||||
int extra_dp_tiles_;
|
||||
index_t dp_tiles_per_cta_;
|
||||
index_t extra_dp_tiles_;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -271,12 +279,12 @@ struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, true>
|
||||
* are allocated for the data parallel section.
|
||||
*
|
||||
* @tparam BlockGemmShapeType A class providing basic GEMM parameters.
|
||||
* @tparam ReductionStrategy An enum that defines the reduction strategy for the results in the C
|
||||
* Tensor.
|
||||
* @tparam ReductionStrategyType An enum that defines the reduction strategy for the results in
|
||||
* the C Tensor.
|
||||
*/
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, false>
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, false>
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>
|
||||
{
|
||||
StreamKTilePartitioner_v2(ck_tile::index_t m,
|
||||
ck_tile::index_t n,
|
||||
@@ -292,25 +300,26 @@ struct StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, false>
|
||||
* @return dim_3 The launching grid size for the kernel.
|
||||
*/
|
||||
CK_TILE_HOST auto grid_size() const noexcept -> dim3;
|
||||
|
||||
/**
|
||||
* @brief Returns the total number of DP workgroups.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_dp_ctas() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Returns starting DP workgroup index. It is always zero.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_dp_start_block_idx() const noexcept;
|
||||
|
||||
/**
|
||||
* @brief The index that starts the Stream-K workgroups. It is set to the number of `dp_tiles_`.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE index_t get_sk_start_block_idx() const noexcept;
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief The total number of DP workgroups.
|
||||
*/
|
||||
int dp_ctas_;
|
||||
|
||||
/**
|
||||
* @brief The index that starts the DP workgroups, always 0 in our implementation.
|
||||
*/
|
||||
int dp_start_block_idx_;
|
||||
|
||||
/**
|
||||
* @brief The index that starts the Stream-K workgroups, set to the number of dp_tiles.
|
||||
*/
|
||||
int sk_start_block_idx_;
|
||||
index_t dp_ctas_;
|
||||
index_t dp_start_block_idx_;
|
||||
index_t sk_start_block_idx_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include "streamk_gemm_tile_partitioner.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::StreamKTilePartitionerBase(
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::StreamKTilePartitionerBase(
|
||||
index_t m, index_t n, index_t k, index_t grid)
|
||||
: grid_{grid}, n_{n}
|
||||
{
|
||||
@@ -43,68 +44,68 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::StreamKTilePa
|
||||
total_dp_iters_ = dp_tiles_ * iters_per_tile_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_partials_buffer_size(
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_partials_buffer_size(
|
||||
index_t acc_element_bytes) const noexcept
|
||||
{
|
||||
return MPerBlock * NPerBlock * acc_element_bytes * sk_ctas_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_flags_buffer_size()
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_flags_buffer_size()
|
||||
const noexcept
|
||||
{
|
||||
return sizeof(index_t) * sk_ctas_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE void
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_iter_boundaries(
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_iter_boundaries(
|
||||
index_t& iter, index_t& iter_end, index_t cta_idx) const noexcept
|
||||
{
|
||||
index_t extra_iters__before_me = ck_tile::min(cta_idx, extra_iters_);
|
||||
iter = total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters__before_me;
|
||||
index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
|
||||
iter = total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
|
||||
iter_end = iter + iters_per_sk_cta_ + (cta_idx < extra_iters_);
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_tile_index(
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_tile_index(
|
||||
index_t iter) const noexcept
|
||||
{
|
||||
return iter / iters_per_tile_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE void
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_tile_boundaries(
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_tile_boundaries(
|
||||
index_t& tile_iter, index_t& tile_iter_end, index_t tile_idx) const noexcept
|
||||
{
|
||||
tile_iter = tile_idx * iters_per_tile_;
|
||||
tile_iter_end = tile_iter + iters_per_tile_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE /* static */ index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_local_iter(
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_local_iter(
|
||||
index_t iter, index_t tile_iter) noexcept
|
||||
{
|
||||
return iter - tile_iter;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE /* static */ index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_local_iter_end(
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_local_iter_end(
|
||||
index_t tile_iter, index_t iter_end, index_t tile_iter_end) noexcept
|
||||
{
|
||||
return ck_tile::min(iter_end, tile_iter_end) - tile_iter;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE auto
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_output_tile_index(
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_output_tile_index(
|
||||
index_t tile_idx) const noexcept -> tuple<index_t, index_t>
|
||||
{
|
||||
const index_t n_macro_tiles = integer_divide_ceil(n_, NPerBlock);
|
||||
@@ -114,12 +115,12 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_output_ti
|
||||
return make_tuple(im, in);
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_workspace_size(
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_workspace_size(
|
||||
index_t acc_element_bytes) const noexcept
|
||||
{
|
||||
if constexpr(StreamKReductionStrategy == StreamKReductionStrategy::Reduction)
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
|
||||
return get_partials_buffer_size(acc_element_bytes) + get_flags_buffer_size();
|
||||
@@ -130,104 +131,111 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_workspace
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_num_tiles() const noexcept
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_num_tiles()
|
||||
const noexcept
|
||||
{
|
||||
return num_tiles_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_grid() const noexcept
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_grid() const noexcept
|
||||
{
|
||||
return grid_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_dp_tiles() const noexcept
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_dp_tiles() const noexcept
|
||||
{
|
||||
return dp_tiles_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_sk_tiles() const noexcept
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_sk_tiles() const noexcept
|
||||
{
|
||||
return sk_tiles_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_sk_ctas() const noexcept
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_sk_ctas() const noexcept
|
||||
{
|
||||
return sk_ctas_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_total_sk_iters()
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_total_sk_iters()
|
||||
const noexcept
|
||||
{
|
||||
return total_sk_iters_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_iters_per_tile()
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_iters_per_tile()
|
||||
const noexcept
|
||||
{
|
||||
return iters_per_tile_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_iters_per_sk_cta()
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_iters_per_sk_cta()
|
||||
const noexcept
|
||||
{
|
||||
return iters_per_sk_cta_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_extra_iters() const noexcept
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_extra_iters()
|
||||
const noexcept
|
||||
{
|
||||
return extra_iters_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_total_dp_iters()
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_total_dp_iters()
|
||||
const noexcept
|
||||
{
|
||||
return total_dp_iters_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>::get_n() const noexcept
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_n() const noexcept
|
||||
{
|
||||
return n_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy, bool Persistent>
|
||||
template <typename BlockGemmShapeType,
|
||||
StreamKReductionStrategy ReductionStrategyType,
|
||||
bool Persistent>
|
||||
struct StreamKTilePartitioner_v2;
|
||||
|
||||
// child class for Persistent Tile Partitioner
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, true>::StreamKTilePartitioner_v2(
|
||||
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid)
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>(m, n, k, grid)
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true>::
|
||||
StreamKTilePartitioner_v2(ck_tile::index_t m,
|
||||
ck_tile::index_t n,
|
||||
ck_tile::index_t k,
|
||||
ck_tile::index_t grid)
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
|
||||
{ // inherit from base constructor
|
||||
dp_tiles_per_cta_ = this->dp_tiles_ / this->grid_;
|
||||
extra_dp_tiles_ = this->dp_tiles_ % this->grid_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST auto
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, true>::grid_size() const noexcept
|
||||
-> dim3
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true>::grid_size()
|
||||
const noexcept -> dim3
|
||||
{
|
||||
if(extra_dp_tiles_ == 0)
|
||||
{
|
||||
@@ -239,61 +247,64 @@ StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, true>::grid_siz
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, true>::get_dp_tiles_per_cta()
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true>::get_dp_tiles_per_cta()
|
||||
const noexcept
|
||||
{
|
||||
return dp_tiles_per_cta_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, true>::get_extra_dp_tiles()
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, true>::get_extra_dp_tiles()
|
||||
const noexcept
|
||||
{
|
||||
return extra_dp_tiles_;
|
||||
}
|
||||
|
||||
// child class for Non-Persistent Tile Partitioner
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, false>::StreamKTilePartitioner_v2(
|
||||
ck_tile::index_t m, ck_tile::index_t n, ck_tile::index_t k, ck_tile::index_t grid)
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategy>(m, n, k, grid)
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, false>::
|
||||
StreamKTilePartitioner_v2(ck_tile::index_t m,
|
||||
ck_tile::index_t n,
|
||||
ck_tile::index_t k,
|
||||
ck_tile::index_t grid)
|
||||
: StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
|
||||
{ // inherit from base constructor
|
||||
dp_ctas_ = this->dp_tiles_;
|
||||
dp_start_block_idx_ = 0;
|
||||
sk_start_block_idx_ = this->dp_tiles_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST auto
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, false>::grid_size() const noexcept
|
||||
-> dim3
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, false>::grid_size()
|
||||
const noexcept -> dim3
|
||||
{
|
||||
return dim3(dp_ctas_ + this->get_sk_ctas(), 1, 1);
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, false>::get_dp_ctas()
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, false>::get_dp_ctas()
|
||||
const noexcept
|
||||
{
|
||||
return dp_ctas_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, false>::get_dp_start_block_idx()
|
||||
const noexcept
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, false>::
|
||||
get_dp_start_block_idx() const noexcept
|
||||
{
|
||||
return dp_start_block_idx_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategy>
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategy, false>::get_sk_start_block_idx()
|
||||
const noexcept
|
||||
StreamKTilePartitioner_v2<BlockGemmShapeType, ReductionStrategyType, false>::
|
||||
get_sk_start_block_idx() const noexcept
|
||||
{
|
||||
return sk_start_block_idx_;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user