[CK, CK_TILE] Add GPU Reference Implementations for Grouped Convolution (#3216)

* LWPCK-4043: Add GPU reference implementations for CK Tile convolution

This commit implements GPU-based reference kernels for CK Tile convolution
operations to enable faster verification of optimized kernels, especially
for large tensors (>2GB).

Changes:
- Add naive_grouped_conv_fwd.hpp: GPU reference for forward convolution
- Add naive_grouped_conv_bwd_data.hpp: GPU reference for backward data
- Add naive_grouped_conv_bwd_weight.hpp: GPU reference for backward weight
- Integrate GPU references with test infrastructure (replace -v=2 error)
- Support for 1D, 2D, and 3D convolutions
- Generic data type support (FP16, BF16, FP32)
- Grid-stride loop pattern for scalability

The GPU references use a simple, readable implementation that prioritizes
correctness over performance. They accumulate in float32 and handle
padding, stride, and dilation correctly.

* update gpu reference for ck tile grouped conv

* correct c++ 18 format

* Add GPU Reference Implementations for Old CK Convolution

This commit implements GPU-based reference kernels for Old CK convolution
operations to enable faster verification of optimized kernels.

Changes:
- Fixed old CK forward GPU reference (naive_conv_fwd.hpp)
  * Fixed BF16 NaN issue (use type_convert instead of static_cast)
  * Fixed FP8/BF8 arithmetic (accumulate in float)
  * Fixed uninitialized variables
  * All 9 data types now working (FP16/32/64, BF16, INT8, FP8, BF8, mixed)

- Created backward data GPU reference (naive_conv_bwd_data.hpp)
  * Implements input gradient computation
  * Verified equal to CPU reference
  * Handles 1D, 2D, 3D convolutions

- Created backward weight GPU reference (naive_conv_bwd_weight.hpp)
  * Implements weight gradient computation
  * Verified equal to CPU reference
  * Handles 1D, 2D, 3D convolutions

- Integrated with old CK examples
  * Forward: 10 XDL examples now support do_verification=2
  * Backward data: Integrated with example/17_convnd_bwd_data/
  * Backward weight: Integrated with example/20_grouped_conv_bwd_weight/ (G=1 only)
  * Updated parameter from boolean to int (0=no, 1=CPU, 2=GPU)

Testing:
- 50 comprehensive tests created
- 42/42 tests passing (100% success rate)
- CPU and GPU verification produce identical results
- Verified across multiple dimensions, sizes, and data types

Limitations:
- GPU references support standard convolution only (G=1)
- Fused operations (DL variants) not supported
- Some tests blocked by optimized kernel size constraints

Result: Old CK GPU references can replace CPU references for verification
        with 50-100x performance improvement for large tensors.

* Apply clang-format to old CK GPU reference files

* Fix C++17 compatibility: use brace initialization for aggregate types

* add get_rtol, get_atl and consistency cout message

* Use triple bracket syntax for kernel launch per review feedback

Changed hipLaunchKernelGGL to <<<...>>> syntax as suggested by @aosewski.
This is more idiomatic HIP/CUDA style and equally correct.

All tests still passing after this change.

* Address review feedback: Use HIP_CHECK_ERROR and add v=3 mode

- Replace manual error checking with HIP_CHECK_ERROR macro
- Add v=3 verification mode (GPU ref vs CPU ref direct comparison)
- Consistent output format across all examples
- All tests passing (7/7 v=3 tests pass for FP16)

* Use ConvDims structure to simplify GPU reference kernels

Replace 24 individual parameters with ConvDims structure per review feedback.

- Add conv_common.hpp with ConvDims and helper function
- Update kernel signatures: 24 params → 1 structure
- Remove duplicate extraction code from host files

* Use get_block_id() and get_thread_id() helpers in CK Tile

Replace manual blockIdx.x/threadIdx.x arithmetic with helper functions.

Updated 3 CK Tile GPU reference kernels per review feedback.

* Use std::array for spatial parameters in CK Tile GPU references

Replace raw pointers with std::array for type safety per review feedback.

- Add conv_common.hpp with vector-to-array helper functions
- Update kernel signatures: pointers → std::array references
- Remove DeviceMem allocations for spatial parameters

* Use NDimSpatial+3 for stride array sizes

Replace hardcoded [10] with [NDimSpatial+3] per review feedback.

Array sizes now correctly reflect actual dimensions needed.

* Use #pragma once instead of include guards

Replace traditional include guards with #pragma once per review feedback.

Updated 3 Old CK GPU reference headers.

* Fix element-wise operation output in Old CK GPU references

Write transformed value (out_val/in_val/wei_val) instead of untransformed
result per Copilot feedback.

This ensures element-wise operations are correctly applied to output.

* Initialize element-wise operation variables

Initialize in_val, wei_val, out_val to avoid undefined behavior
per Copilot feedback.

Updated backward data and backward weight kernels.

* Use explicit zero initialization for element-wise variables

Change TIn{} to TIn{0} for consistency per Copilot feedback.

All 3 kernels now use consistent zero initialization.

* Fix copyright headers to match existing style

- Old CK: Use standard format without year
- CK Tile: Add 2018- prefix to year range

Addresses consistency feedback.

* Rename GPU reference files: add _gpu suffix

* Refactor index calculations: use std::array and extract to helper functions

* Remove v=3 option: redundant as v=1 and v=2 comparison validates equivalence

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
JH-Leon-KIM-AMD
2025-12-03 21:14:21 +02:00
committed by GitHub
parent 161835533b
commit 4baa4c9fae
21 changed files with 2280 additions and 69 deletions

View File

@@ -18,6 +18,8 @@
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp"
#include "ck_tile/host/hip_check_error.hpp"
using ::ck::DeviceMem;
using ::ck::HostTensorDescriptor;
@@ -25,7 +27,7 @@ using ::ck::Tensor;
void print_helper_msg()
{
std::cout << "arg1: verification (0=no, 1=yes)\n"
std::cout << "arg1: verification (0=no, 1=CPU, 2=GPU)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=no, 1=yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
@@ -130,7 +132,7 @@ template <ck::index_t NDimSpatial,
typename OutElementOp,
typename DeviceConvNDFwdInstance,
typename ComputeDataType = OutDataType>
bool run_grouped_conv_fwd(bool do_verification,
bool run_grouped_conv_fwd(int do_verification,
int init_method,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
@@ -233,8 +235,11 @@ bool run_grouped_conv_fwd(bool do_verification,
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< conv.GetTypeString() << std::endl;
if(do_verification)
std::cout << "do_verification = " << do_verification << std::endl;
if(do_verification == 1)
{
// CPU verification
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType,
WeiDataType,
@@ -269,6 +274,60 @@ bool run_grouped_conv_fwd(bool do_verification,
get_rtol<OutDataType, ComputeDataType>(),
get_atol<OutDataType, ComputeDataType>());
}
else if(do_verification == 2)
{
// GPU verification using naive GPU reference
std::cout << "Running GPU verification..." << std::endl;
// Allocate and ZERO GPU memory for reference output
DeviceMem out_device_ref_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize());
out_device_ref_buf.SetZero();
// Extract dimensions using helper function
ck::ref::ConvDims dims = ck::utils::conv::extract_conv_dims(conv_param, NDimSpatial);
// Launch GPU reference kernel
constexpr ck::index_t block_size = 256;
const ck::long_index_t output_length = dims.N * dims.Do * dims.Ho * dims.Wo * dims.K;
const ck::index_t grid_size = (output_length + block_size - 1) / block_size;
auto gpu_ref_kernel = ck::ref::naive_conv_fwd_ndhwc_kzyxc_ndhwk<InDataType,
WeiDataType,
OutDataType,
ComputeDataType,
InElementOp,
WeiElementOp,
OutElementOp>;
gpu_ref_kernel<<<dim3(grid_size), dim3(block_size), 0, nullptr>>>(
reinterpret_cast<const InDataType*>(in_device_buf.GetDeviceBuffer()),
reinterpret_cast<const WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
reinterpret_cast<OutDataType*>(out_device_ref_buf.GetDeviceBuffer()),
dims);
HIP_CHECK_ERROR(hipDeviceSynchronize());
std::cout << "GPU reference kernel completed successfully, copying results..." << std::endl;
// Copy GPU reference result to host
out_device_ref_buf.FromDevice(out_host.mData.data());
// Copy GPU kernel result to host
out_device_buf.FromDevice(out_device.mData.data());
std::cout << "Comparing GPU kernel output vs GPU reference..." << std::endl;
// Compare GPU kernel vs GPU reference
bool pass = ck::utils::check_err(out_device,
out_host,
"Error: incorrect results!",
get_rtol<OutDataType, ComputeDataType>(),
get_atol<OutDataType, ComputeDataType>());
std::cout << "GPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
return pass;
}
return true;
}

View File

@@ -25,7 +25,7 @@ using ::ck::Tensor;
void print_helper_msg()
{
std::cout << "arg1: verification (0=no, 1=yes)\n"
std::cout << "arg1: verification (0=no, 1=CPU)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=no, 1=yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
@@ -162,6 +162,7 @@ bool run_grouped_conv_fwd_dl(bool do_verification,
if(do_verification)
{
// CPU verification only (DL variants are fused operations)
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<
NDimSpatial,
InDataType,

View File

@@ -12,9 +12,9 @@ bool run_convnd_fwd_example(int argc, char* argv[])
{
print_helper_msg();
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
int do_verification = 1; // 0=no, 1=CPU, 2=GPU
int init_method = 1;
bool time_kernel = false;
ck::utils::conv::ConvParam conv_param{
2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};