[CK TILE] Grouped conv fwd split image (#2970)

* Refactor split-image implementation: simplify code and remove redundant variables

* Add padding debug output to split-image implementation

- Added debug prints for padding calculations in transform_conv_fwd_to_gemm.hpp
- Verified padding works correctly with all tests passing

* Fix sign comparison warning after rebase with origin/develop

- Cast blockIdX from unsigned to signed index_t for comparisons
- Integrated with new GetOutputTileIndex logic from upstream
- Updated to use amd_wave_read_first_lane instead of __builtin_amdgcn_readfirstlane

* Fix Split-N with groups bug and clean up unused parameters

- Fixed batch stride calculation to include G dimension for grouped convolutions
- When moving between batches in NHWGC/NWGC/NDHWGC layouts, need to account for all groups
- Removed unused multi-split parameters (we only support 2-way split)
- All tests now pass: G=1 with Split-N, G>1 with Split-N, G>1 without Split-N

* Implement recursive queue-based split-image detection and calculation

- Add LaunchKernelWithSplitIfNeeded() helper method in transform_conv_fwd_to_gemm.hpp
- Implement recursive binary splitting algorithm (10GB→5GB+5GB→...)
- Correctly handle odd dimensions (61→30+31)
- Calculate proper offsets for each split piece
- Update invoker to use split-image helper

Note: Split detection and calculation work correctly but kernel launching
for individual pieces requires kernel modification to handle different
spatial dimensions (unlike Split-N which uses blockIdx.z).

* WIP: Split-Image investigation - found architecture mismatch

- Split-N modifies N_ directly in transformer constructor
- Split-Image needs different approach due to varying dimensions
- Added split calculation logic for 1D and 2D convolutions
- Still facing memory issues when creating piece transformers

Key finding: Split-N uses blockIdx.z for parallel execution,
while Split-Image needs sequential execution of non-uniform pieces.

* Add 1D split-image implementation for grouped convolution (N=1 working)

Implements split-image for 1D convolution to handle large tensors that
exceed memory thresholds. This is a critical milestone with N=1 fully
working and tested.

Key Changes:
- Invoker: Add split-image logic that splits W dimension in half
- Transformer: Add SplitConvProblem helper for recursive splitting
- Calculate offsets for LEFT and RIGHT pieces
- Launch two kernels sequentially (LEFT then RIGHT)

Implementation Details:
- Binary split: divides W dimension by 2
- LEFT piece: W=0 to W/2, keeps left padding, removes right padding
- RIGHT piece: W/2 to W, removes left padding, keeps right padding
- Offset calculation accounts for stride, dilation, and padding
- Physical memory offset (no padding in memory)

Test Results (N=1):
 94/94 tests passing
- Comprehensive tests: 36/36 (channels, padding, stride, dilation, filters, groups)
- Edge case tests: 31/31 (odd dimensions, extreme parameters, boundaries)
- Stress tests: 27/27 (maximum dimensions, up to 91.4 TFlops)

Known Limitations:
- Only works with N=1 (single batch)
- N>1 fails when split-image triggers (offset calculation issue with Split-N)
- Root cause: Split-N modifies N in transformer, but offset calculated in invoker
- Solution planned: Move offset calculation to transformer (next phase)

Files Modified:
- grouped_convolution_forward_invoker.hpp: Add split-image logic
- transform_conv_fwd_to_gemm.hpp: Add SplitConvProblem helper

This commit represents a stable, tested 1D split-image implementation
for N=1 cases. It's an important milestone before extending to N>1
and multi-dimensional splits.

* Add basic split-image implementation for 1D/2D/3D grouped convolution

This is a working baseline implementation that splits large spatial
dimensions to handle memory constraints.

Implementation:
- 1D: W-split for NWGC layout (36/36 tests passing)
- 2D: H-split for NHWGC layout (20/20 tests passing)
- 3D: D-split for NDHWGC layout (verified working)

Features:
- Binary split of outermost spatial dimension
- Sequential LEFT/RIGHT kernel launches
- Proper padding adjustment at split boundaries
- Offset calculation for pointer arithmetic
- Debug output for verification

Threshold: 100KB (configurable in transformer)

Known limitations:
- No safety checks for edge cases (to be added)
- Offset calculated before Split-N (incompatible with N>1, to be fixed)
- No recursive splitting for very large tensors

Next steps:
- Add safety checks (is_possible_to_split_*)
- Move offset calculation to transformer (after Split-N)
- Test with N>1 + split-image combination

* Refactor split-image to unified structure for 1D/2D/3D

Unified the three separate dimension-specific blocks into a single
common implementation with dimension-specific stride calculations.

Benefits:
- Reduced code from 636 → 348 lines (45% reduction)
- Eliminated code duplication
- Easier to maintain and extend
- Single source of truth for split logic

Implementation:
- Common: Binary split, offset calc, padding adjustment, kernel launch
- Dimension-specific: Stride calculation only
  - 1D: stride = G * C
  - 2D: stride = W_in * G * C
  - 3D: stride = H_in * W_in * G * C

Test results (all passing):
- 1D: 36/36 tests 
- 2D: 20/20 tests 
- 3D: 28/28 tests 
- Total: 84/84 (100%)

All test scenarios verified:
- Varying channels, padding, stride, dilation
- Filter sizes (1x1 pointwise to 7x7)
- Multiple groups (G=1,2,4)
- Odd dimensions
- Complex combinations

* Add safety checks for split-image in all dimensions

Added is_possible_to_split safety checks to prevent crashes when
splitting is not feasible.

Safety checks verify:
1. Output dimension > 1 (can't split single element)
2. RIGHT piece starts after left padding
3. LEFT piece ends within input bounds

If checks fail, falls back to normal kernel launch.

Verified for all dimensions:
- 1D (W-split): Wo=1 case triggers fallback
- 2D (H-split): Ho=1 case triggers fallback
- 3D (D-split): Do=1 case triggers fallback

Original 84 tests still pass - they use normal configurations
that naturally satisfy safety conditions.

Safety checks protect against pathological edge cases with:
- Very small spatial dimensions
- Extreme stride/dilation combinations
- Invalid padding configurations

* Fix Split-N + Split-Image compatibility issue

Fixed critical bug where Split-N and Split-Image working together
caused ~50% incorrect results due to wrong batch stride calculation.

Problem:
- Batch stride was calculated using MODIFIED spatial dimensions
  (e.g., W=50000 after split) instead of ORIGINAL dimensions (W=100000)
- Spatial offset was applied globally in invoker, not per-batch in kernel
- Each batch (blockIdx.z) got wrong memory offset

Solution:
1. Store spatial offset in kargs (don't apply to pointer in invoker)
2. Copy correct batch_stride from temp_kargs to left/right kargs
3. Apply formula in operator(): ptr = base + (batch × stride) + spatial_offset

Changes:
- grouped_convolution_forward_kernel.hpp:
  * Added spatial_offset_in/out fields to KernelArgs
  * Apply batch + spatial offset in operator()

- grouped_convolution_forward_invoker.hpp:
  * Keep base pointer, store spatial offset in kargs
  * Copy batch_stride from temp_kargs (has original dimensions)

- transform_conv_fwd_to_gemm.hpp:
  * Add debug output for split-image calculation

Results:
- N=1 tests: 84/84 passing (100%)
- N>1 tests: Now all passing (previously ~50% errors)
- Tested: 1D, 2D, 3D with N=1,2,4,8,16,20

* Implement unified threshold for Split-N and Split-Image

This commit consolidates threshold management for both Split-N and
Split-Image operations into a single source of truth, eliminating
code duplication and fixing offset calculation issues.

Key Changes:
============

1. Transformer (transform_conv_fwd_to_gemm.hpp):
   - Moved TwoGB constant to public section for unified access
   - CalculateSplitImage() now takes no parameters
   - Uses internal threshold: TwoGB / sizeof(CDataType)
   - Calculates offsets using N_ (after Split-N) for correctness

2. Kernel (grouped_convolution_forward_kernel.hpp):
   - GetSplitImageInfo() simplified to take no parameters
   - Forwards to transformer's CalculateSplitImage()
   - Clean interface with unified threshold internally

3. Invoker (grouped_convolution_forward_invoker.hpp):
   - Removed redundant threshold calculation
   - Simplified to call kargs.GetSplitImageInfo() with no params
   - Clean early-return pattern (no unnecessary else blocks)
   - Removed duplicate/dead code paths

Benefits:
=========
- Single source of truth: TwoGB defined once in transformer
- No parameter passing for threshold between components
- Correct offset calculation using N_ (post-Split-N)
- Cleaner code with no duplication
- All tests passing: 1D/2D/3D with various N values

Testing:
========
- Split-Image only (N=1, large spatial): PASS
- Split-N only (N>1, small spatial): PASS
- Both splits active (N>1, large spatial): PASS
- No splits (N=1, small spatial): PASS
- CPU verification correct for all scenarios

* Comment out outdated split-image code (SplitConvProblem/LaunchKernelWithSplitIfNeeded)

The old recursive queue-based implementation has been replaced by the
new CalculateSplitImage() method which is simpler and correctly handles
Split-N + Split-Image interaction.

Changes:
- Wrapped lines 381-1078 in #if 0...#endif
- Old methods: SplitConvProblem() and LaunchKernelWithSplitIfNeeded()
- Preserved for reference but disabled from compilation
- No functional changes - all tests still pass

The new implementation (CalculateSplitImage at line ~2163) provides:
- Correct offset calculation using N_ (after Split-N)
- Simpler binary split logic
- Better integration with unified threshold approach

* Implement recursive split-image with depth limit (MAX_DEPTH=10)

Changes:
- Add depth tracking to SplitPiece struct
- Implement two stopping conditions:
  1. Piece size below threshold (optimal case)
  2. Depth >= MAX_DEPTH (prevents infinite recursion)
- Remove MAX_PIECES limit in favor of depth-based control
- Support up to 2^10 = 1024 pieces with depth 10

This allows handling extreme tensor sizes while ensuring termination.
Pieces larger than threshold will still launch correctly if depth limit reached.

Tested with H=100 (4 levels), H=2000 (6 levels), H=4000 (9 levels) - all pass CPU verification.

* Summary of recursive split-image implementation:
- Recursive queue-based splitting with depth limit (MAX_DEPTH=10, up to 1024 pieces)
- Two stopping conditions: size below threshold OR max depth reached
- Cumulative offset tracking through all recursion levels
- LEFT piece inherits parent offset, RIGHT accumulates (parent + local)
- Per-batch spatial offset application in kernel operator()
- Batch stride uses original dimensions (before split)
- Works with Split-N: split-N first, then recursive split-image
- Handles odd dimensions, padding, stride, dilation correctly
- All 1D/2D/3D tests pass with CPU verification

* Add comment explaining MAX_DEPTH capacity for 2GB threshold

* Refactor: move recursive split-image logic to transformer

- Move LaunchWithRecursiveSplit() from invoker to transform_conv_fwd_to_gemm.hpp
- Simplify invoker from ~250 lines to ~140 lines (removed 110 lines of inline logic)
- Encapsulate SplitPiece struct and BFS splitting algorithm in transformer
- Remove unused includes (queue, vector) from invoker
- Add documentation comment for AreDescriptorsSmallerThan2GB()
- Improve code organization and reusability
- No performance overhead (static template function, compiler inlines)
- All tests passing with 2GB production threshold

* Apply clang-format-18 formatting

- Format invoker and transformer files with clang-format-18
- Fix brace placement and alignment
- No functional changes

* Fix clang-format-18 issues in forward kernel

- Remove extra blank lines
- Fix line wrapping for template calls
- Consolidate GetSplitImageInfo() to single line

* Update include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Split-Image implementation with temporary fixed divider

- Implemented spatial dimension splitting (Split-Image) for large tensors
- Added piece-based coordinate transformation for 1D/2D/3D convolutions
- Integrated Split-N (batch splitting) with automatic threshold detection
- Fixed M dimension calculation to include batch: M = N × spatial_size
- Added spatial offset support in kernel arguments
- Verified 20/20 test cases passing for Split-Image alone
- Known issue: Split-N + Split-Image combination needs coordinate fix

Implementation Details:
- Split factors: 4 (1D), 4×4 (2D), 4×4×4 (3D) - temporary fixed values
- Batch strides properly calculated for NWGC/NHWGC/NDHWGC layouts
- Piece descriptors track spatial boundaries and block ranges
- No performance overhead for N=1 cases

* Fix 1D split-image padding issue with per-piece dimensions

- Store actual size per piece to handle non-uniform splits
- Remove dead code from transform utils

* Fix 2D/3D split-image with independent split factors per dimension

Problem: Single split factor caused non-uniform pieces when dimensions
didn't divide evenly. Result: 18/25 (72%) 2D padding combinations failed.

Solution: Independent split factor selection for W, H, D dimensions.
Each dimension gets optimal factor based on its own size.

Test Results:
- 1D: 42/42 pass (100%)
- 2D: 25/25 pass (100%)
- Total: 67/67 combinations verified

* Remove unused split-image struct fields

Cleanup of split-image implementation:
- Removed unused piece_d, piece_h, piece_w fields from SplitImageInfo struct
- These fields were declared but never used in the kernel
- Per-piece dimensions are already stored in pieces[] array
- Reduces struct size and improves code clarity

Tested: 1D/2D/3D convolutions with split-image, padding, stride all pass

* Refactor split-image invoker code for improved readability

- Extract piece calculation logic into calculate_piece lambda helper
- Extract kernel args population into populate_split_image_kargs lambda
- Use aggregate initialization for cleaner struct population
- Reduce nesting depth and improve maintainability
- Fix outdated comment about split-image implementation status

* Refactor split-image code and remove debug prints

- Extract GPU kernel helper lambdas for better readability
- Remove all split-image debug print statements
- Set memory threshold to 2GB for production
- All tests pass with CPU verification

* Add split-image safety constraints and refactor to utils

- Add MAX_TOTAL_PIECES=64 limit to prevent segfault
- Move calculate_spatial_piece to library utils
- Add layout validation (NWGC, NHWGC, NDHWGC only)
- Fix hierarchical splitting to respect piece limits
- Add proper documentation and formatting

* Change split-image from runtime to compile-time branching

Response to @bartekxk review comment:
Convert 'if(kargs.num_spatial_pieces > 1)' to 'if constexpr(EnableSplitImage)'

Changes:
- Add EnableSplitImage template parameter to kernel
- Change runtime if to compile-time if constexpr
- Update invoker to instantiate kernel variants with true/false

Benefits:
- Eliminates runtime branching in GPU kernel
- Dead code elimination (each variant is smaller)
- Better compiler optimization

Files modified: 2
Lines changed: 20 total (6 in kernel, 14 in invoker)
Tests: 27/27 passed (100%)
Performance: No regression

* Add split-image example as separate binary

- Create grouped_convolution_forward_split_image example
- Add grouped_convolution_forward_split_image_invoker.hpp
- Update CMakeLists.txt to build split_image binary

* Replace linear search with binary search in find_piece_id

- Change O(n) to O(log n) for finding piece ownership
- Matches reference implementation in large_tensor_cshuffle

* Simplify split-image code and fix integer overflow

- Extract lambda functions to static helper methods
- Pre-calculate constants in invoker
- Fix integer overflow in tensor size calculation for large tensors

* Trigger CI rerun - fix merge conflicts

* Fix merge conflict markers

* Fix clang-format: remove space before {}

* Fix clang-format: comment wrapping and Swish constructor

* Rename split_image to large_tensor for clarity

- Renamed grouped_convolution_forward_split_image.cpp -> grouped_convolution_forward_large_tensor.cpp
- Renamed grouped_convolution_forward_split_image_invoker.hpp -> grouped_convolution_forward_large_tensor_invoker.hpp
- Updated CMakeLists.txt target name: tile_example_grouped_conv_fwd_split_image -> tile_example_grouped_conv_fwd_large_tensor
- Updated comments to refer to 'large tensor' instead of 'split-image'

* Update comments and include in large_tensor example

- Updated header comments to use 'large tensor' terminology
- Fixed include path to use large_tensor_invoker.hpp

* Remove test code, restore 2GB threshold

* Update include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Fix build errors after develop merge and complete rename to large_tensor

This commit addresses compilation errors from the develop merge and
completes the rename from split_image to large_tensor.

Changes:
1. Fix CDEElementWise typo in grouped_convolution_forward_invoker.hpp
2. Fix template parameter order in large_tensor_invoker.hpp
   - TransformConvFwdToGemm signature changed in develop
   - NumGroupsToMerge and SplitN parameters swapped positions
3. Fix missing template parameter in GroupedConvFwdHostArgs
4. Fix EpiloguePipeline scope in kernel (merge conflict)
5. Update binary name references in test scripts

* Restore 2GB threshold for split-image

Changed threshold from 100MB (testing) back to 2GB for production use.

* Fix const-correctness in ds_ptr cast

* Update include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Apply clang-format-18

* update c++ 18 format

* Apply clang-format-18 to transform_conv_fwd_to_gemm.hpp

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
JH-Leon-KIM-AMD
2025-11-01 14:18:16 +02:00
committed by GitHub
parent 8f1274d9b6
commit 1fbb47ad30
8 changed files with 1124 additions and 306 deletions

View File

@@ -2,16 +2,19 @@ set(EXAMPLE_CONV_COMPILE_OPTIONS)
list(APPEND EXAMPLE_CONV_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
add_executable(tile_example_grouped_conv_fwd EXCLUDE_FROM_ALL grouped_convolution_forward.cpp)
target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS})
add_executable(tile_example_grouped_conv_fwd_large_tensor EXCLUDE_FROM_ALL grouped_convolution_forward_large_tensor.cpp)
target_compile_options(tile_example_grouped_conv_fwd_large_tensor PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS})
add_executable(tile_example_grouped_conv_fwd_bias_clamp EXCLUDE_FROM_ALL grouped_convolution_forward_bias_clamp.cpp)
target_compile_options(tile_example_grouped_conv_fwd_bias_clamp PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
add_executable(tile_example_grouped_conv_bwd_weight EXCLUDE_FROM_ALL grouped_convolution_backward_weight.cpp)
target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS})
add_executable(tile_example_grouped_conv_bwd_weight_two_stage EXCLUDE_FROM_ALL grouped_convolution_backward_weight_two_stage.cpp)
target_compile_options(tile_example_grouped_conv_bwd_weight_two_stage PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_grouped_conv_bwd_weight_two_stage PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS})
add_executable(tile_example_grouped_conv_bwd_data EXCLUDE_FROM_ALL grouped_convolution_backward_data.cpp)
target_compile_options(tile_example_grouped_conv_bwd_data PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
target_compile_options(tile_example_grouped_conv_bwd_data PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS})

