Files
composable_kernel/example/17_convnd_bwd_data
JH-Leon-KIM-AMD 4baa4c9fae [CK, CK_TILE] Add GPU Reference Implementations for Grouped Convolution (#3216)
* LWPCK-4043: Add GPU reference implementations for CK Tile convolution

This commit implements GPU-based reference kernels for CK Tile convolution
operations to enable faster verification of optimized kernels, especially
for large tensors (>2GB).

Changes:
- Add naive_grouped_conv_fwd.hpp: GPU reference for forward convolution
- Add naive_grouped_conv_bwd_data.hpp: GPU reference for backward data
- Add naive_grouped_conv_bwd_weight.hpp: GPU reference for backward weight
- Integrate GPU references with test infrastructure (replace -v=2 error)
- Support for 1D, 2D, and 3D convolutions
- Generic data type support (FP16, BF16, FP32)
- Grid-stride loop pattern for scalability

The GPU references use a simple, readable implementation that prioritizes
correctness over performance. They accumulate in float32 and handle
padding, stride, and dilation correctly.

* update gpu reference for ck tile grouped conv

* correct c++ 18 format

* Add GPU Reference Implementations for Old CK Convolution

This commit implements GPU-based reference kernels for Old CK convolution
operations to enable faster verification of optimized kernels.

Changes:
- Fixed old CK forward GPU reference (naive_conv_fwd.hpp)
  * Fixed BF16 NaN issue (use type_convert instead of static_cast)
  * Fixed FP8/BF8 arithmetic (accumulate in float)
  * Fixed uninitialized variables
  * All 9 data types now working (FP16/32/64, BF16, INT8, FP8, BF8, mixed)

- Created backward data GPU reference (naive_conv_bwd_data.hpp)
  * Implements input gradient computation
  * Verified equal to CPU reference
  * Handles 1D, 2D, 3D convolutions

- Created backward weight GPU reference (naive_conv_bwd_weight.hpp)
  * Implements weight gradient computation
  * Verified equal to CPU reference
  * Handles 1D, 2D, 3D convolutions

- Integrated with old CK examples
  * Forward: 10 XDL examples now support do_verification=2
  * Backward data: Integrated with example/17_convnd_bwd_data/
  * Backward weight: Integrated with example/20_grouped_conv_bwd_weight/ (G=1 only)
  * Updated parameter from boolean to int (0=no, 1=CPU, 2=GPU)

Testing:
- 50 comprehensive tests created
- 42/42 tests passing (100% success rate)
- CPU and GPU verification produce identical results
- Verified across multiple dimensions, sizes, and data types

Limitations:
- GPU references support standard convolution only (G=1)
- Fused operations (DL variants) not supported
- Some tests blocked by optimized kernel size constraints

Result: Old CK GPU references can replace CPU references for verification
        with 50-100x performance improvement for large tensors.

* Apply clang-format to old CK GPU reference files

* Fix C++17 compatibility: use brace initialization for aggregate types

* add get_rtol, get_atl and consistency cout message

* Use triple bracket syntax for kernel launch per review feedback

Changed hipLaunchKernelGGL to <<<...>>> syntax as suggested by @aosewski.
This is more idiomatic HIP/CUDA style and equally correct.

All tests still passing after this change.

* Address review feedback: Use HIP_CHECK_ERROR and add v=3 mode

- Replace manual error checking with HIP_CHECK_ERROR macro
- Add v=3 verification mode (GPU ref vs CPU ref direct comparison)
- Consistent output format across all examples
- All tests passing (7/7 v=3 tests pass for FP16)

* Use ConvDims structure to simplify GPU reference kernels

Replace 24 individual parameters with ConvDims structure per review feedback.

- Add conv_common.hpp with ConvDims and helper function
- Update kernel signatures: 24 params → 1 structure
- Remove duplicate extraction code from host files

* Use get_block_id() and get_thread_id() helpers in CK Tile

Replace manual blockIdx.x/threadIdx.x arithmetic with helper functions.

Updated 3 CK Tile GPU reference kernels per review feedback.

* Use std::array for spatial parameters in CK Tile GPU references

Replace raw pointers with std::array for type safety per review feedback.

- Add conv_common.hpp with vector-to-array helper functions
- Update kernel signatures: pointers → std::array references
- Remove DeviceMem allocations for spatial parameters

* Use NDimSpatial+3 for stride array sizes

Replace hardcoded [10] with [NDimSpatial+3] per review feedback.

Array sizes now correctly reflect actual dimensions needed.

* Use #pragma once instead of include guards

Replace traditional include guards with #pragma once per review feedback.

Updated 3 Old CK GPU reference headers.

* Fix element-wise operation output in Old CK GPU references

Write transformed value (out_val/in_val/wei_val) instead of untransformed
result per Copilot feedback.

This ensures element-wise operations are correctly applied to output.

* Initialize element-wise operation variables

Initialize in_val, wei_val, out_val to avoid undefined behavior
per Copilot feedback.

Updated backward data and backward weight kernels.

* Use explicit zero initialization for element-wise variables

Change TIn{} to TIn{0} for consistency per Copilot feedback.

All 3 kernels now use consistent zero initialization.

* Fix copyright headers to match existing style

- Old CK: Use standard format without year
- CK Tile: Add 2018- prefix to year range

Addresses consistency feedback.

* Rename GPU reference files: add _gpu suffix

* Refactor index calculations: use std::array and extract to helper functions

* Remove v=3 option: redundant as v=1 and v=2 comparison validates equivalence

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
2025-12-03 21:14:21 +02:00
..

N-Dimensional Convolution Backward Pass for Data

This example demonstrates the backward data pass of an N-dimensional convolution, often denoted as conv_bwd_data. This operation is a crucial part of the backpropagation algorithm for training Convolutional Neural Networks (CNNs). Its purpose is to compute the gradient of the loss function with respect to the convolution's input data, which is then passed back to the preceding layer in the network.

Mathematical Formulation

The backward data pass computes the gradient \frac{\partial L}{\partial \text{In}}, given the gradient from the subsequent layer, \frac{\partial L}{\partial \text{Out}}, and the filter weights W used in the forward pass.

Let the forward convolution be defined as: \text{Out} = \text{In} \star W

The backward data pass is mathematically equivalent to a "full" convolution between the output gradient tensor dL/dOut and the 180-degree rotated (or transposed and flipped) weight tensor W.

\frac{\partial L}{\partial \text{In}} = \frac{\partial L}{\partial \text{Out}} \star \text{rot180}(W)

This operation propagates the error signal from the output back to the input, weighted by the same filters that were used in the forward pass.

Algorithmic Strategy: Implicit GEMM

As with the forward pass, the most efficient way to implement the backward data pass on a GPU is to transform the convolution into a General Matrix-Matrix Multiplication (GEMM) problem.

  1. Output Gradient Reshaping: The output gradient tensor dL/dOut is logically reshaped into a matrix dL/dOut' of shape [K, (N*Ho*Wo)]. This becomes the "A" matrix in the GEMM.

  2. Weight Reshaping: The weight tensor W is logically reshaped into a matrix W' of shape [K, (C*Y*X)]. This becomes the "B" matrix in the GEMM.

  3. Implicit GEMM: The core computation is then formulated as a GEMM operation. However, the output of this GEMM is not a simple matrix; it's the dL/dIn tensor. (\text{dL/dIn})' = (W')^T \times (\text{dL/dOut})'

    The key insight is that this operation can be performed without explicitly forming the matrices. The GEMM kernel is designed to read from dL/dOut and W and write its results directly to the appropriate locations in the dL/dIn tensor. This process is sometimes referred to as an "implicit col2im" (column-to-image), as it is the inverse of the im2col transformation used in the forward pass.

This "implicit GEMM" approach is highly efficient. It avoids the massive memory and bandwidth overhead of materializing intermediate matrices, which is critical for performance.

Source Code Organization

Build and Run

Prerequisites

Please follow the instructions in the main Build Guide section as a prerequisite to building and running this example.

Build the Example

cd /path/to/composable_kernel/example/17_convnd_bwd_data
mkdir build && cd build

cmake \
  -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
  -DCMAKE_PREFIX_PATH="/opt/rocm;${CK_INSTALL_PATH}" \
  ..

make -j

Run the Example

#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1)
#arg4: num_dim_spatial(1|2|3)
#arg5 to ...: N, K, C, [Z,] [Y,] X, [Di,] [Hi,] Wi, S[z,] [Sy,] Sx, [Dz,] [Dy,] Dx, [LeftPz,] [LeftPy,] LeftPx, [RightPy,] [RightPy,] RightPx
./bin/example_convnd_bwd_data_xdl 0 1 5 

Result

in_n_c_hi_wi: dim 4, lengths {128, 128, 71, 71}, strides {645248, 1, 9088, 128}
wei_k_c_y_x: dim 4, lengths {256, 128, 3, 3}, strides {1152, 1, 384, 128}
out_n_k_ho_wo: dim 4, lengths {128, 256, 36, 36}, strides {331776, 1, 9216, 256}
arg.a_grid_desc_k0_m_k1_container_{128, 175232, 8}
arg.b_grid_desc_k0_n_k1_container_{128, 128, 8}
arg.c_grid_desc_m_n_container_{ 175232, 128}
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 ) 
launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1} 
Warm up
Start running 1 times...
arg.a_grid_desc_k0_m_k1_container_{64, 175232, 8}
arg.b_grid_desc_k0_n_k1_container_{64, 128, 8}
arg.c_grid_desc_m_n_container_{ 175232, 128}
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 ) 
launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1} 
Warm up
Start running 1 times...
arg.a_grid_desc_k0_m_k1_container_{64, 175232, 8}
arg.b_grid_desc_k0_n_k1_container_{64, 128, 8}
arg.c_grid_desc_m_n_container_{ 175232, 128}
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 ) 
launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1} 
Warm up
Start running 1 times...
arg.a_grid_desc_k0_m_k1_container_{32, 175232, 8}
arg.b_grid_desc_k0_n_k1_container_{32, 128, 8}
arg.c_grid_desc_m_n_container_{ 175232, 128}
arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( 2738, 2, 2, 2, 4, 2 ) 
launch_and_time_kernel: grid_dim {1369, 1, 1}, block_dim {256, 1, 1} 
Warm up
Start running 1 times...
Perf: 1.40031 ms, 69.8734 TFlops, 179.037 GB/s

Relationship to Other Passes

The training of a single convolutional layer requires three distinct steps:

  1. Forward Pass (conv_fwd): Computes the output feature maps.
    • Out = In * W
  2. Backward Data Pass (conv_bwd_data): Computes the gradient with respect to the input, propagating the error to the previous layer. This is the focus of the current example.
    • dL/dIn = dL/dOut * rot180(W)
  3. Backward Weight Pass (conv_bwd_weight): Computes the gradient with respect to the weights, which is needed for the weight update.
    • dL/dW = In * dL/dOut

All three passes are critical for training a CNN, and all are typically implemented as high-performance implicit GEMM operations.