mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
[CK TILE] Grouped conv fwd split image (#2970)
* Refactor split-image implementation: simplify code and remove redundant variables
* Add padding debug output to split-image implementation
- Added debug prints for padding calculations in transform_conv_fwd_to_gemm.hpp
- Verified padding works correctly with all tests passing
* Fix sign comparison warning after rebase with origin/develop
- Cast blockIdX from unsigned to signed index_t for comparisons
- Integrated with new GetOutputTileIndex logic from upstream
- Updated to use amd_wave_read_first_lane instead of __builtin_amdgcn_readfirstlane
* Fix Split-N with groups bug and clean up unused parameters
- Fixed batch stride calculation to include G dimension for grouped convolutions
- When moving between batches in NHWGC/NWGC/NDHWGC layouts, need to account for all groups
- Removed unused multi-split parameters (we only support 2-way split)
- All tests now pass: G=1 with Split-N, G>1 with Split-N, G>1 without Split-N
* Implement recursive queue-based split-image detection and calculation
- Add LaunchKernelWithSplitIfNeeded() helper method in transform_conv_fwd_to_gemm.hpp
- Implement recursive binary splitting algorithm (10GB→5GB+5GB→...)
- Correctly handle odd dimensions (61→30+31)
- Calculate proper offsets for each split piece
- Update invoker to use split-image helper
Note: Split detection and calculation work correctly but kernel launching
for individual pieces requires kernel modification to handle different
spatial dimensions (unlike Split-N which uses blockIdx.z).
* WIP: Split-Image investigation - found architecture mismatch
- Split-N modifies N_ directly in transformer constructor
- Split-Image needs different approach due to varying dimensions
- Added split calculation logic for 1D and 2D convolutions
- Still facing memory issues when creating piece transformers
Key finding: Split-N uses blockIdx.z for parallel execution,
while Split-Image needs sequential execution of non-uniform pieces.
* Add 1D split-image implementation for grouped convolution (N=1 working)
Implements split-image for 1D convolution to handle large tensors that
exceed memory thresholds. This is a critical milestone with N=1 fully
working and tested.
Key Changes:
- Invoker: Add split-image logic that splits W dimension in half
- Transformer: Add SplitConvProblem helper for recursive splitting
- Calculate offsets for LEFT and RIGHT pieces
- Launch two kernels sequentially (LEFT then RIGHT)
Implementation Details:
- Binary split: divides W dimension by 2
- LEFT piece: W=0 to W/2, keeps left padding, removes right padding
- RIGHT piece: W/2 to W, removes left padding, keeps right padding
- Offset calculation accounts for stride, dilation, and padding
- Physical memory offset (no padding in memory)
Test Results (N=1):
✅ 94/94 tests passing
- Comprehensive tests: 36/36 (channels, padding, stride, dilation, filters, groups)
- Edge case tests: 31/31 (odd dimensions, extreme parameters, boundaries)
- Stress tests: 27/27 (maximum dimensions, up to 91.4 TFlops)
Known Limitations:
- Only works with N=1 (single batch)
- N>1 fails when split-image triggers (offset calculation issue with Split-N)
- Root cause: Split-N modifies N in transformer, but offset calculated in invoker
- Solution planned: Move offset calculation to transformer (next phase)
Files Modified:
- grouped_convolution_forward_invoker.hpp: Add split-image logic
- transform_conv_fwd_to_gemm.hpp: Add SplitConvProblem helper
This commit represents a stable, tested 1D split-image implementation
for N=1 cases. It's an important milestone before extending to N>1
and multi-dimensional splits.
* Add basic split-image implementation for 1D/2D/3D grouped convolution
This is a working baseline implementation that splits large spatial
dimensions to handle memory constraints.
Implementation:
- 1D: W-split for NWGC layout (36/36 tests passing)
- 2D: H-split for NHWGC layout (20/20 tests passing)
- 3D: D-split for NDHWGC layout (verified working)
Features:
- Binary split of outermost spatial dimension
- Sequential LEFT/RIGHT kernel launches
- Proper padding adjustment at split boundaries
- Offset calculation for pointer arithmetic
- Debug output for verification
Threshold: 100KB (configurable in transformer)
Known limitations:
- No safety checks for edge cases (to be added)
- Offset calculated before Split-N (incompatible with N>1, to be fixed)
- No recursive splitting for very large tensors
Next steps:
- Add safety checks (is_possible_to_split_*)
- Move offset calculation to transformer (after Split-N)
- Test with N>1 + split-image combination
* Refactor split-image to unified structure for 1D/2D/3D
Unified the three separate dimension-specific blocks into a single
common implementation with dimension-specific stride calculations.
Benefits:
- Reduced code from 636 → 348 lines (45% reduction)
- Eliminated code duplication
- Easier to maintain and extend
- Single source of truth for split logic
Implementation:
- Common: Binary split, offset calc, padding adjustment, kernel launch
- Dimension-specific: Stride calculation only
- 1D: stride = G * C
- 2D: stride = W_in * G * C
- 3D: stride = H_in * W_in * G * C
Test results (all passing):
- 1D: 36/36 tests ✅
- 2D: 20/20 tests ✅
- 3D: 28/28 tests ✅
- Total: 84/84 (100%)
All test scenarios verified:
- Varying channels, padding, stride, dilation
- Filter sizes (1x1 pointwise to 7x7)
- Multiple groups (G=1,2,4)
- Odd dimensions
- Complex combinations
* Add safety checks for split-image in all dimensions
Added is_possible_to_split safety checks to prevent crashes when
splitting is not feasible.
Safety checks verify:
1. Output dimension > 1 (can't split single element)
2. RIGHT piece starts after left padding
3. LEFT piece ends within input bounds
If checks fail, falls back to normal kernel launch.
Verified for all dimensions:
- 1D (W-split): Wo=1 case triggers fallback
- 2D (H-split): Ho=1 case triggers fallback
- 3D (D-split): Do=1 case triggers fallback
Original 84 tests still pass - they use normal configurations
that naturally satisfy safety conditions.
Safety checks protect against pathological edge cases with:
- Very small spatial dimensions
- Extreme stride/dilation combinations
- Invalid padding configurations
* Fix Split-N + Split-Image compatibility issue
Fixed critical bug where Split-N and Split-Image working together
caused ~50% incorrect results due to wrong batch stride calculation.
Problem:
- Batch stride was calculated using MODIFIED spatial dimensions
(e.g., W=50000 after split) instead of ORIGINAL dimensions (W=100000)
- Spatial offset was applied globally in invoker, not per-batch in kernel
- Each batch (blockIdx.z) got wrong memory offset
Solution:
1. Store spatial offset in kargs (don't apply to pointer in invoker)
2. Copy correct batch_stride from temp_kargs to left/right kargs
3. Apply formula in operator(): ptr = base + (batch × stride) + spatial_offset
Changes:
- grouped_convolution_forward_kernel.hpp:
* Added spatial_offset_in/out fields to KernelArgs
* Apply batch + spatial offset in operator()
- grouped_convolution_forward_invoker.hpp:
* Keep base pointer, store spatial offset in kargs
* Copy batch_stride from temp_kargs (has original dimensions)
- transform_conv_fwd_to_gemm.hpp:
* Add debug output for split-image calculation
Results:
- N=1 tests: 84/84 passing (100%)
- N>1 tests: Now all passing (previously ~50% errors)
- Tested: 1D, 2D, 3D with N=1,2,4,8,16,20
* Implement unified threshold for Split-N and Split-Image
This commit consolidates threshold management for both Split-N and
Split-Image operations into a single source of truth, eliminating
code duplication and fixing offset calculation issues.
Key Changes:
============
1. Transformer (transform_conv_fwd_to_gemm.hpp):
- Moved TwoGB constant to public section for unified access
- CalculateSplitImage() now takes no parameters
- Uses internal threshold: TwoGB / sizeof(CDataType)
- Calculates offsets using N_ (after Split-N) for correctness
2. Kernel (grouped_convolution_forward_kernel.hpp):
- GetSplitImageInfo() simplified to take no parameters
- Forwards to transformer's CalculateSplitImage()
- Clean interface with unified threshold internally
3. Invoker (grouped_convolution_forward_invoker.hpp):
- Removed redundant threshold calculation
- Simplified to call kargs.GetSplitImageInfo() with no params
- Clean early-return pattern (no unnecessary else blocks)
- Removed duplicate/dead code paths
Benefits:
=========
- Single source of truth: TwoGB defined once in transformer
- No parameter passing for threshold between components
- Correct offset calculation using N_ (post-Split-N)
- Cleaner code with no duplication
- All tests passing: 1D/2D/3D with various N values
Testing:
========
- Split-Image only (N=1, large spatial): PASS
- Split-N only (N>1, small spatial): PASS
- Both splits active (N>1, large spatial): PASS
- No splits (N=1, small spatial): PASS
- CPU verification correct for all scenarios
* Comment out outdated split-image code (SplitConvProblem/LaunchKernelWithSplitIfNeeded)
The old recursive queue-based implementation has been replaced by the
new CalculateSplitImage() method which is simpler and correctly handles
Split-N + Split-Image interaction.
Changes:
- Wrapped lines 381-1078 in #if 0...#endif
- Old methods: SplitConvProblem() and LaunchKernelWithSplitIfNeeded()
- Preserved for reference but disabled from compilation
- No functional changes - all tests still pass
The new implementation (CalculateSplitImage at line ~2163) provides:
- Correct offset calculation using N_ (after Split-N)
- Simpler binary split logic
- Better integration with unified threshold approach
* Implement recursive split-image with depth limit (MAX_DEPTH=10)
Changes:
- Add depth tracking to SplitPiece struct
- Implement two stopping conditions:
1. Piece size below threshold (optimal case)
2. Depth >= MAX_DEPTH (prevents infinite recursion)
- Remove MAX_PIECES limit in favor of depth-based control
- Support up to 2^10 = 1024 pieces with depth 10
This allows handling extreme tensor sizes while ensuring termination.
Pieces larger than threshold will still launch correctly if depth limit reached.
Tested with H=100 (4 levels), H=2000 (6 levels), H=4000 (9 levels) - all pass CPU verification.
* Summary of recursive split-image implementation:
- Recursive queue-based splitting with depth limit (MAX_DEPTH=10, up to 1024 pieces)
- Two stopping conditions: size below threshold OR max depth reached
- Cumulative offset tracking through all recursion levels
- LEFT piece inherits parent offset, RIGHT accumulates (parent + local)
- Per-batch spatial offset application in kernel operator()
- Batch stride uses original dimensions (before split)
- Works with Split-N: split-N first, then recursive split-image
- Handles odd dimensions, padding, stride, dilation correctly
- All 1D/2D/3D tests pass with CPU verification
* Add comment explaining MAX_DEPTH capacity for 2GB threshold
* Refactor: move recursive split-image logic to transformer
- Move LaunchWithRecursiveSplit() from invoker to transform_conv_fwd_to_gemm.hpp
- Simplify invoker from ~250 lines to ~140 lines (removed 110 lines of inline logic)
- Encapsulate SplitPiece struct and BFS splitting algorithm in transformer
- Remove unused includes (queue, vector) from invoker
- Add documentation comment for AreDescriptorsSmallerThan2GB()
- Improve code organization and reusability
- No performance overhead (static template function, compiler inlines)
- All tests passing with 2GB production threshold
* Apply clang-format-18 formatting
- Format invoker and transformer files with clang-format-18
- Fix brace placement and alignment
- No functional changes
* Fix clang-format-18 issues in forward kernel
- Remove extra blank lines
- Fix line wrapping for template calls
- Consolidate GetSplitImageInfo() to single line
* Update include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Update include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Update include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Update include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Split-Image implementation with temporary fixed divider
- Implemented spatial dimension splitting (Split-Image) for large tensors
- Added piece-based coordinate transformation for 1D/2D/3D convolutions
- Integrated Split-N (batch splitting) with automatic threshold detection
- Fixed M dimension calculation to include batch: M = N × spatial_size
- Added spatial offset support in kernel arguments
- Verified 20/20 test cases passing for Split-Image alone
- Known issue: Split-N + Split-Image combination needs coordinate fix
Implementation Details:
- Split factors: 4 (1D), 4×4 (2D), 4×4×4 (3D) - temporary fixed values
- Batch strides properly calculated for NWGC/NHWGC/NDHWGC layouts
- Piece descriptors track spatial boundaries and block ranges
- No performance overhead for N=1 cases
* Fix 1D split-image padding issue with per-piece dimensions
- Store actual size per piece to handle non-uniform splits
- Remove dead code from transform utils
* Fix 2D/3D split-image with independent split factors per dimension
Problem: Single split factor caused non-uniform pieces when dimensions
didn't divide evenly. Result: 18/25 (72%) 2D padding combinations failed.
Solution: Independent split factor selection for W, H, D dimensions.
Each dimension gets optimal factor based on its own size.
Test Results:
- 1D: 42/42 pass (100%)
- 2D: 25/25 pass (100%)
- Total: 67/67 combinations verified
* Remove unused split-image struct fields
Cleanup of split-image implementation:
- Removed unused piece_d, piece_h, piece_w fields from SplitImageInfo struct
- These fields were declared but never used in the kernel
- Per-piece dimensions are already stored in pieces[] array
- Reduces struct size and improves code clarity
Tested: 1D/2D/3D convolutions with split-image, padding, stride all pass
* Refactor split-image invoker code for improved readability
- Extract piece calculation logic into calculate_piece lambda helper
- Extract kernel args population into populate_split_image_kargs lambda
- Use aggregate initialization for cleaner struct population
- Reduce nesting depth and improve maintainability
- Fix outdated comment about split-image implementation status
* Refactor split-image code and remove debug prints
- Extract GPU kernel helper lambdas for better readability
- Remove all split-image debug print statements
- Set memory threshold to 2GB for production
- All tests pass with CPU verification
* Add split-image safety constraints and refactor to utils
- Add MAX_TOTAL_PIECES=64 limit to prevent segfault
- Move calculate_spatial_piece to library utils
- Add layout validation (NWGC, NHWGC, NDHWGC only)
- Fix hierarchical splitting to respect piece limits
- Add proper documentation and formatting
* Change split-image from runtime to compile-time branching
Response to @bartekxk review comment:
Convert 'if(kargs.num_spatial_pieces > 1)' to 'if constexpr(EnableSplitImage)'
Changes:
- Add EnableSplitImage template parameter to kernel
- Change runtime if to compile-time if constexpr
- Update invoker to instantiate kernel variants with true/false
Benefits:
- Eliminates runtime branching in GPU kernel
- Dead code elimination (each variant is smaller)
- Better compiler optimization
Files modified: 2
Lines changed: 20 total (6 in kernel, 14 in invoker)
Tests: 27/27 passed (100%)
Performance: No regression
* Add split-image example as separate binary
- Create grouped_convolution_forward_split_image example
- Add grouped_convolution_forward_split_image_invoker.hpp
- Update CMakeLists.txt to build split_image binary
* Replace linear search with binary search in find_piece_id
- Change O(n) to O(log n) for finding piece ownership
- Matches reference implementation in large_tensor_cshuffle
* Simplify split-image code and fix integer overflow
- Extract lambda functions to static helper methods
- Pre-calculate constants in invoker
- Fix integer overflow in tensor size calculation for large tensors
* Trigger CI rerun - fix merge conflicts
* Fix merge conflict markers
* Fix clang-format: remove space before {}
* Fix clang-format: comment wrapping and Swish constructor
* Rename split_image to large_tensor for clarity
- Renamed grouped_convolution_forward_split_image.cpp -> grouped_convolution_forward_large_tensor.cpp
- Renamed grouped_convolution_forward_split_image_invoker.hpp -> grouped_convolution_forward_large_tensor_invoker.hpp
- Updated CMakeLists.txt target name: tile_example_grouped_conv_fwd_split_image -> tile_example_grouped_conv_fwd_large_tensor
- Updated comments to refer to 'large tensor' instead of 'split-image'
* Update comments and include in large_tensor example
- Updated header comments to use 'large tensor' terminology
- Fixed include path to use large_tensor_invoker.hpp
* Remove test code, restore 2GB threshold
* Update include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Fix build errors after develop merge and complete rename to large_tensor
This commit addresses compilation errors from the develop merge and
completes the rename from split_image to large_tensor.
Changes:
1. Fix CDEElementWise typo in grouped_convolution_forward_invoker.hpp
2. Fix template parameter order in large_tensor_invoker.hpp
- TransformConvFwdToGemm signature changed in develop
- NumGroupsToMerge and SplitN parameters swapped positions
3. Fix missing template parameter in GroupedConvFwdHostArgs
4. Fix EpiloguePipeline scope in kernel (merge conflict)
5. Update binary name references in test scripts
* Restore 2GB threshold for split-image
Changed threshold from 100MB (testing) back to 2GB for production use.
* Fix const-correctness in ds_ptr cast
* Update include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Apply clang-format-18
* update c++ 18 format
* Apply clang-format-18 to transform_conv_fwd_to_gemm.hpp
---------
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
[ROCm/composable_kernel commit: 1fbb47ad30]
This commit is contained in:
@@ -110,4 +110,86 @@ struct GroupedConvTraits
|
||||
using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
|
||||
};
|
||||
|
||||
/// @brief Helper struct for split-image piece information
|
||||
///
|
||||
/// @par Overview
|
||||
/// Stores metadata for a single spatial piece in split-image convolution.
|
||||
/// Used to track block ranges and spatial coordinates for each piece.
|
||||
struct SplitImagePieceInfo
|
||||
{
|
||||
ck_tile::index_t block_start, block_end; ///< GPU block range for this piece
|
||||
ck_tile::index_t d_start, h_start, w_start; ///< Spatial start coordinates (output space)
|
||||
ck_tile::index_t d_size, h_size, w_size; ///< Spatial dimensions of this piece
|
||||
};
|
||||
|
||||
/// @brief Calculate piece information for split-image convolution
|
||||
///
|
||||
/// @par Overview
|
||||
/// Computes spatial coordinates, dimensions, and GPU block range for a single
|
||||
/// piece in split-image convolution. Handles edge pieces that may have different
|
||||
/// sizes due to non-uniform division.
|
||||
///
|
||||
/// @tparam TilePartitioner Type providing MPerBlock and NPerBlock constants
|
||||
///
|
||||
/// @param piece_idx Index of the piece to calculate (0-based)
|
||||
/// @param num_d_pieces Number of pieces in D dimension
|
||||
/// @param num_h_pieces Number of pieces in H dimension
|
||||
/// @param num_w_pieces Number of pieces in W dimension
|
||||
/// @param base_piece_d Base size of each D piece (may differ for last piece)
|
||||
/// @param base_piece_h Base size of each H piece (may differ for last piece)
|
||||
/// @param base_piece_w Base size of each W piece (may differ for last piece)
|
||||
/// @param total_d Total D dimension size (output space)
|
||||
/// @param total_h Total H dimension size (output space)
|
||||
/// @param total_w Total W dimension size (output space)
|
||||
/// @param N Batch size
|
||||
/// @param K Output channels
|
||||
/// @param total_blocks Accumulated block count from previous pieces
|
||||
///
|
||||
/// @return SplitImagePieceInfo containing all metadata for this piece
|
||||
template <typename TilePartitioner>
|
||||
CK_TILE_HOST SplitImagePieceInfo calculate_spatial_piece(ck_tile::index_t piece_idx,
|
||||
ck_tile::index_t num_d_pieces,
|
||||
ck_tile::index_t num_h_pieces,
|
||||
ck_tile::index_t num_w_pieces,
|
||||
ck_tile::index_t base_piece_d,
|
||||
ck_tile::index_t base_piece_h,
|
||||
ck_tile::index_t base_piece_w,
|
||||
ck_tile::index_t total_d,
|
||||
ck_tile::index_t total_h,
|
||||
ck_tile::index_t total_w,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t total_blocks)
|
||||
{
|
||||
// Unflatten piece index into 3D coordinates (W-major, then H, then D)
|
||||
const ck_tile::index_t w_idx = piece_idx % num_w_pieces;
|
||||
const ck_tile::index_t h_idx = (piece_idx / num_w_pieces) % num_h_pieces;
|
||||
const ck_tile::index_t d_idx = piece_idx / (num_w_pieces * num_h_pieces);
|
||||
|
||||
// Calculate spatial start positions
|
||||
const ck_tile::index_t w_start = w_idx * base_piece_w;
|
||||
const ck_tile::index_t h_start = h_idx * base_piece_h;
|
||||
const ck_tile::index_t d_start = d_idx * base_piece_d;
|
||||
|
||||
// Calculate piece sizes (last piece may be larger to cover remainder)
|
||||
const ck_tile::index_t w_size =
|
||||
(w_idx == num_w_pieces - 1) ? (total_w - w_start) : base_piece_w;
|
||||
const ck_tile::index_t h_size =
|
||||
(h_idx == num_h_pieces - 1) ? (total_h - h_start) : base_piece_h;
|
||||
const ck_tile::index_t d_size =
|
||||
(d_idx == num_d_pieces - 1) ? (total_d - d_start) : base_piece_d;
|
||||
|
||||
// Calculate GEMM dimensions for this piece
|
||||
const ck_tile::index_t piece_gemm_m = N * d_size * h_size * w_size;
|
||||
const ck_tile::index_t piece_gemm_n = K;
|
||||
|
||||
// Calculate GPU grid size for this piece
|
||||
const ck_tile::index_t piece_grid =
|
||||
((piece_gemm_m + TilePartitioner::MPerBlock - 1) / TilePartitioner::MPerBlock) *
|
||||
((piece_gemm_n + TilePartitioner::NPerBlock - 1) / TilePartitioner::NPerBlock);
|
||||
|
||||
return {
|
||||
total_blocks, total_blocks + piece_grid, d_start, h_start, w_start, d_size, h_size, w_size};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -5,9 +5,15 @@
|
||||
#pragma once
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution/utils/convolution_specialization.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
// Split-Image Information Structure
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
// This structure holds all information needed to perform split-image
|
||||
// NOTE: SplitImageInfo struct deleted - was only used by deleted recursive split code
|
||||
// Current split-image implementation is in grouped_convolution_forward_invoker.hpp
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
ConvolutionSpecialization ConvSpecialization,
|
||||
index_t VectorSizeA,
|
||||
@@ -28,6 +34,9 @@ struct TransformConvFwdToGemm
|
||||
static constexpr auto I4 = number<4>{};
|
||||
static constexpr auto I5 = number<5>{};
|
||||
|
||||
// Unified memory limit constant for both Split-N and Split-Image
|
||||
static constexpr long_index_t TwoGB = (long_index_t{1} << 31); // 2GB
|
||||
|
||||
template <typename ConvDimsType>
|
||||
static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths,
|
||||
const ConvDimsType& strides,
|
||||
@@ -47,6 +56,7 @@ struct TransformConvFwdToGemm
|
||||
static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths,
|
||||
const ConvDimsType& c_g_n_k_wos_lengths)
|
||||
{
|
||||
|
||||
// Calculate strides internally assuming contiguous memory layout
|
||||
ConvDimsType a_g_n_c_wis_strides, c_g_n_k_wos_strides;
|
||||
const index_t num_dims = a_g_n_c_wis_lengths.size();
|
||||
@@ -71,7 +81,6 @@ struct TransformConvFwdToGemm
|
||||
calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1);
|
||||
const long_index_t element_space_size = ck_tile::max(
|
||||
a_element_space_size * sizeof(ADataType), c_element_space_size * sizeof(CDataType));
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31); // 2GB
|
||||
|
||||
const IndexType N = a_g_n_c_wis_lengths[I1];
|
||||
|
||||
@@ -111,6 +120,145 @@ struct TransformConvFwdToGemm
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
// Structure to hold split-image decision and factors
|
||||
struct SplitImageInfo
|
||||
{
|
||||
bool should_split;
|
||||
index_t num_d_pieces;
|
||||
index_t num_h_pieces;
|
||||
index_t num_w_pieces;
|
||||
};
|
||||
|
||||
// Calculate split-image factors AFTER considering split-N
|
||||
// Returns: should_split flag and optimal split factors for D, H, W dimensions
|
||||
// Strategy: Hierarchical splitting with priority order D → H → W
|
||||
// Dynamically increases split factors until memory fits below threshold
|
||||
//
|
||||
// NOTE: Layout validation should be done at the invoker level before calling this function
|
||||
// Split-image only works with specific layouts:
|
||||
// 1D: NWGC (input), GKXC (weight), NWGK (output)
|
||||
// 2D: NHWGC (input), GKYXC (weight), NHWGK (output)
|
||||
// 3D: NDHWGC (input), GKZYXC (weight), NDHWGK (output)
|
||||
CK_TILE_HOST static SplitImageInfo GetSplitImageInfo(
|
||||
index_t G, index_t N, index_t C, index_t K, index_t D_out, index_t H_out, index_t W_out)
|
||||
{
|
||||
SplitImageInfo info{false, 1, 1, 1};
|
||||
|
||||
// Estimate memory (simplified calculation)
|
||||
// Use max of input and output tensor sizes
|
||||
// Cast to long_index_t to prevent overflow during multiplication
|
||||
const long_index_t input_elements =
|
||||
static_cast<long_index_t>(N) * D_out * H_out * W_out * C * G;
|
||||
const long_index_t output_elements =
|
||||
static_cast<long_index_t>(N) * D_out * H_out * W_out * K * G;
|
||||
const long_index_t input_bytes = input_elements * sizeof(ADataType);
|
||||
const long_index_t output_bytes = output_elements * sizeof(CDataType);
|
||||
const long_index_t max_tensor_bytes =
|
||||
(input_bytes > output_bytes) ? input_bytes : output_bytes;
|
||||
|
||||
// Calculate effective N after split-N (simplified - assume worst case N=1)
|
||||
index_t effective_N = 1;
|
||||
if(max_tensor_bytes > TwoGB && N > 1)
|
||||
{
|
||||
// Split-N will reduce to approximately N=1 per launch
|
||||
effective_N = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
effective_N = N;
|
||||
}
|
||||
|
||||
// Check if split-image is needed
|
||||
auto calc_memory = [&](index_t d_split, index_t h_split, index_t w_split) -> long_index_t {
|
||||
index_t d_piece = D_out / d_split;
|
||||
index_t h_piece = H_out / h_split;
|
||||
index_t w_piece = W_out / w_split;
|
||||
// Cast to long_index_t to prevent overflow
|
||||
return static_cast<long_index_t>(effective_N) * d_piece * h_piece * w_piece * K * G *
|
||||
sizeof(CDataType);
|
||||
};
|
||||
|
||||
// Calculate memory after split-N with no spatial split
|
||||
const long_index_t memory_after_split_n = calc_memory(1, 1, 1);
|
||||
|
||||
// Check if split-image is needed
|
||||
if(memory_after_split_n <= TwoGB)
|
||||
{
|
||||
info.should_split = false;
|
||||
return info;
|
||||
}
|
||||
|
||||
// Split-image is needed - use hierarchical priority: D → H → W
|
||||
info.should_split = true;
|
||||
|
||||
// Hierarchical splitting strategy:
|
||||
// 1D: Split W until below threshold
|
||||
// 2D: Split H first, if still too large then split W
|
||||
// 3D: Split D first, then H, then W
|
||||
|
||||
// IMPORTANT: Maximum 64 pieces total (hardcoded array limit in invoker)
|
||||
constexpr index_t MAX_TOTAL_PIECES = 64;
|
||||
|
||||
// Start with no split
|
||||
info.num_d_pieces = 1;
|
||||
info.num_h_pieces = 1;
|
||||
info.num_w_pieces = 1;
|
||||
|
||||
// Try splitting D first (for 3D)
|
||||
if(D_out > 1)
|
||||
{
|
||||
index_t max_d_split = (D_out < MAX_TOTAL_PIECES) ? D_out : MAX_TOTAL_PIECES;
|
||||
for(index_t d_split = 2; d_split <= max_d_split; d_split++)
|
||||
{
|
||||
info.num_d_pieces = d_split;
|
||||
if(calc_memory(d_split, 1, 1) <= TwoGB)
|
||||
{
|
||||
return info; // D split alone is sufficient
|
||||
}
|
||||
}
|
||||
// D split maxed out, try H next
|
||||
}
|
||||
|
||||
// Try splitting H (for 2D/3D)
|
||||
if(H_out > 1)
|
||||
{
|
||||
index_t max_h_split = MAX_TOTAL_PIECES / info.num_d_pieces;
|
||||
max_h_split = (H_out < max_h_split) ? H_out : max_h_split;
|
||||
|
||||
for(index_t h_split = 2; h_split <= max_h_split; h_split++)
|
||||
{
|
||||
info.num_h_pieces = h_split;
|
||||
if(calc_memory(info.num_d_pieces, h_split, 1) <= TwoGB)
|
||||
{
|
||||
return info; // D+H split is sufficient
|
||||
}
|
||||
}
|
||||
// H split maxed out, try W next
|
||||
}
|
||||
|
||||
// Try splitting W (for 1D/2D/3D)
|
||||
index_t max_w_split = MAX_TOTAL_PIECES / (info.num_d_pieces * info.num_h_pieces);
|
||||
max_w_split = (W_out < max_w_split) ? W_out : max_w_split;
|
||||
|
||||
for(index_t w_split = 2; w_split <= max_w_split; w_split++)
|
||||
{
|
||||
info.num_w_pieces = w_split;
|
||||
if(calc_memory(info.num_d_pieces, info.num_h_pieces, w_split) <= TwoGB)
|
||||
{
|
||||
return info; // D+H+W split is sufficient
|
||||
}
|
||||
}
|
||||
|
||||
// If we reach here, even maximum split doesn't fit
|
||||
// Use maximum allowed split as best effort (capped at 64 total pieces)
|
||||
info.num_d_pieces = (D_out < 4) ? D_out : 4; // Cap at 4
|
||||
info.num_h_pieces = (H_out < 4) ? H_out : 4; // Cap at 4
|
||||
info.num_w_pieces = (W_out < 4) ? W_out : 4; // Cap at 4 (max 4×4×4=64)
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
public:
|
||||
// Public getter methods for Split-N support
|
||||
CK_TILE_HOST constexpr IndexType GetN() const { return N_; }
|
||||
@@ -192,14 +340,14 @@ struct TransformConvFwdToGemm
|
||||
std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
|
||||
static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
|
||||
std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
|
||||
|
||||
// Store original N and initialize N_
|
||||
original_N_ = N_ = c_g_n_k_wos_lengths[I1];
|
||||
|
||||
if constexpr(SplitN)
|
||||
{
|
||||
N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
|
||||
}
|
||||
else
|
||||
{
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ConvDimsType,
|
||||
@@ -244,18 +392,13 @@ struct TransformConvFwdToGemm
|
||||
static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
|
||||
std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
|
||||
|
||||
// Store original N
|
||||
original_N_ = c_g_n_k_wos_lengths[I1];
|
||||
// Store original N and initialize N_
|
||||
original_N_ = N_ = c_g_n_k_wos_lengths[I1];
|
||||
|
||||
if constexpr(SplitN)
|
||||
{
|
||||
N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
|
||||
}
|
||||
else
|
||||
{
|
||||
N_ = c_g_n_k_wos_lengths[I1];
|
||||
original_N_ = N_;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ConvDimsType,
|
||||
@@ -300,136 +443,26 @@ struct TransformConvFwdToGemm
|
||||
static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
|
||||
std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
|
||||
|
||||
// Store original N before potential splitting
|
||||
original_N_ = c_g_n_k_wos_lengths[I1];
|
||||
// Store original N and initialize N_
|
||||
original_N_ = N_ = c_g_n_k_wos_lengths[I1];
|
||||
|
||||
if constexpr(SplitN)
|
||||
{
|
||||
N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
|
||||
}
|
||||
else
|
||||
{
|
||||
N_ = original_N_;
|
||||
}
|
||||
}
|
||||
|
||||
#if 0 // TODO: Enable these functionalities
|
||||
__host__ bool AreDescriptorsSmallerThan2GB() const
|
||||
// Check if descriptors fit within memory threshold
|
||||
// NOTE: Not currently used - split-image uses different approach in invoker
|
||||
CK_TILE_HOST bool AreDescriptorsSmallerThan2GB() const
|
||||
{
|
||||
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
|
||||
const long_index_t input_size = static_cast<long_index_t>(N_) * Di_ * Hi_ * Wi_ * C_;
|
||||
const long_index_t output_size = static_cast<long_index_t>(N_) * Do_ * Ho_ * Wo_ * K_;
|
||||
|
||||
const long_index_t in_desc_space_size =
|
||||
I1 + (N_ - I1) * NStrideTensorA_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ +
|
||||
(Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorA_;
|
||||
const long_index_t out_desc_space_size =
|
||||
I1 + (N_ - I1) * NStrideTensorC_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ +
|
||||
(Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorC_;
|
||||
|
||||
bool is_a_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(ADataType)) <= TwoGB;
|
||||
bool is_c_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(CDataType)) <= TwoGB;
|
||||
|
||||
return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB;
|
||||
const long_index_t threshold = TwoGB / sizeof(ADataType);
|
||||
return (input_size < threshold) && (output_size < threshold);
|
||||
}
|
||||
|
||||
__host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
|
||||
CDataType* c_grid_ptr_base) const
|
||||
{
|
||||
// Create copies
|
||||
auto conv_to_gemm_transformer_left = *this;
|
||||
auto conv_to_gemm_transformer_right = *this;
|
||||
IndexType a_right_offset = 0;
|
||||
IndexType c_right_offset = 0;
|
||||
// Calculate real filter size
|
||||
const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1;
|
||||
const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1;
|
||||
const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1;
|
||||
// Calculate start position in input for right tensor
|
||||
const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_;
|
||||
const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_;
|
||||
const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_;
|
||||
// Calculate last position in input for left tensor
|
||||
const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff;
|
||||
const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff;
|
||||
const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff;
|
||||
// Allow to split if whole left padding will be in left tensor and right padding in right
|
||||
// tensor
|
||||
const bool is_possible_to_split_d = Do_ != 1 &&
|
||||
di_right_transformer_start_idx > InLeftPadD_ &&
|
||||
di_left_transformer_end_idx <= (InLeftPadD_ + Di_);
|
||||
const bool is_possible_to_split_h = Ho_ != 1 &&
|
||||
hi_right_transformer_start_idx > InLeftPadH_ &&
|
||||
hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_);
|
||||
const bool is_possible_to_split_w = Wo_ != 1 &&
|
||||
wi_right_transformer_start_idx > InLeftPadW_ &&
|
||||
wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_);
|
||||
|
||||
if(is_possible_to_split_d)
|
||||
{
|
||||
// Apply new sizes
|
||||
// Split output on half
|
||||
conv_to_gemm_transformer_left.Do_ = Do_ / 2;
|
||||
conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2;
|
||||
// Assign left padding to left convolution
|
||||
conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
|
||||
conv_to_gemm_transformer_right.InLeftPadD_ = 0;
|
||||
// Assign right padding to right convolution
|
||||
conv_to_gemm_transformer_left.InRightPadD_ = 0;
|
||||
conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
|
||||
// Calculate new input size
|
||||
conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_;
|
||||
conv_to_gemm_transformer_right.Di_ =
|
||||
math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_),
|
||||
(conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff);
|
||||
;
|
||||
// Calcualte offsets
|
||||
a_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_;
|
||||
c_right_offset = (Do_ / 2) * DoStride_;
|
||||
}
|
||||
else if(is_possible_to_split_h)
|
||||
{
|
||||
conv_to_gemm_transformer_left.Ho_ = Ho_ / 2;
|
||||
conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2;
|
||||
|
||||
conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
|
||||
conv_to_gemm_transformer_right.InLeftPadH_ = 0;
|
||||
|
||||
conv_to_gemm_transformer_left.InRightPadH_ = 0;
|
||||
conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
|
||||
|
||||
conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_;
|
||||
conv_to_gemm_transformer_right.Hi_ =
|
||||
math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_),
|
||||
(conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff);
|
||||
a_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_;
|
||||
c_right_offset = (Ho_ / 2) * HoStride_;
|
||||
}
|
||||
else if(is_possible_to_split_w)
|
||||
{
|
||||
conv_to_gemm_transformer_left.Wo_ = Wo_ / 2;
|
||||
conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2;
|
||||
|
||||
conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
|
||||
conv_to_gemm_transformer_right.InLeftPadW_ = 0;
|
||||
|
||||
conv_to_gemm_transformer_left.InRightPadW_ = 0;
|
||||
conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
|
||||
|
||||
conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_;
|
||||
conv_to_gemm_transformer_right.Wi_ =
|
||||
math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_),
|
||||
(conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff);
|
||||
|
||||
a_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_;
|
||||
c_right_offset = (Wo_ / 2) * WoStride_;
|
||||
}
|
||||
// Return left transform, right transformer, right offset to Input and right offset to
|
||||
// Output
|
||||
return ck_tile::make_tuple(conv_to_gemm_transformer_left,
|
||||
conv_to_gemm_transformer_right,
|
||||
a_grid_ptr_base + a_right_offset,
|
||||
c_grid_ptr_base + c_right_offset);
|
||||
}
|
||||
#endif
|
||||
// TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as
|
||||
// properties
|
||||
template <typename ALayout,
|
||||
@@ -1510,6 +1543,18 @@ struct TransformConvFwdToGemm
|
||||
}
|
||||
}
|
||||
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
// Split-Image Calculation (AFTER Split-N)
|
||||
// ═══════════════════════════════════════════════════════════════════════
|
||||
// This method calculates split-image information using N_ (after Split-N).
|
||||
// This ensures correct offset calculations when both Split-N and Split-Image
|
||||
// are active simultaneously.
|
||||
|
||||
// NOTE: Deleted CalculateSplitImage() and LaunchWithRecursiveSplit() - dead code
|
||||
// Current split-image implementation is in grouped_convolution_forward_invoker.hpp
|
||||
|
||||
public:
|
||||
private:
|
||||
IndexType G_, N_, original_N_;
|
||||
IndexType Di_, Hi_, Wi_;
|
||||
IndexType Do_, Ho_, Wo_;
|
||||
|
||||
Reference in New Issue
Block a user