multiple A/B tensors and D tensor for fwd GPU ref

This commit is contained in:
Graner, Johannes
2026-01-08 06:35:57 -05:00
parent 0c106d2870
commit 9e95a2a62a
5 changed files with 1039 additions and 94 deletions

View File

@@ -10,48 +10,56 @@
#include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include <array>
namespace ck {
namespace ref {
// Optimized convolution kernel working with packed (contiguous) tensors
// Optimized convolution kernel working with packed (contiguous) tensors with multi-ABD support
// Assumes row-major packing: input[G][N][C][spatial], weight[G][K][C][filter],
// output[G][N][K][spatial]
template <index_t NDimSpatial,
index_t NumAExtra, // Number of extra A (input) tensors
index_t NumBExtra, // Number of extra B (weight) tensors
index_t NumD, // Number of D tensors
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename DDataType, // D tensor data type
typename InElementOp,
typename WeiElementOp,
typename OutElementOp>
__global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
const WeiDataType* __restrict__ p_wei,
OutDataType* __restrict__ p_out,
index_t G,
index_t N,
index_t K,
index_t C,
index_t Di,
index_t Hi,
index_t Wi,
index_t Z,
index_t Y,
index_t X,
index_t Do,
index_t Ho,
index_t Wo,
index_t stride_z,
index_t stride_y,
index_t stride_x,
index_t dilation_z,
index_t dilation_y,
index_t dilation_x,
index_t pad_z,
index_t pad_y,
index_t pad_x,
InElementOp in_op,
WeiElementOp wei_op,
OutElementOp out_op)
__global__ void naive_conv_fwd_packed_multi_abd(
const InDataType* const* __restrict__ p_ins, // Array of input pointers (1 + NumAExtra)
const WeiDataType* const* __restrict__ p_weis, // Array of weight pointers (1 + NumBExtra)
const DDataType* const* __restrict__ p_ds, // Array of D tensor pointers
const index_t* const* __restrict__ p_d_strides, // Array of D tensor stride arrays
OutDataType* __restrict__ p_out,
index_t G,
index_t N,
index_t K,
index_t C,
index_t Di,
index_t Hi,
index_t Wi,
index_t Z,
index_t Y,
index_t X,
index_t Do,
index_t Ho,
index_t Wo,
index_t stride_z,
index_t stride_y,
index_t stride_x,
index_t dilation_z,
index_t dilation_y,
index_t dilation_x,
index_t pad_z,
index_t pad_y,
index_t pad_x,
InElementOp in_op,
WeiElementOp wei_op,
OutElementOp out_op)
{
const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const long_index_t num_threads = blockDim.x * gridDim.x;
@@ -84,8 +92,8 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
const index_t g = remaining / N;
float acc = 0.0f;
const InDataType* in_g = p_in + g * in_stride_g + n * in_stride_n;
const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k;
const InDataType* in_g = p_ins[0] + g * in_stride_g + n * in_stride_n;
const WeiDataType* wei_gk = p_weis[0] + g * wei_stride_g + k * wei_stride_k;
for(index_t c = 0; c < C; ++c)
{
@@ -97,15 +105,73 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
if(wi >= 0 && wi < Wi)
{
in_op(in_val, in_gc[wi]);
wei_op(wei_val, wei_gkc[x]);
// Handle input element-wise operation with extra A tensors
if constexpr(NumAExtra == 0)
{
in_op(in_val, in_gc[wi]);
}
else if constexpr(NumAExtra == 1)
{
const InDataType* in_extra =
p_ins[1] + g * in_stride_g + n * in_stride_n + c * in_stride_c;
in_op(in_val, in_gc[wi], in_extra[wi]);
}
else if constexpr(NumAExtra == 2)
{
const InDataType* in_extra0 =
p_ins[1] + g * in_stride_g + n * in_stride_n + c * in_stride_c;
const InDataType* in_extra1 =
p_ins[2] + g * in_stride_g + n * in_stride_n + c * in_stride_c;
in_op(in_val, in_gc[wi], in_extra0[wi], in_extra1[wi]);
}
// Handle weight element-wise operation with extra B tensors
if constexpr(NumBExtra == 0)
{
wei_op(wei_val, wei_gkc[x]);
}
else if constexpr(NumBExtra == 1)
{
const WeiDataType* wei_extra =
p_weis[1] + g * wei_stride_g + k * wei_stride_k + c * wei_stride_c;
wei_op(wei_val, wei_gkc[x], wei_extra[x]);
}
else if constexpr(NumBExtra == 2)
{
const WeiDataType* wei_extra0 =
p_weis[1] + g * wei_stride_g + k * wei_stride_k + c * wei_stride_c;
const WeiDataType* wei_extra1 =
p_weis[2] + g * wei_stride_g + k * wei_stride_k + c * wei_stride_c;
wei_op(wei_val, wei_gkc[x], wei_extra0[x], wei_extra1[x]);
}
acc += type_convert<float>(in_val) * type_convert<float>(wei_val);
}
}
}
OutDataType result = type_convert<OutDataType>(acc);
out_op(out_val, result);
// Handle output element-wise operation with D tensors
if constexpr(NumD == 0)
{
out_op(out_val, result);
}
else if constexpr(NumD == 1)
{
const long_index_t d_idx = g * p_d_strides[0][0] + n * p_d_strides[0][1] +
k * p_d_strides[0][2] + wo * p_d_strides[0][3];
out_op(out_val, result, p_ds[0][d_idx]);
}
else if constexpr(NumD == 2)
{
const long_index_t d0_idx = g * p_d_strides[0][0] + n * p_d_strides[0][1] +
k * p_d_strides[0][2] + wo * p_d_strides[0][3];
const long_index_t d1_idx = g * p_d_strides[1][0] + n * p_d_strides[1][1] +
k * p_d_strides[1][2] + wo * p_d_strides[1][3];
out_op(out_val, result, p_ds[0][d0_idx], p_ds[1][d1_idx]);
}
p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + wo] = out_val;
}
}
@@ -138,8 +204,8 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
const index_t g = remaining / N;
float acc = 0.0f;
const InDataType* in_gn = p_in + g * in_stride_g + n * in_stride_n;
const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k;
const InDataType* in_gn = p_ins[0] + g * in_stride_g + n * in_stride_n;
const WeiDataType* wei_gk = p_weis[0] + g * wei_stride_g + k * wei_stride_k;
for(index_t c = 0; c < C; ++c)
{
@@ -159,8 +225,52 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
if(wi >= 0 && wi < Wi)
{
in_op(in_val, in_gnch[wi]);
wei_op(wei_val, wei_gkcy[x]);
// Handle input element-wise operation with extra A tensors
if constexpr(NumAExtra == 0)
{
in_op(in_val, in_gnch[wi]);
}
else if constexpr(NumAExtra == 1)
{
const InDataType* in_extra = p_ins[1] + g * in_stride_g +
n * in_stride_n + c * in_stride_c +
hi * in_stride_h;
in_op(in_val, in_gnch[wi], in_extra[wi]);
}
else if constexpr(NumAExtra == 2)
{
const InDataType* in_extra0 =
p_ins[1] + g * in_stride_g + n * in_stride_n +
c * in_stride_c + hi * in_stride_h;
const InDataType* in_extra1 =
p_ins[2] + g * in_stride_g + n * in_stride_n +
c * in_stride_c + hi * in_stride_h;
in_op(in_val, in_gnch[wi], in_extra0[wi], in_extra1[wi]);
}
// Handle weight element-wise operation with extra B tensors
if constexpr(NumBExtra == 0)
{
wei_op(wei_val, wei_gkcy[x]);
}
else if constexpr(NumBExtra == 1)
{
const WeiDataType* wei_extra =
p_weis[1] + g * wei_stride_g + k * wei_stride_k +
c * wei_stride_c + y * wei_stride_y;
wei_op(wei_val, wei_gkcy[x], wei_extra[x]);
}
else if constexpr(NumBExtra == 2)
{
const WeiDataType* wei_extra0 =
p_weis[1] + g * wei_stride_g + k * wei_stride_k +
c * wei_stride_c + y * wei_stride_y;
const WeiDataType* wei_extra1 =
p_weis[2] + g * wei_stride_g + k * wei_stride_k +
c * wei_stride_c + y * wei_stride_y;
wei_op(wei_val, wei_gkcy[x], wei_extra0[x], wei_extra1[x]);
}
acc += type_convert<float>(in_val) * type_convert<float>(wei_val);
}
}
@@ -169,7 +279,30 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
}
OutDataType result = type_convert<OutDataType>(acc);
out_op(out_val, result);
// Handle output element-wise operation with D tensors
if constexpr(NumD == 0)
{
out_op(out_val, result);
}
else if constexpr(NumD == 1)
{
const long_index_t d_idx = g * p_d_strides[0][0] + n * p_d_strides[0][1] +
k * p_d_strides[0][2] + ho * p_d_strides[0][3] +
wo * p_d_strides[0][4];
out_op(out_val, result, p_ds[0][d_idx]);
}
else if constexpr(NumD == 2)
{
const long_index_t d0_idx = g * p_d_strides[0][0] + n * p_d_strides[0][1] +
k * p_d_strides[0][2] + ho * p_d_strides[0][3] +
wo * p_d_strides[0][4];
const long_index_t d1_idx = g * p_d_strides[1][0] + n * p_d_strides[1][1] +
k * p_d_strides[1][2] + ho * p_d_strides[1][3] +
wo * p_d_strides[1][4];
out_op(out_val, result, p_ds[0][d0_idx], p_ds[1][d1_idx]);
}
p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + ho * out_stride_h + wo] =
out_val;
}
@@ -208,8 +341,8 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
const index_t g = remaining / N;
float acc = 0.0f;
const InDataType* in_gn = p_in + g * in_stride_g + n * in_stride_n;
const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k;
const InDataType* in_gn = p_ins[0] + g * in_stride_g + n * in_stride_n;
const WeiDataType* wei_gk = p_weis[0] + g * wei_stride_g + k * wei_stride_k;
for(index_t c = 0; c < C; ++c)
{
@@ -237,8 +370,62 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
if(wi >= 0 && wi < Wi)
{
in_op(in_val, in_gncdh[wi]);
wei_op(wei_val, wei_gkczy[x]);
// Handle input element-wise operation with extra A tensors
if constexpr(NumAExtra == 0)
{
in_op(in_val, in_gncdh[wi]);
}
else if constexpr(NumAExtra == 1)
{
const InDataType* in_extra =
p_ins[1] + g * in_stride_g + n * in_stride_n +
c * in_stride_c + di * in_stride_d +
hi * in_stride_h;
in_op(in_val, in_gncdh[wi], in_extra[wi]);
}
else if constexpr(NumAExtra == 2)
{
const InDataType* in_extra0 =
p_ins[1] + g * in_stride_g + n * in_stride_n +
c * in_stride_c + di * in_stride_d +
hi * in_stride_h;
const InDataType* in_extra1 =
p_ins[2] + g * in_stride_g + n * in_stride_n +
c * in_stride_c + di * in_stride_d +
hi * in_stride_h;
in_op(
in_val, in_gncdh[wi], in_extra0[wi], in_extra1[wi]);
}
// Handle weight element-wise operation with extra B tensors
if constexpr(NumBExtra == 0)
{
wei_op(wei_val, wei_gkczy[x]);
}
else if constexpr(NumBExtra == 1)
{
const WeiDataType* wei_extra =
p_weis[1] + g * wei_stride_g + k * wei_stride_k +
c * wei_stride_c + z * wei_stride_z +
y * wei_stride_y;
wei_op(wei_val, wei_gkczy[x], wei_extra[x]);
}
else if constexpr(NumBExtra == 2)
{
const WeiDataType* wei_extra0 =
p_weis[1] + g * wei_stride_g + k * wei_stride_k +
c * wei_stride_c + z * wei_stride_z +
y * wei_stride_y;
const WeiDataType* wei_extra1 =
p_weis[2] + g * wei_stride_g + k * wei_stride_k +
c * wei_stride_c + z * wei_stride_z +
y * wei_stride_y;
wei_op(wei_val,
wei_gkczy[x],
wei_extra0[x],
wei_extra1[x]);
}
acc += type_convert<float>(in_val) *
type_convert<float>(wei_val);
}
@@ -250,15 +437,41 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
}
OutDataType result = type_convert<OutDataType>(acc);
out_op(out_val, result);
// Handle output element-wise operation with D tensors
if constexpr(NumD == 0)
{
out_op(out_val, result);
}
else if constexpr(NumD == 1)
{
const long_index_t d_idx = g * p_d_strides[0][0] + n * p_d_strides[0][1] +
k * p_d_strides[0][2] + do_idx * p_d_strides[0][3] +
ho * p_d_strides[0][4] + wo * p_d_strides[0][5];
out_op(out_val, result, p_ds[0][d_idx]);
}
else if constexpr(NumD == 2)
{
const long_index_t d0_idx = g * p_d_strides[0][0] + n * p_d_strides[0][1] +
k * p_d_strides[0][2] + do_idx * p_d_strides[0][3] +
ho * p_d_strides[0][4] + wo * p_d_strides[0][5];
const long_index_t d1_idx = g * p_d_strides[1][0] + n * p_d_strides[1][1] +
k * p_d_strides[1][2] + do_idx * p_d_strides[1][3] +
ho * p_d_strides[1][4] + wo * p_d_strides[1][5];
out_op(out_val, result, p_ds[0][d0_idx], p_ds[1][d1_idx]);
}
p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + do_idx * out_stride_d +
ho * out_stride_h + wo] = out_val;
}
}
}
// GPU reference convolution - takes ConvParam directly
template <typename InLayout,
// GPU reference convolution with multi-ABD support - takes ConvParam directly
template <ck::index_t NumAElementwise = 0,
ck::index_t NumBElementwise = 0,
ck::index_t NumDElementwise = 0,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename TIn,
@@ -266,15 +479,20 @@ template <typename InLayout,
typename TOut,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
void naive_conv_fwd(const TIn* p_in,
const TWei* p_wei,
TOut* p_out,
const ck::utils::conv::ConvParam& conv_param,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
typename OutElementwiseOperation,
typename TD = TOut> // D tensor type, defaults to TOut for backward compatibility
void naive_conv_fwd_multi_abd(
const std::array<const TIn*, NumAElementwise + 1>& p_ins,
const std::array<const TWei*, NumBElementwise + 1>& p_weis,
const std::array<const TD*, NumDElementwise>& p_ds,
TOut* p_out,
const ck::utils::conv::ConvParam& conv_param,
[[maybe_unused]] const std::array<std::vector<index_t>, NumDElementwise>& d_lengths,
const std::array<std::vector<index_t>, NumDElementwise>& d_strides,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
{
const auto ndim = conv_param.num_dim_spatial_;
@@ -303,13 +521,37 @@ void naive_conv_fwd(const TIn* p_in,
for(auto l : out_lengths)
out_total *= l;
// Allocate packed buffers
SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn));
SimpleDeviceMem wei_packed_buf(wei_total * sizeof(TWei));
// Allocate packed buffers for all A and B tensors
// Use separate allocations to avoid copy assignment issues with RAII wrapper
std::vector<SimpleDeviceMem> in_packed_bufs;
in_packed_bufs.reserve(NumAElementwise + 1);
for(index_t i = 0; i <= NumAElementwise; ++i)
{
in_packed_bufs.emplace_back(in_total * sizeof(TIn));
}
std::vector<SimpleDeviceMem> wei_packed_bufs;
wei_packed_bufs.reserve(NumBElementwise + 1);
for(index_t i = 0; i <= NumBElementwise; ++i)
{
wei_packed_bufs.emplace_back(wei_total * sizeof(TWei));
}
SimpleDeviceMem out_packed_buf(out_total * sizeof(TOut));
TIn* p_in_packed = static_cast<TIn*>(in_packed_buf.GetDeviceBuffer());
TWei* p_wei_packed = static_cast<TWei*>(wei_packed_buf.GetDeviceBuffer());
// Get packed buffer pointers
std::array<TIn*, NumAElementwise + 1> p_ins_packed;
for(index_t i = 0; i <= NumAElementwise; ++i)
{
p_ins_packed[i] = static_cast<TIn*>(in_packed_bufs[i].GetDeviceBuffer());
}
std::array<TWei*, NumBElementwise + 1> p_weis_packed;
for(index_t i = 0; i <= NumBElementwise; ++i)
{
p_weis_packed[i] = static_cast<TWei*>(wei_packed_bufs[i].GetDeviceBuffer());
}
TOut* p_out_packed = static_cast<TOut*>(out_packed_buf.GetDeviceBuffer());
// Compute strides and allocate device arrays for pack/unpack
@@ -347,12 +589,82 @@ void naive_conv_fwd(const TIn* p_in,
// Pack input and weight tensors to contiguous layout
constexpr int block_size = 256;
strided_copy_kernel<TIn, false>
<<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_in, p_in_packed, d_in_lengths, d_in_strides, dim_count, in_total);
strided_copy_kernel<TWei, false>
<<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_wei, p_wei_packed, d_wei_lengths, d_wei_strides, dim_count, wei_total);
// Pack all A tensors
for(index_t i = 0; i <= NumAElementwise; ++i)
{
strided_copy_kernel<TIn, false>
<<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_ins[i], p_ins_packed[i], d_in_lengths, d_in_strides, dim_count, in_total);
}
// Pack all B tensors
for(index_t i = 0; i <= NumBElementwise; ++i)
{
strided_copy_kernel<TWei, false>
<<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>(
p_weis[i], p_weis_packed[i], d_wei_lengths, d_wei_strides, dim_count, wei_total);
}
// Prepare D tensor stride arrays on device
// NOTE: D tensors are NOT packed - they are used directly with their original strides
// to support broadcasting (e.g., BiasGK layout with zero strides)
std::vector<SimpleDeviceMem> d_stride_bufs;
std::array<index_t*, NumDElementwise> p_d_strides_dev = {};
if constexpr(NumDElementwise > 0)
{
d_stride_bufs.reserve(NumDElementwise);
for(index_t i = 0; i < NumDElementwise; ++i)
{
// Allocate and copy strides to device
d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t));
p_d_strides_dev[i] = static_cast<index_t*>(d_stride_bufs[i].GetDeviceBuffer());
HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i],
d_strides[i].data(),
d_strides[i].size() * sizeof(index_t),
hipMemcpyHostToDevice));
}
}
// Create device arrays of pointers
SimpleDeviceMem ins_ptrs_buf((NumAElementwise + 1) * sizeof(TIn*));
SimpleDeviceMem weis_ptrs_buf((NumBElementwise + 1) * sizeof(TWei*));
SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*));
SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*));
TIn** d_ins_ptrs = static_cast<TIn**>(ins_ptrs_buf.GetDeviceBuffer());
TWei** d_weis_ptrs = static_cast<TWei**>(weis_ptrs_buf.GetDeviceBuffer());
TD** d_ds_ptrs = static_cast<TD**>(ds_ptrs_buf.GetDeviceBuffer());
index_t** d_d_strides_ptrs = static_cast<index_t**>(d_strides_ptrs_buf.GetDeviceBuffer());
HIP_CHECK_ERROR(hipMemcpy(d_ins_ptrs,
p_ins_packed.data(),
(NumAElementwise + 1) * sizeof(TIn*),
hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_weis_ptrs,
p_weis_packed.data(),
(NumBElementwise + 1) * sizeof(TWei*),
hipMemcpyHostToDevice));
if constexpr(NumDElementwise > 0)
{
// D tensors use original pointers (not packed) to support broadcasting
std::array<const TD*, NumDElementwise> p_ds_dev;
for(index_t i = 0; i < NumDElementwise; ++i)
{
p_ds_dev[i] = p_ds[i];
}
HIP_CHECK_ERROR(hipMemcpy(
d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs,
p_d_strides_dev.data(),
NumDElementwise * sizeof(index_t*),
hipMemcpyHostToDevice));
}
// Build conv parameter vectors for kernel invocation
std::vector<index_t> conv_strides(ndim);
@@ -370,15 +682,21 @@ void naive_conv_fwd(const TIn* p_in,
if(ndim == 1)
{
naive_conv_fwd_packed<1,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<out_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_packed,
naive_conv_fwd_packed_multi_abd<1,
NumAElementwise,
NumBElementwise,
NumDElementwise,
TIn,
TWei,
TOut,
TD,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<out_grid, block_size, 0, stream>>>(d_ins_ptrs,
d_weis_ptrs,
d_ds_ptrs,
d_d_strides_ptrs,
p_out_packed,
G,
N,
@@ -408,15 +726,21 @@ void naive_conv_fwd(const TIn* p_in,
}
else if(ndim == 2)
{
naive_conv_fwd_packed<2,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<out_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_packed,
naive_conv_fwd_packed_multi_abd<2,
NumAElementwise,
NumBElementwise,
NumDElementwise,
TIn,
TWei,
TOut,
TD,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<out_grid, block_size, 0, stream>>>(d_ins_ptrs,
d_weis_ptrs,
d_ds_ptrs,
d_d_strides_ptrs,
p_out_packed,
G,
N,
@@ -446,15 +770,21 @@ void naive_conv_fwd(const TIn* p_in,
}
else // 3D
{
naive_conv_fwd_packed<3,
TIn,
TWei,
TOut,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<out_grid, block_size, 0, stream>>>(p_in_packed,
p_wei_packed,
naive_conv_fwd_packed_multi_abd<3,
NumAElementwise,
NumBElementwise,
NumDElementwise,
TIn,
TWei,
TOut,
TD,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
<<<out_grid, block_size, 0, stream>>>(d_ins_ptrs,
d_weis_ptrs,
d_ds_ptrs,
d_d_strides_ptrs,
p_out_packed,
G,
N,
@@ -492,5 +822,43 @@ void naive_conv_fwd(const TIn* p_in,
// Memory automatically freed by SimpleDeviceMem destructors
}
// Original naive_conv_fwd - now a zero-overhead wrapper
template <typename InLayout,
typename WeiLayout,
typename OutLayout,
typename TIn,
typename TWei,
typename TOut,
typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
inline void naive_conv_fwd(const TIn* p_in,
const TWei* p_wei,
TOut* p_out,
const ck::utils::conv::ConvParam& conv_param,
InElementwiseOperation in_element_op = InElementwiseOperation{},
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
hipStream_t stream = nullptr)
{
std::array<const TIn*, 1> p_ins = {p_in};
std::array<const TWei*, 1> p_weis = {p_wei};
std::array<const TOut*, 0> p_ds = {};
std::array<std::vector<index_t>, 0> d_lengths = {};
std::array<std::vector<index_t>, 0> d_strides = {};
naive_conv_fwd_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_ins,
p_weis,
p_ds,
p_out,
conv_param,
d_lengths,
d_strides,
in_element_op,
wei_element_op,
out_element_op,
stream);
}
} // namespace ref
} // namespace ck

View File

@@ -22,9 +22,39 @@ struct SimpleDeviceMem
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&p_mem_), mem_size));
}
// Delete copy operations (resource should not be copied)
SimpleDeviceMem(const SimpleDeviceMem&) = delete;
SimpleDeviceMem& operator=(const SimpleDeviceMem&) = delete;
// Define move operations
SimpleDeviceMem(SimpleDeviceMem&& other) noexcept : p_mem_(other.p_mem_)
{
other.p_mem_ = nullptr;
}
SimpleDeviceMem& operator=(SimpleDeviceMem&& other) noexcept
{
if(this != &other)
{
if(p_mem_)
{
(void)hipFree(p_mem_);
}
p_mem_ = other.p_mem_;
other.p_mem_ = nullptr;
}
return *this;
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
~SimpleDeviceMem()
{
if(p_mem_)
{
(void)hipFree(p_mem_);
}
}
void* p_mem_;
};

View File

@@ -4,6 +4,9 @@
add_gtest_executable(test_gpu_reference_conv_fwd test_gpu_reference_conv_fwd.cpp)
target_link_libraries(test_gpu_reference_conv_fwd PRIVATE utility)
add_gtest_executable(test_gpu_reference_conv_fwd_multi_abd test_gpu_reference_conv_fwd_multi_abd.cpp)
target_link_libraries(test_gpu_reference_conv_fwd_multi_abd PRIVATE utility)
add_gtest_executable(test_gpu_reference_conv_bwd_data test_gpu_reference_conv_bwd_data.cpp)
target_link_libraries(test_gpu_reference_conv_bwd_data PRIVATE utility)

View File

@@ -381,5 +381,230 @@ bool test_conv_gpu_ref(const ck::utils::conv::ConvParam& params, ConvKernelType
}
}
// Forward convolution with D tensor support
template <index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename OutElementOp>
bool test_conv_fwd_with_d_tensor_impl(const ck::utils::conv::ConvParam& params,
const Tensor<InDataType>& input_cpu,
const Tensor<WeiDataType>& weight_cpu,
const Tensor<OutDataType>& d_cpu,
DeviceMem& input_dev,
DeviceMem& weight_dev,
DeviceMem& d_dev,
DeviceMem& output_dev,
OutElementOp out_element_op)
{
using InElementOp = tensor_operation::element_wise::PassThrough;
using WeiElementOp = tensor_operation::element_wise::PassThrough;
// Create D tensor lengths and strides for GPU reference
std::vector<index_t> d_lengths_vec(NDimSpatial + 3);
d_lengths_vec[0] = params.G_;
d_lengths_vec[1] = params.N_;
d_lengths_vec[2] = params.K_;
for(index_t i = 0; i < NDimSpatial; ++i)
{
d_lengths_vec[3 + i] = static_cast<index_t>(params.output_spatial_lengths_[i]);
}
std::vector<index_t> d_strides_vec =
ref::compute_conv_tensor_strides<OutLayout>(d_lengths_vec, params.num_dim_spatial_);
std::array<const OutDataType*, 1> d_ptrs = {
reinterpret_cast<const OutDataType*>(d_dev.GetDeviceBuffer())};
std::array<std::vector<index_t>, 1> d_lengths = {d_lengths_vec};
std::array<std::vector<index_t>, 1> d_strides = {d_strides_vec};
// Call GPU reference with D tensor
std::array<const InDataType*, 1> in_ptrs = {
reinterpret_cast<const InDataType*>(input_dev.GetDeviceBuffer())};
std::array<const WeiDataType*, 1> wei_ptrs = {
reinterpret_cast<const WeiDataType*>(weight_dev.GetDeviceBuffer())};
ref::naive_conv_fwd_multi_abd<0,
0,
1,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
OutDataType>( // Explicitly specify TD = OutDataType
in_ptrs,
wei_ptrs,
d_ptrs,
reinterpret_cast<OutDataType*>(output_dev.GetDeviceBuffer()),
params,
d_lengths,
d_strides,
InElementOp{},
WeiElementOp{},
out_element_op);
HIP_CHECK_ERROR(hipDeviceSynchronize());
// Run CPU reference
std::vector<long_index_t> strides_long(params.conv_filter_strides_.begin(),
params.conv_filter_strides_.end());
std::vector<long_index_t> dilations_long(params.conv_filter_dilations_.begin(),
params.conv_filter_dilations_.end());
std::vector<long_index_t> pads_long(params.input_left_pads_.begin(),
params.input_left_pads_.end());
Tensor<InDataType> input_ref = input_cpu;
Tensor<WeiDataType> weight_ref = weight_cpu;
Tensor<OutDataType> output_ref(
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(params));
std::array<Tensor<OutDataType>, 1> d_tensors_ref = {d_cpu};
auto ref_conv = tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
0, // NumA
0, // NumB
1 // NumD
>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_arg = ref_conv.MakeArgument(input_ref,
weight_ref,
output_ref,
strides_long,
dilations_long,
pads_long,
pads_long,
InElementOp{},
WeiElementOp{},
out_element_op,
{}, // A tensors
{}, // B tensors
d_tensors_ref);
ref_invoker.Run(ref_arg);
// Copy result from device and compare
Tensor<OutDataType> output_gpu(output_ref.mDesc);
output_dev.FromDevice(output_gpu.mData.data());
HIP_CHECK_ERROR(hipDeviceSynchronize());
// Compare results
return ck::utils::check_err(output_gpu, output_ref);
}
// Forward convolution with multiple A/B tensor support
template <index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename InElementOp,
typename WeiElementOp>
bool test_conv_fwd_with_multi_ab_impl(const ck::utils::conv::ConvParam& params,
const Tensor<InDataType>& input_cpu,
const Tensor<WeiDataType>& weight_cpu,
const Tensor<InDataType>& a_extra_cpu,
const Tensor<WeiDataType>& b_extra_cpu,
DeviceMem& input_dev,
DeviceMem& weight_dev,
DeviceMem& a_extra_dev,
DeviceMem& b_extra_dev,
DeviceMem& output_dev,
InElementOp in_element_op,
WeiElementOp wei_element_op)
{
using OutElementOp = tensor_operation::element_wise::PassThrough;
// Call GPU reference with extra A and B tensors
std::array<const InDataType*, 2> in_ptrs = {
reinterpret_cast<const InDataType*>(input_dev.GetDeviceBuffer()),
reinterpret_cast<const InDataType*>(a_extra_dev.GetDeviceBuffer())};
std::array<const WeiDataType*, 2> wei_ptrs = {
reinterpret_cast<const WeiDataType*>(weight_dev.GetDeviceBuffer()),
reinterpret_cast<const WeiDataType*>(b_extra_dev.GetDeviceBuffer())};
std::array<const OutDataType*, 0> d_ptrs = {};
std::array<std::vector<index_t>, 0> d_lengths = {};
std::array<std::vector<index_t>, 0> d_strides = {};
ref::naive_conv_fwd_multi_abd<1, 1, 0, InLayout, WeiLayout, OutLayout>(
in_ptrs,
wei_ptrs,
d_ptrs,
reinterpret_cast<OutDataType*>(output_dev.GetDeviceBuffer()),
params,
d_lengths,
d_strides,
in_element_op,
wei_element_op,
OutElementOp{});
HIP_CHECK_ERROR(hipDeviceSynchronize());
// Run CPU reference
std::vector<long_index_t> strides_long(params.conv_filter_strides_.begin(),
params.conv_filter_strides_.end());
std::vector<long_index_t> dilations_long(params.conv_filter_dilations_.begin(),
params.conv_filter_dilations_.end());
std::vector<long_index_t> pads_long(params.input_left_pads_.begin(),
params.input_left_pads_.end());
Tensor<InDataType> input_ref = input_cpu;
Tensor<WeiDataType> weight_ref = weight_cpu;
Tensor<OutDataType> output_ref(
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(params));
std::array<Tensor<InDataType>, 1> a_tensors_ref = {a_extra_cpu};
std::array<Tensor<WeiDataType>, 1> b_tensors_ref = {b_extra_cpu};
auto ref_conv = tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
1, // NumA
1, // NumB
0 // NumD
>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_arg = ref_conv.MakeArgument(input_ref,
weight_ref,
output_ref,
strides_long,
dilations_long,
pads_long,
pads_long,
in_element_op,
wei_element_op,
OutElementOp{},
a_tensors_ref,
b_tensors_ref,
{});
ref_invoker.Run(ref_arg);
// Copy result from device and compare
Tensor<OutDataType> output_gpu(output_ref.mDesc);
output_dev.FromDevice(output_gpu.mData.data());
HIP_CHECK_ERROR(hipDeviceSynchronize());
// Compare results
return ck::utils::check_err(output_gpu, output_ref);
}
} // namespace test
} // namespace ck