View File

@@ -1,5 +1,12 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// Regular grouped convolution invoker (no split-image)
// This invoker demonstrates regular convolution without split-image.
// It always uses Kernel<false> (split-image disabled).
// For large images that require split-image, use
// grouped_convolution_forward_split_image_invoker.hpp
#pragma once
#include "grouped_convolution_utils.hpp"
@@ -21,6 +28,10 @@ struct GroupedConvolutionForwardInvoker
static float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs<CDElementWise>& args,
const ck_tile::stream_config& s)
{
if(s.log_level_ > 0)
{
std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n";
}
constexpr int kBlockPerCu = 1;
// Implicit GEMM Traits
@@ -90,6 +101,7 @@ struct GroupedConvolutionForwardInvoker
1,
std::multiplies<ck_tile::index_t>());
// Split-K parameters
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
@@ -97,100 +109,117 @@ struct GroupedConvolutionForwardInvoker
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run =
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
// =====================================================================
// Regular Convolution: Simple, no split-image
// =====================================================================
const auto Run = [&]<bool EnableSplitImage>(const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
true,
VectorSizeA,
VectorSizeB>;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
true,
VectorSizeA,
VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
InDataType,
WeiDataType,
DsDataType,
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
GemmConfig::TransposeC,
memory_operation,
1,
true,
GroupedConvTraitsType::VectorSizeC>>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
InDataType,
WeiDataType,
DsDataType,
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
GemmConfig::TransposeC,
memory_operation,
1,
true,
GroupedConvTraitsType::VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
using Kernel = ck_tile::GroupedConvolutionForwardKernel<EnableSplitImage,
GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::GridSize(kargs);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
return ave_time;
};
// =====================================================================
// Split-K lambda
// =====================================================================
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_, tail_number_, MemoryOpSet{});
Run.template operator()<false>(has_hot_loop_, tail_number_, MemoryOpSet{});
}
else
{
Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
Run.template operator()<false>(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
}
};
// =====================================================================
// Regular Convolution Example: ALWAYS uses regular path (Kernel<false>)
// =====================================================================
// This example demonstrates regular convolution without split-image.
// For large images that don't fit in memory, use
// grouped_convolution_forward_split_image.cpp
// Launch kernel using regular path (no split-image)
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}
};

