[CK tests] Extend conv GPU reference (#3539)

* test_convnd_fwd

* test_convnd_bwd_data

* test_conv_bwd_data_scale

* test_grouped_convnd_fwd_clamp

* test_grouped_convnd_fwd_scale

* multiple A/B tensors and D tensor for fwd GPU ref

* test_grouped_convnd_fwd_scaleadd_ab

* test_grouped_convnd_fwd_bias_clamp

* test_grouped_convnd_fwd_bilinear

* test_grouped_convnd_fwd_gk_bias_clamp

* Extend GPU reference to enable batchnorm epilogue

* test_grouped_convnd_fwd{,_gk}_bias_bnorm_clamp

* test_grouped_conv_bwd_data_bilinear

* test_grouped_convnd_bwd_weight_bilinear

* Add missing template instantiation

* Perform operations in float in reference

* Slightly increase tolerance for batchnorm profiler

* Revert "Slightly increase tolerance for batchnorm profiler"

This reverts commit a3b2475229.

* Revert "test_grouped_convnd_fwd{,_gk}_bias_bnorm_clamp"

This reverts commit 6da4576060.

* Revert "Extend GPU reference to enable batchnorm epilogue"

This reverts commit e2f75fa10e.

* Clarify variable names

* Refactor elementwise ops into helper functions

* Make helpers C++17-compatible

[ROCm/composable_kernel commit: c190d8d61f]
This commit is contained in:
Johannes Graner
2026-01-27 09:49:42 +01:00
committed by GitHub
parent a7b7eae2a1
commit eb72f85509
24 changed files with 2217 additions and 473 deletions

View File

@@ -10,49 +10,55 @@
#include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include <array>
namespace ck {
namespace ref {
// Optimized backward data convolution kernel working with packed (contiguous) tensors
// Computes gradients w.r.t. input from output gradients and weights
// Assumes row-major packing: input[G][N][C][spatial], weight[G][K][C][filter],
// output[G][N][K][spatial]
// Optimized backward data convolution kernel working with packed (contiguous) tensors with
// multi-ABD support Computes gradients w.r.t. input from output gradients and weights Assumes
// row-major packing: input[G][N][C][spatial], weight[G][K][C][filter], output[G][N][K][spatial]
template <index_t NDimSpatial,
index_t NumAExtra, // Number of extra A (output gradient) tensors
index_t NumBExtra, // Number of extra B (weight) tensors
index_t NumD, // Number of D tensors
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename DDataType, // D tensor data type
typename InElementOp,
typename WeiElementOp,
typename OutElementOp>
__global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
const WeiDataType* __restrict__ p_wei,
const OutDataType* __restrict__ p_out,
index_t G,
index_t N,
index_t K,
index_t C,
index_t Di,
index_t Hi,
index_t Wi,
index_t Z,
index_t Y,
index_t X,
index_t Do,
index_t Ho,
index_t Wo,
index_t stride_z,
index_t stride_y,
index_t stride_x,
index_t dilation_z,
index_t dilation_y,
index_t dilation_x,
index_t pad_z,
index_t pad_y,
index_t pad_x,
InElementOp in_op,
WeiElementOp wei_op,
OutElementOp out_op)
__global__ void naive_conv_bwd_data_packed_multi_abd(InDataType* __restrict__ p_in,
const WeiDataType* const* __restrict__ p_weis,
const OutDataType* const* __restrict__ p_outs,
const DDataType* const* __restrict__ p_ds,
const index_t* const* __restrict__ p_d_strides,
index_t G,
index_t N,
index_t K,
index_t C,
index_t Di,
index_t Hi,
index_t Wi,
index_t Z,
index_t Y,
index_t X,
index_t Do,
index_t Ho,
index_t Wo,
index_t stride_z,
index_t stride_y,
index_t stride_x,
index_t dilation_z,
index_t dilation_y,
index_t dilation_x,
index_t pad_z,
index_t pad_y,
index_t pad_x,
InElementOp in_op,
WeiElementOp wei_op,
OutElementOp out_op)
{
const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const long_index_t num_threads = blockDim.x * gridDim.x;
@@ -84,9 +90,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
const index_t n = remaining % N;
const index_t g = remaining / N;
float acc = 0.0f;
const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n;
const WeiDataType* wei_g = p_wei + g * wei_stride_g;
float acc = 0.0f;
// Base pointers for current group and batch
const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n;
const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g;
for(index_t x = 0; x < X; ++x)
{
@@ -96,21 +103,39 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
long_index_t wo = w_tmp / stride_x;
if(wo >= 0 && wo < Wo)
{
const OutDataType* out_gnk = out_gn;
const WeiDataType* wei_gkc = wei_g + c * wei_stride_c;
// Pointers at current filter position
const OutDataType* output_grad_g_n_k = output_grad_g_n;
const WeiDataType* weight_g_k_c = weight_g + c * wei_stride_c;
for(index_t k = 0; k < K; ++k)
{
out_op(out_val, out_gnk[k * out_stride_k + wo]);
wei_op(wei_val, wei_gkc[k * wei_stride_k + x]);
// Handle output gradient element-wise operation with extra A tensors
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
out_val,
out_op,
output_grad_g_n_k,
p_outs + 1,
g * out_stride_g + n * out_stride_n,
k * out_stride_k + wo);
// Handle weight element-wise operation with extra B tensors
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
wei_val,
wei_op,
weight_g_k_c,
p_weis + 1,
g * wei_stride_g + c * wei_stride_c,
k * wei_stride_k + x);
acc += type_convert<float>(out_val) * type_convert<float>(wei_val);
}
}
}
}
InDataType result = type_convert<InDataType>(acc);
in_op(in_val, result);
detail::apply_d_tensor_elementwise_op<NumD>(
in_val, in_op, acc, p_ds, p_d_strides, g, n, c, wi);
p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + wi] = in_val;
}
}
@@ -142,9 +167,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
const index_t n = remaining % N;
const index_t g = remaining / N;
float acc = 0.0f;
const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n;
const WeiDataType* wei_g = p_wei + g * wei_stride_g;
float acc = 0.0f;
// Base pointers for current group and batch
const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n;
const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g;
for(index_t y = 0; y < Y; ++y)
{
@@ -154,8 +180,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
long_index_t ho = h_tmp / stride_y;
if(ho >= 0 && ho < Ho)
{
const OutDataType* out_gnkh = out_gn + ho * out_stride_h;
const WeiDataType* wei_gkcy = wei_g + c * wei_stride_c + y * wei_stride_y;
// Pointers at current spatial height and filter Y position
const OutDataType* output_grad_at_h = output_grad_g_n + ho * out_stride_h;
const WeiDataType* weight_at_c_y =
weight_g + c * wei_stride_c + y * wei_stride_y;
for(index_t x = 0; x < X; ++x)
{
@@ -167,8 +195,25 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
{
for(index_t k = 0; k < K; ++k)
{
out_op(out_val, out_gnkh[k * out_stride_k + wo]);
wei_op(wei_val, wei_gkcy[k * wei_stride_k + x]);
// Handle output gradient element-wise operation with extra
// A tensors
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
out_val,
out_op,
output_grad_at_h,
p_outs + 1,
g * out_stride_g + n * out_stride_n + ho * out_stride_h,
k * out_stride_k + wo);
// Handle weight element-wise operation with extra B tensors
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
wei_val,
wei_op,
weight_at_c_y,
p_weis + 1,
g * wei_stride_g + c * wei_stride_c + y * wei_stride_y,
k * wei_stride_k + x);
acc += type_convert<float>(out_val) *
type_convert<float>(wei_val);
}
@@ -179,8 +224,17 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
}
}
InDataType result = type_convert<InDataType>(acc);
in_op(in_val, result);
detail::apply_d_tensor_elementwise_op<NumD>(in_val,
in_op,
acc,
p_ds,
p_d_strides,
g,
n,
c,
hi * p_d_strides[0][3] +
wi * p_d_strides[0][4]);
p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + hi * in_stride_h + wi] =
in_val;
}
@@ -218,9 +272,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
const index_t n = remaining % N;
const index_t g = remaining / N;
float acc = 0.0f;
const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n;
const WeiDataType* wei_g = p_wei + g * wei_stride_g;
float acc = 0.0f;
// Base pointers for current group and batch
const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n;
const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g;
for(index_t z = 0; z < Z; ++z)
{
@@ -230,8 +285,11 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
long_index_t do_idx = d_tmp / stride_z;
if(do_idx >= 0 && do_idx < Do)
{
const OutDataType* out_gnkd = out_gn + do_idx * out_stride_d;
const WeiDataType* wei_gkcz = wei_g + c * wei_stride_c + z * wei_stride_z;
// Pointers at current spatial depth
const OutDataType* output_grad_at_d =
output_grad_g_n + do_idx * out_stride_d;
const WeiDataType* weight_at_c_z =
weight_g + c * wei_stride_c + z * wei_stride_z;
for(index_t y = 0; y < Y; ++y)
{
@@ -241,8 +299,11 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
long_index_t ho = h_tmp / stride_y;
if(ho >= 0 && ho < Ho)
{
const OutDataType* out_gnkdh = out_gnkd + ho * out_stride_h;
const WeiDataType* wei_gkczy = wei_gkcz + y * wei_stride_y;
// Pointers at current spatial depth and height
const OutDataType* output_grad_at_d_h =
output_grad_at_d + ho * out_stride_h;
const WeiDataType* weight_at_c_z_y =
weight_at_c_z + y * wei_stride_y;
for(index_t x = 0; x < X; ++x)
{
@@ -254,10 +315,31 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
{
for(index_t k = 0; k < K; ++k)
{
out_op(out_val,
out_gnkdh[k * out_stride_k + wo]);
wei_op(wei_val,
wei_gkczy[k * wei_stride_k + x]);
// Handle output gradient element-wise operation
// with extra A tensors
detail::apply_multi_tensor_elementwise_op<
NumAExtra>(out_val,
out_op,
output_grad_at_d_h,
p_outs + 1,
g * out_stride_g +
n * out_stride_n +
do_idx * out_stride_d +
ho * out_stride_h,
k * out_stride_k + wo);
// Handle weight element-wise operation with
// extra B tensors
detail::apply_multi_tensor_elementwise_op<
NumBExtra>(
wei_val,
wei_op,
weight_at_c_z_y,
p_weis + 1,
g * wei_stride_g + c * wei_stride_c +
z * wei_stride_z + y * wei_stride_y,
k * wei_stride_k + x);
acc += type_convert<float>(out_val) *
type_convert<float>(wei_val);
}
@@ -271,16 +353,28 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
}
}
InDataType result = type_convert<InDataType>(acc);
in_op(in_val, result);
detail::apply_d_tensor_elementwise_op<NumD>(
in_val,
in_op,
acc,
p_ds,
p_d_strides,
g,
n,
c,
di * p_d_strides[0][3] + hi * p_d_strides[0][4] + wi * p_d_strides[0][5]);
p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + di * in_stride_d +
hi * in_stride_h + wi] = in_val;
}
}
}
// GPU reference backward data convolution - takes ConvParam directly
template <typename InLayout,
// GPU reference backward data convolution with multi-ABD support - takes ConvParam directly
template <ck::index_t NumAElementwise = 0,
ck::index_t NumBElementwise = 0,
ck::index_t NumDElementwise = 0,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename TIn,
@@ -288,15 +382,20 @@ template <typename InLayout,
typename TOut,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
void naive_conv_bwd_data(TIn* p_in,
const TWei* p_wei,
const TOut* p_out,
const ck::utils::conv::ConvParam& conv_param,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
typename OutElementwiseOperation,
typename TD = TIn> // D tensor type, defaults to TIn for backward compatibility
void naive_conv_bwd_data_multi_abd(
TIn* p_in,
const std::array<const TWei*, NumBElementwise + 1>& p_weis,
const std::array<const TOut*, NumAElementwise + 1>& p_outs,
const std::array<const TD*, NumDElementwise>& p_ds,
const ck::utils::conv::ConvParam& conv_param,
[[maybe_unused]] const std::array<std::vector<index_t>, NumDElementwise>& d_lengths,
const std::array<std::vector<index_t>, NumDElementwise>& d_strides,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
{
const auto ndim = conv_param.num_dim_spatial_;
@@ -327,12 +426,34 @@ void naive_conv_bwd_data(TIn* p_in,
// Allocate packed buffers
SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn));
SimpleDeviceMem wei_packed_buf(wei_total * sizeof(TWei));
SimpleDeviceMem out_packed_buf(out_total * sizeof(TOut));
TIn* p_in_packed = static_cast<TIn*>(in_packed_buf.GetDeviceBuffer());
TWei* p_wei_packed = static_cast<TWei*>(wei_packed_buf.GetDeviceBuffer());
TOut* p_out_packed = static_cast<TOut*>(out_packed_buf.GetDeviceBuffer());
std::vector<SimpleDeviceMem> wei_packed_bufs;
wei_packed_bufs.reserve(NumBElementwise + 1);
for(index_t i = 0; i <= NumBElementwise; ++i)
{
wei_packed_bufs.emplace_back(wei_total * sizeof(TWei));
}
std::vector<SimpleDeviceMem> out_packed_bufs;
out_packed_bufs.reserve(NumAElementwise + 1);
for(index_t i = 0; i <= NumAElementwise; ++i)
{
out_packed_bufs.emplace_back(out_total * sizeof(TOut));
}
TIn* p_in_packed = static_cast<TIn*>(in_packed_buf.GetDeviceBuffer());
std::array<TWei*, NumBElementwise + 1> p_weis_packed;
for(index_t i = 0; i <= NumBElementwise; ++i)
{
p_weis_packed[i] = static_cast<TWei*>(wei_packed_bufs[i].GetDeviceBuffer());
}
std::array<TOut*, NumAElementwise + 1> p_outs_packed;
for(index_t i = 0; i <= NumAElementwise; ++i)
{
p_outs_packed[i] = static_cast<TOut*>(out_packed_bufs[i].GetDeviceBuffer());
}
// Compute strides and allocate device arrays for pack/unpack
std::vector<index_t> in_strides = compute_conv_tensor_strides<InLayout>(in_lengths, ndim);
@@ -369,12 +490,76 @@ void naive_conv_bwd_data(TIn* p_in,
// Pack output and weight tensors to contiguous layout (inputs to bwd data)
constexpr int block_size = 256;
strided_copy_kernel<TOut, false>
<<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_out, p_out_packed, d_out_lengths, d_out_strides, dim_count, out_total);
strided_copy_kernel<TWei, false>
<<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_wei, p_wei_packed, d_wei_lengths, d_wei_strides, dim_count, wei_total);
for(index_t i = 0; i <= NumAElementwise; ++i)
{
strided_copy_kernel<TOut, false>
<<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_outs[i], p_outs_packed[i], d_out_lengths, d_out_strides, dim_count, out_total);
}
for(index_t i = 0; i <= NumBElementwise; ++i)
{
strided_copy_kernel<TWei, false>
<<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_weis[i], p_weis_packed[i], d_wei_lengths, d_wei_strides, dim_count, wei_total);
}
// Prepare D tensor stride arrays on device
std::vector<SimpleDeviceMem> d_stride_bufs;
std::array<index_t*, NumDElementwise> p_d_strides_dev = {};
if constexpr(NumDElementwise > 0)
{
d_stride_bufs.reserve(NumDElementwise);
for(index_t i = 0; i < NumDElementwise; ++i)
{
d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t));
p_d_strides_dev[i] = static_cast<index_t*>(d_stride_bufs[i].GetDeviceBuffer());
HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i],
d_strides[i].data(),
d_strides[i].size() * sizeof(index_t),
hipMemcpyHostToDevice));
}
}
// Create device arrays of pointers
SimpleDeviceMem weis_ptrs_buf((NumBElementwise + 1) * sizeof(TWei*));
SimpleDeviceMem outs_ptrs_buf((NumAElementwise + 1) * sizeof(TOut*));
SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*));
SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*));
TWei** d_weis_ptrs = static_cast<TWei**>(weis_ptrs_buf.GetDeviceBuffer());
TOut** d_outs_ptrs = static_cast<TOut**>(outs_ptrs_buf.GetDeviceBuffer());
TD** d_ds_ptrs = static_cast<TD**>(ds_ptrs_buf.GetDeviceBuffer());
index_t** d_d_strides_ptrs = static_cast<index_t**>(d_strides_ptrs_buf.GetDeviceBuffer());
HIP_CHECK_ERROR(hipMemcpy(d_weis_ptrs,
p_weis_packed.data(),
(NumBElementwise + 1) * sizeof(TWei*),
hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_outs_ptrs,
p_outs_packed.data(),
(NumAElementwise + 1) * sizeof(TOut*),
hipMemcpyHostToDevice));
if constexpr(NumDElementwise > 0)
{
std::array<const TD*, NumDElementwise> p_ds_dev;
for(index_t i = 0; i < NumDElementwise; ++i)
{
p_ds_dev[i] = p_ds[i];
}
HIP_CHECK_ERROR(hipMemcpy(
d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs,
p_d_strides_dev.data(),
NumDElementwise * sizeof(index_t*),
hipMemcpyHostToDevice));
}
// Build conv parameter vectors for kernel invocation
std::vector<index_t> conv_strides(ndim);
@@ -392,16 +577,22 @@ void naive_conv_bwd_data(TIn* p_in,
if(ndim == 1)
{
naive_conv_bwd_data_packed<1,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
naive_conv_bwd_data_packed_multi_abd<1,
NumAElementwise,
NumBElementwise,
NumDElementwise,
TIn,
TWei,
TOut,
TD,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<in_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_packed,
p_out_packed,
d_weis_ptrs,
d_outs_ptrs,
d_ds_ptrs,
d_d_strides_ptrs,
G,
N,
K,
@@ -430,16 +621,22 @@ void naive_conv_bwd_data(TIn* p_in,
}
else if(ndim == 2)
{
naive_conv_bwd_data_packed<2,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
naive_conv_bwd_data_packed_multi_abd<2,
NumAElementwise,
NumBElementwise,
NumDElementwise,
TIn,
TWei,
TOut,
TD,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<in_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_packed,
p_out_packed,
d_weis_ptrs,
d_outs_ptrs,
d_ds_ptrs,
d_d_strides_ptrs,
G,
N,
K,
@@ -468,16 +665,22 @@ void naive_conv_bwd_data(TIn* p_in,
}
else // 3D
{
naive_conv_bwd_data_packed<3,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
naive_conv_bwd_data_packed_multi_abd<3,
NumAElementwise,
NumBElementwise,
NumDElementwise,
TIn,
TWei,
TOut,
TD,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<in_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_packed,
p_out_packed,
d_weis_ptrs,
d_outs_ptrs,
d_ds_ptrs,
d_d_strides_ptrs,
G,
N,
K,
@@ -514,5 +717,43 @@ void naive_conv_bwd_data(TIn* p_in,
// Memory automatically freed by SimpleDeviceMem destructors
}
// Original naive_conv_bwd_data - now a zero-overhead wrapper
template <typename InLayout,
typename WeiLayout,
typename OutLayout,
typename TIn,
typename TWei,
typename TOut,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
inline void naive_conv_bwd_data(TIn* p_in,
const TWei* p_wei,
const TOut* p_out,
const ck::utils::conv::ConvParam& conv_param,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
{
std::array<const TWei*, 1> p_weis = {p_wei};
std::array<const TOut*, 1> p_outs = {p_out};
std::array<const TIn*, 0> p_ds = {};
std::array<std::vector<index_t>, 0> d_lengths = {};
std::array<std::vector<index_t>, 0> d_strides = {};
naive_conv_bwd_data_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_in,
p_weis,
p_outs,
p_ds,
conv_param,
d_lengths,
d_strides,
in_element_op,
wei_element_op,
out_element_op,
stream);
}
} // namespace ref
} // namespace ck

View File

@@ -10,49 +10,58 @@
#include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include <array>
namespace ck {
namespace ref {
// Optimized backward weight convolution kernel working with packed (contiguous) tensors
// Optimized backward weight convolution kernel working with packed (contiguous) tensors with
// multi-ABD support
// Assumes row-major packing: input[G][N][C][spatial], output_grad[G][N][K][spatial],
// weight_grad[G][K][C][filter]
// Computes gradient with respect to weights
template <index_t NDimSpatial,
index_t NumAExtra, // Number of extra A (input) tensors
index_t NumBExtra, // Number of extra B (output gradient) tensors
index_t NumD, // Number of D tensors
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename DDataType, // D tensor data type
typename InElementOp,
typename WeiElementOp,
typename OutElementOp>
__global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in,
WeiDataType* __restrict__ p_wei_grad,
const OutDataType* __restrict__ p_out_grad,
index_t G,
index_t N,
index_t K,
index_t C,
index_t Di,
index_t Hi,
index_t Wi,
index_t Z,
index_t Y,
index_t X,
index_t Do,
index_t Ho,
index_t Wo,
index_t stride_z,
index_t stride_y,
index_t stride_x,
index_t dilation_z,
index_t dilation_y,
index_t dilation_x,
index_t pad_z,
index_t pad_y,
index_t pad_x,
InElementOp in_op,
WeiElementOp wei_op,
OutElementOp out_op)
__global__ void
naive_conv_bwd_weight_packed_multi_abd(const InDataType* const* __restrict__ p_ins,
WeiDataType* __restrict__ p_wei_grad,
const OutDataType* const* __restrict__ p_out_grads,
const DDataType* const* __restrict__ p_ds,
const index_t* const* __restrict__ p_d_strides,
index_t G,
index_t N,
index_t K,
index_t C,
index_t Di,
index_t Hi,
index_t Wi,
index_t Z,
index_t Y,
index_t X,
index_t Do,
index_t Ho,
index_t Wo,
index_t stride_z,
index_t stride_y,
index_t stride_x,
index_t dilation_z,
index_t dilation_y,
index_t dilation_x,
index_t pad_z,
index_t pad_y,
index_t pad_x,
InElementOp in_op,
WeiElementOp wei_op,
OutElementOp out_op)
{
const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const long_index_t num_threads = blockDim.x * gridDim.x;
@@ -84,30 +93,50 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in
const index_t k = remaining % K;
const index_t g = remaining / K;
float acc = 0.0f;
const InDataType* in_g = p_in + g * in_stride_g;
const OutDataType* out_grad = p_out_grad + g * out_stride_g;
float acc = 0.0f;
// Base pointers for current group
const InDataType* input_g = p_ins[0] + g * in_stride_g;
const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g;
// Loop over batch and output positions
for(index_t n = 0; n < N; ++n)
{
const InDataType* in_gn = in_g + n * in_stride_n + c * in_stride_c;
const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k;
// Pointers at current batch and input channel
const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c;
const OutDataType* output_grad_at_n_k =
output_grad_g + n * out_stride_n + k * out_stride_k;
for(index_t wo = 0; wo < Wo; ++wo)
{
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
if(wi >= 0 && wi < Wi)
{
in_op(in_val, in_gn[wi]);
out_op(out_val, out_gn_k[wo]);
// Handle input element-wise operation with extra A tensors
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
in_val,
in_op,
input_at_n_c,
p_ins + 1,
g * in_stride_g + n * in_stride_n + c * in_stride_c,
wi);
// Handle output gradient element-wise operation with extra B tensors
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
out_val,
out_op,
output_grad_at_n_k,
p_out_grads + 1,
g * out_stride_g + n * out_stride_n + k * out_stride_k,
wo);
acc += type_convert<float>(out_val) * type_convert<float>(in_val);
}
}
}
WeiDataType result = type_convert<WeiDataType>(acc);
wei_op(wei_val, result);
detail::apply_d_tensor_elementwise_op<NumD>(
wei_val, wei_op, acc, p_ds, p_d_strides, g, k, c, x);
p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + x] = wei_val;
}
}
@@ -139,31 +168,55 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in
const index_t k = remaining % K;
const index_t g = remaining / K;
float acc = 0.0f;
const InDataType* in_g = p_in + g * in_stride_g;
const OutDataType* out_grad = p_out_grad + g * out_stride_g;
float acc = 0.0f;
// Base pointers for current group
const InDataType* input_g = p_ins[0] + g * in_stride_g;
const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g;
// Loop over batch and output positions
for(index_t n = 0; n < N; ++n)
{
const InDataType* in_gnc = in_g + n * in_stride_n + c * in_stride_c;
const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k;
// Pointers at current batch and input channel
const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c;
const OutDataType* output_grad_at_n_k =
output_grad_g + n * out_stride_n + k * out_stride_k;
for(index_t ho = 0; ho < Ho; ++ho)
{
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
if(hi >= 0 && hi < Hi)
{
const InDataType* in_gnch = in_gnc + hi * in_stride_h;
const OutDataType* out_gn_kh = out_gn_k + ho * out_stride_h;
// Pointers at current spatial height
const InDataType* input_at_h = input_at_n_c + hi * in_stride_h;
const OutDataType* output_grad_at_h =
output_grad_at_n_k + ho * out_stride_h;
for(index_t wo = 0; wo < Wo; ++wo)
{
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
if(wi >= 0 && wi < Wi)
{
in_op(in_val, in_gnch[wi]);
out_op(out_val, out_gn_kh[wo]);
// Handle input element-wise operation with extra A tensors
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
in_val,
in_op,
input_at_h,
p_ins + 1,
g * in_stride_g + n * in_stride_n + c * in_stride_c +
hi * in_stride_h,
wi);
// Handle output gradient element-wise operation with extra B
// tensors
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
out_val,
out_op,
output_grad_at_h,
p_out_grads + 1,
g * out_stride_g + n * out_stride_n + k * out_stride_k +
ho * out_stride_h,
wo);
acc += type_convert<float>(out_val) * type_convert<float>(in_val);
}
}
@@ -171,8 +224,17 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in
}
}
WeiDataType result = type_convert<WeiDataType>(acc);
wei_op(wei_val, result);
detail::apply_d_tensor_elementwise_op<NumD>(wei_val,
wei_op,
acc,
p_ds,
p_d_strides,
g,
k,
c,
y * p_d_strides[0][3] +
x * p_d_strides[0][4]);
p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + y * wei_stride_y +
x] = wei_val;
}
@@ -210,39 +272,65 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in
const index_t k = remaining % K;
const index_t g = remaining / K;
float acc = 0.0f;
const InDataType* in_g = p_in + g * in_stride_g;
const OutDataType* out_grad = p_out_grad + g * out_stride_g;
float acc = 0.0f;
// Base pointers for current group
const InDataType* input_g = p_ins[0] + g * in_stride_g;
const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g;
// Loop over batch and output positions
for(index_t n = 0; n < N; ++n)
{
const InDataType* in_gnc = in_g + n * in_stride_n + c * in_stride_c;
const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k;
// Pointers at current batch and input channel
const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c;
const OutDataType* output_grad_at_n_k =
output_grad_g + n * out_stride_n + k * out_stride_k;
for(index_t do_idx = 0; do_idx < Do; ++do_idx)
{
long_index_t di = do_idx * stride_z + z * dilation_z - pad_z;
if(di >= 0 && di < Di)
{
const InDataType* in_gncd = in_gnc + di * in_stride_d;
const OutDataType* out_gn_kd = out_gn_k + do_idx * out_stride_d;
// Pointers at current spatial depth
const InDataType* input_at_d = input_at_n_c + di * in_stride_d;
const OutDataType* output_grad_at_d =
output_grad_at_n_k + do_idx * out_stride_d;
for(index_t ho = 0; ho < Ho; ++ho)
{
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
if(hi >= 0 && hi < Hi)
{
const InDataType* in_gncdh = in_gncd + hi * in_stride_h;
const OutDataType* out_gn_kdh = out_gn_kd + ho * out_stride_h;
// Pointers at current spatial depth and height
const InDataType* input_at_d_h = input_at_d + hi * in_stride_h;
const OutDataType* output_grad_at_d_h =
output_grad_at_d + ho * out_stride_h;
for(index_t wo = 0; wo < Wo; ++wo)
{
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
if(wi >= 0 && wi < Wi)
{
in_op(in_val, in_gncdh[wi]);
out_op(out_val, out_gn_kdh[wo]);
// Handle input element-wise operation with extra A tensors
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
in_val,
in_op,
input_at_d_h,
p_ins + 1,
g * in_stride_g + n * in_stride_n + c * in_stride_c +
di * in_stride_d + hi * in_stride_h,
wi);
// Handle output gradient element-wise operation with extra
// B tensors
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
out_val,
out_op,
output_grad_at_d_h,
p_out_grads + 1,
g * out_stride_g + n * out_stride_n + k * out_stride_k +
do_idx * out_stride_d + ho * out_stride_h,
wo);
acc += type_convert<float>(out_val) *
type_convert<float>(in_val);
}
@@ -253,16 +341,28 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in
}
}
WeiDataType result = type_convert<WeiDataType>(acc);
wei_op(wei_val, result);
detail::apply_d_tensor_elementwise_op<NumD>(
wei_val,
wei_op,
acc,
p_ds,
p_d_strides,
g,
k,
c,
z * p_d_strides[0][3] + y * p_d_strides[0][4] + x * p_d_strides[0][5]);
p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + z * wei_stride_z +
y * wei_stride_y + x] = wei_val;
}
}
}
// GPU reference backward weight convolution - takes ConvParam directly
template <typename InLayout,
// GPU reference backward weight convolution with multi-ABD support - takes ConvParam directly
template <ck::index_t NumAElementwise = 0,
ck::index_t NumBElementwise = 0,
ck::index_t NumDElementwise = 0,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename TIn,
@@ -270,15 +370,20 @@ template <typename InLayout,
typename TOut,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
void naive_conv_bwd_weight(const TIn* p_in,
TWei* p_wei_grad,
const TOut* p_out,
const ck::utils::conv::ConvParam& conv_param,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
typename OutElementwiseOperation,
typename TD = TWei> // D tensor type, defaults to TWei for backward compatibility
void naive_conv_bwd_weight_multi_abd(
const std::array<const TIn*, NumAElementwise + 1>& p_ins,
TWei* p_wei_grad,
const std::array<const TOut*, NumBElementwise + 1>& p_outs,
const std::array<const TD*, NumDElementwise>& p_ds,
const ck::utils::conv::ConvParam& conv_param,
[[maybe_unused]] const std::array<std::vector<index_t>, NumDElementwise>& d_lengths,
const std::array<std::vector<index_t>, NumDElementwise>& d_strides,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
{
const auto ndim = conv_param.num_dim_spatial_;
@@ -308,13 +413,35 @@ void naive_conv_bwd_weight(const TIn* p_in,
out_total *= l;
// Allocate packed buffers
SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn));
SimpleDeviceMem wei_grad_packed_buf(wei_total * sizeof(TWei));
SimpleDeviceMem out_grad_packed_buf(out_total * sizeof(TOut));
std::vector<SimpleDeviceMem> in_packed_bufs;
in_packed_bufs.reserve(NumAElementwise + 1);
for(index_t i = 0; i <= NumAElementwise; ++i)
{
in_packed_bufs.emplace_back(in_total * sizeof(TIn));
}
SimpleDeviceMem wei_grad_packed_buf(wei_total * sizeof(TWei));
std::vector<SimpleDeviceMem> out_grad_packed_bufs;
out_grad_packed_bufs.reserve(NumBElementwise + 1);
for(index_t i = 0; i <= NumBElementwise; ++i)
{
out_grad_packed_bufs.emplace_back(out_total * sizeof(TOut));
}
std::array<TIn*, NumAElementwise + 1> p_ins_packed;
for(index_t i = 0; i <= NumAElementwise; ++i)
{
p_ins_packed[i] = static_cast<TIn*>(in_packed_bufs[i].GetDeviceBuffer());
}
TIn* p_in_packed = static_cast<TIn*>(in_packed_buf.GetDeviceBuffer());
TWei* p_wei_grad_packed = static_cast<TWei*>(wei_grad_packed_buf.GetDeviceBuffer());
TOut* p_out_grad_packed = static_cast<TOut*>(out_grad_packed_buf.GetDeviceBuffer());
std::array<TOut*, NumBElementwise + 1> p_out_grads_packed;
for(index_t i = 0; i <= NumBElementwise; ++i)
{
p_out_grads_packed[i] = static_cast<TOut*>(out_grad_packed_bufs[i].GetDeviceBuffer());
}
// Compute strides and allocate device arrays for pack/unpack
std::vector<index_t> in_strides = compute_conv_tensor_strides<InLayout>(in_lengths, ndim);
@@ -351,12 +478,81 @@ void naive_conv_bwd_weight(const TIn* p_in,
// Pack input and output_grad tensors to contiguous layout (inputs to bwd weight)
constexpr int block_size = 256;
strided_copy_kernel<TIn, false>
<<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_in, p_in_packed, d_in_lengths, d_in_strides, dim_count, in_total);
strided_copy_kernel<TOut, false>
<<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_out, p_out_grad_packed, d_out_lengths, d_out_strides, dim_count, out_total);
for(index_t i = 0; i <= NumAElementwise; ++i)
{
strided_copy_kernel<TIn, false>
<<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_ins[i], p_ins_packed[i], d_in_lengths, d_in_strides, dim_count, in_total);
}
for(index_t i = 0; i <= NumBElementwise; ++i)
{
strided_copy_kernel<TOut, false>
<<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_outs[i],
p_out_grads_packed[i],
d_out_lengths,
d_out_strides,
dim_count,
out_total);
}
// Prepare D tensor stride arrays on device
std::vector<SimpleDeviceMem> d_stride_bufs;
std::array<index_t*, NumDElementwise> p_d_strides_dev = {};
if constexpr(NumDElementwise > 0)
{
d_stride_bufs.reserve(NumDElementwise);
for(index_t i = 0; i < NumDElementwise; ++i)
{
d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t));
p_d_strides_dev[i] = static_cast<index_t*>(d_stride_bufs[i].GetDeviceBuffer());
HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i],
d_strides[i].data(),
d_strides[i].size() * sizeof(index_t),
hipMemcpyHostToDevice));
}
}
// Create device arrays of pointers
SimpleDeviceMem ins_ptrs_buf((NumAElementwise + 1) * sizeof(TIn*));
SimpleDeviceMem out_grads_ptrs_buf((NumBElementwise + 1) * sizeof(TOut*));
SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*));
SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*));
TIn** d_ins_ptrs = static_cast<TIn**>(ins_ptrs_buf.GetDeviceBuffer());
TOut** d_out_grads_ptrs = static_cast<TOut**>(out_grads_ptrs_buf.GetDeviceBuffer());
TD** d_ds_ptrs = static_cast<TD**>(ds_ptrs_buf.GetDeviceBuffer());
index_t** d_d_strides_ptrs = static_cast<index_t**>(d_strides_ptrs_buf.GetDeviceBuffer());
HIP_CHECK_ERROR(hipMemcpy(d_ins_ptrs,
p_ins_packed.data(),
(NumAElementwise + 1) * sizeof(TIn*),
hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_out_grads_ptrs,
p_out_grads_packed.data(),
(NumBElementwise + 1) * sizeof(TOut*),
hipMemcpyHostToDevice));
if constexpr(NumDElementwise > 0)
{
std::array<const TD*, NumDElementwise> p_ds_dev;
for(index_t i = 0; i < NumDElementwise; ++i)
{
p_ds_dev[i] = p_ds[i];
}
HIP_CHECK_ERROR(hipMemcpy(
d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs,
p_d_strides_dev.data(),
NumDElementwise * sizeof(index_t*),
hipMemcpyHostToDevice));
}
// Build conv parameter vectors for kernel invocation
std::vector<index_t> conv_strides(ndim);
@@ -374,16 +570,22 @@ void naive_conv_bwd_weight(const TIn* p_in,
if(ndim == 1)
{
naive_conv_bwd_weight_packed<1,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<wei_grid, block_size, 0, stream>>>(p_in_packed,
naive_conv_bwd_weight_packed_multi_abd<1,
NumAElementwise,
NumBElementwise,
NumDElementwise,
TIn,
TWei,
TOut,
TD,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<wei_grid, block_size, 0, stream>>>(d_ins_ptrs,
p_wei_grad_packed,
p_out_grad_packed,
d_out_grads_ptrs,
d_ds_ptrs,
d_d_strides_ptrs,
G,
N,
K,
@@ -412,16 +614,22 @@ void naive_conv_bwd_weight(const TIn* p_in,
}
else if(ndim == 2)
{
naive_conv_bwd_weight_packed<2,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<wei_grid, block_size, 0, stream>>>(p_in_packed,
naive_conv_bwd_weight_packed_multi_abd<2,
NumAElementwise,
NumBElementwise,
NumDElementwise,
TIn,
TWei,
TOut,
TD,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<wei_grid, block_size, 0, stream>>>(d_ins_ptrs,
p_wei_grad_packed,
p_out_grad_packed,
d_out_grads_ptrs,
d_ds_ptrs,
d_d_strides_ptrs,
G,
N,
K,
@@ -450,16 +658,22 @@ void naive_conv_bwd_weight(const TIn* p_in,
}
else // 3D
{
naive_conv_bwd_weight_packed<3,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<wei_grid, block_size, 0, stream>>>(p_in_packed,
naive_conv_bwd_weight_packed_multi_abd<3,
NumAElementwise,
NumBElementwise,
NumDElementwise,
TIn,
TWei,
TOut,
TD,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<wei_grid, block_size, 0, stream>>>(d_ins_ptrs,
p_wei_grad_packed,
p_out_grad_packed,
d_out_grads_ptrs,
d_ds_ptrs,
d_d_strides_ptrs,
G,
N,
K,
@@ -496,5 +710,44 @@ void naive_conv_bwd_weight(const TIn* p_in,
// Memory automatically freed by SimpleDeviceMem destructors
}
// Original naive_conv_bwd_weight - now a zero-overhead wrapper
template <typename InLayout,
typename WeiLayout,
typename OutLayout,
typename TIn,
typename TWei,
typename TOut,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
inline void
naive_conv_bwd_weight(const TIn* p_in,
TWei* p_wei_grad,
const TOut* p_out,
const ck::utils::conv::ConvParam& conv_param,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
{
std::array<const TIn*, 1> p_ins = {p_in};
std::array<const TOut*, 1> p_outs = {p_out};
std::array<const TWei*, 0> p_ds = {};
std::array<std::vector<index_t>, 0> d_lengths = {};
std::array<std::vector<index_t>, 0> d_strides = {};
naive_conv_bwd_weight_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_ins,
p_wei_grad,
p_outs,
p_ds,
conv_param,
d_lengths,
d_strides,
in_element_op,
wei_element_op,
out_element_op,
stream);
}
} // namespace ref
} // namespace ck

View File

@@ -10,48 +10,56 @@
#include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include <array>
namespace ck {
namespace ref {
// Optimized convolution kernel working with packed (contiguous) tensors
// Optimized convolution kernel working with packed (contiguous) tensors with multi-ABD support
// Assumes row-major packing: input[G][N][C][spatial], weight[G][K][C][filter],
// output[G][N][K][spatial]
template <index_t NDimSpatial,
index_t NumAExtra, // Number of extra A (input) tensors
index_t NumBExtra, // Number of extra B (weight) tensors
index_t NumD, // Number of D tensors
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename DDataType, // D tensor data type
typename InElementOp,
typename WeiElementOp,
typename OutElementOp>
__global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
const WeiDataType* __restrict__ p_wei,
OutDataType* __restrict__ p_out,
index_t G,
index_t N,
index_t K,
index_t C,
index_t Di,
index_t Hi,
index_t Wi,
index_t Z,
index_t Y,
index_t X,
index_t Do,
index_t Ho,
index_t Wo,
index_t stride_z,
index_t stride_y,
index_t stride_x,
index_t dilation_z,
index_t dilation_y,
index_t dilation_x,
index_t pad_z,
index_t pad_y,
index_t pad_x,
InElementOp in_op,
WeiElementOp wei_op,
OutElementOp out_op)
__global__ void naive_conv_fwd_packed_multi_abd(
const InDataType* const* __restrict__ p_ins, // Array of input pointers (1 + NumAExtra)
const WeiDataType* const* __restrict__ p_weis, // Array of weight pointers (1 + NumBExtra)
const DDataType* const* __restrict__ p_ds, // Array of D tensor pointers
const index_t* const* __restrict__ p_d_strides, // Array of D tensor stride arrays
OutDataType* __restrict__ p_out,
index_t G,
index_t N,
index_t K,
index_t C,
index_t Di,
index_t Hi,
index_t Wi,
index_t Z,
index_t Y,
index_t X,
index_t Do,
index_t Ho,
index_t Wo,
index_t stride_z,
index_t stride_y,
index_t stride_x,
index_t dilation_z,
index_t dilation_y,
index_t dilation_x,
index_t pad_z,
index_t pad_y,
index_t pad_x,
InElementOp in_op,
WeiElementOp wei_op,
OutElementOp out_op)
{
const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const long_index_t num_threads = blockDim.x * gridDim.x;
@@ -83,29 +91,48 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
const index_t n = remaining % N;
const index_t g = remaining / N;
float acc = 0.0f;
const InDataType* in_g = p_in + g * in_stride_g + n * in_stride_n;
const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k;
float acc = 0.0f;
// Base pointers for current group, batch, and output channel
const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n;
const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k;
for(index_t c = 0; c < C; ++c)
{
const InDataType* in_gc = in_g + c * in_stride_c;
const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c;
// Pointers at current input channel
const InDataType* input_at_c = input_g_n + c * in_stride_c;
const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c;
for(index_t x = 0; x < X; ++x)
{
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
if(wi >= 0 && wi < Wi)
{
in_op(in_val, in_gc[wi]);
wei_op(wei_val, wei_gkc[x]);
// Handle input element-wise operation with extra A tensors
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
in_val,
in_op,
input_at_c,
p_ins + 1,
g * in_stride_g + n * in_stride_n + c * in_stride_c,
wi);
// Handle weight element-wise operation with extra B tensors
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
wei_val,
wei_op,
weight_at_c,
p_weis + 1,
g * wei_stride_g + k * wei_stride_k + c * wei_stride_c,
x);
acc += type_convert<float>(in_val) * type_convert<float>(wei_val);
}
}
}
OutDataType result = type_convert<OutDataType>(acc);
out_op(out_val, result);
detail::apply_d_tensor_elementwise_op<NumD>(
out_val, out_op, acc, p_ds, p_d_strides, g, n, k, wo);
p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + wo] = out_val;
}
}
@@ -137,30 +164,51 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
const index_t n = remaining % N;
const index_t g = remaining / N;
float acc = 0.0f;
const InDataType* in_gn = p_in + g * in_stride_g + n * in_stride_n;
const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k;
float acc = 0.0f;
// Base pointers for current group, batch, and output channel
const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n;
const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k;
for(index_t c = 0; c < C; ++c)
{
const InDataType* in_gnc = in_gn + c * in_stride_c;
const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c;
// Pointers at current input channel
const InDataType* input_at_c = input_g_n + c * in_stride_c;
const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c;
for(index_t y = 0; y < Y; ++y)
{
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
if(hi >= 0 && hi < Hi)
{
const InDataType* in_gnch = in_gnc + hi * in_stride_h;
const WeiDataType* wei_gkcy = wei_gkc + y * wei_stride_y;
// Pointers at current spatial height and filter Y position
const InDataType* input_at_h = input_at_c + hi * in_stride_h;
const WeiDataType* weight_at_y = weight_at_c + y * wei_stride_y;
for(index_t x = 0; x < X; ++x)
{
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
if(wi >= 0 && wi < Wi)
{
in_op(in_val, in_gnch[wi]);
wei_op(wei_val, wei_gkcy[x]);
// Handle input element-wise operation with extra A tensors
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
in_val,
in_op,
input_at_h,
p_ins + 1,
g * in_stride_g + n * in_stride_n + c * in_stride_c +
hi * in_stride_h,
wi);
// Handle weight element-wise operation with extra B tensors
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
wei_val,
wei_op,
weight_at_y,
p_weis + 1,
g * wei_stride_g + k * wei_stride_k + c * wei_stride_c +
y * wei_stride_y,
x);
acc += type_convert<float>(in_val) * type_convert<float>(wei_val);
}
}
@@ -168,8 +216,17 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
}
}
OutDataType result = type_convert<OutDataType>(acc);
out_op(out_val, result);
detail::apply_d_tensor_elementwise_op<NumD>(out_val,
out_op,
acc,
p_ds,
p_d_strides,
g,
n,
k,
ho * p_d_strides[0][3] +
wo * p_d_strides[0][4]);
p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + ho * out_stride_h + wo] =
out_val;
}
@@ -207,38 +264,60 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
const index_t n = remaining % N;
const index_t g = remaining / N;
float acc = 0.0f;
const InDataType* in_gn = p_in + g * in_stride_g + n * in_stride_n;
const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k;
float acc = 0.0f;
// Base pointers for current group, batch, and output channel
const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n;
const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k;
for(index_t c = 0; c < C; ++c)
{
const InDataType* in_gnc = in_gn + c * in_stride_c;
const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c;
// Pointers at current input channel
const InDataType* input_at_c = input_g_n + c * in_stride_c;
const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c;
for(index_t z = 0; z < Z; ++z)
{
long_index_t di = do_idx * stride_z + z * dilation_z - pad_z;
if(di >= 0 && di < Di)
{
const InDataType* in_gncd = in_gnc + di * in_stride_d;
const WeiDataType* wei_gkcz = wei_gkc + z * wei_stride_z;
// Pointers at current spatial depth
const InDataType* input_at_d = input_at_c + di * in_stride_d;
const WeiDataType* weight_at_z = weight_at_c + z * wei_stride_z;
for(index_t y = 0; y < Y; ++y)
{
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
if(hi >= 0 && hi < Hi)
{
const InDataType* in_gncdh = in_gncd + hi * in_stride_h;
const WeiDataType* wei_gkczy = wei_gkcz + y * wei_stride_y;
// Pointers at current spatial depth and height
const InDataType* input_at_d_h = input_at_d + hi * in_stride_h;
const WeiDataType* weight_at_z_y = weight_at_z + y * wei_stride_y;
for(index_t x = 0; x < X; ++x)
{
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
if(wi >= 0 && wi < Wi)
{
in_op(in_val, in_gncdh[wi]);
wei_op(wei_val, wei_gkczy[x]);
// Handle input element-wise operation with extra A tensors
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
in_val,
in_op,
input_at_d_h,
p_ins + 1,
g * in_stride_g + n * in_stride_n + c * in_stride_c +
di * in_stride_d + hi * in_stride_h,
wi);
// Handle weight element-wise operation with extra B tensors
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
wei_val,
wei_op,
weight_at_z_y,
p_weis + 1,
g * wei_stride_g + k * wei_stride_k + c * wei_stride_c +
z * wei_stride_z + y * wei_stride_y,
x);
acc += type_convert<float>(in_val) *
type_convert<float>(wei_val);
}
@@ -249,16 +328,28 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
}
}
OutDataType result = type_convert<OutDataType>(acc);
out_op(out_val, result);
detail::apply_d_tensor_elementwise_op<NumD>(
out_val,
out_op,
acc,
p_ds,
p_d_strides,
g,
n,
k,
do_idx * p_d_strides[0][3] + ho * p_d_strides[0][4] + wo * p_d_strides[0][5]);
p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + do_idx * out_stride_d +
ho * out_stride_h + wo] = out_val;
}
}
}
// GPU reference convolution - takes ConvParam directly
template <typename InLayout,
// GPU reference convolution with multi-ABD support - takes ConvParam directly
template <ck::index_t NumAElementwise = 0,
ck::index_t NumBElementwise = 0,
ck::index_t NumDElementwise = 0,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename TIn,
@@ -266,15 +357,20 @@ template <typename InLayout,
typename TOut,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
void naive_conv_fwd(const TIn* p_in,
const TWei* p_wei,
TOut* p_out,
const ck::utils::conv::ConvParam& conv_param,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
typename OutElementwiseOperation,
typename TD = TOut> // D tensor type, defaults to TOut for backward compatibility
void naive_conv_fwd_multi_abd(
const std::array<const TIn*, NumAElementwise + 1>& p_ins,
const std::array<const TWei*, NumBElementwise + 1>& p_weis,
const std::array<const TD*, NumDElementwise>& p_ds,
TOut* p_out,
const ck::utils::conv::ConvParam& conv_param,
[[maybe_unused]] const std::array<std::vector<index_t>, NumDElementwise>& d_lengths,
const std::array<std::vector<index_t>, NumDElementwise>& d_strides,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
{
const auto ndim = conv_param.num_dim_spatial_;
@@ -303,13 +399,37 @@ void naive_conv_fwd(const TIn* p_in,
for(auto l : out_lengths)
out_total *= l;
// Allocate packed buffers
SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn));
SimpleDeviceMem wei_packed_buf(wei_total * sizeof(TWei));
// Allocate packed buffers for all A and B tensors
// Use separate allocations to avoid copy assignment issues with RAII wrapper
std::vector<SimpleDeviceMem> in_packed_bufs;
in_packed_bufs.reserve(NumAElementwise + 1);
for(index_t i = 0; i <= NumAElementwise; ++i)
{
in_packed_bufs.emplace_back(in_total * sizeof(TIn));
}
std::vector<SimpleDeviceMem> wei_packed_bufs;
wei_packed_bufs.reserve(NumBElementwise + 1);
for(index_t i = 0; i <= NumBElementwise; ++i)
{
wei_packed_bufs.emplace_back(wei_total * sizeof(TWei));
}
SimpleDeviceMem out_packed_buf(out_total * sizeof(TOut));
TIn* p_in_packed = static_cast<TIn*>(in_packed_buf.GetDeviceBuffer());
TWei* p_wei_packed = static_cast<TWei*>(wei_packed_buf.GetDeviceBuffer());
// Get packed buffer pointers
std::array<TIn*, NumAElementwise + 1> p_ins_packed;
for(index_t i = 0; i <= NumAElementwise; ++i)
{
p_ins_packed[i] = static_cast<TIn*>(in_packed_bufs[i].GetDeviceBuffer());
}
std::array<TWei*, NumBElementwise + 1> p_weis_packed;
for(index_t i = 0; i <= NumBElementwise; ++i)
{
p_weis_packed[i] = static_cast<TWei*>(wei_packed_bufs[i].GetDeviceBuffer());
}
TOut* p_out_packed = static_cast<TOut*>(out_packed_buf.GetDeviceBuffer());
// Compute strides and allocate device arrays for pack/unpack
@@ -347,12 +467,82 @@ void naive_conv_fwd(const TIn* p_in,
// Pack input and weight tensors to contiguous layout
constexpr int block_size = 256;
strided_copy_kernel<TIn, false>
<<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_in, p_in_packed, d_in_lengths, d_in_strides, dim_count, in_total);
strided_copy_kernel<TWei, false>
<<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_wei, p_wei_packed, d_wei_lengths, d_wei_strides, dim_count, wei_total);
// Pack all A tensors
for(index_t i = 0; i <= NumAElementwise; ++i)
{
strided_copy_kernel<TIn, false>
<<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_ins[i], p_ins_packed[i], d_in_lengths, d_in_strides, dim_count, in_total);
}
// Pack all B tensors
for(index_t i = 0; i <= NumBElementwise; ++i)
{
strided_copy_kernel<TWei, false>
<<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_weis[i], p_weis_packed[i], d_wei_lengths, d_wei_strides, dim_count, wei_total);
}
// Prepare D tensor stride arrays on device
// NOTE: D tensors are NOT packed - they are used directly with their original strides
// to support broadcasting (e.g., BiasGK layout with zero strides)
std::vector<SimpleDeviceMem> d_stride_bufs;
std::array<index_t*, NumDElementwise> p_d_strides_dev = {};
if constexpr(NumDElementwise > 0)
{
d_stride_bufs.reserve(NumDElementwise);
for(index_t i = 0; i < NumDElementwise; ++i)
{
// Allocate and copy strides to device
d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t));
p_d_strides_dev[i] = static_cast<index_t*>(d_stride_bufs[i].GetDeviceBuffer());
HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i],
d_strides[i].data(),
d_strides[i].size() * sizeof(index_t),
hipMemcpyHostToDevice));
}
}
// Create device arrays of pointers
SimpleDeviceMem ins_ptrs_buf((NumAElementwise + 1) * sizeof(TIn*));
SimpleDeviceMem weis_ptrs_buf((NumBElementwise + 1) * sizeof(TWei*));
SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*));
SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*));
TIn** d_ins_ptrs = static_cast<TIn**>(ins_ptrs_buf.GetDeviceBuffer());
TWei** d_weis_ptrs = static_cast<TWei**>(weis_ptrs_buf.GetDeviceBuffer());
TD** d_ds_ptrs = static_cast<TD**>(ds_ptrs_buf.GetDeviceBuffer());
index_t** d_d_strides_ptrs = static_cast<index_t**>(d_strides_ptrs_buf.GetDeviceBuffer());
HIP_CHECK_ERROR(hipMemcpy(d_ins_ptrs,
p_ins_packed.data(),
(NumAElementwise + 1) * sizeof(TIn*),
hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_weis_ptrs,
p_weis_packed.data(),
(NumBElementwise + 1) * sizeof(TWei*),
hipMemcpyHostToDevice));
if constexpr(NumDElementwise > 0)
{
// D tensors use original pointers (not packed) to support broadcasting
std::array<const TD*, NumDElementwise> p_ds_dev;
for(index_t i = 0; i < NumDElementwise; ++i)
{
p_ds_dev[i] = p_ds[i];
}
HIP_CHECK_ERROR(hipMemcpy(
d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs,
p_d_strides_dev.data(),
NumDElementwise * sizeof(index_t*),
hipMemcpyHostToDevice));
}
// Build conv parameter vectors for kernel invocation
std::vector<index_t> conv_strides(ndim);
@@ -370,15 +560,21 @@ void naive_conv_fwd(const TIn* p_in,
if(ndim == 1)
{
naive_conv_fwd_packed<1,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<out_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_packed,
naive_conv_fwd_packed_multi_abd<1,
NumAElementwise,
NumBElementwise,
NumDElementwise,
TIn,
TWei,
TOut,
TD,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<out_grid, block_size, 0, stream>>>(d_ins_ptrs,
d_weis_ptrs,
d_ds_ptrs,
d_d_strides_ptrs,
p_out_packed,
G,
N,
@@ -408,15 +604,21 @@ void naive_conv_fwd(const TIn* p_in,
}
else if(ndim == 2)
{
naive_conv_fwd_packed<2,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<out_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_packed,
naive_conv_fwd_packed_multi_abd<2,
NumAElementwise,
NumBElementwise,
NumDElementwise,
TIn,
TWei,
TOut,
TD,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<out_grid, block_size, 0, stream>>>(d_ins_ptrs,
d_weis_ptrs,
d_ds_ptrs,
d_d_strides_ptrs,
p_out_packed,
G,
N,
@@ -446,15 +648,21 @@ void naive_conv_fwd(const TIn* p_in,
}
else // 3D
{
naive_conv_fwd_packed<3,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<out_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_packed,
naive_conv_fwd_packed_multi_abd<3,
NumAElementwise,
NumBElementwise,
NumDElementwise,
TIn,
TWei,
TOut,
TD,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<out_grid, block_size, 0, stream>>>(d_ins_ptrs,
d_weis_ptrs,
d_ds_ptrs,
d_d_strides_ptrs,
p_out_packed,
G,
N,
@@ -492,5 +700,43 @@ void naive_conv_fwd(const TIn* p_in,
// Memory automatically freed by SimpleDeviceMem destructors
}
// Original naive_conv_fwd - now a zero-overhead wrapper
template <typename InLayout,
typename WeiLayout,
typename OutLayout,
typename TIn,
typename TWei,
typename TOut,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
inline void naive_conv_fwd(const TIn* p_in,
const TWei* p_wei,
TOut* p_out,
const ck::utils::conv::ConvParam& conv_param,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
{
std::array<const TIn*, 1> p_ins = {p_in};
std::array<const TWei*, 1> p_weis = {p_wei};
std::array<const TOut*, 0> p_ds = {};
std::array<std::vector<index_t>, 0> d_lengths = {};
std::array<std::vector<index_t>, 0> d_strides = {};
naive_conv_fwd_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_ins,
p_weis,
p_ds,
p_out,
conv_param,
d_lengths,
d_strides,
in_element_op,
wei_element_op,
out_element_op,
stream);
}
} // namespace ref
} // namespace ck

View File

@@ -22,9 +22,39 @@ struct SimpleDeviceMem
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&p_mem_), mem_size));
}
// Delete copy operations (resource should not be copied)
SimpleDeviceMem(const SimpleDeviceMem&) = delete;
SimpleDeviceMem& operator=(const SimpleDeviceMem&) = delete;
// Define move operations
SimpleDeviceMem(SimpleDeviceMem&& other) noexcept : p_mem_(other.p_mem_)
{
other.p_mem_ = nullptr;
}
SimpleDeviceMem& operator=(SimpleDeviceMem&& other) noexcept
{
if(this != &other)
{
if(p_mem_)
{
(void)hipFree(p_mem_);
}
p_mem_ = other.p_mem_;
other.p_mem_ = nullptr;
}
return *this;
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
~SimpleDeviceMem()
{
if(p_mem_)
{
(void)hipFree(p_mem_);
}
}
void* p_mem_;
};
@@ -173,5 +203,90 @@ __global__ void strided_copy_kernel(const DataType* __restrict__ src,
}
}
namespace detail {
// Helper for parameter pack expansion (D tensors)
template <typename ResultType, typename Op, typename DataType, std::size_t... Is>
__device__ __forceinline__ void apply_multi_tensor_impl(ResultType& result,
Op&& element_op,
const DataType* const* tensor_ptrs,
long_index_t element_offset,
std::index_sequence<Is...>)
{
element_op(result, tensor_ptrs[Is][element_offset]...);
}
// Generic helper for A and B tensors (works in all directions)
template <index_t NumExtraTensors, typename DataType, typename ResultType, typename Op>
__device__ __forceinline__ void apply_multi_tensor_elementwise_op(ResultType& result,
Op&& element_op,
const DataType* primary_ptr,
const DataType* const* extra_ptrs,
long_index_t extra_base_offset,
long_index_t element_offset)
{
const DataType* tensor_ptrs[NumExtraTensors + 1];
tensor_ptrs[0] = primary_ptr;
static_for<1, NumExtraTensors + 1, 1>{}(
[&](auto i) { tensor_ptrs[i] = extra_ptrs[i - 1] + extra_base_offset; });
apply_multi_tensor_impl(result,
element_op,
tensor_ptrs,
element_offset,
std::make_index_sequence<NumExtraTensors + 1>{});
}
// Helper for parameter pack expansion (D tensors)
template <typename OutDataType, typename Op, std::size_t... Is>
__device__ __forceinline__ void apply_d_tensor_impl(OutDataType& result_out,
Op&& element_op,
float computed_value,
const float* d_values,
std::index_sequence<Is...>)
{
float temp_out;
element_op(temp_out, computed_value, d_values[Is]...);
result_out = type_convert<OutDataType>(temp_out);
}
// Specialized helper for D tensors with stride calculations and float conversion
template <index_t NumDTensors, typename DDataType, typename OutDataType, typename Op>
__device__ __forceinline__ void apply_d_tensor_elementwise_op(OutDataType& result_out,
Op&& element_op,
float computed_value,
const DDataType* const* p_ds,
const index_t* const* p_d_strides,
index_t g,
index_t n,
index_t c_or_k,
long_index_t spatial_linear_index)
{
if constexpr(NumDTensors == 0)
{
element_op(result_out, computed_value);
}
else
{
float d_values[NumDTensors];
// Compute all D tensor indices and convert to float
static_for<0, NumDTensors, 1>{}([&](auto i) {
const long_index_t d_idx = g * p_d_strides[i][0] + n * p_d_strides[i][1] +
c_or_k * p_d_strides[i][2] + spatial_linear_index;
d_values[i] = type_convert<float>(p_ds[i][d_idx]);
});
apply_d_tensor_impl(result_out,
element_op,
computed_value,
d_values,
std::make_index_sequence<NumDTensors>{});
}
}
} // namespace detail
} // namespace ref
} // namespace ck