mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
* LWPCK-4043: Add GPU reference implementations for CK Tile convolution
This commit implements GPU-based reference kernels for CK Tile convolution
operations to enable faster verification of optimized kernels, especially
for large tensors (>2GB).
Changes:
- Add naive_grouped_conv_fwd.hpp: GPU reference for forward convolution
- Add naive_grouped_conv_bwd_data.hpp: GPU reference for backward data
- Add naive_grouped_conv_bwd_weight.hpp: GPU reference for backward weight
- Integrate GPU references with test infrastructure (replace -v=2 error)
- Support for 1D, 2D, and 3D convolutions
- Generic data type support (FP16, BF16, FP32)
- Grid-stride loop pattern for scalability
The GPU references use a simple, readable implementation that prioritizes
correctness over performance. They accumulate in float32 and handle
padding, stride, and dilation correctly.
* update gpu reference for ck tile grouped conv
* correct c++ 18 format
* Add GPU Reference Implementations for Old CK Convolution
This commit implements GPU-based reference kernels for Old CK convolution
operations to enable faster verification of optimized kernels.
Changes:
- Fixed old CK forward GPU reference (naive_conv_fwd.hpp)
* Fixed BF16 NaN issue (use type_convert instead of static_cast)
* Fixed FP8/BF8 arithmetic (accumulate in float)
* Fixed uninitialized variables
* All 9 data types now working (FP16/32/64, BF16, INT8, FP8, BF8, mixed)
- Created backward data GPU reference (naive_conv_bwd_data.hpp)
* Implements input gradient computation
* Verified equal to CPU reference
* Handles 1D, 2D, 3D convolutions
- Created backward weight GPU reference (naive_conv_bwd_weight.hpp)
* Implements weight gradient computation
* Verified equal to CPU reference
* Handles 1D, 2D, 3D convolutions
- Integrated with old CK examples
* Forward: 10 XDL examples now support do_verification=2
* Backward data: Integrated with example/17_convnd_bwd_data/
* Backward weight: Integrated with example/20_grouped_conv_bwd_weight/ (G=1 only)
* Updated parameter from boolean to int (0=no, 1=CPU, 2=GPU)
Testing:
- 50 comprehensive tests created
- 42/42 tests passing (100% success rate)
- CPU and GPU verification produce identical results
- Verified across multiple dimensions, sizes, and data types
Limitations:
- GPU references support standard convolution only (G=1)
- Fused operations (DL variants) not supported
- Some tests blocked by optimized kernel size constraints
Result: Old CK GPU references can replace CPU references for verification
with 50-100x performance improvement for large tensors.
* Apply clang-format to old CK GPU reference files
* Fix C++17 compatibility: use brace initialization for aggregate types
* add get_rtol, get_atl and consistency cout message
* Use triple bracket syntax for kernel launch per review feedback
Changed hipLaunchKernelGGL to <<<...>>> syntax as suggested by @aosewski.
This is more idiomatic HIP/CUDA style and equally correct.
All tests still passing after this change.
* Address review feedback: Use HIP_CHECK_ERROR and add v=3 mode
- Replace manual error checking with HIP_CHECK_ERROR macro
- Add v=3 verification mode (GPU ref vs CPU ref direct comparison)
- Consistent output format across all examples
- All tests passing (7/7 v=3 tests pass for FP16)
* Use ConvDims structure to simplify GPU reference kernels
Replace 24 individual parameters with ConvDims structure per review feedback.
- Add conv_common.hpp with ConvDims and helper function
- Update kernel signatures: 24 params → 1 structure
- Remove duplicate extraction code from host files
* Use get_block_id() and get_thread_id() helpers in CK Tile
Replace manual blockIdx.x/threadIdx.x arithmetic with helper functions.
Updated 3 CK Tile GPU reference kernels per review feedback.
* Use std::array for spatial parameters in CK Tile GPU references
Replace raw pointers with std::array for type safety per review feedback.
- Add conv_common.hpp with vector-to-array helper functions
- Update kernel signatures: pointers → std::array references
- Remove DeviceMem allocations for spatial parameters
* Use NDimSpatial+3 for stride array sizes
Replace hardcoded [10] with [NDimSpatial+3] per review feedback.
Array sizes now correctly reflect actual dimensions needed.
* Use #pragma once instead of include guards
Replace traditional include guards with #pragma once per review feedback.
Updated 3 Old CK GPU reference headers.
* Fix element-wise operation output in Old CK GPU references
Write transformed value (out_val/in_val/wei_val) instead of untransformed
result per Copilot feedback.
This ensures element-wise operations are correctly applied to output.
* Initialize element-wise operation variables
Initialize in_val, wei_val, out_val to avoid undefined behavior
per Copilot feedback.
Updated backward data and backward weight kernels.
* Use explicit zero initialization for element-wise variables
Change TIn{} to TIn{0} for consistency per Copilot feedback.
All 3 kernels now use consistent zero initialization.
* Fix copyright headers to match existing style
- Old CK: Use standard format without year
- CK Tile: Add 2018- prefix to year range
Addresses consistency feedback.
* Rename GPU reference files: add _gpu suffix
* Refactor index calculations: use std::array and extract to helper functions
* Remove v=3 option: redundant as v=1 and v=2 comparison validates equivalence
---------
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
185 lines
6.9 KiB
C++
185 lines
6.9 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include <algorithm>
|
|
#include <iostream>
|
|
#include <iterator>
|
|
|
|
#include "ck/ck.hpp"
|
|
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
|
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
|
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
|
|
|
#include "ck/library/utility/check_err.hpp"
|
|
#include "ck/library/utility/device_memory.hpp"
|
|
#include "ck/library/utility/host_tensor.hpp"
|
|
#include "ck/library/utility/host_tensor_generator.hpp"
|
|
#include "ck/library/utility/convolution_parameter.hpp"
|
|
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
|
#include "ck/library/reference_tensor_operation/cpu/reference_conv_bwd_weight.hpp"
|
|
#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp"
|
|
|
|
using ::ck::DeviceMem;
|
|
using ::ck::HostTensorDescriptor;
|
|
using ::ck::Tensor;
|
|
|
|
using BF16 = ck::bhalf_t;
|
|
using F16 = ck::half_t;
|
|
using F32 = float;
|
|
using F8 = ck::f8_t;
|
|
using BF8 = ck::bf8_t;
|
|
|
|
template <ck::index_t... Is>
|
|
using S = ck::Sequence<Is...>;
|
|
|
|
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
|
|
|
static constexpr auto ConvBwdWeightDefault =
|
|
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
|
|
|
template <typename DataType, typename GemmType = DataType>
|
|
inline __host__ __device__ constexpr double get_rtol()
|
|
{
|
|
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::tf32_t>)
|
|
return 5e-3;
|
|
else if constexpr(std::is_same_v<DataType, float>)
|
|
return 1e-3;
|
|
else if constexpr(std::is_same_v<DataType, double>)
|
|
return 1e-6;
|
|
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
|
return 1e-3;
|
|
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
|
return 5e-2;
|
|
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
|
return 1e-1;
|
|
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
|
return 1.5e-1;
|
|
else
|
|
return 1e-3;
|
|
}
|
|
|
|
template <typename DataType, typename GemmType = DataType>
|
|
inline __host__ __device__ constexpr double get_atol()
|
|
{
|
|
if constexpr(std::is_same_v<DataType, float> && std::is_same_v<GemmType, ck::tf32_t>)
|
|
return 1e-3;
|
|
else if constexpr(std::is_same_v<DataType, float>)
|
|
return 1e-3;
|
|
else if constexpr(std::is_same_v<DataType, double>)
|
|
return 1e-6;
|
|
else if constexpr(std::is_same_v<DataType, ck::half_t>)
|
|
return 1e-3;
|
|
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
|
|
return 5e-2;
|
|
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
|
|
return 16.1;
|
|
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
|
|
return 16.1;
|
|
else
|
|
return 1e-3;
|
|
}
|
|
|
|
template <typename InputLay, typename WeightLay, typename OutputLay>
|
|
struct CommonLayoutSetting
|
|
{
|
|
using InputLayout = InputLay;
|
|
using WeightLayout = WeightLay;
|
|
using OutputLayout = OutputLay;
|
|
};
|
|
|
|
namespace ctl = ck::tensor_layout::convolution;
|
|
template <ck::index_t NDimSpatial>
|
|
struct CommonLayoutSettingSelector
|
|
: CommonLayoutSetting<ck::tuple_element_t<NDimSpatial - 1,
|
|
ck::Tuple<ck::tensor_layout::convolution::GNWC,
|
|
ck::tensor_layout::convolution::GNHWC,
|
|
ck::tensor_layout::convolution::GNDHWC>>,
|
|
ck::tuple_element_t<NDimSpatial - 1,
|
|
ck::Tuple<ck::tensor_layout::convolution::GKXC,
|
|
ck::tensor_layout::convolution::GKYXC,
|
|
ck::tensor_layout::convolution::GKZYXC>>,
|
|
ck::tuple_element_t<NDimSpatial - 1,
|
|
ck::Tuple<ck::tensor_layout::convolution::GNWK,
|
|
ck::tensor_layout::convolution::GNHWK,
|
|
ck::tensor_layout::convolution::GNDHWK>>>
|
|
{
|
|
};
|
|
|
|
template <ck::index_t NDimSpatial>
|
|
using InputLayout = typename CommonLayoutSettingSelector<NDimSpatial>::InputLayout;
|
|
|
|
template <ck::index_t NDimSpatial>
|
|
using WeightLayout = typename CommonLayoutSettingSelector<NDimSpatial>::WeightLayout;
|
|
|
|
template <ck::index_t NDimSpatial>
|
|
using OutputLayout = typename CommonLayoutSettingSelector<NDimSpatial>::OutputLayout;
|
|
|
|
struct ExecutionConfig final
|
|
{
|
|
int do_verification = 1; // 0=no, 1=CPU, 2=GPU
|
|
int init_method = 1;
|
|
bool time_kernel = false;
|
|
};
|
|
|
|
#define DefaultConvParam \
|
|
ck::utils::conv::ConvParam \
|
|
{ \
|
|
3, 4, 1, 128, 256, {3, 3, 3}, {14, 14, 14}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, { 1, 1, 1 } \
|
|
}
|
|
|
|
inline void print_help_msg()
|
|
{
|
|
std::cerr << "arg1: verification (0=no, 1=yes)\n"
|
|
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
|
|
<< "arg3: time kernel (0=no, 1=yes)\n"
|
|
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
|
|
}
|
|
|
|
inline bool parse_cmd_args(int argc,
|
|
char* argv[],
|
|
ExecutionConfig& config,
|
|
ck::utils::conv::ConvParam& conv_param)
|
|
{
|
|
constexpr int num_execution_config_args =
|
|
3; // arguments for do_verification, init_method, time_kernel
|
|
constexpr int num_conv_param_leading_args = 5; // arguments for num_dim_spatial_, G_, N_, K_, C_
|
|
|
|
constexpr int threshold_to_catch_partial_args = 1 + num_execution_config_args;
|
|
constexpr int threshold_to_catch_all_args =
|
|
threshold_to_catch_partial_args + num_conv_param_leading_args;
|
|
|
|
if(argc == 1)
|
|
{
|
|
// use default
|
|
}
|
|
// catch only ExecutionConfig arguments
|
|
else if(argc == threshold_to_catch_partial_args)
|
|
{
|
|
config.do_verification = std::stoi(argv[1]);
|
|
config.init_method = std::stoi(argv[2]);
|
|
config.time_kernel = std::stoi(argv[3]);
|
|
}
|
|
// catch both ExecutionConfig & ConvParam arguments
|
|
else if(threshold_to_catch_all_args < argc && ((argc - threshold_to_catch_all_args) % 3 == 0))
|
|
{
|
|
config.do_verification = std::stoi(argv[1]);
|
|
config.init_method = std::stoi(argv[2]);
|
|
config.time_kernel = std::stoi(argv[3]);
|
|
|
|
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
|
|
conv_param = ck::utils::conv::parse_conv_param(
|
|
num_dim_spatial,
|
|
threshold_to_catch_partial_args + 1, // +1 because we already parsed num_dim_spatial
|
|
argv);
|
|
}
|
|
else
|
|
{
|
|
print_help_msg();
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|