mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
adding int8 direct that reads pre-vectorized data
This commit is contained in:
@@ -3,17 +3,18 @@
|
||||
#include "device.hpp"
|
||||
#include "gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp"
|
||||
|
||||
template <class T, class InDesc, class WeiDesc, class OutDesc>
|
||||
template <class TInWei, class TOut, class InDesc, class WeiDesc, class OutDesc>
|
||||
void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
const Tensor<TInWei>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
const Tensor<TInWei>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
Tensor<TOut>& out_nkhw,
|
||||
unsigned nrepeat)
|
||||
{
|
||||
constexpr unsigned NVector = 1;
|
||||
using vector_t = vector_type<T, NVector>;
|
||||
constexpr unsigned NVector = 4;
|
||||
using accum_t = int32_t;
|
||||
using vector_t = vector_type<TInWei, NVector>;
|
||||
using vector_mem_t = typename vector_t::MemoryType;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
@@ -44,11 +45,16 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
Tensor<vector_mem_t> in_nchw_vec(make_TensorDescriptor(in_nchw_vec_desc));
|
||||
|
||||
auto f_vectorized_nchw = [&](auto n, auto c, auto h, auto w) {
|
||||
#if 1
|
||||
#if 0
|
||||
in_nchw_vec(n, c, h, w) = in_nchw(n, c, h, w);
|
||||
#else
|
||||
#elif 0
|
||||
in_nchw_vec(n, c, h, w) =
|
||||
vector_t::Pack(in_nchw(n, 2 * c, h, w), in_nchw(n, 2 * c + 1, h, w));
|
||||
#elif 1
|
||||
in_nchw_vec(n, c, h, w) = vector_t::Pack(in_nchw(n, 4 * c, h, w),
|
||||
in_nchw(n, 4 * c + 1, h, w),
|
||||
in_nchw(n, 4 * c + 2, h, w),
|
||||
in_nchw(n, 4 * c + 3, h, w));
|
||||
#endif
|
||||
};
|
||||
|
||||
@@ -62,11 +68,16 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
Tensor<vector_mem_t> wei_kcyx_vec(make_TensorDescriptor(wei_kcyx_vec_desc));
|
||||
|
||||
auto f_vectorized_kcyx = [&](auto k, auto c, auto y, auto x) {
|
||||
#if 1
|
||||
#if 0
|
||||
wei_kcyx_vec(k, c, y, x) = wei_kcyx(k, c, y, x);
|
||||
#else
|
||||
#elif 0
|
||||
wei_kcyx_vec(k, c, y, x) =
|
||||
vector_t::Pack(wei_kcyx(k, 2 * c, y, x), wei_kcyx(k, 2 * c + 1, y, x));
|
||||
#elif 1
|
||||
wei_kcyx_vec(k, c, y, x) = vector_t::Pack(wei_kcyx(k, 4 * c, y, x),
|
||||
wei_kcyx(k, 4 * c + 1, y, x),
|
||||
wei_kcyx(k, 4 * c + 2, y, x),
|
||||
wei_kcyx(k, 4 * c + 3, y, x));
|
||||
#endif
|
||||
};
|
||||
|
||||
@@ -76,13 +87,13 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
//
|
||||
DeviceMem in_nchw_vec_device_buf(sizeof(vector_mem_t) * in_nchw_vec.mDesc.GetElementSpace());
|
||||
DeviceMem wei_kcyx_vec_device_buf(sizeof(vector_mem_t) * wei_kcyx_vec.mDesc.GetElementSpace());
|
||||
DeviceMem out_nkhw_device_buf(sizeof(T) * out_nkhw.mDesc.GetElementSpace());
|
||||
DeviceMem out_nkhw_device_buf(sizeof(TOut) * out_nkhw.mDesc.GetElementSpace());
|
||||
|
||||
in_nchw_vec_device_buf.ToDevice(in_nchw_vec.mData.data());
|
||||
wei_kcyx_vec_device_buf.ToDevice(wei_kcyx_vec.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
// 3x3, 34x34, 128 thread, fp32, vector = 1
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
constexpr unsigned KPerBlock = 32;
|
||||
@@ -100,7 +111,7 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 3x3, 34x34, 128 thread, fp32, vector = 2
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
constexpr unsigned KPerBlock = 32;
|
||||
@@ -117,9 +128,27 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr unsigned InBlockCopyDataPerRead = 2;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 0
|
||||
// 3x3, 34x34, 128 thread, int8, vector = 4
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
constexpr unsigned KPerBlock = 32;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 32;
|
||||
|
||||
constexpr unsigned NPerThread = 2;
|
||||
constexpr unsigned KPerThread = 4;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 2;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 1
|
||||
// 3x3, 34x34, 128 thread, fp16
|
||||
// 1x1, 32x32, 128 thread, int8, vector = 4
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
constexpr unsigned KPerBlock = 32;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
@@ -128,12 +157,12 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
|
||||
constexpr unsigned NPerThread = 2;
|
||||
constexpr unsigned KPerThread = 4;
|
||||
constexpr unsigned CPerThread = 2;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 2;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#endif
|
||||
@@ -146,7 +175,9 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
for(unsigned i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(
|
||||
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw<T,
|
||||
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw<TInWei,
|
||||
TOut,
|
||||
accum_t,
|
||||
decltype(in_nchw_vec_desc),
|
||||
decltype(wei_kcyx_vec_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
@@ -167,9 +198,9 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
GridSize>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
static_cast<T*>(in_nchw_vec_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_vec_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
static_cast<TInWei*>(in_nchw_vec_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_kcyx_vec_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
|
||||
printf("Elapsed time : %f ms\n", time);
|
||||
usleep(std::min(time * 1000, float(10000)));
|
||||
|
||||
@@ -88,9 +88,12 @@ auto make_TensorDescriptor(TConstTensorDesc)
|
||||
return TensorDescriptor(lengths, strides);
|
||||
}
|
||||
|
||||
template <class T, class LowerPads, class UpperPads>
|
||||
void host_direct_convolution(
|
||||
const Tensor<T>& in_nchw, const Tensor<T>& wei_kcyx, Tensor<T>& out, LowerPads, UpperPads)
|
||||
template <class TIn, class TWei, class TOut, class LowerPads, class UpperPads>
|
||||
void host_direct_convolution(const Tensor<TIn>& in_nchw,
|
||||
const Tensor<TWei>& wei_kcyx,
|
||||
Tensor<TOut>& out_nkhw,
|
||||
LowerPads,
|
||||
UpperPads)
|
||||
{
|
||||
unsigned h_pad_low = LowerPads{}.Get(Number<0>{});
|
||||
unsigned w_pad_low = LowerPads{}.Get(Number<1>{});
|
||||
@@ -116,21 +119,24 @@ void host_direct_convolution(
|
||||
}
|
||||
}
|
||||
}
|
||||
out(n, k, ho, wo) = v;
|
||||
out_nkhw(n, k, ho, wo) = v;
|
||||
};
|
||||
|
||||
auto f_par = make_ParallelTensorFunctor(f,
|
||||
out.mDesc.GetLengths()[0],
|
||||
out.mDesc.GetLengths()[1],
|
||||
out.mDesc.GetLengths()[2],
|
||||
out.mDesc.GetLengths()[3]);
|
||||
out_nkhw.mDesc.GetLengths()[0],
|
||||
out_nkhw.mDesc.GetLengths()[1],
|
||||
out_nkhw.mDesc.GetLengths()[2],
|
||||
out_nkhw.mDesc.GetLengths()[3]);
|
||||
|
||||
f_par(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
template <class T, class LowerPads, class UpperPads>
|
||||
void host_winograd_3x3_convolution(
|
||||
const Tensor<T>& in_nchw, const Tensor<T>& wei_kcyx, Tensor<T>& out, LowerPads, UpperPads)
|
||||
template <class TIn, class TWei, class TOut, class LowerPads, class UpperPads>
|
||||
void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
|
||||
const Tensor<TWei>& wei_kcyx,
|
||||
Tensor<TOut>& out_nkhw,
|
||||
LowerPads,
|
||||
UpperPads)
|
||||
{
|
||||
constexpr std::size_t HoPerTile = 2;
|
||||
constexpr std::size_t WoPerTile = 2;
|
||||
@@ -144,8 +150,8 @@ void host_winograd_3x3_convolution(
|
||||
std::size_t Y = wei_kcyx.mDesc.GetLengths()[2];
|
||||
std::size_t X = wei_kcyx.mDesc.GetLengths()[3];
|
||||
|
||||
std::size_t HO = out.mDesc.GetLengths()[2];
|
||||
std::size_t WO = out.mDesc.GetLengths()[3];
|
||||
std::size_t HO = out_nkhw.mDesc.GetLengths()[2];
|
||||
std::size_t WO = out_nkhw.mDesc.GetLengths()[3];
|
||||
|
||||
unsigned h_pad_low = LowerPads{}.Get(Number<0>{});
|
||||
unsigned w_pad_low = LowerPads{}.Get(Number<1>{});
|
||||
@@ -180,7 +186,7 @@ void host_winograd_3x3_convolution(
|
||||
}
|
||||
else
|
||||
{
|
||||
in_hold(n, c, htile, wtile, j, i) = T(0);
|
||||
in_hold(n, c, htile, wtile, j, i) = TIn(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -347,8 +353,8 @@ void host_winograd_3x3_convolution(
|
||||
std::size_t ho = HoPerTile * htile + j;
|
||||
for(int i = 0; i < WoPerTile; ++i)
|
||||
{
|
||||
std::size_t wo = WoPerTile * wtile + i;
|
||||
out(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i);
|
||||
std::size_t wo = WoPerTile * wtile + i;
|
||||
out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -403,7 +409,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 3x3, 34x34
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 256;
|
||||
@@ -502,7 +508,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
constexpr unsigned HPad = 1;
|
||||
constexpr unsigned WPad = 1;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 1x1 filter, 28x28 image
|
||||
constexpr unsigned N = 16;
|
||||
constexpr unsigned C = 256;
|
||||
@@ -562,6 +568,18 @@ int main(int argc, char* argv[])
|
||||
|
||||
constexpr unsigned HPad = 2;
|
||||
constexpr unsigned WPad = 2;
|
||||
#elif 1
|
||||
// 1x1 filter, 32x32 image
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 256;
|
||||
constexpr unsigned HI = 32;
|
||||
constexpr unsigned WI = 32;
|
||||
constexpr unsigned K = 512;
|
||||
constexpr unsigned Y = 1;
|
||||
constexpr unsigned X = 1;
|
||||
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
#endif
|
||||
|
||||
auto lower_pads = Sequence<HPad, WPad>{};
|
||||
@@ -576,11 +594,12 @@ int main(int argc, char* argv[])
|
||||
ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
|
||||
ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
|
||||
|
||||
using Float = float;
|
||||
Tensor<Float> in_nchw(make_TensorDescriptor(in_nchw_desc));
|
||||
Tensor<Float> wei_kcyx(make_TensorDescriptor(wei_kcyx_desc));
|
||||
Tensor<Float> out_nkhw_host(make_TensorDescriptor(out_nkhw_desc));
|
||||
Tensor<Float> out_nkhw_device(make_TensorDescriptor(out_nkhw_desc));
|
||||
using in_data_t = char;
|
||||
using out_data_t = int32_t;
|
||||
Tensor<in_data_t> in_nchw(make_TensorDescriptor(in_nchw_desc));
|
||||
Tensor<in_data_t> wei_kcyx(make_TensorDescriptor(wei_kcyx_desc));
|
||||
Tensor<out_data_t> out_nkhw_host(make_TensorDescriptor(out_nkhw_desc));
|
||||
Tensor<out_data_t> out_nkhw_device(make_TensorDescriptor(out_nkhw_desc));
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
|
||||
@@ -10,16 +10,6 @@ namespace CUDA {
|
||||
using half = CUDA::half;
|
||||
using half2 = CUDA::half2;
|
||||
|
||||
struct half4
|
||||
{
|
||||
half data[4];
|
||||
};
|
||||
|
||||
struct half8
|
||||
{
|
||||
half data[8];
|
||||
};
|
||||
|
||||
template <class T, unsigned N>
|
||||
struct vector_type
|
||||
{
|
||||
@@ -119,39 +109,141 @@ struct vector_type<half2, 4>
|
||||
using MemoryType = float4;
|
||||
};
|
||||
|
||||
template <class TDst, class TSrc0, class TSrc1, class TSrc2>
|
||||
__device__ void fused_multiply_add(TDst& d, TSrc0 s0, TSrc1 s1, TSrc2 s2)
|
||||
template <>
|
||||
struct vector_type<char, 1>
|
||||
{
|
||||
using MemoryType = char;
|
||||
|
||||
__host__ __device__ static MemoryType Pack(char s) { return s; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<char, 2>
|
||||
{
|
||||
using MemoryType = char2;
|
||||
|
||||
__host__ __device__ static MemoryType Pack(char s0, char s1)
|
||||
{
|
||||
union
|
||||
{
|
||||
MemoryType vector;
|
||||
char scalar[2];
|
||||
} data;
|
||||
|
||||
data.scalar[0] = s0;
|
||||
data.scalar[1] = s1;
|
||||
return data.vector;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<char, 4>
|
||||
{
|
||||
using MemoryType = char4;
|
||||
|
||||
__host__ __device__ static MemoryType Pack(char s0, char s1, char s2, char s3)
|
||||
{
|
||||
union
|
||||
{
|
||||
MemoryType vector;
|
||||
char scalar[4];
|
||||
} data;
|
||||
|
||||
data.scalar[0] = s0;
|
||||
data.scalar[1] = s1;
|
||||
data.scalar[2] = s2;
|
||||
data.scalar[3] = s3;
|
||||
return data.vector;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<char, 8>
|
||||
{
|
||||
using MemoryType = int64_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<char2, 2>
|
||||
{
|
||||
using MemoryType = char4;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<char2, 4>
|
||||
{
|
||||
using MemoryType = int64_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<char4, 2>
|
||||
{
|
||||
using MemoryType = int64_t;
|
||||
};
|
||||
|
||||
template <class TDst, class TSrc0, class TSrc1>
|
||||
__device__ void fused_multiply_accumulate(TDst& d, const TSrc0& s0, const TSrc1& s1)
|
||||
{
|
||||
// static_assert(false, "should not call into base");
|
||||
printf("should not call into base");
|
||||
assert(false);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void fused_multiply_add(float& d, float s0, float s1, float s2)
|
||||
__device__ void fused_multiply_accumulate(float& d, const float& s0, const float& s1)
|
||||
{
|
||||
d = s0 * s1 + s2;
|
||||
d += s0 * s1;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void fused_multiply_add(float& d, float2 s0, float2 s1, float s2)
|
||||
__device__ void fused_multiply_accumulate(float& d, const float2& s0, const float2& s1)
|
||||
{
|
||||
d = s0.x * s1.x + s0.y * s1.y + s2;
|
||||
d += s0.x * s1.x;
|
||||
d += s0.y * s1.y;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void fused_multiply_add(float& d, float4 s0, float4 s1, float s2)
|
||||
__device__ void fused_multiply_accumulate(float& d, const float4& s0, const float4& s1)
|
||||
{
|
||||
d = s0.x * s1.x + s0.y * s1.y + s0.z * s1.z + s0.w * s1.w + s2;
|
||||
d += s0.x * s1.x;
|
||||
d += s0.y * s1.y;
|
||||
d += s0.z * s1.z;
|
||||
d += s0.w * s1.w;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void fused_multiply_add(half& d, half s0, half s1, half s2)
|
||||
__device__ void fused_multiply_accumulate(half& d, const half& s0, const half& s1)
|
||||
{
|
||||
d = s0 * s1 + s2;
|
||||
d += s0 * s1;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void fused_multiply_add(half& d, half2 s0, half2 s1, half s2)
|
||||
__device__ void fused_multiply_accumulate(half& d, const half2& s0, const half2& s1)
|
||||
{
|
||||
d = s0.x * s1.x + s0.y * s1.y + s2;
|
||||
}
|
||||
d += s0.x * s1.x;
|
||||
d += s0.y * s1.y;
|
||||
}
|
||||
|
||||
#if 0
|
||||
template <>
|
||||
__device__ void fused_multiply_accumulate(float& d, const half2& s0, const half2& s1)
|
||||
{
|
||||
d += s0.x * s1.x + s0.y * s1.y;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
__device__ void fused_multiply_accumulate(char& d, const char& s0, const char& s1)
|
||||
{
|
||||
d += s0 * s1;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void fused_multiply_accumulate(int32_t& d, const char4& s0, const char4& s1)
|
||||
{
|
||||
#if DEVICE_BACKEND_CUDA
|
||||
d = __dp4a(s0, s1, d);
|
||||
#else
|
||||
d += s0.x * s1.x + s0.y * s1.y + s0.z * s1.z + s0.w * s1.w;
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -7,7 +7,9 @@
|
||||
#include "threadwise_4d_tensor_op.hip.hpp"
|
||||
#include "threadwise_direct_convolution.hip.hpp"
|
||||
|
||||
template <class Float,
|
||||
template <class TInWei,
|
||||
class TOut,
|
||||
class TAccum,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
@@ -27,14 +29,16 @@ template <class Float,
|
||||
unsigned BlockSize,
|
||||
unsigned GridSize>
|
||||
__global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
const typename vector_type<Float,
|
||||
const typename vector_type<TInWei,
|
||||
ScalarPerVector>::MemoryType* const __restrict__ p_in_vec_global,
|
||||
const typename vector_type<Float,
|
||||
const typename vector_type<TInWei,
|
||||
ScalarPerVector>::MemoryType* const __restrict__ p_wei_vec_global,
|
||||
Float* const __restrict__ p_out_global)
|
||||
TOut* const __restrict__ p_out_global)
|
||||
{
|
||||
using scalar_t = Float;
|
||||
using vector_mem_t = typename vector_type<scalar_t, ScalarPerVector>::MemoryType;
|
||||
using in_scalar_t = TInWei;
|
||||
using in_vector_mem_t = typename vector_type<in_scalar_t, ScalarPerVector>::MemoryType;
|
||||
using out_scalar_t = TOut;
|
||||
using accum_t = TAccum;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -79,9 +83,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
? InBlockCopyDataPerRead
|
||||
: WeiBlockCopyDataPerRead;
|
||||
|
||||
__shared__ vector_mem_t
|
||||
__shared__ in_vector_mem_t
|
||||
p_in_vec_block[max_align * ((in_block_size + max_align - 1) / max_align)];
|
||||
__shared__ vector_mem_t
|
||||
__shared__ in_vector_mem_t
|
||||
p_wei_vec_block[max_align * ((wei_block_size + max_align - 1) / max_align)];
|
||||
|
||||
// threadwise tensors
|
||||
@@ -99,7 +103,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
in_nchw_vec_thread_block_desc, wei_kcyx_vec_thread_block_desc);
|
||||
|
||||
// register
|
||||
scalar_t p_out_thread[out_nkhw_thread_desc.GetElementSpace()];
|
||||
out_scalar_t p_out_thread[out_nkhw_thread_desc.GetElementSpace()];
|
||||
|
||||
// divide block work
|
||||
constexpr unsigned NBlockWork =
|
||||
@@ -155,7 +159,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
|
||||
constexpr auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
vector_mem_t,
|
||||
in_vector_mem_t,
|
||||
decltype(in_nchw_vec_global_desc),
|
||||
decltype(in_nchw_vec_block_desc),
|
||||
decltype(in_nchw_vec_block_desc.GetLengths()),
|
||||
@@ -164,7 +168,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
#if 0
|
||||
constexpr auto blockwise_wei_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
vector_mem_t,
|
||||
in_vector_mem_t,
|
||||
decltype(wei_kcyx_vec_global_desc),
|
||||
decltype(wei_kcyx_vec_block_desc),
|
||||
decltype(wei_kcyx_vec_block_desc.GetLengths()),
|
||||
@@ -172,15 +176,17 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
#elif 1
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
vector_mem_t,
|
||||
in_vector_mem_t,
|
||||
decltype(wei_ke_vec_global_desc),
|
||||
decltype(wei_ke_vec_block_desc),
|
||||
decltype(wei_ke_vec_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead>{};
|
||||
#endif
|
||||
|
||||
#if 1 // debug
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_nkhw_thread_desc, p_out_thread);
|
||||
#endif
|
||||
|
||||
for(unsigned c_block_data_begin = 0; c_block_data_begin < C;
|
||||
c_block_data_begin += CPerBlock, __syncthreads())
|
||||
|
||||
@@ -37,7 +37,8 @@ __device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __re
|
||||
|
||||
// TODO: in order to optimize mem access for different mem type,
|
||||
// need to write specialized version
|
||||
template <class Float,
|
||||
template <class SrcData,
|
||||
class DstData,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
@@ -45,9 +46,9 @@ template <class Float,
|
||||
class F>
|
||||
__device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
|
||||
SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
const SrcData* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
DstData* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
DstFromSrcReorder,
|
||||
F f)
|
||||
@@ -88,33 +89,38 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d
|
||||
}
|
||||
}
|
||||
|
||||
template <class Float, class Desc>
|
||||
__device__ void threadwise_4d_tensor_set_zero(Desc, Float* __restrict__ p)
|
||||
template <class Data, class Desc>
|
||||
__device__ void threadwise_4d_tensor_set_zero(Desc, Data* __restrict__ p)
|
||||
{
|
||||
auto f_set_zero = [](Float& v) { v = Float(0); };
|
||||
auto f_set_zero = [](Data& v) { v = Data(0); };
|
||||
|
||||
threadwise_4d_tensor_pointwise_operation_unary<Float, Desc, decltype(f_set_zero)>(
|
||||
threadwise_4d_tensor_pointwise_operation_unary<Data, Desc, decltype(f_set_zero)>(
|
||||
Desc{}, p, f_set_zero);
|
||||
}
|
||||
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, class DstFromSrcReorder>
|
||||
template <class SrcData,
|
||||
class DstData,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class DstFromSrcReorder>
|
||||
__device__ void
|
||||
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
const SrcData* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
DstData* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
DstFromSrcReorder)
|
||||
{
|
||||
auto f_copy = [](const Float& src, Float& dst) { dst = src; };
|
||||
auto f_copy = [](const SrcData& src, DstData& dst) { dst = static_cast<DstData>(src); };
|
||||
|
||||
threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
|
||||
}
|
||||
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
|
||||
template <class SrcData, class DstData, class SrcDesc, class DstDesc, class SrcOpLengths>
|
||||
__device__ void threadwise_4d_tensor_copy(
|
||||
SrcDesc, const Float* __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths)
|
||||
SrcDesc, const SrcData* __restrict__ p_src, DstDesc, DstData* __restrict__ p_dst, SrcOpLengths)
|
||||
{
|
||||
auto dst_from_src_reorder = Sequence<0, 1, 2, 3>{};
|
||||
|
||||
|
||||
@@ -51,10 +51,8 @@ __device__ void threadwise_direct_convolution_1(InDesc,
|
||||
|
||||
const unsigned out_index = out_desc.Get1dIndex(n, k, ho, wo);
|
||||
|
||||
fused_multiply_add(p_out[out_index],
|
||||
p_wei[wei_index],
|
||||
p_in[in_index],
|
||||
p_out[out_index]);
|
||||
fused_multiply_accumulate(
|
||||
p_out[out_index], p_wei[wei_index], p_in[in_index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user