mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
adding fp16 direct that reads pre-vectorized data
This commit is contained in:
@@ -13,8 +13,8 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
unsigned nrepeat)
|
||||
{
|
||||
constexpr unsigned NVector = 1;
|
||||
using vector_type_t = vector_type<T, NVector>;
|
||||
using vector_t = typename vector_type_t::VectorType;
|
||||
using vector_t = vector_type<T, NVector>;
|
||||
using vector_mem_t = typename vector_t::MemoryType;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -41,40 +41,41 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
auto in_nchw_vec_desc = make_ConstantTensorDescriptor(Sequence<N, C / NVector, Hi, Wi>{});
|
||||
ostream_ConstantTensorDescriptor(in_nchw_vec_desc, std::cout << "in_nchw_vec_desc: ");
|
||||
|
||||
Tensor<vector_t> in_nchw_vec(make_TensorDescriptor(in_nchw_vec_desc));
|
||||
Tensor<vector_mem_t> in_nchw_vec(make_TensorDescriptor(in_nchw_vec_desc));
|
||||
|
||||
auto f_vectorized_nchw = [&](auto n, auto c, auto h, auto w) {
|
||||
#if 1
|
||||
in_nchw_vec(n, c, h, w) = in_nchw(n, c, h, w);
|
||||
#else
|
||||
in_nchw_vec(n, c, h, w) =
|
||||
vector_type_t::pack(in_nchw(n, 2 * c, h, w), in_nchw(n, 2 * c + 1, h, w));
|
||||
vector_t::Pack(in_nchw(n, 2 * c, h, w), in_nchw(n, 2 * c + 1, h, w));
|
||||
#endif
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_vectorized_nchw, N, C, Hi, Wi)(
|
||||
make_ParallelTensorFunctor(f_vectorized_nchw, N, C / NVector, Hi, Wi)(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
// vectorize weight
|
||||
auto wei_kcyx_vec_desc = make_ConstantTensorDescriptor(Sequence<K, C / NVector, Y, X>{});
|
||||
ostream_ConstantTensorDescriptor(wei_kcyx_vec_desc, std::cout << "wei_kcyx_vec_desc: ");
|
||||
|
||||
Tensor<vector_t> wei_kcyx_vec(make_TensorDescriptor(wei_kcyx_vec_desc));
|
||||
Tensor<vector_mem_t> wei_kcyx_vec(make_TensorDescriptor(wei_kcyx_vec_desc));
|
||||
|
||||
auto f_vectorized_kcyx = [&](auto k, auto c, auto y, auto x) {
|
||||
#if 1
|
||||
wei_kcyx_vec(k, c, y, x) = wei_kcyx(k, c, y, x);
|
||||
#else
|
||||
wei_kcyx_vec(k, c, y, x) =
|
||||
vector_type_t::pack(wei_kcyx(k, 2 * c, y, x), wei_kcyx(k, 2 * c + 1, y, x));
|
||||
vector_t::Pack(wei_kcyx(k, 2 * c, y, x), wei_kcyx(k, 2 * c + 1, y, x));
|
||||
#endif
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_vectorized_kcyx, K, C, Y, X)(std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(f_vectorized_kcyx, K, C / NVector, Y, X)(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
//
|
||||
DeviceMem in_nchw_vec_device_buf(sizeof(vector_t) * in_nchw_vec.mDesc.GetElementSpace());
|
||||
DeviceMem wei_kcyx_vec_device_buf(sizeof(vector_t) * wei_kcyx_vec.mDesc.GetElementSpace());
|
||||
DeviceMem in_nchw_vec_device_buf(sizeof(vector_mem_t) * in_nchw_vec.mDesc.GetElementSpace());
|
||||
DeviceMem wei_kcyx_vec_device_buf(sizeof(vector_mem_t) * wei_kcyx_vec.mDesc.GetElementSpace());
|
||||
DeviceMem out_nkhw_device_buf(sizeof(T) * out_nkhw.mDesc.GetElementSpace());
|
||||
|
||||
in_nchw_vec_device_buf.ToDevice(in_nchw_vec.mData.data());
|
||||
@@ -82,7 +83,7 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
// 3x3, 34x34, 128 thread
|
||||
// 3x3, 34x34, 128 thread, fp32, vector = 1
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
constexpr unsigned KPerBlock = 32;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
@@ -96,24 +97,42 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 2;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 1
|
||||
// 3x3, 34x34, 128 thread, fp16
|
||||
// 3x3, 34x34, 128 thread, fp32, vector = 2
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
constexpr unsigned KPerBlock = 32;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
constexpr unsigned CPerBlock = 2;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 32;
|
||||
|
||||
constexpr unsigned NPerThread = 2;
|
||||
constexpr unsigned KPerThread = 4;
|
||||
constexpr unsigned CPerThread = 2;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 2;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 1
|
||||
// 3x3, 34x34, 128 thread, fp16
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
constexpr unsigned KPerBlock = 32;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 32;
|
||||
|
||||
constexpr unsigned NPerThread = 2;
|
||||
constexpr unsigned KPerThread = 4;
|
||||
constexpr unsigned CPerThread = 2;
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 2;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
|
||||
@@ -373,7 +373,7 @@ template <unsigned BlockSize,
|
||||
unsigned DataPerRead>
|
||||
struct Blockwise2dTensorCopy3
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::VectorType;
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
unsigned mSrcMyThreadOffset;
|
||||
unsigned mDstMyThreadOffset;
|
||||
|
||||
@@ -207,7 +207,7 @@ template <unsigned BlockSize,
|
||||
unsigned DataPerRead>
|
||||
struct Blockwise4dTensorCopy1
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::VectorType;
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
__device__ constexpr Blockwise4dTensorCopy1()
|
||||
{
|
||||
@@ -444,7 +444,7 @@ template <unsigned BlockSize,
|
||||
unsigned DataPerRead>
|
||||
struct Blockwise4dTensorCopy3
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::VectorType;
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
unsigned mSrcMyThreadOffset;
|
||||
unsigned mDstMyThreadOffset;
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#pragma once
|
||||
#include "data_type.hip.hpp"
|
||||
#include "constant_integral.hip.hpp"
|
||||
#include "Sequence.hip.hpp"
|
||||
#include "Array.hip.hpp"
|
||||
@@ -20,97 +21,6 @@ struct is_same<T, T>
|
||||
static const bool value = true;
|
||||
};
|
||||
|
||||
template <class T, unsigned N>
|
||||
struct vector_type
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<float, 1>
|
||||
{
|
||||
using VectorType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<float, 2>
|
||||
{
|
||||
using VectorType = float2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<float, 4>
|
||||
{
|
||||
using VectorType = float4;
|
||||
};
|
||||
|
||||
#if 0
|
||||
template <>
|
||||
struct vector_type<half_float::half, 1>
|
||||
{
|
||||
using VectorType = half_float::half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half_float::half, 2>
|
||||
{
|
||||
using VectorType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half_float::half, 4>
|
||||
{
|
||||
using VectorType = float2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half_float::half, 8>
|
||||
{
|
||||
using VectorType = float4;
|
||||
};
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
template <>
|
||||
struct vector_type<half, 1>
|
||||
{
|
||||
using VectorType = half;
|
||||
|
||||
__host__ __device__ static VectorType pack(half s) { return s; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half, 2>
|
||||
{
|
||||
using VectorType = half2;
|
||||
|
||||
union Data
|
||||
{
|
||||
VectorType vector;
|
||||
half scalar[2];
|
||||
};
|
||||
|
||||
__host__ __device__ static VectorType pack(half s0, half s1)
|
||||
{
|
||||
Data data;
|
||||
data.scalar[0] = s0;
|
||||
data.scalar[1] = s1;
|
||||
return data.vector;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half, 4>
|
||||
{
|
||||
using VectorType = float2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half, 8>
|
||||
{
|
||||
using VectorType = float4;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T max(T a, T b)
|
||||
{
|
||||
|
||||
@@ -4,10 +4,8 @@
|
||||
|
||||
#if DEVICE_BACKEND_HIP
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "half.hpp"
|
||||
#elif DEVICE_BACKEND_CUDA
|
||||
#include "cuda_runtime.h"
|
||||
#include "nvToolsExt.h"
|
||||
#include "helper_cuda.h"
|
||||
#include "cuda_fp16.h"
|
||||
#endif
|
||||
|
||||
@@ -47,3 +47,11 @@ struct static_const_reduce_n<1>
|
||||
return f(Number<0>{});
|
||||
}
|
||||
};
|
||||
|
||||
#if 0
|
||||
template<class F>
|
||||
__host__ __device__ constexpr auto unpacker(F f)
|
||||
{
|
||||
return [=](auto xs_array){ f(xs...); };
|
||||
}
|
||||
#endif
|
||||
@@ -27,12 +27,14 @@ template <class Float,
|
||||
unsigned BlockSize,
|
||||
unsigned GridSize>
|
||||
__global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
const typename vector_type<Float, ScalarPerVector>::VectorType* const __restrict__ p_in_vec_global,
|
||||
const typename vector_type<Float, ScalarPerVector>::VectorType* const __restrict__ p_wei_vec_global,
|
||||
const typename vector_type<Float,
|
||||
ScalarPerVector>::MemoryType* const __restrict__ p_in_vec_global,
|
||||
const typename vector_type<Float,
|
||||
ScalarPerVector>::MemoryType* const __restrict__ p_wei_vec_global,
|
||||
Float* const __restrict__ p_out_global)
|
||||
{
|
||||
using scalar_t = Float;
|
||||
using vector_t = typename vector_type<scalar_t, ScalarPerVector>::VectorType;
|
||||
using scalar_t = Float;
|
||||
using vector_mem_t = typename vector_type<scalar_t, ScalarPerVector>::MemoryType;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -69,6 +71,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
// shared mem
|
||||
constexpr unsigned in_block_size =
|
||||
in_nchw_vec_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr unsigned wei_block_size =
|
||||
wei_kcyx_vec_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
@@ -76,8 +79,10 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
? InBlockCopyDataPerRead
|
||||
: WeiBlockCopyDataPerRead;
|
||||
|
||||
__shared__ vector_t p_in_vec_block[max_align * ((in_block_size + max_align - 1) / max_align)];
|
||||
__shared__ vector_t p_wei_vec_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
|
||||
__shared__ vector_mem_t
|
||||
p_in_vec_block[max_align * ((in_block_size + max_align - 1) / max_align)];
|
||||
__shared__ vector_mem_t
|
||||
p_wei_vec_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
|
||||
|
||||
// threadwise tensors
|
||||
constexpr unsigned HiPerThread = HoPerThread + Y - 1;
|
||||
@@ -150,7 +155,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
|
||||
constexpr auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
vector_t,
|
||||
vector_mem_t,
|
||||
decltype(in_nchw_vec_global_desc),
|
||||
decltype(in_nchw_vec_block_desc),
|
||||
decltype(in_nchw_vec_block_desc.GetLengths()),
|
||||
@@ -159,7 +164,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
#if 0
|
||||
constexpr auto blockwise_wei_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
vector_t,
|
||||
vector_mem_t,
|
||||
decltype(wei_kcyx_vec_global_desc),
|
||||
decltype(wei_kcyx_vec_block_desc),
|
||||
decltype(wei_kcyx_vec_block_desc.GetLengths()),
|
||||
@@ -167,7 +172,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
#elif 1
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
vector_t,
|
||||
vector_mem_t,
|
||||
decltype(wei_ke_vec_global_desc),
|
||||
decltype(wei_ke_vec_block_desc),
|
||||
decltype(wei_ke_vec_block_desc.GetLengths()),
|
||||
@@ -181,15 +186,16 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
c_block_data_begin += CPerBlock, __syncthreads())
|
||||
{
|
||||
// copy input tensor to LDS
|
||||
blockwise_in_copy.Run(p_in_vec_global + in_nchw_vec_global_desc.Get1dIndex(n_block_data_begin,
|
||||
c_block_data_begin,
|
||||
hi_block_data_begin,
|
||||
wi_block_data_begin),
|
||||
blockwise_in_copy.Run(p_in_vec_global +
|
||||
in_nchw_vec_global_desc.Get1dIndex(n_block_data_begin,
|
||||
c_block_data_begin,
|
||||
hi_block_data_begin,
|
||||
wi_block_data_begin),
|
||||
p_in_vec_block);
|
||||
|
||||
// copy weight tensor to LDS
|
||||
blockwise_wei_copy.Run(p_wei_vec_global + wei_kcyx_vec_global_desc.Get1dIndex(
|
||||
k_block_data_begin, c_block_data_begin, 0, 0),
|
||||
k_block_data_begin, c_block_data_begin, 0, 0),
|
||||
p_wei_vec_block);
|
||||
|
||||
__syncthreads();
|
||||
@@ -201,9 +207,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
threadwise_direct_convolution_2(
|
||||
in_nchw_vec_thread_block_desc,
|
||||
p_in_vec_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_kcyx_vec_thread_block_desc,
|
||||
p_wei_vec_block +
|
||||
wei_kcyx_vec_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
|
||||
@@ -213,9 +219,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
threadwise_direct_convolution_3(
|
||||
in_nchw_vec_thread_block_desc,
|
||||
p_in_vec_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_kcyx_vec_thread_block_desc,
|
||||
p_wei_vec_block +
|
||||
wei_kcyx_vec_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
|
||||
|
||||
@@ -2,13 +2,13 @@
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
|
||||
// optimized for scenario if p_in, p_wei, p_out are in register
|
||||
template <class Float, class InDesc, class WeiDesc, class OutDesc>
|
||||
template <class TInWei, class TOut, class InDesc, class WeiDesc, class OutDesc>
|
||||
__device__ void threadwise_direct_convolution_1(InDesc,
|
||||
Float* const __restrict__ p_in,
|
||||
TInWei* const __restrict__ p_in,
|
||||
WeiDesc,
|
||||
Float* const __restrict__ p_wei,
|
||||
TInWei* const __restrict__ p_wei,
|
||||
OutDesc,
|
||||
Float* __restrict__ p_out)
|
||||
TOut* __restrict__ p_out)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -51,25 +51,10 @@ __device__ void threadwise_direct_convolution_1(InDesc,
|
||||
|
||||
const unsigned out_index = out_desc.Get1dIndex(n, k, ho, wo);
|
||||
|
||||
p_out[out_index] += p_wei[wei_index] * p_in[in_index];
|
||||
|
||||
#if 0
|
||||
// if(threadIdx.x == 0)
|
||||
{
|
||||
printf("threadwise_direct_convolution: \t"
|
||||
"threadIdx.x %u\t"
|
||||
"out_index %u, p_out[out_index] %f, \t"
|
||||
"wei_index %u, p_wei[wei_index] %f, \t"
|
||||
"in_index %u, p_in[in_index] %f\n",
|
||||
threadIdx.x,
|
||||
out_index,
|
||||
p_out[out_index],
|
||||
wei_index,
|
||||
p_wei[wei_index],
|
||||
in_index,
|
||||
p_in[in_index]);
|
||||
}
|
||||
#endif
|
||||
fused_multiply_add(p_out[out_index],
|
||||
p_wei[wei_index],
|
||||
p_in[in_index],
|
||||
p_out[out_index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -81,13 +66,13 @@ __device__ void threadwise_direct_convolution_1(InDesc,
|
||||
|
||||
// Optimized for scenario if p_in and p_wei are in LDS, p_out are in register
|
||||
// Copy in and wei into register before doing convolution
|
||||
template <class Float, class InDesc, class WeiDesc, class OutDesc>
|
||||
template <class TInWei, class TOut, class InDesc, class WeiDesc, class OutDesc>
|
||||
__device__ void threadwise_direct_convolution_2(InDesc,
|
||||
Float* const __restrict__ p_in,
|
||||
TInWei* const __restrict__ p_in,
|
||||
WeiDesc,
|
||||
Float* const __restrict__ p_wei,
|
||||
TInWei* const __restrict__ p_wei,
|
||||
OutDesc,
|
||||
Float* __restrict__ p_out)
|
||||
TOut* __restrict__ p_out)
|
||||
{
|
||||
constexpr auto in_desc = InDesc{};
|
||||
constexpr auto wei_desc = WeiDesc{};
|
||||
@@ -97,8 +82,8 @@ __device__ void threadwise_direct_convolution_2(InDesc,
|
||||
constexpr auto wei_reg_desc = make_ConstantTensorDescriptor(wei_desc.GetLengths());
|
||||
|
||||
// register
|
||||
Float p_in_reg[in_reg_desc.GetElementSpace()];
|
||||
Float p_wei_reg[wei_reg_desc.GetElementSpace()];
|
||||
TInWei p_in_reg[in_reg_desc.GetElementSpace()];
|
||||
TInWei p_wei_reg[wei_reg_desc.GetElementSpace()];
|
||||
|
||||
// copy input tensor into register
|
||||
threadwise_4d_tensor_copy(in_desc, p_in, in_reg_desc, p_in_reg, in_reg_desc.GetLengths());
|
||||
@@ -114,13 +99,13 @@ __device__ void threadwise_direct_convolution_2(InDesc,
|
||||
// optimized for scenario where p_in and p_wei are in LDS, p_out is in register
|
||||
// break down a non-1x1 convolution into a sequence of 1x1 convolutions,
|
||||
// load 1x1 weight into register, and do 1x1 convolution in register.
|
||||
template <class Float, class InDesc, class WeiDesc, class OutDesc>
|
||||
template <class Data, class InDesc, class WeiDesc, class OutDesc>
|
||||
__device__ void threadwise_direct_convolution_3(InDesc,
|
||||
Float* const __restrict__ p_in,
|
||||
Data* const __restrict__ p_in,
|
||||
WeiDesc,
|
||||
Float* const __restrict__ p_wei,
|
||||
Data* const __restrict__ p_wei,
|
||||
OutDesc,
|
||||
Float* __restrict__ p_out)
|
||||
Data* __restrict__ p_out)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -139,8 +124,8 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
constexpr auto wei_reg_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<wei_desc.GetLength(I0), wei_desc.GetLength(I1), 1, 1>{});
|
||||
|
||||
Float p_in_reg[in_reg_desc.GetElementSpace()];
|
||||
Float p_wei_reg[wei_reg_desc.GetElementSpace()];
|
||||
Data p_in_reg[in_reg_desc.GetElementSpace()];
|
||||
Data p_wei_reg[wei_reg_desc.GetElementSpace()];
|
||||
|
||||
constexpr unsigned in_w_new_read = 1;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user