View File

@@ -0,0 +1,63 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// Large tensor grouped convolution example
// This example demonstrates convolution for large tensors that exceed memory limits.
// It uses automatic tensor splitting when needed to handle large images.
// For regular convolution without tensor splitting, use grouped_convolution_forward.cpp
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include "ck_tile/host.hpp"
#include "grouped_convolution_utils.hpp"
#include "grouped_convolution_forward_large_tensor_invoker.hpp"
#include "run_grouped_convolution_fwd_example.inc"
template <template <typename PrecType> typename GemmConfig>
int run_grouped_conv_fwd_example(int argc, char* argv[])
{
using Invoker = GroupedConvolutionForwardInvoker;
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string in_layout = arg_parser.get_str("in_layout");
std::string wei_layout = arg_parser.get_str("wei_layout");
std::string out_layout = arg_parser.get_str("out_layout");
if(data_type == "fp16")
{
return run_grouped_conv_fwd_example_prec_type<Invoker,
GemmConfig<ck_tile::half_t>,
ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_fwd_example_prec_type<Invoker,
GemmConfig<ck_tile::bf16_t>,
ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
int main(int argc, char* argv[])
{
#if CK_TILE_USE_WMMA
return !run_grouped_conv_fwd_example<ConvConfigComputeV3_WMMA>(argc, argv);
#else
return !run_grouped_conv_fwd_example<ConvConfigComputeV3>(argc, argv);
#endif
}

View File

@@ -0,0 +1,388 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "grouped_convolution_utils.hpp"
struct GroupedConvolutionForwardInvoker
{
template <ck_tile::index_t NDimSpatial,
typename GemmConfig,
typename InDataType,
typename WeiDataType,
typename AccDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename DsDataType = ck_tile::tuple<>,
typename DsLayout = ck_tile::tuple<>,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
static float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs<CDEElementWise>& args,
const ck_tile::stream_config& s)
{
if(s.log_level_ > 0)
{
std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n";
}
constexpr int kBlockPerCu = 1;
// Implicit GEMM Traits
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
ConvSpec,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
VectorSizeA,
VectorSizeB,
VectorSizeC>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::AsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::BsLayout,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
false, // Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
InDataType,
WeiDataType,
AccDataType,
GemmShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
true,
VectorSizeA,
VectorSizeB>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t gemm_k =
args.C_ * std::accumulate(args.filter_spatial_lengths_.begin(),
args.filter_spatial_lengths_.end(),
1,
std::multiplies<ck_tile::index_t>());
// Split-K parameters
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
using TransformType =
ck_tile::TransformConvFwdToGemm<NDimSpatial,
ck_tile::ConvolutionSpecialization::Default,
VectorSizeA,
VectorSizeB,
VectorSizeC,
1, // NumGroupsToMerge
false, // SplitN
InDataType,
OutDataType>;
// =====================================================================
// Step 1: Check if layout supports split-image kernel
// =====================================================================
// Split-image requires specific memory layouts:
// 1D: NWGC (input), GKXC (weight), NWGK (output)
// 2D: NHWGC (input), GKYXC (weight), NHWGK (output)
// 3D: NDHWGC (input), GKZYXC (weight), NDHWGK (output)
constexpr bool is_supported_layout =
std::is_same<InLayout, ck_tile::tensor_layout::convolution::NWGC>::value ||
std::is_same<InLayout, ck_tile::tensor_layout::convolution::NHWGC>::value ||
std::is_same<InLayout, ck_tile::tensor_layout::convolution::NDHWGC>::value;
// =====================================================================
// Step 2: Calculate split-image info (if layout supports it)
// =====================================================================
// Extract output spatial dimensions
const ck_tile::index_t total_d =
(NDimSpatial == 3) ? args.output_spatial_lengths_[NDimSpatial - 3] : 1;
const ck_tile::index_t total_h =
(NDimSpatial >= 2) ? args.output_spatial_lengths_[NDimSpatial - 2] : 1;
const ck_tile::index_t total_w = args.output_spatial_lengths_[NDimSpatial - 1];
auto split_info = TransformType::GetSplitImageInfo(
args.G_, args.N_, args.C_, args.K_, total_d, total_h, total_w);
// =====================================================================
// Decide: Split-image or regular kernel?
// =====================================================================
const bool use_split_image = is_supported_layout && split_info.should_split;
if(s.log_level_ > 0)
{
if(!is_supported_layout)
{
std::cout << "[INVOKER] Layout not supported for split-image. "
<< "Using regular kernel (Kernel<false>).\n";
}
else if(!split_info.should_split)
{
std::cout << "[INVOKER] Image is small (" << total_h << "×" << total_w
<< "), split-image not necessary.\n";
std::cout << "[INVOKER] Using regular kernel (Kernel<false>).\n";
}
}
// =====================================================================
// Step 3: Calculate split-image pieces (only if using split-image)
// =====================================================================
ck_tile::index_t num_d_pieces = 1;
ck_tile::index_t num_h_pieces = 1;
ck_tile::index_t num_w_pieces = 1;
ck_tile::index_t total_pieces = 1;
ck_tile::index_t base_piece_d = total_d;
ck_tile::index_t base_piece_h = total_h;
ck_tile::index_t base_piece_w = total_w;
std::array<ck_tile::SplitImagePieceInfo, 64> temp_pieces{};
ck_tile::index_t total_blocks = 0;
if(use_split_image)
{
num_d_pieces = split_info.num_d_pieces;
num_h_pieces = split_info.num_h_pieces;
num_w_pieces = split_info.num_w_pieces;
total_pieces = num_d_pieces * num_h_pieces * num_w_pieces;
if(s.log_level_ > 0)
{
std::cout << "\n========================================\n";
std::cout << "[SPLIT-IMAGE ENABLED] Large tensor detected\n";
std::cout << "========================================\n";
if(NDimSpatial == 3)
{
std::cout << "Total dimensions: D=" << total_d << " H=" << total_h
<< " W=" << total_w << "\n";
std::cout << "Split into pieces: D=" << num_d_pieces << " × H=" << num_h_pieces
<< " × W=" << num_w_pieces << " = " << total_pieces
<< " total pieces\n";
std::cout << "Base piece size: D=" << (total_d / num_d_pieces)
<< " H=" << (total_h / num_h_pieces)
<< " W=" << (total_w / num_w_pieces) << "\n";
}
else if(NDimSpatial == 2)
{
std::cout << "Total dimensions: H=" << total_h << " W=" << total_w << "\n";
std::cout << "Split into pieces: H=" << num_h_pieces << " × W=" << num_w_pieces
<< " = " << total_pieces << " total pieces\n";
std::cout << "Base piece size: H=" << (total_h / num_h_pieces)
<< " W=" << (total_w / num_w_pieces) << "\n";
}
else
{
std::cout << "Total dimensions: W=" << total_w << "\n";
std::cout << "Split into pieces: W=" << num_w_pieces << " = " << total_pieces
<< " total pieces\n";
std::cout << "Base piece size: W=" << (total_w / num_w_pieces) << "\n";
}
std::cout << "========================================\n\n";
}
// Base piece size (non-overlapping division)
base_piece_d = total_d / num_d_pieces;
base_piece_h = total_h / num_h_pieces;
base_piece_w = total_w / num_w_pieces;
// Calculate piece info for all pieces using library utility function
for(ck_tile::index_t piece = 0; piece < total_pieces; piece++)
{
temp_pieces[piece] =
ck_tile::calculate_spatial_piece<TilePartitioner>(piece,
num_d_pieces,
num_h_pieces,
num_w_pieces,
base_piece_d,
base_piece_h,
base_piece_w,
total_d,
total_h,
total_w,
args.N_,
args.K_,
total_blocks);
total_blocks = temp_pieces[piece].block_end;
}
}
// =====================================================================
// Kernel launch lambda: Uses EnableSplitImage based on layout support
// =====================================================================
const auto Run = [&]<bool EnableSplitImage>(const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<InDataType,
WeiDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v,
ck_tile::element_wise::PassThrough,
ck_tile::element_wise::PassThrough,
OutDataType,
true,
VectorSizeA,
VectorSizeB>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
InDataType,
WeiDataType,
DsDataType,
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
GemmConfig::TransposeC,
memory_operation,
1,
true,
GroupedConvTraitsType::VectorSizeC>>;
// Use split-image kernel if layout supports it, otherwise use regular kernel
using Kernel = ck_tile::GroupedConvolutionForwardKernel<EnableSplitImage,
GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
// Create kargs
auto kargs = Kernel::MakeKernelArgs(args);
// Populate split-image metadata ONLY if using split-image kernel
if constexpr(EnableSplitImage)
{
kargs.num_spatial_pieces = total_pieces;
kargs.split_image.total_d = total_d;
kargs.split_image.total_h = total_h;
kargs.split_image.total_w = total_w;
kargs.split_image.total_spatial = total_d * total_h * total_w; // Pre-calculate
kargs.split_image.num_d_pieces = num_d_pieces;
kargs.split_image.num_h_pieces = num_h_pieces;
kargs.split_image.num_w_pieces = num_w_pieces;
for(ck_tile::index_t i = 0; i < total_pieces; i++)
{
kargs.split_image.pieces[i] = {temp_pieces[i].block_start,
temp_pieces[i].block_end,
temp_pieces[i].d_start,
temp_pieces[i].h_start,
temp_pieces[i].w_start,
temp_pieces[i].d_size,
temp_pieces[i].h_size,
temp_pieces[i].w_size};
}
}
// Calculate grid: use total_blocks for split-image, or normal GridSize for regular
const dim3 grids = [&]() {
if constexpr(EnableSplitImage)
return dim3(total_blocks, kargs.GemmBatch, kargs.n_splits);
else
return Kernel::GridSize(kargs);
}();
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << '\n'
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
// =====================================================================
// Step 4: Dispatch kernel (split-image or regular based on decision)
// =====================================================================
if(use_split_image)
{
// Use split-image kernel (Kernel<true>)
const auto RunSplitImage = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
Run.template operator()<true>(has_hot_loop_, tail_number_, MemoryOpSet{});
else
Run.template operator()<true>(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
};
BaseGemmPipeline::TailHandler(RunSplitImage, has_hot_loop, tail_num);
}
else
{
// Use regular kernel (Kernel<false>)
const auto RunRegular = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
Run.template operator()<false>(has_hot_loop_, tail_number_, MemoryOpSet{});
else
Run.template operator()<false>(
has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
};
BaseGemmPipeline::TailHandler(RunRegular, has_hot_loop, tail_num);
}
return ave_time;
}
};

View File

@@ -58,7 +58,7 @@ struct TransformConvFwdToGemm
calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1);
const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType),
c_element_space_size * sizeof(CDataType));
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
constexpr long_index_t TwoGB = (long_index_t{1} << 31); // 2GB threshold
const IndexType N = a_g_n_c_wis_lengths[I1];

