[CK] Integrate GPU reference into ckProfiler for convolutions (#3379)

Refactor and integrate CK GPU references into ckProfiler.
- All convolution layouts and groupings supported for all three directions
- Unit tests verifying GPU and CPU reference is the same
- Support added to profiler (do_verification = 2 enables GPU reference)
- One profiler-based test per direction changed to GPU reference to demonstrate usag

Closes AICK-427
This commit is contained in:
Johannes Graner
2025-12-18 07:59:45 +01:00
committed by GitHub
parent 87dd073887
commit bb8445dca8
31 changed files with 3351 additions and 953 deletions

View File

@@ -1,73 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#ifndef CONV_COMMON_HPP
#define CONV_COMMON_HPP
#include "ck/ck.hpp"
namespace ck {
namespace ref {
// Device-compatible dimension structure for GPU reference kernels
// Replaces passing 24 individual parameters
struct ConvDims
{
index_t N, K, C;
index_t Di, Hi, Wi;
index_t Z, Y, X;
index_t Do, Ho, Wo;
index_t stride_z, stride_y, stride_x;
index_t dilation_z, dilation_y, dilation_x;
index_t pad_z, pad_y, pad_x;
};
} // namespace ref
// Helper function to extract dimensions from ConvParam for GPU kernels
// Defined in ck::utils::conv namespace for convenience
namespace utils {
namespace conv {
inline ck::ref::ConvDims
extract_conv_dims(const ConvParam& conv_param, ck::index_t NDimSpatial, bool apply_group = true)
{
ck::ref::ConvDims dims;
dims.N = conv_param.N_;
dims.K = conv_param.K_;
dims.C = apply_group ? (conv_param.C_ * conv_param.G_) : conv_param.C_;
dims.Di = (NDimSpatial >= 3) ? conv_param.input_spatial_lengths_[0] : 1;
dims.Hi = (NDimSpatial >= 2) ? conv_param.input_spatial_lengths_[NDimSpatial >= 3 ? 1 : 0] : 1;
dims.Wi = conv_param.input_spatial_lengths_[NDimSpatial - 1];
dims.Z = (NDimSpatial >= 3) ? conv_param.filter_spatial_lengths_[0] : 1;
dims.Y = (NDimSpatial >= 2) ? conv_param.filter_spatial_lengths_[NDimSpatial >= 3 ? 1 : 0] : 1;
dims.X = conv_param.filter_spatial_lengths_[NDimSpatial - 1];
dims.Do = (NDimSpatial >= 3) ? conv_param.output_spatial_lengths_[0] : 1;
dims.Ho = (NDimSpatial >= 2) ? conv_param.output_spatial_lengths_[NDimSpatial >= 3 ? 1 : 0] : 1;
dims.Wo = conv_param.output_spatial_lengths_[NDimSpatial - 1];
dims.stride_z = (NDimSpatial >= 3) ? conv_param.conv_filter_strides_[0] : 1;
dims.stride_y =
(NDimSpatial >= 2) ? conv_param.conv_filter_strides_[NDimSpatial >= 3 ? 1 : 0] : 1;
dims.stride_x = conv_param.conv_filter_strides_[NDimSpatial - 1];
dims.dilation_z = (NDimSpatial >= 3) ? conv_param.conv_filter_dilations_[0] : 1;
dims.dilation_y =
(NDimSpatial >= 2) ? conv_param.conv_filter_dilations_[NDimSpatial >= 3 ? 1 : 0] : 1;
dims.dilation_x = conv_param.conv_filter_dilations_[NDimSpatial - 1];
dims.pad_z = (NDimSpatial >= 3) ? conv_param.input_left_pads_[0] : 0;
dims.pad_y = (NDimSpatial >= 2) ? conv_param.input_left_pads_[NDimSpatial >= 3 ? 1 : 0] : 0;
dims.pad_x = conv_param.input_left_pads_[NDimSpatial - 1];
return dims;
}
} // namespace conv
} // namespace utils
} // namespace ck
#endif

View File

@@ -4,146 +4,515 @@
#pragma once
#include "ck/utility/type_convert.hpp"
#include "ck/library/reference_tensor_operation/gpu/conv_common.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace ref {
/*
* \brief naive implementation of 3D convolution backward data.
* Layout is (NDHWC, KZYXC, NDHWK).
* Computes gradient with respect to input.
*
* \param N number of batches
* \param K number of filters (output channels)
* \param C number of input channels
* \param (Di, Hi, Wi) depth, height and width dimension of input
* \param (Z, Y, X) depth, height and width dimensions of filter
* \param (Do, Ho, Wo) depth, height and width dimension of output
* \param (stride_z, stride_y, stride_x) strides
* \param (dilation_z, dilation_y, dilation_x) dilations
* \param (pad_z, pad_y, pad_x) pads
*/
template <typename TIn,
typename TWei,
typename TOut,
typename TAcc,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
__global__ void naive_conv_bwd_data_ndhwc_kzyxc_ndhwk(TIn* __restrict__ p_in_grad,
const TWei* __restrict__ p_wei,
const TOut* __restrict__ p_out_grad,
const ConvDims dims)
// Optimized backward data convolution kernel working with packed (contiguous) tensors
// Computes gradients w.r.t. input from output gradients and weights
// Assumes row-major packing: input[G][N][C][spatial], weight[G][K][C][filter],
// output[G][N][K][spatial]
template <index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementOp,
typename WeiElementOp,
typename OutElementOp>
__global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
const WeiDataType* __restrict__ p_wei,
const OutDataType* __restrict__ p_out,
index_t G,
index_t N,
index_t K,
index_t C,
index_t Di,
index_t Hi,
index_t Wi,
index_t Z,
index_t Y,
index_t X,
index_t Do,
index_t Ho,
index_t Wo,
index_t stride_z,
index_t stride_y,
index_t stride_x,
index_t dilation_z,
index_t dilation_y,
index_t dilation_x,
index_t pad_z,
index_t pad_y,
index_t pad_x,
InElementOp in_op,
WeiElementOp wei_op,
OutElementOp out_op)
{
const index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const index_t num_threads = blockDim.x * gridDim.x;
const long_index_t input_length = dims.N * dims.Di * dims.Hi * dims.Wi * dims.C;
const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const long_index_t num_threads = blockDim.x * gridDim.x;
const index_t in_strides[] = {
dims.Di * dims.Hi * dims.Wi * dims.C, dims.Hi * dims.Wi * dims.C, dims.Wi * dims.C, dims.C};
const index_t out_strides[] = {
dims.Do * dims.Ho * dims.Wo * dims.K, dims.Ho * dims.Wo * dims.K, dims.Wo * dims.K, dims.K};
const index_t wei_strides[] = {
dims.Z * dims.Y * dims.X * dims.C, dims.Y * dims.X * dims.C, dims.X * dims.C, dims.C};
InDataType in_val = InDataType{0};
WeiDataType wei_val = WeiDataType{0};
OutDataType out_val = OutDataType{0};
constexpr auto in_op = InElementwiseOperation{};
constexpr auto wei_op = WeiElementwiseOperation{};
constexpr auto out_op = OutElementwiseOperation{};
TIn in_val = TIn{0};
TWei wei_val = TWei{0};
TOut out_val = TOut{0};
for(long_index_t ii = tid; ii < input_length; ii += num_threads)
if constexpr(NDimSpatial == 1)
{
// Decode linear index to (n, di, hi, wi, c)
const index_t n = ii / in_strides[0];
index_t tmp = ii - n * in_strides[0];
const index_t di = tmp / in_strides[1];
tmp -= di * in_strides[1];
const index_t hi = tmp / in_strides[2];
tmp -= hi * in_strides[2];
const index_t wi = tmp / in_strides[3];
tmp -= wi * in_strides[3];
const index_t c = tmp;
const long_index_t num_in = G * N * C * Wi;
const long_index_t out_stride_g = N * K * Wo;
const long_index_t out_stride_n = K * Wo;
const long_index_t out_stride_k = Wo;
const long_index_t wei_stride_g = K * C * X;
const long_index_t wei_stride_k = C * X;
const long_index_t wei_stride_c = X;
const long_index_t in_stride_g = N * C * Wi;
const long_index_t in_stride_n = C * Wi;
const long_index_t in_stride_c = Wi;
// Always accumulate in float
float acc_float = 0.0f;
const TOut* out_n = p_out_grad + static_cast<long_index_t>(n) * out_strides[0];
// Loop over output channels
for(index_t k = 0; k < dims.K; ++k)
for(long_index_t idx = tid; idx < num_in; idx += num_threads)
{
const TWei* wei_k = p_wei + static_cast<long_index_t>(k) * wei_strides[0];
index_t remaining = idx;
const index_t wi = remaining % Wi;
remaining /= Wi;
const index_t c = remaining % C;
remaining /= C;
const index_t n = remaining % N;
const index_t g = remaining / N;
// Loop over filter dimensions
for(index_t z = 0; z < dims.Z; ++z)
float acc = 0.0f;
const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n;
const WeiDataType* wei_g = p_wei + g * wei_stride_g;
for(index_t x = 0; x < X; ++x)
{
// Calculate output position from input position (inverse of forward)
index_t d_tmp = di + dims.pad_z - z * dims.dilation_z;
if(d_tmp % dims.stride_z != 0)
continue;
index_t d_o = d_tmp / dims.stride_z;
if(d_o < 0 || d_o >= dims.Do)
continue;
const TOut* out_n_do = out_n + d_o * out_strides[1];
const TWei* wei_k_z = wei_k + z * wei_strides[1];
for(index_t y = 0; y < dims.Y; ++y)
long_index_t w_tmp = wi + pad_x - x * dilation_x;
if(w_tmp % stride_x == 0)
{
index_t h_tmp = hi + dims.pad_y - y * dims.dilation_y;
if(h_tmp % dims.stride_y != 0)
continue;
index_t ho = h_tmp / dims.stride_y;
if(ho < 0 || ho >= dims.Ho)
continue;
const TOut* out_n_do_ho = out_n_do + ho * out_strides[2];
const TWei* wei_k_z_y = wei_k_z + y * wei_strides[2];
for(index_t x = 0; x < dims.X; ++x)
long_index_t wo = w_tmp / stride_x;
if(wo >= 0 && wo < Wo)
{
index_t w_tmp = wi + dims.pad_x - x * dims.dilation_x;
if(w_tmp % dims.stride_x != 0)
continue;
index_t wo = w_tmp / dims.stride_x;
if(wo < 0 || wo >= dims.Wo)
continue;
const OutDataType* out_gnk = out_gn;
const WeiDataType* wei_gkc = wei_g + c * wei_stride_c;
const TOut* out_n_do_ho_wo = out_n_do_ho + wo * out_strides[3];
const TWei* wei_k_z_y_x = wei_k_z_y + x * wei_strides[3];
// Load values from memory
TOut out_loaded = out_n_do_ho_wo[k];
TWei wei_loaded = wei_k_z_y_x[c];
// Apply element-wise operations (like forward does)
out_op(out_val, out_loaded);
wei_op(wei_val, wei_loaded);
// Convert to float for multiplication
float out_f = type_convert<float>(out_val);
float wei_f = type_convert<float>(wei_val);
acc_float += out_f * wei_f;
for(index_t k = 0; k < K; ++k)
{
out_op(out_val, out_gnk[k * out_stride_k + wo]);
wei_op(wei_val, wei_gkc[k * wei_stride_k + x]);
acc += type_convert<float>(out_val) * type_convert<float>(wei_val);
}
}
}
}
InDataType result = type_convert<InDataType>(acc);
in_op(in_val, result);
p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + wi] = in_val;
}
}
else if constexpr(NDimSpatial == 2)
{
const long_index_t num_in = G * N * C * Hi * Wi;
const long_index_t out_stride_g = N * K * Ho * Wo;
const long_index_t out_stride_n = K * Ho * Wo;
const long_index_t out_stride_k = Ho * Wo;
const long_index_t out_stride_h = Wo;
const long_index_t wei_stride_g = K * C * Y * X;
const long_index_t wei_stride_k = C * Y * X;
const long_index_t wei_stride_c = Y * X;
const long_index_t wei_stride_y = X;
const long_index_t in_stride_g = N * C * Hi * Wi;
const long_index_t in_stride_n = C * Hi * Wi;
const long_index_t in_stride_c = Hi * Wi;
const long_index_t in_stride_h = Wi;
// Convert float accumulator to TAcc, then to input type
TAcc acc = type_convert<TAcc>(acc_float);
TIn result = type_convert<TIn>(acc);
for(long_index_t idx = tid; idx < num_in; idx += num_threads)
{
index_t remaining = idx;
const index_t wi = remaining % Wi;
remaining /= Wi;
const index_t hi = remaining % Hi;
remaining /= Hi;
const index_t c = remaining % C;
remaining /= C;
const index_t n = remaining % N;
const index_t g = remaining / N;
// Apply input element-wise operation (if any)
in_op(in_val, result);
float acc = 0.0f;
const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n;
const WeiDataType* wei_g = p_wei + g * wei_stride_g;
// Write transformed result
p_in_grad[ii] = in_val;
for(index_t y = 0; y < Y; ++y)
{
long_index_t h_tmp = hi + pad_y - y * dilation_y;
if(h_tmp % stride_y == 0)
{
long_index_t ho = h_tmp / stride_y;
if(ho >= 0 && ho < Ho)
{
const OutDataType* out_gnkh = out_gn + ho * out_stride_h;
const WeiDataType* wei_gkcy = wei_g + c * wei_stride_c + y * wei_stride_y;
for(index_t x = 0; x < X; ++x)
{
long_index_t w_tmp = wi + pad_x - x * dilation_x;
if(w_tmp % stride_x == 0)
{
long_index_t wo = w_tmp / stride_x;
if(wo >= 0 && wo < Wo)
{
for(index_t k = 0; k < K; ++k)
{
out_op(out_val, out_gnkh[k * out_stride_k + wo]);
wei_op(wei_val, wei_gkcy[k * wei_stride_k + x]);
acc += type_convert<float>(out_val) *
type_convert<float>(wei_val);
}
}
}
}
}
}
}
InDataType result = type_convert<InDataType>(acc);
in_op(in_val, result);
p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + hi * in_stride_h + wi] =
in_val;
}
}
else if constexpr(NDimSpatial == 3)
{
const long_index_t num_in = G * N * C * Di * Hi * Wi;
const long_index_t out_stride_g = N * K * Do * Ho * Wo;
const long_index_t out_stride_n = K * Do * Ho * Wo;
const long_index_t out_stride_k = Do * Ho * Wo;
const long_index_t out_stride_d = Ho * Wo;
const long_index_t out_stride_h = Wo;
const long_index_t wei_stride_g = K * C * Z * Y * X;
const long_index_t wei_stride_k = C * Z * Y * X;
const long_index_t wei_stride_c = Z * Y * X;
const long_index_t wei_stride_z = Y * X;
const long_index_t wei_stride_y = X;
const long_index_t in_stride_g = N * C * Di * Hi * Wi;
const long_index_t in_stride_n = C * Di * Hi * Wi;
const long_index_t in_stride_c = Di * Hi * Wi;
const long_index_t in_stride_d = Hi * Wi;
const long_index_t in_stride_h = Wi;
for(long_index_t idx = tid; idx < num_in; idx += num_threads)
{
index_t remaining = idx;
const index_t wi = remaining % Wi;
remaining /= Wi;
const index_t hi = remaining % Hi;
remaining /= Hi;
const index_t di = remaining % Di;
remaining /= Di;
const index_t c = remaining % C;
remaining /= C;
const index_t n = remaining % N;
const index_t g = remaining / N;
float acc = 0.0f;
const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n;
const WeiDataType* wei_g = p_wei + g * wei_stride_g;
for(index_t z = 0; z < Z; ++z)
{
long_index_t d_tmp = di + pad_z - z * dilation_z;
if(d_tmp % stride_z == 0)
{
long_index_t do_idx = d_tmp / stride_z;
if(do_idx >= 0 && do_idx < Do)
{
const OutDataType* out_gnkd = out_gn + do_idx * out_stride_d;
const WeiDataType* wei_gkcz = wei_g + c * wei_stride_c + z * wei_stride_z;
for(index_t y = 0; y < Y; ++y)
{
long_index_t h_tmp = hi + pad_y - y * dilation_y;
if(h_tmp % stride_y == 0)
{
long_index_t ho = h_tmp / stride_y;
if(ho >= 0 && ho < Ho)
{
const OutDataType* out_gnkdh = out_gnkd + ho * out_stride_h;
const WeiDataType* wei_gkczy = wei_gkcz + y * wei_stride_y;
for(index_t x = 0; x < X; ++x)
{
long_index_t w_tmp = wi + pad_x - x * dilation_x;
if(w_tmp % stride_x == 0)
{
long_index_t wo = w_tmp / stride_x;
if(wo >= 0 && wo < Wo)
{
for(index_t k = 0; k < K; ++k)
{
out_op(out_val,
out_gnkdh[k * out_stride_k + wo]);
wei_op(wei_val,
wei_gkczy[k * wei_stride_k + x]);
acc += type_convert<float>(out_val) *
type_convert<float>(wei_val);
}
}
}
}
}
}
}
}
}
}
InDataType result = type_convert<InDataType>(acc);
in_op(in_val, result);
p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + di * in_stride_d +
hi * in_stride_h + wi] = in_val;
}
}
}
// GPU reference backward data convolution - takes ConvParam directly
template <typename InLayout,
typename WeiLayout,
typename OutLayout,
typename TIn,
typename TWei,
typename TOut,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
void naive_conv_bwd_data(TIn* p_in,
const TWei* p_wei,
const TOut* p_out,
const ck::utils::conv::ConvParam& conv_param,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
{
const auto ndim = conv_param.num_dim_spatial_;
const index_t G = conv_param.G_;
const index_t N = conv_param.N_;
const index_t C = conv_param.C_;
const index_t K = conv_param.K_;
std::vector<index_t> in_lengths = {G, N, C};
std::vector<index_t> wei_lengths = {G, K, C};
std::vector<index_t> out_lengths = {G, N, K};
for(index_t i = 0; i < ndim; ++i)
{
in_lengths.push_back(static_cast<index_t>(conv_param.input_spatial_lengths_[i]));
wei_lengths.push_back(static_cast<index_t>(conv_param.filter_spatial_lengths_[i]));
out_lengths.push_back(static_cast<index_t>(conv_param.output_spatial_lengths_[i]));
}
// Calculate total elements for buffer allocation
long_index_t in_total = 1, wei_total = 1, out_total = 1;
for(auto l : in_lengths)
in_total *= l;
for(auto l : wei_lengths)
wei_total *= l;
for(auto l : out_lengths)
out_total *= l;
// Allocate packed buffers
SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn));
SimpleDeviceMem wei_packed_buf(wei_total * sizeof(TWei));
SimpleDeviceMem out_packed_buf(out_total * sizeof(TOut));
TIn* p_in_packed = static_cast<TIn*>(in_packed_buf.GetDeviceBuffer());
TWei* p_wei_packed = static_cast<TWei*>(wei_packed_buf.GetDeviceBuffer());
TOut* p_out_packed = static_cast<TOut*>(out_packed_buf.GetDeviceBuffer());
// Compute strides and allocate device arrays for pack/unpack
std::vector<index_t> in_strides = compute_conv_tensor_strides<InLayout>(in_lengths, ndim);
std::vector<index_t> wei_strides = compute_conv_tensor_strides<WeiLayout>(wei_lengths, ndim);
std::vector<index_t> out_strides = compute_conv_tensor_strides<OutLayout>(out_lengths, ndim);
const size_t dim_count = in_lengths.size();
SimpleDeviceMem in_lengths_buf(dim_count * sizeof(index_t));
SimpleDeviceMem in_strides_buf(dim_count * sizeof(index_t));
SimpleDeviceMem wei_lengths_buf(dim_count * sizeof(index_t));
SimpleDeviceMem wei_strides_buf(dim_count * sizeof(index_t));
SimpleDeviceMem out_lengths_buf(dim_count * sizeof(index_t));
SimpleDeviceMem out_strides_buf(dim_count * sizeof(index_t));
index_t* d_in_lengths = static_cast<index_t*>(in_lengths_buf.GetDeviceBuffer());
index_t* d_in_strides = static_cast<index_t*>(in_strides_buf.GetDeviceBuffer());
index_t* d_wei_lengths = static_cast<index_t*>(wei_lengths_buf.GetDeviceBuffer());
index_t* d_wei_strides = static_cast<index_t*>(wei_strides_buf.GetDeviceBuffer());
index_t* d_out_lengths = static_cast<index_t*>(out_lengths_buf.GetDeviceBuffer());
index_t* d_out_strides = static_cast<index_t*>(out_strides_buf.GetDeviceBuffer());
HIP_CHECK_ERROR(hipMemcpy(
d_in_lengths, in_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(
d_in_strides, in_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(
d_wei_lengths, wei_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(
d_wei_strides, wei_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(
d_out_lengths, out_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(
d_out_strides, out_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
// Pack output and weight tensors to contiguous layout (inputs to bwd data)
constexpr int block_size = 256;
strided_copy_kernel<TOut, false>
<<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_out, p_out_packed, d_out_lengths, d_out_strides, dim_count, out_total);
strided_copy_kernel<TWei, false>
<<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_wei, p_wei_packed, d_wei_lengths, d_wei_strides, dim_count, wei_total);
// Build conv parameter vectors for kernel invocation
std::vector<index_t> conv_strides(ndim);
std::vector<index_t> conv_dilations(ndim);
std::vector<index_t> input_pads(ndim);
for(index_t i = 0; i < ndim; ++i)
{
conv_strides[i] = static_cast<index_t>(conv_param.conv_filter_strides_[i]);
conv_dilations[i] = static_cast<index_t>(conv_param.conv_filter_dilations_[i]);
input_pads[i] = static_cast<index_t>(conv_param.input_left_pads_[i]);
}
// Run backward data convolution kernel on packed data
const int in_grid = (in_total + block_size - 1) / block_size;
if(ndim == 1)
{
naive_conv_bwd_data_packed<1,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<in_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_packed,
p_out_packed,
G,
N,
K,
C,
1,
1,
in_lengths[3],
1,
1,
wei_lengths[3],
1,
1,
out_lengths[3],
1,
1,
conv_strides[0],
1,
1,
conv_dilations[0],
0,
0,
input_pads[0],
in_element_op,
wei_element_op,
out_element_op);
}
else if(ndim == 2)
{
naive_conv_bwd_data_packed<2,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<in_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_packed,
p_out_packed,
G,
N,
K,
C,
1,
in_lengths[3],
in_lengths[4],
1,
wei_lengths[3],
wei_lengths[4],
1,
out_lengths[3],
out_lengths[4],
1,
conv_strides[0],
conv_strides[1],
1,
conv_dilations[0],
conv_dilations[1],
0,
input_pads[0],
input_pads[1],
in_element_op,
wei_element_op,
out_element_op);
}
else // 3D
{
naive_conv_bwd_data_packed<3,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<in_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_packed,
p_out_packed,
G,
N,
K,
C,
in_lengths[3],
in_lengths[4],
in_lengths[5],
wei_lengths[3],
wei_lengths[4],
wei_lengths[5],
out_lengths[3],
out_lengths[4],
out_lengths[5],
conv_strides[0],
conv_strides[1],
conv_strides[2],
conv_dilations[0],
conv_dilations[1],
conv_dilations[2],
input_pads[0],
input_pads[1],
input_pads[2],
in_element_op,
wei_element_op,
out_element_op);
}
// Unpack result back to strided layout
strided_copy_kernel<TIn, true><<<in_grid, block_size, 0, stream>>>(
p_in_packed, p_in, d_in_lengths, d_in_strides, dim_count, in_total);
HIP_CHECK_ERROR(hipGetLastError());
// Memory automatically freed by SimpleDeviceMem destructors
}
} // namespace ref
} // namespace ck

