mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK TILE] Grouped conv fwd split image (#2970)
* Refactor split-image implementation: simplify code and remove redundant variables * Add padding debug output to split-image implementation - Added debug prints for padding calculations in transform_conv_fwd_to_gemm.hpp - Verified padding works correctly with all tests passing * Fix sign comparison warning after rebase with origin/develop - Cast blockIdX from unsigned to signed index_t for comparisons - Integrated with new GetOutputTileIndex logic from upstream - Updated to use amd_wave_read_first_lane instead of __builtin_amdgcn_readfirstlane * Fix Split-N with groups bug and clean up unused parameters - Fixed batch stride calculation to include G dimension for grouped convolutions - When moving between batches in NHWGC/NWGC/NDHWGC layouts, need to account for all groups - Removed unused multi-split parameters (we only support 2-way split) - All tests now pass: G=1 with Split-N, G>1 with Split-N, G>1 without Split-N * Implement recursive queue-based split-image detection and calculation - Add LaunchKernelWithSplitIfNeeded() helper method in transform_conv_fwd_to_gemm.hpp - Implement recursive binary splitting algorithm (10GB→5GB+5GB→...) - Correctly handle odd dimensions (61→30+31) - Calculate proper offsets for each split piece - Update invoker to use split-image helper Note: Split detection and calculation work correctly but kernel launching for individual pieces requires kernel modification to handle different spatial dimensions (unlike Split-N which uses blockIdx.z). * WIP: Split-Image investigation - found architecture mismatch - Split-N modifies N_ directly in transformer constructor - Split-Image needs different approach due to varying dimensions - Added split calculation logic for 1D and 2D convolutions - Still facing memory issues when creating piece transformers Key finding: Split-N uses blockIdx.z for parallel execution, while Split-Image needs sequential execution of non-uniform pieces. * Add 1D split-image implementation for grouped convolution (N=1 working) Implements split-image for 1D convolution to handle large tensors that exceed memory thresholds. This is a critical milestone with N=1 fully working and tested. Key Changes: - Invoker: Add split-image logic that splits W dimension in half - Transformer: Add SplitConvProblem helper for recursive splitting - Calculate offsets for LEFT and RIGHT pieces - Launch two kernels sequentially (LEFT then RIGHT) Implementation Details: - Binary split: divides W dimension by 2 - LEFT piece: W=0 to W/2, keeps left padding, removes right padding - RIGHT piece: W/2 to W, removes left padding, keeps right padding - Offset calculation accounts for stride, dilation, and padding - Physical memory offset (no padding in memory) Test Results (N=1): ✅ 94/94 tests passing - Comprehensive tests: 36/36 (channels, padding, stride, dilation, filters, groups) - Edge case tests: 31/31 (odd dimensions, extreme parameters, boundaries) - Stress tests: 27/27 (maximum dimensions, up to 91.4 TFlops) Known Limitations: - Only works with N=1 (single batch) - N>1 fails when split-image triggers (offset calculation issue with Split-N) - Root cause: Split-N modifies N in transformer, but offset calculated in invoker - Solution planned: Move offset calculation to transformer (next phase) Files Modified: - grouped_convolution_forward_invoker.hpp: Add split-image logic - transform_conv_fwd_to_gemm.hpp: Add SplitConvProblem helper This commit represents a stable, tested 1D split-image implementation for N=1 cases. It's an important milestone before extending to N>1 and multi-dimensional splits. * Add basic split-image implementation for 1D/2D/3D grouped convolution This is a working baseline implementation that splits large spatial dimensions to handle memory constraints. Implementation: - 1D: W-split for NWGC layout (36/36 tests passing) - 2D: H-split for NHWGC layout (20/20 tests passing) - 3D: D-split for NDHWGC layout (verified working) Features: - Binary split of outermost spatial dimension - Sequential LEFT/RIGHT kernel launches - Proper padding adjustment at split boundaries - Offset calculation for pointer arithmetic - Debug output for verification Threshold: 100KB (configurable in transformer) Known limitations: - No safety checks for edge cases (to be added) - Offset calculated before Split-N (incompatible with N>1, to be fixed) - No recursive splitting for very large tensors Next steps: - Add safety checks (is_possible_to_split_*) - Move offset calculation to transformer (after Split-N) - Test with N>1 + split-image combination * Refactor split-image to unified structure for 1D/2D/3D Unified the three separate dimension-specific blocks into a single common implementation with dimension-specific stride calculations. Benefits: - Reduced code from 636 → 348 lines (45% reduction) - Eliminated code duplication - Easier to maintain and extend - Single source of truth for split logic Implementation: - Common: Binary split, offset calc, padding adjustment, kernel launch - Dimension-specific: Stride calculation only - 1D: stride = G * C - 2D: stride = W_in * G * C - 3D: stride = H_in * W_in * G * C Test results (all passing): - 1D: 36/36 tests ✅ - 2D: 20/20 tests ✅ - 3D: 28/28 tests ✅ - Total: 84/84 (100%) All test scenarios verified: - Varying channels, padding, stride, dilation - Filter sizes (1x1 pointwise to 7x7) - Multiple groups (G=1,2,4) - Odd dimensions - Complex combinations * Add safety checks for split-image in all dimensions Added is_possible_to_split safety checks to prevent crashes when splitting is not feasible. Safety checks verify: 1. Output dimension > 1 (can't split single element) 2. RIGHT piece starts after left padding 3. LEFT piece ends within input bounds If checks fail, falls back to normal kernel launch. Verified for all dimensions: - 1D (W-split): Wo=1 case triggers fallback - 2D (H-split): Ho=1 case triggers fallback - 3D (D-split): Do=1 case triggers fallback Original 84 tests still pass - they use normal configurations that naturally satisfy safety conditions. Safety checks protect against pathological edge cases with: - Very small spatial dimensions - Extreme stride/dilation combinations - Invalid padding configurations * Fix Split-N + Split-Image compatibility issue Fixed critical bug where Split-N and Split-Image working together caused ~50% incorrect results due to wrong batch stride calculation. Problem: - Batch stride was calculated using MODIFIED spatial dimensions (e.g., W=50000 after split) instead of ORIGINAL dimensions (W=100000) - Spatial offset was applied globally in invoker, not per-batch in kernel - Each batch (blockIdx.z) got wrong memory offset Solution: 1. Store spatial offset in kargs (don't apply to pointer in invoker) 2. Copy correct batch_stride from temp_kargs to left/right kargs 3. Apply formula in operator(): ptr = base + (batch × stride) + spatial_offset Changes: - grouped_convolution_forward_kernel.hpp: * Added spatial_offset_in/out fields to KernelArgs * Apply batch + spatial offset in operator() - grouped_convolution_forward_invoker.hpp: * Keep base pointer, store spatial offset in kargs * Copy batch_stride from temp_kargs (has original dimensions) - transform_conv_fwd_to_gemm.hpp: * Add debug output for split-image calculation Results: - N=1 tests: 84/84 passing (100%) - N>1 tests: Now all passing (previously ~50% errors) - Tested: 1D, 2D, 3D with N=1,2,4,8,16,20 * Implement unified threshold for Split-N and Split-Image This commit consolidates threshold management for both Split-N and Split-Image operations into a single source of truth, eliminating code duplication and fixing offset calculation issues. Key Changes: ============ 1. Transformer (transform_conv_fwd_to_gemm.hpp): - Moved TwoGB constant to public section for unified access - CalculateSplitImage() now takes no parameters - Uses internal threshold: TwoGB / sizeof(CDataType) - Calculates offsets using N_ (after Split-N) for correctness 2. Kernel (grouped_convolution_forward_kernel.hpp): - GetSplitImageInfo() simplified to take no parameters - Forwards to transformer's CalculateSplitImage() - Clean interface with unified threshold internally 3. Invoker (grouped_convolution_forward_invoker.hpp): - Removed redundant threshold calculation - Simplified to call kargs.GetSplitImageInfo() with no params - Clean early-return pattern (no unnecessary else blocks) - Removed duplicate/dead code paths Benefits: ========= - Single source of truth: TwoGB defined once in transformer - No parameter passing for threshold between components - Correct offset calculation using N_ (post-Split-N) - Cleaner code with no duplication - All tests passing: 1D/2D/3D with various N values Testing: ======== - Split-Image only (N=1, large spatial): PASS - Split-N only (N>1, small spatial): PASS - Both splits active (N>1, large spatial): PASS - No splits (N=1, small spatial): PASS - CPU verification correct for all scenarios * Comment out outdated split-image code (SplitConvProblem/LaunchKernelWithSplitIfNeeded) The old recursive queue-based implementation has been replaced by the new CalculateSplitImage() method which is simpler and correctly handles Split-N + Split-Image interaction. Changes: - Wrapped lines 381-1078 in #if 0...#endif - Old methods: SplitConvProblem() and LaunchKernelWithSplitIfNeeded() - Preserved for reference but disabled from compilation - No functional changes - all tests still pass The new implementation (CalculateSplitImage at line ~2163) provides: - Correct offset calculation using N_ (after Split-N) - Simpler binary split logic - Better integration with unified threshold approach * Implement recursive split-image with depth limit (MAX_DEPTH=10) Changes: - Add depth tracking to SplitPiece struct - Implement two stopping conditions: 1. Piece size below threshold (optimal case) 2. Depth >= MAX_DEPTH (prevents infinite recursion) - Remove MAX_PIECES limit in favor of depth-based control - Support up to 2^10 = 1024 pieces with depth 10 This allows handling extreme tensor sizes while ensuring termination. Pieces larger than threshold will still launch correctly if depth limit reached. Tested with H=100 (4 levels), H=2000 (6 levels), H=4000 (9 levels) - all pass CPU verification. * Summary of recursive split-image implementation: - Recursive queue-based splitting with depth limit (MAX_DEPTH=10, up to 1024 pieces) - Two stopping conditions: size below threshold OR max depth reached - Cumulative offset tracking through all recursion levels - LEFT piece inherits parent offset, RIGHT accumulates (parent + local) - Per-batch spatial offset application in kernel operator() - Batch stride uses original dimensions (before split) - Works with Split-N: split-N first, then recursive split-image - Handles odd dimensions, padding, stride, dilation correctly - All 1D/2D/3D tests pass with CPU verification * Add comment explaining MAX_DEPTH capacity for 2GB threshold * Refactor: move recursive split-image logic to transformer - Move LaunchWithRecursiveSplit() from invoker to transform_conv_fwd_to_gemm.hpp - Simplify invoker from ~250 lines to ~140 lines (removed 110 lines of inline logic) - Encapsulate SplitPiece struct and BFS splitting algorithm in transformer - Remove unused includes (queue, vector) from invoker - Add documentation comment for AreDescriptorsSmallerThan2GB() - Improve code organization and reusability - No performance overhead (static template function, compiler inlines) - All tests passing with 2GB production threshold * Apply clang-format-18 formatting - Format invoker and transformer files with clang-format-18 - Fix brace placement and alignment - No functional changes * Fix clang-format-18 issues in forward kernel - Remove extra blank lines - Fix line wrapping for template calls - Consolidate GetSplitImageInfo() to single line * Update include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Split-Image implementation with temporary fixed divider - Implemented spatial dimension splitting (Split-Image) for large tensors - Added piece-based coordinate transformation for 1D/2D/3D convolutions - Integrated Split-N (batch splitting) with automatic threshold detection - Fixed M dimension calculation to include batch: M = N × spatial_size - Added spatial offset support in kernel arguments - Verified 20/20 test cases passing for Split-Image alone - Known issue: Split-N + Split-Image combination needs coordinate fix Implementation Details: - Split factors: 4 (1D), 4×4 (2D), 4×4×4 (3D) - temporary fixed values - Batch strides properly calculated for NWGC/NHWGC/NDHWGC layouts - Piece descriptors track spatial boundaries and block ranges - No performance overhead for N=1 cases * Fix 1D split-image padding issue with per-piece dimensions - Store actual size per piece to handle non-uniform splits - Remove dead code from transform utils * Fix 2D/3D split-image with independent split factors per dimension Problem: Single split factor caused non-uniform pieces when dimensions didn't divide evenly. Result: 18/25 (72%) 2D padding combinations failed. Solution: Independent split factor selection for W, H, D dimensions. Each dimension gets optimal factor based on its own size. Test Results: - 1D: 42/42 pass (100%) - 2D: 25/25 pass (100%) - Total: 67/67 combinations verified * Remove unused split-image struct fields Cleanup of split-image implementation: - Removed unused piece_d, piece_h, piece_w fields from SplitImageInfo struct - These fields were declared but never used in the kernel - Per-piece dimensions are already stored in pieces[] array - Reduces struct size and improves code clarity Tested: 1D/2D/3D convolutions with split-image, padding, stride all pass * Refactor split-image invoker code for improved readability - Extract piece calculation logic into calculate_piece lambda helper - Extract kernel args population into populate_split_image_kargs lambda - Use aggregate initialization for cleaner struct population - Reduce nesting depth and improve maintainability - Fix outdated comment about split-image implementation status * Refactor split-image code and remove debug prints - Extract GPU kernel helper lambdas for better readability - Remove all split-image debug print statements - Set memory threshold to 2GB for production - All tests pass with CPU verification * Add split-image safety constraints and refactor to utils - Add MAX_TOTAL_PIECES=64 limit to prevent segfault - Move calculate_spatial_piece to library utils - Add layout validation (NWGC, NHWGC, NDHWGC only) - Fix hierarchical splitting to respect piece limits - Add proper documentation and formatting * Change split-image from runtime to compile-time branching Response to @bartekxk review comment: Convert 'if(kargs.num_spatial_pieces > 1)' to 'if constexpr(EnableSplitImage)' Changes: - Add EnableSplitImage template parameter to kernel - Change runtime if to compile-time if constexpr - Update invoker to instantiate kernel variants with true/false Benefits: - Eliminates runtime branching in GPU kernel - Dead code elimination (each variant is smaller) - Better compiler optimization Files modified: 2 Lines changed: 20 total (6 in kernel, 14 in invoker) Tests: 27/27 passed (100%) Performance: No regression * Add split-image example as separate binary - Create grouped_convolution_forward_split_image example - Add grouped_convolution_forward_split_image_invoker.hpp - Update CMakeLists.txt to build split_image binary * Replace linear search with binary search in find_piece_id - Change O(n) to O(log n) for finding piece ownership - Matches reference implementation in large_tensor_cshuffle * Simplify split-image code and fix integer overflow - Extract lambda functions to static helper methods - Pre-calculate constants in invoker - Fix integer overflow in tensor size calculation for large tensors * Trigger CI rerun - fix merge conflicts * Fix merge conflict markers * Fix clang-format: remove space before {} * Fix clang-format: comment wrapping and Swish constructor * Rename split_image to large_tensor for clarity - Renamed grouped_convolution_forward_split_image.cpp -> grouped_convolution_forward_large_tensor.cpp - Renamed grouped_convolution_forward_split_image_invoker.hpp -> grouped_convolution_forward_large_tensor_invoker.hpp - Updated CMakeLists.txt target name: tile_example_grouped_conv_fwd_split_image -> tile_example_grouped_conv_fwd_large_tensor - Updated comments to refer to 'large tensor' instead of 'split-image' * Update comments and include in large_tensor example - Updated header comments to use 'large tensor' terminology - Fixed include path to use large_tensor_invoker.hpp * Remove test code, restore 2GB threshold * Update include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Fix build errors after develop merge and complete rename to large_tensor This commit addresses compilation errors from the develop merge and completes the rename from split_image to large_tensor. Changes: 1. Fix CDEElementWise typo in grouped_convolution_forward_invoker.hpp 2. Fix template parameter order in large_tensor_invoker.hpp - TransformConvFwdToGemm signature changed in develop - NumGroupsToMerge and SplitN parameters swapped positions 3. Fix missing template parameter in GroupedConvFwdHostArgs 4. Fix EpiloguePipeline scope in kernel (merge conflict) 5. Update binary name references in test scripts * Restore 2GB threshold for split-image Changed threshold from 100MB (testing) back to 2GB for production use. * Fix const-correctness in ds_ptr cast * Update include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Apply clang-format-18 * update c++ 18 format * Apply clang-format-18 to transform_conv_fwd_to_gemm.hpp --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -78,23 +78,21 @@ struct GroupedConvFwdKernelArgs
|
||||
}
|
||||
out_ptr = args.out_ptr;
|
||||
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
// Create and STORE transformer (for split-image support)
|
||||
transformer_ = ConvToGemmFwdTransformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
a_grid_desc_m_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
|
||||
transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
|
||||
b_grid_desc_n_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
|
||||
transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
|
||||
c_grid_desc_m_n =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
|
||||
transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
|
||||
|
||||
group_stride_a = args.C_;
|
||||
group_stride_b = args.K_ * args.C_ *
|
||||
@@ -106,13 +104,16 @@ struct GroupedConvFwdKernelArgs
|
||||
|
||||
// Initialize Split-N support fields for 1D convolution (NWGC layout)
|
||||
// Get the actual split N from transformer
|
||||
n_per_split = conv_to_gemm_transformer.GetN();
|
||||
original_n = conv_to_gemm_transformer.GetOriginalN();
|
||||
n_per_split = transformer_.GetN();
|
||||
original_n = transformer_.GetOriginalN();
|
||||
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
|
||||
|
||||
// Calculate batch strides for NWGC layout
|
||||
input_batch_stride = args.C_ * args.input_spatial_lengths_[0];
|
||||
output_batch_stride = args.K_ * args.output_spatial_lengths_[0];
|
||||
// Calculate batch strides using the original argument dimensions.
|
||||
// These are the original dimensions passed to the constructor, not modified by the invoker
|
||||
// yet. (The invoker modifies args after calling MakeKernelArgs.) VERIFIED: G_ MUST be
|
||||
// included - NWGC layout has all groups within each batch
|
||||
input_batch_stride = args.G_ * args.C_ * args.input_spatial_lengths_[0];
|
||||
output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0];
|
||||
|
||||
// Update GemmM to use split N (not original N)
|
||||
GemmM = n_per_split * args.output_spatial_lengths_[0];
|
||||
@@ -169,23 +170,21 @@ struct GroupedConvFwdKernelArgs
|
||||
}
|
||||
out_ptr = args.out_ptr;
|
||||
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
// Create and STORE transformer (for split-image support)
|
||||
transformer_ = ConvToGemmFwdTransformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
a_grid_desc_m_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
|
||||
transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
|
||||
b_grid_desc_n_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
|
||||
transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
|
||||
c_grid_desc_m_n =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
|
||||
transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
|
||||
|
||||
group_stride_a = args.C_;
|
||||
group_stride_b = args.K_ * args.C_ *
|
||||
@@ -197,15 +196,16 @@ struct GroupedConvFwdKernelArgs
|
||||
|
||||
// Initialize Split-N support fields for 2D convolution (NHWGC layout)
|
||||
// Get the actual split N from transformer
|
||||
n_per_split = conv_to_gemm_transformer.GetN();
|
||||
original_n = conv_to_gemm_transformer.GetOriginalN();
|
||||
n_per_split = transformer_.GetN();
|
||||
original_n = transformer_.GetOriginalN();
|
||||
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
|
||||
|
||||
// Calculate batch strides for NHWGC layout
|
||||
// VERIFIED: G_ MUST be included - NHWGC layout has all groups within each batch
|
||||
input_batch_stride =
|
||||
args.C_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1];
|
||||
args.G_ * args.C_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1];
|
||||
output_batch_stride =
|
||||
args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
|
||||
args.G_ * args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
|
||||
|
||||
// Update GemmM to use split N (not original N)
|
||||
GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
|
||||
@@ -270,23 +270,21 @@ struct GroupedConvFwdKernelArgs
|
||||
}
|
||||
out_ptr = args.out_ptr;
|
||||
|
||||
ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
// Create and STORE transformer (for split-image support)
|
||||
transformer_ = ConvToGemmFwdTransformer{in_g_n_c_wis_lengths,
|
||||
wei_g_k_c_xs_lengths,
|
||||
out_g_n_k_wos_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
|
||||
a_grid_desc_m_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
|
||||
transformer_.template MakeADescriptor_M_K<typename GroupedConvTraitsType_::InLayout>();
|
||||
b_grid_desc_n_k =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
|
||||
transformer_.template MakeBDescriptor_N_K<typename GroupedConvTraitsType_::WeiLayout>();
|
||||
c_grid_desc_m_n =
|
||||
conv_to_gemm_transformer
|
||||
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
|
||||
transformer_.template MakeCDescriptor_M_N<typename GroupedConvTraitsType_::OutLayout>();
|
||||
|
||||
group_stride_a = args.C_;
|
||||
group_stride_b = args.K_ * args.C_ *
|
||||
@@ -298,14 +296,15 @@ struct GroupedConvFwdKernelArgs
|
||||
|
||||
// Initialize Split-N support fields for 3D convolution (NDHWGC layout)
|
||||
// Get the actual split N from transformer
|
||||
n_per_split = conv_to_gemm_transformer.GetN();
|
||||
original_n = conv_to_gemm_transformer.GetOriginalN();
|
||||
n_per_split = transformer_.GetN();
|
||||
original_n = transformer_.GetOriginalN();
|
||||
n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split);
|
||||
|
||||
// Calculate batch strides for NDHWGC layout
|
||||
input_batch_stride = args.C_ * args.input_spatial_lengths_[0] *
|
||||
// VERIFIED: G_ MUST be included - NDHWGC layout has all groups within each batch
|
||||
input_batch_stride = args.G_ * args.C_ * args.input_spatial_lengths_[0] *
|
||||
args.input_spatial_lengths_[1] * args.input_spatial_lengths_[2];
|
||||
output_batch_stride = args.K_ * args.output_spatial_lengths_[0] *
|
||||
output_batch_stride = args.G_ * args.K_ * args.output_spatial_lengths_[0] *
|
||||
args.output_spatial_lengths_[1] * args.output_spatial_lengths_[2];
|
||||
|
||||
// Update GemmM to use split N (not original N)
|
||||
@@ -359,6 +358,42 @@ struct GroupedConvFwdKernelArgs
|
||||
index_t original_n = 1; // Original batch size before splitting
|
||||
index_t input_batch_stride = 0; // Stride to next batch in input tensor
|
||||
index_t output_batch_stride = 0; // Stride to next batch in output tensor
|
||||
|
||||
// Split-image support - spatial offsets (applied per-batch in operator())
|
||||
long_index_t spatial_offset_in = 0; // Spatial offset for input (e.g., W/2 for 1D split)
|
||||
long_index_t spatial_offset_out = 0; // Spatial offset for output (e.g., W/2 for 1D split)
|
||||
|
||||
// Split-image support - transformer instance
|
||||
ConvToGemmFwdTransformer transformer_;
|
||||
|
||||
// Forward declare descriptor types (will be defined after using declarations)
|
||||
using ConvToGemmFwdTransformer_t = ConvToGemmFwdTransformer;
|
||||
using AGridDescMK_t = AGridDescMK;
|
||||
using CGridDescMN_t = CGridDescMN;
|
||||
|
||||
// Split-image support: Common data for all pieces
|
||||
struct SplitImageInfo
|
||||
{
|
||||
// Common dimensions (same for all pieces)
|
||||
index_t total_d = 1, total_h = 1, total_w = 1; // Total tensor dimensions
|
||||
index_t total_spatial = 1; // Pre-calculated: total_d * total_h * total_w
|
||||
index_t num_d_pieces = 1, num_h_pieces = 1, num_w_pieces = 1; // Split factors
|
||||
|
||||
// Minimal per-piece data (only unique values)
|
||||
struct PieceInfo
|
||||
{
|
||||
index_t block_start; // Starting block index for this piece
|
||||
index_t block_end; // Ending block index (exclusive)
|
||||
index_t d_start, h_start, w_start; // Piece starting position in OUTPUT space
|
||||
index_t d_size, h_size, w_size; // Piece size in OUTPUT space
|
||||
};
|
||||
|
||||
static constexpr index_t MaxPieces = 64; // Max pieces: 4 (1D), 16 (2D), 64 (3D)
|
||||
std::array<PieceInfo, MaxPieces> pieces; // Array of minimal piece descriptors
|
||||
};
|
||||
|
||||
index_t num_spatial_pieces = 1; // Number of spatial pieces (1 = no split)
|
||||
SplitImageInfo split_image; // Nested structure with common + per-piece data
|
||||
};
|
||||
|
||||
/// @brief The Grouped Convolution Forward kernel template.
|
||||
@@ -399,13 +434,15 @@ struct GroupedConvFwdKernelArgs
|
||||
/// multiplication implementation. It is responsible for storing
|
||||
/// results calculated by @ref GemmPipeline_ "GemmPipeline" to
|
||||
/// the output C tensor in global memory.
|
||||
template <typename GroupedConvTraitsType_,
|
||||
template <bool EnableSplitImage_,
|
||||
typename GroupedConvTraitsType_,
|
||||
typename TilePartitioner_,
|
||||
typename GemmPipeline_,
|
||||
typename EpiloguePipeline_>
|
||||
struct GroupedConvolutionForwardKernel
|
||||
{
|
||||
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
|
||||
static constexpr bool EnableSplitImage = EnableSplitImage_;
|
||||
static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
|
||||
static constexpr ConvolutionSpecialization ConvSpecialization =
|
||||
GroupedConvTraitsType_::ConvSpecialization;
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
@@ -435,7 +472,6 @@ struct GroupedConvolutionForwardKernel
|
||||
|
||||
using GroupedConvFwdKernelArgsSpecialized = GroupedConvFwdKernelArgs<GroupedConvTraitsType_>;
|
||||
|
||||
// TODO: Enable this
|
||||
static constexpr bool IsSplitKSupported = false;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
@@ -449,6 +485,77 @@ struct GroupedConvolutionForwardKernel
|
||||
static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
|
||||
static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
|
||||
|
||||
// Helper struct for spatial coordinates
|
||||
struct SpatialCoords
|
||||
{
|
||||
index_t d, h, w;
|
||||
};
|
||||
|
||||
// Helper: Convert flat spatial index to (d,h,w) coordinates
|
||||
CK_TILE_DEVICE static SpatialCoords
|
||||
UnflattenSpatial(index_t flat, index_t h_size, index_t w_size)
|
||||
{
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
return SpatialCoords{0, 0, flat};
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
return SpatialCoords{0, flat / w_size, flat % w_size};
|
||||
}
|
||||
else // NDimSpatial == 3
|
||||
{
|
||||
const index_t hw = h_size * w_size;
|
||||
const index_t d = flat / hw;
|
||||
const index_t remainder = flat % hw;
|
||||
return SpatialCoords{d, remainder / w_size, remainder % w_size};
|
||||
}
|
||||
}
|
||||
|
||||
// Helper: Convert (d,h,w) to flat spatial index
|
||||
CK_TILE_DEVICE static index_t
|
||||
FlattenSpatial(index_t d, index_t h, index_t w, index_t total_h, index_t total_w)
|
||||
{
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
return w;
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
return h * total_w + w;
|
||||
}
|
||||
else // NDimSpatial == 3
|
||||
{
|
||||
return (d * total_h + h) * total_w + w;
|
||||
}
|
||||
}
|
||||
|
||||
// Helper: Find which piece owns a block using binary search
|
||||
template <typename SplitImageInfo>
|
||||
CK_TILE_DEVICE static index_t
|
||||
FindPieceId(index_t block_id, const SplitImageInfo& split_info, index_t num_pieces)
|
||||
{
|
||||
index_t left = 0;
|
||||
index_t right = num_pieces - 1;
|
||||
index_t piece_id = (left + right) / 2;
|
||||
|
||||
while(!(block_id >= split_info.pieces[piece_id].block_start &&
|
||||
block_id < split_info.pieces[piece_id].block_end) &&
|
||||
left <= right)
|
||||
{
|
||||
if(block_id < split_info.pieces[piece_id].block_start)
|
||||
{
|
||||
right = piece_id - 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
left = piece_id + 1;
|
||||
}
|
||||
piece_id = (left + right) / 2;
|
||||
}
|
||||
return piece_id;
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
@@ -475,7 +582,8 @@ struct GroupedConvolutionForwardKernel
|
||||
CK_TILE_HOST static constexpr GroupedConvFwdKernelArgsSpecialized
|
||||
MakeKernelArgs(const GroupedConvFwdHostArgs<CDElementwise>& hostArgs)
|
||||
{
|
||||
return GroupedConvFwdKernelArgsSpecialized(hostArgs);
|
||||
auto kargs = GroupedConvFwdKernelArgsSpecialized(hostArgs);
|
||||
return kargs;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
@@ -499,17 +607,6 @@ struct GroupedConvolutionForwardKernel
|
||||
}
|
||||
}
|
||||
|
||||
// Check Split-K and Split-N conflict (both use blockIdx.z)
|
||||
if(kargs.k_batch > 1 && kargs.n_splits > 1)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR(
|
||||
"Cannot use both Split-K and Split-N simultaneously (both use blockIdx.z)!");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
|
||||
const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
|
||||
|
||||
@@ -618,27 +715,32 @@ struct GroupedConvolutionForwardKernel
|
||||
return true;
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set,
|
||||
typename ADescType,
|
||||
typename BDescType,
|
||||
typename CDescType>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTensorViews(const InDataType* a_ptr,
|
||||
const WeiDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
OutDataType* c_ptr,
|
||||
const GroupedConvFwdKernelArgsSpecialized& kargs)
|
||||
const ADescType& a_desc,
|
||||
const BDescType& b_desc,
|
||||
const CDescType& c_desc)
|
||||
{
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_m_k);
|
||||
return make_tensor_view<address_space_enum::global>(a_ptr, a_desc);
|
||||
}();
|
||||
|
||||
const auto& b_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_n_k);
|
||||
return make_tensor_view<address_space_enum::global>(b_ptr, b_desc);
|
||||
}();
|
||||
|
||||
// TODO: enable vector write for C in ColMajor
|
||||
const auto& c_tensor_view = [&]() {
|
||||
return make_tensor_view<address_space_enum::global>(c_ptr, kargs.c_grid_desc_m_n);
|
||||
return make_tensor_view<address_space_enum::global>(c_ptr, c_desc);
|
||||
}();
|
||||
|
||||
const auto& ds_tensor_view = generate_tuple(
|
||||
@@ -651,7 +753,7 @@ struct GroupedConvolutionForwardKernel
|
||||
"Not supported!");
|
||||
|
||||
return make_tensor_view<address_space_enum::global>(
|
||||
static_cast<const OutDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
|
||||
static_cast<const OutDataType*>(ds_ptr[i]), c_desc);
|
||||
},
|
||||
number<NumDTensor>{});
|
||||
|
||||
@@ -743,31 +845,39 @@ struct GroupedConvolutionForwardKernel
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param ds_ptr input D tensors pointer array
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The start memory pointer of the shared memory block.
|
||||
* @param kargs Grouped Convolution Forward kernel arguments
|
||||
* @param a_desc Input tensor A descriptor
|
||||
* @param b_desc Weight tensor B descriptor
|
||||
* @param c_desc Output tensor C descriptor
|
||||
* @param gemm_k The GEMM K dimension
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
*/
|
||||
template <typename ADescType, typename BDescType, typename CDescType>
|
||||
CK_TILE_DEVICE static void RunGemm(const InDataType* a_ptr,
|
||||
const WeiDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
OutDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
const GroupedConvFwdKernelArgsSpecialized& kargs,
|
||||
const ADescType& a_desc,
|
||||
const BDescType& b_desc,
|
||||
const CDescType& c_desc,
|
||||
const index_t gemm_k,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc);
|
||||
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(kargs.GemmK));
|
||||
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
@@ -780,9 +890,8 @@ struct GroupedConvolutionForwardKernel
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{kargs.elfunc}
|
||||
.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -792,32 +901,40 @@ struct GroupedConvolutionForwardKernel
|
||||
*
|
||||
* @param a_ptr input A pointer
|
||||
* @param b_ptr input B pointer
|
||||
* @param ds_ptr input D tensors pointer array
|
||||
* @param c_ptr output C pointer
|
||||
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
|
||||
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
|
||||
* @param kargs Grouped Convolution Forward kernel arguments
|
||||
* @param a_desc Input tensor A descriptor
|
||||
* @param b_desc Weight tensor B descriptor
|
||||
* @param c_desc Output tensor C descriptor
|
||||
* @param gemm_k The GEMM K dimension
|
||||
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
|
||||
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
|
||||
*
|
||||
*/
|
||||
template <typename ADescType, typename BDescType, typename CDescType>
|
||||
CK_TILE_DEVICE static void RunGemm2LDS(const InDataType* a_ptr,
|
||||
const WeiDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
OutDataType* c_ptr,
|
||||
void* __restrict__ smem_ptr_0,
|
||||
void* __restrict__ smem_ptr_1,
|
||||
const GroupedConvFwdKernelArgsSpecialized& kargs,
|
||||
const ADescType& a_desc,
|
||||
const BDescType& b_desc,
|
||||
const CDescType& c_desc,
|
||||
const index_t gemm_k,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
|
||||
a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(kargs.GemmK));
|
||||
const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k));
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
@@ -837,12 +954,8 @@ struct GroupedConvolutionForwardKernel
|
||||
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const
|
||||
{
|
||||
const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
|
||||
const auto [iM, iN] =
|
||||
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
|
||||
const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
|
||||
|
||||
const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
|
||||
const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
|
||||
const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
|
||||
const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
|
||||
@@ -860,14 +973,89 @@ struct GroupedConvolutionForwardKernel
|
||||
static_cast<long_index_t>(batch_offset) *
|
||||
static_cast<long_index_t>(kargs.output_batch_stride);
|
||||
|
||||
// Adjust pointers: combine group offset and batch offset
|
||||
const InDataType* a_ptr =
|
||||
// Calculate base pointers with group and batch offsets
|
||||
const InDataType* base_a_ptr =
|
||||
static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a + input_batch_offset;
|
||||
const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) +
|
||||
group_offset_b; // No batch offset for weights!
|
||||
OutDataType* c_ptr =
|
||||
OutDataType* base_c_ptr =
|
||||
static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c + output_batch_offset;
|
||||
|
||||
// =====================================================================
|
||||
// Split-image: Map local block to global tile index (if enabled)
|
||||
// =====================================================================
|
||||
const InDataType* a_ptr;
|
||||
OutDataType* c_ptr;
|
||||
index_t i_m = 0;
|
||||
index_t i_n = 0;
|
||||
|
||||
// Pre-calculate block_id (used in both split-image and non-split paths)
|
||||
const index_t block_id = static_cast<index_t>(blockIdX);
|
||||
|
||||
if constexpr(EnableSplitImage)
|
||||
{
|
||||
// Add spatial offsets for split-image (constexpr optimization)
|
||||
a_ptr = base_a_ptr + kargs.spatial_offset_in;
|
||||
c_ptr = base_c_ptr + kargs.spatial_offset_out;
|
||||
|
||||
// Find which piece owns this block using binary search
|
||||
// Reference: device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
|
||||
const index_t piece_id =
|
||||
FindPieceId(block_id, kargs.split_image, kargs.num_spatial_pieces);
|
||||
const auto& piece = kargs.split_image.pieces[piece_id];
|
||||
const auto& split_info = kargs.split_image;
|
||||
|
||||
// Calculate local block ID and tile indices
|
||||
const index_t local_block_id = block_id - piece.block_start;
|
||||
const index_t local_gemm_m =
|
||||
kargs.n_per_split * piece.d_size * piece.h_size * piece.w_size;
|
||||
const auto [local_tile_m, local_tile_n] =
|
||||
TilePartitioner{local_gemm_m, kargs.GemmN}.GetOutputTileIndex(local_block_id);
|
||||
|
||||
// Extract batch and spatial coordinates from local tile
|
||||
const index_t local_m_start = local_tile_m * TilePartitioner::MPerBlock;
|
||||
const index_t spatial_per_batch = piece.d_size * piece.h_size * piece.w_size;
|
||||
const index_t local_n = local_m_start / spatial_per_batch;
|
||||
const index_t local_spatial_flat = local_m_start % spatial_per_batch;
|
||||
|
||||
// Convert to local spatial coordinates
|
||||
const auto local_coords =
|
||||
UnflattenSpatial(local_spatial_flat, piece.h_size, piece.w_size);
|
||||
|
||||
// Convert to global spatial coordinates
|
||||
const index_t global_n = local_n;
|
||||
const index_t global_d = piece.d_start + local_coords.d;
|
||||
const index_t global_h = piece.h_start + local_coords.h;
|
||||
const index_t global_w = piece.w_start + local_coords.w;
|
||||
|
||||
// Convert to global M index
|
||||
const index_t global_spatial_per_batch = split_info.total_spatial; // Pre-calculated
|
||||
const index_t global_spatial_flat = FlattenSpatial(
|
||||
global_d, global_h, global_w, split_info.total_h, split_info.total_w);
|
||||
const index_t global_m = global_n * global_spatial_per_batch + global_spatial_flat;
|
||||
|
||||
// Set tile indices for GEMM operation
|
||||
i_m = amd_wave_read_first_lane(global_m);
|
||||
i_n = amd_wave_read_first_lane(local_tile_n * TilePartitioner::NPerBlock);
|
||||
}
|
||||
else
|
||||
{
|
||||
// No spatial offsets needed for regular path
|
||||
a_ptr = base_a_ptr;
|
||||
c_ptr = base_c_ptr;
|
||||
|
||||
// No split-image: use standard tile partitioning
|
||||
const auto [iM, iN] =
|
||||
TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(block_id);
|
||||
i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
|
||||
i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
|
||||
}
|
||||
|
||||
// Use global descriptors for all cases
|
||||
const auto& a_desc = kargs.a_grid_desc_m_k;
|
||||
const auto& b_desc = kargs.b_grid_desc_n_k;
|
||||
const auto& c_desc = kargs.c_grid_desc_m_n;
|
||||
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
@@ -878,8 +1066,18 @@ struct GroupedConvolutionForwardKernel
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm2LDS(
|
||||
a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, i_m, i_n);
|
||||
RunGemm2LDS(a_ptr,
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
a_desc,
|
||||
b_desc,
|
||||
c_desc,
|
||||
kargs.GemmK,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
else
|
||||
@@ -888,7 +1086,17 @@ struct GroupedConvolutionForwardKernel
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
{
|
||||
RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n);
|
||||
RunGemm(a_ptr,
|
||||
b_ptr,
|
||||
kargs.ds_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
a_desc,
|
||||
b_desc,
|
||||
c_desc,
|
||||
kargs.GemmK,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user