View File

@@ -78,23 +78,21 @@ struct GroupedConvFwdKernelArgs
}
out_ptr = args.out_ptr;
ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
wei_g_k_c_xs_lengths,
out_g_n_k_wos_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
// Create and STORE transformer (for split-image support)
transformer_ = ConvToGemmFwdTransformer{in_g_n_c_wis_lengths,
wei_g_k_c_xs_lengths,
out_g_n_k_wos_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
a_grid_desc_m_k =
conv_to_gemm_transformer
.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
b_grid_desc_n_k =
conv_to_gemm_transformer
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
c_grid_desc_m_n =
conv_to_gemm_transformer
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
group_stride_a = args.C_;
group_stride_b = args.K_ * args.C_ *
@@ -106,13 +104,16 @@ struct GroupedConvFwdKernelArgs
// Initialize Split-N support fields for 1D convolution (NWGC layout)
// Get the actual split N from transformer
n_per_split = conv_to_gemm_transformer.GetN();
original_n = conv_to_gemm_transformer.GetOriginalN();
n_per_split = transformer_.GetN();
original_n = transformer_.GetOriginalN();
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
// Calculate batch strides for NWGC layout
input_batch_stride = args.C_ * args.input_spatial_lengths_[0];
output_batch_stride = args.K_ * args.output_spatial_lengths_[0];
// Calculate batch strides using the original argument dimensions.
// These are the original dimensions passed to the constructor, not modified by the invoker
// yet. (The invoker modifies args after calling MakeKernelArgs.) VERIFIED: G_ MUST be
// included - NWGC layout has all groups within each batch
input_batch_stride = args.G_ * args.C_ * args.input_spatial_lengths_[0];
output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0];
// Update GemmM to use split N (not original N)
GemmM = n_per_split * args.output_spatial_lengths_[0];
@@ -169,23 +170,21 @@ struct GroupedConvFwdKernelArgs
}
out_ptr = args.out_ptr;
ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
wei_g_k_c_xs_lengths,
out_g_n_k_wos_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
// Create and STORE transformer (for split-image support)
transformer_ = ConvToGemmFwdTransformer{in_g_n_c_wis_lengths,
wei_g_k_c_xs_lengths,
out_g_n_k_wos_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
a_grid_desc_m_k =
conv_to_gemm_transformer
.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
b_grid_desc_n_k =
conv_to_gemm_transformer
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
c_grid_desc_m_n =
conv_to_gemm_transformer
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
group_stride_a = args.C_;
group_stride_b = args.K_ * args.C_ *
@@ -197,15 +196,16 @@ struct GroupedConvFwdKernelArgs
// Initialize Split-N support fields for 2D convolution (NHWGC layout)
// Get the actual split N from transformer
n_per_split = conv_to_gemm_transformer.GetN();
original_n = conv_to_gemm_transformer.GetOriginalN();
n_per_split = transformer_.GetN();
original_n = transformer_.GetOriginalN();
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
// Calculate batch strides for NHWGC layout
// VERIFIED: G_ MUST be included - NHWGC layout has all groups within each batch
input_batch_stride =
args.C_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1];
args.G_ * args.C_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1];
output_batch_stride =
args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
args.G_ * args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
// Update GemmM to use split N (not original N)
GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
@@ -270,23 +270,21 @@ struct GroupedConvFwdKernelArgs
}
out_ptr = args.out_ptr;
ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
wei_g_k_c_xs_lengths,
out_g_n_k_wos_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
// Create and STORE transformer (for split-image support)
transformer_ = ConvToGemmFwdTransformer{in_g_n_c_wis_lengths,
wei_g_k_c_xs_lengths,
out_g_n_k_wos_lengths,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads};
a_grid_desc_m_k =
conv_to_gemm_transformer
.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
b_grid_desc_n_k =
conv_to_gemm_transformer
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
c_grid_desc_m_n =
conv_to_gemm_transformer
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
group_stride_a = args.C_;
group_stride_b = args.K_ * args.C_ *
@@ -298,14 +296,15 @@ struct GroupedConvFwdKernelArgs
// Initialize Split-N support fields for 3D convolution (NDHWGC layout)
// Get the actual split N from transformer
n_per_split = conv_to_gemm_transformer.GetN();
original_n = conv_to_gemm_transformer.GetOriginalN();
n_per_split = transformer_.GetN();
original_n = transformer_.GetOriginalN();
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
// Calculate batch strides for NDHWGC layout
input_batch_stride = args.C_ * args.input_spatial_lengths_[0] *
// VERIFIED: G_ MUST be included - NDHWGC layout has all groups within each batch
input_batch_stride = args.G_ * args.C_ * args.input_spatial_lengths_[0] *
args.input_spatial_lengths_[1] * args.input_spatial_lengths_[2];
output_batch_stride = args.K_ * args.output_spatial_lengths_[0] *
output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0] *
args.output_spatial_lengths_[1] * args.output_spatial_lengths_[2];
// Update GemmM to use split N (not original N)
@@ -359,6 +358,42 @@ struct GroupedConvFwdKernelArgs
index_t original_n = 1; // Original batch size before splitting
index_t input_batch_stride = 0; // Stride to next batch in input tensor
index_t output_batch_stride = 0; // Stride to next batch in output tensor
// Split-image support - spatial offsets (applied per-batch in operator())
long_index_t spatial_offset_in = 0; // Spatial offset for input (e.g., W/2 for 1D split)
long_index_t spatial_offset_out = 0; // Spatial offset for output (e.g., W/2 for 1D split)
// Split-image support - transformer instance
ConvToGemmFwdTransformer transformer_;
// Forward declare descriptor types (will be defined after using declarations)
using ConvToGemmFwdTransformer_t = ConvToGemmFwdTransformer;
using AGridDescMK_t = AGridDescMK;
using CGridDescMN_t = CGridDescMN;
// Split-image support: Common data for all pieces
struct SplitImageInfo
{
// Common dimensions (same for all pieces)
index_t total_d = 1, total_h = 1, total_w = 1; // Total tensor dimensions
index_t total_spatial = 1; // Pre-calculated: total_d * total_h * total_w
index_t num_d_pieces = 1, num_h_pieces = 1, num_w_pieces = 1; // Split factors
// Minimal per-piece data (only unique values)
struct PieceInfo
{
index_t block_start; // Starting block index for this piece
index_t block_end; // Ending block index (exclusive)
index_t d_start, h_start, w_start; // Piece starting position in OUTPUT space
index_t d_size, h_size, w_size; // Piece size in OUTPUT space
};
static constexpr index_t MaxPieces = 64; // Max pieces: 4 (1D), 16 (2D), 64 (3D)
std::array<PieceInfo, MaxPieces> pieces; // Array of minimal piece descriptors
};
index_t num_spatial_pieces = 1; // Number of spatial pieces (1 = no split)
SplitImageInfo split_image; // Nested structure with common + per-piece data
};
/// @brief The Grouped Convolution Forward kernel template.
@@ -399,13 +434,15 @@ struct GroupedConvFwdKernelArgs
/// multiplication implementation. It is responsible for storing
/// results calculated by @ref GemmPipeline_ "GemmPipeline" to
/// the output C tensor in global memory.
template <typename GroupedConvTraitsType_,
template <bool EnableSplitImage_,
typename GroupedConvTraitsType_,
typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_>
struct GroupedConvolutionForwardKernel
{
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
static constexpr bool EnableSplitImage = EnableSplitImage_;
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
static constexpr ConvolutionSpecialization ConvSpecialization =
GroupedConvTraitsType_::ConvSpecialization;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
@@ -435,7 +472,6 @@ struct GroupedConvolutionForwardKernel
using GroupedConvFwdKernelArgsSpecialized = GroupedConvFwdKernelArgs<GroupedConvTraitsType_>;
// TODO: Enable this
static constexpr bool IsSplitKSupported = false;
static constexpr auto I0 = number<0>();
@@ -449,6 +485,77 @@ struct GroupedConvolutionForwardKernel
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
// Helper struct for spatial coordinates
struct SpatialCoords
{
index_t d, h, w;
};
// Helper: Convert flat spatial index to (d,h,w) coordinates
CK_TILE_DEVICE static SpatialCoords
UnflattenSpatial(index_t flat, index_t h_size, index_t w_size)
{
if constexpr(NDimSpatial == 1)
{
return SpatialCoords{0, 0, flat};
}
else if constexpr(NDimSpatial == 2)
{
return SpatialCoords{0, flat / w_size, flat % w_size};
}
else // NDimSpatial == 3
{
const index_t hw = h_size * w_size;
const index_t d = flat / hw;
const index_t remainder = flat % hw;
return SpatialCoords{d, remainder / w_size, remainder % w_size};
}
}
// Helper: Convert (d,h,w) to flat spatial index
CK_TILE_DEVICE static index_t
FlattenSpatial(index_t d, index_t h, index_t w, index_t total_h, index_t total_w)
{
if constexpr(NDimSpatial == 1)
{
return w;
}
else if constexpr(NDimSpatial == 2)
{
return h * total_w + w;
}
else // NDimSpatial == 3
{
return (d * total_h + h) * total_w + w;
}
}
// Helper: Find which piece owns a block using binary search
template <typename SplitImageInfo>
CK_TILE_DEVICE static index_t
FindPieceId(index_t block_id, const SplitImageInfo& split_info, index_t num_pieces)
{
index_t left = 0;
index_t right = num_pieces - 1;
index_t piece_id = (left + right) / 2;
while(!(block_id >= split_info.pieces[piece_id].block_start &&
block_id < split_info.pieces[piece_id].block_end) &&
left <= right)
{
if(block_id < split_info.pieces[piece_id].block_start)
{
right = piece_id - 1;
}
else
{
left = piece_id + 1;
}
piece_id = (left + right) / 2;
}
return piece_id;
}
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
@@ -475,7 +582,8 @@ struct GroupedConvolutionForwardKernel
CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized
MakeKernelArgs(const GroupedConvFwdHostArgs<CDElementwise>& hostArgs)
{
return GroupedConvFwdKernelArgsSpecialized(hostArgs);
auto kargs = GroupedConvFwdKernelArgsSpecialized(hostArgs);
return kargs;
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
@@ -499,17 +607,6 @@ struct GroupedConvolutionForwardKernel
}
}
// Check Split-K and Split-N conflict (both use blockIdx.z)
if(kargs.k_batch > 1 && kargs.n_splits > 1)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR(
"Cannot use both Split-K and Split-N simultaneously (both use blockIdx.z)!");
}
return false;
}
const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
@@ -618,27 +715,32 @@ struct GroupedConvolutionForwardKernel
return true;
}
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
template <memory_operation_enum DstInMemOp = memory_operation_enum::set,
typename ADescType,
typename BDescType,
typename CDescType>
CK_TILE_DEVICE static auto
MakeGemmTensorViews(const InDataType* a_ptr,
const WeiDataType* b_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
OutDataType* c_ptr,
const GroupedConvFwdKernelArgsSpecialized& kargs)
const ADescType& a_desc,
const BDescType& b_desc,
const CDescType& c_desc)
{
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
const auto& a_tensor_view = [&]() {
return make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_m_k);
return make_tensor_view<address_space_enum::global>(a_ptr, a_desc);
}();
const auto& b_tensor_view = [&]() {
return make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_n_k);
return make_tensor_view<address_space_enum::global>(b_ptr, b_desc);
}();
// TODO: enable vector write for C in ColMajor
const auto& c_tensor_view = [&]() {
return make_tensor_view<address_space_enum::global>(c_ptr, kargs.c_grid_desc_m_n);
return make_tensor_view<address_space_enum::global>(c_ptr, c_desc);
}();
const auto& ds_tensor_view = generate_tuple(
@@ -651,7 +753,7 @@ struct GroupedConvolutionForwardKernel
"Not supported!");
return make_tensor_view<address_space_enum::global>(
static_cast<const OutDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
static_cast<const OutDataType*>(ds_ptr[i]), c_desc);
},
number<NumDTensor>{});
@@ -743,31 +845,39 @@ struct GroupedConvolutionForwardKernel
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param ds_ptr input D tensors pointer array
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param kargs Grouped Convolution Forward kernel arguments
* @param a_desc Input tensor A descriptor
* @param b_desc Weight tensor B descriptor
* @param c_desc Output tensor C descriptor
* @param gemm_k The GEMM K dimension
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
*/
template <typename ADescType, typename BDescType, typename CDescType>
CK_TILE_DEVICE static void RunGemm(const InDataType* a_ptr,
const WeiDataType* b_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
OutDataType* c_ptr,
void* smem_ptr_0,
const GroupedConvFwdKernelArgsSpecialized& kargs,
const ADescType& a_desc,
const BDescType& b_desc,
const CDescType& c_desc,
const index_t gemm_k,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(kargs.GemmK));
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k));
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
@@ -780,9 +890,8 @@ struct GroupedConvolutionForwardKernel
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{kargs.elfunc}
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
}
/**
@@ -792,32 +901,40 @@ struct GroupedConvolutionForwardKernel
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param ds_ptr input D tensors pointer array
* @param c_ptr output C pointer
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
* @param kargs Grouped Convolution Forward kernel arguments
* @param a_desc Input tensor A descriptor
* @param b_desc Weight tensor B descriptor
* @param c_desc Output tensor C descriptor
* @param gemm_k The GEMM K dimension
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
*/
template <typename ADescType, typename BDescType, typename CDescType>
CK_TILE_DEVICE static void RunGemm2LDS(const InDataType* a_ptr,
const WeiDataType* b_ptr,
const std::array<const void*, NumDTensor>& ds_ptr,
OutDataType* c_ptr,
void* __restrict__ smem_ptr_0,
void* __restrict__ smem_ptr_1,
const GroupedConvFwdKernelArgsSpecialized& kargs,
const ADescType& a_desc,
const BDescType& b_desc,
const CDescType& c_desc,
const index_t gemm_k,
const index_t block_idx_m,
const index_t block_idx_n)
{
// Create Gemm tensor views, pad views and tile windows
const auto& gemm_tensor_views_tuple =
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(kargs.GemmK));
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k));
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
@@ -837,12 +954,8 @@ struct GroupedConvolutionForwardKernel
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const
{
const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
const auto [iM, iN] =
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
@@ -860,14 +973,89 @@ struct GroupedConvolutionForwardKernel
static_cast<long_index_t>(batch_offset) *
static_cast<long_index_t>(kargs.output_batch_stride);
// Adjust pointers: combine group offset and batch offset
const InDataType* a_ptr =
// Calculate base pointers with group and batch offsets
const InDataType* base_a_ptr =
static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a + input_batch_offset;
const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) +
group_offset_b; // No batch offset for weights!
OutDataType* c_ptr =
OutDataType* base_c_ptr =
static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c + output_batch_offset;
// =====================================================================
// Split-image: Map local block to global tile index (if enabled)
// =====================================================================
const InDataType* a_ptr;
OutDataType* c_ptr;
index_t i_m = 0;
index_t i_n = 0;
// Pre-calculate block_id (used in both split-image and non-split paths)
const index_t block_id = static_cast<index_t>(blockIdX);
if constexpr(EnableSplitImage)
{
// Add spatial offsets for split-image (constexpr optimization)
a_ptr = base_a_ptr + kargs.spatial_offset_in;
c_ptr = base_c_ptr + kargs.spatial_offset_out;
// Find which piece owns this block using binary search
// Reference: device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
const index_t piece_id =
FindPieceId(block_id, kargs.split_image, kargs.num_spatial_pieces);
const auto& piece = kargs.split_image.pieces[piece_id];
const auto& split_info = kargs.split_image;
// Calculate local block ID and tile indices
const index_t local_block_id = block_id - piece.block_start;
const index_t local_gemm_m =
kargs.n_per_split * piece.d_size * piece.h_size * piece.w_size;
const auto [local_tile_m, local_tile_n] =
TilePartitioner{local_gemm_m, kargs.GemmN}.GetOutputTileIndex(local_block_id);
// Extract batch and spatial coordinates from local tile
const index_t local_m_start = local_tile_m * TilePartitioner::MPerBlock;
const index_t spatial_per_batch = piece.d_size * piece.h_size * piece.w_size;
const index_t local_n = local_m_start / spatial_per_batch;
const index_t local_spatial_flat = local_m_start % spatial_per_batch;
// Convert to local spatial coordinates
const auto local_coords =
UnflattenSpatial(local_spatial_flat, piece.h_size, piece.w_size);
// Convert to global spatial coordinates
const index_t global_n = local_n;
const index_t global_d = piece.d_start + local_coords.d;
const index_t global_h = piece.h_start + local_coords.h;
const index_t global_w = piece.w_start + local_coords.w;
// Convert to global M index
const index_t global_spatial_per_batch = split_info.total_spatial; // Pre-calculated
const index_t global_spatial_flat = FlattenSpatial(
global_d, global_h, global_w, split_info.total_h, split_info.total_w);
const index_t global_m = global_n * global_spatial_per_batch + global_spatial_flat;
// Set tile indices for GEMM operation
i_m = amd_wave_read_first_lane(global_m);
i_n = amd_wave_read_first_lane(local_tile_n * TilePartitioner::NPerBlock);
}
else
{
// No spatial offsets needed for regular path
a_ptr = base_a_ptr;
c_ptr = base_c_ptr;
// No split-image: use standard tile partitioning
const auto [iM, iN] =
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(block_id);
i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
}
// Use global descriptors for all cases
const auto& a_desc = kargs.a_grid_desc_m_k;
const auto& b_desc = kargs.b_grid_desc_n_k;
const auto& c_desc = kargs.c_grid_desc_m_n;
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
@@ -878,8 +1066,18 @@ struct GroupedConvolutionForwardKernel
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm2LDS(
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, i_m, i_n);
RunGemm2LDS(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
smem_ptr_1,
a_desc,
b_desc,
c_desc,
kargs.GemmK,
i_m,
i_n);
}
}
else
@@ -888,7 +1086,17 @@ struct GroupedConvolutionForwardKernel
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
is_any_of<OutDataType, fp16_t, bf16_t>::value))
{
RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n);
RunGemm(a_ptr,
b_ptr,
kargs.ds_ptr,
c_ptr,
smem_ptr_0,
a_desc,
b_desc,
c_desc,
kargs.GemmK,
i_m,
i_n);
}
}
}

