mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Stream-K Tree Reduction and Cache Skipping Integration (#3371)
* CK Tile Stream-K Tree Reduction This change adds the first implementation of the Stream-K tree reduction strategy into CK Tile. The tree reduction reduces the the number of steps for accumulating results for a tile from O(N) to O(logN) where N is the number of workgroups contributing to a C tile. Additionally, in the original non-atomic reduction strategy, atomics were used to set the flags buffer and to read from the flags buffer. Howeover, through investigation with the tree reduciton, atomics with default (relaxed) semantics were not enough to guarantee workgroups would not read stale data, leading to incorrect results. Stronger acquire/release memory orderings are too expensive. So, this change also eliminates the use of atomics for setting the flags. Instead, we leverage cache modifiers (e.g., GLC) to avoid writing to cache, thereby avoiding the use of atomics. Prelimiary tests were also added for the normal reduction and tree reduction. More will be added in a future PR via tile engine. * Move Stream-K kernel files to a subdirectory * Cleanup Code Style & Handle Unsupported Reductions This change makes the following small changes: - Add an explicit else block for unimplemented reduction strategies - Clarify type of sk_flags_ptr via auto* - Add description for extra_iters_before_me variable * Run new copyright script on new files
This commit is contained in:
@@ -8,7 +8,8 @@
|
||||
namespace ck_tile {
|
||||
enum StreamKReductionStrategy : uint32_t
|
||||
{
|
||||
Atomic = 0u,
|
||||
Reduction = 1u
|
||||
Atomic = 0u,
|
||||
Reduction = 1u,
|
||||
TreeReduction = 2u
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -33,9 +33,10 @@
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_coherency.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename CompilerTarget, typename Enabler = void>
|
||||
struct StreamKCoherency
|
||||
{
|
||||
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
|
||||
amd_buffer_coherence_enum::coherence_default;
|
||||
};
|
||||
|
||||
template <typename CompilerTarget>
|
||||
struct StreamKCoherency<CompilerTarget,
|
||||
core::arch::enable_if_target_id_t<CompilerTarget,
|
||||
core::arch::amdgcn_target_id::GFX942,
|
||||
core::arch::amdgcn_target_id::GFX950>>
|
||||
{
|
||||
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
|
||||
amd_buffer_coherence_enum::SYSTEM_NT0;
|
||||
};
|
||||
|
||||
template <typename CompilerTarget>
|
||||
struct StreamKCoherency<CompilerTarget,
|
||||
core::arch::enable_if_target_id_t<CompilerTarget,
|
||||
core::arch::amdgcn_target_id::GFX908,
|
||||
core::arch::amdgcn_target_id::GFX90A>>
|
||||
{
|
||||
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
|
||||
amd_buffer_coherence_enum::glc_slc;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "streamk_gemm_coherency.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -318,37 +319,58 @@ struct StreamKKernel
|
||||
* results.
|
||||
* @param kargs Kernel arguments, including the workspace pointer.
|
||||
* @param cta_idx The index of the current thread block (CTA).
|
||||
* @note This function utilizes a workgroup barrier to set a synchronization flag for the given
|
||||
* CTA index.
|
||||
* @note This function utilizes a scalar store to write to the flags buffer.
|
||||
*/
|
||||
CK_TILE_DEVICE void SignalStorePartialDone(const StreamKKernelArgs& kargs,
|
||||
index_t cta_idx) const
|
||||
{
|
||||
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
|
||||
workgroup_barrier sk_flags(sk_flags_ptr);
|
||||
sk_flags.wait_set(0, 1, cta_idx);
|
||||
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
|
||||
index_t offset = cta_idx * sizeof(index_t);
|
||||
|
||||
asm volatile("s_mov_b32 m0, %2\n\t"
|
||||
// Depending on the architecture, the GLC flag will bypass the approproriate
|
||||
// cache level(s) to ensure the write is visible to other workgroups. See the
|
||||
// appropriate ISA for details about the GLC modifier.
|
||||
"s_store_dword %0, %1, %2 glc\n\t"
|
||||
"s_waitcnt lgkmcnt(0)" // Wait for the store to complete
|
||||
:
|
||||
: "s"(1), "s"(sk_flags_ptr), "s"(offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Waits for the thread block (cta_idx) to complete storing its partial results.
|
||||
* @param kargs Kernel arguments, including the workspace pointer.
|
||||
* @param cta_idx The index of the thread block (CTA).
|
||||
* @note This function utilizes a workgroup barrier to wait for the synchronization flag to be
|
||||
* set by the given CTA index.
|
||||
* @note This function utilizes a scalar load to read from the flags
|
||||
* buffer.
|
||||
*/
|
||||
CK_TILE_DEVICE void WaitStorePartialDone(const StreamKKernelArgs& kargs, index_t cta_idx) const
|
||||
{
|
||||
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
|
||||
workgroup_barrier sk_flags(sk_flags_ptr);
|
||||
sk_flags.wait_eq(1, cta_idx);
|
||||
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
|
||||
index_t result;
|
||||
index_t offset = cta_idx * sizeof(index_t);
|
||||
|
||||
do
|
||||
{
|
||||
asm volatile("s_mov_b32 m0, %2\n\t"
|
||||
// Depending on the architecture, the GLC flag will bypass the
|
||||
// approproriate cache level(s) to avoid reading stale flags. See the
|
||||
// appropriate ISA for details about the GLC modifier.
|
||||
"s_load_dword %0, %1, %2 glc\n\t"
|
||||
"s_waitcnt lgkmcnt(0)" // Wait for the load to complete
|
||||
: "=s"(result)
|
||||
: "s"(sk_flags_ptr), "s"(offset)
|
||||
: "memory");
|
||||
} while(result != 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Adds the values of a block tile to an output block tile.
|
||||
* @param in_out_block_tile The output block tile to which values are added.
|
||||
* @param in_block_tile The input block tile whose values are added.
|
||||
* @note This function iterates over the distributed spans of the block tiles and updates the
|
||||
* output block tile with accumulated values.
|
||||
* @note This function iterates over the distributed spans of the block tiles and updates
|
||||
* the output block tile with accumulated values.
|
||||
*/
|
||||
template <typename OAccTile>
|
||||
CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile,
|
||||
@@ -370,8 +392,8 @@ struct StreamKKernel
|
||||
* @param cta_idx The index of the thread block (CTA).
|
||||
* @param c_block_tile_dist The tile distribution for the block.
|
||||
* @return The loaded partial block tile.
|
||||
* @note This function calculates the buffer pointer and uses the tile distribution for loading
|
||||
* the partial block tile.
|
||||
* @note This function calculates the buffer pointer and uses the tile distribution for
|
||||
* loading the partial block tile.
|
||||
*/
|
||||
template <typename DataType, typename OAccTileDist>
|
||||
CK_TILE_DEVICE auto LoadPartial(const StreamKKernelArgs& kargs,
|
||||
@@ -405,8 +427,8 @@ struct StreamKKernel
|
||||
* @param kargs Kernel arguments, including the workspace pointer.
|
||||
* @param cta_idx The index of the thread block (CTA).
|
||||
* @param c_block_tile The block tile to be stored.
|
||||
* @note This function calculates the buffer pointer and uses the tile window for storing the
|
||||
* partial block tile.
|
||||
* @note This function calculates the buffer pointer and uses the tile window for storing
|
||||
* the partial block tile.
|
||||
*/
|
||||
template <typename OAccTile>
|
||||
CK_TILE_DEVICE void StorePartial(const StreamKKernelArgs& kargs,
|
||||
@@ -420,7 +442,10 @@ struct StreamKKernel
|
||||
kargs.tile_partitioner.get_flags_buffer_size() +
|
||||
cta_idx * c_block_tile_buffer_size;
|
||||
|
||||
const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
const auto& partial_tensor_view = make_naive_tensor_view<
|
||||
address_space_enum::global,
|
||||
memory_operation_enum::set,
|
||||
StreamKCoherency<decltype(core::arch::get_compiler_target())>::BUFFER_COHERENCE>(
|
||||
static_cast<typename OAccTile::DataType*>(partial_buffer_ptr),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
make_tuple(TilePartitioner::NPerBlock, 1),
|
||||
@@ -431,8 +456,11 @@ struct StreamKKernel
|
||||
partial_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{0, 0});
|
||||
|
||||
store_tile(partial_tile_window, c_block_tile);
|
||||
// Wait for all vector stores for this wavefront to complete
|
||||
s_waitcnt</*vmcnt*/ 0, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
|
||||
// Wait for all wavefronts in this workgroup to arrive here before continuing
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -483,7 +511,8 @@ struct StreamKKernel
|
||||
{
|
||||
BaseGemm(kargs, tile_idx, num_loop_sk, i_k_a, i_k_b, k_size, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
else if(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Reduction ||
|
||||
TilePartitioner::ReductionStrategy == StreamKReductionStrategy::TreeReduction)
|
||||
{
|
||||
const auto c_macro_tile_idx =
|
||||
kargs.tile_partitioner.get_output_tile_index(tile_idx);
|
||||
@@ -528,46 +557,107 @@ struct StreamKKernel
|
||||
|
||||
auto tile_started = iter_start == tile_iter_start;
|
||||
auto tile_ended = iter_end >= tile_iter_end;
|
||||
if(!tile_started)
|
||||
|
||||
if constexpr(TilePartitioner::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
StorePartial(kargs, cta_idx, c_block_tile);
|
||||
// Ensure device-wide visibility of partial results stored in global memory
|
||||
// before signaling completion. __threadfence() guarantees that all global
|
||||
// memory writes by this thread are visible to other threads on the device.
|
||||
__threadfence(); // send signal when the store is done
|
||||
SignalStorePartialDone(kargs, cta_idx);
|
||||
if(!tile_started)
|
||||
{
|
||||
StorePartial(kargs, cta_idx, c_block_tile);
|
||||
SignalStorePartialDone(kargs, cta_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto accum_block_tile = c_block_tile;
|
||||
if(!tile_ended)
|
||||
{
|
||||
const index_t iter_per_tile =
|
||||
kargs.tile_partitioner.get_iters_per_tile();
|
||||
const index_t iter_per_cta =
|
||||
kargs.tile_partitioner.get_iters_per_sk_cta();
|
||||
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
|
||||
int accum_iters = local_iter_end - local_iter_start;
|
||||
int next_cta = cta_idx + 1;
|
||||
|
||||
while(accum_iters < iter_per_tile)
|
||||
{
|
||||
WaitStorePartialDone(kargs, next_cta);
|
||||
|
||||
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
|
||||
AddBlockTile(
|
||||
accum_block_tile,
|
||||
LoadPartial<typename BlockType::DataType>(
|
||||
kargs, next_cta, c_block_tile.get_tile_distribution()));
|
||||
|
||||
accum_iters += iter_per_cta + (next_cta < extra_iters);
|
||||
++next_cta;
|
||||
}
|
||||
}
|
||||
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
else
|
||||
else // Tree Reduction
|
||||
{
|
||||
auto accum_block_tile = c_block_tile;
|
||||
if(!tile_ended)
|
||||
index_t tile_local_cta_idx =
|
||||
kargs.tile_partitioner.get_tile_local_cta_index(tile_iter_start, cta_idx);
|
||||
|
||||
for(index_t stride = 1;; stride <<= 1)
|
||||
{
|
||||
const index_t iter_per_tile = kargs.tile_partitioner.get_iters_per_tile();
|
||||
const index_t iter_per_cta = kargs.tile_partitioner.get_iters_per_sk_cta();
|
||||
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
|
||||
int accum_iters = local_iter_end - local_iter_start;
|
||||
int next_cta = cta_idx + 1;
|
||||
const index_t partner_cta_idx = cta_idx + stride;
|
||||
const index_t partner_start_iter =
|
||||
kargs.tile_partitioner.get_start_iter(partner_cta_idx);
|
||||
bool partner_in_tile = partner_start_iter < tile_iter_end;
|
||||
|
||||
while(accum_iters < iter_per_tile)
|
||||
// If the partner of the workgroup who started the tile is not in this tile,
|
||||
// then the work for this tile is done and results can be stored in the C
|
||||
// tensor.
|
||||
if(tile_started && !partner_in_tile)
|
||||
{
|
||||
WaitStorePartialDone(kargs, next_cta);
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
|
||||
break;
|
||||
}
|
||||
|
||||
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
|
||||
AddBlockTile(
|
||||
accum_block_tile,
|
||||
LoadPartial<typename BlockType::DataType>(
|
||||
kargs, next_cta, c_block_tile.get_tile_distribution()));
|
||||
|
||||
accum_iters += iter_per_cta + (next_cta < extra_iters);
|
||||
++next_cta;
|
||||
// It's this workgroup's turn to read from partials.
|
||||
if(tile_local_cta_idx % (stride << 1) == 0)
|
||||
{
|
||||
// If this workgroup's partner is in the tile then it can read from
|
||||
// partials and accumulate results.
|
||||
if(partner_in_tile)
|
||||
{
|
||||
WaitStorePartialDone(kargs, partner_cta_idx);
|
||||
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
|
||||
AddBlockTile(accum_block_tile,
|
||||
LoadPartial<typename BlockType::DataType>(
|
||||
kargs,
|
||||
partner_cta_idx,
|
||||
c_block_tile.get_tile_distribution()));
|
||||
}
|
||||
}
|
||||
// Otherwise, it's this workgroup's turn to write to partials. All
|
||||
// workgroups, except the workgroup who starts the tile, will write to
|
||||
// partials.
|
||||
else
|
||||
{
|
||||
StorePartial(kargs, cta_idx, accum_block_tile);
|
||||
SignalStorePartialDone(kargs, cta_idx);
|
||||
// Once the workgroup writes to partials, it has no more work to do for
|
||||
// this tile.
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(
|
||||
"An implementation does not exist for the chosen reduction strategy.");
|
||||
}
|
||||
|
||||
// Prepare for next Stream-K loop iteration.
|
||||
iter_start = tile_iter_end;
|
||||
@@ -640,10 +730,10 @@ struct StreamKKernel
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset is
|
||||
* the starting macro tile index in the K dimension for the workgroup.
|
||||
* @return A tuple containing the offsets into the A and B tensors accounting for the layouts
|
||||
* of A and B.
|
||||
* @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset
|
||||
* is the starting macro tile index in the K dimension for the workgroup.
|
||||
* @return A tuple containing the offsets into the A and B tensors accounting for the
|
||||
* layouts of A and B.
|
||||
* @note The default case is that A is assumed to be row major and B is assumed to be column
|
||||
* major.
|
||||
*/
|
||||
@@ -688,7 +778,8 @@ struct StreamKKernel
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the kernel
|
||||
* @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the
|
||||
* kernel
|
||||
* @return The occupancy
|
||||
* @note This function queries the maximum occupancy of the kernel using
|
||||
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
|
||||
@@ -46,6 +46,16 @@ struct StreamKTilePartitionerBase
|
||||
CK_TILE_HOST_DEVICE index_t get_flags_buffer_size() const noexcept;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Calculates the start iteration for the given the cta_idx.
|
||||
* @param cta_idx The current Stream-K workgroup's index.
|
||||
* @return index_t The start iteration.
|
||||
* @note It is assumed that the first Stream-K workgroup has a `cta_idx` of zero. If a
|
||||
* non-persistent DP section is used, then a Stream-K workgroup's `cta_idx` should be something
|
||||
* like `blockIdx.x` minus number of DP workgroups.
|
||||
*/
|
||||
CK_TILE_DEVICE index_t get_start_iter(index_t cta_idx) const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the start and end iteration given the cta_idx.
|
||||
*
|
||||
@@ -107,7 +117,17 @@ struct StreamKTilePartitionerBase
|
||||
get_local_iter_end(index_t tile_iter_start, index_t iter_end, index_t tile_iter_end) noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the workgroups 2D tile index in the C tensor given the 1D tile index.
|
||||
* @brief Calculates the workgroup's local CTA idx within the given tile.
|
||||
*
|
||||
* @param tile_iter_start The starting tile iteration.
|
||||
* @param cta_idx The Stream-K workgroup index.
|
||||
* @return index_t The tile local workgroup index in the tile.
|
||||
*/
|
||||
CK_TILE_DEVICE index_t get_tile_local_cta_index(index_t tile_iter_start,
|
||||
index_t cta_idx) const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the workgroup's 2D tile index in the C tensor given the 1D tile index.
|
||||
*
|
||||
* @param tile_idx The 1D tile index in the C tensor for the workgroup.
|
||||
* @return index_t The corresponding 2D tile index in the C tensor for the workgroup.
|
||||
@@ -61,13 +61,24 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_flags
|
||||
return sizeof(index_t) * sk_ctas_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_start_iter(
|
||||
index_t cta_idx) const noexcept
|
||||
{
|
||||
// Compute the number of extra iterations done before this CTA. If the cta_idx is less than
|
||||
// extra_iters, the number of extra iterations before the CTA is exactly the cta_idx. Otherwise,
|
||||
// it is extra_iters.
|
||||
index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
|
||||
return total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE void
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_iter_boundaries(
|
||||
index_t& iter, index_t& iter_end, index_t cta_idx) const noexcept
|
||||
{
|
||||
index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
|
||||
iter = total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
|
||||
iter = get_start_iter(cta_idx);
|
||||
iter_end = iter + iters_per_sk_cta_ + (cta_idx < extra_iters_);
|
||||
}
|
||||
|
||||
@@ -104,6 +115,24 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_local
|
||||
return ck_tile::min(iter_end, tile_iter_end) - tile_iter;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_tile_local_cta_index(
|
||||
index_t tile_iter_start, index_t cta_idx) const noexcept
|
||||
{
|
||||
tile_iter_start = tile_iter_start - (dp_tiles_ * iters_per_tile_);
|
||||
|
||||
// Compute how many WGs fit before this tile starts assuming each WG does an
|
||||
// extra_iter
|
||||
const index_t num_extra_iter_ctas = tile_iter_start / (iters_per_sk_cta_ + 1);
|
||||
// Compute how many WGs fit before this tile starts excluding extra iters
|
||||
const index_t num_non_extra_iter_ctas = (tile_iter_start - extra_iters_) / iters_per_sk_cta_;
|
||||
// Compute the CTA idx for the CTA that starts this tile
|
||||
const index_t coop_group_start =
|
||||
num_extra_iter_ctas < extra_iters_ ? num_extra_iter_ctas : num_non_extra_iter_ctas;
|
||||
return cta_idx - coop_group_start;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE auto
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_output_tile_index(
|
||||
@@ -121,7 +150,8 @@ CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_workspace_size(
|
||||
index_t acc_element_bytes) const noexcept
|
||||
{
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction ||
|
||||
ReductionStrategy == StreamKReductionStrategy::TreeReduction)
|
||||
{
|
||||
|
||||
return get_partials_buffer_size(acc_element_bytes) + get_flags_buffer_size();
|
||||
Reference in New Issue
Block a user