View File

@@ -4,136 +4,497 @@
#pragma once
#include "ck/utility/type_convert.hpp"
#include "ck/library/reference_tensor_operation/gpu/conv_common.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace ref {
/*
* \brief naive implementation of 3D convolution backward weight.
* Layout is (NDHWC, KZYXC, NDHWK).
* Computes gradient with respect to weights.
*
* \param N number of batches
* \param K number of filters (output channels)
* \param C number of input channels
* \param (Di, Hi, Wi) depth, height and width dimension of input
* \param (Z, Y, X) depth, height and width dimensions of filter
* \param (Do, Ho, Wo) depth, height and width dimension of output
* \param (stride_z, stride_y, stride_x) strides
* \param (dilation_z, dilation_y, dilation_x) dilations
* \param (pad_z, pad_y, pad_x) pads
*/
template <typename TIn,
typename TWei,
typename TOut,
typename TAcc,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
__global__ void naive_conv_bwd_weight_ndhwc_kzyxc_ndhwk(const TIn* __restrict__ p_in,
TWei* __restrict__ p_wei_grad,
const TOut* __restrict__ p_out_grad,
const ConvDims dims)
// Optimized backward weight convolution kernel working with packed (contiguous) tensors
// Assumes row-major packing: input[G][N][C][spatial], output_grad[G][N][K][spatial],
// weight_grad[G][K][C][filter]
// Computes gradient with respect to weights
template <index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementOp,
typename WeiElementOp,
typename OutElementOp>
__global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in,
WeiDataType* __restrict__ p_wei_grad,
const OutDataType* __restrict__ p_out_grad,
index_t G,
index_t N,
index_t K,
index_t C,
index_t Di,
index_t Hi,
index_t Wi,
index_t Z,
index_t Y,
index_t X,
index_t Do,
index_t Ho,
index_t Wo,
index_t stride_z,
index_t stride_y,
index_t stride_x,
index_t dilation_z,
index_t dilation_y,
index_t dilation_x,
index_t pad_z,
index_t pad_y,
index_t pad_x,
InElementOp in_op,
WeiElementOp wei_op,
OutElementOp out_op)
{
const index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const index_t num_threads = blockDim.x * gridDim.x;
const long_index_t weight_length = dims.K * dims.Z * dims.Y * dims.X * dims.C;
const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const long_index_t num_threads = blockDim.x * gridDim.x;
const index_t in_strides[] = {
dims.Di * dims.Hi * dims.Wi * dims.C, dims.Hi * dims.Wi * dims.C, dims.Wi * dims.C, dims.C};
const index_t out_strides[] = {
dims.Do * dims.Ho * dims.Wo * dims.K, dims.Ho * dims.Wo * dims.K, dims.Wo * dims.K, dims.K};
const index_t wei_strides[] = {
dims.Z * dims.Y * dims.X * dims.C, dims.Y * dims.X * dims.C, dims.X * dims.C, dims.C};
InDataType in_val = InDataType{0};
WeiDataType wei_val = WeiDataType{0};
OutDataType out_val = OutDataType{0};
constexpr auto in_op = InElementwiseOperation{};
constexpr auto wei_op = WeiElementwiseOperation{};
constexpr auto out_op = OutElementwiseOperation{};
TIn in_val = TIn{0};
TWei wei_val = TWei{0};
TOut out_val = TOut{0};
for(long_index_t ii = tid; ii < weight_length; ii += num_threads)
if constexpr(NDimSpatial == 1)
{
// Decode linear index to (k, z, y, x, c)
const index_t k = ii / wei_strides[0];
index_t tmp = ii - k * wei_strides[0];
const index_t z = tmp / wei_strides[1];
tmp -= z * wei_strides[1];
const index_t y = tmp / wei_strides[2];
tmp -= y * wei_strides[2];
const index_t x = tmp / wei_strides[3];
tmp -= x * wei_strides[3];
const index_t c = tmp;
const long_index_t num_wei = G * K * C * X;
const long_index_t in_stride_g = N * C * Wi;
const long_index_t in_stride_n = C * Wi;
const long_index_t in_stride_c = Wi;
const long_index_t out_stride_g = N * K * Wo;
const long_index_t out_stride_n = K * Wo;
const long_index_t out_stride_k = Wo;
const long_index_t wei_stride_g = K * C * X;
const long_index_t wei_stride_k = C * X;
const long_index_t wei_stride_c = X;
// Always accumulate in float
float acc_float = 0.0f;
// Loop over batch
for(index_t n = 0; n < dims.N; ++n)
for(long_index_t idx = tid; idx < num_wei; idx += num_threads)
{
const TIn* in_n = p_in + static_cast<long_index_t>(n) * in_strides[0];
const TOut* out_n = p_out_grad + static_cast<long_index_t>(n) * out_strides[0];
index_t remaining = idx;
const index_t x = remaining % X;
remaining /= X;
const index_t c = remaining % C;
remaining /= C;
const index_t k = remaining % K;
const index_t g = remaining / K;
// Loop over output spatial dimensions
for(index_t d_o = 0; d_o < dims.Do; ++d_o)
float acc = 0.0f;
const InDataType* in_g = p_in + g * in_stride_g;
const OutDataType* out_grad = p_out_grad + g * out_stride_g;
// Loop over batch and output positions
for(index_t n = 0; n < N; ++n)
{
// Calculate input position from output position
index_t di = d_o * dims.stride_z - dims.pad_z + z * dims.dilation_z;
if(di < 0 || di >= dims.Di)
continue;
const InDataType* in_gn = in_g + n * in_stride_n + c * in_stride_c;
const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k;
const TIn* in_n_di = in_n + di * in_strides[1];
const TOut* out_n_do = out_n + d_o * out_strides[1];
for(index_t ho = 0; ho < dims.Ho; ++ho)
for(index_t wo = 0; wo < Wo; ++wo)
{
index_t hi = ho * dims.stride_y - dims.pad_y + y * dims.dilation_y;
if(hi < 0 || hi >= dims.Hi)
continue;
const TIn* in_n_di_hi = in_n_di + hi * in_strides[2];
const TOut* out_n_do_ho = out_n_do + ho * out_strides[2];
for(index_t wo = 0; wo < dims.Wo; ++wo)
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
if(wi >= 0 && wi < Wi)
{
index_t wi = wo * dims.stride_x - dims.pad_x + x * dims.dilation_x;
if(wi < 0 || wi >= dims.Wi)
continue;
// Load values from memory (like forward does)
const TIn* in_ptr = in_n_di_hi + wi * in_strides[3];
const TOut* out_ptr = out_n_do_ho + wo * out_strides[3];
TIn in_loaded = in_ptr[c];
TOut out_loaded = out_ptr[k];
// Apply element-wise operations
in_op(in_val, in_loaded);
out_op(out_val, out_loaded);
// Convert to float for multiplication
float in_f = type_convert<float>(in_val);
float out_f = type_convert<float>(out_val);
acc_float += out_f * in_f;
in_op(in_val, in_gn[wi]);
out_op(out_val, out_gn_k[wo]);
acc += type_convert<float>(out_val) * type_convert<float>(in_val);
}
}
}
WeiDataType result = type_convert<WeiDataType>(acc);
wei_op(wei_val, result);
p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + x] = wei_val;
}
}
else if constexpr(NDimSpatial == 2)
{
const long_index_t num_wei = G * K * C * Y * X;
const long_index_t in_stride_g = N * C * Hi * Wi;
const long_index_t in_stride_n = C * Hi * Wi;
const long_index_t in_stride_c = Hi * Wi;
const long_index_t in_stride_h = Wi;
const long_index_t out_stride_g = N * K * Ho * Wo;
const long_index_t out_stride_n = K * Ho * Wo;
const long_index_t out_stride_k = Ho * Wo;
const long_index_t out_stride_h = Wo;
const long_index_t wei_stride_g = K * C * Y * X;
const long_index_t wei_stride_k = C * Y * X;
const long_index_t wei_stride_c = Y * X;
const long_index_t wei_stride_y = X;
// Convert float accumulator to TAcc, then to weight type
TAcc acc = type_convert<TAcc>(acc_float);
TWei result = type_convert<TWei>(acc);
for(long_index_t idx = tid; idx < num_wei; idx += num_threads)
{
index_t remaining = idx;
const index_t x = remaining % X;
remaining /= X;
const index_t y = remaining % Y;
remaining /= Y;
const index_t c = remaining % C;
remaining /= C;
const index_t k = remaining % K;
const index_t g = remaining / K;
// Apply weight element-wise operation (if any)
wei_op(wei_val, result);
float acc = 0.0f;
const InDataType* in_g = p_in + g * in_stride_g;
const OutDataType* out_grad = p_out_grad + g * out_stride_g;
// Write transformed result
p_wei_grad[ii] = wei_val;
// Loop over batch and output positions
for(index_t n = 0; n < N; ++n)
{
const InDataType* in_gnc = in_g + n * in_stride_n + c * in_stride_c;
const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k;
for(index_t ho = 0; ho < Ho; ++ho)
{
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
if(hi >= 0 && hi < Hi)
{
const InDataType* in_gnch = in_gnc + hi * in_stride_h;
const OutDataType* out_gn_kh = out_gn_k + ho * out_stride_h;
for(index_t wo = 0; wo < Wo; ++wo)
{
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
if(wi >= 0 && wi < Wi)
{
in_op(in_val, in_gnch[wi]);
out_op(out_val, out_gn_kh[wo]);
acc += type_convert<float>(out_val) * type_convert<float>(in_val);
}
}
}
}
}
WeiDataType result = type_convert<WeiDataType>(acc);
wei_op(wei_val, result);
p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + y * wei_stride_y +
x] = wei_val;
}
}
else if constexpr(NDimSpatial == 3)
{
const long_index_t num_wei = G * K * C * Z * Y * X;
const long_index_t in_stride_g = N * C * Di * Hi * Wi;
const long_index_t in_stride_n = C * Di * Hi * Wi;
const long_index_t in_stride_c = Di * Hi * Wi;
const long_index_t in_stride_d = Hi * Wi;
const long_index_t in_stride_h = Wi;
const long_index_t out_stride_g = N * K * Do * Ho * Wo;
const long_index_t out_stride_n = K * Do * Ho * Wo;
const long_index_t out_stride_k = Do * Ho * Wo;
const long_index_t out_stride_d = Ho * Wo;
const long_index_t out_stride_h = Wo;
const long_index_t wei_stride_g = K * C * Z * Y * X;
const long_index_t wei_stride_k = C * Z * Y * X;
const long_index_t wei_stride_c = Z * Y * X;
const long_index_t wei_stride_z = Y * X;
const long_index_t wei_stride_y = X;
for(long_index_t idx = tid; idx < num_wei; idx += num_threads)
{
index_t remaining = idx;
const index_t x = remaining % X;
remaining /= X;
const index_t y = remaining % Y;
remaining /= Y;
const index_t z = remaining % Z;
remaining /= Z;
const index_t c = remaining % C;
remaining /= C;
const index_t k = remaining % K;
const index_t g = remaining / K;
float acc = 0.0f;
const InDataType* in_g = p_in + g * in_stride_g;
const OutDataType* out_grad = p_out_grad + g * out_stride_g;
// Loop over batch and output positions
for(index_t n = 0; n < N; ++n)
{
const InDataType* in_gnc = in_g + n * in_stride_n + c * in_stride_c;
const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k;
for(index_t do_idx = 0; do_idx < Do; ++do_idx)
{
long_index_t di = do_idx * stride_z + z * dilation_z - pad_z;
if(di >= 0 && di < Di)
{
const InDataType* in_gncd = in_gnc + di * in_stride_d;
const OutDataType* out_gn_kd = out_gn_k + do_idx * out_stride_d;
for(index_t ho = 0; ho < Ho; ++ho)
{
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
if(hi >= 0 && hi < Hi)
{
const InDataType* in_gncdh = in_gncd + hi * in_stride_h;
const OutDataType* out_gn_kdh = out_gn_kd + ho * out_stride_h;
for(index_t wo = 0; wo < Wo; ++wo)
{
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
if(wi >= 0 && wi < Wi)
{
in_op(in_val, in_gncdh[wi]);
out_op(out_val, out_gn_kdh[wo]);
acc += type_convert<float>(out_val) *
type_convert<float>(in_val);
}
}
}
}
}
}
}
WeiDataType result = type_convert<WeiDataType>(acc);
wei_op(wei_val, result);
p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + z * wei_stride_z +
y * wei_stride_y + x] = wei_val;
}
}
}
// GPU reference backward weight convolution - takes ConvParam directly
template <typename InLayout,
typename WeiLayout,
typename OutLayout,
typename TIn,
typename TWei,
typename TOut,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
void naive_conv_bwd_weight(const TIn* p_in,
TWei* p_wei_grad,
const TOut* p_out,
const ck::utils::conv::ConvParam& conv_param,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
{
const auto ndim = conv_param.num_dim_spatial_;
const index_t G = conv_param.G_;
const index_t N = conv_param.N_;
const index_t C = conv_param.C_;
const index_t K = conv_param.K_;
std::vector<index_t> in_lengths = {G, N, C};
std::vector<index_t> wei_lengths = {G, K, C};
std::vector<index_t> out_lengths = {G, N, K};
for(index_t i = 0; i < ndim; ++i)
{
in_lengths.push_back(static_cast<index_t>(conv_param.input_spatial_lengths_[i]));
wei_lengths.push_back(static_cast<index_t>(conv_param.filter_spatial_lengths_[i]));
out_lengths.push_back(static_cast<index_t>(conv_param.output_spatial_lengths_[i]));
}
// Calculate total elements for buffer allocation
long_index_t in_total = 1, wei_total = 1, out_total = 1;
for(auto l : in_lengths)
in_total *= l;
for(auto l : wei_lengths)
wei_total *= l;
for(auto l : out_lengths)
out_total *= l;
// Allocate packed buffers
SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn));
SimpleDeviceMem wei_grad_packed_buf(wei_total * sizeof(TWei));
SimpleDeviceMem out_grad_packed_buf(out_total * sizeof(TOut));
TIn* p_in_packed = static_cast<TIn*>(in_packed_buf.GetDeviceBuffer());
TWei* p_wei_grad_packed = static_cast<TWei*>(wei_grad_packed_buf.GetDeviceBuffer());
TOut* p_out_grad_packed = static_cast<TOut*>(out_grad_packed_buf.GetDeviceBuffer());
// Compute strides and allocate device arrays for pack/unpack
std::vector<index_t> in_strides = compute_conv_tensor_strides<InLayout>(in_lengths, ndim);
std::vector<index_t> wei_strides = compute_conv_tensor_strides<WeiLayout>(wei_lengths, ndim);
std::vector<index_t> out_strides = compute_conv_tensor_strides<OutLayout>(out_lengths, ndim);
const size_t dim_count = in_lengths.size();
SimpleDeviceMem in_lengths_buf(dim_count * sizeof(index_t));
SimpleDeviceMem in_strides_buf(dim_count * sizeof(index_t));
SimpleDeviceMem wei_lengths_buf(dim_count * sizeof(index_t));
SimpleDeviceMem wei_strides_buf(dim_count * sizeof(index_t));
SimpleDeviceMem out_lengths_buf(dim_count * sizeof(index_t));
SimpleDeviceMem out_strides_buf(dim_count * sizeof(index_t));
index_t* d_in_lengths = static_cast<index_t*>(in_lengths_buf.GetDeviceBuffer());
index_t* d_in_strides = static_cast<index_t*>(in_strides_buf.GetDeviceBuffer());
index_t* d_wei_lengths = static_cast<index_t*>(wei_lengths_buf.GetDeviceBuffer());
index_t* d_wei_strides = static_cast<index_t*>(wei_strides_buf.GetDeviceBuffer());
index_t* d_out_lengths = static_cast<index_t*>(out_lengths_buf.GetDeviceBuffer());
index_t* d_out_strides = static_cast<index_t*>(out_strides_buf.GetDeviceBuffer());
HIP_CHECK_ERROR(hipMemcpy(
d_in_lengths, in_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(
d_in_strides, in_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(
d_wei_lengths, wei_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(
d_wei_strides, wei_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(
d_out_lengths, out_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(
d_out_strides, out_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
// Pack input and output_grad tensors to contiguous layout (inputs to bwd weight)
constexpr int block_size = 256;
strided_copy_kernel<TIn, false>
<<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_in, p_in_packed, d_in_lengths, d_in_strides, dim_count, in_total);
strided_copy_kernel<TOut, false>
<<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_out, p_out_grad_packed, d_out_lengths, d_out_strides, dim_count, out_total);
// Build conv parameter vectors for kernel invocation
std::vector<index_t> conv_strides(ndim);
std::vector<index_t> conv_dilations(ndim);
std::vector<index_t> input_pads(ndim);
for(index_t i = 0; i < ndim; ++i)
{
conv_strides[i] = static_cast<index_t>(conv_param.conv_filter_strides_[i]);
conv_dilations[i] = static_cast<index_t>(conv_param.conv_filter_dilations_[i]);
input_pads[i] = static_cast<index_t>(conv_param.input_left_pads_[i]);
}
// Run backward weight convolution kernel on packed data
const int wei_grid = (wei_total + block_size - 1) / block_size;
if(ndim == 1)
{
naive_conv_bwd_weight_packed<1,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<wei_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_grad_packed,
p_out_grad_packed,
G,
N,
K,
C,
1,
1,
in_lengths[3],
1,
1,
wei_lengths[3],
1,
1,
out_lengths[3],
1,
1,
conv_strides[0],
1,
1,
conv_dilations[0],
0,
0,
input_pads[0],
in_element_op,
wei_element_op,
out_element_op);
}
else if(ndim == 2)
{
naive_conv_bwd_weight_packed<2,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<wei_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_grad_packed,
p_out_grad_packed,
G,
N,
K,
C,
1,
in_lengths[3],
in_lengths[4],
1,
wei_lengths[3],
wei_lengths[4],
1,
out_lengths[3],
out_lengths[4],
1,
conv_strides[0],
conv_strides[1],
1,
conv_dilations[0],
conv_dilations[1],
0,
input_pads[0],
input_pads[1],
in_element_op,
wei_element_op,
out_element_op);
}
else // 3D
{
naive_conv_bwd_weight_packed<3,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<wei_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_grad_packed,
p_out_grad_packed,
G,
N,
K,
C,
in_lengths[3],
in_lengths[4],
in_lengths[5],
wei_lengths[3],
wei_lengths[4],
wei_lengths[5],
out_lengths[3],
out_lengths[4],
out_lengths[5],
conv_strides[0],
conv_strides[1],
conv_strides[2],
conv_dilations[0],
conv_dilations[1],
conv_dilations[2],
input_pads[0],
input_pads[1],
input_pads[2],
in_element_op,
wei_element_op,
out_element_op);
}
// Unpack weight gradient
strided_copy_kernel<TWei, true><<<wei_grid, block_size, 0, stream>>>(
p_wei_grad_packed, p_wei_grad, d_wei_lengths, d_wei_strides, dim_count, wei_total);
HIP_CHECK_ERROR(hipGetLastError());
// Memory automatically freed by SimpleDeviceMem destructors
}
} // namespace ref
} // namespace ck