View File

@@ -110,4 +110,86 @@ struct GroupedConvTraits
using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
};
/// @brief Helper struct for split-image piece information
///
/// @par Overview
/// Stores metadata for a single spatial piece in split-image convolution.
/// Used to track block ranges and spatial coordinates for each piece.
struct SplitImagePieceInfo
{
ck_tile::index_t block_start, block_end; ///< GPU block range for this piece
ck_tile::index_t d_start, h_start, w_start; ///< Spatial start coordinates (output space)
ck_tile::index_t d_size, h_size, w_size; ///< Spatial dimensions of this piece
};
/// @brief Calculate piece information for split-image convolution
///
/// @par Overview
/// Computes spatial coordinates, dimensions, and GPU block range for a single
/// piece in split-image convolution. Handles edge pieces that may have different
/// sizes due to non-uniform division.
///
/// @tparam TilePartitioner Type providing MPerBlock and NPerBlock constants
///
/// @param piece_idx Index of the piece to calculate (0-based)
/// @param num_d_pieces Number of pieces in D dimension
/// @param num_h_pieces Number of pieces in H dimension
/// @param num_w_pieces Number of pieces in W dimension
/// @param base_piece_d Base size of each D piece (may differ for last piece)
/// @param base_piece_h Base size of each H piece (may differ for last piece)
/// @param base_piece_w Base size of each W piece (may differ for last piece)
/// @param total_d Total D dimension size (output space)
/// @param total_h Total H dimension size (output space)
/// @param total_w Total W dimension size (output space)
/// @param N Batch size
/// @param K Output channels
/// @param total_blocks Accumulated block count from previous pieces
///
/// @return SplitImagePieceInfo containing all metadata for this piece
template <typename TilePartitioner>
CK_TILE_HOST SplitImagePieceInfo calculate_spatial_piece(ck_tile::index_t piece_idx,
ck_tile::index_t num_d_pieces,
ck_tile::index_t num_h_pieces,
ck_tile::index_t num_w_pieces,
ck_tile::index_t base_piece_d,
ck_tile::index_t base_piece_h,
ck_tile::index_t base_piece_w,
ck_tile::index_t total_d,
ck_tile::index_t total_h,
ck_tile::index_t total_w,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t total_blocks)
{
// Unflatten piece index into 3D coordinates (W-major, then H, then D)
const ck_tile::index_t w_idx = piece_idx % num_w_pieces;
const ck_tile::index_t h_idx = (piece_idx / num_w_pieces) % num_h_pieces;
const ck_tile::index_t d_idx = piece_idx / (num_w_pieces * num_h_pieces);
// Calculate spatial start positions
const ck_tile::index_t w_start = w_idx * base_piece_w;
const ck_tile::index_t h_start = h_idx * base_piece_h;
const ck_tile::index_t d_start = d_idx * base_piece_d;
// Calculate piece sizes (last piece may be larger to cover remainder)
const ck_tile::index_t w_size =
(w_idx == num_w_pieces - 1) ? (total_w - w_start) : base_piece_w;
const ck_tile::index_t h_size =
(h_idx == num_h_pieces - 1) ? (total_h - h_start) : base_piece_h;
const ck_tile::index_t d_size =
(d_idx == num_d_pieces - 1) ? (total_d - d_start) : base_piece_d;
// Calculate GEMM dimensions for this piece
const ck_tile::index_t piece_gemm_m = N * d_size * h_size * w_size;
const ck_tile::index_t piece_gemm_n = K;
// Calculate GPU grid size for this piece
const ck_tile::index_t piece_grid =
((piece_gemm_m + TilePartitioner::MPerBlock - 1) / TilePartitioner::MPerBlock) *
((piece_gemm_n + TilePartitioner::NPerBlock - 1) / TilePartitioner::NPerBlock);
return {
total_blocks, total_blocks + piece_grid, d_start, h_start, w_start, d_size, h_size, w_size};
}
} // namespace ck_tile

