mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 19:09:59 +00:00
adding fp16 direct that reads pre-vectorized data
[ROCm/composable_kernel commit: 4f0fc72e91]
This commit is contained in:
@@ -373,7 +373,7 @@ template <unsigned BlockSize,
|
||||
unsigned DataPerRead>
|
||||
struct Blockwise2dTensorCopy3
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::type;
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::VectorType;
|
||||
|
||||
unsigned mSrcMyThreadOffset;
|
||||
unsigned mDstMyThreadOffset;
|
||||
|
||||
@@ -207,9 +207,9 @@ template <unsigned BlockSize,
|
||||
unsigned DataPerRead>
|
||||
struct Blockwise4dTensorCopy1
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::type;
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::VectorType;
|
||||
|
||||
__device__ void SanityCheck() const
|
||||
__device__ constexpr Blockwise4dTensorCopy1()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -239,8 +239,6 @@ struct Blockwise4dTensorCopy1
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
SanityCheck();
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
@@ -446,7 +444,7 @@ template <unsigned BlockSize,
|
||||
unsigned DataPerRead>
|
||||
struct Blockwise4dTensorCopy3
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::type;
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::VectorType;
|
||||
|
||||
unsigned mSrcMyThreadOffset;
|
||||
unsigned mDstMyThreadOffset;
|
||||
|
||||
@@ -28,44 +28,44 @@ struct vector_type
|
||||
template <>
|
||||
struct vector_type<float, 1>
|
||||
{
|
||||
using type = float;
|
||||
using VectorType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<float, 2>
|
||||
{
|
||||
using type = float2;
|
||||
using VectorType = float2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<float, 4>
|
||||
{
|
||||
using type = float4;
|
||||
using VectorType = float4;
|
||||
};
|
||||
|
||||
#if 0
|
||||
template <>
|
||||
struct vector_type<half_float::half, 1>
|
||||
{
|
||||
using type = half_float::half;
|
||||
using VectorType = half_float::half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half_float::half, 2>
|
||||
{
|
||||
using type = float;
|
||||
using VectorType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half_float::half, 4>
|
||||
{
|
||||
using type = float2;
|
||||
using VectorType = float2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half_float::half, 8>
|
||||
{
|
||||
using type = float4;
|
||||
using VectorType = float4;
|
||||
};
|
||||
#endif
|
||||
|
||||
@@ -73,25 +73,41 @@ struct vector_type<half_float::half, 8>
|
||||
template <>
|
||||
struct vector_type<half, 1>
|
||||
{
|
||||
using type = half;
|
||||
using VectorType = half;
|
||||
|
||||
__host__ __device__ static VectorType pack(half s) { return s; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half, 2>
|
||||
{
|
||||
using type = half2;
|
||||
using VectorType = half2;
|
||||
|
||||
union Data
|
||||
{
|
||||
VectorType vector;
|
||||
half scalar[2];
|
||||
};
|
||||
|
||||
__host__ __device__ static VectorType pack(half s0, half s1)
|
||||
{
|
||||
Data data;
|
||||
data.scalar[0] = s0;
|
||||
data.scalar[1] = s1;
|
||||
return data.vector;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half, 4>
|
||||
{
|
||||
using type = float2;
|
||||
using VectorType = float2;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half, 8>
|
||||
{
|
||||
using type = float4;
|
||||
using VectorType = float4;
|
||||
};
|
||||
#endif
|
||||
|
||||
|
||||
@@ -25,10 +25,10 @@ template <class Float,
|
||||
unsigned WeiBlockCopyDataPerRead,
|
||||
unsigned BlockSize,
|
||||
unsigned GridSize>
|
||||
__global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global)
|
||||
__global__ void
|
||||
gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
@@ -11,6 +11,7 @@ template <class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
unsigned ScalarPerVector,
|
||||
unsigned NPerBlock,
|
||||
unsigned KPerBlock,
|
||||
unsigned CPerBlock,
|
||||
@@ -26,47 +27,50 @@ template <class Float,
|
||||
unsigned BlockSize,
|
||||
unsigned GridSize>
|
||||
__global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
const typename vector_type<Float, ScalarPerVector>::VectorType* const __restrict__ p_in_global,
|
||||
const typename vector_type<Float, ScalarPerVector>::VectorType* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global)
|
||||
{
|
||||
using scalar_t = Float;
|
||||
using vector_t = typename vector_type<scalar_t, ScalarPerVector>::VectorType;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_nchw_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_kcyx_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_nkhw_global_desc = OutGlobalDesc{};
|
||||
constexpr auto in_nchw_vec_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_kcyx_vec_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_nkhw_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr unsigned N = in_nchw_global_desc.GetLength(I0);
|
||||
constexpr unsigned K = wei_kcyx_global_desc.GetLength(I0);
|
||||
constexpr unsigned C = wei_kcyx_global_desc.GetLength(I1);
|
||||
constexpr unsigned Y = wei_kcyx_global_desc.GetLength(I2);
|
||||
constexpr unsigned X = wei_kcyx_global_desc.GetLength(I3);
|
||||
constexpr unsigned N = in_nchw_vec_global_desc.GetLength(I0);
|
||||
constexpr unsigned K = wei_kcyx_vec_global_desc.GetLength(I0);
|
||||
constexpr unsigned C = wei_kcyx_vec_global_desc.GetLength(I1);
|
||||
constexpr unsigned Y = wei_kcyx_vec_global_desc.GetLength(I2);
|
||||
constexpr unsigned X = wei_kcyx_vec_global_desc.GetLength(I3);
|
||||
|
||||
constexpr auto wei_ke_global_desc = make_ConstantTensorDescriptor(
|
||||
constexpr auto wei_ke_vec_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K, C * Y * X>{}); // 2d view of wei for blockwise copy
|
||||
|
||||
constexpr unsigned HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr unsigned WiPerBlock = WoPerBlock + X - 1;
|
||||
|
||||
constexpr auto in_nchw_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
constexpr auto in_nchw_vec_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<NPerBlock, CPerBlock, HiPerBlock, WiPerBlock>{}, Number<InBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr auto wei_ke_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
constexpr auto wei_ke_vec_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<KPerBlock, CPerBlock * Y * X>{},
|
||||
Number<WeiBlockCopyDataPerRead>{}); // 2d view of wei for blockwise copy
|
||||
|
||||
constexpr auto wei_kcyx_block_desc =
|
||||
constexpr auto wei_kcyx_vec_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<KPerBlock, CPerBlock, Y, X>{},
|
||||
Sequence<wei_ke_block_desc.GetStride(I0), Y * X, X, 1>{});
|
||||
Sequence<wei_ke_vec_block_desc.GetStride(I0), Y * X, X, 1>{});
|
||||
|
||||
// shared mem
|
||||
constexpr unsigned in_block_size =
|
||||
in_nchw_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
|
||||
in_nchw_vec_block_desc.GetElementSpace(Number<InBlockCopyDataPerRead>{});
|
||||
constexpr unsigned wei_block_size =
|
||||
wei_kcyx_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
|
||||
wei_kcyx_vec_block_desc.GetElementSpace(Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr unsigned max_align = InBlockCopyDataPerRead > WeiBlockCopyDataPerRead
|
||||
? InBlockCopyDataPerRead
|
||||
@@ -81,10 +85,10 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
|
||||
constexpr auto in_nchw_thread_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<NPerThread, CPerThread, HiPerThread, WiPerThread>{},
|
||||
in_nchw_block_desc.GetStrides());
|
||||
in_nchw_vec_block_desc.GetStrides());
|
||||
|
||||
constexpr auto wei_kcyx_thread_block_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread, CPerThread, Y, X>{}, wei_kcyx_block_desc.GetStrides());
|
||||
Sequence<KPerThread, CPerThread, Y, X>{}, wei_kcyx_vec_block_desc.GetStrides());
|
||||
|
||||
constexpr auto out_nkhw_thread_desc = get_convolution_output_default_4d_tensor_descriptor(
|
||||
in_nchw_thread_block_desc, wei_kcyx_thread_block_desc);
|
||||
@@ -147,26 +151,27 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
constexpr auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_nchw_global_desc),
|
||||
decltype(in_nchw_block_desc),
|
||||
decltype(in_nchw_block_desc.GetLengths()),
|
||||
decltype(in_nchw_vec_global_desc),
|
||||
decltype(in_nchw_vec_block_desc),
|
||||
decltype(in_nchw_vec_block_desc.GetLengths()),
|
||||
InBlockCopyDataPerRead>{};
|
||||
|
||||
#if 0
|
||||
constexpr auto blockwise_wei_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_kcyx_global_desc),
|
||||
decltype(wei_kcyx_block_desc),
|
||||
decltype(wei_kcyx_block_desc.GetLengths()),
|
||||
decltype(wei_kcyx_vec_global_desc),
|
||||
decltype(wei_kcyx_vec_block_desc),
|
||||
decltype(wei_kcyx_vec_block_desc.GetLengths()),
|
||||
1>{};
|
||||
#elif 1
|
||||
const auto blockwise_wei_copy = Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ke_global_desc),
|
||||
decltype(wei_ke_block_desc),
|
||||
decltype(wei_ke_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead>{};
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ke_vec_global_desc),
|
||||
decltype(wei_ke_vec_block_desc),
|
||||
decltype(wei_ke_vec_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead>{};
|
||||
#endif
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
@@ -176,14 +181,14 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
c_block_data_begin += CPerBlock, __syncthreads())
|
||||
{
|
||||
// copy input tensor to LDS
|
||||
blockwise_in_copy.Run(p_in_global + in_nchw_global_desc.Get1dIndex(n_block_data_begin,
|
||||
c_block_data_begin,
|
||||
hi_block_data_begin,
|
||||
wi_block_data_begin),
|
||||
blockwise_in_copy.Run(p_in_global + in_nchw_vec_global_desc.Get1dIndex(n_block_data_begin,
|
||||
c_block_data_begin,
|
||||
hi_block_data_begin,
|
||||
wi_block_data_begin),
|
||||
p_in_block);
|
||||
|
||||
// copy weight tensor to LDS
|
||||
blockwise_wei_copy.Run(p_wei_global + wei_kcyx_global_desc.Get1dIndex(
|
||||
blockwise_wei_copy.Run(p_wei_global + wei_kcyx_vec_global_desc.Get1dIndex(
|
||||
k_block_data_begin, c_block_data_begin, 0, 0),
|
||||
p_wei_block);
|
||||
|
||||
@@ -195,25 +200,25 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
#if 1
|
||||
threadwise_direct_convolution_2(
|
||||
in_nchw_thread_block_desc,
|
||||
p_in_block + in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
p_in_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_kcyx_thread_block_desc,
|
||||
p_wei_block +
|
||||
wei_kcyx_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
|
||||
wei_kcyx_vec_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
|
||||
out_nkhw_thread_desc,
|
||||
p_out_thread);
|
||||
#elif 0
|
||||
threadwise_direct_convolution_3(
|
||||
in_nchw_thread_block_desc,
|
||||
p_in_block + in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
p_in_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_kcyx_thread_block_desc,
|
||||
p_wei_block +
|
||||
wei_kcyx_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
|
||||
wei_kcyx_vec_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
|
||||
out_nkhw_thread_desc,
|
||||
p_out_thread);
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user