View File

@@ -0,0 +1,319 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include "gpu_reference_utils.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
using namespace ck;
using ck::test::ConvKernelType;
// ==================== D Tensor (Bias) Tests ====================
template <index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
bool test_conv_gpu_ref_with_bias(const ck::utils::conv::ConvParam& params)
{
using tensor_operation::element_wise::AddClamp;
// Create tensor descriptors
const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(params);
const auto wei_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(params);
const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(params);
// Create tensors
Tensor<InDataType> input(in_g_n_c_wis_desc);
Tensor<WeiDataType> weight(wei_g_k_c_xs_desc);
Tensor<OutDataType> output(out_g_n_k_wos_desc);
Tensor<OutDataType> bias(out_g_n_k_wos_desc); // Same shape as output
// Allocate device memory
DeviceMem input_dev(input.mData.size() * sizeof(InDataType));
DeviceMem weight_dev(weight.mData.size() * sizeof(WeiDataType));
DeviceMem bias_dev(bias.mData.size() * sizeof(OutDataType));
DeviceMem output_dev(output.mData.size() * sizeof(OutDataType));
// Initialize and copy tensors
test::initialize_and_copy_tensor(input, input_dev);
test::initialize_and_copy_tensor(weight, weight_dev);
test::initialize_and_copy_tensor(bias, bias_dev);
// Test with AddClamp (bias operation with clamping)
AddClamp out_element_op(0.0f, 6.0f); // Clamp between 0 and 6
return test::test_conv_fwd_with_d_tensor_impl<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(
params, input, weight, bias, input_dev, weight_dev, bias_dev, output_dev, out_element_op);
}
TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16Bias)
{
auto params = test::conv_test_shapes::get_2d_small();
bool result = test_conv_gpu_ref_with_bias<2,
half_t,
half_t,
half_t,
tensor_layout::convolution::GNCHW,
tensor_layout::convolution::GKCYX,
tensor_layout::convolution::GNKHW>(params);
EXPECT_TRUE(result);
}
TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32Bias)
{
auto params = test::conv_test_shapes::get_2d_medium();
bool result = test_conv_gpu_ref_with_bias<2,
float,
float,
float,
tensor_layout::convolution::GNCHW,
tensor_layout::convolution::GKCYX,
tensor_layout::convolution::GNKHW>(params);
EXPECT_TRUE(result);
}
TEST(GpuReferenceConvFwdMultiABD, Conv3DFP32Bias)
{
auto params = test::conv_test_shapes::get_3d_small();
bool result = test_conv_gpu_ref_with_bias<3,
float,
float,
float,
tensor_layout::convolution::GNCDHW,
tensor_layout::convolution::GKCZYX,
tensor_layout::convolution::GNKDHW>(params);
EXPECT_TRUE(result);
}
TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16GroupedG2Bias)
{
auto params = test::conv_test_shapes::get_2d_grouped_g2();
bool result = test_conv_gpu_ref_with_bias<2,
half_t,
half_t,
half_t,
tensor_layout::convolution::GNCHW,
tensor_layout::convolution::GKCYX,
tensor_layout::convolution::GNKHW>(params);
EXPECT_TRUE(result);
}
TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32GroupedG4Bias)
{
auto params = test::conv_test_shapes::get_2d_grouped_g4();
bool result = test_conv_gpu_ref_with_bias<2,
float,
float,
float,
tensor_layout::convolution::GNCHW,
tensor_layout::convolution::GKCYX,
tensor_layout::convolution::GNKHW>(params);
EXPECT_TRUE(result);
}
// ==================== D Tensor (Bilinear) Tests ====================
template <index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
bool test_conv_gpu_ref_with_bilinear(const ck::utils::conv::ConvParam& params)
{
using tensor_operation::element_wise::Bilinear;
// Create tensor descriptors
const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(params);
const auto wei_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(params);
const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(params);
// Create tensors
Tensor<InDataType> input(in_g_n_c_wis_desc);
Tensor<WeiDataType> weight(wei_g_k_c_xs_desc);
Tensor<OutDataType> output(out_g_n_k_wos_desc);
Tensor<OutDataType> d_tensor(out_g_n_k_wos_desc); // Same shape as output
// Allocate device memory
DeviceMem input_dev(input.mData.size() * sizeof(InDataType));
DeviceMem weight_dev(weight.mData.size() * sizeof(WeiDataType));
DeviceMem d_dev(d_tensor.mData.size() * sizeof(OutDataType));
DeviceMem output_dev(output.mData.size() * sizeof(OutDataType));
// Initialize and copy tensors
test::initialize_and_copy_tensor(input, input_dev);
test::initialize_and_copy_tensor(weight, weight_dev);
test::initialize_and_copy_tensor(d_tensor, d_dev);
// Test with Bilinear: y = alpha * conv_result + beta * d_tensor
Bilinear out_element_op(1.5f, 0.5f); // alpha=1.5, beta=0.5
return test::test_conv_fwd_with_d_tensor_impl<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(
params, input, weight, d_tensor, input_dev, weight_dev, d_dev, output_dev, out_element_op);
}
TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16Bilinear)
{
auto params = test::conv_test_shapes::get_2d_small();
bool result = test_conv_gpu_ref_with_bilinear<2,
half_t,
half_t,
half_t,
tensor_layout::convolution::GNCHW,
tensor_layout::convolution::GKCYX,
tensor_layout::convolution::GNKHW>(params);
EXPECT_TRUE(result);
}
TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32Bilinear)
{
auto params = test::conv_test_shapes::get_2d_medium();
bool result = test_conv_gpu_ref_with_bilinear<2,
float,
float,
float,
tensor_layout::convolution::GNCHW,
tensor_layout::convolution::GKCYX,
tensor_layout::convolution::GNKHW>(params);
EXPECT_TRUE(result);
}
TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16GroupedG2Bilinear)
{
auto params = test::conv_test_shapes::get_2d_grouped_g2();
bool result = test_conv_gpu_ref_with_bilinear<2,
half_t,
half_t,
half_t,
tensor_layout::convolution::GNCHW,
tensor_layout::convolution::GKCYX,
tensor_layout::convolution::GNKHW>(params);
EXPECT_TRUE(result);
}
// ==================== Multiple A/B (ScaleAdd) Tests ====================
template <index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
bool test_conv_gpu_ref_with_scaleadd(const ck::utils::conv::ConvParam& params)
{
using tensor_operation::element_wise::ScaleAdd;
// Create tensor descriptors
const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(params);
const auto wei_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(params);
const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(params);
// Create tensors
Tensor<InDataType> input(in_g_n_c_wis_desc);
Tensor<WeiDataType> weight(wei_g_k_c_xs_desc);
Tensor<OutDataType> output(out_g_n_k_wos_desc);
Tensor<InDataType> a_extra(in_g_n_c_wis_desc); // Extra A tensor (same shape as input)
Tensor<WeiDataType> b_extra(wei_g_k_c_xs_desc); // Extra B tensor (same shape as weight)
// Allocate device memory
DeviceMem input_dev(input.mData.size() * sizeof(InDataType));
DeviceMem weight_dev(weight.mData.size() * sizeof(WeiDataType));
DeviceMem a_extra_dev(a_extra.mData.size() * sizeof(InDataType));
DeviceMem b_extra_dev(b_extra.mData.size() * sizeof(WeiDataType));
DeviceMem output_dev(output.mData.size() * sizeof(OutDataType));
// Initialize and copy tensors
test::initialize_and_copy_tensor(input, input_dev);
test::initialize_and_copy_tensor(weight, weight_dev);
test::initialize_and_copy_tensor(a_extra, a_extra_dev);
test::initialize_and_copy_tensor(b_extra, b_extra_dev);
// Test with ScaleAdd: in_out = scale * in_0 + in_1, wei_out = scale * wei_0 + wei_1
ScaleAdd in_element_op(2.0f); // scale factor for input
ScaleAdd wei_element_op(1.5f); // scale factor for weight
return test::test_conv_fwd_with_multi_ab_impl<NDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(params,
input,
weight,
a_extra,
b_extra,
input_dev,
weight_dev,
a_extra_dev,
b_extra_dev,
output_dev,
in_element_op,
wei_element_op);
}
TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16ScaleAdd)
{
auto params = test::conv_test_shapes::get_2d_small();
bool result = test_conv_gpu_ref_with_scaleadd<2,
half_t,
half_t,
half_t,
tensor_layout::convolution::GNCHW,
tensor_layout::convolution::GKCYX,
tensor_layout::convolution::GNKHW>(params);
EXPECT_TRUE(result);
}
TEST(GpuReferenceConvFwdMultiABD, Conv2DFP32ScaleAdd)
{
auto params = test::conv_test_shapes::get_2d_medium();
bool result = test_conv_gpu_ref_with_scaleadd<2,
float,
float,
float,
tensor_layout::convolution::GNCHW,
tensor_layout::convolution::GKCYX,
tensor_layout::convolution::GNKHW>(params);
EXPECT_TRUE(result);
}
TEST(GpuReferenceConvFwdMultiABD, Conv2DFP16GroupedG2ScaleAdd)
{
auto params = test::conv_test_shapes::get_2d_grouped_g2();
bool result = test_conv_gpu_ref_with_scaleadd<2,
half_t,
half_t,
half_t,
tensor_layout::convolution::GNCHW,
tensor_layout::convolution::GKCYX,
tensor_layout::convolution::GNKHW>(params);
EXPECT_TRUE(result);
}