mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
Merge commit '22b945e06ea4b4de188d7ff4ec7ae4bf127be9f9' into develop
This commit is contained in:
@@ -8,7 +8,8 @@
|
||||
namespace ck_tile {
|
||||
enum StreamKReductionStrategy : uint32_t
|
||||
{
|
||||
Atomic = 0u,
|
||||
Reduction = 1u
|
||||
Atomic = 0u,
|
||||
Reduction = 1u,
|
||||
TreeReduction = 2u
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -33,9 +33,10 @@
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_coherency.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner_impl.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp"
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename CompilerTarget, typename Enabler = void>
|
||||
struct StreamKCoherency
|
||||
{
|
||||
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
|
||||
amd_buffer_coherence_enum::coherence_default;
|
||||
};
|
||||
|
||||
template <typename CompilerTarget>
|
||||
struct StreamKCoherency<CompilerTarget,
|
||||
core::arch::enable_if_target_id_t<CompilerTarget,
|
||||
core::arch::amdgcn_target_id::GFX942,
|
||||
core::arch::amdgcn_target_id::GFX950>>
|
||||
{
|
||||
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
|
||||
amd_buffer_coherence_enum::SYSTEM_NT0;
|
||||
};
|
||||
|
||||
template <typename CompilerTarget>
|
||||
struct StreamKCoherency<CompilerTarget,
|
||||
core::arch::enable_if_target_id_t<CompilerTarget,
|
||||
core::arch::amdgcn_target_id::GFX908,
|
||||
core::arch::amdgcn_target_id::GFX90A>>
|
||||
{
|
||||
static constexpr amd_buffer_coherence_enum BUFFER_COHERENCE =
|
||||
amd_buffer_coherence_enum::glc_slc;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "streamk_gemm_coherency.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -318,37 +319,58 @@ struct StreamKKernel
|
||||
* results.
|
||||
* @param kargs Kernel arguments, including the workspace pointer.
|
||||
* @param cta_idx The index of the current thread block (CTA).
|
||||
* @note This function utilizes a workgroup barrier to set a synchronization flag for the given
|
||||
* CTA index.
|
||||
* @note This function utilizes a scalar store to write to the flags buffer.
|
||||
*/
|
||||
CK_TILE_DEVICE void SignalStorePartialDone(const StreamKKernelArgs& kargs,
|
||||
index_t cta_idx) const
|
||||
{
|
||||
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
|
||||
workgroup_barrier sk_flags(sk_flags_ptr);
|
||||
sk_flags.wait_set(0, 1, cta_idx);
|
||||
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
|
||||
index_t offset = cta_idx * sizeof(index_t);
|
||||
|
||||
asm volatile("s_mov_b32 m0, %2\n\t"
|
||||
// Depending on the architecture, the GLC flag will bypass the approproriate
|
||||
// cache level(s) to ensure the write is visible to other workgroups. See the
|
||||
// appropriate ISA for details about the GLC modifier.
|
||||
"s_store_dword %0, %1, %2 glc\n\t"
|
||||
"s_waitcnt lgkmcnt(0)" // Wait for the store to complete
|
||||
:
|
||||
: "s"(1), "s"(sk_flags_ptr), "s"(offset)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Waits for the thread block (cta_idx) to complete storing its partial results.
|
||||
* @param kargs Kernel arguments, including the workspace pointer.
|
||||
* @param cta_idx The index of the thread block (CTA).
|
||||
* @note This function utilizes a workgroup barrier to wait for the synchronization flag to be
|
||||
* set by the given CTA index.
|
||||
* @note This function utilizes a scalar load to read from the flags
|
||||
* buffer.
|
||||
*/
|
||||
CK_TILE_DEVICE void WaitStorePartialDone(const StreamKKernelArgs& kargs, index_t cta_idx) const
|
||||
{
|
||||
auto sk_flags_ptr = static_cast<uint32_t*>(kargs.workspace_ptr);
|
||||
workgroup_barrier sk_flags(sk_flags_ptr);
|
||||
sk_flags.wait_eq(1, cta_idx);
|
||||
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
|
||||
index_t result;
|
||||
index_t offset = cta_idx * sizeof(index_t);
|
||||
|
||||
do
|
||||
{
|
||||
asm volatile("s_mov_b32 m0, %2\n\t"
|
||||
// Depending on the architecture, the GLC flag will bypass the
|
||||
// approproriate cache level(s) to avoid reading stale flags. See the
|
||||
// appropriate ISA for details about the GLC modifier.
|
||||
"s_load_dword %0, %1, %2 glc\n\t"
|
||||
"s_waitcnt lgkmcnt(0)" // Wait for the load to complete
|
||||
: "=s"(result)
|
||||
: "s"(sk_flags_ptr), "s"(offset)
|
||||
: "memory");
|
||||
} while(result != 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Adds the values of a block tile to an output block tile.
|
||||
* @param in_out_block_tile The output block tile to which values are added.
|
||||
* @param in_block_tile The input block tile whose values are added.
|
||||
* @note This function iterates over the distributed spans of the block tiles and updates the
|
||||
* output block tile with accumulated values.
|
||||
* @note This function iterates over the distributed spans of the block tiles and updates
|
||||
* the output block tile with accumulated values.
|
||||
*/
|
||||
template <typename OAccTile>
|
||||
CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile,
|
||||
@@ -370,8 +392,8 @@ struct StreamKKernel
|
||||
* @param cta_idx The index of the thread block (CTA).
|
||||
* @param c_block_tile_dist The tile distribution for the block.
|
||||
* @return The loaded partial block tile.
|
||||
* @note This function calculates the buffer pointer and uses the tile distribution for loading
|
||||
* the partial block tile.
|
||||
* @note This function calculates the buffer pointer and uses the tile distribution for
|
||||
* loading the partial block tile.
|
||||
*/
|
||||
template <typename DataType, typename OAccTileDist>
|
||||
CK_TILE_DEVICE auto LoadPartial(const StreamKKernelArgs& kargs,
|
||||
@@ -405,8 +427,8 @@ struct StreamKKernel
|
||||
* @param kargs Kernel arguments, including the workspace pointer.
|
||||
* @param cta_idx The index of the thread block (CTA).
|
||||
* @param c_block_tile The block tile to be stored.
|
||||
* @note This function calculates the buffer pointer and uses the tile window for storing the
|
||||
* partial block tile.
|
||||
* @note This function calculates the buffer pointer and uses the tile window for storing
|
||||
* the partial block tile.
|
||||
*/
|
||||
template <typename OAccTile>
|
||||
CK_TILE_DEVICE void StorePartial(const StreamKKernelArgs& kargs,
|
||||
@@ -420,7 +442,10 @@ struct StreamKKernel
|
||||
kargs.tile_partitioner.get_flags_buffer_size() +
|
||||
cta_idx * c_block_tile_buffer_size;
|
||||
|
||||
const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
const auto& partial_tensor_view = make_naive_tensor_view<
|
||||
address_space_enum::global,
|
||||
memory_operation_enum::set,
|
||||
StreamKCoherency<decltype(core::arch::get_compiler_target())>::BUFFER_COHERENCE>(
|
||||
static_cast<typename OAccTile::DataType*>(partial_buffer_ptr),
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
make_tuple(TilePartitioner::NPerBlock, 1),
|
||||
@@ -431,8 +456,11 @@ struct StreamKKernel
|
||||
partial_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{0, 0});
|
||||
|
||||
store_tile(partial_tile_window, c_block_tile);
|
||||
// Wait for all vector stores for this wavefront to complete
|
||||
s_waitcnt</*vmcnt*/ 0, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
|
||||
// Wait for all wavefronts in this workgroup to arrive here before continuing
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -483,7 +511,8 @@ struct StreamKKernel
|
||||
{
|
||||
BaseGemm(kargs, tile_idx, num_loop_sk, i_k_a, i_k_b, k_size, smem_ptr_0);
|
||||
}
|
||||
else
|
||||
else if(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Reduction ||
|
||||
TilePartitioner::ReductionStrategy == StreamKReductionStrategy::TreeReduction)
|
||||
{
|
||||
const auto c_macro_tile_idx =
|
||||
kargs.tile_partitioner.get_output_tile_index(tile_idx);
|
||||
@@ -528,46 +557,107 @@ struct StreamKKernel
|
||||
|
||||
auto tile_started = iter_start == tile_iter_start;
|
||||
auto tile_ended = iter_end >= tile_iter_end;
|
||||
if(!tile_started)
|
||||
|
||||
if constexpr(TilePartitioner::ReductionStrategy ==
|
||||
StreamKReductionStrategy::Reduction)
|
||||
{
|
||||
StorePartial(kargs, cta_idx, c_block_tile);
|
||||
// Ensure device-wide visibility of partial results stored in global memory
|
||||
// before signaling completion. __threadfence() guarantees that all global
|
||||
// memory writes by this thread are visible to other threads on the device.
|
||||
__threadfence(); // send signal when the store is done
|
||||
SignalStorePartialDone(kargs, cta_idx);
|
||||
if(!tile_started)
|
||||
{
|
||||
StorePartial(kargs, cta_idx, c_block_tile);
|
||||
SignalStorePartialDone(kargs, cta_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto accum_block_tile = c_block_tile;
|
||||
if(!tile_ended)
|
||||
{
|
||||
const index_t iter_per_tile =
|
||||
kargs.tile_partitioner.get_iters_per_tile();
|
||||
const index_t iter_per_cta =
|
||||
kargs.tile_partitioner.get_iters_per_sk_cta();
|
||||
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
|
||||
int accum_iters = local_iter_end - local_iter_start;
|
||||
int next_cta = cta_idx + 1;
|
||||
|
||||
while(accum_iters < iter_per_tile)
|
||||
{
|
||||
WaitStorePartialDone(kargs, next_cta);
|
||||
|
||||
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
|
||||
AddBlockTile(
|
||||
accum_block_tile,
|
||||
LoadPartial<typename BlockType::DataType>(
|
||||
kargs, next_cta, c_block_tile.get_tile_distribution()));
|
||||
|
||||
accum_iters += iter_per_cta + (next_cta < extra_iters);
|
||||
++next_cta;
|
||||
}
|
||||
}
|
||||
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
else
|
||||
else // Tree Reduction
|
||||
{
|
||||
auto accum_block_tile = c_block_tile;
|
||||
if(!tile_ended)
|
||||
index_t tile_local_cta_idx =
|
||||
kargs.tile_partitioner.get_tile_local_cta_index(tile_iter_start, cta_idx);
|
||||
|
||||
for(index_t stride = 1;; stride <<= 1)
|
||||
{
|
||||
const index_t iter_per_tile = kargs.tile_partitioner.get_iters_per_tile();
|
||||
const index_t iter_per_cta = kargs.tile_partitioner.get_iters_per_sk_cta();
|
||||
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
|
||||
int accum_iters = local_iter_end - local_iter_start;
|
||||
int next_cta = cta_idx + 1;
|
||||
const index_t partner_cta_idx = cta_idx + stride;
|
||||
const index_t partner_start_iter =
|
||||
kargs.tile_partitioner.get_start_iter(partner_cta_idx);
|
||||
bool partner_in_tile = partner_start_iter < tile_iter_end;
|
||||
|
||||
while(accum_iters < iter_per_tile)
|
||||
// If the partner of the workgroup who started the tile is not in this tile,
|
||||
// then the work for this tile is done and results can be stored in the C
|
||||
// tensor.
|
||||
if(tile_started && !partner_in_tile)
|
||||
{
|
||||
WaitStorePartialDone(kargs, next_cta);
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
|
||||
break;
|
||||
}
|
||||
|
||||
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
|
||||
AddBlockTile(
|
||||
accum_block_tile,
|
||||
LoadPartial<typename BlockType::DataType>(
|
||||
kargs, next_cta, c_block_tile.get_tile_distribution()));
|
||||
|
||||
accum_iters += iter_per_cta + (next_cta < extra_iters);
|
||||
++next_cta;
|
||||
// It's this workgroup's turn to read from partials.
|
||||
if(tile_local_cta_idx % (stride << 1) == 0)
|
||||
{
|
||||
// If this workgroup's partner is in the tile then it can read from
|
||||
// partials and accumulate results.
|
||||
if(partner_in_tile)
|
||||
{
|
||||
WaitStorePartialDone(kargs, partner_cta_idx);
|
||||
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
|
||||
AddBlockTile(accum_block_tile,
|
||||
LoadPartial<typename BlockType::DataType>(
|
||||
kargs,
|
||||
partner_cta_idx,
|
||||
c_block_tile.get_tile_distribution()));
|
||||
}
|
||||
}
|
||||
// Otherwise, it's this workgroup's turn to write to partials. All
|
||||
// workgroups, except the workgroup who starts the tile, will write to
|
||||
// partials.
|
||||
else
|
||||
{
|
||||
StorePartial(kargs, cta_idx, accum_block_tile);
|
||||
SignalStorePartialDone(kargs, cta_idx);
|
||||
// Once the workgroup writes to partials, it has no more work to do for
|
||||
// this tile.
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window, accum_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(
|
||||
"An implementation does not exist for the chosen reduction strategy.");
|
||||
}
|
||||
|
||||
// Prepare for next Stream-K loop iteration.
|
||||
iter_start = tile_iter_end;
|
||||
@@ -640,10 +730,10 @@ struct StreamKKernel
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset is
|
||||
* the starting macro tile index in the K dimension for the workgroup.
|
||||
* @return A tuple containing the offsets into the A and B tensors accounting for the layouts
|
||||
* of A and B.
|
||||
* @brief Computes the K offsets in the A and B tensors given iter_offset, where iter_offset
|
||||
* is the starting macro tile index in the K dimension for the workgroup.
|
||||
* @return A tuple containing the offsets into the A and B tensors accounting for the
|
||||
* layouts of A and B.
|
||||
* @note The default case is that A is assumed to be row major and B is assumed to be column
|
||||
* major.
|
||||
*/
|
||||
@@ -688,7 +778,8 @@ struct StreamKKernel
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the kernel
|
||||
* @brief Computes the occupancy (i.e. maximum number of active blocks per CU) for the
|
||||
* kernel
|
||||
* @return The occupancy
|
||||
* @note This function queries the maximum occupancy of the kernel using
|
||||
* `hipOccupancyMaxActiveBlocksPerMultiprocessor`.
|
||||
@@ -46,6 +46,16 @@ struct StreamKTilePartitionerBase
|
||||
CK_TILE_HOST_DEVICE index_t get_flags_buffer_size() const noexcept;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Calculates the start iteration for the given the cta_idx.
|
||||
* @param cta_idx The current Stream-K workgroup's index.
|
||||
* @return index_t The start iteration.
|
||||
* @note It is assumed that the first Stream-K workgroup has a `cta_idx` of zero. If a
|
||||
* non-persistent DP section is used, then a Stream-K workgroup's `cta_idx` should be something
|
||||
* like `blockIdx.x` minus number of DP workgroups.
|
||||
*/
|
||||
CK_TILE_DEVICE index_t get_start_iter(index_t cta_idx) const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the start and end iteration given the cta_idx.
|
||||
*
|
||||
@@ -107,7 +117,17 @@ struct StreamKTilePartitionerBase
|
||||
get_local_iter_end(index_t tile_iter_start, index_t iter_end, index_t tile_iter_end) noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the workgroups 2D tile index in the C tensor given the 1D tile index.
|
||||
* @brief Calculates the workgroup's local CTA idx within the given tile.
|
||||
*
|
||||
* @param tile_iter_start The starting tile iteration.
|
||||
* @param cta_idx The Stream-K workgroup index.
|
||||
* @return index_t The tile local workgroup index in the tile.
|
||||
*/
|
||||
CK_TILE_DEVICE index_t get_tile_local_cta_index(index_t tile_iter_start,
|
||||
index_t cta_idx) const noexcept;
|
||||
|
||||
/**
|
||||
* @brief Calculates the workgroup's 2D tile index in the C tensor given the 1D tile index.
|
||||
*
|
||||
* @param tile_idx The 1D tile index in the C tensor for the workgroup.
|
||||
* @return index_t The corresponding 2D tile index in the C tensor for the workgroup.
|
||||
@@ -61,13 +61,24 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_flags
|
||||
return sizeof(index_t) * sk_ctas_;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_start_iter(
|
||||
index_t cta_idx) const noexcept
|
||||
{
|
||||
// Compute the number of extra iterations done before this CTA. If the cta_idx is less than
|
||||
// extra_iters, the number of extra iterations before the CTA is exactly the cta_idx. Otherwise,
|
||||
// it is extra_iters.
|
||||
index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
|
||||
return total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE void
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_iter_boundaries(
|
||||
index_t& iter, index_t& iter_end, index_t cta_idx) const noexcept
|
||||
{
|
||||
index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
|
||||
iter = total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
|
||||
iter = get_start_iter(cta_idx);
|
||||
iter_end = iter + iters_per_sk_cta_ + (cta_idx < extra_iters_);
|
||||
}
|
||||
|
||||
@@ -104,6 +115,24 @@ StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_local
|
||||
return ck_tile::min(iter_end, tile_iter_end) - tile_iter;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_tile_local_cta_index(
|
||||
index_t tile_iter_start, index_t cta_idx) const noexcept
|
||||
{
|
||||
tile_iter_start = tile_iter_start - (dp_tiles_ * iters_per_tile_);
|
||||
|
||||
// Compute how many WGs fit before this tile starts assuming each WG does an
|
||||
// extra_iter
|
||||
const index_t num_extra_iter_ctas = tile_iter_start / (iters_per_sk_cta_ + 1);
|
||||
// Compute how many WGs fit before this tile starts excluding extra iters
|
||||
const index_t num_non_extra_iter_ctas = (tile_iter_start - extra_iters_) / iters_per_sk_cta_;
|
||||
// Compute the CTA idx for the CTA that starts this tile
|
||||
const index_t coop_group_start =
|
||||
num_extra_iter_ctas < extra_iters_ ? num_extra_iter_ctas : num_non_extra_iter_ctas;
|
||||
return cta_idx - coop_group_start;
|
||||
}
|
||||
|
||||
template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
|
||||
CK_TILE_DEVICE auto
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_output_tile_index(
|
||||
@@ -121,7 +150,8 @@ CK_TILE_HOST_DEVICE index_t
|
||||
StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>::get_workspace_size(
|
||||
index_t acc_element_bytes) const noexcept
|
||||
{
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
|
||||
if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction ||
|
||||
ReductionStrategy == StreamKReductionStrategy::TreeReduction)
|
||||
{
|
||||
|
||||
return get_partials_buffer_size(acc_element_bytes) + get_flags_buffer_size();
|
||||
Reference in New Issue
Block a user