mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
[CK Tile] Stream-K RDNA Support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Currently, CK Tile Stream-K only supports CDNA architectures. This change adds Stream-K support on RDNA3/3.5 and RDNA4 architectures. ## Technical Details Stream-K currently has 3 reduction strategies: 1) atomics, 2) linear, and 3) tree. The linear and tree reductions require inter-workgroup communication to a global flags buffer and a global partials buffer. To ensure cache coherency, we use cache modifiers to skip cache levels that are not visible to all workgroups. On CDNA architectures, scalar load and scalar store instructions are available, which we use to read and write to the flags buffer with appropriate cache skipping modifiers. However, RDNA architectures do not support scalar store instructions, so workgroups must use a buffer store instruction to write to flags. Additionally, cache modifiers differ between CDNA and RDNA; they also differ between RDNA3 and RDNA4. Given this information, the main changes are as follows: - Added RDNA flag signaling: Use buffer store instructions for writing to global flags buffer - Add appropriate cache modifiers for reading and writing to flags and partials: - RDNA3 (gfx11): Use `glc | dlc` coherence flags - RDNA4 (gfx12): Use `DEVICE` coherence scope - SFINAE-guarded overloads: Added compile-time dispatch for `SignalStorePartialDone()` and `WaitStorePartialDone()` based on target architecture - RDNA alignment requirements: Increased flags buffer alignment from 128B to 256B due to RDNA cache line size **A note about the `amd_buffer_coherence_enum`:** - **Problem:** The `amd_buffer_coherence_enum` uses preprocessor conditionals (`#if defined(__gfx12__)`) to define architecture-specific values. Template specializations reference enum values from different architectures (e.g., `glc_dlc` for GFX11). Due to C++ two-phase name lookup, non-dependent names are resolved during template parsing regardless of which architecture is being compiled, causing compilation failures when referenced values do not exist in the active preprocessor branch. - **Temporary Solution**: Added compatibility enum values to each architecture block. For example, I added `glc_dlc` in the `__gfx12__` block. I will create a ticket to refactor this enum with a design that has better scalability and tries to avoid the use of preprocessor conditionals. ## Test Plan ### Summary gtests were added to test wmma variants of Stream-K. These tests were stressed tested locally on gfx11 and gfx12. ### More details This PR makes the following changes/additions to the Stream-K gtests: - Split tests into MFMA (CDNA) and WMMA (RDNA) variants - Added 16 WMMA kernel types: FP16/BF16/FP8/BF8 × Linear/Tree reduction - WMMA uses 16×16×16 wave tiles for RDNA (this is the only tile size supported on RDNA) - Fixed RDNA WGP mode: multiply multiProcessorCount by 2 for actual CU count - As described in [HIP documentation](https://rocm.docs.amd.com/projects/HIP/en/docs-7.2.0/doxygen/html/group___global_defs.html#ggacc0acd7b9bda126c6bb3dfd6e2796d7ca3ac50041beb59111a5c76edf03da0898), when in Workgroup Processor (WGP) mode, the value of `hipDeviceAttributeMultiprocessorCount` is half of CUs, because a single WGP contains two CUs. The default mode on RDNA is WGP mode, so when creating (M, N, K) instances for gtests using the CU count, we need to multiply the CU count by 2 to get the correct value. This is not needed in the kernel host code, because the occupancy ensures that overall `max_active_wgs` is correct. ## Test Result All tests pass locally. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
403 lines
19 KiB
C++
403 lines
19 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_coherency.hpp"
|
|
|
|
namespace ck_tile {
|
|
enum StreamKReductionStrategy : uint32_t
|
|
{
|
|
Atomic = 0u,
|
|
Linear = 1u,
|
|
Tree = 2u
|
|
};
|
|
|
|
/// @brief StreamK reduction helpers: partial store/load, flag signaling, and tile accumulation.
|
|
/// Shared by StreamK GEMM and StreamK conv bwd weight kernels.
|
|
template <typename TilePartitioner_, typename GemmPipeline_, typename KernelArgs_>
|
|
struct StreamKReductionOps
|
|
{
|
|
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
|
using BlockGemm = typename GemmPipeline_::BlockGemm;
|
|
using WarpGemm = typename BlockGemm::WarpGemm;
|
|
using BlockGemmShape = typename GemmPipeline_::BlockGemmShape;
|
|
using CompilerTarget = decltype(core::arch::get_compiler_target());
|
|
|
|
// Helper: all supported architectures for specialized versions
|
|
template <typename CompilerTarget_>
|
|
static constexpr bool IsStreamKReductionSupportedArch()
|
|
{
|
|
return core::arch::is_target_id_any_of<CompilerTarget_,
|
|
core::arch::amdgcn_target_id::GFX90A,
|
|
core::arch::amdgcn_target_id::GFX942,
|
|
core::arch::amdgcn_target_id::GFX950,
|
|
core::arch::amdgcn_target_id::GFX1200,
|
|
core::arch::amdgcn_target_id::GFX1201,
|
|
core::arch::amdgcn_target_id::GFX12_GENERIC>() ||
|
|
core::arch::is_target_family_gfx11<CompilerTarget_>();
|
|
}
|
|
|
|
/**
|
|
*@brief Signals that the current thread block(CTA) has completed storing its partial
|
|
* results.
|
|
* @param kargs Kernel arguments, including the workspace pointer.
|
|
* @param cta_idx The index of the current thread block (CTA).
|
|
* @note This function utilizes a scalar store to write to the flags buffer.
|
|
*/
|
|
template <typename CompilerTarget_ = CompilerTarget>
|
|
CK_TILE_DEVICE core::arch::enable_if_target_id_t<CompilerTarget_,
|
|
core::arch::amdgcn_target_id::GFX90A,
|
|
core::arch::amdgcn_target_id::GFX942,
|
|
core::arch::amdgcn_target_id::GFX950>
|
|
SignalStorePartialDone(const KernelArgs_& kargs, index_t cta_idx) const
|
|
{
|
|
// s_store_dword needs a wave-uniform (SGPR) address; kargs-by-ref can leave the pointer
|
|
// in a VGPR (instantiation-dependent), which the assembler rejects. Pull the (uniform)
|
|
// operands into SGPRs first.
|
|
auto* sk_flags_ptr = reinterpret_cast<index_t*>(
|
|
amd_wave_read_first_lane(reinterpret_cast<uintptr_t>(kargs.workspace_ptr)));
|
|
index_t offset = amd_wave_read_first_lane(cta_idx) * sizeof(index_t);
|
|
|
|
// Depending on the architecture, the GLC flag will bypass the appropriate
|
|
// cache level(s) to ensure the write is visible to other workgroups. See the
|
|
// appropriate ISA for details about the GLC modifier.
|
|
asm volatile("s_store_dword %0, %1, %2 glc\n\t"
|
|
"s_waitcnt lgkmcnt(0)" // Wait for the store to complete
|
|
:
|
|
: "s"(1), "s"(sk_flags_ptr), "s"(offset)
|
|
: "memory");
|
|
}
|
|
|
|
template <typename CompilerTarget_ = CompilerTarget>
|
|
CK_TILE_DEVICE std::enable_if_t<
|
|
core::arch::is_target_id_any_of<CompilerTarget_,
|
|
core::arch::amdgcn_target_id::GFX1200,
|
|
core::arch::amdgcn_target_id::GFX1201,
|
|
core::arch::amdgcn_target_id::GFX12_GENERIC>() ||
|
|
core::arch::is_target_family_gfx11<CompilerTarget_>()>
|
|
SignalStorePartialDone(const KernelArgs_& kargs, index_t cta_idx) const
|
|
{
|
|
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
|
|
index_t offset = cta_idx * sizeof(index_t);
|
|
__amdgpu_buffer_rsrc_t buffer_rsc = make_builtin_buffer_resource(
|
|
sk_flags_ptr, sizeof(index_t) * kargs.tile_partitioner.get_sk_ctas());
|
|
|
|
if(threadIdx.x == 0)
|
|
{
|
|
__builtin_amdgcn_raw_buffer_store_b32(
|
|
1,
|
|
buffer_rsc,
|
|
offset,
|
|
0,
|
|
static_cast<int>(StreamKCoherency<CompilerTarget_>::BUFFER_COHERENCE));
|
|
}
|
|
}
|
|
|
|
template <typename CompilerTarget_ = CompilerTarget>
|
|
CK_TILE_DEVICE std::enable_if_t<!IsStreamKReductionSupportedArch<CompilerTarget_>()>
|
|
SignalStorePartialDone([[maybe_unused]] const KernelArgs_& kargs,
|
|
[[maybe_unused]] index_t cta_idx) const
|
|
{
|
|
static_assert(IsStreamKReductionSupportedArch<CompilerTarget_>(),
|
|
"SignalStorePartialDone not implemented for this architecture.");
|
|
}
|
|
|
|
/**
|
|
* @brief Waits for the thread block (cta_idx) to complete storing its partial results.
|
|
* @param kargs Kernel arguments, including the workspace pointer.
|
|
* @param cta_idx The index of the thread block (CTA).
|
|
* @note This function utilizes a scalar load to read from the flags
|
|
* buffer.
|
|
*/
|
|
template <typename CompilerTarget_ = CompilerTarget>
|
|
CK_TILE_DEVICE core::arch::enable_if_target_id_t<CompilerTarget_,
|
|
core::arch::amdgcn_target_id::GFX90A,
|
|
core::arch::amdgcn_target_id::GFX942,
|
|
core::arch::amdgcn_target_id::GFX950>
|
|
WaitStorePartialDone(const KernelArgs_& kargs, index_t cta_idx) const
|
|
{
|
|
// s_load_dword needs a wave-uniform (SGPR) address (see SignalStorePartialDone).
|
|
auto* sk_flags_ptr = reinterpret_cast<index_t*>(
|
|
amd_wave_read_first_lane(reinterpret_cast<uintptr_t>(kargs.workspace_ptr)));
|
|
index_t result;
|
|
index_t offset = amd_wave_read_first_lane(cta_idx) * sizeof(index_t);
|
|
|
|
do
|
|
{
|
|
// Depending on the architecture, the GLC flag will bypass the
|
|
// appropriate cache level(s) to avoid reading stale flags. See the
|
|
// appropriate ISA for details about the GLC modifier.
|
|
asm volatile("s_load_dword %0, %1, %2 glc\n\t"
|
|
"s_waitcnt lgkmcnt(0)" // Wait for the load to complete
|
|
: "=s"(result)
|
|
: "s"(sk_flags_ptr), "s"(offset)
|
|
: "memory");
|
|
} while(result != 1);
|
|
}
|
|
|
|
template <typename CompilerTarget_ = CompilerTarget>
|
|
CK_TILE_DEVICE core::arch::enable_if_target_id_t<CompilerTarget_,
|
|
core::arch::amdgcn_target_id::GFX1200,
|
|
core::arch::amdgcn_target_id::GFX1201,
|
|
core::arch::amdgcn_target_id::GFX12_GENERIC>
|
|
WaitStorePartialDone(const KernelArgs_& kargs, index_t cta_idx) const
|
|
{
|
|
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
|
|
index_t result;
|
|
index_t offset = cta_idx * sizeof(index_t);
|
|
do
|
|
{
|
|
asm volatile("s_load_b32 %0, %1, %2 scope:SCOPE_DEV\n\t"
|
|
"s_wait_kmcnt 0" // Wait for the load to complete
|
|
: "=s"(result)
|
|
: "s"(sk_flags_ptr), "s"(offset)
|
|
: "memory");
|
|
} while(result != 1);
|
|
}
|
|
|
|
template <typename CompilerTarget_ = CompilerTarget>
|
|
CK_TILE_DEVICE
|
|
core::arch::enable_if_target_family_id_t<CompilerTarget_,
|
|
core::arch::amdgcn_target_family_id::GFX11>
|
|
WaitStorePartialDone(const KernelArgs_& kargs, index_t cta_idx) const
|
|
{
|
|
auto* sk_flags_ptr = static_cast<index_t*>(kargs.workspace_ptr);
|
|
index_t result;
|
|
index_t offset = cta_idx * sizeof(index_t);
|
|
do
|
|
{
|
|
asm volatile("s_load_b32 %0, %1, %2 glc dlc\n\t"
|
|
"s_waitcnt lgkmcnt(0)" // Wait for the load to complete
|
|
: "=s"(result)
|
|
: "s"(sk_flags_ptr), "s"(offset)
|
|
: "memory");
|
|
} while(result != 1);
|
|
}
|
|
|
|
template <typename CompilerTarget_ = CompilerTarget>
|
|
CK_TILE_DEVICE std::enable_if_t<!IsStreamKReductionSupportedArch<CompilerTarget_>()>
|
|
WaitStorePartialDone([[maybe_unused]] const KernelArgs_& kargs,
|
|
[[maybe_unused]] index_t cta_idx) const
|
|
{
|
|
static_assert(IsStreamKReductionSupportedArch<CompilerTarget_>(),
|
|
"WaitStorePartialDone not implemented for this architecture.");
|
|
}
|
|
|
|
/**
|
|
* @brief Adds the values of a block tile to an output block tile.
|
|
* @param in_out_block_tile The output block tile to which values are added.
|
|
* @param in_block_tile The input block tile whose values are added.
|
|
* @note This function iterates over the distributed spans of the block tiles and updates
|
|
* the output block tile with accumulated values.
|
|
*/
|
|
template <typename OAccTile>
|
|
CK_TILE_DEVICE void AddBlockTile(OAccTile& in_out_block_tile,
|
|
const OAccTile& in_block_tile) const
|
|
{
|
|
using BlockType = remove_cvref_t<decltype(in_out_block_tile)>;
|
|
constexpr auto o_spans = BlockType::get_distributed_spans();
|
|
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
|
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
|
|
constexpr auto idx = make_tuple(idx0, idx1);
|
|
in_out_block_tile(idx) = in_out_block_tile[idx] + in_block_tile[idx];
|
|
});
|
|
});
|
|
}
|
|
|
|
/**
|
|
* @brief Loads a partial block tile from the workspace buffer.
|
|
* @param kargs Kernel arguments, including the workspace pointer.
|
|
* @param cta_idx The index of the thread block (CTA).
|
|
* @param c_block_tile_dist The tile distribution for the block.
|
|
* @return The loaded partial block tile.
|
|
* @note This function calculates the buffer pointer and uses the tile distribution for
|
|
* loading the partial block tile.
|
|
*/
|
|
template <typename DataType, typename OAccTileDist>
|
|
CK_TILE_DEVICE auto LoadPartial(const KernelArgs_& kargs,
|
|
index_t cta_idx,
|
|
const OAccTileDist& c_block_tile_dist) const
|
|
{
|
|
const auto c_block_tile_buffer_size =
|
|
TilePartitioner::MPerBlock * TilePartitioner::NPerBlock * sizeof(DataType);
|
|
void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) +
|
|
kargs.tile_partitioner.get_flags_buffer_size() +
|
|
cta_idx * c_block_tile_buffer_size;
|
|
|
|
const auto& partial_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
|
static_cast<DataType*>(partial_buffer_ptr),
|
|
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
|
make_tuple(TilePartitioner::NPerBlock, 1),
|
|
number<GetVectorSizePartials()>{},
|
|
number<1>{});
|
|
|
|
auto partial_tile_window = make_tile_window(
|
|
partial_tensor_view,
|
|
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
|
{0, 0},
|
|
MakePartialsDistribution());
|
|
|
|
auto partials_tile = load_tile(partial_tile_window);
|
|
|
|
// Since the partials distribution is not the same as the C block distribution, we must
|
|
// describe the contents in the partials tile with the C block distribution.
|
|
// Note: The data assigned to threads does not change between distributions.
|
|
auto partials_tile_with_c_distr = make_static_distributed_tensor<DataType>(
|
|
c_block_tile_dist, partials_tile.get_thread_buffer());
|
|
|
|
return partials_tile_with_c_distr;
|
|
}
|
|
|
|
/**
|
|
* @brief Returns the vector size to be used for reading from and writing to partials.
|
|
* @return The vector size
|
|
*/
|
|
CK_TILE_DEVICE static constexpr index_t GetVectorSizePartials()
|
|
{
|
|
// We use kCM1PerLane from the C register layout of the warp GEMM which corresponds to the
|
|
// maximum vector width
|
|
return WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane;
|
|
}
|
|
|
|
/**
|
|
* @brief Returns distribution used for reading from and writing to partials.
|
|
* @return The distribution.
|
|
* @note This will result in optimized reads from and writes to partials when C is row major.
|
|
* Additional functionality should be added to ensure optimized accesses to partials when C is
|
|
* column major. Since the C-Shuffle epilogue only supports C as row major, this is not a
|
|
* current limitation.
|
|
*/
|
|
CK_TILE_DEVICE static constexpr auto MakePartialsDistribution()
|
|
{
|
|
// Create the encoding to describe waves within a block
|
|
constexpr index_t m_warp = BlockGemmShape::BlockWarps::at(number<0>{});
|
|
constexpr index_t n_warp = BlockGemmShape::BlockWarps::at(number<1>{});
|
|
|
|
constexpr index_t m_iter_per_warp = TilePartitioner::MPerBlock / (m_warp * WarpGemm::kM);
|
|
constexpr index_t n_iter_per_warp = TilePartitioner::NPerBlock / (n_warp * WarpGemm::kN);
|
|
|
|
constexpr auto partials_outer_dstr_encoding = tile_distribution_encoding<
|
|
sequence<>,
|
|
tuple<sequence<m_iter_per_warp, m_warp>, sequence<n_iter_per_warp, n_warp>>,
|
|
tuple<sequence<1, 2>>,
|
|
tuple<sequence<1, 1>>,
|
|
sequence<1, 2>,
|
|
sequence<0, 0>>{};
|
|
|
|
// Create the encoding to describe threads within a wave
|
|
constexpr index_t vector_size = GetVectorSizePartials();
|
|
constexpr index_t m_warp_repeat = WarpGemm::WarpGemmAttribute::Impl::kCM0PerLane;
|
|
constexpr index_t warp_tile_n_threads = WarpGemm::kN / vector_size;
|
|
constexpr index_t warp_tile_m_threads = get_warp_size() / warp_tile_n_threads;
|
|
|
|
// This inner encoding ensures that contiguous threads perform vectorized writes along the
|
|
// same row in C.
|
|
constexpr auto partials_inner_dstr_encoding =
|
|
tile_distribution_encoding<sequence<>,
|
|
tuple<sequence<m_warp_repeat, warp_tile_m_threads>,
|
|
sequence<warp_tile_n_threads, vector_size>>,
|
|
tuple<sequence<1, 2>>,
|
|
tuple<sequence<1, 0>>,
|
|
sequence<1, 2>,
|
|
sequence<0, 1>>{};
|
|
|
|
// Combine the outer and inner encoding
|
|
constexpr auto partials_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
|
partials_outer_dstr_encoding, partials_inner_dstr_encoding);
|
|
|
|
return make_static_tile_distribution(partials_dstr_encode);
|
|
}
|
|
|
|
/**
|
|
* @brief Stores a partial block tile to the workspace buffer.
|
|
* @param kargs Kernel arguments, including the workspace pointer.
|
|
* @param cta_idx The index of the thread block (CTA).
|
|
* @param c_block_tile The block tile to be stored.
|
|
* @note This function calculates the buffer pointer and uses the tile window for storing
|
|
* the partial block tile.
|
|
*/
|
|
template <typename OAccTile>
|
|
CK_TILE_DEVICE void
|
|
StorePartial(const KernelArgs_& kargs, index_t cta_idx, const OAccTile& c_block_tile) const
|
|
{
|
|
const auto c_block_tile_buffer_size = TilePartitioner::MPerBlock *
|
|
TilePartitioner::NPerBlock *
|
|
sizeof(typename OAccTile::DataType);
|
|
void* partial_buffer_ptr = static_cast<char*>(kargs.workspace_ptr) +
|
|
kargs.tile_partitioner.get_flags_buffer_size() +
|
|
cta_idx * c_block_tile_buffer_size;
|
|
|
|
const auto& partial_tensor_view = make_naive_tensor_view<
|
|
address_space_enum::global,
|
|
memory_operation_enum::set,
|
|
StreamKCoherency<decltype(core::arch::get_compiler_target())>::BUFFER_COHERENCE>(
|
|
static_cast<typename OAccTile::DataType*>(partial_buffer_ptr),
|
|
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
|
make_tuple(TilePartitioner::NPerBlock, 1),
|
|
number<GetVectorSizePartials()>{},
|
|
number<1>{});
|
|
|
|
auto partial_tile_window = make_tile_window(
|
|
partial_tensor_view,
|
|
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
|
{0, 0},
|
|
MakePartialsDistribution());
|
|
|
|
// Since the C block distribution is not the same as the partials distribution, we must
|
|
// describe the contents in the c_block_tile with the partials distribution.
|
|
// Note: The data assigned to threads does not change between distributions.
|
|
auto c_with_partials_dist = make_static_distributed_tensor<typename OAccTile::DataType>(
|
|
MakePartialsDistribution(), c_block_tile.get_thread_buffer());
|
|
|
|
store_tile(partial_tile_window, c_with_partials_dist);
|
|
// Wait for all vector stores for this wavefront to complete
|
|
s_waitcnt</*vmcnt*/ 0, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
|
|
// Wait for all wavefronts in this workgroup to arrive here before continuing
|
|
__builtin_amdgcn_s_barrier();
|
|
}
|
|
};
|
|
|
|
/// @brief StreamK data-parallel (DP) dispatch: handles persistent vs non-persistent DP,
|
|
/// then delegates to the Stream-K loop. Shared by GEMM and Conv StreamK kernels.
|
|
///
|
|
/// Non-persistent: launches dp_ctas + sk_ctas workgroups. DP workgroups each process
|
|
/// one full tile; SK workgroups share the remaining tiles' K-iterations.
|
|
/// Persistent: launches num_cu * occupancy workgroups. Each loops over DP tiles
|
|
/// (round-robin), then proceeds to SK work.
|
|
///
|
|
/// @tparam TilePartitioner_ Partitioner type (persistent or non-persistent specialization).
|
|
/// @param tile_partitioner The partitioner instance from kernel args.
|
|
/// @param dp_tile_func Callable(index_t tile_idx) - processes one full DP tile.
|
|
/// @param sk_func Callable(index_t sk_cta_idx) - runs the StreamK loop for this CTA.
|
|
template <typename TilePartitioner_, typename DPTileFunc, typename SKFunc>
|
|
CK_TILE_DEVICE void StreamKDispatch(const TilePartitioner_& tile_partitioner,
|
|
DPTileFunc dp_tile_func,
|
|
SKFunc sk_func,
|
|
index_t block_idx)
|
|
{
|
|
if constexpr(TilePartitioner_::PERSISTENT)
|
|
{
|
|
// Persistent: each workgroup loops over multiple DP tiles, then does SK work
|
|
for(index_t tile_idx = block_idx; tile_idx < tile_partitioner.get_dp_tiles();
|
|
tile_idx += tile_partitioner.get_max_active_wgs())
|
|
{
|
|
dp_tile_func(tile_idx);
|
|
block_sync_lds();
|
|
}
|
|
sk_func(block_idx);
|
|
}
|
|
else
|
|
{
|
|
// Non-persistent: dedicated DP workgroups, then dedicated SK workgroups
|
|
const index_t dp_ctas = tile_partitioner.get_dp_ctas();
|
|
if(block_idx < dp_ctas)
|
|
dp_tile_func(block_idx);
|
|
else
|
|
sk_func(block_idx - dp_ctas);
|
|
}
|
|
}
|
|
|
|
} // namespace ck_tile
|