mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
[CK tests] Extend conv GPU reference (#3539)
* test_convnd_fwd
* test_convnd_bwd_data
* test_conv_bwd_data_scale
* test_grouped_convnd_fwd_clamp
* test_grouped_convnd_fwd_scale
* multiple A/B tensors and D tensor for fwd GPU ref
* test_grouped_convnd_fwd_scaleadd_ab
* test_grouped_convnd_fwd_bias_clamp
* test_grouped_convnd_fwd_bilinear
* test_grouped_convnd_fwd_gk_bias_clamp
* Extend GPU reference to enable batchnorm epilogue
* test_grouped_convnd_fwd{,_gk}_bias_bnorm_clamp
* test_grouped_conv_bwd_data_bilinear
* test_grouped_convnd_bwd_weight_bilinear
* Add missing template instantiation
* Perform operations in float in reference
* Slightly increase tolerance for batchnorm profiler
* Revert "Slightly increase tolerance for batchnorm profiler"
This reverts commit a3b2475229.
* Revert "test_grouped_convnd_fwd{,_gk}_bias_bnorm_clamp"
This reverts commit 6da4576060.
* Revert "Extend GPU reference to enable batchnorm epilogue"
This reverts commit e2f75fa10e.
* Clarify variable names
* Refactor elementwise ops into helper functions
* Make helpers C++17-compatible
[ROCm/composable_kernel commit: c190d8d61f]
This commit is contained in:
@@ -10,49 +10,55 @@
|
||||
#include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include <array>
|
||||
|
||||
namespace ck {
|
||||
namespace ref {
|
||||
|
||||
// Optimized backward data convolution kernel working with packed (contiguous) tensors
|
||||
// Computes gradients w.r.t. input from output gradients and weights
|
||||
// Assumes row-major packing: input[G][N][C][spatial], weight[G][K][C][filter],
|
||||
// output[G][N][K][spatial]
|
||||
// Optimized backward data convolution kernel working with packed (contiguous) tensors with
|
||||
// multi-ABD support Computes gradients w.r.t. input from output gradients and weights Assumes
|
||||
// row-major packing: input[G][N][C][spatial], weight[G][K][C][filter], output[G][N][K][spatial]
|
||||
template <index_t NDimSpatial,
|
||||
index_t NumAExtra, // Number of extra A (output gradient) tensors
|
||||
index_t NumBExtra, // Number of extra B (weight) tensors
|
||||
index_t NumD, // Number of D tensors
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename DDataType, // D tensor data type
|
||||
typename InElementOp,
|
||||
typename WeiElementOp,
|
||||
typename OutElementOp>
|
||||
__global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
const WeiDataType* __restrict__ p_wei,
|
||||
const OutDataType* __restrict__ p_out,
|
||||
index_t G,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t C,
|
||||
index_t Di,
|
||||
index_t Hi,
|
||||
index_t Wi,
|
||||
index_t Z,
|
||||
index_t Y,
|
||||
index_t X,
|
||||
index_t Do,
|
||||
index_t Ho,
|
||||
index_t Wo,
|
||||
index_t stride_z,
|
||||
index_t stride_y,
|
||||
index_t stride_x,
|
||||
index_t dilation_z,
|
||||
index_t dilation_y,
|
||||
index_t dilation_x,
|
||||
index_t pad_z,
|
||||
index_t pad_y,
|
||||
index_t pad_x,
|
||||
InElementOp in_op,
|
||||
WeiElementOp wei_op,
|
||||
OutElementOp out_op)
|
||||
__global__ void naive_conv_bwd_data_packed_multi_abd(InDataType* __restrict__ p_in,
|
||||
const WeiDataType* const* __restrict__ p_weis,
|
||||
const OutDataType* const* __restrict__ p_outs,
|
||||
const DDataType* const* __restrict__ p_ds,
|
||||
const index_t* const* __restrict__ p_d_strides,
|
||||
index_t G,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t C,
|
||||
index_t Di,
|
||||
index_t Hi,
|
||||
index_t Wi,
|
||||
index_t Z,
|
||||
index_t Y,
|
||||
index_t X,
|
||||
index_t Do,
|
||||
index_t Ho,
|
||||
index_t Wo,
|
||||
index_t stride_z,
|
||||
index_t stride_y,
|
||||
index_t stride_x,
|
||||
index_t dilation_z,
|
||||
index_t dilation_y,
|
||||
index_t dilation_x,
|
||||
index_t pad_z,
|
||||
index_t pad_y,
|
||||
index_t pad_x,
|
||||
InElementOp in_op,
|
||||
WeiElementOp wei_op,
|
||||
OutElementOp out_op)
|
||||
{
|
||||
const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const long_index_t num_threads = blockDim.x * gridDim.x;
|
||||
@@ -84,9 +90,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
const index_t n = remaining % N;
|
||||
const index_t g = remaining / N;
|
||||
|
||||
float acc = 0.0f;
|
||||
const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n;
|
||||
const WeiDataType* wei_g = p_wei + g * wei_stride_g;
|
||||
float acc = 0.0f;
|
||||
// Base pointers for current group and batch
|
||||
const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n;
|
||||
const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g;
|
||||
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
@@ -96,21 +103,39 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
long_index_t wo = w_tmp / stride_x;
|
||||
if(wo >= 0 && wo < Wo)
|
||||
{
|
||||
const OutDataType* out_gnk = out_gn;
|
||||
const WeiDataType* wei_gkc = wei_g + c * wei_stride_c;
|
||||
// Pointers at current filter position
|
||||
const OutDataType* output_grad_g_n_k = output_grad_g_n;
|
||||
const WeiDataType* weight_g_k_c = weight_g + c * wei_stride_c;
|
||||
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
out_op(out_val, out_gnk[k * out_stride_k + wo]);
|
||||
wei_op(wei_val, wei_gkc[k * wei_stride_k + x]);
|
||||
// Handle output gradient element-wise operation with extra A tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
|
||||
out_val,
|
||||
out_op,
|
||||
output_grad_g_n_k,
|
||||
p_outs + 1,
|
||||
g * out_stride_g + n * out_stride_n,
|
||||
k * out_stride_k + wo);
|
||||
|
||||
// Handle weight element-wise operation with extra B tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
|
||||
wei_val,
|
||||
wei_op,
|
||||
weight_g_k_c,
|
||||
p_weis + 1,
|
||||
g * wei_stride_g + c * wei_stride_c,
|
||||
k * wei_stride_k + x);
|
||||
|
||||
acc += type_convert<float>(out_val) * type_convert<float>(wei_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
InDataType result = type_convert<InDataType>(acc);
|
||||
in_op(in_val, result);
|
||||
detail::apply_d_tensor_elementwise_op<NumD>(
|
||||
in_val, in_op, acc, p_ds, p_d_strides, g, n, c, wi);
|
||||
|
||||
p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + wi] = in_val;
|
||||
}
|
||||
}
|
||||
@@ -142,9 +167,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
const index_t n = remaining % N;
|
||||
const index_t g = remaining / N;
|
||||
|
||||
float acc = 0.0f;
|
||||
const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n;
|
||||
const WeiDataType* wei_g = p_wei + g * wei_stride_g;
|
||||
float acc = 0.0f;
|
||||
// Base pointers for current group and batch
|
||||
const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n;
|
||||
const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g;
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
@@ -154,8 +180,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
long_index_t ho = h_tmp / stride_y;
|
||||
if(ho >= 0 && ho < Ho)
|
||||
{
|
||||
const OutDataType* out_gnkh = out_gn + ho * out_stride_h;
|
||||
const WeiDataType* wei_gkcy = wei_g + c * wei_stride_c + y * wei_stride_y;
|
||||
// Pointers at current spatial height and filter Y position
|
||||
const OutDataType* output_grad_at_h = output_grad_g_n + ho * out_stride_h;
|
||||
const WeiDataType* weight_at_c_y =
|
||||
weight_g + c * wei_stride_c + y * wei_stride_y;
|
||||
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
@@ -167,8 +195,25 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
{
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
out_op(out_val, out_gnkh[k * out_stride_k + wo]);
|
||||
wei_op(wei_val, wei_gkcy[k * wei_stride_k + x]);
|
||||
// Handle output gradient element-wise operation with extra
|
||||
// A tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
|
||||
out_val,
|
||||
out_op,
|
||||
output_grad_at_h,
|
||||
p_outs + 1,
|
||||
g * out_stride_g + n * out_stride_n + ho * out_stride_h,
|
||||
k * out_stride_k + wo);
|
||||
|
||||
// Handle weight element-wise operation with extra B tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
|
||||
wei_val,
|
||||
wei_op,
|
||||
weight_at_c_y,
|
||||
p_weis + 1,
|
||||
g * wei_stride_g + c * wei_stride_c + y * wei_stride_y,
|
||||
k * wei_stride_k + x);
|
||||
|
||||
acc += type_convert<float>(out_val) *
|
||||
type_convert<float>(wei_val);
|
||||
}
|
||||
@@ -179,8 +224,17 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
}
|
||||
}
|
||||
|
||||
InDataType result = type_convert<InDataType>(acc);
|
||||
in_op(in_val, result);
|
||||
detail::apply_d_tensor_elementwise_op<NumD>(in_val,
|
||||
in_op,
|
||||
acc,
|
||||
p_ds,
|
||||
p_d_strides,
|
||||
g,
|
||||
n,
|
||||
c,
|
||||
hi * p_d_strides[0][3] +
|
||||
wi * p_d_strides[0][4]);
|
||||
|
||||
p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + hi * in_stride_h + wi] =
|
||||
in_val;
|
||||
}
|
||||
@@ -218,9 +272,10 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
const index_t n = remaining % N;
|
||||
const index_t g = remaining / N;
|
||||
|
||||
float acc = 0.0f;
|
||||
const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n;
|
||||
const WeiDataType* wei_g = p_wei + g * wei_stride_g;
|
||||
float acc = 0.0f;
|
||||
// Base pointers for current group and batch
|
||||
const OutDataType* output_grad_g_n = p_outs[0] + g * out_stride_g + n * out_stride_n;
|
||||
const WeiDataType* weight_g = p_weis[0] + g * wei_stride_g;
|
||||
|
||||
for(index_t z = 0; z < Z; ++z)
|
||||
{
|
||||
@@ -230,8 +285,11 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
long_index_t do_idx = d_tmp / stride_z;
|
||||
if(do_idx >= 0 && do_idx < Do)
|
||||
{
|
||||
const OutDataType* out_gnkd = out_gn + do_idx * out_stride_d;
|
||||
const WeiDataType* wei_gkcz = wei_g + c * wei_stride_c + z * wei_stride_z;
|
||||
// Pointers at current spatial depth
|
||||
const OutDataType* output_grad_at_d =
|
||||
output_grad_g_n + do_idx * out_stride_d;
|
||||
const WeiDataType* weight_at_c_z =
|
||||
weight_g + c * wei_stride_c + z * wei_stride_z;
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
@@ -241,8 +299,11 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
long_index_t ho = h_tmp / stride_y;
|
||||
if(ho >= 0 && ho < Ho)
|
||||
{
|
||||
const OutDataType* out_gnkdh = out_gnkd + ho * out_stride_h;
|
||||
const WeiDataType* wei_gkczy = wei_gkcz + y * wei_stride_y;
|
||||
// Pointers at current spatial depth and height
|
||||
const OutDataType* output_grad_at_d_h =
|
||||
output_grad_at_d + ho * out_stride_h;
|
||||
const WeiDataType* weight_at_c_z_y =
|
||||
weight_at_c_z + y * wei_stride_y;
|
||||
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
@@ -254,10 +315,31 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
{
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
out_op(out_val,
|
||||
out_gnkdh[k * out_stride_k + wo]);
|
||||
wei_op(wei_val,
|
||||
wei_gkczy[k * wei_stride_k + x]);
|
||||
// Handle output gradient element-wise operation
|
||||
// with extra A tensors
|
||||
detail::apply_multi_tensor_elementwise_op<
|
||||
NumAExtra>(out_val,
|
||||
out_op,
|
||||
output_grad_at_d_h,
|
||||
p_outs + 1,
|
||||
g * out_stride_g +
|
||||
n * out_stride_n +
|
||||
do_idx * out_stride_d +
|
||||
ho * out_stride_h,
|
||||
k * out_stride_k + wo);
|
||||
|
||||
// Handle weight element-wise operation with
|
||||
// extra B tensors
|
||||
detail::apply_multi_tensor_elementwise_op<
|
||||
NumBExtra>(
|
||||
wei_val,
|
||||
wei_op,
|
||||
weight_at_c_z_y,
|
||||
p_weis + 1,
|
||||
g * wei_stride_g + c * wei_stride_c +
|
||||
z * wei_stride_z + y * wei_stride_y,
|
||||
k * wei_stride_k + x);
|
||||
|
||||
acc += type_convert<float>(out_val) *
|
||||
type_convert<float>(wei_val);
|
||||
}
|
||||
@@ -271,16 +353,28 @@ __global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
}
|
||||
}
|
||||
|
||||
InDataType result = type_convert<InDataType>(acc);
|
||||
in_op(in_val, result);
|
||||
detail::apply_d_tensor_elementwise_op<NumD>(
|
||||
in_val,
|
||||
in_op,
|
||||
acc,
|
||||
p_ds,
|
||||
p_d_strides,
|
||||
g,
|
||||
n,
|
||||
c,
|
||||
di * p_d_strides[0][3] + hi * p_d_strides[0][4] + wi * p_d_strides[0][5]);
|
||||
|
||||
p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + di * in_stride_d +
|
||||
hi * in_stride_h + wi] = in_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GPU reference backward data convolution - takes ConvParam directly
|
||||
template <typename InLayout,
|
||||
// GPU reference backward data convolution with multi-ABD support - takes ConvParam directly
|
||||
template <ck::index_t NumAElementwise = 0,
|
||||
ck::index_t NumBElementwise = 0,
|
||||
ck::index_t NumDElementwise = 0,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename TIn,
|
||||
@@ -288,15 +382,20 @@ template <typename InLayout,
|
||||
typename TOut,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
void naive_conv_bwd_data(TIn* p_in,
|
||||
const TWei* p_wei,
|
||||
const TOut* p_out,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
InElementwiseOperation in_element_op = InElementwiseOperation{},
|
||||
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
|
||||
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
|
||||
hipStream_t stream = nullptr)
|
||||
typename OutElementwiseOperation,
|
||||
typename TD = TIn> // D tensor type, defaults to TIn for backward compatibility
|
||||
void naive_conv_bwd_data_multi_abd(
|
||||
TIn* p_in,
|
||||
const std::array<const TWei*, NumBElementwise + 1>& p_weis,
|
||||
const std::array<const TOut*, NumAElementwise + 1>& p_outs,
|
||||
const std::array<const TD*, NumDElementwise>& p_ds,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
[[maybe_unused]] const std::array<std::vector<index_t>, NumDElementwise>& d_lengths,
|
||||
const std::array<std::vector<index_t>, NumDElementwise>& d_strides,
|
||||
InElementwiseOperation in_element_op = InElementwiseOperation{},
|
||||
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
|
||||
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
|
||||
hipStream_t stream = nullptr)
|
||||
{
|
||||
const auto ndim = conv_param.num_dim_spatial_;
|
||||
|
||||
@@ -327,12 +426,34 @@ void naive_conv_bwd_data(TIn* p_in,
|
||||
|
||||
// Allocate packed buffers
|
||||
SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn));
|
||||
SimpleDeviceMem wei_packed_buf(wei_total * sizeof(TWei));
|
||||
SimpleDeviceMem out_packed_buf(out_total * sizeof(TOut));
|
||||
|
||||
TIn* p_in_packed = static_cast<TIn*>(in_packed_buf.GetDeviceBuffer());
|
||||
TWei* p_wei_packed = static_cast<TWei*>(wei_packed_buf.GetDeviceBuffer());
|
||||
TOut* p_out_packed = static_cast<TOut*>(out_packed_buf.GetDeviceBuffer());
|
||||
std::vector<SimpleDeviceMem> wei_packed_bufs;
|
||||
wei_packed_bufs.reserve(NumBElementwise + 1);
|
||||
for(index_t i = 0; i <= NumBElementwise; ++i)
|
||||
{
|
||||
wei_packed_bufs.emplace_back(wei_total * sizeof(TWei));
|
||||
}
|
||||
|
||||
std::vector<SimpleDeviceMem> out_packed_bufs;
|
||||
out_packed_bufs.reserve(NumAElementwise + 1);
|
||||
for(index_t i = 0; i <= NumAElementwise; ++i)
|
||||
{
|
||||
out_packed_bufs.emplace_back(out_total * sizeof(TOut));
|
||||
}
|
||||
|
||||
TIn* p_in_packed = static_cast<TIn*>(in_packed_buf.GetDeviceBuffer());
|
||||
|
||||
std::array<TWei*, NumBElementwise + 1> p_weis_packed;
|
||||
for(index_t i = 0; i <= NumBElementwise; ++i)
|
||||
{
|
||||
p_weis_packed[i] = static_cast<TWei*>(wei_packed_bufs[i].GetDeviceBuffer());
|
||||
}
|
||||
|
||||
std::array<TOut*, NumAElementwise + 1> p_outs_packed;
|
||||
for(index_t i = 0; i <= NumAElementwise; ++i)
|
||||
{
|
||||
p_outs_packed[i] = static_cast<TOut*>(out_packed_bufs[i].GetDeviceBuffer());
|
||||
}
|
||||
|
||||
// Compute strides and allocate device arrays for pack/unpack
|
||||
std::vector<index_t> in_strides = compute_conv_tensor_strides<InLayout>(in_lengths, ndim);
|
||||
@@ -369,12 +490,76 @@ void naive_conv_bwd_data(TIn* p_in,
|
||||
|
||||
// Pack output and weight tensors to contiguous layout (inputs to bwd data)
|
||||
constexpr int block_size = 256;
|
||||
strided_copy_kernel<TOut, false>
|
||||
<<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_out, p_out_packed, d_out_lengths, d_out_strides, dim_count, out_total);
|
||||
strided_copy_kernel<TWei, false>
|
||||
<<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_wei, p_wei_packed, d_wei_lengths, d_wei_strides, dim_count, wei_total);
|
||||
|
||||
for(index_t i = 0; i <= NumAElementwise; ++i)
|
||||
{
|
||||
strided_copy_kernel<TOut, false>
|
||||
<<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_outs[i], p_outs_packed[i], d_out_lengths, d_out_strides, dim_count, out_total);
|
||||
}
|
||||
|
||||
for(index_t i = 0; i <= NumBElementwise; ++i)
|
||||
{
|
||||
strided_copy_kernel<TWei, false>
|
||||
<<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_weis[i], p_weis_packed[i], d_wei_lengths, d_wei_strides, dim_count, wei_total);
|
||||
}
|
||||
|
||||
// Prepare D tensor stride arrays on device
|
||||
std::vector<SimpleDeviceMem> d_stride_bufs;
|
||||
std::array<index_t*, NumDElementwise> p_d_strides_dev = {};
|
||||
|
||||
if constexpr(NumDElementwise > 0)
|
||||
{
|
||||
d_stride_bufs.reserve(NumDElementwise);
|
||||
|
||||
for(index_t i = 0; i < NumDElementwise; ++i)
|
||||
{
|
||||
d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t));
|
||||
p_d_strides_dev[i] = static_cast<index_t*>(d_stride_bufs[i].GetDeviceBuffer());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i],
|
||||
d_strides[i].data(),
|
||||
d_strides[i].size() * sizeof(index_t),
|
||||
hipMemcpyHostToDevice));
|
||||
}
|
||||
}
|
||||
|
||||
// Create device arrays of pointers
|
||||
SimpleDeviceMem weis_ptrs_buf((NumBElementwise + 1) * sizeof(TWei*));
|
||||
SimpleDeviceMem outs_ptrs_buf((NumAElementwise + 1) * sizeof(TOut*));
|
||||
SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*));
|
||||
SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*));
|
||||
|
||||
TWei** d_weis_ptrs = static_cast<TWei**>(weis_ptrs_buf.GetDeviceBuffer());
|
||||
TOut** d_outs_ptrs = static_cast<TOut**>(outs_ptrs_buf.GetDeviceBuffer());
|
||||
TD** d_ds_ptrs = static_cast<TD**>(ds_ptrs_buf.GetDeviceBuffer());
|
||||
index_t** d_d_strides_ptrs = static_cast<index_t**>(d_strides_ptrs_buf.GetDeviceBuffer());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_weis_ptrs,
|
||||
p_weis_packed.data(),
|
||||
(NumBElementwise + 1) * sizeof(TWei*),
|
||||
hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_outs_ptrs,
|
||||
p_outs_packed.data(),
|
||||
(NumAElementwise + 1) * sizeof(TOut*),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
if constexpr(NumDElementwise > 0)
|
||||
{
|
||||
std::array<const TD*, NumDElementwise> p_ds_dev;
|
||||
for(index_t i = 0; i < NumDElementwise; ++i)
|
||||
{
|
||||
p_ds_dev[i] = p_ds[i];
|
||||
}
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs,
|
||||
p_d_strides_dev.data(),
|
||||
NumDElementwise * sizeof(index_t*),
|
||||
hipMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
// Build conv parameter vectors for kernel invocation
|
||||
std::vector<index_t> conv_strides(ndim);
|
||||
@@ -392,16 +577,22 @@ void naive_conv_bwd_data(TIn* p_in,
|
||||
|
||||
if(ndim == 1)
|
||||
{
|
||||
naive_conv_bwd_data_packed<1,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
naive_conv_bwd_data_packed_multi_abd<1,
|
||||
NumAElementwise,
|
||||
NumBElementwise,
|
||||
NumDElementwise,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
TD,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<in_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
p_wei_packed,
|
||||
p_out_packed,
|
||||
d_weis_ptrs,
|
||||
d_outs_ptrs,
|
||||
d_ds_ptrs,
|
||||
d_d_strides_ptrs,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
@@ -430,16 +621,22 @@ void naive_conv_bwd_data(TIn* p_in,
|
||||
}
|
||||
else if(ndim == 2)
|
||||
{
|
||||
naive_conv_bwd_data_packed<2,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
naive_conv_bwd_data_packed_multi_abd<2,
|
||||
NumAElementwise,
|
||||
NumBElementwise,
|
||||
NumDElementwise,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
TD,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<in_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
p_wei_packed,
|
||||
p_out_packed,
|
||||
d_weis_ptrs,
|
||||
d_outs_ptrs,
|
||||
d_ds_ptrs,
|
||||
d_d_strides_ptrs,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
@@ -468,16 +665,22 @@ void naive_conv_bwd_data(TIn* p_in,
|
||||
}
|
||||
else // 3D
|
||||
{
|
||||
naive_conv_bwd_data_packed<3,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
naive_conv_bwd_data_packed_multi_abd<3,
|
||||
NumAElementwise,
|
||||
NumBElementwise,
|
||||
NumDElementwise,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
TD,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<in_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
p_wei_packed,
|
||||
p_out_packed,
|
||||
d_weis_ptrs,
|
||||
d_outs_ptrs,
|
||||
d_ds_ptrs,
|
||||
d_d_strides_ptrs,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
@@ -514,5 +717,43 @@ void naive_conv_bwd_data(TIn* p_in,
|
||||
// Memory automatically freed by SimpleDeviceMem destructors
|
||||
}
|
||||
|
||||
// Original naive_conv_bwd_data - now a zero-overhead wrapper
|
||||
template <typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
inline void naive_conv_bwd_data(TIn* p_in,
|
||||
const TWei* p_wei,
|
||||
const TOut* p_out,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
InElementwiseOperation in_element_op = InElementwiseOperation{},
|
||||
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
|
||||
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
|
||||
hipStream_t stream = nullptr)
|
||||
{
|
||||
std::array<const TWei*, 1> p_weis = {p_wei};
|
||||
std::array<const TOut*, 1> p_outs = {p_out};
|
||||
std::array<const TIn*, 0> p_ds = {};
|
||||
std::array<std::vector<index_t>, 0> d_lengths = {};
|
||||
std::array<std::vector<index_t>, 0> d_strides = {};
|
||||
|
||||
naive_conv_bwd_data_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_in,
|
||||
p_weis,
|
||||
p_outs,
|
||||
p_ds,
|
||||
conv_param,
|
||||
d_lengths,
|
||||
d_strides,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op,
|
||||
stream);
|
||||
}
|
||||
|
||||
} // namespace ref
|
||||
} // namespace ck
|
||||
|
||||
@@ -10,49 +10,58 @@
|
||||
#include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include <array>
|
||||
|
||||
namespace ck {
|
||||
namespace ref {
|
||||
|
||||
// Optimized backward weight convolution kernel working with packed (contiguous) tensors
|
||||
// Optimized backward weight convolution kernel working with packed (contiguous) tensors with
|
||||
// multi-ABD support
|
||||
// Assumes row-major packing: input[G][N][C][spatial], output_grad[G][N][K][spatial],
|
||||
// weight_grad[G][K][C][filter]
|
||||
// Computes gradient with respect to weights
|
||||
template <index_t NDimSpatial,
|
||||
index_t NumAExtra, // Number of extra A (input) tensors
|
||||
index_t NumBExtra, // Number of extra B (output gradient) tensors
|
||||
index_t NumD, // Number of D tensors
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename DDataType, // D tensor data type
|
||||
typename InElementOp,
|
||||
typename WeiElementOp,
|
||||
typename OutElementOp>
|
||||
__global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in,
|
||||
WeiDataType* __restrict__ p_wei_grad,
|
||||
const OutDataType* __restrict__ p_out_grad,
|
||||
index_t G,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t C,
|
||||
index_t Di,
|
||||
index_t Hi,
|
||||
index_t Wi,
|
||||
index_t Z,
|
||||
index_t Y,
|
||||
index_t X,
|
||||
index_t Do,
|
||||
index_t Ho,
|
||||
index_t Wo,
|
||||
index_t stride_z,
|
||||
index_t stride_y,
|
||||
index_t stride_x,
|
||||
index_t dilation_z,
|
||||
index_t dilation_y,
|
||||
index_t dilation_x,
|
||||
index_t pad_z,
|
||||
index_t pad_y,
|
||||
index_t pad_x,
|
||||
InElementOp in_op,
|
||||
WeiElementOp wei_op,
|
||||
OutElementOp out_op)
|
||||
__global__ void
|
||||
naive_conv_bwd_weight_packed_multi_abd(const InDataType* const* __restrict__ p_ins,
|
||||
WeiDataType* __restrict__ p_wei_grad,
|
||||
const OutDataType* const* __restrict__ p_out_grads,
|
||||
const DDataType* const* __restrict__ p_ds,
|
||||
const index_t* const* __restrict__ p_d_strides,
|
||||
index_t G,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t C,
|
||||
index_t Di,
|
||||
index_t Hi,
|
||||
index_t Wi,
|
||||
index_t Z,
|
||||
index_t Y,
|
||||
index_t X,
|
||||
index_t Do,
|
||||
index_t Ho,
|
||||
index_t Wo,
|
||||
index_t stride_z,
|
||||
index_t stride_y,
|
||||
index_t stride_x,
|
||||
index_t dilation_z,
|
||||
index_t dilation_y,
|
||||
index_t dilation_x,
|
||||
index_t pad_z,
|
||||
index_t pad_y,
|
||||
index_t pad_x,
|
||||
InElementOp in_op,
|
||||
WeiElementOp wei_op,
|
||||
OutElementOp out_op)
|
||||
{
|
||||
const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const long_index_t num_threads = blockDim.x * gridDim.x;
|
||||
@@ -84,30 +93,50 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in
|
||||
const index_t k = remaining % K;
|
||||
const index_t g = remaining / K;
|
||||
|
||||
float acc = 0.0f;
|
||||
const InDataType* in_g = p_in + g * in_stride_g;
|
||||
const OutDataType* out_grad = p_out_grad + g * out_stride_g;
|
||||
float acc = 0.0f;
|
||||
// Base pointers for current group
|
||||
const InDataType* input_g = p_ins[0] + g * in_stride_g;
|
||||
const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g;
|
||||
|
||||
// Loop over batch and output positions
|
||||
for(index_t n = 0; n < N; ++n)
|
||||
{
|
||||
const InDataType* in_gn = in_g + n * in_stride_n + c * in_stride_c;
|
||||
const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k;
|
||||
// Pointers at current batch and input channel
|
||||
const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c;
|
||||
const OutDataType* output_grad_at_n_k =
|
||||
output_grad_g + n * out_stride_n + k * out_stride_k;
|
||||
|
||||
for(index_t wo = 0; wo < Wo; ++wo)
|
||||
{
|
||||
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
|
||||
if(wi >= 0 && wi < Wi)
|
||||
{
|
||||
in_op(in_val, in_gn[wi]);
|
||||
out_op(out_val, out_gn_k[wo]);
|
||||
// Handle input element-wise operation with extra A tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
|
||||
in_val,
|
||||
in_op,
|
||||
input_at_n_c,
|
||||
p_ins + 1,
|
||||
g * in_stride_g + n * in_stride_n + c * in_stride_c,
|
||||
wi);
|
||||
|
||||
// Handle output gradient element-wise operation with extra B tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
|
||||
out_val,
|
||||
out_op,
|
||||
output_grad_at_n_k,
|
||||
p_out_grads + 1,
|
||||
g * out_stride_g + n * out_stride_n + k * out_stride_k,
|
||||
wo);
|
||||
|
||||
acc += type_convert<float>(out_val) * type_convert<float>(in_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
WeiDataType result = type_convert<WeiDataType>(acc);
|
||||
wei_op(wei_val, result);
|
||||
detail::apply_d_tensor_elementwise_op<NumD>(
|
||||
wei_val, wei_op, acc, p_ds, p_d_strides, g, k, c, x);
|
||||
|
||||
p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + x] = wei_val;
|
||||
}
|
||||
}
|
||||
@@ -139,31 +168,55 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in
|
||||
const index_t k = remaining % K;
|
||||
const index_t g = remaining / K;
|
||||
|
||||
float acc = 0.0f;
|
||||
const InDataType* in_g = p_in + g * in_stride_g;
|
||||
const OutDataType* out_grad = p_out_grad + g * out_stride_g;
|
||||
float acc = 0.0f;
|
||||
// Base pointers for current group
|
||||
const InDataType* input_g = p_ins[0] + g * in_stride_g;
|
||||
const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g;
|
||||
|
||||
// Loop over batch and output positions
|
||||
for(index_t n = 0; n < N; ++n)
|
||||
{
|
||||
const InDataType* in_gnc = in_g + n * in_stride_n + c * in_stride_c;
|
||||
const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k;
|
||||
// Pointers at current batch and input channel
|
||||
const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c;
|
||||
const OutDataType* output_grad_at_n_k =
|
||||
output_grad_g + n * out_stride_n + k * out_stride_k;
|
||||
|
||||
for(index_t ho = 0; ho < Ho; ++ho)
|
||||
{
|
||||
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
|
||||
if(hi >= 0 && hi < Hi)
|
||||
{
|
||||
const InDataType* in_gnch = in_gnc + hi * in_stride_h;
|
||||
const OutDataType* out_gn_kh = out_gn_k + ho * out_stride_h;
|
||||
// Pointers at current spatial height
|
||||
const InDataType* input_at_h = input_at_n_c + hi * in_stride_h;
|
||||
const OutDataType* output_grad_at_h =
|
||||
output_grad_at_n_k + ho * out_stride_h;
|
||||
|
||||
for(index_t wo = 0; wo < Wo; ++wo)
|
||||
{
|
||||
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
|
||||
if(wi >= 0 && wi < Wi)
|
||||
{
|
||||
in_op(in_val, in_gnch[wi]);
|
||||
out_op(out_val, out_gn_kh[wo]);
|
||||
// Handle input element-wise operation with extra A tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
|
||||
in_val,
|
||||
in_op,
|
||||
input_at_h,
|
||||
p_ins + 1,
|
||||
g * in_stride_g + n * in_stride_n + c * in_stride_c +
|
||||
hi * in_stride_h,
|
||||
wi);
|
||||
|
||||
// Handle output gradient element-wise operation with extra B
|
||||
// tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
|
||||
out_val,
|
||||
out_op,
|
||||
output_grad_at_h,
|
||||
p_out_grads + 1,
|
||||
g * out_stride_g + n * out_stride_n + k * out_stride_k +
|
||||
ho * out_stride_h,
|
||||
wo);
|
||||
|
||||
acc += type_convert<float>(out_val) * type_convert<float>(in_val);
|
||||
}
|
||||
}
|
||||
@@ -171,8 +224,17 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in
|
||||
}
|
||||
}
|
||||
|
||||
WeiDataType result = type_convert<WeiDataType>(acc);
|
||||
wei_op(wei_val, result);
|
||||
detail::apply_d_tensor_elementwise_op<NumD>(wei_val,
|
||||
wei_op,
|
||||
acc,
|
||||
p_ds,
|
||||
p_d_strides,
|
||||
g,
|
||||
k,
|
||||
c,
|
||||
y * p_d_strides[0][3] +
|
||||
x * p_d_strides[0][4]);
|
||||
|
||||
p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + y * wei_stride_y +
|
||||
x] = wei_val;
|
||||
}
|
||||
@@ -210,39 +272,65 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in
|
||||
const index_t k = remaining % K;
|
||||
const index_t g = remaining / K;
|
||||
|
||||
float acc = 0.0f;
|
||||
const InDataType* in_g = p_in + g * in_stride_g;
|
||||
const OutDataType* out_grad = p_out_grad + g * out_stride_g;
|
||||
float acc = 0.0f;
|
||||
// Base pointers for current group
|
||||
const InDataType* input_g = p_ins[0] + g * in_stride_g;
|
||||
const OutDataType* output_grad_g = p_out_grads[0] + g * out_stride_g;
|
||||
|
||||
// Loop over batch and output positions
|
||||
for(index_t n = 0; n < N; ++n)
|
||||
{
|
||||
const InDataType* in_gnc = in_g + n * in_stride_n + c * in_stride_c;
|
||||
const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k;
|
||||
// Pointers at current batch and input channel
|
||||
const InDataType* input_at_n_c = input_g + n * in_stride_n + c * in_stride_c;
|
||||
const OutDataType* output_grad_at_n_k =
|
||||
output_grad_g + n * out_stride_n + k * out_stride_k;
|
||||
|
||||
for(index_t do_idx = 0; do_idx < Do; ++do_idx)
|
||||
{
|
||||
long_index_t di = do_idx * stride_z + z * dilation_z - pad_z;
|
||||
if(di >= 0 && di < Di)
|
||||
{
|
||||
const InDataType* in_gncd = in_gnc + di * in_stride_d;
|
||||
const OutDataType* out_gn_kd = out_gn_k + do_idx * out_stride_d;
|
||||
// Pointers at current spatial depth
|
||||
const InDataType* input_at_d = input_at_n_c + di * in_stride_d;
|
||||
const OutDataType* output_grad_at_d =
|
||||
output_grad_at_n_k + do_idx * out_stride_d;
|
||||
|
||||
for(index_t ho = 0; ho < Ho; ++ho)
|
||||
{
|
||||
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
|
||||
if(hi >= 0 && hi < Hi)
|
||||
{
|
||||
const InDataType* in_gncdh = in_gncd + hi * in_stride_h;
|
||||
const OutDataType* out_gn_kdh = out_gn_kd + ho * out_stride_h;
|
||||
// Pointers at current spatial depth and height
|
||||
const InDataType* input_at_d_h = input_at_d + hi * in_stride_h;
|
||||
const OutDataType* output_grad_at_d_h =
|
||||
output_grad_at_d + ho * out_stride_h;
|
||||
|
||||
for(index_t wo = 0; wo < Wo; ++wo)
|
||||
{
|
||||
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
|
||||
if(wi >= 0 && wi < Wi)
|
||||
{
|
||||
in_op(in_val, in_gncdh[wi]);
|
||||
out_op(out_val, out_gn_kdh[wo]);
|
||||
// Handle input element-wise operation with extra A tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
|
||||
in_val,
|
||||
in_op,
|
||||
input_at_d_h,
|
||||
p_ins + 1,
|
||||
g * in_stride_g + n * in_stride_n + c * in_stride_c +
|
||||
di * in_stride_d + hi * in_stride_h,
|
||||
wi);
|
||||
|
||||
// Handle output gradient element-wise operation with extra
|
||||
// B tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
|
||||
out_val,
|
||||
out_op,
|
||||
output_grad_at_d_h,
|
||||
p_out_grads + 1,
|
||||
g * out_stride_g + n * out_stride_n + k * out_stride_k +
|
||||
do_idx * out_stride_d + ho * out_stride_h,
|
||||
wo);
|
||||
|
||||
acc += type_convert<float>(out_val) *
|
||||
type_convert<float>(in_val);
|
||||
}
|
||||
@@ -253,16 +341,28 @@ __global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in
|
||||
}
|
||||
}
|
||||
|
||||
WeiDataType result = type_convert<WeiDataType>(acc);
|
||||
wei_op(wei_val, result);
|
||||
detail::apply_d_tensor_elementwise_op<NumD>(
|
||||
wei_val,
|
||||
wei_op,
|
||||
acc,
|
||||
p_ds,
|
||||
p_d_strides,
|
||||
g,
|
||||
k,
|
||||
c,
|
||||
z * p_d_strides[0][3] + y * p_d_strides[0][4] + x * p_d_strides[0][5]);
|
||||
|
||||
p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + z * wei_stride_z +
|
||||
y * wei_stride_y + x] = wei_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GPU reference backward weight convolution - takes ConvParam directly
|
||||
template <typename InLayout,
|
||||
// GPU reference backward weight convolution with multi-ABD support - takes ConvParam directly
|
||||
template <ck::index_t NumAElementwise = 0,
|
||||
ck::index_t NumBElementwise = 0,
|
||||
ck::index_t NumDElementwise = 0,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename TIn,
|
||||
@@ -270,15 +370,20 @@ template <typename InLayout,
|
||||
typename TOut,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
void naive_conv_bwd_weight(const TIn* p_in,
|
||||
TWei* p_wei_grad,
|
||||
const TOut* p_out,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
InElementwiseOperation in_element_op = InElementwiseOperation{},
|
||||
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
|
||||
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
|
||||
hipStream_t stream = nullptr)
|
||||
typename OutElementwiseOperation,
|
||||
typename TD = TWei> // D tensor type, defaults to TWei for backward compatibility
|
||||
void naive_conv_bwd_weight_multi_abd(
|
||||
const std::array<const TIn*, NumAElementwise + 1>& p_ins,
|
||||
TWei* p_wei_grad,
|
||||
const std::array<const TOut*, NumBElementwise + 1>& p_outs,
|
||||
const std::array<const TD*, NumDElementwise>& p_ds,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
[[maybe_unused]] const std::array<std::vector<index_t>, NumDElementwise>& d_lengths,
|
||||
const std::array<std::vector<index_t>, NumDElementwise>& d_strides,
|
||||
InElementwiseOperation in_element_op = InElementwiseOperation{},
|
||||
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
|
||||
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
|
||||
hipStream_t stream = nullptr)
|
||||
{
|
||||
const auto ndim = conv_param.num_dim_spatial_;
|
||||
|
||||
@@ -308,13 +413,35 @@ void naive_conv_bwd_weight(const TIn* p_in,
|
||||
out_total *= l;
|
||||
|
||||
// Allocate packed buffers
|
||||
SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn));
|
||||
SimpleDeviceMem wei_grad_packed_buf(wei_total * sizeof(TWei));
|
||||
SimpleDeviceMem out_grad_packed_buf(out_total * sizeof(TOut));
|
||||
std::vector<SimpleDeviceMem> in_packed_bufs;
|
||||
in_packed_bufs.reserve(NumAElementwise + 1);
|
||||
for(index_t i = 0; i <= NumAElementwise; ++i)
|
||||
{
|
||||
in_packed_bufs.emplace_back(in_total * sizeof(TIn));
|
||||
}
|
||||
|
||||
SimpleDeviceMem wei_grad_packed_buf(wei_total * sizeof(TWei));
|
||||
|
||||
std::vector<SimpleDeviceMem> out_grad_packed_bufs;
|
||||
out_grad_packed_bufs.reserve(NumBElementwise + 1);
|
||||
for(index_t i = 0; i <= NumBElementwise; ++i)
|
||||
{
|
||||
out_grad_packed_bufs.emplace_back(out_total * sizeof(TOut));
|
||||
}
|
||||
|
||||
std::array<TIn*, NumAElementwise + 1> p_ins_packed;
|
||||
for(index_t i = 0; i <= NumAElementwise; ++i)
|
||||
{
|
||||
p_ins_packed[i] = static_cast<TIn*>(in_packed_bufs[i].GetDeviceBuffer());
|
||||
}
|
||||
|
||||
TIn* p_in_packed = static_cast<TIn*>(in_packed_buf.GetDeviceBuffer());
|
||||
TWei* p_wei_grad_packed = static_cast<TWei*>(wei_grad_packed_buf.GetDeviceBuffer());
|
||||
TOut* p_out_grad_packed = static_cast<TOut*>(out_grad_packed_buf.GetDeviceBuffer());
|
||||
|
||||
std::array<TOut*, NumBElementwise + 1> p_out_grads_packed;
|
||||
for(index_t i = 0; i <= NumBElementwise; ++i)
|
||||
{
|
||||
p_out_grads_packed[i] = static_cast<TOut*>(out_grad_packed_bufs[i].GetDeviceBuffer());
|
||||
}
|
||||
|
||||
// Compute strides and allocate device arrays for pack/unpack
|
||||
std::vector<index_t> in_strides = compute_conv_tensor_strides<InLayout>(in_lengths, ndim);
|
||||
@@ -351,12 +478,81 @@ void naive_conv_bwd_weight(const TIn* p_in,
|
||||
|
||||
// Pack input and output_grad tensors to contiguous layout (inputs to bwd weight)
|
||||
constexpr int block_size = 256;
|
||||
strided_copy_kernel<TIn, false>
|
||||
<<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_in, p_in_packed, d_in_lengths, d_in_strides, dim_count, in_total);
|
||||
strided_copy_kernel<TOut, false>
|
||||
<<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_out, p_out_grad_packed, d_out_lengths, d_out_strides, dim_count, out_total);
|
||||
|
||||
for(index_t i = 0; i <= NumAElementwise; ++i)
|
||||
{
|
||||
strided_copy_kernel<TIn, false>
|
||||
<<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_ins[i], p_ins_packed[i], d_in_lengths, d_in_strides, dim_count, in_total);
|
||||
}
|
||||
|
||||
for(index_t i = 0; i <= NumBElementwise; ++i)
|
||||
{
|
||||
strided_copy_kernel<TOut, false>
|
||||
<<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_outs[i],
|
||||
p_out_grads_packed[i],
|
||||
d_out_lengths,
|
||||
d_out_strides,
|
||||
dim_count,
|
||||
out_total);
|
||||
}
|
||||
|
||||
// Prepare D tensor stride arrays on device
|
||||
std::vector<SimpleDeviceMem> d_stride_bufs;
|
||||
std::array<index_t*, NumDElementwise> p_d_strides_dev = {};
|
||||
|
||||
if constexpr(NumDElementwise > 0)
|
||||
{
|
||||
d_stride_bufs.reserve(NumDElementwise);
|
||||
|
||||
for(index_t i = 0; i < NumDElementwise; ++i)
|
||||
{
|
||||
d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t));
|
||||
p_d_strides_dev[i] = static_cast<index_t*>(d_stride_bufs[i].GetDeviceBuffer());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i],
|
||||
d_strides[i].data(),
|
||||
d_strides[i].size() * sizeof(index_t),
|
||||
hipMemcpyHostToDevice));
|
||||
}
|
||||
}
|
||||
|
||||
// Create device arrays of pointers
|
||||
SimpleDeviceMem ins_ptrs_buf((NumAElementwise + 1) * sizeof(TIn*));
|
||||
SimpleDeviceMem out_grads_ptrs_buf((NumBElementwise + 1) * sizeof(TOut*));
|
||||
SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*));
|
||||
SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*));
|
||||
|
||||
TIn** d_ins_ptrs = static_cast<TIn**>(ins_ptrs_buf.GetDeviceBuffer());
|
||||
TOut** d_out_grads_ptrs = static_cast<TOut**>(out_grads_ptrs_buf.GetDeviceBuffer());
|
||||
TD** d_ds_ptrs = static_cast<TD**>(ds_ptrs_buf.GetDeviceBuffer());
|
||||
index_t** d_d_strides_ptrs = static_cast<index_t**>(d_strides_ptrs_buf.GetDeviceBuffer());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_ins_ptrs,
|
||||
p_ins_packed.data(),
|
||||
(NumAElementwise + 1) * sizeof(TIn*),
|
||||
hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_out_grads_ptrs,
|
||||
p_out_grads_packed.data(),
|
||||
(NumBElementwise + 1) * sizeof(TOut*),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
if constexpr(NumDElementwise > 0)
|
||||
{
|
||||
std::array<const TD*, NumDElementwise> p_ds_dev;
|
||||
for(index_t i = 0; i < NumDElementwise; ++i)
|
||||
{
|
||||
p_ds_dev[i] = p_ds[i];
|
||||
}
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs,
|
||||
p_d_strides_dev.data(),
|
||||
NumDElementwise * sizeof(index_t*),
|
||||
hipMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
// Build conv parameter vectors for kernel invocation
|
||||
std::vector<index_t> conv_strides(ndim);
|
||||
@@ -374,16 +570,22 @@ void naive_conv_bwd_weight(const TIn* p_in,
|
||||
|
||||
if(ndim == 1)
|
||||
{
|
||||
naive_conv_bwd_weight_packed<1,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<wei_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
naive_conv_bwd_weight_packed_multi_abd<1,
|
||||
NumAElementwise,
|
||||
NumBElementwise,
|
||||
NumDElementwise,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
TD,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<wei_grid, block_size, 0, stream>>>(d_ins_ptrs,
|
||||
p_wei_grad_packed,
|
||||
p_out_grad_packed,
|
||||
d_out_grads_ptrs,
|
||||
d_ds_ptrs,
|
||||
d_d_strides_ptrs,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
@@ -412,16 +614,22 @@ void naive_conv_bwd_weight(const TIn* p_in,
|
||||
}
|
||||
else if(ndim == 2)
|
||||
{
|
||||
naive_conv_bwd_weight_packed<2,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<wei_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
naive_conv_bwd_weight_packed_multi_abd<2,
|
||||
NumAElementwise,
|
||||
NumBElementwise,
|
||||
NumDElementwise,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
TD,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<wei_grid, block_size, 0, stream>>>(d_ins_ptrs,
|
||||
p_wei_grad_packed,
|
||||
p_out_grad_packed,
|
||||
d_out_grads_ptrs,
|
||||
d_ds_ptrs,
|
||||
d_d_strides_ptrs,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
@@ -450,16 +658,22 @@ void naive_conv_bwd_weight(const TIn* p_in,
|
||||
}
|
||||
else // 3D
|
||||
{
|
||||
naive_conv_bwd_weight_packed<3,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<wei_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
naive_conv_bwd_weight_packed_multi_abd<3,
|
||||
NumAElementwise,
|
||||
NumBElementwise,
|
||||
NumDElementwise,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
TD,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<wei_grid, block_size, 0, stream>>>(d_ins_ptrs,
|
||||
p_wei_grad_packed,
|
||||
p_out_grad_packed,
|
||||
d_out_grads_ptrs,
|
||||
d_ds_ptrs,
|
||||
d_d_strides_ptrs,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
@@ -496,5 +710,44 @@ void naive_conv_bwd_weight(const TIn* p_in,
|
||||
// Memory automatically freed by SimpleDeviceMem destructors
|
||||
}
|
||||
|
||||
// Original naive_conv_bwd_weight - now a zero-overhead wrapper
|
||||
template <typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
inline void
|
||||
naive_conv_bwd_weight(const TIn* p_in,
|
||||
TWei* p_wei_grad,
|
||||
const TOut* p_out,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
InElementwiseOperation in_element_op = InElementwiseOperation{},
|
||||
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
|
||||
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
|
||||
hipStream_t stream = nullptr)
|
||||
{
|
||||
std::array<const TIn*, 1> p_ins = {p_in};
|
||||
std::array<const TOut*, 1> p_outs = {p_out};
|
||||
std::array<const TWei*, 0> p_ds = {};
|
||||
std::array<std::vector<index_t>, 0> d_lengths = {};
|
||||
std::array<std::vector<index_t>, 0> d_strides = {};
|
||||
|
||||
naive_conv_bwd_weight_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_ins,
|
||||
p_wei_grad,
|
||||
p_outs,
|
||||
p_ds,
|
||||
conv_param,
|
||||
d_lengths,
|
||||
d_strides,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op,
|
||||
stream);
|
||||
}
|
||||
|
||||
} // namespace ref
|
||||
} // namespace ck
|
||||
|
||||
@@ -10,48 +10,56 @@
|
||||
#include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include <array>
|
||||
|
||||
namespace ck {
|
||||
namespace ref {
|
||||
|
||||
// Optimized convolution kernel working with packed (contiguous) tensors
|
||||
// Optimized convolution kernel working with packed (contiguous) tensors with multi-ABD support
|
||||
// Assumes row-major packing: input[G][N][C][spatial], weight[G][K][C][filter],
|
||||
// output[G][N][K][spatial]
|
||||
template <index_t NDimSpatial,
|
||||
index_t NumAExtra, // Number of extra A (input) tensors
|
||||
index_t NumBExtra, // Number of extra B (weight) tensors
|
||||
index_t NumD, // Number of D tensors
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename DDataType, // D tensor data type
|
||||
typename InElementOp,
|
||||
typename WeiElementOp,
|
||||
typename OutElementOp>
|
||||
__global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
|
||||
const WeiDataType* __restrict__ p_wei,
|
||||
OutDataType* __restrict__ p_out,
|
||||
index_t G,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t C,
|
||||
index_t Di,
|
||||
index_t Hi,
|
||||
index_t Wi,
|
||||
index_t Z,
|
||||
index_t Y,
|
||||
index_t X,
|
||||
index_t Do,
|
||||
index_t Ho,
|
||||
index_t Wo,
|
||||
index_t stride_z,
|
||||
index_t stride_y,
|
||||
index_t stride_x,
|
||||
index_t dilation_z,
|
||||
index_t dilation_y,
|
||||
index_t dilation_x,
|
||||
index_t pad_z,
|
||||
index_t pad_y,
|
||||
index_t pad_x,
|
||||
InElementOp in_op,
|
||||
WeiElementOp wei_op,
|
||||
OutElementOp out_op)
|
||||
__global__ void naive_conv_fwd_packed_multi_abd(
|
||||
const InDataType* const* __restrict__ p_ins, // Array of input pointers (1 + NumAExtra)
|
||||
const WeiDataType* const* __restrict__ p_weis, // Array of weight pointers (1 + NumBExtra)
|
||||
const DDataType* const* __restrict__ p_ds, // Array of D tensor pointers
|
||||
const index_t* const* __restrict__ p_d_strides, // Array of D tensor stride arrays
|
||||
OutDataType* __restrict__ p_out,
|
||||
index_t G,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t C,
|
||||
index_t Di,
|
||||
index_t Hi,
|
||||
index_t Wi,
|
||||
index_t Z,
|
||||
index_t Y,
|
||||
index_t X,
|
||||
index_t Do,
|
||||
index_t Ho,
|
||||
index_t Wo,
|
||||
index_t stride_z,
|
||||
index_t stride_y,
|
||||
index_t stride_x,
|
||||
index_t dilation_z,
|
||||
index_t dilation_y,
|
||||
index_t dilation_x,
|
||||
index_t pad_z,
|
||||
index_t pad_y,
|
||||
index_t pad_x,
|
||||
InElementOp in_op,
|
||||
WeiElementOp wei_op,
|
||||
OutElementOp out_op)
|
||||
{
|
||||
const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const long_index_t num_threads = blockDim.x * gridDim.x;
|
||||
@@ -83,29 +91,48 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
|
||||
const index_t n = remaining % N;
|
||||
const index_t g = remaining / N;
|
||||
|
||||
float acc = 0.0f;
|
||||
const InDataType* in_g = p_in + g * in_stride_g + n * in_stride_n;
|
||||
const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k;
|
||||
float acc = 0.0f;
|
||||
// Base pointers for current group, batch, and output channel
|
||||
const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n;
|
||||
const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k;
|
||||
|
||||
for(index_t c = 0; c < C; ++c)
|
||||
{
|
||||
const InDataType* in_gc = in_g + c * in_stride_c;
|
||||
const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c;
|
||||
// Pointers at current input channel
|
||||
const InDataType* input_at_c = input_g_n + c * in_stride_c;
|
||||
const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c;
|
||||
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
|
||||
if(wi >= 0 && wi < Wi)
|
||||
{
|
||||
in_op(in_val, in_gc[wi]);
|
||||
wei_op(wei_val, wei_gkc[x]);
|
||||
// Handle input element-wise operation with extra A tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
|
||||
in_val,
|
||||
in_op,
|
||||
input_at_c,
|
||||
p_ins + 1,
|
||||
g * in_stride_g + n * in_stride_n + c * in_stride_c,
|
||||
wi);
|
||||
|
||||
// Handle weight element-wise operation with extra B tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
|
||||
wei_val,
|
||||
wei_op,
|
||||
weight_at_c,
|
||||
p_weis + 1,
|
||||
g * wei_stride_g + k * wei_stride_k + c * wei_stride_c,
|
||||
x);
|
||||
|
||||
acc += type_convert<float>(in_val) * type_convert<float>(wei_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OutDataType result = type_convert<OutDataType>(acc);
|
||||
out_op(out_val, result);
|
||||
detail::apply_d_tensor_elementwise_op<NumD>(
|
||||
out_val, out_op, acc, p_ds, p_d_strides, g, n, k, wo);
|
||||
|
||||
p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + wo] = out_val;
|
||||
}
|
||||
}
|
||||
@@ -137,30 +164,51 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
|
||||
const index_t n = remaining % N;
|
||||
const index_t g = remaining / N;
|
||||
|
||||
float acc = 0.0f;
|
||||
const InDataType* in_gn = p_in + g * in_stride_g + n * in_stride_n;
|
||||
const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k;
|
||||
float acc = 0.0f;
|
||||
// Base pointers for current group, batch, and output channel
|
||||
const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n;
|
||||
const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k;
|
||||
|
||||
for(index_t c = 0; c < C; ++c)
|
||||
{
|
||||
const InDataType* in_gnc = in_gn + c * in_stride_c;
|
||||
const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c;
|
||||
// Pointers at current input channel
|
||||
const InDataType* input_at_c = input_g_n + c * in_stride_c;
|
||||
const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c;
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
|
||||
if(hi >= 0 && hi < Hi)
|
||||
{
|
||||
const InDataType* in_gnch = in_gnc + hi * in_stride_h;
|
||||
const WeiDataType* wei_gkcy = wei_gkc + y * wei_stride_y;
|
||||
// Pointers at current spatial height and filter Y position
|
||||
const InDataType* input_at_h = input_at_c + hi * in_stride_h;
|
||||
const WeiDataType* weight_at_y = weight_at_c + y * wei_stride_y;
|
||||
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
|
||||
if(wi >= 0 && wi < Wi)
|
||||
{
|
||||
in_op(in_val, in_gnch[wi]);
|
||||
wei_op(wei_val, wei_gkcy[x]);
|
||||
// Handle input element-wise operation with extra A tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
|
||||
in_val,
|
||||
in_op,
|
||||
input_at_h,
|
||||
p_ins + 1,
|
||||
g * in_stride_g + n * in_stride_n + c * in_stride_c +
|
||||
hi * in_stride_h,
|
||||
wi);
|
||||
|
||||
// Handle weight element-wise operation with extra B tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
|
||||
wei_val,
|
||||
wei_op,
|
||||
weight_at_y,
|
||||
p_weis + 1,
|
||||
g * wei_stride_g + k * wei_stride_k + c * wei_stride_c +
|
||||
y * wei_stride_y,
|
||||
x);
|
||||
|
||||
acc += type_convert<float>(in_val) * type_convert<float>(wei_val);
|
||||
}
|
||||
}
|
||||
@@ -168,8 +216,17 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
|
||||
}
|
||||
}
|
||||
|
||||
OutDataType result = type_convert<OutDataType>(acc);
|
||||
out_op(out_val, result);
|
||||
detail::apply_d_tensor_elementwise_op<NumD>(out_val,
|
||||
out_op,
|
||||
acc,
|
||||
p_ds,
|
||||
p_d_strides,
|
||||
g,
|
||||
n,
|
||||
k,
|
||||
ho * p_d_strides[0][3] +
|
||||
wo * p_d_strides[0][4]);
|
||||
|
||||
p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + ho * out_stride_h + wo] =
|
||||
out_val;
|
||||
}
|
||||
@@ -207,38 +264,60 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
|
||||
const index_t n = remaining % N;
|
||||
const index_t g = remaining / N;
|
||||
|
||||
float acc = 0.0f;
|
||||
const InDataType* in_gn = p_in + g * in_stride_g + n * in_stride_n;
|
||||
const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k;
|
||||
float acc = 0.0f;
|
||||
// Base pointers for current group, batch, and output channel
|
||||
const InDataType* input_g_n = p_ins[0] + g * in_stride_g + n * in_stride_n;
|
||||
const WeiDataType* weight_g_k = p_weis[0] + g * wei_stride_g + k * wei_stride_k;
|
||||
|
||||
for(index_t c = 0; c < C; ++c)
|
||||
{
|
||||
const InDataType* in_gnc = in_gn + c * in_stride_c;
|
||||
const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c;
|
||||
// Pointers at current input channel
|
||||
const InDataType* input_at_c = input_g_n + c * in_stride_c;
|
||||
const WeiDataType* weight_at_c = weight_g_k + c * wei_stride_c;
|
||||
|
||||
for(index_t z = 0; z < Z; ++z)
|
||||
{
|
||||
long_index_t di = do_idx * stride_z + z * dilation_z - pad_z;
|
||||
if(di >= 0 && di < Di)
|
||||
{
|
||||
const InDataType* in_gncd = in_gnc + di * in_stride_d;
|
||||
const WeiDataType* wei_gkcz = wei_gkc + z * wei_stride_z;
|
||||
// Pointers at current spatial depth
|
||||
const InDataType* input_at_d = input_at_c + di * in_stride_d;
|
||||
const WeiDataType* weight_at_z = weight_at_c + z * wei_stride_z;
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
|
||||
if(hi >= 0 && hi < Hi)
|
||||
{
|
||||
const InDataType* in_gncdh = in_gncd + hi * in_stride_h;
|
||||
const WeiDataType* wei_gkczy = wei_gkcz + y * wei_stride_y;
|
||||
// Pointers at current spatial depth and height
|
||||
const InDataType* input_at_d_h = input_at_d + hi * in_stride_h;
|
||||
const WeiDataType* weight_at_z_y = weight_at_z + y * wei_stride_y;
|
||||
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
|
||||
if(wi >= 0 && wi < Wi)
|
||||
{
|
||||
in_op(in_val, in_gncdh[wi]);
|
||||
wei_op(wei_val, wei_gkczy[x]);
|
||||
// Handle input element-wise operation with extra A tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumAExtra>(
|
||||
in_val,
|
||||
in_op,
|
||||
input_at_d_h,
|
||||
p_ins + 1,
|
||||
g * in_stride_g + n * in_stride_n + c * in_stride_c +
|
||||
di * in_stride_d + hi * in_stride_h,
|
||||
wi);
|
||||
|
||||
// Handle weight element-wise operation with extra B tensors
|
||||
detail::apply_multi_tensor_elementwise_op<NumBExtra>(
|
||||
wei_val,
|
||||
wei_op,
|
||||
weight_at_z_y,
|
||||
p_weis + 1,
|
||||
g * wei_stride_g + k * wei_stride_k + c * wei_stride_c +
|
||||
z * wei_stride_z + y * wei_stride_y,
|
||||
x);
|
||||
|
||||
acc += type_convert<float>(in_val) *
|
||||
type_convert<float>(wei_val);
|
||||
}
|
||||
@@ -249,16 +328,28 @@ __global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
|
||||
}
|
||||
}
|
||||
|
||||
OutDataType result = type_convert<OutDataType>(acc);
|
||||
out_op(out_val, result);
|
||||
detail::apply_d_tensor_elementwise_op<NumD>(
|
||||
out_val,
|
||||
out_op,
|
||||
acc,
|
||||
p_ds,
|
||||
p_d_strides,
|
||||
g,
|
||||
n,
|
||||
k,
|
||||
do_idx * p_d_strides[0][3] + ho * p_d_strides[0][4] + wo * p_d_strides[0][5]);
|
||||
|
||||
p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + do_idx * out_stride_d +
|
||||
ho * out_stride_h + wo] = out_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GPU reference convolution - takes ConvParam directly
|
||||
template <typename InLayout,
|
||||
// GPU reference convolution with multi-ABD support - takes ConvParam directly
|
||||
template <ck::index_t NumAElementwise = 0,
|
||||
ck::index_t NumBElementwise = 0,
|
||||
ck::index_t NumDElementwise = 0,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename TIn,
|
||||
@@ -266,15 +357,20 @@ template <typename InLayout,
|
||||
typename TOut,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
void naive_conv_fwd(const TIn* p_in,
|
||||
const TWei* p_wei,
|
||||
TOut* p_out,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
InElementwiseOperation in_element_op = InElementwiseOperation{},
|
||||
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
|
||||
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
|
||||
hipStream_t stream = nullptr)
|
||||
typename OutElementwiseOperation,
|
||||
typename TD = TOut> // D tensor type, defaults to TOut for backward compatibility
|
||||
void naive_conv_fwd_multi_abd(
|
||||
const std::array<const TIn*, NumAElementwise + 1>& p_ins,
|
||||
const std::array<const TWei*, NumBElementwise + 1>& p_weis,
|
||||
const std::array<const TD*, NumDElementwise>& p_ds,
|
||||
TOut* p_out,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
[[maybe_unused]] const std::array<std::vector<index_t>, NumDElementwise>& d_lengths,
|
||||
const std::array<std::vector<index_t>, NumDElementwise>& d_strides,
|
||||
InElementwiseOperation in_element_op = InElementwiseOperation{},
|
||||
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
|
||||
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
|
||||
hipStream_t stream = nullptr)
|
||||
{
|
||||
const auto ndim = conv_param.num_dim_spatial_;
|
||||
|
||||
@@ -303,13 +399,37 @@ void naive_conv_fwd(const TIn* p_in,
|
||||
for(auto l : out_lengths)
|
||||
out_total *= l;
|
||||
|
||||
// Allocate packed buffers
|
||||
SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn));
|
||||
SimpleDeviceMem wei_packed_buf(wei_total * sizeof(TWei));
|
||||
// Allocate packed buffers for all A and B tensors
|
||||
// Use separate allocations to avoid copy assignment issues with RAII wrapper
|
||||
std::vector<SimpleDeviceMem> in_packed_bufs;
|
||||
in_packed_bufs.reserve(NumAElementwise + 1);
|
||||
for(index_t i = 0; i <= NumAElementwise; ++i)
|
||||
{
|
||||
in_packed_bufs.emplace_back(in_total * sizeof(TIn));
|
||||
}
|
||||
|
||||
std::vector<SimpleDeviceMem> wei_packed_bufs;
|
||||
wei_packed_bufs.reserve(NumBElementwise + 1);
|
||||
for(index_t i = 0; i <= NumBElementwise; ++i)
|
||||
{
|
||||
wei_packed_bufs.emplace_back(wei_total * sizeof(TWei));
|
||||
}
|
||||
|
||||
SimpleDeviceMem out_packed_buf(out_total * sizeof(TOut));
|
||||
|
||||
TIn* p_in_packed = static_cast<TIn*>(in_packed_buf.GetDeviceBuffer());
|
||||
TWei* p_wei_packed = static_cast<TWei*>(wei_packed_buf.GetDeviceBuffer());
|
||||
// Get packed buffer pointers
|
||||
std::array<TIn*, NumAElementwise + 1> p_ins_packed;
|
||||
for(index_t i = 0; i <= NumAElementwise; ++i)
|
||||
{
|
||||
p_ins_packed[i] = static_cast<TIn*>(in_packed_bufs[i].GetDeviceBuffer());
|
||||
}
|
||||
|
||||
std::array<TWei*, NumBElementwise + 1> p_weis_packed;
|
||||
for(index_t i = 0; i <= NumBElementwise; ++i)
|
||||
{
|
||||
p_weis_packed[i] = static_cast<TWei*>(wei_packed_bufs[i].GetDeviceBuffer());
|
||||
}
|
||||
|
||||
TOut* p_out_packed = static_cast<TOut*>(out_packed_buf.GetDeviceBuffer());
|
||||
|
||||
// Compute strides and allocate device arrays for pack/unpack
|
||||
@@ -347,12 +467,82 @@ void naive_conv_fwd(const TIn* p_in,
|
||||
|
||||
// Pack input and weight tensors to contiguous layout
|
||||
constexpr int block_size = 256;
|
||||
strided_copy_kernel<TIn, false>
|
||||
<<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_in, p_in_packed, d_in_lengths, d_in_strides, dim_count, in_total);
|
||||
strided_copy_kernel<TWei, false>
|
||||
<<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_wei, p_wei_packed, d_wei_lengths, d_wei_strides, dim_count, wei_total);
|
||||
|
||||
// Pack all A tensors
|
||||
for(index_t i = 0; i <= NumAElementwise; ++i)
|
||||
{
|
||||
strided_copy_kernel<TIn, false>
|
||||
<<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_ins[i], p_ins_packed[i], d_in_lengths, d_in_strides, dim_count, in_total);
|
||||
}
|
||||
|
||||
// Pack all B tensors
|
||||
for(index_t i = 0; i <= NumBElementwise; ++i)
|
||||
{
|
||||
strided_copy_kernel<TWei, false>
|
||||
<<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_weis[i], p_weis_packed[i], d_wei_lengths, d_wei_strides, dim_count, wei_total);
|
||||
}
|
||||
|
||||
// Prepare D tensor stride arrays on device
|
||||
// NOTE: D tensors are NOT packed - they are used directly with their original strides
|
||||
// to support broadcasting (e.g., BiasGK layout with zero strides)
|
||||
std::vector<SimpleDeviceMem> d_stride_bufs;
|
||||
std::array<index_t*, NumDElementwise> p_d_strides_dev = {};
|
||||
|
||||
if constexpr(NumDElementwise > 0)
|
||||
{
|
||||
d_stride_bufs.reserve(NumDElementwise);
|
||||
|
||||
for(index_t i = 0; i < NumDElementwise; ++i)
|
||||
{
|
||||
// Allocate and copy strides to device
|
||||
d_stride_bufs.emplace_back(d_strides[i].size() * sizeof(index_t));
|
||||
p_d_strides_dev[i] = static_cast<index_t*>(d_stride_bufs[i].GetDeviceBuffer());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(p_d_strides_dev[i],
|
||||
d_strides[i].data(),
|
||||
d_strides[i].size() * sizeof(index_t),
|
||||
hipMemcpyHostToDevice));
|
||||
}
|
||||
}
|
||||
|
||||
// Create device arrays of pointers
|
||||
SimpleDeviceMem ins_ptrs_buf((NumAElementwise + 1) * sizeof(TIn*));
|
||||
SimpleDeviceMem weis_ptrs_buf((NumBElementwise + 1) * sizeof(TWei*));
|
||||
SimpleDeviceMem ds_ptrs_buf(NumDElementwise * sizeof(TD*));
|
||||
SimpleDeviceMem d_strides_ptrs_buf(NumDElementwise * sizeof(index_t*));
|
||||
|
||||
TIn** d_ins_ptrs = static_cast<TIn**>(ins_ptrs_buf.GetDeviceBuffer());
|
||||
TWei** d_weis_ptrs = static_cast<TWei**>(weis_ptrs_buf.GetDeviceBuffer());
|
||||
TD** d_ds_ptrs = static_cast<TD**>(ds_ptrs_buf.GetDeviceBuffer());
|
||||
index_t** d_d_strides_ptrs = static_cast<index_t**>(d_strides_ptrs_buf.GetDeviceBuffer());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_ins_ptrs,
|
||||
p_ins_packed.data(),
|
||||
(NumAElementwise + 1) * sizeof(TIn*),
|
||||
hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_weis_ptrs,
|
||||
p_weis_packed.data(),
|
||||
(NumBElementwise + 1) * sizeof(TWei*),
|
||||
hipMemcpyHostToDevice));
|
||||
|
||||
if constexpr(NumDElementwise > 0)
|
||||
{
|
||||
// D tensors use original pointers (not packed) to support broadcasting
|
||||
std::array<const TD*, NumDElementwise> p_ds_dev;
|
||||
for(index_t i = 0; i < NumDElementwise; ++i)
|
||||
{
|
||||
p_ds_dev[i] = p_ds[i];
|
||||
}
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_ds_ptrs, p_ds_dev.data(), NumDElementwise * sizeof(TD*), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(d_d_strides_ptrs,
|
||||
p_d_strides_dev.data(),
|
||||
NumDElementwise * sizeof(index_t*),
|
||||
hipMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
// Build conv parameter vectors for kernel invocation
|
||||
std::vector<index_t> conv_strides(ndim);
|
||||
@@ -370,15 +560,21 @@ void naive_conv_fwd(const TIn* p_in,
|
||||
|
||||
if(ndim == 1)
|
||||
{
|
||||
naive_conv_fwd_packed<1,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<out_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
p_wei_packed,
|
||||
naive_conv_fwd_packed_multi_abd<1,
|
||||
NumAElementwise,
|
||||
NumBElementwise,
|
||||
NumDElementwise,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
TD,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<out_grid, block_size, 0, stream>>>(d_ins_ptrs,
|
||||
d_weis_ptrs,
|
||||
d_ds_ptrs,
|
||||
d_d_strides_ptrs,
|
||||
p_out_packed,
|
||||
G,
|
||||
N,
|
||||
@@ -408,15 +604,21 @@ void naive_conv_fwd(const TIn* p_in,
|
||||
}
|
||||
else if(ndim == 2)
|
||||
{
|
||||
naive_conv_fwd_packed<2,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<out_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
p_wei_packed,
|
||||
naive_conv_fwd_packed_multi_abd<2,
|
||||
NumAElementwise,
|
||||
NumBElementwise,
|
||||
NumDElementwise,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
TD,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<out_grid, block_size, 0, stream>>>(d_ins_ptrs,
|
||||
d_weis_ptrs,
|
||||
d_ds_ptrs,
|
||||
d_d_strides_ptrs,
|
||||
p_out_packed,
|
||||
G,
|
||||
N,
|
||||
@@ -446,15 +648,21 @@ void naive_conv_fwd(const TIn* p_in,
|
||||
}
|
||||
else // 3D
|
||||
{
|
||||
naive_conv_fwd_packed<3,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<out_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
p_wei_packed,
|
||||
naive_conv_fwd_packed_multi_abd<3,
|
||||
NumAElementwise,
|
||||
NumBElementwise,
|
||||
NumDElementwise,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
TD,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<out_grid, block_size, 0, stream>>>(d_ins_ptrs,
|
||||
d_weis_ptrs,
|
||||
d_ds_ptrs,
|
||||
d_d_strides_ptrs,
|
||||
p_out_packed,
|
||||
G,
|
||||
N,
|
||||
@@ -492,5 +700,43 @@ void naive_conv_fwd(const TIn* p_in,
|
||||
// Memory automatically freed by SimpleDeviceMem destructors
|
||||
}
|
||||
|
||||
// Original naive_conv_fwd - now a zero-overhead wrapper
|
||||
template <typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
inline void naive_conv_fwd(const TIn* p_in,
|
||||
const TWei* p_wei,
|
||||
TOut* p_out,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
InElementwiseOperation in_element_op = InElementwiseOperation{},
|
||||
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
|
||||
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
|
||||
hipStream_t stream = nullptr)
|
||||
{
|
||||
std::array<const TIn*, 1> p_ins = {p_in};
|
||||
std::array<const TWei*, 1> p_weis = {p_wei};
|
||||
std::array<const TOut*, 0> p_ds = {};
|
||||
std::array<std::vector<index_t>, 0> d_lengths = {};
|
||||
std::array<std::vector<index_t>, 0> d_strides = {};
|
||||
|
||||
naive_conv_fwd_multi_abd<0, 0, 0, InLayout, WeiLayout, OutLayout>(p_ins,
|
||||
p_weis,
|
||||
p_ds,
|
||||
p_out,
|
||||
conv_param,
|
||||
d_lengths,
|
||||
d_strides,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op,
|
||||
stream);
|
||||
}
|
||||
|
||||
} // namespace ref
|
||||
} // namespace ck
|
||||
|
||||
@@ -22,9 +22,39 @@ struct SimpleDeviceMem
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&p_mem_), mem_size));
|
||||
}
|
||||
|
||||
// Delete copy operations (resource should not be copied)
|
||||
SimpleDeviceMem(const SimpleDeviceMem&) = delete;
|
||||
SimpleDeviceMem& operator=(const SimpleDeviceMem&) = delete;
|
||||
|
||||
// Define move operations
|
||||
SimpleDeviceMem(SimpleDeviceMem&& other) noexcept : p_mem_(other.p_mem_)
|
||||
{
|
||||
other.p_mem_ = nullptr;
|
||||
}
|
||||
|
||||
SimpleDeviceMem& operator=(SimpleDeviceMem&& other) noexcept
|
||||
{
|
||||
if(this != &other)
|
||||
{
|
||||
if(p_mem_)
|
||||
{
|
||||
(void)hipFree(p_mem_);
|
||||
}
|
||||
p_mem_ = other.p_mem_;
|
||||
other.p_mem_ = nullptr;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
~SimpleDeviceMem()
|
||||
{
|
||||
if(p_mem_)
|
||||
{
|
||||
(void)hipFree(p_mem_);
|
||||
}
|
||||
}
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
@@ -173,5 +203,90 @@ __global__ void strided_copy_kernel(const DataType* __restrict__ src,
|
||||
}
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Helper for parameter pack expansion (D tensors)
|
||||
template <typename ResultType, typename Op, typename DataType, std::size_t... Is>
|
||||
__device__ __forceinline__ void apply_multi_tensor_impl(ResultType& result,
|
||||
Op&& element_op,
|
||||
const DataType* const* tensor_ptrs,
|
||||
long_index_t element_offset,
|
||||
std::index_sequence<Is...>)
|
||||
{
|
||||
element_op(result, tensor_ptrs[Is][element_offset]...);
|
||||
}
|
||||
|
||||
// Generic helper for A and B tensors (works in all directions)
|
||||
template <index_t NumExtraTensors, typename DataType, typename ResultType, typename Op>
|
||||
__device__ __forceinline__ void apply_multi_tensor_elementwise_op(ResultType& result,
|
||||
Op&& element_op,
|
||||
const DataType* primary_ptr,
|
||||
const DataType* const* extra_ptrs,
|
||||
long_index_t extra_base_offset,
|
||||
long_index_t element_offset)
|
||||
{
|
||||
const DataType* tensor_ptrs[NumExtraTensors + 1];
|
||||
tensor_ptrs[0] = primary_ptr;
|
||||
|
||||
static_for<1, NumExtraTensors + 1, 1>{}(
|
||||
[&](auto i) { tensor_ptrs[i] = extra_ptrs[i - 1] + extra_base_offset; });
|
||||
|
||||
apply_multi_tensor_impl(result,
|
||||
element_op,
|
||||
tensor_ptrs,
|
||||
element_offset,
|
||||
std::make_index_sequence<NumExtraTensors + 1>{});
|
||||
}
|
||||
|
||||
// Helper for parameter pack expansion (D tensors)
|
||||
template <typename OutDataType, typename Op, std::size_t... Is>
|
||||
__device__ __forceinline__ void apply_d_tensor_impl(OutDataType& result_out,
|
||||
Op&& element_op,
|
||||
float computed_value,
|
||||
const float* d_values,
|
||||
std::index_sequence<Is...>)
|
||||
{
|
||||
float temp_out;
|
||||
element_op(temp_out, computed_value, d_values[Is]...);
|
||||
result_out = type_convert<OutDataType>(temp_out);
|
||||
}
|
||||
|
||||
// Specialized helper for D tensors with stride calculations and float conversion
|
||||
template <index_t NumDTensors, typename DDataType, typename OutDataType, typename Op>
|
||||
__device__ __forceinline__ void apply_d_tensor_elementwise_op(OutDataType& result_out,
|
||||
Op&& element_op,
|
||||
float computed_value,
|
||||
const DDataType* const* p_ds,
|
||||
const index_t* const* p_d_strides,
|
||||
index_t g,
|
||||
index_t n,
|
||||
index_t c_or_k,
|
||||
long_index_t spatial_linear_index)
|
||||
{
|
||||
if constexpr(NumDTensors == 0)
|
||||
{
|
||||
element_op(result_out, computed_value);
|
||||
}
|
||||
else
|
||||
{
|
||||
float d_values[NumDTensors];
|
||||
|
||||
// Compute all D tensor indices and convert to float
|
||||
static_for<0, NumDTensors, 1>{}([&](auto i) {
|
||||
const long_index_t d_idx = g * p_d_strides[i][0] + n * p_d_strides[i][1] +
|
||||
c_or_k * p_d_strides[i][2] + spatial_linear_index;
|
||||
d_values[i] = type_convert<float>(p_ds[i][d_idx]);
|
||||
});
|
||||
|
||||
apply_d_tensor_impl(result_out,
|
||||
element_op,
|
||||
computed_value,
|
||||
d_values,
|
||||
std::make_index_sequence<NumDTensors>{});
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
} // namespace ref
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user