View File

@@ -4,126 +4,493 @@
#pragma once
#include "ck/utility/type_convert.hpp"
#include "ck/library/reference_tensor_operation/gpu/conv_common.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace ref {
/*
* \brief naive implementation of 3D convolution. Layout is (NDHWC, KZYXC, NDHWK).
*
* \param N number of batches
* \param K number of filters
* \param C number of channels of weight
* \param (Di, Hi, Wi) depth, height and width dimension of data
* \param (Z, Y, X) depth, height and width dimensions of weights
* \param (Do, Ho, Wo) depth, height and width dimension of output
* \param (stride_z, stride_y, stride_x) strides
* \param (dilation_z, dilation_y, dilation_x) dilations
* \param (pad_z, pad_y, pad_x) pads
*/
template <typename TIn,
typename TWei,
typename TOut,
typename TAcc,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
__global__ void naive_conv_fwd_ndhwc_kzyxc_ndhwk(const TIn* __restrict__ p_in,
const TWei* __restrict__ p_wei,
TOut* __restrict__ p_out,
const ConvDims dims)
// Optimized convolution kernel working with packed (contiguous) tensors
// Assumes row-major packing: input[G][N][C][spatial], weight[G][K][C][filter],
// output[G][N][K][spatial]
template <index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InElementOp,
typename WeiElementOp,
typename OutElementOp>
__global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
const WeiDataType* __restrict__ p_wei,
OutDataType* __restrict__ p_out,
index_t G,
index_t N,
index_t K,
index_t C,
index_t Di,
index_t Hi,
index_t Wi,
index_t Z,
index_t Y,
index_t X,
index_t Do,
index_t Ho,
index_t Wo,
index_t stride_z,
index_t stride_y,
index_t stride_x,
index_t dilation_z,
index_t dilation_y,
index_t dilation_x,
index_t pad_z,
index_t pad_y,
index_t pad_x,
InElementOp in_op,
WeiElementOp wei_op,
OutElementOp out_op)
{
const index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const index_t num_threads = blockDim.x * gridDim.x;
const long_index_t output_length = dims.N * dims.Do * dims.Ho * dims.Wo * dims.K;
const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const long_index_t num_threads = blockDim.x * gridDim.x;
const index_t out_strides[] = {
dims.Do * dims.Ho * dims.Wo * dims.K, dims.Ho * dims.Wo * dims.K, dims.Wo * dims.K, dims.K};
const index_t in_strides[] = {
dims.Di * dims.Hi * dims.Wi * dims.C, dims.Hi * dims.Wi * dims.C, dims.Wi * dims.C, dims.C};
const index_t wei_strides[] = {
dims.Z * dims.Y * dims.X * dims.C, dims.Y * dims.X * dims.C, dims.X * dims.C, dims.C};
InDataType in_val = InDataType{0};
WeiDataType wei_val = WeiDataType{0};
OutDataType out_val = OutDataType{0};
constexpr auto in_op = InElementwiseOperation{};
constexpr auto wei_op = WeiElementwiseOperation{};
constexpr auto out_op = OutElementwiseOperation{};
TIn in_val = TIn{0};
TWei wei_val = TWei{0};
TOut out_val = TOut{0};
for(long_index_t ii = tid; ii < output_length; ii += num_threads)
if constexpr(NDimSpatial == 1)
{
const index_t n = ii / out_strides[0];
index_t k = ii - n * out_strides[0];
const index_t dO = k / out_strides[1];
k -= dO * out_strides[1];
const index_t ho = k / out_strides[2];
k -= ho * out_strides[2];
const index_t wo = k / out_strides[3];
k -= wo * out_strides[3];
const long_index_t num_out = G * N * K * Wo;
const long_index_t in_stride_g = N * C * Wi;
const long_index_t in_stride_n = C * Wi;
const long_index_t in_stride_c = Wi;
const long_index_t wei_stride_g = K * C * X;
const long_index_t wei_stride_k = C * X;
const long_index_t wei_stride_c = X;
const long_index_t out_stride_g = N * K * Wo;
const long_index_t out_stride_n = K * Wo;
const long_index_t out_stride_k = Wo;
// Always accumulate in float (FP8/BF8 don't support arithmetic)
float acc_float = 0.0f;
const TIn* in_n = p_in + static_cast<long_index_t>(n) * in_strides[0];
const TWei* wei_k = p_wei + static_cast<long_index_t>(k) * wei_strides[0];
for(index_t z = 0; z < dims.Z; ++z)
for(long_index_t idx = tid; idx < num_out; idx += num_threads)
{
index_t di = dims.stride_z * dO - dims.pad_z + dims.dilation_z * z;
const TIn* in_n_di = in_n + di * in_strides[1];
const TWei* wei_k_z = wei_k + z * wei_strides[1];
index_t remaining = idx;
const index_t wo = remaining % Wo;
remaining /= Wo;
const index_t k = remaining % K;
remaining /= K;
const index_t n = remaining % N;
const index_t g = remaining / N;
for(index_t y = 0; y < dims.Y; ++y)
float acc = 0.0f;
const InDataType* in_g = p_in + g * in_stride_g + n * in_stride_n;
const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k;
for(index_t c = 0; c < C; ++c)
{
index_t hi = dims.stride_y * ho - dims.pad_y + dims.dilation_y * y;
const TIn* in_n_di_hi = in_n_di + hi * in_strides[2];
const TWei* wei_k_z_y = wei_k_z + y * wei_strides[2];
const InDataType* in_gc = in_g + c * in_stride_c;
const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c;
for(index_t x = 0; x < dims.X; ++x)
for(index_t x = 0; x < X; ++x)
{
index_t wi = dims.stride_x * wo - dims.pad_x + dims.dilation_x * x;
const TIn* in_n_di_hi_wi = in_n_di_hi + wi * in_strides[3];
const TWei* wei_k_z_y_x = wei_k_z_y + x * wei_strides[3];
if(di >= 0 && di < dims.Di && hi >= 0 && hi < dims.Hi && wi >= 0 &&
wi < dims.Wi)
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
if(wi >= 0 && wi < Wi)
{
for(index_t c = 0; c < dims.C; ++c)
in_op(in_val, in_gc[wi]);
wei_op(wei_val, wei_gkc[x]);
acc += type_convert<float>(in_val) * type_convert<float>(wei_val);
}
}
}
OutDataType result = type_convert<OutDataType>(acc);
out_op(out_val, result);
p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + wo] = out_val;
}
}
else if constexpr(NDimSpatial == 2)
{
const long_index_t num_out = G * N * K * Ho * Wo;
const long_index_t in_stride_g = N * C * Hi * Wi;
const long_index_t in_stride_n = C * Hi * Wi;
const long_index_t in_stride_c = Hi * Wi;
const long_index_t in_stride_h = Wi;
const long_index_t wei_stride_g = K * C * Y * X;
const long_index_t wei_stride_k = C * Y * X;
const long_index_t wei_stride_c = Y * X;
const long_index_t wei_stride_y = X;
const long_index_t out_stride_g = N * K * Ho * Wo;
const long_index_t out_stride_n = K * Ho * Wo;
const long_index_t out_stride_k = Ho * Wo;
const long_index_t out_stride_h = Wo;
for(long_index_t idx = tid; idx < num_out; idx += num_threads)
{
index_t remaining = idx;
const index_t wo = remaining % Wo;
remaining /= Wo;
const index_t ho = remaining % Ho;
remaining /= Ho;
const index_t k = remaining % K;
remaining /= K;
const index_t n = remaining % N;
const index_t g = remaining / N;
float acc = 0.0f;
const InDataType* in_gn = p_in + g * in_stride_g + n * in_stride_n;
const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k;
for(index_t c = 0; c < C; ++c)
{
const InDataType* in_gnc = in_gn + c * in_stride_c;
const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c;
for(index_t y = 0; y < Y; ++y)
{
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
if(hi >= 0 && hi < Hi)
{
const InDataType* in_gnch = in_gnc + hi * in_stride_h;
const WeiDataType* wei_gkcy = wei_gkc + y * wei_stride_y;
for(index_t x = 0; x < X; ++x)
{
// Load values from memory
TIn in_loaded = in_n_di_hi_wi[c];
TWei wei_loaded = wei_k_z_y_x[c];
// Apply element-wise operations
in_op(in_val, in_loaded);
wei_op(wei_val, wei_loaded);
// Always convert to float for multiplication (FP8/BF8 don't support
// direct arithmetic)
float in_f = type_convert<float>(in_val);
float wei_f = type_convert<float>(wei_val);
// Accumulate in float
acc_float += in_f * wei_f;
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
if(wi >= 0 && wi < Wi)
{
in_op(in_val, in_gnch[wi]);
wei_op(wei_val, wei_gkcy[x]);
acc += type_convert<float>(in_val) * type_convert<float>(wei_val);
}
}
}
}
}
OutDataType result = type_convert<OutDataType>(acc);
out_op(out_val, result);
p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + ho * out_stride_h + wo] =
out_val;
}
}
else if constexpr(NDimSpatial == 3)
{
const long_index_t num_out = G * N * K * Do * Ho * Wo;
const long_index_t in_stride_g = N * C * Di * Hi * Wi;
const long_index_t in_stride_n = C * Di * Hi * Wi;
const long_index_t in_stride_c = Di * Hi * Wi;
const long_index_t in_stride_d = Hi * Wi;
const long_index_t in_stride_h = Wi;
const long_index_t wei_stride_g = K * C * Z * Y * X;
const long_index_t wei_stride_k = C * Z * Y * X;
const long_index_t wei_stride_c = Z * Y * X;
const long_index_t wei_stride_z = Y * X;
const long_index_t wei_stride_y = X;
const long_index_t out_stride_g = N * K * Do * Ho * Wo;
const long_index_t out_stride_n = K * Do * Ho * Wo;
const long_index_t out_stride_k = Do * Ho * Wo;
const long_index_t out_stride_d = Ho * Wo;
const long_index_t out_stride_h = Wo;
// Convert float accumulator to TAcc, then to output type
TAcc acc = type_convert<TAcc>(acc_float);
TOut result = type_convert<TOut>(acc);
for(long_index_t idx = tid; idx < num_out; idx += num_threads)
{
index_t remaining = idx;
const index_t wo = remaining % Wo;
remaining /= Wo;
const index_t ho = remaining % Ho;
remaining /= Ho;
const index_t do_idx = remaining % Do;
remaining /= Do;
const index_t k = remaining % K;
remaining /= K;
const index_t n = remaining % N;
const index_t g = remaining / N;
// Apply output element-wise operation (if any)
out_op(out_val, result);
float acc = 0.0f;
const InDataType* in_gn = p_in + g * in_stride_g + n * in_stride_n;
const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k;
// Write transformed result
p_out[ii] = out_val;
for(index_t c = 0; c < C; ++c)
{
const InDataType* in_gnc = in_gn + c * in_stride_c;
const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c;
for(index_t z = 0; z < Z; ++z)
{
long_index_t di = do_idx * stride_z + z * dilation_z - pad_z;
if(di >= 0 && di < Di)
{
const InDataType* in_gncd = in_gnc + di * in_stride_d;
const WeiDataType* wei_gkcz = wei_gkc + z * wei_stride_z;
for(index_t y = 0; y < Y; ++y)
{
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
if(hi >= 0 && hi < Hi)
{
const InDataType* in_gncdh = in_gncd + hi * in_stride_h;
const WeiDataType* wei_gkczy = wei_gkcz + y * wei_stride_y;
for(index_t x = 0; x < X; ++x)
{
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
if(wi >= 0 && wi < Wi)
{
in_op(in_val, in_gncdh[wi]);
wei_op(wei_val, wei_gkczy[x]);
acc += type_convert<float>(in_val) *
type_convert<float>(wei_val);
}
}
}
}
}
}
}
OutDataType result = type_convert<OutDataType>(acc);
out_op(out_val, result);
p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + do_idx * out_stride_d +
ho * out_stride_h + wo] = out_val;
}
}
}
// GPU reference convolution - takes ConvParam directly
template <typename InLayout,
typename WeiLayout,
typename OutLayout,
typename TIn,
typename TWei,
typename TOut,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
void naive_conv_fwd(const TIn* p_in,
const TWei* p_wei,
TOut* p_out,
const ck::utils::conv::ConvParam& conv_param,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
{
const auto ndim = conv_param.num_dim_spatial_;
const index_t G = conv_param.G_;
const index_t N = conv_param.N_;
const index_t C = conv_param.C_;
const index_t K = conv_param.K_;
std::vector<index_t> in_lengths = {G, N, C};
std::vector<index_t> wei_lengths = {G, K, C};
std::vector<index_t> out_lengths = {G, N, K};
for(index_t i = 0; i < ndim; ++i)
{
in_lengths.push_back(static_cast<index_t>(conv_param.input_spatial_lengths_[i]));
wei_lengths.push_back(static_cast<index_t>(conv_param.filter_spatial_lengths_[i]));
out_lengths.push_back(static_cast<index_t>(conv_param.output_spatial_lengths_[i]));
}
// Calculate total elements for buffer allocation
long_index_t in_total = 1, wei_total = 1, out_total = 1;
for(auto l : in_lengths)
in_total *= l;
for(auto l : wei_lengths)
wei_total *= l;
for(auto l : out_lengths)
out_total *= l;
// Allocate packed buffers
SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn));
SimpleDeviceMem wei_packed_buf(wei_total * sizeof(TWei));
SimpleDeviceMem out_packed_buf(out_total * sizeof(TOut));
TIn* p_in_packed = static_cast<TIn*>(in_packed_buf.GetDeviceBuffer());
TWei* p_wei_packed = static_cast<TWei*>(wei_packed_buf.GetDeviceBuffer());
TOut* p_out_packed = static_cast<TOut*>(out_packed_buf.GetDeviceBuffer());
// Compute strides and allocate device arrays for pack/unpack
std::vector<index_t> in_strides = compute_conv_tensor_strides<InLayout>(in_lengths, ndim);
std::vector<index_t> wei_strides = compute_conv_tensor_strides<WeiLayout>(wei_lengths, ndim);
std::vector<index_t> out_strides = compute_conv_tensor_strides<OutLayout>(out_lengths, ndim);
const size_t dim_count = in_lengths.size();
SimpleDeviceMem in_lengths_buf(dim_count * sizeof(index_t));
SimpleDeviceMem in_strides_buf(dim_count * sizeof(index_t));
SimpleDeviceMem wei_lengths_buf(dim_count * sizeof(index_t));
SimpleDeviceMem wei_strides_buf(dim_count * sizeof(index_t));
SimpleDeviceMem out_lengths_buf(dim_count * sizeof(index_t));
SimpleDeviceMem out_strides_buf(dim_count * sizeof(index_t));
index_t* d_in_lengths = static_cast<index_t*>(in_lengths_buf.GetDeviceBuffer());
index_t* d_in_strides = static_cast<index_t*>(in_strides_buf.GetDeviceBuffer());
index_t* d_wei_lengths = static_cast<index_t*>(wei_lengths_buf.GetDeviceBuffer());
index_t* d_wei_strides = static_cast<index_t*>(wei_strides_buf.GetDeviceBuffer());
index_t* d_out_lengths = static_cast<index_t*>(out_lengths_buf.GetDeviceBuffer());
index_t* d_out_strides = static_cast<index_t*>(out_strides_buf.GetDeviceBuffer());
HIP_CHECK_ERROR(hipMemcpy(
d_in_lengths, in_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(
d_in_strides, in_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(
d_wei_lengths, wei_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(
d_wei_strides, wei_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(
d_out_lengths, out_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(
d_out_strides, out_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
// Pack input and weight tensors to contiguous layout
constexpr int block_size = 256;
strided_copy_kernel<TIn, false>
<<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_in, p_in_packed, d_in_lengths, d_in_strides, dim_count, in_total);
strided_copy_kernel<TWei, false>
<<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_wei, p_wei_packed, d_wei_lengths, d_wei_strides, dim_count, wei_total);
// Build conv parameter vectors for kernel invocation
std::vector<index_t> conv_strides(ndim);
std::vector<index_t> conv_dilations(ndim);
std::vector<index_t> input_pads(ndim);
for(index_t i = 0; i < ndim; ++i)
{
conv_strides[i] = static_cast<index_t>(conv_param.conv_filter_strides_[i]);
conv_dilations[i] = static_cast<index_t>(conv_param.conv_filter_dilations_[i]);
input_pads[i] = static_cast<index_t>(conv_param.input_left_pads_[i]);
}
// Run convolution kernel on packed data
const int out_grid = (out_total + block_size - 1) / block_size;
if(ndim == 1)
{
naive_conv_fwd_packed<1,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<out_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_packed,
p_out_packed,
G,
N,
K,
C,
1,
1,
in_lengths[3],
1,
1,
wei_lengths[3],
1,
1,
out_lengths[3],
1,
1,
conv_strides[0],
1,
1,
conv_dilations[0],
0,
0,
input_pads[0],
in_element_op,
wei_element_op,
out_element_op);
}
else if(ndim == 2)
{
naive_conv_fwd_packed<2,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<out_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_packed,
p_out_packed,
G,
N,
K,
C,
1,
in_lengths[3],
in_lengths[4],
1,
wei_lengths[3],
wei_lengths[4],
1,
out_lengths[3],
out_lengths[4],
1,
conv_strides[0],
conv_strides[1],
1,
conv_dilations[0],
conv_dilations[1],
0,
input_pads[0],
input_pads[1],
in_element_op,
wei_element_op,
out_element_op);
}
else // 3D
{
naive_conv_fwd_packed<3,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<out_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_packed,
p_out_packed,
G,
N,
K,
C,
in_lengths[3],
in_lengths[4],
in_lengths[5],
wei_lengths[3],
wei_lengths[4],
wei_lengths[5],
out_lengths[3],
out_lengths[4],
out_lengths[5],
conv_strides[0],
conv_strides[1],
conv_strides[2],
conv_dilations[0],
conv_dilations[1],
conv_dilations[2],
input_pads[0],
input_pads[1],
input_pads[2],
in_element_op,
wei_element_op,
out_element_op);
}
// Unpack
strided_copy_kernel<TOut, true><<<out_grid, block_size, 0, stream>>>(
p_out_packed, p_out, d_out_lengths, d_out_strides, dim_count, out_total);
HIP_CHECK_ERROR(hipGetLastError());
// Memory automatically freed by SimpleDeviceMem destructors
}
} // namespace ref
} // namespace ck

