mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK_TILE] Stream-K Tree Reduction and Cache Skipping Integration (#3371)
* CK Tile Stream-K Tree Reduction
This change adds the first implementation of the Stream-K tree reduction
strategy into CK Tile. The tree reduction reduces the the number of
steps for accumulating results for a tile from O(N) to O(logN) where N
is the number of workgroups contributing to a C tile.
Additionally, in the original non-atomic reduction strategy, atomics
were used to set the flags buffer and to read from the flags buffer.
Howeover, through investigation with the tree reduciton, atomics with
default (relaxed) semantics were not enough to guarantee workgroups
would not read stale data, leading to incorrect results. Stronger
acquire/release memory orderings are too expensive. So, this change
also eliminates the use of atomics for setting the flags. Instead, we
leverage cache modifiers (e.g., GLC) to avoid writing to cache, thereby
avoiding the use of atomics.
Prelimiary tests were also added for the normal reduction and tree
reduction. More will be added in a future PR via tile engine.
* Move Stream-K kernel files to a subdirectory
* Cleanup Code Style & Handle Unsupported Reductions
This change makes the following small changes:
- Add an explicit else block for unimplemented reduction strategies
- Clarify type of sk_flags_ptr via auto*
- Add description for extra_iters_before_me variable
* Run new copyright script on new files
[ROCm/composable_kernel commit: 22b945e06e]
This commit is contained in:
@@ -8,7 +8,8 @@
|
||||
namespace ck_tile {
|
||||
enum StreamKReductionStrategy : uint32_t
|
||||
{
|
||||
Atomic = 0u,
|
||||
Reduction = 1u
|
||||
Atomic = 0u,
|
||||
Reduction = 1u,
|
||||
TreeReduction = 2u
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -33,9 +33,10 @@
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_coherency.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename CompilerTarget, typename Enabler = void>
|
||||
struct StreamKCoherency
|
||||
{
|
||||
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
|
||||
amd_buffer_coherence_enum::coherence_default;
|
||||
};
|
||||
|
||||
template <typename CompilerTarget>
|
||||
struct StreamKCoherency<CompilerTarget,
|
||||
core::arch::enable_if_target_id_t<CompilerTarget,
|
||||
core::arch::amdgcn_target_id::GFX942,
|
||||
core::arch::amdgcn_target_id::GFX950>>
|
||||
{
|
||||
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
|
||||
amd_buffer_coherence_enum::SYSTEM_NT0;
|
||||
};
|
||||
|
||||
template <typename CompilerTarget>
|
||||
struct StreamKCoherency<CompilerTarget,
|
||||
core::arch::enable_if_target_id_t<CompilerTarget,
|
||||
core::arch::amdgcn_target_id::GFX908,
|
||||
core::arch::amdgcn_target_id::GFX90A>>
|
||||
{
|
||||
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
|
||||
amd_buffer_coherence_enum::glc_slc;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "streamk_gemm_coherency.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -318,37 +319,58 @@ struct StreamKKernel
|
||||
* results.
|
||||
* @param kargs Kernel arguments, including the workspace pointer.
|
||||
* @param cta_idx The index of the current thread block (CTA).
|
||||
* @note This function utilizes a workgroup barrier to set a synchronization flag for the given
|
||||
* CTA index.
|
||||
* @note This function utilizes a scalar store to write to the flags buffer.
|
||||
*/
|
||||
CK_TILE_DEVICE void SignalStorePartialDone(const StreamKKernelArgs& kargs,
|
||||
index_t cta_idx) const
|
||||
{
|
||||
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
|
||||
workgroup_barrier sk_flags(sk_flags_ptr);
|
||||
sk_flags.wait_set(0, 1, cta_idx);
|
||||
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
|
||||
index_t offset = cta_idx * sizeof(index_t);
|
||||
|
||||
asm volatile("s_mov_b32 m0, %2\n\t"
|
||||
// Depending on the architecture, the GLC flag will bypass the approproriate
|
||||
// cache level(s) to ensure the write is visible to other workgroups. See the
|
||||
// appropriate ISA for details about the GLC modifier.
|
||||
"s_store_dword %0, %1, %2 glc\n\t"
|
||||
"s_waitcnt lgkmcnt(0)" // Wait for the store to complete
|
||||
:
|
||||
: "s"(1), "s"(sk_flags_ptr), "s"(offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Waits for the thread block (cta_idx) to complete storing its partial results.
|
||||
* @param kargs Kernel arguments, including the workspace pointer.
|
||||
* @param cta_idx The index of the thread block (CTA).
|
||||
* @note This function utilizes a workgroup barrier to wait for the synchronization flag to be
|
||||
* set by the given CTA index.
|
||||
* @note This function utilizes a scalar load to read from the flags
|
||||
* buffer.
|
||||
*/
|
||||
CK_TILE_DEVICE void WaitStorePartialDone(const StreamKKernelArgs& kargs, index_t cta_idx) const
|
||||
{
|
||||
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
|
||||
workgroup_barrier sk_flags(sk_flags_ptr);
|
||||
sk_flags.wait_eq(1, cta_idx);
|
||||
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
|
||||
index_t result;
|
||||
index_t offset = cta_idx * sizeof(index_t);
|
||||
|
||||
do
|
||||
{
|
||||
asm volatile("s_mov_b32 m0, %2\n\t"
|
||||
// Depending on the architecture, the GLC flag will bypass the
|
||||
// approproriate cache level(s) to avoid reading stale flags. See the
|
||||
// appropriate ISA for details about the GLC modifier.
|
||||
"s_load_dword %0, %1, %2 glc\n\t"
|
||||
"s_waitcnt lgkmcnt(0)" // Wait for the load to complete
|
||||
: "=s"(result)
|
||||
: "s"(sk_flags_ptr), "s"(offset)
|
||||
: "memory");
|
||||
} while(result != 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Adds the values of a block tile to an output block tile.
|
||||
* @param in_out_block_tile The output block tile to which values are added.
|
||||
* @param in_block_tile The input block tile whose values are added.
|
||||
* @note This function iterates over the distributed spans of the block tiles and updates the
|
||||
* output block tile with accumulated values.
|
||||
* @note This function iterates over the distributed spans of the block tiles and updates
|
||||
* the output block tile with accumulated values.
|
||||
*/
|
||||
template <typename OAccTile>
|
||||
CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile,
|
||||
@@ -370,8 +392,8 @@ struct StreamKKernel
|
||||
* @param cta_idx The index of the thread block (CTA).
|
||||
* @param c_block_tile_dist The tile distribution for the block.
|
||||
* @return The loaded partial block tile.
|
||||
* @note This function calculates the buffer pointer and uses the tile distribution for loading
|
||||
* the partial block tile.
|
||||
* @note This function calculates the buffer pointer and uses the tile distribution for
|
||||
* loading the partial block tile.
|
||||
*/
|
||||
template <typename DataType, typename OAccTileDist>
|
||||
CK_TILE_DEVICE auto LoadPartial(const StreamKKernelArgs& kargs,
|
||||
@@ -405,8 +427,8 @@ struct StreamKKernel
|
||||
* @param kargs Kernel arguments, including the workspace pointer.
|
||||
* @param cta_idx The index of the thread block (CTA).
|
||||
* @param c_block_tile The block tile to be stored.
|
||||
* @note This function calculates the buffer pointer and uses the tile window for storing the
|
||||
* partial block tile.
|
||||
* @note This function calculates the buffer pointer and uses the tile window for storing
|
||||
* the partial block tile.
|
||||
*/
|
||||
template <typename OAccTile>
|
||||
CK_TILE_DEVICE void StorePartial(const StreamKKernelArgs& kargs,
|
||||
@@ -420,7 +442,10 @@ struct StreamKKernel
|
||||
kargs.tile_partitioner.get_flags_buffer_size() +
|
||||
cta_idx * c_block_tile_buffer_size;
|
||||
|
||||
const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
const auto& partial_tensor_view = make_naive_tensor_view<
|
||||
address_space_enum::global,
|
||||
memory_operation_enum::set,
|
||||
StreamKCoherency<decltype(core::arch::get_compiler_target())>::BUFFER_COHERENCE>(
|
||||
static_cast<typename OAccTile::DataType*>(partial_buffer_ptr),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
make_tuple(TilePartitioner::NPerBlock, 1),
|
||||
@@ -431,8 +456,11 @@ struct StreamKKernel
|
||||
partial_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{0, 0});
|
||||
|
||||
store_tile(partial_tile_window, c_block_tile);
|
||||
// Wait for all vector stores for this wavefront to complete
|
||||
s_waitcnt</*vmcnt*/ 0, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
|
||||
// Wait for all wavefronts in this workgroup to arrive here before continuing
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -483,7 +511,8 @@ struct StreamKKernel
|
||||
{
|
||||
BaseGemm(kargs, tile_idx, num_loop_sk, i_k_a, i_k_b, k_size, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
else if(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Reduction ||
|
||||
TilePartitioner::ReductionStrategy == StreamKReductionStrategy::TreeReduction)
|
||||
{
|
||||
const auto c_macro_tile_idx =
|
||||
kargs.tile_partitioner.get_output_tile_index(tile_idx);
|
||||
@@ -528,46 +557,107 @@ struct StreamKKernel
|
||||
|
||||
auto tile_started = iter_start == tile_iter_start;
|
||||
auto tile_ended = iter_end >= tile_iter_end;
|
||||
if(!tile_started)
|
||||
|
||||
if constexpr(TilePartitioner::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
StorePartial(kargs, cta_idx, c_block_tile);
|
||||
// Ensure device-wide visibility of partial results stored in global memory
|
||||
// before signaling completion. __threadfence() guarantees that all global
|
||||
// memory writes by this thread are visible to other threads on the device.
|
||||
__threadfence(); // send signal when the store is done
|
||||
SignalStorePartialDone(kargs, cta_idx);
|
||||
if(!tile_started)
|
||||
{
|
||||
StorePartial(kargs, cta_idx, c_block_tile);
|
||||
SignalStorePartialDone(kargs, cta_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto accum_block_tile = c_block_tile;
|
||||
if(!tile_ended)
|
||||
{
|
||||
const index_t iter_per_tile =
|
||||
kargs.tile_partitioner.get_iters_per_tile();
|
||||
const index_t iter_per_cta =
|
||||
kargs.tile_partitioner.get_iters_per_sk_cta();
|
||||
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
|
||||
int accum_iters = local_iter_end - local_iter_start;
|
||||
int next_cta = cta_idx + 1;
|
||||
|
||||
while(accum_iters < iter_per_tile)
|
||||
{
|
||||
WaitStorePartialDone(kargs, next_cta);
|
||||
|
||||
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
|
||||
AddBlockTile(
|
||||
accum_block_tile,
|
||||
LoadPartial<typename BlockType::DataType>(
|
||||
kargs, next_cta, c_block_tile.get_tile_distribution()));
|
||||
|
||||
accum_iters += iter_per_cta + (next_cta < extra_iters);
|
||||
++next_cta;
|
||||
}
|
||||
}
|
||||
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
else
|
||||
else // Tree Reduction
|
||||
{
|
||||
auto accum_block_tile = c_block_tile;
|
||||
if(!tile_ended)
|
||||
index_t tile_local_cta_idx =
|
||||
kargs.tile_partitioner.get_tile_local_cta_index(tile_iter_start, cta_idx);
|
||||
|
||||
for(index_t stride = 1;; stride <<= 1)
|
||||
{
|
||||
const index_t iter_per_tile = kargs.tile_partitioner.get_iters_per_tile();
|
||||
const index_t iter_per_cta = kargs.tile_partitioner.get_iters_per_sk_cta();
|
||||
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
|
||||
int accum_iters = local_iter_end - local_iter_start;
|
||||
int next_cta = cta_idx + 1;
|
||||
const index_t partner_cta_idx = cta_idx + stride;
|
||||
const index_t partner_start_iter =
|
||||
kargs.tile_partitioner.get_start_iter(partner_cta_idx);
|
||||
bool partner_in_tile = partner_start_iter < tile_iter_end;
|
||||
|
||||
while(accum_iters < iter_per_tile)
|
||||
// If the partner of the workgroup who started the tile is not in this tile,
|
||||
// then the work for this tile is done and results can be stored in the C
|
||||
// tensor.
|
||||
if(tile_started && !partner_in_tile)
|
||||
{
|
||||
WaitStorePartialDone(kargs, next_cta);
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
|
||||
break;
|
||||
}
|
||||
|
||||
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
|
||||
AddBlockTile(
|
||||
accum_block_tile,
|
||||
LoadPartial<typename BlockType::DataType>(
|
||||
kargs, next_cta, c_block_tile.get_tile_distribution()));
|
||||
|
||||
accum_iters += iter_per_cta + (next_cta < extra_iters);
|
||||
++next_cta;
|
||||
// It's this workgroup's turn to read from partials.
|
||||
if(tile_local_cta_idx % (stride << 1) == 0)
|
||||
{
|
||||
// If this workgroup's partner is in the tile then it can read from
|
||||
// partials and accumulate results.
|
||||
if(partner_in_tile)
|
||||
{
|
||||
WaitStorePartialDone(kargs, partner_cta_idx);
|
||||
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
|
||||
AddBlockTile(accum_block_tile,
|
||||
LoadPartial<typename BlockType::DataType>(
|
||||
kargs,
|
||||
partner_cta_idx,
|
||||
c_block_tile.get_tile_distribution()));
|
||||
}
|
||||
}
|
||||
// Otherwise, it's this workgroup's turn to write to partials. All
|
||||
// workgroups, except the workgroup who starts the tile, will write to
|
||||
// partials.
|
||||
else
|
||||
{
|
||||
StorePartial(kargs, cta_idx, accum_block_tile);
|
||||
SignalStorePartialDone(kargs, cta_idx);
|
||||
// Once the workgroup writes to partials, it has no more work to do for
|
||||
// this tile.
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(
|
||||
"An implementation does not exist for the chosen reduction strategy.");
|
||||
}
|
||||
|
||||
// Prepare for next Stream-K loop iteration.
|
||||
iter_start = tile_iter_end;
|
||||
@@ -640,10 +730,10 @@ struct StreamKKernel
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset is
|
||||
* the starting macro tile index in the K dimension for the workgroup.
|
||||
* @return A tuple containing the offsets into the A and B tensors accounting for the layouts
|
||||
* of A and B.
|
||||
* @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset
|
||||
* is the starting macro tile index in the K dimension for the workgroup.
|
||||
* @return A tuple containing the offsets into the A and B tensors accounting for the
|
||||
* layouts of A and B.
|
||||
* @note The default case is that A is assumed to be row major and B is assumed to be column
|
||||
* major.
|
||||
*/
|
||||
@@ -688,7 +778,8 @@ struct StreamKKernel
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the kernel
|
||||
* @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the
|
||||
* kernel
|
||||
* @return The occupancy
|
||||
* @note This function queries the maximum occupancy of the kernel using
|
||||
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
|
||||
@@ -46,6 +46,16 @@ struct StreamKTilePartitionerBase
|
||||
CK_TILE_HOST_DEVICE index_t get_flags_buffer_size() const noexcept;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Calculates the start iteration for the given the cta_idx.
|
||||
* @param cta_idx The current Stream-K workgroup's index.
|
||||
* @return index_t The start iteration.
|
||||
* @note It is assumed that the first Stream-K workgroup has a `cta_idx` of zero. If a
|
||||
* non-persistent DP section is used, then a Stream-K workgroup's `cta_idx` should be something
|
||||
* like `blockIdx.x` minus number of DP workgroups.
|
||||
*/
|
||||
CK_TILE_DEVICE index_t get_start_iter(index_t cta_idx) const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the start and end iteration given the cta_idx.
|
||||
*
|
||||
@@ -107,7 +117,17 @@ struct StreamKTilePartitionerBase
|
||||
get_local_iter_end(index_t tile_iter_start, index_t iter_end, index_t tile_iter_end) noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the workgroups 2D tile index in the C tensor given the 1D tile index.
|
||||
* @brief Calculates the workgroup's local CTA idx within the given tile.
|
||||
*
|
||||
* @param tile_iter_start The starting tile iteration.
|
||||
* @param cta_idx The Stream-K workgroup index.
|
||||
* @return index_t The tile local workgroup index in the tile.
|
||||
*/
|
||||
CK_TILE_DEVICE index_t get_tile_local_cta_index(index_t tile_iter_start,
|
||||
index_t cta_idx) const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the workgroup's 2D tile index in the C tensor given the 1D tile index.
|
||||
*
|
||||
* @param tile_idx The 1D tile index in the C tensor for the workgroup.
|
||||
* @return index_t The corresponding 2D tile index in the C tensor for the workgroup.
|
||||
@@ -61,13 +61,24 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_flags
|
||||
return sizeof(index_t) * sk_ctas_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_start_iter(
|
||||
index_t cta_idx) const noexcept
|
||||
{
|
||||
// Compute the number of extra iterations done before this CTA. If the cta_idx is less than
|
||||
// extra_iters, the number of extra iterations before the CTA is exactly the cta_idx. Otherwise,
|
||||
// it is extra_iters.
|
||||
index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
|
||||
return total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE void
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_iter_boundaries(
|
||||
index_t& iter, index_t& iter_end, index_t cta_idx) const noexcept
|
||||
{
|
||||
index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
|
||||
iter = total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
|
||||
iter = get_start_iter(cta_idx);
|
||||
iter_end = iter + iters_per_sk_cta_ + (cta_idx < extra_iters_);
|
||||
}
|
||||
|
||||
@@ -104,6 +115,24 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_local
|
||||
return ck_tile::min(iter_end, tile_iter_end) - tile_iter;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_tile_local_cta_index(
|
||||
index_t tile_iter_start, index_t cta_idx) const noexcept
|
||||
{
|
||||
tile_iter_start = tile_iter_start - (dp_tiles_ * iters_per_tile_);
|
||||
|
||||
// Compute how many WGs fit before this tile starts assuming each WG does an
|
||||
// extra_iter
|
||||
const index_t num_extra_iter_ctas = tile_iter_start / (iters_per_sk_cta_ + 1);
|
||||
// Compute how many WGs fit before this tile starts excluding extra iters
|
||||
const index_t num_non_extra_iter_ctas = (tile_iter_start - extra_iters_) / iters_per_sk_cta_;
|
||||
// Compute the CTA idx for the CTA that starts this tile
|
||||
const index_t coop_group_start =
|
||||
num_extra_iter_ctas < extra_iters_ ? num_extra_iter_ctas : num_non_extra_iter_ctas;
|
||||
return cta_idx - coop_group_start;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE auto
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_output_tile_index(
|
||||
@@ -121,7 +150,8 @@ CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_workspace_size(
|
||||
index_t acc_element_bytes) const noexcept
|
||||
{
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction ||
|
||||
ReductionStrategy == StreamKReductionStrategy::TreeReduction)
|
||||
{
|
||||
|
||||
return get_partials_buffer_size(acc_element_bytes) + get_flags_buffer_size();
|
||||
@@ -23,6 +23,9 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
|
||||
#TODO: support all arches
|
||||
#TODO: current c-shuffle only supports C layout as R
|
||||
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
|
||||
add_gtest_executable(test_ck_tile_streamk_reduction
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp
|
||||
test_gemm_streamk_util.cpp)
|
||||
add_gtest_executable(test_ck_tile_streamk_smoke
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf16_persistent.cpp
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp16Reduction : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp16Reduction
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp16Reduction, KernelTypesStreamKFp16Reduction);
|
||||
|
||||
#include "test_gemm_streamk_reduction_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,88 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile_Tree)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile;
|
||||
ck_tile::index_t N = N_Tile;
|
||||
ck_tile::index_t K = K_Tile * num_cu;
|
||||
|
||||
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile;
|
||||
ck_tile::index_t N = N_Tile;
|
||||
ck_tile::index_t K = K_Tile * num_cu;
|
||||
|
||||
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Tree)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile * 4;
|
||||
ck_tile::index_t N = N_Tile;
|
||||
ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile);
|
||||
|
||||
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Reduction)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile * 4;
|
||||
ck_tile::index_t N = N_Tile;
|
||||
ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile);
|
||||
|
||||
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles_Tree)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile * 3;
|
||||
ck_tile::index_t N = N_Tile * 7;
|
||||
ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile);
|
||||
|
||||
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile * 3;
|
||||
ck_tile::index_t N = N_Tile * 7;
|
||||
ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile);
|
||||
|
||||
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction);
|
||||
}
|
||||
@@ -33,6 +33,14 @@ using KernelTypesStreamKFp16Persistent = ::testing::Types<
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp16Reduction = ::testing::Types<
|
||||
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>>;
|
||||
|
||||
using KernelTypesStreamKBf16Persistent = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
|
||||
|
||||
@@ -144,7 +144,11 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
|
||||
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const auto workspace_size = Kernel::GetWorkSpaceSize(kargs);
|
||||
ck_tile::DeviceMem workspace_data(workspace_size);
|
||||
workspace_data.SetZero();
|
||||
kargs.workspace_ptr = workspace_data.GetDeviceBuffer();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
@@ -184,11 +188,6 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
|
||||
using namespace ck_tile::literals;
|
||||
|
||||
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
throw std::runtime_error("Reduction Strategy is current unsupported!\n");
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
@@ -252,9 +251,25 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
stride_B,
|
||||
stride_C};
|
||||
|
||||
ck_tile::index_t num_accumulations_per_tile =
|
||||
invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
ck_tile::index_t num_accumulations_per_tile;
|
||||
|
||||
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
num_accumulations_per_tile = invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
|
||||
}
|
||||
else if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
num_accumulations_per_tile =
|
||||
invoke_streamk<ck_tile::StreamKReductionStrategy::Reduction>(
|
||||
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
|
||||
}
|
||||
else
|
||||
{
|
||||
num_accumulations_per_tile =
|
||||
invoke_streamk<ck_tile::StreamKReductionStrategy::TreeReduction>(
|
||||
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
|
||||
}
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
|
||||
@@ -372,6 +372,85 @@ TEST(StreamKTilePartitionerBaseGetOutputTileIndex, TestAllMappings)
|
||||
}
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetTileLocalCtaIndex, SKOnlyLargeK)
|
||||
{
|
||||
/*
|
||||
The StreamKTilePartitionerBaseConfigSKOnlyLargeK has the following form:
|
||||
- tiles in the C tensor: 2
|
||||
- iters_per_tile: 5
|
||||
- grid: 5
|
||||
- dp_tiles: 0
|
||||
- sk_tiles: 2
|
||||
- iters_per_sk_cta: 2
|
||||
- extra_iters: 0
|
||||
|
||||
The tiles with iters are as follows:
|
||||
|
||||
tile_idx: __________0_________|_________1_________|
|
||||
tile_iter:| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
|
||||
| | | | | | | | | | |
|
||||
<---------------SK Tiles--------------->|
|
||||
|
||||
From the above configuration, we get the following:
|
||||
- SK CTA 0: tile_iter_start is 0 with local CTA index of 0 in tile 0
|
||||
- SK CTA 1: tile_iter_start is 0 with local CTA index of 1 in tile 0
|
||||
- SK CTA 2: tile_iter_start is 0 with local CTA index of 2 in tile 0
|
||||
- SK CTA 2: tile_iter_start is 5 with local CTA index of 0 in tile 1
|
||||
- SK CTA 3: tile_iter_start is 5 with local CTA index of 1 in tile 1
|
||||
- SK CTA 4: tile_iter_start is 5 with local CTA index of 2 in tile 1
|
||||
*/
|
||||
|
||||
// Now we create a vector of triplets (tile_iter_start, cta_idx, tile_local_cta_idx) to test
|
||||
std::vector<std::array<ck_tile::index_t, 3>> sk_triplets{
|
||||
{0, 0, 0}, {0, 1, 1}, {0, 2, 2}, {5, 2, 0}, {5, 3, 1}, {5, 4, 2}};
|
||||
|
||||
for(const auto& triplet : sk_triplets)
|
||||
{
|
||||
const auto& [tile_iter_start, cta_idx, tile_local_cta_idx] = triplet;
|
||||
test_get_tile_local_cta_idx<StreamKTilePartitionerBaseConfigSKOnlyLargeK>(
|
||||
tile_iter_start, cta_idx, tile_local_cta_idx);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetTileLocalCtaIndex, DP2TileSK)
|
||||
{
|
||||
/*
|
||||
The StreamKTilePartitionerBaseConfigDP2TileSK has the following form:
|
||||
- tiles in the C tensor: 7
|
||||
- iters_per_tile: 3
|
||||
- grid: 3
|
||||
- dp_tiles: 3
|
||||
- sk_tiles: 4
|
||||
- iters_per_sk_cta: 2
|
||||
- extra_iters: 2
|
||||
|
||||
The tiles with iters are as follows:
|
||||
|
||||
tile_idx: ____0___|___1___|___2___|___3___|___4___|____5____|____6____|
|
||||
tile_iter:| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 |
|
||||
| | | | | | | | | | | | | | |
|
||||
|<-------DP Tiles------>|<------------SK Tiles------------->|
|
||||
|
||||
From the above configuration, we get the following:
|
||||
- SK CTA 0: tile_iter_start is 6 with local CTA index of 0 in tile 3
|
||||
- SK CTA 0: tile_iter_start is 8 with local CTA index of 0 in tile 4
|
||||
- SK CTA 1: tile_iter_start is 8 with local CTA index of 1 in tile 4
|
||||
- SK CTA 1: tile_iter_start is 10 with local CTA index of 0 in tile 5
|
||||
- SK CTA 2: tile_iter_start is 12 with local CTA index of 0 in tile 6
|
||||
*/
|
||||
|
||||
// Now we create a vector of triplets (tile_iter_start, cta_idx, tile_local_cta_idx) to test
|
||||
std::vector<std::array<ck_tile::index_t, 3>> sk_triplets{
|
||||
{6, 0, 0}, {8, 0, 0}, {8, 1, 1}, {10, 1, 0}, {12, 2, 0}};
|
||||
|
||||
for(const auto& triplet : sk_triplets)
|
||||
{
|
||||
const auto& [tile_iter_start, cta_idx, tile_local_cta_idx] = triplet;
|
||||
test_get_tile_local_cta_idx<StreamKTilePartitionerBaseConfigDP2TileSK>(
|
||||
tile_iter_start, cta_idx, tile_local_cta_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Persistent
|
||||
TEST(StreamKTilePartitioner_PersistentConstructor, SKOnly)
|
||||
{
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include <array>
|
||||
|
||||
enum StreamKTilePartitionerBaseMethodId
|
||||
{
|
||||
@@ -12,7 +13,8 @@ enum StreamKTilePartitionerBaseMethodId
|
||||
GET_TILE_BOUNDARIES,
|
||||
GET_TILE_INDEX,
|
||||
GET_ITER_BOUNDARIES,
|
||||
GET_OUTPUT_TILE_INDEX
|
||||
GET_OUTPUT_TILE_INDEX,
|
||||
GET_TILE_LOCAL_CTA_INDEX
|
||||
};
|
||||
|
||||
// Base kernel wrapper class to facilitate testing class device functions.
|
||||
@@ -136,6 +138,22 @@ struct KernelWrapperSpecialized<TilePartitioner,
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TilePartitioner>
|
||||
struct KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_TILE_LOCAL_CTA_INDEX>
|
||||
: public KernelWrapper<TilePartitioner>
|
||||
{
|
||||
|
||||
using Base = KernelWrapper<TilePartitioner>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(typename Base::KernelArgs kargs)
|
||||
{
|
||||
ck_tile::index_t tile_local_cta_index =
|
||||
kargs.tile_partitioner.get_tile_local_cta_index(kargs.arg1, kargs.arg2);
|
||||
*(static_cast<ck_tile::index_t*>(kargs.result1)) = tile_local_cta_index;
|
||||
}
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseExpected
|
||||
{
|
||||
ck_tile::index_t sk_tiles_;
|
||||
@@ -243,6 +261,22 @@ struct StreamKTilePartitionerBaseConfigSKOnly : public StreamKTilePartitionerBas
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigSKOnlyLargeK : public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M = 8;
|
||||
static constexpr ck_tile::index_t N = 2;
|
||||
static constexpr ck_tile::index_t K = 10;
|
||||
static constexpr ck_tile::index_t GRID = 5;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 2;
|
||||
static constexpr ck_tile::index_t K_TILE = 2;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigEdgeCase : public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
|
||||
@@ -314,6 +348,38 @@ void test_get_output_tile_index(ck_tile::index_t tile_idx,
|
||||
EXPECT_EQ(in, in_expected);
|
||||
};
|
||||
|
||||
template <typename Config>
|
||||
void test_get_tile_local_cta_idx(ck_tile::index_t tile_iter_start,
|
||||
ck_tile::index_t cta_idx,
|
||||
ck_tile::index_t expected_tile_local_cta_idx)
|
||||
{
|
||||
// Types
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<typename Config::GemmShape>;
|
||||
using Kernel =
|
||||
KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_TILE_LOCAL_CTA_INDEX>;
|
||||
|
||||
// Test parameters
|
||||
ck_tile::StreamKTilePartitionerBase<typename Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
ck_tile::DeviceMem tile_local_cta_idx_dev(sizeof(ck_tile::index_t));
|
||||
|
||||
// Launch kernel
|
||||
auto kargs = Kernel::MakeKernelArgs(tile_iter_start,
|
||||
cta_idx,
|
||||
Config::UNUSED,
|
||||
tile_local_cta_idx_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
tile_partitioner);
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
|
||||
|
||||
// Validate results
|
||||
ck_tile::index_t tile_local_cta_idx;
|
||||
tile_local_cta_idx_dev.FromDevice(&tile_local_cta_idx);
|
||||
EXPECT_EQ(tile_local_cta_idx, expected_tile_local_cta_idx);
|
||||
}
|
||||
|
||||
// Configs for TilePartitioner Child structs
|
||||
struct StreamKTilePartitionerV2PersistentExpected
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user