mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK] Integrate GPU reference into ckProfiler for convolutions (#3379)
Refactor and integrate CK GPU references into ckProfiler. - All convolution layouts and groupings supported for all three directions - Unit tests verifying GPU and CPU reference is the same - Support added to profiler (do_verification = 2 enables GPU reference) - One profiler-based test per direction changed to GPU reference to demonstrate usag Closes AICK-427
This commit is contained in:
@@ -1,73 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#ifndef CONV_COMMON_HPP
|
||||
#define CONV_COMMON_HPP
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace ref {
|
||||
|
||||
// Device-compatible dimension structure for GPU reference kernels
|
||||
// Replaces passing 24 individual parameters
|
||||
struct ConvDims
|
||||
{
|
||||
index_t N, K, C;
|
||||
index_t Di, Hi, Wi;
|
||||
index_t Z, Y, X;
|
||||
index_t Do, Ho, Wo;
|
||||
index_t stride_z, stride_y, stride_x;
|
||||
index_t dilation_z, dilation_y, dilation_x;
|
||||
index_t pad_z, pad_y, pad_x;
|
||||
};
|
||||
|
||||
} // namespace ref
|
||||
|
||||
// Helper function to extract dimensions from ConvParam for GPU kernels
|
||||
// Defined in ck::utils::conv namespace for convenience
|
||||
namespace utils {
|
||||
namespace conv {
|
||||
|
||||
inline ck::ref::ConvDims
|
||||
extract_conv_dims(const ConvParam& conv_param, ck::index_t NDimSpatial, bool apply_group = true)
|
||||
{
|
||||
ck::ref::ConvDims dims;
|
||||
dims.N = conv_param.N_;
|
||||
dims.K = conv_param.K_;
|
||||
dims.C = apply_group ? (conv_param.C_ * conv_param.G_) : conv_param.C_;
|
||||
|
||||
dims.Di = (NDimSpatial >= 3) ? conv_param.input_spatial_lengths_[0] : 1;
|
||||
dims.Hi = (NDimSpatial >= 2) ? conv_param.input_spatial_lengths_[NDimSpatial >= 3 ? 1 : 0] : 1;
|
||||
dims.Wi = conv_param.input_spatial_lengths_[NDimSpatial - 1];
|
||||
|
||||
dims.Z = (NDimSpatial >= 3) ? conv_param.filter_spatial_lengths_[0] : 1;
|
||||
dims.Y = (NDimSpatial >= 2) ? conv_param.filter_spatial_lengths_[NDimSpatial >= 3 ? 1 : 0] : 1;
|
||||
dims.X = conv_param.filter_spatial_lengths_[NDimSpatial - 1];
|
||||
|
||||
dims.Do = (NDimSpatial >= 3) ? conv_param.output_spatial_lengths_[0] : 1;
|
||||
dims.Ho = (NDimSpatial >= 2) ? conv_param.output_spatial_lengths_[NDimSpatial >= 3 ? 1 : 0] : 1;
|
||||
dims.Wo = conv_param.output_spatial_lengths_[NDimSpatial - 1];
|
||||
|
||||
dims.stride_z = (NDimSpatial >= 3) ? conv_param.conv_filter_strides_[0] : 1;
|
||||
dims.stride_y =
|
||||
(NDimSpatial >= 2) ? conv_param.conv_filter_strides_[NDimSpatial >= 3 ? 1 : 0] : 1;
|
||||
dims.stride_x = conv_param.conv_filter_strides_[NDimSpatial - 1];
|
||||
|
||||
dims.dilation_z = (NDimSpatial >= 3) ? conv_param.conv_filter_dilations_[0] : 1;
|
||||
dims.dilation_y =
|
||||
(NDimSpatial >= 2) ? conv_param.conv_filter_dilations_[NDimSpatial >= 3 ? 1 : 0] : 1;
|
||||
dims.dilation_x = conv_param.conv_filter_dilations_[NDimSpatial - 1];
|
||||
|
||||
dims.pad_z = (NDimSpatial >= 3) ? conv_param.input_left_pads_[0] : 0;
|
||||
dims.pad_y = (NDimSpatial >= 2) ? conv_param.input_left_pads_[NDimSpatial >= 3 ? 1 : 0] : 0;
|
||||
dims.pad_x = conv_param.input_left_pads_[NDimSpatial - 1];
|
||||
|
||||
return dims;
|
||||
}
|
||||
|
||||
} // namespace conv
|
||||
} // namespace utils
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -4,146 +4,515 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
#include "ck/library/reference_tensor_operation/gpu/conv_common.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace ref {
|
||||
|
||||
/*
|
||||
* \brief naive implementation of 3D convolution backward data.
|
||||
* Layout is (NDHWC, KZYXC, NDHWK).
|
||||
* Computes gradient with respect to input.
|
||||
*
|
||||
* \param N number of batches
|
||||
* \param K number of filters (output channels)
|
||||
* \param C number of input channels
|
||||
* \param (Di, Hi, Wi) depth, height and width dimension of input
|
||||
* \param (Z, Y, X) depth, height and width dimensions of filter
|
||||
* \param (Do, Ho, Wo) depth, height and width dimension of output
|
||||
* \param (stride_z, stride_y, stride_x) strides
|
||||
* \param (dilation_z, dilation_y, dilation_x) dilations
|
||||
* \param (pad_z, pad_y, pad_x) pads
|
||||
*/
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename TAcc,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
__global__ void naive_conv_bwd_data_ndhwc_kzyxc_ndhwk(TIn* __restrict__ p_in_grad,
|
||||
const TWei* __restrict__ p_wei,
|
||||
const TOut* __restrict__ p_out_grad,
|
||||
const ConvDims dims)
|
||||
// Optimized backward data convolution kernel working with packed (contiguous) tensors
|
||||
// Computes gradients w.r.t. input from output gradients and weights
|
||||
// Assumes row-major packing: input[G][N][C][spatial], weight[G][K][C][filter],
|
||||
// output[G][N][K][spatial]
|
||||
template <index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InElementOp,
|
||||
typename WeiElementOp,
|
||||
typename OutElementOp>
|
||||
__global__ void naive_conv_bwd_data_packed(InDataType* __restrict__ p_in,
|
||||
const WeiDataType* __restrict__ p_wei,
|
||||
const OutDataType* __restrict__ p_out,
|
||||
index_t G,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t C,
|
||||
index_t Di,
|
||||
index_t Hi,
|
||||
index_t Wi,
|
||||
index_t Z,
|
||||
index_t Y,
|
||||
index_t X,
|
||||
index_t Do,
|
||||
index_t Ho,
|
||||
index_t Wo,
|
||||
index_t stride_z,
|
||||
index_t stride_y,
|
||||
index_t stride_x,
|
||||
index_t dilation_z,
|
||||
index_t dilation_y,
|
||||
index_t dilation_x,
|
||||
index_t pad_z,
|
||||
index_t pad_y,
|
||||
index_t pad_x,
|
||||
InElementOp in_op,
|
||||
WeiElementOp wei_op,
|
||||
OutElementOp out_op)
|
||||
{
|
||||
const index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const index_t num_threads = blockDim.x * gridDim.x;
|
||||
const long_index_t input_length = dims.N * dims.Di * dims.Hi * dims.Wi * dims.C;
|
||||
const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const long_index_t num_threads = blockDim.x * gridDim.x;
|
||||
|
||||
const index_t in_strides[] = {
|
||||
dims.Di * dims.Hi * dims.Wi * dims.C, dims.Hi * dims.Wi * dims.C, dims.Wi * dims.C, dims.C};
|
||||
const index_t out_strides[] = {
|
||||
dims.Do * dims.Ho * dims.Wo * dims.K, dims.Ho * dims.Wo * dims.K, dims.Wo * dims.K, dims.K};
|
||||
const index_t wei_strides[] = {
|
||||
dims.Z * dims.Y * dims.X * dims.C, dims.Y * dims.X * dims.C, dims.X * dims.C, dims.C};
|
||||
InDataType in_val = InDataType{0};
|
||||
WeiDataType wei_val = WeiDataType{0};
|
||||
OutDataType out_val = OutDataType{0};
|
||||
|
||||
constexpr auto in_op = InElementwiseOperation{};
|
||||
constexpr auto wei_op = WeiElementwiseOperation{};
|
||||
constexpr auto out_op = OutElementwiseOperation{};
|
||||
|
||||
TIn in_val = TIn{0};
|
||||
TWei wei_val = TWei{0};
|
||||
TOut out_val = TOut{0};
|
||||
|
||||
for(long_index_t ii = tid; ii < input_length; ii += num_threads)
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
// Decode linear index to (n, di, hi, wi, c)
|
||||
const index_t n = ii / in_strides[0];
|
||||
index_t tmp = ii - n * in_strides[0];
|
||||
const index_t di = tmp / in_strides[1];
|
||||
tmp -= di * in_strides[1];
|
||||
const index_t hi = tmp / in_strides[2];
|
||||
tmp -= hi * in_strides[2];
|
||||
const index_t wi = tmp / in_strides[3];
|
||||
tmp -= wi * in_strides[3];
|
||||
const index_t c = tmp;
|
||||
const long_index_t num_in = G * N * C * Wi;
|
||||
const long_index_t out_stride_g = N * K * Wo;
|
||||
const long_index_t out_stride_n = K * Wo;
|
||||
const long_index_t out_stride_k = Wo;
|
||||
const long_index_t wei_stride_g = K * C * X;
|
||||
const long_index_t wei_stride_k = C * X;
|
||||
const long_index_t wei_stride_c = X;
|
||||
const long_index_t in_stride_g = N * C * Wi;
|
||||
const long_index_t in_stride_n = C * Wi;
|
||||
const long_index_t in_stride_c = Wi;
|
||||
|
||||
// Always accumulate in float
|
||||
float acc_float = 0.0f;
|
||||
|
||||
const TOut* out_n = p_out_grad + static_cast<long_index_t>(n) * out_strides[0];
|
||||
|
||||
// Loop over output channels
|
||||
for(index_t k = 0; k < dims.K; ++k)
|
||||
for(long_index_t idx = tid; idx < num_in; idx += num_threads)
|
||||
{
|
||||
const TWei* wei_k = p_wei + static_cast<long_index_t>(k) * wei_strides[0];
|
||||
index_t remaining = idx;
|
||||
const index_t wi = remaining % Wi;
|
||||
remaining /= Wi;
|
||||
const index_t c = remaining % C;
|
||||
remaining /= C;
|
||||
const index_t n = remaining % N;
|
||||
const index_t g = remaining / N;
|
||||
|
||||
// Loop over filter dimensions
|
||||
for(index_t z = 0; z < dims.Z; ++z)
|
||||
float acc = 0.0f;
|
||||
const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n;
|
||||
const WeiDataType* wei_g = p_wei + g * wei_stride_g;
|
||||
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
// Calculate output position from input position (inverse of forward)
|
||||
index_t d_tmp = di + dims.pad_z - z * dims.dilation_z;
|
||||
if(d_tmp % dims.stride_z != 0)
|
||||
continue;
|
||||
index_t d_o = d_tmp / dims.stride_z;
|
||||
if(d_o < 0 || d_o >= dims.Do)
|
||||
continue;
|
||||
|
||||
const TOut* out_n_do = out_n + d_o * out_strides[1];
|
||||
const TWei* wei_k_z = wei_k + z * wei_strides[1];
|
||||
|
||||
for(index_t y = 0; y < dims.Y; ++y)
|
||||
long_index_t w_tmp = wi + pad_x - x * dilation_x;
|
||||
if(w_tmp % stride_x == 0)
|
||||
{
|
||||
index_t h_tmp = hi + dims.pad_y - y * dims.dilation_y;
|
||||
if(h_tmp % dims.stride_y != 0)
|
||||
continue;
|
||||
index_t ho = h_tmp / dims.stride_y;
|
||||
if(ho < 0 || ho >= dims.Ho)
|
||||
continue;
|
||||
|
||||
const TOut* out_n_do_ho = out_n_do + ho * out_strides[2];
|
||||
const TWei* wei_k_z_y = wei_k_z + y * wei_strides[2];
|
||||
|
||||
for(index_t x = 0; x < dims.X; ++x)
|
||||
long_index_t wo = w_tmp / stride_x;
|
||||
if(wo >= 0 && wo < Wo)
|
||||
{
|
||||
index_t w_tmp = wi + dims.pad_x - x * dims.dilation_x;
|
||||
if(w_tmp % dims.stride_x != 0)
|
||||
continue;
|
||||
index_t wo = w_tmp / dims.stride_x;
|
||||
if(wo < 0 || wo >= dims.Wo)
|
||||
continue;
|
||||
const OutDataType* out_gnk = out_gn;
|
||||
const WeiDataType* wei_gkc = wei_g + c * wei_stride_c;
|
||||
|
||||
const TOut* out_n_do_ho_wo = out_n_do_ho + wo * out_strides[3];
|
||||
const TWei* wei_k_z_y_x = wei_k_z_y + x * wei_strides[3];
|
||||
|
||||
// Load values from memory
|
||||
TOut out_loaded = out_n_do_ho_wo[k];
|
||||
TWei wei_loaded = wei_k_z_y_x[c];
|
||||
|
||||
// Apply element-wise operations (like forward does)
|
||||
out_op(out_val, out_loaded);
|
||||
wei_op(wei_val, wei_loaded);
|
||||
|
||||
// Convert to float for multiplication
|
||||
float out_f = type_convert<float>(out_val);
|
||||
float wei_f = type_convert<float>(wei_val);
|
||||
|
||||
acc_float += out_f * wei_f;
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
out_op(out_val, out_gnk[k * out_stride_k + wo]);
|
||||
wei_op(wei_val, wei_gkc[k * wei_stride_k + x]);
|
||||
acc += type_convert<float>(out_val) * type_convert<float>(wei_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
InDataType result = type_convert<InDataType>(acc);
|
||||
in_op(in_val, result);
|
||||
p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + wi] = in_val;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
const long_index_t num_in = G * N * C * Hi * Wi;
|
||||
const long_index_t out_stride_g = N * K * Ho * Wo;
|
||||
const long_index_t out_stride_n = K * Ho * Wo;
|
||||
const long_index_t out_stride_k = Ho * Wo;
|
||||
const long_index_t out_stride_h = Wo;
|
||||
const long_index_t wei_stride_g = K * C * Y * X;
|
||||
const long_index_t wei_stride_k = C * Y * X;
|
||||
const long_index_t wei_stride_c = Y * X;
|
||||
const long_index_t wei_stride_y = X;
|
||||
const long_index_t in_stride_g = N * C * Hi * Wi;
|
||||
const long_index_t in_stride_n = C * Hi * Wi;
|
||||
const long_index_t in_stride_c = Hi * Wi;
|
||||
const long_index_t in_stride_h = Wi;
|
||||
|
||||
// Convert float accumulator to TAcc, then to input type
|
||||
TAcc acc = type_convert<TAcc>(acc_float);
|
||||
TIn result = type_convert<TIn>(acc);
|
||||
for(long_index_t idx = tid; idx < num_in; idx += num_threads)
|
||||
{
|
||||
index_t remaining = idx;
|
||||
const index_t wi = remaining % Wi;
|
||||
remaining /= Wi;
|
||||
const index_t hi = remaining % Hi;
|
||||
remaining /= Hi;
|
||||
const index_t c = remaining % C;
|
||||
remaining /= C;
|
||||
const index_t n = remaining % N;
|
||||
const index_t g = remaining / N;
|
||||
|
||||
// Apply input element-wise operation (if any)
|
||||
in_op(in_val, result);
|
||||
float acc = 0.0f;
|
||||
const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n;
|
||||
const WeiDataType* wei_g = p_wei + g * wei_stride_g;
|
||||
|
||||
// Write transformed result
|
||||
p_in_grad[ii] = in_val;
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
long_index_t h_tmp = hi + pad_y - y * dilation_y;
|
||||
if(h_tmp % stride_y == 0)
|
||||
{
|
||||
long_index_t ho = h_tmp / stride_y;
|
||||
if(ho >= 0 && ho < Ho)
|
||||
{
|
||||
const OutDataType* out_gnkh = out_gn + ho * out_stride_h;
|
||||
const WeiDataType* wei_gkcy = wei_g + c * wei_stride_c + y * wei_stride_y;
|
||||
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
long_index_t w_tmp = wi + pad_x - x * dilation_x;
|
||||
if(w_tmp % stride_x == 0)
|
||||
{
|
||||
long_index_t wo = w_tmp / stride_x;
|
||||
if(wo >= 0 && wo < Wo)
|
||||
{
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
out_op(out_val, out_gnkh[k * out_stride_k + wo]);
|
||||
wei_op(wei_val, wei_gkcy[k * wei_stride_k + x]);
|
||||
acc += type_convert<float>(out_val) *
|
||||
type_convert<float>(wei_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
InDataType result = type_convert<InDataType>(acc);
|
||||
in_op(in_val, result);
|
||||
p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + hi * in_stride_h + wi] =
|
||||
in_val;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
const long_index_t num_in = G * N * C * Di * Hi * Wi;
|
||||
const long_index_t out_stride_g = N * K * Do * Ho * Wo;
|
||||
const long_index_t out_stride_n = K * Do * Ho * Wo;
|
||||
const long_index_t out_stride_k = Do * Ho * Wo;
|
||||
const long_index_t out_stride_d = Ho * Wo;
|
||||
const long_index_t out_stride_h = Wo;
|
||||
const long_index_t wei_stride_g = K * C * Z * Y * X;
|
||||
const long_index_t wei_stride_k = C * Z * Y * X;
|
||||
const long_index_t wei_stride_c = Z * Y * X;
|
||||
const long_index_t wei_stride_z = Y * X;
|
||||
const long_index_t wei_stride_y = X;
|
||||
const long_index_t in_stride_g = N * C * Di * Hi * Wi;
|
||||
const long_index_t in_stride_n = C * Di * Hi * Wi;
|
||||
const long_index_t in_stride_c = Di * Hi * Wi;
|
||||
const long_index_t in_stride_d = Hi * Wi;
|
||||
const long_index_t in_stride_h = Wi;
|
||||
|
||||
for(long_index_t idx = tid; idx < num_in; idx += num_threads)
|
||||
{
|
||||
index_t remaining = idx;
|
||||
const index_t wi = remaining % Wi;
|
||||
remaining /= Wi;
|
||||
const index_t hi = remaining % Hi;
|
||||
remaining /= Hi;
|
||||
const index_t di = remaining % Di;
|
||||
remaining /= Di;
|
||||
const index_t c = remaining % C;
|
||||
remaining /= C;
|
||||
const index_t n = remaining % N;
|
||||
const index_t g = remaining / N;
|
||||
|
||||
float acc = 0.0f;
|
||||
const OutDataType* out_gn = p_out + g * out_stride_g + n * out_stride_n;
|
||||
const WeiDataType* wei_g = p_wei + g * wei_stride_g;
|
||||
|
||||
for(index_t z = 0; z < Z; ++z)
|
||||
{
|
||||
long_index_t d_tmp = di + pad_z - z * dilation_z;
|
||||
if(d_tmp % stride_z == 0)
|
||||
{
|
||||
long_index_t do_idx = d_tmp / stride_z;
|
||||
if(do_idx >= 0 && do_idx < Do)
|
||||
{
|
||||
const OutDataType* out_gnkd = out_gn + do_idx * out_stride_d;
|
||||
const WeiDataType* wei_gkcz = wei_g + c * wei_stride_c + z * wei_stride_z;
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
long_index_t h_tmp = hi + pad_y - y * dilation_y;
|
||||
if(h_tmp % stride_y == 0)
|
||||
{
|
||||
long_index_t ho = h_tmp / stride_y;
|
||||
if(ho >= 0 && ho < Ho)
|
||||
{
|
||||
const OutDataType* out_gnkdh = out_gnkd + ho * out_stride_h;
|
||||
const WeiDataType* wei_gkczy = wei_gkcz + y * wei_stride_y;
|
||||
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
long_index_t w_tmp = wi + pad_x - x * dilation_x;
|
||||
if(w_tmp % stride_x == 0)
|
||||
{
|
||||
long_index_t wo = w_tmp / stride_x;
|
||||
if(wo >= 0 && wo < Wo)
|
||||
{
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
out_op(out_val,
|
||||
out_gnkdh[k * out_stride_k + wo]);
|
||||
wei_op(wei_val,
|
||||
wei_gkczy[k * wei_stride_k + x]);
|
||||
acc += type_convert<float>(out_val) *
|
||||
type_convert<float>(wei_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
InDataType result = type_convert<InDataType>(acc);
|
||||
in_op(in_val, result);
|
||||
p_in[g * in_stride_g + n * in_stride_n + c * in_stride_c + di * in_stride_d +
|
||||
hi * in_stride_h + wi] = in_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GPU reference backward data convolution - takes ConvParam directly
|
||||
template <typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
void naive_conv_bwd_data(TIn* p_in,
|
||||
const TWei* p_wei,
|
||||
const TOut* p_out,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
InElementwiseOperation in_element_op = InElementwiseOperation{},
|
||||
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
|
||||
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
|
||||
hipStream_t stream = nullptr)
|
||||
{
|
||||
const auto ndim = conv_param.num_dim_spatial_;
|
||||
|
||||
const index_t G = conv_param.G_;
|
||||
const index_t N = conv_param.N_;
|
||||
const index_t C = conv_param.C_;
|
||||
const index_t K = conv_param.K_;
|
||||
|
||||
std::vector<index_t> in_lengths = {G, N, C};
|
||||
std::vector<index_t> wei_lengths = {G, K, C};
|
||||
std::vector<index_t> out_lengths = {G, N, K};
|
||||
|
||||
for(index_t i = 0; i < ndim; ++i)
|
||||
{
|
||||
in_lengths.push_back(static_cast<index_t>(conv_param.input_spatial_lengths_[i]));
|
||||
wei_lengths.push_back(static_cast<index_t>(conv_param.filter_spatial_lengths_[i]));
|
||||
out_lengths.push_back(static_cast<index_t>(conv_param.output_spatial_lengths_[i]));
|
||||
}
|
||||
|
||||
// Calculate total elements for buffer allocation
|
||||
long_index_t in_total = 1, wei_total = 1, out_total = 1;
|
||||
for(auto l : in_lengths)
|
||||
in_total *= l;
|
||||
for(auto l : wei_lengths)
|
||||
wei_total *= l;
|
||||
for(auto l : out_lengths)
|
||||
out_total *= l;
|
||||
|
||||
// Allocate packed buffers
|
||||
SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn));
|
||||
SimpleDeviceMem wei_packed_buf(wei_total * sizeof(TWei));
|
||||
SimpleDeviceMem out_packed_buf(out_total * sizeof(TOut));
|
||||
|
||||
TIn* p_in_packed = static_cast<TIn*>(in_packed_buf.GetDeviceBuffer());
|
||||
TWei* p_wei_packed = static_cast<TWei*>(wei_packed_buf.GetDeviceBuffer());
|
||||
TOut* p_out_packed = static_cast<TOut*>(out_packed_buf.GetDeviceBuffer());
|
||||
|
||||
// Compute strides and allocate device arrays for pack/unpack
|
||||
std::vector<index_t> in_strides = compute_conv_tensor_strides<InLayout>(in_lengths, ndim);
|
||||
std::vector<index_t> wei_strides = compute_conv_tensor_strides<WeiLayout>(wei_lengths, ndim);
|
||||
std::vector<index_t> out_strides = compute_conv_tensor_strides<OutLayout>(out_lengths, ndim);
|
||||
|
||||
const size_t dim_count = in_lengths.size();
|
||||
SimpleDeviceMem in_lengths_buf(dim_count * sizeof(index_t));
|
||||
SimpleDeviceMem in_strides_buf(dim_count * sizeof(index_t));
|
||||
SimpleDeviceMem wei_lengths_buf(dim_count * sizeof(index_t));
|
||||
SimpleDeviceMem wei_strides_buf(dim_count * sizeof(index_t));
|
||||
SimpleDeviceMem out_lengths_buf(dim_count * sizeof(index_t));
|
||||
SimpleDeviceMem out_strides_buf(dim_count * sizeof(index_t));
|
||||
|
||||
index_t* d_in_lengths = static_cast<index_t*>(in_lengths_buf.GetDeviceBuffer());
|
||||
index_t* d_in_strides = static_cast<index_t*>(in_strides_buf.GetDeviceBuffer());
|
||||
index_t* d_wei_lengths = static_cast<index_t*>(wei_lengths_buf.GetDeviceBuffer());
|
||||
index_t* d_wei_strides = static_cast<index_t*>(wei_strides_buf.GetDeviceBuffer());
|
||||
index_t* d_out_lengths = static_cast<index_t*>(out_lengths_buf.GetDeviceBuffer());
|
||||
index_t* d_out_strides = static_cast<index_t*>(out_strides_buf.GetDeviceBuffer());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_in_lengths, in_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_in_strides, in_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_wei_lengths, wei_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_wei_strides, wei_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_out_lengths, out_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_out_strides, out_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
|
||||
|
||||
// Pack output and weight tensors to contiguous layout (inputs to bwd data)
|
||||
constexpr int block_size = 256;
|
||||
strided_copy_kernel<TOut, false>
|
||||
<<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_out, p_out_packed, d_out_lengths, d_out_strides, dim_count, out_total);
|
||||
strided_copy_kernel<TWei, false>
|
||||
<<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_wei, p_wei_packed, d_wei_lengths, d_wei_strides, dim_count, wei_total);
|
||||
|
||||
// Build conv parameter vectors for kernel invocation
|
||||
std::vector<index_t> conv_strides(ndim);
|
||||
std::vector<index_t> conv_dilations(ndim);
|
||||
std::vector<index_t> input_pads(ndim);
|
||||
for(index_t i = 0; i < ndim; ++i)
|
||||
{
|
||||
conv_strides[i] = static_cast<index_t>(conv_param.conv_filter_strides_[i]);
|
||||
conv_dilations[i] = static_cast<index_t>(conv_param.conv_filter_dilations_[i]);
|
||||
input_pads[i] = static_cast<index_t>(conv_param.input_left_pads_[i]);
|
||||
}
|
||||
|
||||
// Run backward data convolution kernel on packed data
|
||||
const int in_grid = (in_total + block_size - 1) / block_size;
|
||||
|
||||
if(ndim == 1)
|
||||
{
|
||||
naive_conv_bwd_data_packed<1,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<in_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
p_wei_packed,
|
||||
p_out_packed,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
1,
|
||||
1,
|
||||
in_lengths[3],
|
||||
1,
|
||||
1,
|
||||
wei_lengths[3],
|
||||
1,
|
||||
1,
|
||||
out_lengths[3],
|
||||
1,
|
||||
1,
|
||||
conv_strides[0],
|
||||
1,
|
||||
1,
|
||||
conv_dilations[0],
|
||||
0,
|
||||
0,
|
||||
input_pads[0],
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
else if(ndim == 2)
|
||||
{
|
||||
naive_conv_bwd_data_packed<2,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<in_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
p_wei_packed,
|
||||
p_out_packed,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
1,
|
||||
in_lengths[3],
|
||||
in_lengths[4],
|
||||
1,
|
||||
wei_lengths[3],
|
||||
wei_lengths[4],
|
||||
1,
|
||||
out_lengths[3],
|
||||
out_lengths[4],
|
||||
1,
|
||||
conv_strides[0],
|
||||
conv_strides[1],
|
||||
1,
|
||||
conv_dilations[0],
|
||||
conv_dilations[1],
|
||||
0,
|
||||
input_pads[0],
|
||||
input_pads[1],
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
else // 3D
|
||||
{
|
||||
naive_conv_bwd_data_packed<3,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<in_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
p_wei_packed,
|
||||
p_out_packed,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
in_lengths[3],
|
||||
in_lengths[4],
|
||||
in_lengths[5],
|
||||
wei_lengths[3],
|
||||
wei_lengths[4],
|
||||
wei_lengths[5],
|
||||
out_lengths[3],
|
||||
out_lengths[4],
|
||||
out_lengths[5],
|
||||
conv_strides[0],
|
||||
conv_strides[1],
|
||||
conv_strides[2],
|
||||
conv_dilations[0],
|
||||
conv_dilations[1],
|
||||
conv_dilations[2],
|
||||
input_pads[0],
|
||||
input_pads[1],
|
||||
input_pads[2],
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
|
||||
// Unpack result back to strided layout
|
||||
strided_copy_kernel<TIn, true><<<in_grid, block_size, 0, stream>>>(
|
||||
p_in_packed, p_in, d_in_lengths, d_in_strides, dim_count, in_total);
|
||||
|
||||
HIP_CHECK_ERROR(hipGetLastError());
|
||||
|
||||
// Memory automatically freed by SimpleDeviceMem destructors
|
||||
}
|
||||
|
||||
} // namespace ref
|
||||
} // namespace ck
|
||||
|
||||
@@ -4,136 +4,497 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
#include "ck/library/reference_tensor_operation/gpu/conv_common.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace ref {
|
||||
|
||||
/*
|
||||
* \brief naive implementation of 3D convolution backward weight.
|
||||
* Layout is (NDHWC, KZYXC, NDHWK).
|
||||
* Computes gradient with respect to weights.
|
||||
*
|
||||
* \param N number of batches
|
||||
* \param K number of filters (output channels)
|
||||
* \param C number of input channels
|
||||
* \param (Di, Hi, Wi) depth, height and width dimension of input
|
||||
* \param (Z, Y, X) depth, height and width dimensions of filter
|
||||
* \param (Do, Ho, Wo) depth, height and width dimension of output
|
||||
* \param (stride_z, stride_y, stride_x) strides
|
||||
* \param (dilation_z, dilation_y, dilation_x) dilations
|
||||
* \param (pad_z, pad_y, pad_x) pads
|
||||
*/
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename TAcc,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
__global__ void naive_conv_bwd_weight_ndhwc_kzyxc_ndhwk(const TIn* __restrict__ p_in,
|
||||
TWei* __restrict__ p_wei_grad,
|
||||
const TOut* __restrict__ p_out_grad,
|
||||
const ConvDims dims)
|
||||
// Optimized backward weight convolution kernel working with packed (contiguous) tensors
|
||||
// Assumes row-major packing: input[G][N][C][spatial], output_grad[G][N][K][spatial],
|
||||
// weight_grad[G][K][C][filter]
|
||||
// Computes gradient with respect to weights
|
||||
template <index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InElementOp,
|
||||
typename WeiElementOp,
|
||||
typename OutElementOp>
|
||||
__global__ void naive_conv_bwd_weight_packed(const InDataType* __restrict__ p_in,
|
||||
WeiDataType* __restrict__ p_wei_grad,
|
||||
const OutDataType* __restrict__ p_out_grad,
|
||||
index_t G,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t C,
|
||||
index_t Di,
|
||||
index_t Hi,
|
||||
index_t Wi,
|
||||
index_t Z,
|
||||
index_t Y,
|
||||
index_t X,
|
||||
index_t Do,
|
||||
index_t Ho,
|
||||
index_t Wo,
|
||||
index_t stride_z,
|
||||
index_t stride_y,
|
||||
index_t stride_x,
|
||||
index_t dilation_z,
|
||||
index_t dilation_y,
|
||||
index_t dilation_x,
|
||||
index_t pad_z,
|
||||
index_t pad_y,
|
||||
index_t pad_x,
|
||||
InElementOp in_op,
|
||||
WeiElementOp wei_op,
|
||||
OutElementOp out_op)
|
||||
{
|
||||
const index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const index_t num_threads = blockDim.x * gridDim.x;
|
||||
const long_index_t weight_length = dims.K * dims.Z * dims.Y * dims.X * dims.C;
|
||||
const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const long_index_t num_threads = blockDim.x * gridDim.x;
|
||||
|
||||
const index_t in_strides[] = {
|
||||
dims.Di * dims.Hi * dims.Wi * dims.C, dims.Hi * dims.Wi * dims.C, dims.Wi * dims.C, dims.C};
|
||||
const index_t out_strides[] = {
|
||||
dims.Do * dims.Ho * dims.Wo * dims.K, dims.Ho * dims.Wo * dims.K, dims.Wo * dims.K, dims.K};
|
||||
const index_t wei_strides[] = {
|
||||
dims.Z * dims.Y * dims.X * dims.C, dims.Y * dims.X * dims.C, dims.X * dims.C, dims.C};
|
||||
InDataType in_val = InDataType{0};
|
||||
WeiDataType wei_val = WeiDataType{0};
|
||||
OutDataType out_val = OutDataType{0};
|
||||
|
||||
constexpr auto in_op = InElementwiseOperation{};
|
||||
constexpr auto wei_op = WeiElementwiseOperation{};
|
||||
constexpr auto out_op = OutElementwiseOperation{};
|
||||
|
||||
TIn in_val = TIn{0};
|
||||
TWei wei_val = TWei{0};
|
||||
TOut out_val = TOut{0};
|
||||
|
||||
for(long_index_t ii = tid; ii < weight_length; ii += num_threads)
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
// Decode linear index to (k, z, y, x, c)
|
||||
const index_t k = ii / wei_strides[0];
|
||||
index_t tmp = ii - k * wei_strides[0];
|
||||
const index_t z = tmp / wei_strides[1];
|
||||
tmp -= z * wei_strides[1];
|
||||
const index_t y = tmp / wei_strides[2];
|
||||
tmp -= y * wei_strides[2];
|
||||
const index_t x = tmp / wei_strides[3];
|
||||
tmp -= x * wei_strides[3];
|
||||
const index_t c = tmp;
|
||||
const long_index_t num_wei = G * K * C * X;
|
||||
const long_index_t in_stride_g = N * C * Wi;
|
||||
const long_index_t in_stride_n = C * Wi;
|
||||
const long_index_t in_stride_c = Wi;
|
||||
const long_index_t out_stride_g = N * K * Wo;
|
||||
const long_index_t out_stride_n = K * Wo;
|
||||
const long_index_t out_stride_k = Wo;
|
||||
const long_index_t wei_stride_g = K * C * X;
|
||||
const long_index_t wei_stride_k = C * X;
|
||||
const long_index_t wei_stride_c = X;
|
||||
|
||||
// Always accumulate in float
|
||||
float acc_float = 0.0f;
|
||||
|
||||
// Loop over batch
|
||||
for(index_t n = 0; n < dims.N; ++n)
|
||||
for(long_index_t idx = tid; idx < num_wei; idx += num_threads)
|
||||
{
|
||||
const TIn* in_n = p_in + static_cast<long_index_t>(n) * in_strides[0];
|
||||
const TOut* out_n = p_out_grad + static_cast<long_index_t>(n) * out_strides[0];
|
||||
index_t remaining = idx;
|
||||
const index_t x = remaining % X;
|
||||
remaining /= X;
|
||||
const index_t c = remaining % C;
|
||||
remaining /= C;
|
||||
const index_t k = remaining % K;
|
||||
const index_t g = remaining / K;
|
||||
|
||||
// Loop over output spatial dimensions
|
||||
for(index_t d_o = 0; d_o < dims.Do; ++d_o)
|
||||
float acc = 0.0f;
|
||||
const InDataType* in_g = p_in + g * in_stride_g;
|
||||
const OutDataType* out_grad = p_out_grad + g * out_stride_g;
|
||||
|
||||
// Loop over batch and output positions
|
||||
for(index_t n = 0; n < N; ++n)
|
||||
{
|
||||
// Calculate input position from output position
|
||||
index_t di = d_o * dims.stride_z - dims.pad_z + z * dims.dilation_z;
|
||||
if(di < 0 || di >= dims.Di)
|
||||
continue;
|
||||
const InDataType* in_gn = in_g + n * in_stride_n + c * in_stride_c;
|
||||
const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k;
|
||||
|
||||
const TIn* in_n_di = in_n + di * in_strides[1];
|
||||
const TOut* out_n_do = out_n + d_o * out_strides[1];
|
||||
|
||||
for(index_t ho = 0; ho < dims.Ho; ++ho)
|
||||
for(index_t wo = 0; wo < Wo; ++wo)
|
||||
{
|
||||
index_t hi = ho * dims.stride_y - dims.pad_y + y * dims.dilation_y;
|
||||
if(hi < 0 || hi >= dims.Hi)
|
||||
continue;
|
||||
|
||||
const TIn* in_n_di_hi = in_n_di + hi * in_strides[2];
|
||||
const TOut* out_n_do_ho = out_n_do + ho * out_strides[2];
|
||||
|
||||
for(index_t wo = 0; wo < dims.Wo; ++wo)
|
||||
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
|
||||
if(wi >= 0 && wi < Wi)
|
||||
{
|
||||
index_t wi = wo * dims.stride_x - dims.pad_x + x * dims.dilation_x;
|
||||
if(wi < 0 || wi >= dims.Wi)
|
||||
continue;
|
||||
|
||||
// Load values from memory (like forward does)
|
||||
const TIn* in_ptr = in_n_di_hi + wi * in_strides[3];
|
||||
const TOut* out_ptr = out_n_do_ho + wo * out_strides[3];
|
||||
|
||||
TIn in_loaded = in_ptr[c];
|
||||
TOut out_loaded = out_ptr[k];
|
||||
|
||||
// Apply element-wise operations
|
||||
in_op(in_val, in_loaded);
|
||||
out_op(out_val, out_loaded);
|
||||
|
||||
// Convert to float for multiplication
|
||||
float in_f = type_convert<float>(in_val);
|
||||
float out_f = type_convert<float>(out_val);
|
||||
|
||||
acc_float += out_f * in_f;
|
||||
in_op(in_val, in_gn[wi]);
|
||||
out_op(out_val, out_gn_k[wo]);
|
||||
acc += type_convert<float>(out_val) * type_convert<float>(in_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
WeiDataType result = type_convert<WeiDataType>(acc);
|
||||
wei_op(wei_val, result);
|
||||
p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + x] = wei_val;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
const long_index_t num_wei = G * K * C * Y * X;
|
||||
const long_index_t in_stride_g = N * C * Hi * Wi;
|
||||
const long_index_t in_stride_n = C * Hi * Wi;
|
||||
const long_index_t in_stride_c = Hi * Wi;
|
||||
const long_index_t in_stride_h = Wi;
|
||||
const long_index_t out_stride_g = N * K * Ho * Wo;
|
||||
const long_index_t out_stride_n = K * Ho * Wo;
|
||||
const long_index_t out_stride_k = Ho * Wo;
|
||||
const long_index_t out_stride_h = Wo;
|
||||
const long_index_t wei_stride_g = K * C * Y * X;
|
||||
const long_index_t wei_stride_k = C * Y * X;
|
||||
const long_index_t wei_stride_c = Y * X;
|
||||
const long_index_t wei_stride_y = X;
|
||||
|
||||
// Convert float accumulator to TAcc, then to weight type
|
||||
TAcc acc = type_convert<TAcc>(acc_float);
|
||||
TWei result = type_convert<TWei>(acc);
|
||||
for(long_index_t idx = tid; idx < num_wei; idx += num_threads)
|
||||
{
|
||||
index_t remaining = idx;
|
||||
const index_t x = remaining % X;
|
||||
remaining /= X;
|
||||
const index_t y = remaining % Y;
|
||||
remaining /= Y;
|
||||
const index_t c = remaining % C;
|
||||
remaining /= C;
|
||||
const index_t k = remaining % K;
|
||||
const index_t g = remaining / K;
|
||||
|
||||
// Apply weight element-wise operation (if any)
|
||||
wei_op(wei_val, result);
|
||||
float acc = 0.0f;
|
||||
const InDataType* in_g = p_in + g * in_stride_g;
|
||||
const OutDataType* out_grad = p_out_grad + g * out_stride_g;
|
||||
|
||||
// Write transformed result
|
||||
p_wei_grad[ii] = wei_val;
|
||||
// Loop over batch and output positions
|
||||
for(index_t n = 0; n < N; ++n)
|
||||
{
|
||||
const InDataType* in_gnc = in_g + n * in_stride_n + c * in_stride_c;
|
||||
const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k;
|
||||
|
||||
for(index_t ho = 0; ho < Ho; ++ho)
|
||||
{
|
||||
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
|
||||
if(hi >= 0 && hi < Hi)
|
||||
{
|
||||
const InDataType* in_gnch = in_gnc + hi * in_stride_h;
|
||||
const OutDataType* out_gn_kh = out_gn_k + ho * out_stride_h;
|
||||
|
||||
for(index_t wo = 0; wo < Wo; ++wo)
|
||||
{
|
||||
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
|
||||
if(wi >= 0 && wi < Wi)
|
||||
{
|
||||
in_op(in_val, in_gnch[wi]);
|
||||
out_op(out_val, out_gn_kh[wo]);
|
||||
acc += type_convert<float>(out_val) * type_convert<float>(in_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
WeiDataType result = type_convert<WeiDataType>(acc);
|
||||
wei_op(wei_val, result);
|
||||
p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + y * wei_stride_y +
|
||||
x] = wei_val;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
const long_index_t num_wei = G * K * C * Z * Y * X;
|
||||
const long_index_t in_stride_g = N * C * Di * Hi * Wi;
|
||||
const long_index_t in_stride_n = C * Di * Hi * Wi;
|
||||
const long_index_t in_stride_c = Di * Hi * Wi;
|
||||
const long_index_t in_stride_d = Hi * Wi;
|
||||
const long_index_t in_stride_h = Wi;
|
||||
const long_index_t out_stride_g = N * K * Do * Ho * Wo;
|
||||
const long_index_t out_stride_n = K * Do * Ho * Wo;
|
||||
const long_index_t out_stride_k = Do * Ho * Wo;
|
||||
const long_index_t out_stride_d = Ho * Wo;
|
||||
const long_index_t out_stride_h = Wo;
|
||||
const long_index_t wei_stride_g = K * C * Z * Y * X;
|
||||
const long_index_t wei_stride_k = C * Z * Y * X;
|
||||
const long_index_t wei_stride_c = Z * Y * X;
|
||||
const long_index_t wei_stride_z = Y * X;
|
||||
const long_index_t wei_stride_y = X;
|
||||
|
||||
for(long_index_t idx = tid; idx < num_wei; idx += num_threads)
|
||||
{
|
||||
index_t remaining = idx;
|
||||
const index_t x = remaining % X;
|
||||
remaining /= X;
|
||||
const index_t y = remaining % Y;
|
||||
remaining /= Y;
|
||||
const index_t z = remaining % Z;
|
||||
remaining /= Z;
|
||||
const index_t c = remaining % C;
|
||||
remaining /= C;
|
||||
const index_t k = remaining % K;
|
||||
const index_t g = remaining / K;
|
||||
|
||||
float acc = 0.0f;
|
||||
const InDataType* in_g = p_in + g * in_stride_g;
|
||||
const OutDataType* out_grad = p_out_grad + g * out_stride_g;
|
||||
|
||||
// Loop over batch and output positions
|
||||
for(index_t n = 0; n < N; ++n)
|
||||
{
|
||||
const InDataType* in_gnc = in_g + n * in_stride_n + c * in_stride_c;
|
||||
const OutDataType* out_gn_k = out_grad + n * out_stride_n + k * out_stride_k;
|
||||
|
||||
for(index_t do_idx = 0; do_idx < Do; ++do_idx)
|
||||
{
|
||||
long_index_t di = do_idx * stride_z + z * dilation_z - pad_z;
|
||||
if(di >= 0 && di < Di)
|
||||
{
|
||||
const InDataType* in_gncd = in_gnc + di * in_stride_d;
|
||||
const OutDataType* out_gn_kd = out_gn_k + do_idx * out_stride_d;
|
||||
|
||||
for(index_t ho = 0; ho < Ho; ++ho)
|
||||
{
|
||||
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
|
||||
if(hi >= 0 && hi < Hi)
|
||||
{
|
||||
const InDataType* in_gncdh = in_gncd + hi * in_stride_h;
|
||||
const OutDataType* out_gn_kdh = out_gn_kd + ho * out_stride_h;
|
||||
|
||||
for(index_t wo = 0; wo < Wo; ++wo)
|
||||
{
|
||||
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
|
||||
if(wi >= 0 && wi < Wi)
|
||||
{
|
||||
in_op(in_val, in_gncdh[wi]);
|
||||
out_op(out_val, out_gn_kdh[wo]);
|
||||
acc += type_convert<float>(out_val) *
|
||||
type_convert<float>(in_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
WeiDataType result = type_convert<WeiDataType>(acc);
|
||||
wei_op(wei_val, result);
|
||||
p_wei_grad[g * wei_stride_g + k * wei_stride_k + c * wei_stride_c + z * wei_stride_z +
|
||||
y * wei_stride_y + x] = wei_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GPU reference backward weight convolution - takes ConvParam directly
|
||||
template <typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
void naive_conv_bwd_weight(const TIn* p_in,
|
||||
TWei* p_wei_grad,
|
||||
const TOut* p_out,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
InElementwiseOperation in_element_op = InElementwiseOperation{},
|
||||
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
|
||||
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
|
||||
hipStream_t stream = nullptr)
|
||||
{
|
||||
const auto ndim = conv_param.num_dim_spatial_;
|
||||
|
||||
const index_t G = conv_param.G_;
|
||||
const index_t N = conv_param.N_;
|
||||
const index_t C = conv_param.C_;
|
||||
const index_t K = conv_param.K_;
|
||||
|
||||
std::vector<index_t> in_lengths = {G, N, C};
|
||||
std::vector<index_t> wei_lengths = {G, K, C};
|
||||
std::vector<index_t> out_lengths = {G, N, K};
|
||||
|
||||
for(index_t i = 0; i < ndim; ++i)
|
||||
{
|
||||
in_lengths.push_back(static_cast<index_t>(conv_param.input_spatial_lengths_[i]));
|
||||
wei_lengths.push_back(static_cast<index_t>(conv_param.filter_spatial_lengths_[i]));
|
||||
out_lengths.push_back(static_cast<index_t>(conv_param.output_spatial_lengths_[i]));
|
||||
}
|
||||
|
||||
// Calculate total elements for buffer allocation
|
||||
long_index_t in_total = 1, wei_total = 1, out_total = 1;
|
||||
for(auto l : in_lengths)
|
||||
in_total *= l;
|
||||
for(auto l : wei_lengths)
|
||||
wei_total *= l;
|
||||
for(auto l : out_lengths)
|
||||
out_total *= l;
|
||||
|
||||
// Allocate packed buffers
|
||||
SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn));
|
||||
SimpleDeviceMem wei_grad_packed_buf(wei_total * sizeof(TWei));
|
||||
SimpleDeviceMem out_grad_packed_buf(out_total * sizeof(TOut));
|
||||
|
||||
TIn* p_in_packed = static_cast<TIn*>(in_packed_buf.GetDeviceBuffer());
|
||||
TWei* p_wei_grad_packed = static_cast<TWei*>(wei_grad_packed_buf.GetDeviceBuffer());
|
||||
TOut* p_out_grad_packed = static_cast<TOut*>(out_grad_packed_buf.GetDeviceBuffer());
|
||||
|
||||
// Compute strides and allocate device arrays for pack/unpack
|
||||
std::vector<index_t> in_strides = compute_conv_tensor_strides<InLayout>(in_lengths, ndim);
|
||||
std::vector<index_t> wei_strides = compute_conv_tensor_strides<WeiLayout>(wei_lengths, ndim);
|
||||
std::vector<index_t> out_strides = compute_conv_tensor_strides<OutLayout>(out_lengths, ndim);
|
||||
|
||||
const size_t dim_count = in_lengths.size();
|
||||
SimpleDeviceMem in_lengths_buf(dim_count * sizeof(index_t));
|
||||
SimpleDeviceMem in_strides_buf(dim_count * sizeof(index_t));
|
||||
SimpleDeviceMem wei_lengths_buf(dim_count * sizeof(index_t));
|
||||
SimpleDeviceMem wei_strides_buf(dim_count * sizeof(index_t));
|
||||
SimpleDeviceMem out_lengths_buf(dim_count * sizeof(index_t));
|
||||
SimpleDeviceMem out_strides_buf(dim_count * sizeof(index_t));
|
||||
|
||||
index_t* d_in_lengths = static_cast<index_t*>(in_lengths_buf.GetDeviceBuffer());
|
||||
index_t* d_in_strides = static_cast<index_t*>(in_strides_buf.GetDeviceBuffer());
|
||||
index_t* d_wei_lengths = static_cast<index_t*>(wei_lengths_buf.GetDeviceBuffer());
|
||||
index_t* d_wei_strides = static_cast<index_t*>(wei_strides_buf.GetDeviceBuffer());
|
||||
index_t* d_out_lengths = static_cast<index_t*>(out_lengths_buf.GetDeviceBuffer());
|
||||
index_t* d_out_strides = static_cast<index_t*>(out_strides_buf.GetDeviceBuffer());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_in_lengths, in_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_in_strides, in_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_wei_lengths, wei_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_wei_strides, wei_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_out_lengths, out_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_out_strides, out_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
|
||||
|
||||
// Pack input and output_grad tensors to contiguous layout (inputs to bwd weight)
|
||||
constexpr int block_size = 256;
|
||||
strided_copy_kernel<TIn, false>
|
||||
<<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_in, p_in_packed, d_in_lengths, d_in_strides, dim_count, in_total);
|
||||
strided_copy_kernel<TOut, false>
|
||||
<<<(out_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_out, p_out_grad_packed, d_out_lengths, d_out_strides, dim_count, out_total);
|
||||
|
||||
// Build conv parameter vectors for kernel invocation
|
||||
std::vector<index_t> conv_strides(ndim);
|
||||
std::vector<index_t> conv_dilations(ndim);
|
||||
std::vector<index_t> input_pads(ndim);
|
||||
for(index_t i = 0; i < ndim; ++i)
|
||||
{
|
||||
conv_strides[i] = static_cast<index_t>(conv_param.conv_filter_strides_[i]);
|
||||
conv_dilations[i] = static_cast<index_t>(conv_param.conv_filter_dilations_[i]);
|
||||
input_pads[i] = static_cast<index_t>(conv_param.input_left_pads_[i]);
|
||||
}
|
||||
|
||||
// Run backward weight convolution kernel on packed data
|
||||
const int wei_grid = (wei_total + block_size - 1) / block_size;
|
||||
|
||||
if(ndim == 1)
|
||||
{
|
||||
naive_conv_bwd_weight_packed<1,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<wei_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
p_wei_grad_packed,
|
||||
p_out_grad_packed,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
1,
|
||||
1,
|
||||
in_lengths[3],
|
||||
1,
|
||||
1,
|
||||
wei_lengths[3],
|
||||
1,
|
||||
1,
|
||||
out_lengths[3],
|
||||
1,
|
||||
1,
|
||||
conv_strides[0],
|
||||
1,
|
||||
1,
|
||||
conv_dilations[0],
|
||||
0,
|
||||
0,
|
||||
input_pads[0],
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
else if(ndim == 2)
|
||||
{
|
||||
naive_conv_bwd_weight_packed<2,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<wei_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
p_wei_grad_packed,
|
||||
p_out_grad_packed,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
1,
|
||||
in_lengths[3],
|
||||
in_lengths[4],
|
||||
1,
|
||||
wei_lengths[3],
|
||||
wei_lengths[4],
|
||||
1,
|
||||
out_lengths[3],
|
||||
out_lengths[4],
|
||||
1,
|
||||
conv_strides[0],
|
||||
conv_strides[1],
|
||||
1,
|
||||
conv_dilations[0],
|
||||
conv_dilations[1],
|
||||
0,
|
||||
input_pads[0],
|
||||
input_pads[1],
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
else // 3D
|
||||
{
|
||||
naive_conv_bwd_weight_packed<3,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<wei_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
p_wei_grad_packed,
|
||||
p_out_grad_packed,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
in_lengths[3],
|
||||
in_lengths[4],
|
||||
in_lengths[5],
|
||||
wei_lengths[3],
|
||||
wei_lengths[4],
|
||||
wei_lengths[5],
|
||||
out_lengths[3],
|
||||
out_lengths[4],
|
||||
out_lengths[5],
|
||||
conv_strides[0],
|
||||
conv_strides[1],
|
||||
conv_strides[2],
|
||||
conv_dilations[0],
|
||||
conv_dilations[1],
|
||||
conv_dilations[2],
|
||||
input_pads[0],
|
||||
input_pads[1],
|
||||
input_pads[2],
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
|
||||
// Unpack weight gradient
|
||||
strided_copy_kernel<TWei, true><<<wei_grid, block_size, 0, stream>>>(
|
||||
p_wei_grad_packed, p_wei_grad, d_wei_lengths, d_wei_strides, dim_count, wei_total);
|
||||
|
||||
HIP_CHECK_ERROR(hipGetLastError());
|
||||
|
||||
// Memory automatically freed by SimpleDeviceMem destructors
|
||||
}
|
||||
|
||||
} // namespace ref
|
||||
} // namespace ck
|
||||
|
||||
@@ -4,126 +4,493 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
#include "ck/library/reference_tensor_operation/gpu/conv_common.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/reference_tensor_operation/gpu/naive_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace ref {
|
||||
|
||||
/*
|
||||
* \brief naive implementation of 3D convolution. Layout is (NDHWC, KZYXC, NDHWK).
|
||||
*
|
||||
* \param N number of batches
|
||||
* \param K number of filters
|
||||
* \param C number of channels of weight
|
||||
* \param (Di, Hi, Wi) depth, height and width dimension of data
|
||||
* \param (Z, Y, X) depth, height and width dimensions of weights
|
||||
* \param (Do, Ho, Wo) depth, height and width dimension of output
|
||||
* \param (stride_z, stride_y, stride_x) strides
|
||||
* \param (dilation_z, dilation_y, dilation_x) dilations
|
||||
* \param (pad_z, pad_y, pad_x) pads
|
||||
*/
|
||||
template <typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename TAcc,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
__global__ void naive_conv_fwd_ndhwc_kzyxc_ndhwk(const TIn* __restrict__ p_in,
|
||||
const TWei* __restrict__ p_wei,
|
||||
TOut* __restrict__ p_out,
|
||||
const ConvDims dims)
|
||||
// Optimized convolution kernel working with packed (contiguous) tensors
|
||||
// Assumes row-major packing: input[G][N][C][spatial], weight[G][K][C][filter],
|
||||
// output[G][N][K][spatial]
|
||||
template <index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InElementOp,
|
||||
typename WeiElementOp,
|
||||
typename OutElementOp>
|
||||
__global__ void naive_conv_fwd_packed(const InDataType* __restrict__ p_in,
|
||||
const WeiDataType* __restrict__ p_wei,
|
||||
OutDataType* __restrict__ p_out,
|
||||
index_t G,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t C,
|
||||
index_t Di,
|
||||
index_t Hi,
|
||||
index_t Wi,
|
||||
index_t Z,
|
||||
index_t Y,
|
||||
index_t X,
|
||||
index_t Do,
|
||||
index_t Ho,
|
||||
index_t Wo,
|
||||
index_t stride_z,
|
||||
index_t stride_y,
|
||||
index_t stride_x,
|
||||
index_t dilation_z,
|
||||
index_t dilation_y,
|
||||
index_t dilation_x,
|
||||
index_t pad_z,
|
||||
index_t pad_y,
|
||||
index_t pad_x,
|
||||
InElementOp in_op,
|
||||
WeiElementOp wei_op,
|
||||
OutElementOp out_op)
|
||||
{
|
||||
const index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const index_t num_threads = blockDim.x * gridDim.x;
|
||||
const long_index_t output_length = dims.N * dims.Do * dims.Ho * dims.Wo * dims.K;
|
||||
const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const long_index_t num_threads = blockDim.x * gridDim.x;
|
||||
|
||||
const index_t out_strides[] = {
|
||||
dims.Do * dims.Ho * dims.Wo * dims.K, dims.Ho * dims.Wo * dims.K, dims.Wo * dims.K, dims.K};
|
||||
const index_t in_strides[] = {
|
||||
dims.Di * dims.Hi * dims.Wi * dims.C, dims.Hi * dims.Wi * dims.C, dims.Wi * dims.C, dims.C};
|
||||
const index_t wei_strides[] = {
|
||||
dims.Z * dims.Y * dims.X * dims.C, dims.Y * dims.X * dims.C, dims.X * dims.C, dims.C};
|
||||
InDataType in_val = InDataType{0};
|
||||
WeiDataType wei_val = WeiDataType{0};
|
||||
OutDataType out_val = OutDataType{0};
|
||||
|
||||
constexpr auto in_op = InElementwiseOperation{};
|
||||
constexpr auto wei_op = WeiElementwiseOperation{};
|
||||
constexpr auto out_op = OutElementwiseOperation{};
|
||||
|
||||
TIn in_val = TIn{0};
|
||||
TWei wei_val = TWei{0};
|
||||
TOut out_val = TOut{0};
|
||||
|
||||
for(long_index_t ii = tid; ii < output_length; ii += num_threads)
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
const index_t n = ii / out_strides[0];
|
||||
index_t k = ii - n * out_strides[0];
|
||||
const index_t dO = k / out_strides[1];
|
||||
k -= dO * out_strides[1];
|
||||
const index_t ho = k / out_strides[2];
|
||||
k -= ho * out_strides[2];
|
||||
const index_t wo = k / out_strides[3];
|
||||
k -= wo * out_strides[3];
|
||||
const long_index_t num_out = G * N * K * Wo;
|
||||
const long_index_t in_stride_g = N * C * Wi;
|
||||
const long_index_t in_stride_n = C * Wi;
|
||||
const long_index_t in_stride_c = Wi;
|
||||
const long_index_t wei_stride_g = K * C * X;
|
||||
const long_index_t wei_stride_k = C * X;
|
||||
const long_index_t wei_stride_c = X;
|
||||
const long_index_t out_stride_g = N * K * Wo;
|
||||
const long_index_t out_stride_n = K * Wo;
|
||||
const long_index_t out_stride_k = Wo;
|
||||
|
||||
// Always accumulate in float (FP8/BF8 don't support arithmetic)
|
||||
float acc_float = 0.0f;
|
||||
|
||||
const TIn* in_n = p_in + static_cast<long_index_t>(n) * in_strides[0];
|
||||
const TWei* wei_k = p_wei + static_cast<long_index_t>(k) * wei_strides[0];
|
||||
|
||||
for(index_t z = 0; z < dims.Z; ++z)
|
||||
for(long_index_t idx = tid; idx < num_out; idx += num_threads)
|
||||
{
|
||||
index_t di = dims.stride_z * dO - dims.pad_z + dims.dilation_z * z;
|
||||
const TIn* in_n_di = in_n + di * in_strides[1];
|
||||
const TWei* wei_k_z = wei_k + z * wei_strides[1];
|
||||
index_t remaining = idx;
|
||||
const index_t wo = remaining % Wo;
|
||||
remaining /= Wo;
|
||||
const index_t k = remaining % K;
|
||||
remaining /= K;
|
||||
const index_t n = remaining % N;
|
||||
const index_t g = remaining / N;
|
||||
|
||||
for(index_t y = 0; y < dims.Y; ++y)
|
||||
float acc = 0.0f;
|
||||
const InDataType* in_g = p_in + g * in_stride_g + n * in_stride_n;
|
||||
const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k;
|
||||
|
||||
for(index_t c = 0; c < C; ++c)
|
||||
{
|
||||
index_t hi = dims.stride_y * ho - dims.pad_y + dims.dilation_y * y;
|
||||
const TIn* in_n_di_hi = in_n_di + hi * in_strides[2];
|
||||
const TWei* wei_k_z_y = wei_k_z + y * wei_strides[2];
|
||||
const InDataType* in_gc = in_g + c * in_stride_c;
|
||||
const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c;
|
||||
|
||||
for(index_t x = 0; x < dims.X; ++x)
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
index_t wi = dims.stride_x * wo - dims.pad_x + dims.dilation_x * x;
|
||||
const TIn* in_n_di_hi_wi = in_n_di_hi + wi * in_strides[3];
|
||||
const TWei* wei_k_z_y_x = wei_k_z_y + x * wei_strides[3];
|
||||
|
||||
if(di >= 0 && di < dims.Di && hi >= 0 && hi < dims.Hi && wi >= 0 &&
|
||||
wi < dims.Wi)
|
||||
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
|
||||
if(wi >= 0 && wi < Wi)
|
||||
{
|
||||
for(index_t c = 0; c < dims.C; ++c)
|
||||
in_op(in_val, in_gc[wi]);
|
||||
wei_op(wei_val, wei_gkc[x]);
|
||||
acc += type_convert<float>(in_val) * type_convert<float>(wei_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OutDataType result = type_convert<OutDataType>(acc);
|
||||
out_op(out_val, result);
|
||||
p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + wo] = out_val;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
const long_index_t num_out = G * N * K * Ho * Wo;
|
||||
const long_index_t in_stride_g = N * C * Hi * Wi;
|
||||
const long_index_t in_stride_n = C * Hi * Wi;
|
||||
const long_index_t in_stride_c = Hi * Wi;
|
||||
const long_index_t in_stride_h = Wi;
|
||||
const long_index_t wei_stride_g = K * C * Y * X;
|
||||
const long_index_t wei_stride_k = C * Y * X;
|
||||
const long_index_t wei_stride_c = Y * X;
|
||||
const long_index_t wei_stride_y = X;
|
||||
const long_index_t out_stride_g = N * K * Ho * Wo;
|
||||
const long_index_t out_stride_n = K * Ho * Wo;
|
||||
const long_index_t out_stride_k = Ho * Wo;
|
||||
const long_index_t out_stride_h = Wo;
|
||||
|
||||
for(long_index_t idx = tid; idx < num_out; idx += num_threads)
|
||||
{
|
||||
index_t remaining = idx;
|
||||
const index_t wo = remaining % Wo;
|
||||
remaining /= Wo;
|
||||
const index_t ho = remaining % Ho;
|
||||
remaining /= Ho;
|
||||
const index_t k = remaining % K;
|
||||
remaining /= K;
|
||||
const index_t n = remaining % N;
|
||||
const index_t g = remaining / N;
|
||||
|
||||
float acc = 0.0f;
|
||||
const InDataType* in_gn = p_in + g * in_stride_g + n * in_stride_n;
|
||||
const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k;
|
||||
|
||||
for(index_t c = 0; c < C; ++c)
|
||||
{
|
||||
const InDataType* in_gnc = in_gn + c * in_stride_c;
|
||||
const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c;
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
|
||||
if(hi >= 0 && hi < Hi)
|
||||
{
|
||||
const InDataType* in_gnch = in_gnc + hi * in_stride_h;
|
||||
const WeiDataType* wei_gkcy = wei_gkc + y * wei_stride_y;
|
||||
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
// Load values from memory
|
||||
TIn in_loaded = in_n_di_hi_wi[c];
|
||||
TWei wei_loaded = wei_k_z_y_x[c];
|
||||
|
||||
// Apply element-wise operations
|
||||
in_op(in_val, in_loaded);
|
||||
wei_op(wei_val, wei_loaded);
|
||||
|
||||
// Always convert to float for multiplication (FP8/BF8 don't support
|
||||
// direct arithmetic)
|
||||
float in_f = type_convert<float>(in_val);
|
||||
float wei_f = type_convert<float>(wei_val);
|
||||
|
||||
// Accumulate in float
|
||||
acc_float += in_f * wei_f;
|
||||
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
|
||||
if(wi >= 0 && wi < Wi)
|
||||
{
|
||||
in_op(in_val, in_gnch[wi]);
|
||||
wei_op(wei_val, wei_gkcy[x]);
|
||||
acc += type_convert<float>(in_val) * type_convert<float>(wei_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OutDataType result = type_convert<OutDataType>(acc);
|
||||
out_op(out_val, result);
|
||||
p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + ho * out_stride_h + wo] =
|
||||
out_val;
|
||||
}
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
const long_index_t num_out = G * N * K * Do * Ho * Wo;
|
||||
const long_index_t in_stride_g = N * C * Di * Hi * Wi;
|
||||
const long_index_t in_stride_n = C * Di * Hi * Wi;
|
||||
const long_index_t in_stride_c = Di * Hi * Wi;
|
||||
const long_index_t in_stride_d = Hi * Wi;
|
||||
const long_index_t in_stride_h = Wi;
|
||||
const long_index_t wei_stride_g = K * C * Z * Y * X;
|
||||
const long_index_t wei_stride_k = C * Z * Y * X;
|
||||
const long_index_t wei_stride_c = Z * Y * X;
|
||||
const long_index_t wei_stride_z = Y * X;
|
||||
const long_index_t wei_stride_y = X;
|
||||
const long_index_t out_stride_g = N * K * Do * Ho * Wo;
|
||||
const long_index_t out_stride_n = K * Do * Ho * Wo;
|
||||
const long_index_t out_stride_k = Do * Ho * Wo;
|
||||
const long_index_t out_stride_d = Ho * Wo;
|
||||
const long_index_t out_stride_h = Wo;
|
||||
|
||||
// Convert float accumulator to TAcc, then to output type
|
||||
TAcc acc = type_convert<TAcc>(acc_float);
|
||||
TOut result = type_convert<TOut>(acc);
|
||||
for(long_index_t idx = tid; idx < num_out; idx += num_threads)
|
||||
{
|
||||
index_t remaining = idx;
|
||||
const index_t wo = remaining % Wo;
|
||||
remaining /= Wo;
|
||||
const index_t ho = remaining % Ho;
|
||||
remaining /= Ho;
|
||||
const index_t do_idx = remaining % Do;
|
||||
remaining /= Do;
|
||||
const index_t k = remaining % K;
|
||||
remaining /= K;
|
||||
const index_t n = remaining % N;
|
||||
const index_t g = remaining / N;
|
||||
|
||||
// Apply output element-wise operation (if any)
|
||||
out_op(out_val, result);
|
||||
float acc = 0.0f;
|
||||
const InDataType* in_gn = p_in + g * in_stride_g + n * in_stride_n;
|
||||
const WeiDataType* wei_gk = p_wei + g * wei_stride_g + k * wei_stride_k;
|
||||
|
||||
// Write transformed result
|
||||
p_out[ii] = out_val;
|
||||
for(index_t c = 0; c < C; ++c)
|
||||
{
|
||||
const InDataType* in_gnc = in_gn + c * in_stride_c;
|
||||
const WeiDataType* wei_gkc = wei_gk + c * wei_stride_c;
|
||||
|
||||
for(index_t z = 0; z < Z; ++z)
|
||||
{
|
||||
long_index_t di = do_idx * stride_z + z * dilation_z - pad_z;
|
||||
if(di >= 0 && di < Di)
|
||||
{
|
||||
const InDataType* in_gncd = in_gnc + di * in_stride_d;
|
||||
const WeiDataType* wei_gkcz = wei_gkc + z * wei_stride_z;
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
long_index_t hi = ho * stride_y + y * dilation_y - pad_y;
|
||||
if(hi >= 0 && hi < Hi)
|
||||
{
|
||||
const InDataType* in_gncdh = in_gncd + hi * in_stride_h;
|
||||
const WeiDataType* wei_gkczy = wei_gkcz + y * wei_stride_y;
|
||||
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
long_index_t wi = wo * stride_x + x * dilation_x - pad_x;
|
||||
if(wi >= 0 && wi < Wi)
|
||||
{
|
||||
in_op(in_val, in_gncdh[wi]);
|
||||
wei_op(wei_val, wei_gkczy[x]);
|
||||
acc += type_convert<float>(in_val) *
|
||||
type_convert<float>(wei_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OutDataType result = type_convert<OutDataType>(acc);
|
||||
out_op(out_val, result);
|
||||
p_out[g * out_stride_g + n * out_stride_n + k * out_stride_k + do_idx * out_stride_d +
|
||||
ho * out_stride_h + wo] = out_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GPU reference convolution - takes ConvParam directly
|
||||
template <typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename TIn,
|
||||
typename TWei,
|
||||
typename TOut,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
void naive_conv_fwd(const TIn* p_in,
|
||||
const TWei* p_wei,
|
||||
TOut* p_out,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
InElementwiseOperation in_element_op = InElementwiseOperation{},
|
||||
WeiElementwiseOperation wei_element_op = WeiElementwiseOperation{},
|
||||
OutElementwiseOperation out_element_op = OutElementwiseOperation{},
|
||||
hipStream_t stream = nullptr)
|
||||
{
|
||||
const auto ndim = conv_param.num_dim_spatial_;
|
||||
|
||||
const index_t G = conv_param.G_;
|
||||
const index_t N = conv_param.N_;
|
||||
const index_t C = conv_param.C_;
|
||||
const index_t K = conv_param.K_;
|
||||
|
||||
std::vector<index_t> in_lengths = {G, N, C};
|
||||
std::vector<index_t> wei_lengths = {G, K, C};
|
||||
std::vector<index_t> out_lengths = {G, N, K};
|
||||
|
||||
for(index_t i = 0; i < ndim; ++i)
|
||||
{
|
||||
in_lengths.push_back(static_cast<index_t>(conv_param.input_spatial_lengths_[i]));
|
||||
wei_lengths.push_back(static_cast<index_t>(conv_param.filter_spatial_lengths_[i]));
|
||||
out_lengths.push_back(static_cast<index_t>(conv_param.output_spatial_lengths_[i]));
|
||||
}
|
||||
|
||||
// Calculate total elements for buffer allocation
|
||||
long_index_t in_total = 1, wei_total = 1, out_total = 1;
|
||||
for(auto l : in_lengths)
|
||||
in_total *= l;
|
||||
for(auto l : wei_lengths)
|
||||
wei_total *= l;
|
||||
for(auto l : out_lengths)
|
||||
out_total *= l;
|
||||
|
||||
// Allocate packed buffers
|
||||
SimpleDeviceMem in_packed_buf(in_total * sizeof(TIn));
|
||||
SimpleDeviceMem wei_packed_buf(wei_total * sizeof(TWei));
|
||||
SimpleDeviceMem out_packed_buf(out_total * sizeof(TOut));
|
||||
|
||||
TIn* p_in_packed = static_cast<TIn*>(in_packed_buf.GetDeviceBuffer());
|
||||
TWei* p_wei_packed = static_cast<TWei*>(wei_packed_buf.GetDeviceBuffer());
|
||||
TOut* p_out_packed = static_cast<TOut*>(out_packed_buf.GetDeviceBuffer());
|
||||
|
||||
// Compute strides and allocate device arrays for pack/unpack
|
||||
std::vector<index_t> in_strides = compute_conv_tensor_strides<InLayout>(in_lengths, ndim);
|
||||
std::vector<index_t> wei_strides = compute_conv_tensor_strides<WeiLayout>(wei_lengths, ndim);
|
||||
std::vector<index_t> out_strides = compute_conv_tensor_strides<OutLayout>(out_lengths, ndim);
|
||||
|
||||
const size_t dim_count = in_lengths.size();
|
||||
SimpleDeviceMem in_lengths_buf(dim_count * sizeof(index_t));
|
||||
SimpleDeviceMem in_strides_buf(dim_count * sizeof(index_t));
|
||||
SimpleDeviceMem wei_lengths_buf(dim_count * sizeof(index_t));
|
||||
SimpleDeviceMem wei_strides_buf(dim_count * sizeof(index_t));
|
||||
SimpleDeviceMem out_lengths_buf(dim_count * sizeof(index_t));
|
||||
SimpleDeviceMem out_strides_buf(dim_count * sizeof(index_t));
|
||||
|
||||
index_t* d_in_lengths = static_cast<index_t*>(in_lengths_buf.GetDeviceBuffer());
|
||||
index_t* d_in_strides = static_cast<index_t*>(in_strides_buf.GetDeviceBuffer());
|
||||
index_t* d_wei_lengths = static_cast<index_t*>(wei_lengths_buf.GetDeviceBuffer());
|
||||
index_t* d_wei_strides = static_cast<index_t*>(wei_strides_buf.GetDeviceBuffer());
|
||||
index_t* d_out_lengths = static_cast<index_t*>(out_lengths_buf.GetDeviceBuffer());
|
||||
index_t* d_out_strides = static_cast<index_t*>(out_strides_buf.GetDeviceBuffer());
|
||||
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_in_lengths, in_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_in_strides, in_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_wei_lengths, wei_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_wei_strides, wei_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_out_lengths, out_lengths.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
d_out_strides, out_strides.data(), dim_count * sizeof(index_t), hipMemcpyHostToDevice));
|
||||
|
||||
// Pack input and weight tensors to contiguous layout
|
||||
constexpr int block_size = 256;
|
||||
strided_copy_kernel<TIn, false>
|
||||
<<<(in_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_in, p_in_packed, d_in_lengths, d_in_strides, dim_count, in_total);
|
||||
strided_copy_kernel<TWei, false>
|
||||
<<<(wei_total + block_size - 1) / block_size, block_size, 0, stream>>>(
|
||||
p_wei, p_wei_packed, d_wei_lengths, d_wei_strides, dim_count, wei_total);
|
||||
|
||||
// Build conv parameter vectors for kernel invocation
|
||||
std::vector<index_t> conv_strides(ndim);
|
||||
std::vector<index_t> conv_dilations(ndim);
|
||||
std::vector<index_t> input_pads(ndim);
|
||||
for(index_t i = 0; i < ndim; ++i)
|
||||
{
|
||||
conv_strides[i] = static_cast<index_t>(conv_param.conv_filter_strides_[i]);
|
||||
conv_dilations[i] = static_cast<index_t>(conv_param.conv_filter_dilations_[i]);
|
||||
input_pads[i] = static_cast<index_t>(conv_param.input_left_pads_[i]);
|
||||
}
|
||||
|
||||
// Run convolution kernel on packed data
|
||||
const int out_grid = (out_total + block_size - 1) / block_size;
|
||||
|
||||
if(ndim == 1)
|
||||
{
|
||||
naive_conv_fwd_packed<1,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<out_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
p_wei_packed,
|
||||
p_out_packed,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
1,
|
||||
1,
|
||||
in_lengths[3],
|
||||
1,
|
||||
1,
|
||||
wei_lengths[3],
|
||||
1,
|
||||
1,
|
||||
out_lengths[3],
|
||||
1,
|
||||
1,
|
||||
conv_strides[0],
|
||||
1,
|
||||
1,
|
||||
conv_dilations[0],
|
||||
0,
|
||||
0,
|
||||
input_pads[0],
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
else if(ndim == 2)
|
||||
{
|
||||
naive_conv_fwd_packed<2,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<out_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
p_wei_packed,
|
||||
p_out_packed,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
1,
|
||||
in_lengths[3],
|
||||
in_lengths[4],
|
||||
1,
|
||||
wei_lengths[3],
|
||||
wei_lengths[4],
|
||||
1,
|
||||
out_lengths[3],
|
||||
out_lengths[4],
|
||||
1,
|
||||
conv_strides[0],
|
||||
conv_strides[1],
|
||||
1,
|
||||
conv_dilations[0],
|
||||
conv_dilations[1],
|
||||
0,
|
||||
input_pads[0],
|
||||
input_pads[1],
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
else // 3D
|
||||
{
|
||||
naive_conv_fwd_packed<3,
|
||||
TIn,
|
||||
TWei,
|
||||
TOut,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
<<<out_grid, block_size, 0, stream>>>(p_in_packed,
|
||||
p_wei_packed,
|
||||
p_out_packed,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
in_lengths[3],
|
||||
in_lengths[4],
|
||||
in_lengths[5],
|
||||
wei_lengths[3],
|
||||
wei_lengths[4],
|
||||
wei_lengths[5],
|
||||
out_lengths[3],
|
||||
out_lengths[4],
|
||||
out_lengths[5],
|
||||
conv_strides[0],
|
||||
conv_strides[1],
|
||||
conv_strides[2],
|
||||
conv_dilations[0],
|
||||
conv_dilations[1],
|
||||
conv_dilations[2],
|
||||
input_pads[0],
|
||||
input_pads[1],
|
||||
input_pads[2],
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
|
||||
// Unpack
|
||||
strided_copy_kernel<TOut, true><<<out_grid, block_size, 0, stream>>>(
|
||||
p_out_packed, p_out, d_out_lengths, d_out_strides, dim_count, out_total);
|
||||
|
||||
HIP_CHECK_ERROR(hipGetLastError());
|
||||
|
||||
// Memory automatically freed by SimpleDeviceMem destructors
|
||||
}
|
||||
|
||||
} // namespace ref
|
||||
} // namespace ck
|
||||
|
||||
@@ -0,0 +1,177 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/host_utility/hip_check_error.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <vector>
|
||||
|
||||
namespace ck {
|
||||
namespace ref {
|
||||
|
||||
// RAII wrapper for device memory to prevent leaks
|
||||
struct SimpleDeviceMem
|
||||
{
|
||||
SimpleDeviceMem() = delete;
|
||||
|
||||
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
|
||||
{
|
||||
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&p_mem_), mem_size));
|
||||
}
|
||||
|
||||
void* GetDeviceBuffer() { return p_mem_; }
|
||||
|
||||
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
|
||||
|
||||
void* p_mem_;
|
||||
};
|
||||
|
||||
// Helper function to map layout dimension character to index in lengths array
|
||||
// lengths array structure: [G, N/K, C/K, spatial...]
|
||||
inline int map_dim_char_to_index(char dim_char, index_t ndim_spatial, bool is_weight)
|
||||
{
|
||||
// G dimension
|
||||
if(dim_char == 'G')
|
||||
return 0;
|
||||
|
||||
// Batch/output channels dimension (N for input/output, K for weight's first non-G dim)
|
||||
if(dim_char == 'N')
|
||||
return 1;
|
||||
if(dim_char == 'K' && is_weight)
|
||||
return 1;
|
||||
|
||||
// Channel dimension (C for input/weight, K for output)
|
||||
if(dim_char == 'C')
|
||||
return 2;
|
||||
if(dim_char == 'K' && !is_weight)
|
||||
return 2;
|
||||
|
||||
// Spatial dimensions - map based on ndim_spatial
|
||||
// Input/Output use: D/H/W, Weight uses: Z/Y/X
|
||||
if(ndim_spatial == 1)
|
||||
{
|
||||
if(dim_char == 'W' || dim_char == 'X')
|
||||
return 3;
|
||||
}
|
||||
else if(ndim_spatial == 2)
|
||||
{
|
||||
if(dim_char == 'H' || dim_char == 'Y')
|
||||
return 3;
|
||||
if(dim_char == 'W' || dim_char == 'X')
|
||||
return 4;
|
||||
}
|
||||
else if(ndim_spatial == 3)
|
||||
{
|
||||
if(dim_char == 'D' || dim_char == 'Z')
|
||||
return 3;
|
||||
if(dim_char == 'H' || dim_char == 'Y')
|
||||
return 4;
|
||||
if(dim_char == 'W' || dim_char == 'X')
|
||||
return 5;
|
||||
}
|
||||
|
||||
// Should not reach here
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Template function to compute layout-aware strides based on layout name
|
||||
// The layout name directly encodes memory ordering from left to right
|
||||
template <typename Layout>
|
||||
inline std::vector<index_t> compute_conv_tensor_strides(const std::vector<index_t>& lengths,
|
||||
index_t ndim_spatial)
|
||||
{
|
||||
constexpr const char* layout_name = Layout::name;
|
||||
const int num_dims = static_cast<int>(lengths.size());
|
||||
std::vector<index_t> strides(num_dims, 0);
|
||||
|
||||
// Determine if this is a weight tensor (has 'K' but not 'N')
|
||||
bool has_k = false;
|
||||
bool has_n = false;
|
||||
for(const char* p = layout_name; *p != '\0'; ++p)
|
||||
{
|
||||
if(*p == 'K')
|
||||
has_k = true;
|
||||
if(*p == 'N')
|
||||
has_n = true;
|
||||
}
|
||||
bool is_weight = has_k && !has_n;
|
||||
|
||||
// Build dimension ordering from layout name (parse string)
|
||||
std::vector<char> dim_order;
|
||||
const char dim_chars[] = {'G', 'N', 'K', 'C', 'D', 'H', 'W', 'X', 'Y', 'Z'};
|
||||
for(const char* p = layout_name; *p != '\0'; ++p)
|
||||
{
|
||||
char c = *p;
|
||||
// Skip underscores (strided layouts)
|
||||
if(c == '_')
|
||||
continue;
|
||||
// Valid dimension characters
|
||||
if(std::find(std::begin(dim_chars), std::end(dim_chars), c) != std::end(dim_chars))
|
||||
{
|
||||
dim_order.push_back(c);
|
||||
}
|
||||
}
|
||||
|
||||
// Compute strides: process from right to left (innermost to outermost)
|
||||
index_t stride = 1;
|
||||
for(int i = static_cast<int>(dim_order.size()) - 1; i >= 0; --i)
|
||||
{
|
||||
char dim_char = dim_order[i];
|
||||
int length_idx = map_dim_char_to_index(dim_char, ndim_spatial, is_weight);
|
||||
|
||||
if(length_idx >= 0 && length_idx < num_dims)
|
||||
{
|
||||
strides[length_idx] = stride;
|
||||
stride *= lengths[length_idx];
|
||||
}
|
||||
}
|
||||
|
||||
return strides;
|
||||
}
|
||||
|
||||
// Unified kernel for strided tensor copy operations
|
||||
// IsUnpack=false: Pack strided -> contiguous
|
||||
// IsUnpack=true: Unpack contiguous -> strided
|
||||
template <typename DataType, bool IsUnpack>
|
||||
__global__ void strided_copy_kernel(const DataType* __restrict__ src,
|
||||
DataType* __restrict__ dst,
|
||||
const index_t* tensor_lengths,
|
||||
const index_t* strided_strides,
|
||||
int num_dims,
|
||||
long_index_t total_elements)
|
||||
{
|
||||
const long_index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const long_index_t num_threads = blockDim.x * gridDim.x;
|
||||
|
||||
for(long_index_t linear_idx = tid; linear_idx < total_elements; linear_idx += num_threads)
|
||||
{
|
||||
// Compute strided index from linear index
|
||||
long_index_t remaining = linear_idx;
|
||||
long_index_t strided_idx = 0;
|
||||
|
||||
for(int dim = num_dims - 1; dim >= 0; --dim)
|
||||
{
|
||||
index_t coord = remaining % tensor_lengths[dim];
|
||||
remaining /= tensor_lengths[dim];
|
||||
strided_idx += coord * strided_strides[dim];
|
||||
}
|
||||
|
||||
// Direction determines which is src and which is dst
|
||||
if constexpr(IsUnpack)
|
||||
{
|
||||
// Unpack: src is contiguous (linear_idx), dst is strided (strided_idx)
|
||||
dst[strided_idx] = src[linear_idx];
|
||||
}
|
||||
else
|
||||
{
|
||||
// Pack: src is strided (strided_idx), dst is contiguous (linear_idx)
|
||||
dst[linear_idx] = src[strided_idx];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ref
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user