mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[rocm-libraries] ROCm/rocm-libraries#5393 (commit d51b649)
[CK Tile] StreamK support for Bwd Weight grouped convolutions (#5393) ## Motivation Add StreamK work distribution to the CK Tile grouped convolution backward weight kernel. Split-K divides the K-dimension uniformly across a fixed `k_batch`, which causes load imbalance when the number of output tiles doesn't evenly fill the GPU. StreamK distributes total K-iterations evenly across workgroups, improving utilization on these shapes. ## Technical Details StreamK is added as an `if constexpr` branch in the existing kernel, selected by the `TilePartitioner_` template parameter. Two reduction strategies are supported: - **Linear**: tile-starter sequentially accumulates partials from contributing CTAs - **Tree**: pairwise binary tree reduction (O(log n) depth, faster for many contributors) Both persistent and non-persistent data-parallel (DP) sections are supported. Key changes: - `grouped_convolution_backward_weight_kernel.hpp`: StreamK execution path with `RunStreamK`/`RunStreamKLoop`, partial store/load via workspace, flag-based cross-CTA synchronization, `GridSize`/`MakeKernelArgs`/`GetWorkSpaceSize` extensions - `streamk_common.hpp`: Shared `StreamKReductionOps` (reduction helpers) and `StreamKDispatch` (persistent/non-persistent DP dispatch), used by both GEMM and Conv StreamK kernels - `streamk_gemm_kernel.hpp`: Refactored to use shared helpers - Merged split-K and StreamK example invokers via `PartitionerPolicy` template parameter - StreamK example binary with `--streamk_reduction=linear|tree` and `--streamk_persistent=0|1` - CK Builder integration: `SpecifiesStreamK` concept, `TilePartitionerType` factory helper, `InstanceTraits` with StreamK fields - 30 tests: host-side, GPU end-to-end (Linear + Tree + Persistent DP), negative, builder regression ### Performance (MI355X, gfx950) Speedup relative to best split-K (sweep over k_batch={1,2,4,8,16,32}): | Shape | 16x64 tiles | | 128x128 tiles | | |---|---|---|---|---| | | Split-K | StreamK | Split-K | StreamK | | 1x1 128x128 N=32 28x28 | 1.00x | 0.54x | 1.00x | 0.81x | | 3x3 128x128 N=32 14x14 | 1.00x | 0.59x | 1.00x | 0.62x | | 1x1 256x64 N=32 56x56 | 1.00x | 0.83x | 1.00x | 1.83x | | 3x3 512x512 N=2 7x7 | 1.00x | 1.12x | 1.00x | 0.62x | | 1x1 1024x1024 N=4 7x7 | 1.00x | 1.09x | 1.00x | 0.60x | | 3x3 128x128 N=32 28x28 | 1.00x | 0.44x | 1.00x | 0.96x | | 3x3 256x256 N=32 14x14 | 1.00x | 0.67x | 1.00x | 0.93x | | 3x3 512x512 N=32 7x7 | 1.00x | 0.98x | 1.00x | 1.16x | StreamK's value depends on tile config: with larger tiles (fewer output tiles), StreamK delivers up to 1.83x speedup on bottleneck shapes and up to 1.16x on typical large-channel convolutions. Tree reduction consistently outperforms Linear when multiple CTAs contribute to the same tile (up to 2.87x faster), due to O(log n) reduction depth vs O(n) sequential accumulation. The table reports the best of Linear and Tree for each shape. ## Test Plan ```bash ninja -C build test_ck_tile_grouped_conv_bwd_weight_streamk ./build/bin/test_ck_tile_grouped_conv_bwd_weight_streamk # Builder tests (requires CK_EXPERIMENTAL_BUILDER=ON) ninja -C build check-builder ``` 30 tests covering: - Host-side: type traits, kernel args construction, grid size, workspace size - GPU end-to-end (Linear + Tree): small/medium shapes, multi-group, stride>1, pure-DP degeneration, single-tile all-SK, large GemmK, higher occupancy - Persistent DP: Linear + Tree with persistent data-parallel dispatch - Negative: `IsSupportedArgument` rejects unaligned K and C - Builder: Create (instance string validation) + Execution (reference comparison) + instance string regression ## Test Result All 30 conv StreamK tests pass on MI355X (gfx950). 64/64 GEMM StreamK tests pass. Full `check-builder` suite passes. Tolerances computed dynamically using `calculate_rtol_atol` pattern (fp16 ULP-aware). ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
36f2ec23f5
commit
58475d3f45
@@ -16,12 +16,25 @@
|
||||
|
||||
#include "ck_tile/ops/grouped_convolution/utils/split_k_utils.hpp"
|
||||
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_coherency.hpp"
|
||||
#include "ck_tile/ops/common/streamk_common.hpp"
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BUILDER
|
||||
#include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp"
|
||||
#endif
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T>
|
||||
struct is_streamk_partitioner : std::false_type
|
||||
{
|
||||
};
|
||||
template <typename Shape, StreamKReductionStrategy S, bool P>
|
||||
struct is_streamk_partitioner<StreamKTilePartitioner<Shape, S, P>> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename... Args>
|
||||
CK_TILE_HOST void LogInfo(Args&&... args) noexcept
|
||||
{
|
||||
@@ -32,7 +45,7 @@ CK_TILE_HOST void LogInfo(Args&&... args) noexcept
|
||||
}
|
||||
|
||||
/// @brief The Grouped Convolution kernel device arguments.
|
||||
template <typename GroupedConvTraitsType_>
|
||||
template <typename GroupedConvTraitsType_, typename TilePartitioner_ = void>
|
||||
struct GroupedConvBwdWeightKernelArgs
|
||||
{
|
||||
|
||||
@@ -354,6 +367,23 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
long_index_t group_stride_a;
|
||||
long_index_t group_stride_b;
|
||||
long_index_t group_stride_c;
|
||||
|
||||
void* workspace_ptr = nullptr;
|
||||
|
||||
// StreamK tile partitioner — stored directly when TilePartitioner_ is a real type,
|
||||
// empty struct when void (Split-K path). Constructed with dummy values here;
|
||||
// properly initialized in MakeKernelArgs before device-side use.
|
||||
struct EmptyPartitioner
|
||||
{
|
||||
};
|
||||
using PartitionerType =
|
||||
std::conditional_t<std::is_void_v<TilePartitioner_>, EmptyPartitioner, TilePartitioner_>;
|
||||
PartitionerType tile_partitioner = []() {
|
||||
if constexpr(std::is_void_v<TilePartitioner_>)
|
||||
return EmptyPartitioner{};
|
||||
else
|
||||
return TilePartitioner_(1, 1, 1, 1);
|
||||
}();
|
||||
};
|
||||
|
||||
/// @brief The Grouped Convolution Backward Weight kernel template.
|
||||
@@ -424,10 +454,15 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>;
|
||||
using WeiDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using GroupedConvBwdWeightKernelArgsSpecialized =
|
||||
GroupedConvBwdWeightKernelArgs<GroupedConvTraitsType_>;
|
||||
|
||||
static constexpr bool IsSplitKSupported = true;
|
||||
static constexpr bool IsStreamK = is_streamk_partitioner<TilePartitioner>::value;
|
||||
|
||||
using GroupedConvBwdWeightKernelArgsSpecialized =
|
||||
std::conditional_t<IsStreamK,
|
||||
GroupedConvBwdWeightKernelArgs<GroupedConvTraitsType_, TilePartitioner>,
|
||||
GroupedConvBwdWeightKernelArgs<GroupedConvTraitsType_>>;
|
||||
|
||||
using AccDataType = remove_cvref_t<typename EpiloguePipeline::AccDataType>;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
@@ -442,6 +477,32 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
static_assert(GroupedConvTraitsType_::ExplicitGemm == false ||
|
||||
GroupedConvTraitsType_::NumGroupsToMerge == 1,
|
||||
"Not supported!");
|
||||
static_assert(!IsStreamK || NumDTensor == 0,
|
||||
"D tensor per-group offsets not implemented for StreamK path");
|
||||
|
||||
// StreamK reduction helpers (partial store/load, flag signaling, tile accumulation).
|
||||
// Shared with the StreamK GEMM kernel via StreamKReductionOps in streamk_common.hpp.
|
||||
using StreamKOps = StreamKReductionOps<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GroupedConvBwdWeightKernelArgsSpecialized>;
|
||||
|
||||
CK_TILE_HOST static index_t
|
||||
GetWorkSpaceSize(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
|
||||
{
|
||||
if constexpr(IsStreamK)
|
||||
return kargs.tile_partitioner.get_workspace_size(sizeof(AccDataType)) * kargs.GemmBatch;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Post-construction setter: workspace is allocated by the caller after
|
||||
// GetWorkSpaceSize() and must outlive the kernel launch. Can't be moved into
|
||||
// the constructor because kargs is a POD value type copied to GPU constant memory.
|
||||
CK_TILE_HOST static void SetWorkSpacePointer(GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
void* workspace_ptr)
|
||||
{
|
||||
kargs.workspace_ptr = workspace_ptr;
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
@@ -463,7 +524,8 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
"SplitImage",
|
||||
EnableSplitImage,
|
||||
"ExplicitGemm",
|
||||
GroupedConvTraitsType_::ExplicitGemm
|
||||
GroupedConvTraitsType_::ExplicitGemm,
|
||||
IsStreamK ? "StreamK" : "SplitK"
|
||||
);
|
||||
// clang-format on
|
||||
}
|
||||
@@ -483,11 +545,17 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
}
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GridSize(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
|
||||
CK_TILE_HOST static auto GridSize(const GroupedConvBwdWeightKernelArgsSpecialized& kargs)
|
||||
{
|
||||
return dim3(
|
||||
TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
|
||||
if constexpr(IsStreamK)
|
||||
{
|
||||
auto sk_grid = kargs.tile_partitioner.grid_size();
|
||||
return dim3(sk_grid.x, kargs.GemmBatch, 1);
|
||||
}
|
||||
else
|
||||
return dim3(TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN),
|
||||
kargs.GemmBatch,
|
||||
kargs.k_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize()
|
||||
@@ -495,8 +563,10 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr GroupedConvBwdWeightKernelArgsSpecialized
|
||||
MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs)
|
||||
CK_TILE_HOST static GroupedConvBwdWeightKernelArgsSpecialized
|
||||
MakeKernelArgs(const GroupedConvBwdWeightHostArgs& hostArgs,
|
||||
[[maybe_unused]] int num_cu = 0,
|
||||
[[maybe_unused]] int occupancy = 0)
|
||||
{
|
||||
LogInfo("MPerBlock: ",
|
||||
number<TilePartitioner::MPerBlock>{},
|
||||
@@ -507,18 +577,42 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
|
||||
auto kernel_args = GroupedConvBwdWeightKernelArgsSpecialized(hostArgs);
|
||||
|
||||
using KernelImpl = GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType_,
|
||||
TilePartitioner_,
|
||||
GemmPipeline_,
|
||||
EpiloguePipeline_>;
|
||||
|
||||
// Negative k_batch value: split-K autodeduction.
|
||||
if(kernel_args.k_batch < 0)
|
||||
if constexpr(IsStreamK)
|
||||
{
|
||||
const auto optimal_split_k =
|
||||
calculate_optimal_k_batch<GemmPipeline_::BlockSize, KernelImpl, TilePartitioner_>(
|
||||
kernel_args);
|
||||
kernel_args.k_batch = optimal_split_k;
|
||||
// StreamK: construct tile partitioner and embed it in the args.
|
||||
// Use provided num_cu/occupancy, or query HW.
|
||||
if(num_cu == 0)
|
||||
{
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
ck_tile::hip_check_error(hipGetDevice(&dev));
|
||||
ck_tile::hip_check_error(hipGetDeviceProperties(&dev_prop, dev));
|
||||
num_cu = dev_prop.multiProcessorCount;
|
||||
}
|
||||
if(occupancy == 0)
|
||||
occupancy = 1; // conservative default; caller may use hipOccupancy API
|
||||
|
||||
const index_t grid = num_cu * occupancy;
|
||||
kernel_args.tile_partitioner =
|
||||
TilePartitioner(kernel_args.GemmM, kernel_args.GemmN, kernel_args.GemmK, grid);
|
||||
kernel_args.k_batch = 1; // StreamK does its own K distribution
|
||||
}
|
||||
else
|
||||
{
|
||||
using KernelImpl = GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType_,
|
||||
TilePartitioner_,
|
||||
GemmPipeline_,
|
||||
EpiloguePipeline_>;
|
||||
|
||||
// Negative k_batch value: split-K autodeduction.
|
||||
if(kernel_args.k_batch < 0)
|
||||
{
|
||||
const auto optimal_split_k =
|
||||
calculate_optimal_k_batch<GemmPipeline_::BlockSize,
|
||||
KernelImpl,
|
||||
TilePartitioner_>(kernel_args);
|
||||
kernel_args.k_batch = optimal_split_k;
|
||||
}
|
||||
}
|
||||
|
||||
return kernel_args;
|
||||
@@ -539,6 +633,21 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
return false;
|
||||
}
|
||||
}
|
||||
// Runtime arch check — complements the static_assert in operator().
|
||||
// Both are needed: this check runs on the host (where get_compiler_target()
|
||||
// isn't available since HIP's host pass doesn't define __gfx*__ macros),
|
||||
// while the static_assert in operator() catches misuse at device compile time.
|
||||
if constexpr(IsStreamK)
|
||||
{
|
||||
const auto name = get_device_name();
|
||||
if(name != "gfx90a" && name != "gfx942" && name != "gfx950")
|
||||
{
|
||||
LogInfo("StreamK requires cross-CU buffer coherence. "
|
||||
"Supported: gfx90a, gfx942, gfx950. Got: ",
|
||||
name);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if(kargs.k_batch < 1)
|
||||
{
|
||||
LogInfo("k_batch must be at least one. Ensure argument is created via MakeKernelArgs.");
|
||||
@@ -574,17 +683,20 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
}
|
||||
}
|
||||
|
||||
if(integer_divide_ceil(kargs.GemmK,
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})) <
|
||||
kargs.k_batch)
|
||||
if constexpr(!IsStreamK)
|
||||
{
|
||||
LogInfo("KBatch is too large, part of GPU wouldn't be utilized! GemmK: ",
|
||||
kargs.GemmK,
|
||||
", BlockGemmShape K: ",
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}),
|
||||
", k_batch: ",
|
||||
kargs.k_batch);
|
||||
return false;
|
||||
if(integer_divide_ceil(kargs.GemmK,
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})) <
|
||||
kargs.k_batch)
|
||||
{
|
||||
LogInfo("KBatch is too large, part of GPU wouldn't be utilized! GemmK: ",
|
||||
kargs.GemmK,
|
||||
", BlockGemmShape K: ",
|
||||
TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}),
|
||||
", k_batch: ",
|
||||
kargs.k_batch);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
|
||||
@@ -929,9 +1041,239 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
ExplicitBatchedGemmKernel{}(batched_gemm_kargs);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void RunStreamK(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const
|
||||
{
|
||||
// Device-side compile-time arch check — complements the runtime check in
|
||||
// IsSupportedArgument(). Both are needed: the runtime check runs on the host
|
||||
// (where get_compiler_target() isn't available since HIP's host pass doesn't
|
||||
// define __gfx*__ macros), while this catches misuse at device compile time.
|
||||
static_assert(
|
||||
StreamKCoherency<decltype(core::arch::get_compiler_target())>::BUFFER_COHERENCE !=
|
||||
amd_buffer_coherence_enum::coherence_default,
|
||||
"StreamK requires cross-CU buffer coherence (StreamKCoherency specialization). "
|
||||
"Currently supported: gfx90a, gfx942, gfx950.");
|
||||
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// Group offset (blockIdx.y = group batch index)
|
||||
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
|
||||
const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
|
||||
const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
|
||||
const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
|
||||
|
||||
const OutDataType* a_ptr = static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
|
||||
const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
|
||||
WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
|
||||
|
||||
// Offset workspace per group so groups don't interfere.
|
||||
// Safe to mutate kargs: on GPU each workgroup operates on its own
|
||||
// register-local copy of the kernel arguments.
|
||||
const auto per_group_ws_size =
|
||||
kargs.tile_partitioner.get_workspace_size(sizeof(AccDataType));
|
||||
kargs.workspace_ptr =
|
||||
static_cast<char*>(kargs.workspace_ptr) + blockIdY * per_group_ws_size;
|
||||
|
||||
const index_t dp_num_loop = kargs.tile_partitioner.get_iters_per_tile();
|
||||
|
||||
StreamKDispatch(
|
||||
kargs.tile_partitioner,
|
||||
[&](index_t tile_idx) {
|
||||
// Data-parallel workgroup: process one full tile
|
||||
const auto tile_mn = kargs.tile_partitioner.get_output_tile_index(tile_idx);
|
||||
const index_t i_m =
|
||||
amd_wave_read_first_lane(tile_mn[I0] * TilePartitioner::MPerBlock);
|
||||
const index_t i_n =
|
||||
amd_wave_read_first_lane(tile_mn[I1] * TilePartitioner::NPerBlock);
|
||||
|
||||
RunGemm(a_ptr,
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
c_ptr,
|
||||
smem_ptr,
|
||||
kargs,
|
||||
dp_num_loop,
|
||||
i_m,
|
||||
i_n,
|
||||
/*block_idx_k=*/0);
|
||||
},
|
||||
[&](index_t sk_cta_idx) {
|
||||
RunStreamKLoop(kargs, sk_cta_idx, a_ptr, b_ptr, c_ptr, smem_ptr);
|
||||
});
|
||||
}
|
||||
|
||||
/// @brief Stream-K loop: iterate over assigned K-iterations, run GEMM pipeline,
|
||||
/// and perform Linear or Tree reduction to accumulate partial results.
|
||||
CK_TILE_DEVICE void RunStreamKLoop(GroupedConvBwdWeightKernelArgsSpecialized& kargs,
|
||||
index_t sk_cta_idx,
|
||||
const OutDataType* a_ptr,
|
||||
const InDataType* b_ptr,
|
||||
WeiDataType* c_ptr,
|
||||
char* smem_ptr) const
|
||||
{
|
||||
const StreamKOps sk_ops{};
|
||||
|
||||
index_t iter_start, iter_end;
|
||||
kargs.tile_partitioner.get_iter_boundaries(iter_start, iter_end, sk_cta_idx);
|
||||
|
||||
while(iter_start < iter_end)
|
||||
{
|
||||
index_t tile_idx =
|
||||
amd_wave_read_first_lane(kargs.tile_partitioner.get_tile_index(iter_start));
|
||||
|
||||
index_t tile_iter_start, tile_iter_end;
|
||||
kargs.tile_partitioner.get_tile_boundaries(tile_iter_start, tile_iter_end, tile_idx);
|
||||
|
||||
index_t local_iter_start = amd_wave_read_first_lane(
|
||||
kargs.tile_partitioner.get_local_iter(iter_start, tile_iter_start));
|
||||
index_t local_iter_end =
|
||||
amd_wave_read_first_lane(kargs.tile_partitioner.get_local_iter_end(
|
||||
tile_iter_start, iter_end, tile_iter_end));
|
||||
|
||||
index_t num_loop_sk = local_iter_end - local_iter_start;
|
||||
|
||||
// Compute M/N tile indices from 1D tile index
|
||||
const auto c_macro_tile_idx = kargs.tile_partitioner.get_output_tile_index(tile_idx);
|
||||
const index_t i_m =
|
||||
amd_wave_read_first_lane(c_macro_tile_idx[I0] * TilePartitioner::MPerBlock);
|
||||
const index_t i_n =
|
||||
amd_wave_read_first_lane(c_macro_tile_idx[I1] * TilePartitioner::NPerBlock);
|
||||
|
||||
// K offset = local_iter_start * KPerBlock
|
||||
const index_t i_k =
|
||||
amd_wave_read_first_lane(local_iter_start * TilePartitioner::KPerBlock);
|
||||
|
||||
// Create block windows and run pipeline
|
||||
const auto& a_block_window = MakeABlockWindow(a_ptr, kargs, i_m, i_k);
|
||||
const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, i_n, i_k);
|
||||
const auto& d_block_window = MakeDBlockWindows(kargs.ds_ptr, kargs, i_m, i_n);
|
||||
|
||||
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop_sk);
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop_sk);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
a_block_window, b_block_window, num_loop_sk, has_hot_loop, tail_num, smem_ptr);
|
||||
|
||||
auto tile_started = iter_start == tile_iter_start;
|
||||
auto tile_ended = iter_end >= tile_iter_end;
|
||||
|
||||
if constexpr(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Linear)
|
||||
{
|
||||
// Linear Reduction: tile-starter sequentially accumulates all
|
||||
// partials from subsequent CTAs in order.
|
||||
if(!tile_started)
|
||||
{
|
||||
sk_ops.StorePartial(kargs, sk_cta_idx, c_block_tile);
|
||||
sk_ops.SignalStorePartialDone(kargs, sk_cta_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto accum_block_tile = c_block_tile;
|
||||
if(!tile_ended)
|
||||
{
|
||||
const index_t iter_per_tile = kargs.tile_partitioner.get_iters_per_tile();
|
||||
const index_t iter_per_cta = kargs.tile_partitioner.get_iters_per_sk_cta();
|
||||
const index_t extra_iters = kargs.tile_partitioner.get_extra_iters();
|
||||
int accum_iters = local_iter_end - local_iter_start;
|
||||
int next_cta = sk_cta_idx + 1;
|
||||
|
||||
while(accum_iters < iter_per_tile)
|
||||
{
|
||||
sk_ops.WaitStorePartialDone(kargs, next_cta);
|
||||
|
||||
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
|
||||
sk_ops.AddBlockTile(
|
||||
accum_block_tile,
|
||||
sk_ops.template LoadPartial<typename BlockType::DataType>(
|
||||
kargs, next_cta, c_block_tile.get_tile_distribution()));
|
||||
|
||||
accum_iters += iter_per_cta + (next_cta < extra_iters);
|
||||
++next_cta;
|
||||
}
|
||||
}
|
||||
|
||||
auto c_block_window_out =
|
||||
MakeCBlockWindow<memory_operation_enum::set>(c_ptr, kargs, i_m, i_n);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window_out, accum_block_tile, d_block_window, smem_ptr);
|
||||
}
|
||||
}
|
||||
else if constexpr(TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Tree)
|
||||
{
|
||||
// Tree Reduction: pairwise reduction with stride doubling.
|
||||
// At each round, half the CTAs store their accumulated partial
|
||||
// and exit; the other half load and accumulate from their partner.
|
||||
// The tile-starter writes the final result.
|
||||
auto accum_block_tile = c_block_tile;
|
||||
index_t tile_local_cta_idx = amd_wave_read_first_lane(
|
||||
kargs.tile_partitioner.get_tile_local_cta_index(tile_iter_start, sk_cta_idx));
|
||||
|
||||
index_t stride = amd_wave_read_first_lane(1);
|
||||
|
||||
for(;; stride <<= 1)
|
||||
{
|
||||
// Partner index is a *global* SK CTA index. This works because
|
||||
// CTAs contributing to the same tile always have contiguous global
|
||||
// SK CTA indices (guaranteed by the partitioner's iteration assignment).
|
||||
const index_t partner_cta_idx = amd_wave_read_first_lane(sk_cta_idx + stride);
|
||||
const index_t partner_start_iter = amd_wave_read_first_lane(
|
||||
kargs.tile_partitioner.get_start_iter(partner_cta_idx));
|
||||
bool partner_in_tile =
|
||||
amd_wave_read_first_lane(partner_start_iter < tile_iter_end);
|
||||
|
||||
// If the partner of the tile-starter is not in this tile,
|
||||
// then all partials are accumulated — write final result.
|
||||
if(tile_started && !partner_in_tile)
|
||||
{
|
||||
auto c_block_window_out =
|
||||
MakeCBlockWindow<memory_operation_enum::set>(c_ptr, kargs, i_m, i_n);
|
||||
EpiloguePipeline{}(
|
||||
c_block_window_out, accum_block_tile, d_block_window, smem_ptr);
|
||||
break;
|
||||
}
|
||||
|
||||
// This CTA's turn to read from its partner and accumulate.
|
||||
if(tile_local_cta_idx % (stride << 1) == 0)
|
||||
{
|
||||
if(partner_in_tile)
|
||||
{
|
||||
sk_ops.WaitStorePartialDone(kargs, partner_cta_idx);
|
||||
using BlockType = remove_cvref_t<decltype(c_block_tile)>;
|
||||
sk_ops.AddBlockTile(
|
||||
accum_block_tile,
|
||||
sk_ops.template LoadPartial<typename BlockType::DataType>(
|
||||
kargs, partner_cta_idx, c_block_tile.get_tile_distribution()));
|
||||
}
|
||||
}
|
||||
// This CTA's turn to write its partial and exit.
|
||||
else
|
||||
{
|
||||
sk_ops.StorePartial(kargs, sk_cta_idx, accum_block_tile);
|
||||
sk_ops.SignalStorePartialDone(kargs, sk_cta_idx);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(
|
||||
TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Linear ||
|
||||
TilePartitioner::ReductionStrategy == StreamKReductionStrategy::Tree,
|
||||
"Unsupported StreamK reduction strategy for conv bwd weight.");
|
||||
}
|
||||
|
||||
// Advance to next tile
|
||||
iter_start = tile_iter_end;
|
||||
block_sync_lds();
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const
|
||||
{
|
||||
if constexpr(GroupedConvTraitsType_::ExplicitGemm)
|
||||
if constexpr(IsStreamK)
|
||||
{
|
||||
RunStreamK(kargs);
|
||||
}
|
||||
else if constexpr(GroupedConvTraitsType_::ExplicitGemm)
|
||||
{
|
||||
CallExplicitGemm(kargs);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user