View File

@@ -0,0 +1,177 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck/ck.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include <hip/hip_runtime.h>
#include <vector>
namespace ck {
namespace ref {
// RAII wrapper for device memory to prevent leaks
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&p_mem_), mem_size));
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
// Helper function to map layout dimension character to index in lengths array
// lengths array structure: [G, N/K, C/K, spatial...]
inline int map_dim_char_to_index(char dim_char, index_t ndim_spatial, bool is_weight)
{
// G dimension
if(dim_char == 'G')
return 0;
// Batch/output channels dimension (N for input/output, K for weight's first non-G dim)
if(dim_char == 'N')
return 1;
if(dim_char == 'K' && is_weight)
return 1;
// Channel dimension (C for input/weight, K for output)
if(dim_char == 'C')
return 2;
if(dim_char == 'K' && !is_weight)
return 2;
// Spatial dimensions - map based on ndim_spatial
// Input/Output use: D/H/W, Weight uses: Z/Y/X
if(ndim_spatial == 1)
{
if(dim_char == 'W' || dim_char == 'X')
return 3;
}
else if(ndim_spatial == 2)
{
if(dim_char == 'H' || dim_char == 'Y')
return 3;
if(dim_char == 'W' || dim_char == 'X')
return 4;
}
else if(ndim_spatial == 3)
{
if(dim_char == 'D' || dim_char == 'Z')
return 3;
if(dim_char == 'H' || dim_char == 'Y')
return 4;
if(dim_char == 'W' || dim_char == 'X')
return 5;
}
// Should not reach here
return -1;
}
// Template function to compute layout-aware strides based on layout name
// The layout name directly encodes memory ordering from left to right
template <typename Layout>
inline std::vector<index_t> compute_conv_tensor_strides(const std::vector<index_t>& lengths,
index_t ndim_spatial)
{
constexpr const char* layout_name = Layout::name;
const int num_dims = static_cast<int>(lengths.size());
std::vector<index_t> strides(num_dims, 0);
// Determine if this is a weight tensor (has 'K' but not 'N')
bool has_k = false;
bool has_n = false;
for(const char* p = layout_name; *p != '\0'; ++p)
{
if(*p == 'K')
has_k = true;
if(*p == 'N')
has_n = true;
}
bool is_weight = has_k && !has_n;
// Build dimension ordering from layout name (parse string)
std::vector<char> dim_order;
const char dim_chars[] = {'G', 'N', 'K', 'C', 'D', 'H', 'W', 'X', 'Y', 'Z'};
for(const char* p = layout_name; *p != '\0'; ++p)
{
char c = *p;
// Skip underscores (strided layouts)
if(c == '_')
continue;
// Valid dimension characters
if(std::find(std::begin(dim_chars), std::end(dim_chars), c) != std::end(dim_chars))
{
dim_order.push_back(c);
}
}
// Compute strides: process from right to left (innermost to outermost)
index_t stride = 1;
for(int i = static_cast<int>(dim_order.size()) - 1; i >= 0; --i)
{
char dim_char = dim_order[i];
int length_idx = map_dim_char_to_index(dim_char, ndim_spatial, is_weight);
if(length_idx >= 0 && length_idx < num_dims)
{
strides[length_idx] = stride;
stride *= lengths[length_idx];
}
}
return strides;
}
// Unified kernel for strided tensor copy operations
// IsUnpack=false: Pack strided -> contiguous
// IsUnpack=true: Unpack contiguous -> strided
template <typename DataType, bool IsUnpack>
__global__ void strided_copy_kernel(const DataType* __restrict__ src,
DataType* __restrict__ dst,
const index_t* tensor_lengths,
const index_t* strided_strides,
int num_dims,
long_index_t total_elements)
{
const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const long_index_t num_threads = blockDim.x * gridDim.x;
for(long_index_t linear_idx = tid; linear_idx < total_elements; linear_idx += num_threads)
{
// Compute strided index from linear index
long_index_t remaining = linear_idx;
long_index_t strided_idx = 0;
for(int dim = num_dims - 1; dim >= 0; --dim)
{
index_t coord = remaining % tensor_lengths[dim];
remaining /= tensor_lengths[dim];
strided_idx += coord * strided_strides[dim];
}
// Direction determines which is src and which is dst
if constexpr(IsUnpack)
{
// Unpack: src is contiguous (linear_idx), dst is strided (strided_idx)
dst[strided_idx] = src[linear_idx];
}
else
{
// Pack: src is strided (strided_idx), dst is contiguous (linear_idx)
dst[linear_idx] = src[strided_idx];
}
}
}
} // namespace ref
} // namespace ck