[CK] Integrate GPU reference into ckProfiler for convolutions (#3379)

Refactor and integrate CK GPU references into ckProfiler.
- All convolution layouts and groupings supported for all three directions
- Unit tests verifying GPU and CPU reference is the same
- Support added to profiler (do_verification = 2 enables GPU reference)
- One profiler-based test per direction changed to GPU reference to demonstrate usag

Closes AICK-427
This commit is contained in:
Johannes Graner
2025-12-18 07:59:45 +01:00
committed by GitHub
parent 87dd073887
commit bb8445dca8
31 changed files with 3351 additions and 953 deletions

View File

@@ -7,11 +7,12 @@
#include <iostream>
#include <memory>
#include <sstream>
#include "conv_util.hpp"
#include "device.hpp"
#include "device_conv_fwd.hpp"
#include "common_header.hpp"
#include "naive_conv_fwd_gpu.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/stream_config.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp"
namespace ck {
namespace tensor_operation {
@@ -26,7 +27,16 @@ template <typename InDataType,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
: public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
: public DeviceConvFwd<3,
ck::tensor_layout::convolution::NDHWC,
ck::tensor_layout::convolution::KZYXC,
ck::tensor_layout::convolution::NDHWK,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{
using DeviceOp = DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K;
@@ -57,6 +67,7 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op)
: params_{3,
1, // G (group count, always 1 for non-grouped)
N,
K,
C,
@@ -78,7 +89,7 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
}
// private:
utils::conv::ConvParams params_;
utils::conv::ConvParam params_;
std::vector<index_t> out_spatial_lengths_;
const InDataType* p_in_;
@@ -97,46 +108,28 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto naive_conv3d_fwd =
ref::naive_conv_fwd_ndhwc_kzyxc_ndhwk<InDataType,
WeiDataType,
OutDataType,
AccDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>;
using InLayout = ck::tensor_layout::convolution::GNCDHW;
using WeiLayout = ck::tensor_layout::convolution::GKCZYX;
using OutLayout = ck::tensor_layout::convolution::GNKDHW;
float ave_time = launch_and_time_kernel(stream_config,
naive_conv3d_fwd,
dim3(256),
dim3(256),
0,
arg.p_in_,
arg.p_wei_,
arg.p_out_,
arg.N_,
arg.K_,
arg.C_,
arg.in_spatial_lengths_[0],
arg.in_spatial_lengths_[1],
arg.in_spatial_lengths_[2],
arg.filter_spatial_lengths_[0],
arg.filter_spatial_lengths_[1],
arg.filter_spatial_lengths_[2],
arg.out_spatial_lengths_[0],
arg.out_spatial_lengths_[1],
arg.out_spatial_lengths_[2],
arg.conv_filter_strides_[0],
arg.conv_filter_strides_[1],
arg.conv_filter_strides_[2],
arg.conv_filter_dilations_[0],
arg.conv_filter_dilations_[1],
arg.conv_filter_dilations_[2],
arg.in_left_pads_[0],
arg.in_left_pads_[1],
arg.in_left_pads_[2]);
return ave_time;
// Use simplified ConvParam-based API
ref::naive_conv_fwd<InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>(arg.p_in_,
arg.p_wei_,
arg.p_out_,
arg.params_,
arg.in_element_op_,
arg.wei_element_op_,
arg.out_element_op_,
stream_config.stream_id_);
return 0; // No timing for naive implementation
}
// polymorphic
@@ -155,7 +148,9 @@ struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_W
static bool IsSupportedArgument(const Argument& arg)
{
std::vector<index_t> out_spatial_lengths = arg.params_.GetOutputSpatialLengths();
auto out_spatial_lengths_long = arg.params_.GetOutputSpatialLengths();
std::vector<index_t> out_spatial_lengths(out_spatial_lengths_long.begin(),
out_spatial_lengths_long.end());
bool out_lengths_are_consistent = out_spatial_lengths[0] == arg.out_spatial_lengths_[0] &&
out_spatial_lengths[1] == arg.out_spatial_lengths_[1] &&