mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 02:27:57 +00:00
Merge commit '22b945e06ea4b4de188d7ff4ec7ae4bf127be9f9' into develop
This commit is contained in:
@@ -8,7 +8,8 @@
|
||||
namespace ck_tile {
|
||||
enum StreamKReductionStrategy : uint32_t
|
||||
{
|
||||
Atomic = 0u,
|
||||
Reduction = 1u
|
||||
Atomic = 0u,
|
||||
Reduction = 1u,
|
||||
TreeReduction = 2u
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -33,9 +33,10 @@
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_coherency.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename CompilerTarget, typename Enabler = void>
|
||||
struct StreamKCoherency
|
||||
{
|
||||
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
|
||||
amd_buffer_coherence_enum::coherence_default;
|
||||
};
|
||||
|
||||
template <typename CompilerTarget>
|
||||
struct StreamKCoherency<CompilerTarget,
|
||||
core::arch::enable_if_target_id_t<CompilerTarget,
|
||||
core::arch::amdgcn_target_id::GFX942,
|
||||
core::arch::amdgcn_target_id::GFX950>>
|
||||
{
|
||||
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
|
||||
amd_buffer_coherence_enum::SYSTEM_NT0;
|
||||
};
|
||||
|
||||
template <typename CompilerTarget>
|
||||
struct StreamKCoherency<CompilerTarget,
|
||||
core::arch::enable_if_target_id_t<CompilerTarget,
|
||||
core::arch::amdgcn_target_id::GFX908,
|
||||
core::arch::amdgcn_target_id::GFX90A>>
|
||||
{
|
||||
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
|
||||
amd_buffer_coherence_enum::glc_slc;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "streamk_gemm_coherency.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -318,37 +319,58 @@ struct StreamKKernel
|
||||
* results.
|
||||
* @param kargs Kernel arguments, including the workspace pointer.
|
||||
* @param cta_idx The index of the current thread block (CTA).
|
||||
* @note This function utilizes a workgroup barrier to set a synchronization flag for the given
|
||||
* CTA index.
|
||||
* @note This function utilizes a scalar store to write to the flags buffer.
|
||||
*/
|
||||
CK_TILE_DEVICE void SignalStorePartialDone(const StreamKKernelArgs& kargs,
|
||||
index_t cta_idx) const
|
||||
{
|
||||
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
|
||||
workgroup_barrier sk_flags(sk_flags_ptr);
|
||||
sk_flags.wait_set(0, 1, cta_idx);
|
||||
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
|
||||
index_t offset = cta_idx * sizeof(index_t);
|
||||
|
||||
asm volatile("s_mov_b32 m0, %2\n\t"
|
||||
// Depending on the architecture, the GLC flag will bypass the approproriate
|
||||
// cache level(s) to ensure the write is visible to other workgroups. See the
|
||||
// appropriate ISA for details about the GLC modifier.
|
||||
"s_store_dword %0, %1, %2 glc\n\t"
|
||||
"s_waitcnt lgkmcnt(0)" // Wait for the store to complete
|
||||
:
|
||||
: "s"(1), "s"(sk_flags_ptr), "s"(offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Waits for the thread block (cta_idx) to complete storing its partial results.
|
||||
* @param kargs Kernel arguments, including the workspace pointer.
|
||||
* @param cta_idx The index of the thread block (CTA).
|
||||
* @note This function utilizes a workgroup barrier to wait for the synchronization flag to be
|
||||
* set by the given CTA index.
|
||||
* @note This function utilizes a scalar load to read from the flags
|
||||
* buffer.
|
||||
*/
|
||||
CK_TILE_DEVICE void WaitStorePartialDone(const StreamKKernelArgs& kargs, index_t cta_idx) const
|
||||
{
|
||||
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
|
||||
workgroup_barrier sk_flags(sk_flags_ptr);
|
||||
sk_flags.wait_eq(1, cta_idx);
|
||||
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
|
||||
index_t result;
|
||||
index_t offset = cta_idx * sizeof(index_t);
|
||||
|
||||
do
|
||||
{
|
||||
asm volatile("s_mov_b32 m0, %2\n\t"
|
||||
// Depending on the architecture, the GLC flag will bypass the
|
||||
// approproriate cache level(s) to avoid reading stale flags. See the
|
||||
// appropriate ISA for details about the GLC modifier.
|
||||
"s_load_dword %0, %1, %2 glc\n\t"
|
||||
"s_waitcnt lgkmcnt(0)" // Wait for the load to complete
|
||||
: "=s"(result)
|
||||
: "s"(sk_flags_ptr), "s"(offset)
|
||||
: "memory");
|
||||
} while(result != 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Adds the values of a block tile to an output block tile.
|
||||
* @param in_out_block_tile The output block tile to which values are added.
|
||||
* @param in_block_tile The input block tile whose values are added.
|
||||
* @note This function iterates over the distributed spans of the block tiles and updates the
|
||||
* output block tile with accumulated values.
|
||||
* @note This function iterates over the distributed spans of the block tiles and updates
|
||||
* the output block tile with accumulated values.
|
||||
*/
|
||||
template <typename OAccTile>
|
||||
CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile,
|
||||
@@ -370,8 +392,8 @@ struct StreamKKernel
|
||||
* @param cta_idx The index of the thread block (CTA).
|
||||
* @param c_block_tile_dist The tile distribution for the block.
|
||||
* @return The loaded partial block tile.
|
||||
* @note This function calculates the buffer pointer and uses the tile distribution for loading
|
||||
* the partial block tile.
|
||||
* @note This function calculates the buffer pointer and uses the tile distribution for
|
||||
* loading the partial block tile.
|
||||
*/
|
||||
template <typename DataType, typename OAccTileDist>
|
||||
CK_TILE_DEVICE auto LoadPartial(const StreamKKernelArgs& kargs,
|
||||
@@ -405,8 +427,8 @@ struct StreamKKernel
|
||||
* @param kargs Kernel arguments, including the workspace pointer.
|
||||
* @param cta_idx The index of the thread block (CTA).
|
||||
* @param c_block_tile The block tile to be stored.
|
||||
* @note This function calculates the buffer pointer and uses the tile window for storing the
|
||||
* partial block tile.
|
||||
* @note This function calculates the buffer pointer and uses the tile window for storing
|
||||
* the partial block tile.
|
||||
*/
|
||||
template <typename OAccTile>
|
||||
CK_TILE_DEVICE void StorePartial(const StreamKKernelArgs& kargs,
|
||||
@@ -420,7 +442,10 @@ struct StreamKKernel
|
||||
kargs.tile_partitioner.get_flags_buffer_size() +
|
||||
cta_idx * c_block_tile_buffer_size;
|
||||
|
||||
const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
const auto& partial_tensor_view = make_naive_tensor_view<
|
||||
address_space_enum::global,
|
||||
memory_operation_enum::set,
|
||||
StreamKCoherency<decltype(core::arch::get_compiler_target())>::BUFFER_COHERENCE>(
|
||||
static_cast<typename OAccTile::DataType*>(partial_buffer_ptr),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
make_tuple(TilePartitioner::NPerBlock, 1),
|
||||
@@ -431,8 +456,11 @@ struct StreamKKernel
|
||||
partial_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{0, 0});
|
||||
|
||||
store_tile(partial_tile_window, c_block_tile);
|
||||
// Wait for all vector stores for this wavefront to complete
|
||||
s_waitcnt</*vmcnt*/ 0, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
|
||||
// Wait for all wavefronts in this workgroup to arrive here before continuing
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -483,7 +511,8 @@ struct StreamKKernel
|
||||
{
|
||||
BaseGemm(kargs, tile_idx, num_loop_sk, i_k_a, i_k_b, k_size, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
else if(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Reduction ||
|
||||
TilePartitioner::ReductionStrategy == StreamKReductionStrategy::TreeReduction)
|
||||
{
|
||||
const auto c_macro_tile_idx =
|
||||
kargs.tile_partitioner.get_output_tile_index(tile_idx);
|
||||
@@ -528,46 +557,107 @@ struct StreamKKernel
|
||||
|
||||
auto tile_started = iter_start == tile_iter_start;
|
||||
auto tile_ended = iter_end >= tile_iter_end;
|
||||
if(!tile_started)
|
||||
|
||||
if constexpr(TilePartitioner::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
StorePartial(kargs, cta_idx, c_block_tile);
|
||||
// Ensure device-wide visibility of partial results stored in global memory
|
||||
// before signaling completion. __threadfence() guarantees that all global
|
||||
// memory writes by this thread are visible to other threads on the device.
|
||||
__threadfence(); // send signal when the store is done
|
||||
SignalStorePartialDone(kargs, cta_idx);
|
||||
if(!tile_started)
|
||||
{
|
||||
StorePartial(kargs, cta_idx, c_block_tile);
|
||||
SignalStorePartialDone(kargs, cta_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto accum_block_tile = c_block_tile;
|
||||
if(!tile_ended)
|
||||
{
|
||||
const index_t iter_per_tile =
|
||||
kargs.tile_partitioner.get_iters_per_tile();
|
||||
const index_t iter_per_cta =
|
||||
kargs.tile_partitioner.get_iters_per_sk_cta();
|
||||
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
|
||||
int accum_iters = local_iter_end - local_iter_start;
|
||||
int next_cta = cta_idx + 1;
|
||||
|
||||
while(accum_iters < iter_per_tile)
|
||||
{
|
||||
WaitStorePartialDone(kargs, next_cta);
|
||||
|
||||
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
|
||||
AddBlockTile(
|
||||
accum_block_tile,
|
||||
LoadPartial<typename BlockType::DataType>(
|
||||
kargs, next_cta, c_block_tile.get_tile_distribution()));
|
||||
|
||||
accum_iters += iter_per_cta + (next_cta < extra_iters);
|
||||
++next_cta;
|
||||
}
|
||||
}
|
||||
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
else
|
||||
else // Tree Reduction
|
||||
{
|
||||
auto accum_block_tile = c_block_tile;
|
||||
if(!tile_ended)
|
||||
index_t tile_local_cta_idx =
|
||||
kargs.tile_partitioner.get_tile_local_cta_index(tile_iter_start, cta_idx);
|
||||
|
||||
for(index_t stride = 1;; stride <<= 1)
|
||||
{
|
||||
const index_t iter_per_tile = kargs.tile_partitioner.get_iters_per_tile();
|
||||
const index_t iter_per_cta = kargs.tile_partitioner.get_iters_per_sk_cta();
|
||||
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
|
||||
int accum_iters = local_iter_end - local_iter_start;
|
||||
int next_cta = cta_idx + 1;
|
||||
const index_t partner_cta_idx = cta_idx + stride;
|
||||
const index_t partner_start_iter =
|
||||
kargs.tile_partitioner.get_start_iter(partner_cta_idx);
|
||||
bool partner_in_tile = partner_start_iter < tile_iter_end;
|
||||
|
||||
while(accum_iters < iter_per_tile)
|
||||
// If the partner of the workgroup who started the tile is not in this tile,
|
||||
// then the work for this tile is done and results can be stored in the C
|
||||
// tensor.
|
||||
if(tile_started && !partner_in_tile)
|
||||
{
|
||||
WaitStorePartialDone(kargs, next_cta);
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
|
||||
break;
|
||||
}
|
||||
|
||||
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
|
||||
AddBlockTile(
|
||||
accum_block_tile,
|
||||
LoadPartial<typename BlockType::DataType>(
|
||||
kargs, next_cta, c_block_tile.get_tile_distribution()));
|
||||
|
||||
accum_iters += iter_per_cta + (next_cta < extra_iters);
|
||||
++next_cta;
|
||||
// It's this workgroup's turn to read from partials.
|
||||
if(tile_local_cta_idx % (stride << 1) == 0)
|
||||
{
|
||||
// If this workgroup's partner is in the tile then it can read from
|
||||
// partials and accumulate results.
|
||||
if(partner_in_tile)
|
||||
{
|
||||
WaitStorePartialDone(kargs, partner_cta_idx);
|
||||
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
|
||||
AddBlockTile(accum_block_tile,
|
||||
LoadPartial<typename BlockType::DataType>(
|
||||
kargs,
|
||||
partner_cta_idx,
|
||||
c_block_tile.get_tile_distribution()));
|
||||
}
|
||||
}
|
||||
// Otherwise, it's this workgroup's turn to write to partials. All
|
||||
// workgroups, except the workgroup who starts the tile, will write to
|
||||
// partials.
|
||||
else
|
||||
{
|
||||
StorePartial(kargs, cta_idx, accum_block_tile);
|
||||
SignalStorePartialDone(kargs, cta_idx);
|
||||
// Once the workgroup writes to partials, it has no more work to do for
|
||||
// this tile.
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(
|
||||
"An implementation does not exist for the chosen reduction strategy.");
|
||||
}
|
||||
|
||||
// Prepare for next Stream-K loop iteration.
|
||||
iter_start = tile_iter_end;
|
||||
@@ -640,10 +730,10 @@ struct StreamKKernel
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset is
|
||||
* the starting macro tile index in the K dimension for the workgroup.
|
||||
* @return A tuple containing the offsets into the A and B tensors accounting for the layouts
|
||||
* of A and B.
|
||||
* @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset
|
||||
* is the starting macro tile index in the K dimension for the workgroup.
|
||||
* @return A tuple containing the offsets into the A and B tensors accounting for the
|
||||
* layouts of A and B.
|
||||
* @note The default case is that A is assumed to be row major and B is assumed to be column
|
||||
* major.
|
||||
*/
|
||||
@@ -688,7 +778,8 @@ struct StreamKKernel
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the kernel
|
||||
* @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the
|
||||
* kernel
|
||||
* @return The occupancy
|
||||
* @note This function queries the maximum occupancy of the kernel using
|
||||
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
|
||||
@@ -46,6 +46,16 @@ struct StreamKTilePartitionerBase
|
||||
CK_TILE_HOST_DEVICE index_t get_flags_buffer_size() const noexcept;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Calculates the start iteration for the given the cta_idx.
|
||||
* @param cta_idx The current Stream-K workgroup's index.
|
||||
* @return index_t The start iteration.
|
||||
* @note It is assumed that the first Stream-K workgroup has a `cta_idx` of zero. If a
|
||||
* non-persistent DP section is used, then a Stream-K workgroup's `cta_idx` should be something
|
||||
* like `blockIdx.x` minus number of DP workgroups.
|
||||
*/
|
||||
CK_TILE_DEVICE index_t get_start_iter(index_t cta_idx) const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the start and end iteration given the cta_idx.
|
||||
*
|
||||
@@ -107,7 +117,17 @@ struct StreamKTilePartitionerBase
|
||||
get_local_iter_end(index_t tile_iter_start, index_t iter_end, index_t tile_iter_end) noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the workgroups 2D tile index in the C tensor given the 1D tile index.
|
||||
* @brief Calculates the workgroup's local CTA idx within the given tile.
|
||||
*
|
||||
* @param tile_iter_start The starting tile iteration.
|
||||
* @param cta_idx The Stream-K workgroup index.
|
||||
* @return index_t The tile local workgroup index in the tile.
|
||||
*/
|
||||
CK_TILE_DEVICE index_t get_tile_local_cta_index(index_t tile_iter_start,
|
||||
index_t cta_idx) const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the workgroup's 2D tile index in the C tensor given the 1D tile index.
|
||||
*
|
||||
* @param tile_idx The 1D tile index in the C tensor for the workgroup.
|
||||
* @return index_t The corresponding 2D tile index in the C tensor for the workgroup.
|
||||
@@ -61,13 +61,24 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_flags
|
||||
return sizeof(index_t) * sk_ctas_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_start_iter(
|
||||
index_t cta_idx) const noexcept
|
||||
{
|
||||
// Compute the number of extra iterations done before this CTA. If the cta_idx is less than
|
||||
// extra_iters, the number of extra iterations before the CTA is exactly the cta_idx. Otherwise,
|
||||
// it is extra_iters.
|
||||
index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
|
||||
return total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE void
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_iter_boundaries(
|
||||
index_t& iter, index_t& iter_end, index_t cta_idx) const noexcept
|
||||
{
|
||||
index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
|
||||
iter = total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
|
||||
iter = get_start_iter(cta_idx);
|
||||
iter_end = iter + iters_per_sk_cta_ + (cta_idx < extra_iters_);
|
||||
}
|
||||
|
||||
@@ -104,6 +115,24 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_local
|
||||
return ck_tile::min(iter_end, tile_iter_end) - tile_iter;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_tile_local_cta_index(
|
||||
index_t tile_iter_start, index_t cta_idx) const noexcept
|
||||
{
|
||||
tile_iter_start = tile_iter_start - (dp_tiles_ * iters_per_tile_);
|
||||
|
||||
// Compute how many WGs fit before this tile starts assuming each WG does an
|
||||
// extra_iter
|
||||
const index_t num_extra_iter_ctas = tile_iter_start / (iters_per_sk_cta_ + 1);
|
||||
// Compute how many WGs fit before this tile starts excluding extra iters
|
||||
const index_t num_non_extra_iter_ctas = (tile_iter_start - extra_iters_) / iters_per_sk_cta_;
|
||||
// Compute the CTA idx for the CTA that starts this tile
|
||||
const index_t coop_group_start =
|
||||
num_extra_iter_ctas < extra_iters_ ? num_extra_iter_ctas : num_non_extra_iter_ctas;
|
||||
return cta_idx - coop_group_start;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE auto
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_output_tile_index(
|
||||
@@ -121,7 +150,8 @@ CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_workspace_size(
|
||||
index_t acc_element_bytes) const noexcept
|
||||
{
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction ||
|
||||
ReductionStrategy == StreamKReductionStrategy::TreeReduction)
|
||||
{
|
||||
|
||||
return get_partials_buffer_size(acc_element_bytes) + get_flags_buffer_size();
|
||||
@@ -23,6 +23,9 @@ if(GPU_TARGETS MATCHES "gfx90a|gfx942|gfx950")
|
||||
#TODO: support all arches
|
||||
#TODO: current c-shuffle only supports C layout as R
|
||||
add_gtest_executable(test_ck_tile_streamk_tile_partitioner test_streamk_tile_partitioner.cpp)
|
||||
add_gtest_executable(test_ck_tile_streamk_reduction
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_reduction.cpp
|
||||
test_gemm_streamk_util.cpp)
|
||||
add_gtest_executable(test_ck_tile_streamk_smoke
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_fp16_persistent.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/smoke_tests/test_gemm_streamk_bf16_persistent.cpp
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_streamk_common_includes.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileStreamKFp16Reduction : public TestCkTileStreamK<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
#define TEST_SUITE_NAME TestCkTileStreamKFp16Reduction
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileStreamKFp16Reduction, KernelTypesStreamKFp16Reduction);
|
||||
|
||||
#include "test_gemm_streamk_reduction_cases.inc"
|
||||
|
||||
#undef TEST_SUITE_NAME
|
||||
@@ -0,0 +1,88 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile_Tree)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile;
|
||||
ck_tile::index_t N = N_Tile;
|
||||
ck_tile::index_t K = K_Tile * num_cu;
|
||||
|
||||
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_OneTile)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile;
|
||||
ck_tile::index_t N = N_Tile;
|
||||
ck_tile::index_t K = K_Tile * num_cu;
|
||||
|
||||
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Tree)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile * 4;
|
||||
ck_tile::index_t N = N_Tile;
|
||||
ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile);
|
||||
|
||||
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_4Tiles_Reduction)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile * 4;
|
||||
ck_tile::index_t N = N_Tile;
|
||||
ck_tile::index_t K = K_Tile * num_cu + (25 * K_Tile);
|
||||
|
||||
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles_Tree)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile * 3;
|
||||
ck_tile::index_t N = N_Tile * 7;
|
||||
ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile);
|
||||
|
||||
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::TreeReduction);
|
||||
}
|
||||
|
||||
TYPED_TEST(TEST_SUITE_NAME, StreamK_SKOnly_21Tiles)
|
||||
{
|
||||
const ck_tile::index_t num_cu = get_cu_count();
|
||||
constexpr ck_tile::index_t M_Tile = std::tuple_element_t<7, TypeParam>::value;
|
||||
constexpr ck_tile::index_t N_Tile = std::tuple_element_t<8, TypeParam>::value;
|
||||
constexpr ck_tile::index_t K_Tile = std::tuple_element_t<9, TypeParam>::value;
|
||||
|
||||
ck_tile::index_t M = M_Tile * 3;
|
||||
ck_tile::index_t N = N_Tile * 7;
|
||||
ck_tile::index_t K = K_Tile * num_cu + (30 * K_Tile);
|
||||
|
||||
this->Run(M, N, K, ck_tile::StreamKReductionStrategy::Reduction);
|
||||
}
|
||||
@@ -33,6 +33,14 @@ using KernelTypesStreamKFp16Persistent = ::testing::Types<
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>
|
||||
>;
|
||||
|
||||
using KernelTypesStreamKFp16Reduction = ::testing::Types<
|
||||
// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType M_MacroTile N_MacroTile K_MacroTile Persistent
|
||||
std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I32, NonPersistent>>;
|
||||
|
||||
using KernelTypesStreamKBf16Persistent = ::testing::Types<
|
||||
std::tuple< Row, Row, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
|
||||
std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, I256, I256, I32, Persistent>,
|
||||
|
||||
@@ -144,7 +144,11 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
|
||||
using Kernel = ck_tile::StreamKKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const auto workspace_size = Kernel::GetWorkSpaceSize(kargs);
|
||||
ck_tile::DeviceMem workspace_data(workspace_size);
|
||||
workspace_data.SetZero();
|
||||
kargs.workspace_ptr = workspace_data.GetDeviceBuffer();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
@@ -184,11 +188,6 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
|
||||
using namespace ck_tile::literals;
|
||||
|
||||
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
throw std::runtime_error("Reduction Strategy is current unsupported!\n");
|
||||
}
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
@@ -252,9 +251,25 @@ class TestCkTileStreamK : public ::testing::Test
|
||||
stride_B,
|
||||
stride_C};
|
||||
|
||||
ck_tile::index_t num_accumulations_per_tile =
|
||||
invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
ck_tile::index_t num_accumulations_per_tile;
|
||||
|
||||
if(reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic)
|
||||
{
|
||||
num_accumulations_per_tile = invoke_streamk<ck_tile::StreamKReductionStrategy::Atomic>(
|
||||
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
|
||||
}
|
||||
else if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
num_accumulations_per_tile =
|
||||
invoke_streamk<ck_tile::StreamKReductionStrategy::Reduction>(
|
||||
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
|
||||
}
|
||||
else
|
||||
{
|
||||
num_accumulations_per_tile =
|
||||
invoke_streamk<ck_tile::StreamKReductionStrategy::TreeReduction>(
|
||||
args, ck_tile::stream_config{nullptr, false, 0, 0, 1});
|
||||
}
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
|
||||
|
||||
@@ -372,6 +372,85 @@ TEST(StreamKTilePartitionerBaseGetOutputTileIndex, TestAllMappings)
|
||||
}
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetTileLocalCtaIndex, SKOnlyLargeK)
|
||||
{
|
||||
/*
|
||||
The StreamKTilePartitionerBaseConfigSKOnlyLargeK has the following form:
|
||||
- tiles in the C tensor: 2
|
||||
- iters_per_tile: 5
|
||||
- grid: 5
|
||||
- dp_tiles: 0
|
||||
- sk_tiles: 2
|
||||
- iters_per_sk_cta: 2
|
||||
- extra_iters: 0
|
||||
|
||||
The tiles with iters are as follows:
|
||||
|
||||
tile_idx: __________0_________|_________1_________|
|
||||
tile_iter:| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 |
|
||||
| | | | | | | | | | |
|
||||
<---------------SK Tiles--------------->|
|
||||
|
||||
From the above configuration, we get the following:
|
||||
- SK CTA 0: tile_iter_start is 0 with local CTA index of 0 in tile 0
|
||||
- SK CTA 1: tile_iter_start is 0 with local CTA index of 1 in tile 0
|
||||
- SK CTA 2: tile_iter_start is 0 with local CTA index of 2 in tile 0
|
||||
- SK CTA 2: tile_iter_start is 5 with local CTA index of 0 in tile 1
|
||||
- SK CTA 3: tile_iter_start is 5 with local CTA index of 1 in tile 1
|
||||
- SK CTA 4: tile_iter_start is 5 with local CTA index of 2 in tile 1
|
||||
*/
|
||||
|
||||
// Now we create a vector of triplets (tile_iter_start, cta_idx, tile_local_cta_idx) to test
|
||||
std::vector<std::array<ck_tile::index_t, 3>> sk_triplets{
|
||||
{0, 0, 0}, {0, 1, 1}, {0, 2, 2}, {5, 2, 0}, {5, 3, 1}, {5, 4, 2}};
|
||||
|
||||
for(const auto& triplet : sk_triplets)
|
||||
{
|
||||
const auto& [tile_iter_start, cta_idx, tile_local_cta_idx] = triplet;
|
||||
test_get_tile_local_cta_idx<StreamKTilePartitionerBaseConfigSKOnlyLargeK>(
|
||||
tile_iter_start, cta_idx, tile_local_cta_idx);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(StreamKTilePartitionerBaseGetTileLocalCtaIndex, DP2TileSK)
|
||||
{
|
||||
/*
|
||||
The StreamKTilePartitionerBaseConfigDP2TileSK has the following form:
|
||||
- tiles in the C tensor: 7
|
||||
- iters_per_tile: 3
|
||||
- grid: 3
|
||||
- dp_tiles: 3
|
||||
- sk_tiles: 4
|
||||
- iters_per_sk_cta: 2
|
||||
- extra_iters: 2
|
||||
|
||||
The tiles with iters are as follows:
|
||||
|
||||
tile_idx: ____0___|___1___|___2___|___3___|___4___|____5____|____6____|
|
||||
tile_iter:| 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 |
|
||||
| | | | | | | | | | | | | | |
|
||||
|<-------DP Tiles------>|<------------SK Tiles------------->|
|
||||
|
||||
From the above configuration, we get the following:
|
||||
- SK CTA 0: tile_iter_start is 6 with local CTA index of 0 in tile 3
|
||||
- SK CTA 0: tile_iter_start is 8 with local CTA index of 0 in tile 4
|
||||
- SK CTA 1: tile_iter_start is 8 with local CTA index of 1 in tile 4
|
||||
- SK CTA 1: tile_iter_start is 10 with local CTA index of 0 in tile 5
|
||||
- SK CTA 2: tile_iter_start is 12 with local CTA index of 0 in tile 6
|
||||
*/
|
||||
|
||||
// Now we create a vector of triplets (tile_iter_start, cta_idx, tile_local_cta_idx) to test
|
||||
std::vector<std::array<ck_tile::index_t, 3>> sk_triplets{
|
||||
{6, 0, 0}, {8, 0, 0}, {8, 1, 1}, {10, 1, 0}, {12, 2, 0}};
|
||||
|
||||
for(const auto& triplet : sk_triplets)
|
||||
{
|
||||
const auto& [tile_iter_start, cta_idx, tile_local_cta_idx] = triplet;
|
||||
test_get_tile_local_cta_idx<StreamKTilePartitionerBaseConfigDP2TileSK>(
|
||||
tile_iter_start, cta_idx, tile_local_cta_idx);
|
||||
}
|
||||
}
|
||||
|
||||
// Persistent
|
||||
TEST(StreamKTilePartitioner_PersistentConstructor, SKOnly)
|
||||
{
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include <array>
|
||||
|
||||
enum StreamKTilePartitionerBaseMethodId
|
||||
{
|
||||
@@ -12,7 +13,8 @@ enum StreamKTilePartitionerBaseMethodId
|
||||
GET_TILE_BOUNDARIES,
|
||||
GET_TILE_INDEX,
|
||||
GET_ITER_BOUNDARIES,
|
||||
GET_OUTPUT_TILE_INDEX
|
||||
GET_OUTPUT_TILE_INDEX,
|
||||
GET_TILE_LOCAL_CTA_INDEX
|
||||
};
|
||||
|
||||
// Base kernel wrapper class to facilitate testing class device functions.
|
||||
@@ -136,6 +138,22 @@ struct KernelWrapperSpecialized<TilePartitioner,
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TilePartitioner>
|
||||
struct KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_TILE_LOCAL_CTA_INDEX>
|
||||
: public KernelWrapper<TilePartitioner>
|
||||
{
|
||||
|
||||
using Base = KernelWrapper<TilePartitioner>;
|
||||
|
||||
CK_TILE_DEVICE void operator()(typename Base::KernelArgs kargs)
|
||||
{
|
||||
ck_tile::index_t tile_local_cta_index =
|
||||
kargs.tile_partitioner.get_tile_local_cta_index(kargs.arg1, kargs.arg2);
|
||||
*(static_cast<ck_tile::index_t*>(kargs.result1)) = tile_local_cta_index;
|
||||
}
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseExpected
|
||||
{
|
||||
ck_tile::index_t sk_tiles_;
|
||||
@@ -243,6 +261,22 @@ struct StreamKTilePartitionerBaseConfigSKOnly : public StreamKTilePartitionerBas
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigSKOnlyLargeK : public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
static constexpr ck_tile::index_t M = 8;
|
||||
static constexpr ck_tile::index_t N = 2;
|
||||
static constexpr ck_tile::index_t K = 10;
|
||||
static constexpr ck_tile::index_t GRID = 5;
|
||||
|
||||
static constexpr ck_tile::index_t M_TILE = 4;
|
||||
static constexpr ck_tile::index_t N_TILE = 2;
|
||||
static constexpr ck_tile::index_t K_TILE = 2;
|
||||
|
||||
using GemmShape = ck_tile::TileGemmShape<ck_tile::sequence<M_TILE, N_TILE, K_TILE>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>,
|
||||
ck_tile::sequence<UNUSED, UNUSED, UNUSED>>;
|
||||
};
|
||||
|
||||
struct StreamKTilePartitionerBaseConfigEdgeCase : public StreamKTilePartitionerBaseConfig
|
||||
{
|
||||
|
||||
@@ -314,6 +348,38 @@ void test_get_output_tile_index(ck_tile::index_t tile_idx,
|
||||
EXPECT_EQ(in, in_expected);
|
||||
};
|
||||
|
||||
template <typename Config>
|
||||
void test_get_tile_local_cta_idx(ck_tile::index_t tile_iter_start,
|
||||
ck_tile::index_t cta_idx,
|
||||
ck_tile::index_t expected_tile_local_cta_idx)
|
||||
{
|
||||
// Types
|
||||
using TilePartitioner = ck_tile::StreamKTilePartitionerBase<typename Config::GemmShape>;
|
||||
using Kernel =
|
||||
KernelWrapperSpecialized<TilePartitioner,
|
||||
StreamKTilePartitionerBaseMethodId::GET_TILE_LOCAL_CTA_INDEX>;
|
||||
|
||||
// Test parameters
|
||||
ck_tile::StreamKTilePartitionerBase<typename Config::GemmShape> tile_partitioner{
|
||||
Config::M, Config::N, Config::K, Config::GRID};
|
||||
ck_tile::DeviceMem tile_local_cta_idx_dev(sizeof(ck_tile::index_t));
|
||||
|
||||
// Launch kernel
|
||||
auto kargs = Kernel::MakeKernelArgs(tile_iter_start,
|
||||
cta_idx,
|
||||
Config::UNUSED,
|
||||
tile_local_cta_idx_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
tile_partitioner);
|
||||
ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0, 0, 1},
|
||||
ck_tile::make_kernel<1>(Kernel{}, 1, 1, 0, kargs));
|
||||
|
||||
// Validate results
|
||||
ck_tile::index_t tile_local_cta_idx;
|
||||
tile_local_cta_idx_dev.FromDevice(&tile_local_cta_idx);
|
||||
EXPECT_EQ(tile_local_cta_idx, expected_tile_local_cta_idx);
|
||||
}
|
||||
|
||||
// Configs for TilePartitioner Child structs
|
||||
struct StreamKTilePartitionerV2PersistentExpected
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user