View File

@@ -5,9 +5,15 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp"
namespace ck_tile {
// ═══════════════════════════════════════════════════════════════════════
// Split-Image Information Structure
// ═══════════════════════════════════════════════════════════════════════
// This structure holds all information needed to perform split-image
// NOTE: SplitImageInfo struct deleted - was only used by deleted recursive split code
// Current split-image implementation is in grouped_convolution_forward_invoker.hpp
template <index_t NDimSpatial,
ConvolutionSpecialization ConvSpecialization,
index_t VectorSizeA,
@@ -28,6 +34,9 @@ struct TransformConvFwdToGemm
static constexpr auto I4 = number<4>{};
static constexpr auto I5 = number<5>{};
// Unified memory limit constant for both Split-N and Split-Image
static constexpr long_index_t TwoGB = (long_index_t{1} << 31); // 2GB
template <typename ConvDimsType>
static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths,
const ConvDimsType& strides,
@@ -47,6 +56,7 @@ struct TransformConvFwdToGemm
static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths,
const ConvDimsType& c_g_n_k_wos_lengths)
{
// Calculate strides internally assuming contiguous memory layout
ConvDimsType a_g_n_c_wis_strides, c_g_n_k_wos_strides;
const index_t num_dims = a_g_n_c_wis_lengths.size();
@@ -71,7 +81,6 @@ struct TransformConvFwdToGemm
calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1);
const long_index_t element_space_size = ck_tile::max(
a_element_space_size * sizeof(ADataType), c_element_space_size * sizeof(CDataType));
constexpr long_index_t TwoGB = (long_index_t{1} << 31); // 2GB
const IndexType N = a_g_n_c_wis_lengths[I1];
@@ -111,6 +120,145 @@ struct TransformConvFwdToGemm
}
}
public:
// Structure to hold split-image decision and factors
struct SplitImageInfo
{
bool should_split;
index_t num_d_pieces;
index_t num_h_pieces;
index_t num_w_pieces;
};
// Calculate split-image factors AFTER considering split-N
// Returns: should_split flag and optimal split factors for D, H, W dimensions
// Strategy: Hierarchical splitting with priority order D → H → W
// Dynamically increases split factors until memory fits below threshold
//
// NOTE: Layout validation should be done at the invoker level before calling this function
// Split-image only works with specific layouts:
// 1D: NWGC (input), GKXC (weight), NWGK (output)
// 2D: NHWGC (input), GKYXC (weight), NHWGK (output)
// 3D: NDHWGC (input), GKZYXC (weight), NDHWGK (output)
CK_TILE_HOST static SplitImageInfo GetSplitImageInfo(
index_t G, index_t N, index_t C, index_t K, index_t D_out, index_t H_out, index_t W_out)
{
SplitImageInfo info{false, 1, 1, 1};
// Estimate memory (simplified calculation)
// Use max of input and output tensor sizes
// Cast to long_index_t to prevent overflow during multiplication
const long_index_t input_elements =
static_cast<long_index_t>(N) * D_out * H_out * W_out * C * G;
const long_index_t output_elements =
static_cast<long_index_t>(N) * D_out * H_out * W_out * K * G;
const long_index_t input_bytes = input_elements * sizeof(ADataType);
const long_index_t output_bytes = output_elements * sizeof(CDataType);
const long_index_t max_tensor_bytes =
(input_bytes > output_bytes) ? input_bytes : output_bytes;
// Calculate effective N after split-N (simplified - assume worst case N=1)
index_t effective_N = 1;
if(max_tensor_bytes > TwoGB && N > 1)
{
// Split-N will reduce to approximately N=1 per launch
effective_N = 1;
}
else
{
effective_N = N;
}
// Check if split-image is needed
auto calc_memory = [&](index_t d_split, index_t h_split, index_t w_split) -> long_index_t {
index_t d_piece = D_out / d_split;
index_t h_piece = H_out / h_split;
index_t w_piece = W_out / w_split;
// Cast to long_index_t to prevent overflow
return static_cast<long_index_t>(effective_N) * d_piece * h_piece * w_piece * K * G *
sizeof(CDataType);
};
// Calculate memory after split-N with no spatial split
const long_index_t memory_after_split_n = calc_memory(1, 1, 1);
// Check if split-image is needed
if(memory_after_split_n <= TwoGB)
{
info.should_split = false;
return info;
}
// Split-image is needed - use hierarchical priority: D → H → W
info.should_split = true;
// Hierarchical splitting strategy:
// 1D: Split W until below threshold
// 2D: Split H first, if still too large then split W
// 3D: Split D first, then H, then W
// IMPORTANT: Maximum 64 pieces total (hardcoded array limit in invoker)
constexpr index_t MAX_TOTAL_PIECES = 64;
// Start with no split
info.num_d_pieces = 1;
info.num_h_pieces = 1;
info.num_w_pieces = 1;
// Try splitting D first (for 3D)
if(D_out > 1)
{
index_t max_d_split = (D_out < MAX_TOTAL_PIECES) ? D_out : MAX_TOTAL_PIECES;
for(index_t d_split = 2; d_split <= max_d_split; d_split++)
{
info.num_d_pieces = d_split;
if(calc_memory(d_split, 1, 1) <= TwoGB)
{
return info; // D split alone is sufficient
}
}
// D split maxed out, try H next
}
// Try splitting H (for 2D/3D)
if(H_out > 1)
{
index_t max_h_split = MAX_TOTAL_PIECES / info.num_d_pieces;
max_h_split = (H_out < max_h_split) ? H_out : max_h_split;
for(index_t h_split = 2; h_split <= max_h_split; h_split++)
{
info.num_h_pieces = h_split;
if(calc_memory(info.num_d_pieces, h_split, 1) <= TwoGB)
{
return info; // D+H split is sufficient
}
}
// H split maxed out, try W next
}
// Try splitting W (for 1D/2D/3D)
index_t max_w_split = MAX_TOTAL_PIECES / (info.num_d_pieces * info.num_h_pieces);
max_w_split = (W_out < max_w_split) ? W_out : max_w_split;
for(index_t w_split = 2; w_split <= max_w_split; w_split++)
{
info.num_w_pieces = w_split;
if(calc_memory(info.num_d_pieces, info.num_h_pieces, w_split) <= TwoGB)
{
return info; // D+H+W split is sufficient
}
}
// If we reach here, even maximum split doesn't fit
// Use maximum allowed split as best effort (capped at 64 total pieces)
info.num_d_pieces = (D_out < 4) ? D_out : 4; // Cap at 4
info.num_h_pieces = (H_out < 4) ? H_out : 4; // Cap at 4
info.num_w_pieces = (W_out < 4) ? W_out : 4; // Cap at 4 (max 4×4×4=64)
return info;
}
public:
// Public getter methods for Split-N support
CK_TILE_HOST constexpr IndexType GetN() const { return N_; }
@@ -192,14 +340,14 @@ struct TransformConvFwdToGemm
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
// Store original N and initialize N_
original_N_ = N_ = c_g_n_k_wos_lengths[I1];
if constexpr(SplitN)
{
N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
}
else
{
N_ = c_g_n_k_wos_lengths[I1];
}
}
template <typename ConvDimsType,
@@ -244,18 +392,13 @@ struct TransformConvFwdToGemm
static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
// Store original N
original_N_ = c_g_n_k_wos_lengths[I1];
// Store original N and initialize N_
original_N_ = N_ = c_g_n_k_wos_lengths[I1];
if constexpr(SplitN)
{
N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
}
else
{
N_ = c_g_n_k_wos_lengths[I1];
original_N_ = N_;
}
}
template <typename ConvDimsType,
@@ -300,136 +443,26 @@ struct TransformConvFwdToGemm
static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
// Store original N before potential splitting
original_N_ = c_g_n_k_wos_lengths[I1];
// Store original N and initialize N_
original_N_ = N_ = c_g_n_k_wos_lengths[I1];
if constexpr(SplitN)
{
N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
}
else
{
N_ = original_N_;
}
}
#if 0 // TODO: Enable these functionalities
__host__ bool AreDescriptorsSmallerThan2GB() const
// Check if descriptors fit within memory threshold
// NOTE: Not currently used - split-image uses different approach in invoker
CK_TILE_HOST bool AreDescriptorsSmallerThan2GB() const
{
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
const long_index_t input_size = static_cast<long_index_t>(N_) * Di_ * Hi_ * Wi_ * C_;
const long_index_t output_size = static_cast<long_index_t>(N_) * Do_ * Ho_ * Wo_ * K_;
const long_index_t in_desc_space_size =
I1 + (N_ - I1) * NStrideTensorA_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ +
(Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorA_;
const long_index_t out_desc_space_size =
I1 + (N_ - I1) * NStrideTensorC_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ +
(Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorC_;
bool is_a_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(ADataType)) <= TwoGB;
bool is_c_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(CDataType)) <= TwoGB;
return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB;
const long_index_t threshold = TwoGB / sizeof(ADataType);
return (input_size < threshold) && (output_size < threshold);
}
__host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
CDataType* c_grid_ptr_base) const
{
// Create copies
auto conv_to_gemm_transformer_left = *this;
auto conv_to_gemm_transformer_right = *this;
IndexType a_right_offset = 0;
IndexType c_right_offset = 0;
// Calculate real filter size
const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1;
const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1;
const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1;
// Calculate start position in input for right tensor
const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_;
const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_;
const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_;
// Calculate last position in input for left tensor
const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff;
const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff;
const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff;
// Allow to split if whole left padding will be in left tensor and right padding in right
// tensor
const bool is_possible_to_split_d = Do_ != 1 &&
di_right_transformer_start_idx > InLeftPadD_ &&
di_left_transformer_end_idx <= (InLeftPadD_ + Di_);
const bool is_possible_to_split_h = Ho_ != 1 &&
hi_right_transformer_start_idx > InLeftPadH_ &&
hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_);
const bool is_possible_to_split_w = Wo_ != 1 &&
wi_right_transformer_start_idx > InLeftPadW_ &&
wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_);
if(is_possible_to_split_d)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left.Do_ = Do_ / 2;
conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2;
// Assign left padding to left convolution
conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
conv_to_gemm_transformer_right.InLeftPadD_ = 0;
// Assign right padding to right convolution
conv_to_gemm_transformer_left.InRightPadD_ = 0;
conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
// Calculate new input size
conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_;
conv_to_gemm_transformer_right.Di_ =
math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_),
(conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff);
;
// Calcualte offsets
a_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_;
c_right_offset = (Do_ / 2) * DoStride_;
}
else if(is_possible_to_split_h)
{
conv_to_gemm_transformer_left.Ho_ = Ho_ / 2;
conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2;
conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
conv_to_gemm_transformer_right.InLeftPadH_ = 0;
conv_to_gemm_transformer_left.InRightPadH_ = 0;
conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_;
conv_to_gemm_transformer_right.Hi_ =
math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_),
(conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff);
a_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_;
c_right_offset = (Ho_ / 2) * HoStride_;
}
else if(is_possible_to_split_w)
{
conv_to_gemm_transformer_left.Wo_ = Wo_ / 2;
conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2;
conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
conv_to_gemm_transformer_right.InLeftPadW_ = 0;
conv_to_gemm_transformer_left.InRightPadW_ = 0;
conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_;
conv_to_gemm_transformer_right.Wi_ =
math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_),
(conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff);
a_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_;
c_right_offset = (Wo_ / 2) * WoStride_;
}
// Return left transform, right transformer, right offset to Input and right offset to
// Output
return ck_tile::make_tuple(conv_to_gemm_transformer_left,
conv_to_gemm_transformer_right,
a_grid_ptr_base + a_right_offset,
c_grid_ptr_base + c_right_offset);
}
#endif
// TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as
// properties
template <typename ALayout,
@@ -1510,6 +1543,18 @@ struct TransformConvFwdToGemm
}
}
// ═══════════════════════════════════════════════════════════════════════
// Split-Image Calculation (AFTER Split-N)
// ═══════════════════════════════════════════════════════════════════════
// This method calculates split-image information using N_ (after Split-N).
// This ensures correct offset calculations when both Split-N and Split-Image
// are active simultaneously.
// NOTE: Deleted CalculateSplitImage() and LaunchWithRecursiveSplit() - dead code
// Current split-image implementation is in grouped_convolution_forward_invoker.hpp
public:
private:
IndexType G_, N_, original_N_;
IndexType Di_, Hi_, Wi_;
IndexType Do_, Ho_, Wo_;