mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 02:54:21 +00:00
[CK TILE] Grouped conv fwd split image (#2970)
* Refactor split-image implementation: simplify code and remove redundant variables
* Add padding debug output to split-image implementation
- Added debug prints for padding calculations in transform_conv_fwd_to_gemm.hpp
- Verified padding works correctly with all tests passing
* Fix sign comparison warning after rebase with origin/develop
- Cast blockIdX from unsigned to signed index_t for comparisons
- Integrated with new GetOutputTileIndex logic from upstream
- Updated to use amd_wave_read_first_lane instead of __builtin_amdgcn_readfirstlane
* Fix Split-N with groups bug and clean up unused parameters
- Fixed batch stride calculation to include G dimension for grouped convolutions
- When moving between batches in NHWGC/NWGC/NDHWGC layouts, need to account for all groups
- Removed unused multi-split parameters (we only support 2-way split)
- All tests now pass: G=1 with Split-N, G>1 with Split-N, G>1 without Split-N
* Implement recursive queue-based split-image detection and calculation
- Add LaunchKernelWithSplitIfNeeded() helper method in transform_conv_fwd_to_gemm.hpp
- Implement recursive binary splitting algorithm (10GB→5GB+5GB→...)
- Correctly handle odd dimensions (61→30+31)
- Calculate proper offsets for each split piece
- Update invoker to use split-image helper
Note: Split detection and calculation work correctly but kernel launching
for individual pieces requires kernel modification to handle different
spatial dimensions (unlike Split-N which uses blockIdx.z).
* WIP: Split-Image investigation - found architecture mismatch
- Split-N modifies N_ directly in transformer constructor
- Split-Image needs different approach due to varying dimensions
- Added split calculation logic for 1D and 2D convolutions
- Still facing memory issues when creating piece transformers
Key finding: Split-N uses blockIdx.z for parallel execution,
while Split-Image needs sequential execution of non-uniform pieces.
* Add 1D split-image implementation for grouped convolution (N=1 working)
Implements split-image for 1D convolution to handle large tensors that
exceed memory thresholds. This is a critical milestone with N=1 fully
working and tested.
Key Changes:
- Invoker: Add split-image logic that splits W dimension in half
- Transformer: Add SplitConvProblem helper for recursive splitting
- Calculate offsets for LEFT and RIGHT pieces
- Launch two kernels sequentially (LEFT then RIGHT)
Implementation Details:
- Binary split: divides W dimension by 2
- LEFT piece: W=0 to W/2, keeps left padding, removes right padding
- RIGHT piece: W/2 to W, removes left padding, keeps right padding
- Offset calculation accounts for stride, dilation, and padding
- Physical memory offset (no padding in memory)
Test Results (N=1):
✅ 94/94 tests passing
- Comprehensive tests: 36/36 (channels, padding, stride, dilation, filters, groups)
- Edge case tests: 31/31 (odd dimensions, extreme parameters, boundaries)
- Stress tests: 27/27 (maximum dimensions, up to 91.4 TFlops)
Known Limitations:
- Only works with N=1 (single batch)
- N>1 fails when split-image triggers (offset calculation issue with Split-N)
- Root cause: Split-N modifies N in transformer, but offset calculated in invoker
- Solution planned: Move offset calculation to transformer (next phase)
Files Modified:
- grouped_convolution_forward_invoker.hpp: Add split-image logic
- transform_conv_fwd_to_gemm.hpp: Add SplitConvProblem helper
This commit represents a stable, tested 1D split-image implementation
for N=1 cases. It's an important milestone before extending to N>1
and multi-dimensional splits.
* Add basic split-image implementation for 1D/2D/3D grouped convolution
This is a working baseline implementation that splits large spatial
dimensions to handle memory constraints.
Implementation:
- 1D: W-split for NWGC layout (36/36 tests passing)
- 2D: H-split for NHWGC layout (20/20 tests passing)
- 3D: D-split for NDHWGC layout (verified working)
Features:
- Binary split of outermost spatial dimension
- Sequential LEFT/RIGHT kernel launches
- Proper padding adjustment at split boundaries
- Offset calculation for pointer arithmetic
- Debug output for verification
Threshold: 100KB (configurable in transformer)
Known limitations:
- No safety checks for edge cases (to be added)
- Offset calculated before Split-N (incompatible with N>1, to be fixed)
- No recursive splitting for very large tensors
Next steps:
- Add safety checks (is_possible_to_split_*)
- Move offset calculation to transformer (after Split-N)
- Test with N>1 + split-image combination
* Refactor split-image to unified structure for 1D/2D/3D
Unified the three separate dimension-specific blocks into a single
common implementation with dimension-specific stride calculations.
Benefits:
- Reduced code from 636 → 348 lines (45% reduction)
- Eliminated code duplication
- Easier to maintain and extend
- Single source of truth for split logic
Implementation:
- Common: Binary split, offset calc, padding adjustment, kernel launch
- Dimension-specific: Stride calculation only
- 1D: stride = G * C
- 2D: stride = W_in * G * C
- 3D: stride = H_in * W_in * G * C
Test results (all passing):
- 1D: 36/36 tests ✅
- 2D: 20/20 tests ✅
- 3D: 28/28 tests ✅
- Total: 84/84 (100%)
All test scenarios verified:
- Varying channels, padding, stride, dilation
- Filter sizes (1x1 pointwise to 7x7)
- Multiple groups (G=1,2,4)
- Odd dimensions
- Complex combinations
* Add safety checks for split-image in all dimensions
Added is_possible_to_split safety checks to prevent crashes when
splitting is not feasible.
Safety checks verify:
1. Output dimension > 1 (can't split single element)
2. RIGHT piece starts after left padding
3. LEFT piece ends within input bounds
If checks fail, falls back to normal kernel launch.
Verified for all dimensions:
- 1D (W-split): Wo=1 case triggers fallback
- 2D (H-split): Ho=1 case triggers fallback
- 3D (D-split): Do=1 case triggers fallback
Original 84 tests still pass - they use normal configurations
that naturally satisfy safety conditions.
Safety checks protect against pathological edge cases with:
- Very small spatial dimensions
- Extreme stride/dilation combinations
- Invalid padding configurations
* Fix Split-N + Split-Image compatibility issue
Fixed critical bug where Split-N and Split-Image working together
caused ~50% incorrect results due to wrong batch stride calculation.
Problem:
- Batch stride was calculated using MODIFIED spatial dimensions
(e.g., W=50000 after split) instead of ORIGINAL dimensions (W=100000)
- Spatial offset was applied globally in invoker, not per-batch in kernel
- Each batch (blockIdx.z) got wrong memory offset
Solution:
1. Store spatial offset in kargs (don't apply to pointer in invoker)
2. Copy correct batch_stride from temp_kargs to left/right kargs
3. Apply formula in operator(): ptr = base + (batch × stride) + spatial_offset
Changes:
- grouped_convolution_forward_kernel.hpp:
* Added spatial_offset_in/out fields to KernelArgs
* Apply batch + spatial offset in operator()
- grouped_convolution_forward_invoker.hpp:
* Keep base pointer, store spatial offset in kargs
* Copy batch_stride from temp_kargs (has original dimensions)
- transform_conv_fwd_to_gemm.hpp:
* Add debug output for split-image calculation
Results:
- N=1 tests: 84/84 passing (100%)
- N>1 tests: Now all passing (previously ~50% errors)
- Tested: 1D, 2D, 3D with N=1,2,4,8,16,20
* Implement unified threshold for Split-N and Split-Image
This commit consolidates threshold management for both Split-N and
Split-Image operations into a single source of truth, eliminating
code duplication and fixing offset calculation issues.
Key Changes:
============
1. Transformer (transform_conv_fwd_to_gemm.hpp):
- Moved TwoGB constant to public section for unified access
- CalculateSplitImage() now takes no parameters
- Uses internal threshold: TwoGB / sizeof(CDataType)
- Calculates offsets using N_ (after Split-N) for correctness
2. Kernel (grouped_convolution_forward_kernel.hpp):
- GetSplitImageInfo() simplified to take no parameters
- Forwards to transformer's CalculateSplitImage()
- Clean interface with unified threshold internally
3. Invoker (grouped_convolution_forward_invoker.hpp):
- Removed redundant threshold calculation
- Simplified to call kargs.GetSplitImageInfo() with no params
- Clean early-return pattern (no unnecessary else blocks)
- Removed duplicate/dead code paths
Benefits:
=========
- Single source of truth: TwoGB defined once in transformer
- No parameter passing for threshold between components
- Correct offset calculation using N_ (post-Split-N)
- Cleaner code with no duplication
- All tests passing: 1D/2D/3D with various N values
Testing:
========
- Split-Image only (N=1, large spatial): PASS
- Split-N only (N>1, small spatial): PASS
- Both splits active (N>1, large spatial): PASS
- No splits (N=1, small spatial): PASS
- CPU verification correct for all scenarios
* Comment out outdated split-image code (SplitConvProblem/LaunchKernelWithSplitIfNeeded)
The old recursive queue-based implementation has been replaced by the
new CalculateSplitImage() method which is simpler and correctly handles
Split-N + Split-Image interaction.
Changes:
- Wrapped lines 381-1078 in #if 0...#endif
- Old methods: SplitConvProblem() and LaunchKernelWithSplitIfNeeded()
- Preserved for reference but disabled from compilation
- No functional changes - all tests still pass
The new implementation (CalculateSplitImage at line ~2163) provides:
- Correct offset calculation using N_ (after Split-N)
- Simpler binary split logic
- Better integration with unified threshold approach
* Implement recursive split-image with depth limit (MAX_DEPTH=10)
Changes:
- Add depth tracking to SplitPiece struct
- Implement two stopping conditions:
1. Piece size below threshold (optimal case)
2. Depth >= MAX_DEPTH (prevents infinite recursion)
- Remove MAX_PIECES limit in favor of depth-based control
- Support up to 2^10 = 1024 pieces with depth 10
This allows handling extreme tensor sizes while ensuring termination.
Pieces larger than threshold will still launch correctly if depth limit reached.
Tested with H=100 (4 levels), H=2000 (6 levels), H=4000 (9 levels) - all pass CPU verification.
* Summary of recursive split-image implementation:
- Recursive queue-based splitting with depth limit (MAX_DEPTH=10, up to 1024 pieces)
- Two stopping conditions: size below threshold OR max depth reached
- Cumulative offset tracking through all recursion levels
- LEFT piece inherits parent offset, RIGHT accumulates (parent + local)
- Per-batch spatial offset application in kernel operator()
- Batch stride uses original dimensions (before split)
- Works with Split-N: split-N first, then recursive split-image
- Handles odd dimensions, padding, stride, dilation correctly
- All 1D/2D/3D tests pass with CPU verification
* Add comment explaining MAX_DEPTH capacity for 2GB threshold
* Refactor: move recursive split-image logic to transformer
- Move LaunchWithRecursiveSplit() from invoker to transform_conv_fwd_to_gemm.hpp
- Simplify invoker from ~250 lines to ~140 lines (removed 110 lines of inline logic)
- Encapsulate SplitPiece struct and BFS splitting algorithm in transformer
- Remove unused includes (queue, vector) from invoker
- Add documentation comment for AreDescriptorsSmallerThan2GB()
- Improve code organization and reusability
- No performance overhead (static template function, compiler inlines)
- All tests passing with 2GB production threshold
* Apply clang-format-18 formatting
- Format invoker and transformer files with clang-format-18
- Fix brace placement and alignment
- No functional changes
* Fix clang-format-18 issues in forward kernel
- Remove extra blank lines
- Fix line wrapping for template calls
- Consolidate GetSplitImageInfo() to single line
* Update include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Update include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Update include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Update include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Split-Image implementation with temporary fixed divider
- Implemented spatial dimension splitting (Split-Image) for large tensors
- Added piece-based coordinate transformation for 1D/2D/3D convolutions
- Integrated Split-N (batch splitting) with automatic threshold detection
- Fixed M dimension calculation to include batch: M = N × spatial_size
- Added spatial offset support in kernel arguments
- Verified 20/20 test cases passing for Split-Image alone
- Known issue: Split-N + Split-Image combination needs coordinate fix
Implementation Details:
- Split factors: 4 (1D), 4×4 (2D), 4×4×4 (3D) - temporary fixed values
- Batch strides properly calculated for NWGC/NHWGC/NDHWGC layouts
- Piece descriptors track spatial boundaries and block ranges
- No performance overhead for N=1 cases
* Fix 1D split-image padding issue with per-piece dimensions
- Store actual size per piece to handle non-uniform splits
- Remove dead code from transform utils
* Fix 2D/3D split-image with independent split factors per dimension
Problem: Single split factor caused non-uniform pieces when dimensions
didn't divide evenly. Result: 18/25 (72%) 2D padding combinations failed.
Solution: Independent split factor selection for W, H, D dimensions.
Each dimension gets optimal factor based on its own size.
Test Results:
- 1D: 42/42 pass (100%)
- 2D: 25/25 pass (100%)
- Total: 67/67 combinations verified
* Remove unused split-image struct fields
Cleanup of split-image implementation:
- Removed unused piece_d, piece_h, piece_w fields from SplitImageInfo struct
- These fields were declared but never used in the kernel
- Per-piece dimensions are already stored in pieces[] array
- Reduces struct size and improves code clarity
Tested: 1D/2D/3D convolutions with split-image, padding, stride all pass
* Refactor split-image invoker code for improved readability
- Extract piece calculation logic into calculate_piece lambda helper
- Extract kernel args population into populate_split_image_kargs lambda
- Use aggregate initialization for cleaner struct population
- Reduce nesting depth and improve maintainability
- Fix outdated comment about split-image implementation status
* Refactor split-image code and remove debug prints
- Extract GPU kernel helper lambdas for better readability
- Remove all split-image debug print statements
- Set memory threshold to 2GB for production
- All tests pass with CPU verification
* Add split-image safety constraints and refactor to utils
- Add MAX_TOTAL_PIECES=64 limit to prevent segfault
- Move calculate_spatial_piece to library utils
- Add layout validation (NWGC, NHWGC, NDHWGC only)
- Fix hierarchical splitting to respect piece limits
- Add proper documentation and formatting
* Change split-image from runtime to compile-time branching
Response to @bartekxk review comment:
Convert 'if(kargs.num_spatial_pieces > 1)' to 'if constexpr(EnableSplitImage)'
Changes:
- Add EnableSplitImage template parameter to kernel
- Change runtime if to compile-time if constexpr
- Update invoker to instantiate kernel variants with true/false
Benefits:
- Eliminates runtime branching in GPU kernel
- Dead code elimination (each variant is smaller)
- Better compiler optimization
Files modified: 2
Lines changed: 20 total (6 in kernel, 14 in invoker)
Tests: 27/27 passed (100%)
Performance: No regression
* Add split-image example as separate binary
- Create grouped_convolution_forward_split_image example
- Add grouped_convolution_forward_split_image_invoker.hpp
- Update CMakeLists.txt to build split_image binary
* Replace linear search with binary search in find_piece_id
- Change O(n) to O(log n) for finding piece ownership
- Matches reference implementation in large_tensor_cshuffle
* Simplify split-image code and fix integer overflow
- Extract lambda functions to static helper methods
- Pre-calculate constants in invoker
- Fix integer overflow in tensor size calculation for large tensors
* Trigger CI rerun - fix merge conflicts
* Fix merge conflict markers
* Fix clang-format: remove space before {}
* Fix clang-format: comment wrapping and Swish constructor
* Rename split_image to large_tensor for clarity
- Renamed grouped_convolution_forward_split_image.cpp -> grouped_convolution_forward_large_tensor.cpp
- Renamed grouped_convolution_forward_split_image_invoker.hpp -> grouped_convolution_forward_large_tensor_invoker.hpp
- Updated CMakeLists.txt target name: tile_example_grouped_conv_fwd_split_image -> tile_example_grouped_conv_fwd_large_tensor
- Updated comments to refer to 'large tensor' instead of 'split-image'
* Update comments and include in large_tensor example
- Updated header comments to use 'large tensor' terminology
- Fixed include path to use large_tensor_invoker.hpp
* Remove test code, restore 2GB threshold
* Update include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Fix build errors after develop merge and complete rename to large_tensor
This commit addresses compilation errors from the develop merge and
completes the rename from split_image to large_tensor.
Changes:
1. Fix CDEElementWise typo in grouped_convolution_forward_invoker.hpp
2. Fix template parameter order in large_tensor_invoker.hpp
- TransformConvFwdToGemm signature changed in develop
- NumGroupsToMerge and SplitN parameters swapped positions
3. Fix missing template parameter in GroupedConvFwdHostArgs
4. Fix EpiloguePipeline scope in kernel (merge conflict)
5. Update binary name references in test scripts
* Restore 2GB threshold for split-image
Changed threshold from 100MB (testing) back to 2GB for production use.
* Fix const-correctness in ds_ptr cast
* Update include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Apply clang-format-18
* update c++ 18 format
* Apply clang-format-18 to transform_conv_fwd_to_gemm.hpp
---------
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
[ROCm/composable_kernel commit: 1fbb47ad30]
This commit is contained in:
@@ -2,16 +2,19 @@ set(EXAMPLE_CONV_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_CONV_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
|
||||
|
||||
add_executable(tile_example_grouped_conv_fwd EXCLUDE_FROM_ALL grouped_convolution_forward.cpp)
|
||||
target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_grouped_conv_fwd_large_tensor EXCLUDE_FROM_ALL grouped_convolution_forward_large_tensor.cpp)
|
||||
target_compile_options(tile_example_grouped_conv_fwd_large_tensor PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_grouped_conv_fwd_bias_clamp EXCLUDE_FROM_ALL grouped_convolution_forward_bias_clamp.cpp)
|
||||
target_compile_options(tile_example_grouped_conv_fwd_bias_clamp PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_grouped_conv_bwd_weight EXCLUDE_FROM_ALL grouped_convolution_backward_weight.cpp)
|
||||
target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_grouped_conv_bwd_weight PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_grouped_conv_bwd_weight_two_stage EXCLUDE_FROM_ALL grouped_convolution_backward_weight_two_stage.cpp)
|
||||
target_compile_options(tile_example_grouped_conv_bwd_weight_two_stage PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_grouped_conv_bwd_weight_two_stage PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_grouped_conv_bwd_data EXCLUDE_FROM_ALL grouped_convolution_backward_data.cpp)
|
||||
target_compile_options(tile_example_grouped_conv_bwd_data PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_grouped_conv_bwd_data PRIVATE ${EXAMPLE_CONV_COMPILE_OPTIONS})
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// Regular grouped convolution invoker (no split-image)
|
||||
// This invoker demonstrates regular convolution without split-image.
|
||||
// It always uses Kernel<false> (split-image disabled).
|
||||
// For large images that require split-image, use
|
||||
// grouped_convolution_forward_split_image_invoker.hpp
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "grouped_convolution_utils.hpp"
|
||||
@@ -21,6 +28,10 @@ struct GroupedConvolutionForwardInvoker
|
||||
static float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs<CDElementWise>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n";
|
||||
}
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
@@ -90,6 +101,7 @@ struct GroupedConvolutionForwardInvoker
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>());
|
||||
|
||||
// Split-K parameters
|
||||
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
@@ -97,100 +109,117 @@ struct GroupedConvolutionForwardInvoker
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
// =====================================================================
|
||||
// Regular Convolution: Simple, no split-image
|
||||
// =====================================================================
|
||||
const auto Run = [&]<bool EnableSplitImage>(const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
GemmConfig::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
GemmConfig::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
using Kernel = ck_tile::GroupedConvolutionForwardKernel<EnableSplitImage,
|
||||
GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
// =====================================================================
|
||||
// Split-K lambda
|
||||
// =====================================================================
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_, tail_number_, MemoryOpSet{});
|
||||
Run.template operator()<false>(has_hot_loop_, tail_number_, MemoryOpSet{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
|
||||
Run.template operator()<false>(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
|
||||
}
|
||||
};
|
||||
|
||||
// =====================================================================
|
||||
// Regular Convolution Example: ALWAYS uses regular path (Kernel<false>)
|
||||
// =====================================================================
|
||||
// This example demonstrates regular convolution without split-image.
|
||||
// For large images that don't fit in memory, use
|
||||
// grouped_convolution_forward_split_image.cpp
|
||||
|
||||
// Launch kernel using regular path (no split-image)
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// Large tensor grouped convolution example
|
||||
// This example demonstrates convolution for large tensors that exceed memory limits.
|
||||
// It uses automatic tensor splitting when needed to handle large images.
|
||||
// For regular convolution without tensor splitting, use grouped_convolution_forward.cpp
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "grouped_convolution_utils.hpp"
|
||||
#include "grouped_convolution_forward_large_tensor_invoker.hpp"
|
||||
#include "run_grouped_convolution_fwd_example.inc"
|
||||
|
||||
template <template <typename PrecType> typename GemmConfig>
|
||||
int run_grouped_conv_fwd_example(int argc, char* argv[])
|
||||
{
|
||||
using Invoker = GroupedConvolutionForwardInvoker;
|
||||
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string in_layout = arg_parser.get_str("in_layout");
|
||||
std::string wei_layout = arg_parser.get_str("wei_layout");
|
||||
std::string out_layout = arg_parser.get_str("out_layout");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_grouped_conv_fwd_example_prec_type<Invoker,
|
||||
GemmConfig<ck_tile::half_t>,
|
||||
ck_tile::half_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_grouped_conv_fwd_example_prec_type<Invoker,
|
||||
GemmConfig<ck_tile::bf16_t>,
|
||||
ck_tile::bf16_t>(
|
||||
in_layout, wei_layout, out_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return !run_grouped_conv_fwd_example<ConvConfigComputeV3_WMMA>(argc, argv);
|
||||
#else
|
||||
return !run_grouped_conv_fwd_example<ConvConfigComputeV3>(argc, argv);
|
||||
#endif
|
||||
}
|
||||
@@ -0,0 +1,388 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
#include "grouped_convolution_utils.hpp"
|
||||
|
||||
struct GroupedConvolutionForwardInvoker
|
||||
{
|
||||
template <ck_tile::index_t NDimSpatial,
|
||||
typename GemmConfig,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename AccDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename DsDataType = ck_tile::tuple<>,
|
||||
typename DsLayout = ck_tile::tuple<>,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
static float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs<CDEElementWise>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n";
|
||||
}
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
// Implicit GEMM Traits
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
constexpr ck_tile::index_t VectorSizeA = 8;
|
||||
constexpr ck_tile::index_t VectorSizeB = 8;
|
||||
constexpr ck_tile::index_t VectorSizeC = 8;
|
||||
|
||||
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<NDimSpatial,
|
||||
ConvSpec,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
DsLayout,
|
||||
OutLayout,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
|
||||
GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::AsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::BsLayout,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd::CLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
false, // Persistent,
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraitsFwd,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t gemm_k =
|
||||
args.C_ * std::accumulate(args.filter_spatial_lengths_.begin(),
|
||||
args.filter_spatial_lengths_.end(),
|
||||
1,
|
||||
std::multiplies<ck_tile::index_t>());
|
||||
|
||||
// Split-K parameters
|
||||
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_k + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
using TransformType =
|
||||
ck_tile::TransformConvFwdToGemm<NDimSpatial,
|
||||
ck_tile::ConvolutionSpecialization::Default,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeC,
|
||||
1, // NumGroupsToMerge
|
||||
false, // SplitN
|
||||
InDataType,
|
||||
OutDataType>;
|
||||
|
||||
// =====================================================================
|
||||
// Step 1: Check if layout supports split-image kernel
|
||||
// =====================================================================
|
||||
// Split-image requires specific memory layouts:
|
||||
// 1D: NWGC (input), GKXC (weight), NWGK (output)
|
||||
// 2D: NHWGC (input), GKYXC (weight), NHWGK (output)
|
||||
// 3D: NDHWGC (input), GKZYXC (weight), NDHWGK (output)
|
||||
constexpr bool is_supported_layout =
|
||||
std::is_same<InLayout, ck_tile::tensor_layout::convolution::NWGC>::value ||
|
||||
std::is_same<InLayout, ck_tile::tensor_layout::convolution::NHWGC>::value ||
|
||||
std::is_same<InLayout, ck_tile::tensor_layout::convolution::NDHWGC>::value;
|
||||
|
||||
// =====================================================================
|
||||
// Step 2: Calculate split-image info (if layout supports it)
|
||||
// =====================================================================
|
||||
// Extract output spatial dimensions
|
||||
const ck_tile::index_t total_d =
|
||||
(NDimSpatial == 3) ? args.output_spatial_lengths_[NDimSpatial - 3] : 1;
|
||||
const ck_tile::index_t total_h =
|
||||
(NDimSpatial >= 2) ? args.output_spatial_lengths_[NDimSpatial - 2] : 1;
|
||||
const ck_tile::index_t total_w = args.output_spatial_lengths_[NDimSpatial - 1];
|
||||
|
||||
auto split_info = TransformType::GetSplitImageInfo(
|
||||
args.G_, args.N_, args.C_, args.K_, total_d, total_h, total_w);
|
||||
|
||||
// =====================================================================
|
||||
// Decide: Split-image or regular kernel?
|
||||
// =====================================================================
|
||||
const bool use_split_image = is_supported_layout && split_info.should_split;
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
if(!is_supported_layout)
|
||||
{
|
||||
std::cout << "[INVOKER] Layout not supported for split-image. "
|
||||
<< "Using regular kernel (Kernel<false>).\n";
|
||||
}
|
||||
else if(!split_info.should_split)
|
||||
{
|
||||
std::cout << "[INVOKER] Image is small (" << total_h << "×" << total_w
|
||||
<< "), split-image not necessary.\n";
|
||||
std::cout << "[INVOKER] Using regular kernel (Kernel<false>).\n";
|
||||
}
|
||||
}
|
||||
|
||||
// =====================================================================
|
||||
// Step 3: Calculate split-image pieces (only if using split-image)
|
||||
// =====================================================================
|
||||
ck_tile::index_t num_d_pieces = 1;
|
||||
ck_tile::index_t num_h_pieces = 1;
|
||||
ck_tile::index_t num_w_pieces = 1;
|
||||
ck_tile::index_t total_pieces = 1;
|
||||
ck_tile::index_t base_piece_d = total_d;
|
||||
ck_tile::index_t base_piece_h = total_h;
|
||||
ck_tile::index_t base_piece_w = total_w;
|
||||
std::array<ck_tile::SplitImagePieceInfo, 64> temp_pieces{};
|
||||
ck_tile::index_t total_blocks = 0;
|
||||
|
||||
if(use_split_image)
|
||||
{
|
||||
num_d_pieces = split_info.num_d_pieces;
|
||||
num_h_pieces = split_info.num_h_pieces;
|
||||
num_w_pieces = split_info.num_w_pieces;
|
||||
total_pieces = num_d_pieces * num_h_pieces * num_w_pieces;
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "\n========================================\n";
|
||||
std::cout << "[SPLIT-IMAGE ENABLED] Large tensor detected\n";
|
||||
std::cout << "========================================\n";
|
||||
if(NDimSpatial == 3)
|
||||
{
|
||||
std::cout << "Total dimensions: D=" << total_d << " H=" << total_h
|
||||
<< " W=" << total_w << "\n";
|
||||
std::cout << "Split into pieces: D=" << num_d_pieces << " × H=" << num_h_pieces
|
||||
<< " × W=" << num_w_pieces << " = " << total_pieces
|
||||
<< " total pieces\n";
|
||||
std::cout << "Base piece size: D=" << (total_d / num_d_pieces)
|
||||
<< " H=" << (total_h / num_h_pieces)
|
||||
<< " W=" << (total_w / num_w_pieces) << "\n";
|
||||
}
|
||||
else if(NDimSpatial == 2)
|
||||
{
|
||||
std::cout << "Total dimensions: H=" << total_h << " W=" << total_w << "\n";
|
||||
std::cout << "Split into pieces: H=" << num_h_pieces << " × W=" << num_w_pieces
|
||||
<< " = " << total_pieces << " total pieces\n";
|
||||
std::cout << "Base piece size: H=" << (total_h / num_h_pieces)
|
||||
<< " W=" << (total_w / num_w_pieces) << "\n";
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Total dimensions: W=" << total_w << "\n";
|
||||
std::cout << "Split into pieces: W=" << num_w_pieces << " = " << total_pieces
|
||||
<< " total pieces\n";
|
||||
std::cout << "Base piece size: W=" << (total_w / num_w_pieces) << "\n";
|
||||
}
|
||||
std::cout << "========================================\n\n";
|
||||
}
|
||||
|
||||
// Base piece size (non-overlapping division)
|
||||
base_piece_d = total_d / num_d_pieces;
|
||||
base_piece_h = total_h / num_h_pieces;
|
||||
base_piece_w = total_w / num_w_pieces;
|
||||
|
||||
// Calculate piece info for all pieces using library utility function
|
||||
for(ck_tile::index_t piece = 0; piece < total_pieces; piece++)
|
||||
{
|
||||
temp_pieces[piece] =
|
||||
ck_tile::calculate_spatial_piece<TilePartitioner>(piece,
|
||||
num_d_pieces,
|
||||
num_h_pieces,
|
||||
num_w_pieces,
|
||||
base_piece_d,
|
||||
base_piece_h,
|
||||
base_piece_w,
|
||||
total_d,
|
||||
total_h,
|
||||
total_w,
|
||||
args.N_,
|
||||
args.K_,
|
||||
total_blocks);
|
||||
total_blocks = temp_pieces[piece].block_end;
|
||||
}
|
||||
}
|
||||
|
||||
// =====================================================================
|
||||
// Kernel launch lambda: Uses EnableSplitImage based on layout support
|
||||
// =====================================================================
|
||||
const auto Run = [&]<bool EnableSplitImage>(const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<InDataType,
|
||||
WeiDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
OutDataType,
|
||||
true,
|
||||
VectorSizeA,
|
||||
VectorSizeB>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
OutDataType,
|
||||
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
GemmConfig::TransposeC,
|
||||
memory_operation,
|
||||
1,
|
||||
true,
|
||||
GroupedConvTraitsType::VectorSizeC>>;
|
||||
|
||||
// Use split-image kernel if layout supports it, otherwise use regular kernel
|
||||
using Kernel = ck_tile::GroupedConvolutionForwardKernel<EnableSplitImage,
|
||||
GroupedConvTraitsType,
|
||||
TilePartitioner,
|
||||
GemmPipeline,
|
||||
ConvEpilogue>;
|
||||
|
||||
// Create kargs
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
// Populate split-image metadata ONLY if using split-image kernel
|
||||
if constexpr(EnableSplitImage)
|
||||
{
|
||||
kargs.num_spatial_pieces = total_pieces;
|
||||
kargs.split_image.total_d = total_d;
|
||||
kargs.split_image.total_h = total_h;
|
||||
kargs.split_image.total_w = total_w;
|
||||
kargs.split_image.total_spatial = total_d * total_h * total_w; // Pre-calculate
|
||||
kargs.split_image.num_d_pieces = num_d_pieces;
|
||||
kargs.split_image.num_h_pieces = num_h_pieces;
|
||||
kargs.split_image.num_w_pieces = num_w_pieces;
|
||||
|
||||
for(ck_tile::index_t i = 0; i < total_pieces; i++)
|
||||
{
|
||||
kargs.split_image.pieces[i] = {temp_pieces[i].block_start,
|
||||
temp_pieces[i].block_end,
|
||||
temp_pieces[i].d_start,
|
||||
temp_pieces[i].h_start,
|
||||
temp_pieces[i].w_start,
|
||||
temp_pieces[i].d_size,
|
||||
temp_pieces[i].h_size,
|
||||
temp_pieces[i].w_size};
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate grid: use total_blocks for split-image, or normal GridSize for regular
|
||||
const dim3 grids = [&]() {
|
||||
if constexpr(EnableSplitImage)
|
||||
return dim3(total_blocks, kargs.GemmBatch, kargs.n_splits);
|
||||
else
|
||||
return Kernel::GridSize(kargs);
|
||||
}();
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << '\n'
|
||||
<< "Vector size A: " << GemmPipeline::GetVectorSizeA()
|
||||
<< ", Vector size B: " << GemmPipeline::GetVectorSizeB()
|
||||
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
|
||||
}
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
// =====================================================================
|
||||
// Step 4: Dispatch kernel (split-image or regular based on decision)
|
||||
// =====================================================================
|
||||
if(use_split_image)
|
||||
{
|
||||
// Use split-image kernel (Kernel<true>)
|
||||
const auto RunSplitImage = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
Run.template operator()<true>(has_hot_loop_, tail_number_, MemoryOpSet{});
|
||||
else
|
||||
Run.template operator()<true>(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitImage, has_hot_loop, tail_num);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Use regular kernel (Kernel<false>)
|
||||
const auto RunRegular = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
Run.template operator()<false>(has_hot_loop_, tail_number_, MemoryOpSet{});
|
||||
else
|
||||
Run.template operator()<false>(
|
||||
has_hot_loop_, tail_number_, MemoryOpAtomicAdd{});
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunRegular, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
};
|
||||
Reference in New Issue
Block a user