mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
device_implicit_gemm_convolution_1_chwn_csrk_khwn: use tensor copy (instead of pointwise) for writing output, 3x3 increased from 78% to 84%, 5x5 from 80% to 84%
This commit is contained in:
@@ -75,7 +75,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
|
||||
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
|
||||
|
||||
#if 1
|
||||
// for 3x3, 34x34, try
|
||||
// for 3x3, 34x34
|
||||
constexpr unsigned NPerBlock = 16;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
@@ -106,9 +106,46 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
|
||||
constexpr unsigned GemmNLevel1Cluster = 4;
|
||||
constexpr unsigned GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr unsigned OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 0
|
||||
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
|
||||
// for 5x5, 36x36
|
||||
constexpr unsigned NPerBlock = 16;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 2;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 4;
|
||||
|
||||
constexpr unsigned NPerThread = 8;
|
||||
constexpr unsigned KPerThread = 8;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
|
||||
|
||||
constexpr unsigned InBlockCopy_ThreadPerDimC = 2;
|
||||
constexpr unsigned InBlockCopy_ThreadPerDimH = 2;
|
||||
constexpr unsigned InBlockCopy_ThreadPerDimW = 4;
|
||||
constexpr unsigned InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr unsigned InBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 2;
|
||||
|
||||
constexpr unsigned GemmMPerThreadSubC = 4;
|
||||
constexpr unsigned GemmNPerThreadSubC = 4;
|
||||
constexpr unsigned GemmMLevel0Cluster = 4;
|
||||
constexpr unsigned GemmNLevel0Cluster = 2;
|
||||
constexpr unsigned GemmMLevel1Cluster = 2;
|
||||
constexpr unsigned GemmNLevel1Cluster = 4;
|
||||
constexpr unsigned GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr unsigned OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 0
|
||||
// 3x3 58x58, NKC = 64, 64, 256
|
||||
constexpr unsigned NPerBlock = 16;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
@@ -142,27 +179,6 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 0
|
||||
// for 5x5, 36x36
|
||||
constexpr unsigned NPerBlock = 16;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 2;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 4;
|
||||
|
||||
constexpr unsigned NPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim1 = 32;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 4; // not used, yet
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 0
|
||||
// for 7x7, 38x38
|
||||
@@ -200,7 +216,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// for 1x1, 28x28
|
||||
constexpr unsigned NPerBlock = 16;
|
||||
constexpr unsigned KPerBlock = 128;
|
||||
@@ -210,7 +226,7 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
|
||||
|
||||
constexpr unsigned NPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 2;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
|
||||
@@ -225,6 +241,16 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
|
||||
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned GemmMPerThreadSubC = 4;
|
||||
constexpr unsigned GemmNPerThreadSubC = 4;
|
||||
constexpr unsigned GemmMLevel0Cluster = 4;
|
||||
constexpr unsigned GemmNLevel0Cluster = 2;
|
||||
constexpr unsigned GemmMLevel1Cluster = 2;
|
||||
constexpr unsigned GemmNLevel1Cluster = 4;
|
||||
constexpr unsigned GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr unsigned OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#endif
|
||||
|
||||
@@ -266,7 +292,8 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop>,
|
||||
GemmKPerThreadLoop,
|
||||
OutThreadCopyDataPerWrite>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
static_cast<T*>(in_chwn_device_buf.GetDeviceBuffer()),
|
||||
|
||||
@@ -571,16 +571,21 @@ int main()
|
||||
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
|
||||
bool do_verification = true;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
#if 0
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
#elif 1
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei_kcsr.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei_kcsr.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
#elif 1
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-2, 2}, num_thread);
|
||||
wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-2, 2}, num_thread);
|
||||
wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
#endif
|
||||
}
|
||||
|
||||
unsigned nrepeat = 200;
|
||||
|
||||
@@ -614,22 +619,23 @@ int main()
|
||||
nrepeat);
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
if(S == 3 && R == 3)
|
||||
if(do_verification)
|
||||
{
|
||||
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads);
|
||||
}
|
||||
else
|
||||
{
|
||||
host_direct_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads);
|
||||
}
|
||||
check_error(out_nkhw_host, out_nkhw_device);
|
||||
#endif
|
||||
if(S == 3 && R == 3)
|
||||
{
|
||||
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads);
|
||||
}
|
||||
else
|
||||
{
|
||||
host_direct_convolution(in_nchw, wei_kcsr, out_nkhw_host, lower_pads, upper_pads);
|
||||
}
|
||||
check_error(out_nkhw_host, out_nkhw_device);
|
||||
|
||||
#if 0
|
||||
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "wei_kcsr: ", wei_kcsr.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "wei_kcsr: ", wei_kcsr.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,35 @@ __host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2
|
||||
return Sequence<L1 * L2 * L3, L2 * L3, L3, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 6d
|
||||
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3, unsigned L4, unsigned L5>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5>)
|
||||
{
|
||||
return Sequence<L1 * L2 * L3 * L4 * L5, L2 * L3 * L4 * L5, L3 * L4 * L5, L4 * L5, L5, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 8d
|
||||
template <unsigned L0,
|
||||
unsigned L1,
|
||||
unsigned L2,
|
||||
unsigned L3,
|
||||
unsigned L4,
|
||||
unsigned L5,
|
||||
unsigned L6,
|
||||
unsigned L7>
|
||||
__host__ __device__ constexpr auto
|
||||
calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5, L6, L7>)
|
||||
{
|
||||
return Sequence<L1 * L2 * L3 * L4 * L5 * L6 * L7,
|
||||
L2 * L3 * L4 * L5 * L6 * L7,
|
||||
L3 * L4 * L5 * L6 * L7,
|
||||
L4 * L5 * L6 * L7,
|
||||
L5 * L6 * L7,
|
||||
L6 * L7,
|
||||
L7,
|
||||
1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 2d
|
||||
template <unsigned L0, unsigned L1, unsigned Align>
|
||||
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1>,
|
||||
@@ -64,7 +93,7 @@ struct ConstantTensorDescriptor
|
||||
|
||||
__host__ __device__ constexpr unsigned GetElementSize() const
|
||||
{
|
||||
static_assert(nDim >= 2 && nDim <= 4, "nDim");
|
||||
static_assert(nDim >= 2 && nDim <= 8, "nDim");
|
||||
|
||||
if(nDim == 2)
|
||||
{
|
||||
@@ -90,12 +119,65 @@ struct ConstantTensorDescriptor
|
||||
|
||||
return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3);
|
||||
}
|
||||
else if(nDim == 5)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4);
|
||||
}
|
||||
else if(nDim == 6)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4) *
|
||||
GetLength(I5);
|
||||
}
|
||||
else if(nDim == 7)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4) *
|
||||
GetLength(I5) * GetLength(I6);
|
||||
}
|
||||
else if(nDim == 8)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4) *
|
||||
GetLength(I5) * GetLength(I6) * GetLength(I7);
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
template <class Align = Number<1>>
|
||||
__host__ __device__ constexpr unsigned GetElementSpace(Align align = Align{}) const
|
||||
{
|
||||
static_assert(nDim >= 2 && nDim <= 4, "nDim");
|
||||
static_assert(nDim >= 2 && nDim <= 8, "nDim");
|
||||
|
||||
constexpr unsigned align_size = align.Get();
|
||||
|
||||
@@ -127,6 +209,64 @@ struct ConstantTensorDescriptor
|
||||
(GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) +
|
||||
align_size;
|
||||
}
|
||||
else if(nDim == 5)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
|
||||
(GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) +
|
||||
(GetLength(I4) - 1) * GetStride(I4) + align_size;
|
||||
}
|
||||
else if(nDim == 6)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
|
||||
(GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) +
|
||||
(GetLength(I4) - 1) * GetStride(I4) + (GetLength(I5) - 1) * GetStride(I5) +
|
||||
align_size;
|
||||
}
|
||||
else if(nDim == 7)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
|
||||
(GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) +
|
||||
(GetLength(I4) - 1) * GetStride(I4) + (GetLength(I5) - 1) * GetStride(I5) +
|
||||
(GetLength(I6) - 1) * GetStride(I6) + align_size;
|
||||
}
|
||||
else if(nDim == 8)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
|
||||
(GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) +
|
||||
(GetLength(I4) - 1) * GetStride(I4) + (GetLength(I5) - 1) * GetStride(I5) +
|
||||
(GetLength(I6) - 1) * GetStride(I6) + (GetLength(I7) - 1) * GetStride(I7) +
|
||||
align_size;
|
||||
}
|
||||
}
|
||||
|
||||
// this is ugly, only for 2d
|
||||
@@ -163,6 +303,83 @@ struct ConstantTensorDescriptor
|
||||
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3);
|
||||
}
|
||||
|
||||
// this is ugly, only for 5d
|
||||
__host__ __device__ unsigned
|
||||
Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3, unsigned i4) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
static_assert(nDim == 5, "nDim is not 5");
|
||||
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) +
|
||||
i4 * GetStride(I4);
|
||||
}
|
||||
|
||||
// this is ugly, only for 6d
|
||||
__host__ __device__ unsigned
|
||||
Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3, unsigned i4, unsigned i5) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
static_assert(nDim == 6, "nDim is not 6");
|
||||
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) +
|
||||
i4 * GetStride(I4) + i5 * GetStride(I5);
|
||||
}
|
||||
|
||||
// this is ugly, only for 7d
|
||||
__host__ __device__ unsigned Get1dIndex(unsigned i0,
|
||||
unsigned i1,
|
||||
unsigned i2,
|
||||
unsigned i3,
|
||||
unsigned i4,
|
||||
unsigned i5,
|
||||
unsigned i6) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
static_assert(nDim == 7, "nDim is not 7");
|
||||
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) +
|
||||
i4 * GetStride(I4) + i5 * GetStride(I5) + i6 * GetStride(I6);
|
||||
}
|
||||
|
||||
// this is ugly, only for 8d
|
||||
__host__ __device__ unsigned Get1dIndex(unsigned i0,
|
||||
unsigned i1,
|
||||
unsigned i2,
|
||||
unsigned i3,
|
||||
unsigned i4,
|
||||
unsigned i5,
|
||||
unsigned i6,
|
||||
unsigned i7) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
static_assert(nDim == 8, "nDim is not 8");
|
||||
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) +
|
||||
i4 * GetStride(I4) + i5 * GetStride(I5) + i6 * GetStride(I6) + i7 * GetStride(I7);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto Condense() const
|
||||
{
|
||||
constexpr auto default_strides = calculate_default_strides(Lengths{});
|
||||
@@ -196,7 +413,7 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
constexpr auto desc = TDesc{};
|
||||
constexpr unsigned ndim = desc.GetDimension();
|
||||
|
||||
static_assert(ndim >= 2 && ndim <= 4, "wrong!");
|
||||
static_assert(ndim >= 2 && ndim <= 8, "wrong!");
|
||||
|
||||
if(ndim == 2)
|
||||
{
|
||||
@@ -230,4 +447,110 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
desc.GetStride(I2),
|
||||
desc.GetStride(I3));
|
||||
}
|
||||
else if(ndim == 5)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
desc.GetLength(I3),
|
||||
desc.GetLength(I4),
|
||||
desc.GetStride(I0),
|
||||
desc.GetStride(I1),
|
||||
desc.GetStride(I2),
|
||||
desc.GetStride(I3),
|
||||
desc.GetStride(I4));
|
||||
}
|
||||
else if(ndim == 6)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
desc.GetLength(I3),
|
||||
desc.GetLength(I4),
|
||||
desc.GetLength(I5),
|
||||
desc.GetStride(I0),
|
||||
desc.GetStride(I1),
|
||||
desc.GetStride(I2),
|
||||
desc.GetStride(I3),
|
||||
desc.GetStride(I4),
|
||||
desc.GetStride(I5));
|
||||
}
|
||||
else if(ndim == 7)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
desc.GetLength(I3),
|
||||
desc.GetLength(I4),
|
||||
desc.GetLength(I5),
|
||||
desc.GetLength(I6),
|
||||
desc.GetStride(I0),
|
||||
desc.GetStride(I1),
|
||||
desc.GetStride(I2),
|
||||
desc.GetStride(I3),
|
||||
desc.GetStride(I4),
|
||||
desc.GetStride(I5),
|
||||
desc.GetStride(I6));
|
||||
}
|
||||
else if(ndim == 8)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
desc.GetLength(I3),
|
||||
desc.GetLength(I4),
|
||||
desc.GetLength(I5),
|
||||
desc.GetLength(I6),
|
||||
desc.GetLength(I7),
|
||||
desc.GetStride(I0),
|
||||
desc.GetStride(I1),
|
||||
desc.GetStride(I2),
|
||||
desc.GetStride(I3),
|
||||
desc.GetStride(I4),
|
||||
desc.GetStride(I5),
|
||||
desc.GetStride(I6),
|
||||
desc.GetStride(I7));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -525,7 +525,51 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
f_accum);
|
||||
}
|
||||
}
|
||||
|
||||
template <class BlockMatrixC, unsigned BlockMatrixStrideC, class FloatC>
|
||||
__device__ void CopyThreadMatrixCToBlockMatrixC(const FloatC* __restrict__ p_c_thread,
|
||||
FloatC* __restrict__ p_c_block) const
|
||||
{
|
||||
constexpr auto c_block_mtx = BlockMatrixC{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr unsigned MPerThread = c_thread_mtx.NRow();
|
||||
constexpr unsigned NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
constexpr auto c_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
|
||||
|
||||
constexpr unsigned MPerLevel1Cluster = MPerThreadSubC * MLevel0Cluster * MLevel1Cluster;
|
||||
constexpr unsigned NPerLevel1Cluster = NPerThreadSubC * NLevel0Cluster * NLevel1Cluster;
|
||||
|
||||
constexpr unsigned MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr unsigned NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
const auto c_thread_mtx_begin = GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const unsigned c_thread_offset =
|
||||
c_thread_mtx_begin.batch * BlockMatrixStrideC +
|
||||
c_block_mtx.Get1dIndex(c_thread_mtx_begin.row, c_thread_mtx_begin.col);
|
||||
|
||||
for(unsigned m_repeat = 0; m_repeat, MRepeat; ++m_repeat)
|
||||
{
|
||||
for(unsigned n_repeat = 0; n_repeat, NRepeat; ++n_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
c_thread_sub_mtx,
|
||||
p_c_thread + c_thread_sub_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster,
|
||||
n_repeat * NPerLevel1Cluster),
|
||||
c_block_mtx,
|
||||
p_c_block +
|
||||
c_block_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster,
|
||||
n_repeat * NPerLevel1Cluster) +
|
||||
c_thread_offset,
|
||||
c_thread_sub_mtx.GetLengths());
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <unsigned BlockSize,
|
||||
class BlockMatrixA,
|
||||
class BlockMatrixB,
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "ConstantMatrixDescriptor.hip.hpp"
|
||||
#include "blockwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_2d_tensor_op.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_gemm.hip.hpp"
|
||||
|
||||
@@ -33,7 +34,8 @@ template <unsigned GridSize,
|
||||
unsigned GemmNLevel0Cluster,
|
||||
unsigned GemmMLevel1Cluster,
|
||||
unsigned GemmNLevel1Cluster,
|
||||
unsigned GemmKPerThreadLoop>
|
||||
unsigned GemmKPerThreadLoop,
|
||||
unsigned OutThreadCopyDataPerWrite>
|
||||
__global__ void
|
||||
gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
@@ -270,19 +272,18 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
|
||||
}
|
||||
}
|
||||
|
||||
// output: register to global mem,
|
||||
#if 0
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
// output: register to global mem,
|
||||
// convert out_thread[Ho,K,Wo,N] to out_global[K,Ho,Wo,N]
|
||||
#if 0
|
||||
// for v1 batch-gemm
|
||||
const unsigned k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const unsigned wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const unsigned n_thread_data_begin = c_thread_mtx_begin.col - wo_thread_data_begin * NPerBlock;
|
||||
const unsigned n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
threadwise_4d_tensor_copy(
|
||||
threadwise_4d_tensor_copy_v2(
|
||||
out_khwn_thread_desc,
|
||||
p_out_thread,
|
||||
out_khwn_global_desc,
|
||||
@@ -290,8 +291,12 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_khwn_thread_desc.GetLengths());
|
||||
#else
|
||||
out_khwn_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite>{});
|
||||
#elif 0
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
for(unsigned k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
|
||||
{
|
||||
for(unsigned ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
|
||||
@@ -322,5 +327,58 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif 1
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const unsigned k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const unsigned ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const unsigned wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const unsigned n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
// this is for v2 GEMM
|
||||
// output is a 8d tensor
|
||||
if(NPerThread < NPerBlock && WoPerThread == 1)
|
||||
{
|
||||
constexpr unsigned N1_ = GemmNPerThreadSubC;
|
||||
constexpr unsigned W1_ = WoPerBlock / ((WoPerThread * NPerThread) / GemmNPerThreadSubC);
|
||||
constexpr unsigned K2_ = GemmMPerThreadSubC;
|
||||
constexpr unsigned K1_ = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_8d_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K / (K1_ * K2_), K1_, K2_, Ho, Wo / W1_, W1_, N / N1_, N1_>{});
|
||||
|
||||
constexpr auto out_8d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerBlock / (K1_ * K2_), 1, K2_, HoPerThread, WoPerBlock / W1_, 1, 1, N1_>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_8d_thread_desc, "out_8d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
|
||||
print_ConstantTensorDescriptor(out_8d_global_desc, "out_8d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_8d_tensor_copy(out_8d_thread_desc,
|
||||
p_out_thread,
|
||||
out_8d_global_desc,
|
||||
p_out_global + out_khwn_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_8d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite>{});
|
||||
}
|
||||
else if(NPerThread == NPerBlock)
|
||||
{
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ template <class Float,
|
||||
class F>
|
||||
__device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
|
||||
SrcDesc,
|
||||
Float* const __restrict__ p_src,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
@@ -100,7 +100,7 @@ __device__ void threadwise_4d_tensor_set_zero(Desc, Float* __restrict__ p)
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, class DstFromSrcReorder>
|
||||
__device__ void
|
||||
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
|
||||
Float* const __restrict__ p_src,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
@@ -114,7 +114,7 @@ threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
|
||||
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
|
||||
__device__ void threadwise_4d_tensor_copy(
|
||||
SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths)
|
||||
SrcDesc, const Float* __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths)
|
||||
{
|
||||
auto dst_from_src_reorder = Sequence<0, 1, 2, 3>{};
|
||||
|
||||
@@ -122,6 +122,85 @@ __device__ void threadwise_4d_tensor_copy(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder);
|
||||
}
|
||||
|
||||
// need to assume src and dst is aligned
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, unsigned DataPerRead>
|
||||
__device__ void threadwise_4d_tensor_copy_v2(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
Number<DataPerRead>)
|
||||
{
|
||||
using Float2 = float2;
|
||||
using Float4 = float4;
|
||||
|
||||
static_assert(SrcDesc{}.GetDimension() == 4 && DstDesc{}.GetDimension() == 4 &&
|
||||
SrcOpLengths::nDim == 4,
|
||||
"wrong! should be 4 dimension");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I3) == 1 && DstDesc{}.GetStride(I3) == 1,
|
||||
"wrong! only support stride3 == 1!\n");
|
||||
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I2) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(I2) % DataPerRead == 0,
|
||||
"wrong! src and dst stride should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
constexpr unsigned L3 = SrcOpLengths{}.Get(I3);
|
||||
|
||||
static_assert(L3 % DataPerRead == 0, "wrong! L3 should be evenly divided by DataPerRead");
|
||||
|
||||
constexpr unsigned nloop_d3 = L3 / DataPerRead;
|
||||
|
||||
for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
|
||||
{
|
||||
for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
|
||||
{
|
||||
for(unsigned iloop_d3 = 0; iloop_d3 < nloop_d3; ++iloop_d3)
|
||||
{
|
||||
const unsigned src_index =
|
||||
src_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead);
|
||||
|
||||
const unsigned dst_index =
|
||||
dst_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead);
|
||||
|
||||
if(DataPerRead == 1)
|
||||
{
|
||||
p_dst[dst_index] = p_src[src_index];
|
||||
}
|
||||
else if(DataPerRead == 2)
|
||||
{
|
||||
*(reinterpret_cast<Float2*>(p_dst + dst_index)) =
|
||||
*(reinterpret_cast<const Float2*>(p_src + src_index));
|
||||
}
|
||||
else if(DataPerRead == 4)
|
||||
{
|
||||
*(reinterpret_cast<Float4*>(p_dst + dst_index)) =
|
||||
*(reinterpret_cast<const Float4*>(p_src + src_index));
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class Float, class Desc, class IDim, class NShift>
|
||||
__device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDim, NShift)
|
||||
{
|
||||
|
||||
198
src/include/threadwise_nd_tensor_op.hip.hpp
Normal file
198
src/include/threadwise_nd_tensor_op.hip.hpp
Normal file
@@ -0,0 +1,198 @@
|
||||
#pragma once
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
|
||||
// need to assume src and dst is aligned
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, unsigned DataPerRead>
|
||||
__device__ void threadwise_6d_tensor_copy(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
Number<DataPerRead>)
|
||||
{
|
||||
using Float2 = float2;
|
||||
using Float4 = float4;
|
||||
|
||||
static_assert(SrcDesc{}.GetDimension() == 6 && DstDesc{}.GetDimension() == 6 &&
|
||||
SrcOpLengths::nDim == 6,
|
||||
"wrong! should be 6 dimension");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I5) == 1 && DstDesc{}.GetStride(I5) == 1,
|
||||
"wrong! only support stride5 == 1!\n");
|
||||
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I4) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(I4) % DataPerRead == 0,
|
||||
"wrong! src and dst stride should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
constexpr unsigned L5 = SrcOpLengths{}.Get(I5);
|
||||
|
||||
static_assert(L5 % DataPerRead == 0, "wrong! L5 should be evenly divided by DataPerRead");
|
||||
|
||||
constexpr unsigned nloop_d5 = L5 / DataPerRead;
|
||||
|
||||
for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
|
||||
{
|
||||
for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
|
||||
{
|
||||
for(unsigned did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
|
||||
{
|
||||
for(unsigned did4 = 0; did4 < ref_desc.GetLength(I4); ++did4)
|
||||
{
|
||||
for(unsigned iloop_d5 = 0; iloop_d5 < nloop_d5; ++iloop_d5)
|
||||
{
|
||||
const unsigned src_index = src_desc.Get1dIndex(
|
||||
did0, did1, did2, did3, did4, iloop_d5 * DataPerRead);
|
||||
|
||||
const unsigned dst_index = dst_desc.Get1dIndex(
|
||||
did0, did1, did2, did3, did4, iloop_d5 * DataPerRead);
|
||||
|
||||
if(DataPerRead == 1)
|
||||
{
|
||||
p_dst[dst_index] = p_src[src_index];
|
||||
}
|
||||
else if(DataPerRead == 2)
|
||||
{
|
||||
*(reinterpret_cast<Float2*>(p_dst + dst_index)) =
|
||||
*(reinterpret_cast<const Float2*>(p_src + src_index));
|
||||
}
|
||||
else if(DataPerRead == 4)
|
||||
{
|
||||
*(reinterpret_cast<Float4*>(p_dst + dst_index)) =
|
||||
*(reinterpret_cast<const Float4*>(p_src + src_index));
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// need to assume src and dst is aligned
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, unsigned DataPerRead>
|
||||
__device__ void threadwise_8d_tensor_copy(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
Number<DataPerRead>)
|
||||
{
|
||||
using Float2 = float2;
|
||||
using Float4 = float4;
|
||||
|
||||
static_assert(SrcDesc{}.GetDimension() == 8 && DstDesc{}.GetDimension() == 8 &&
|
||||
SrcOpLengths::nDim == 8,
|
||||
"wrong! should be 8 dimension");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I7) == 1 && DstDesc{}.GetStride(I7) == 1,
|
||||
"wrong! only support stride7 == 1!\n");
|
||||
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I6) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(I6) % DataPerRead == 0,
|
||||
"wrong! src and dst stride should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
constexpr unsigned L7 = SrcOpLengths{}.Get(I7);
|
||||
|
||||
static_assert(L7 % DataPerRead == 0, "wrong! L7 should be evenly divided by DataPerRead");
|
||||
|
||||
constexpr unsigned nloop_d7 = L7 / DataPerRead;
|
||||
|
||||
for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
|
||||
{
|
||||
for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
|
||||
{
|
||||
for(unsigned did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
|
||||
{
|
||||
for(unsigned did4 = 0; did4 < ref_desc.GetLength(I4); ++did4)
|
||||
{
|
||||
for(unsigned did5 = 0; did5 < ref_desc.GetLength(I5); ++did5)
|
||||
{
|
||||
for(unsigned did6 = 0; did6 < ref_desc.GetLength(I6); ++did6)
|
||||
{
|
||||
for(unsigned iloop_d7 = 0; iloop_d7 < nloop_d7; ++iloop_d7)
|
||||
{
|
||||
const unsigned src_index =
|
||||
src_desc.Get1dIndex(did0,
|
||||
did1,
|
||||
did2,
|
||||
did3,
|
||||
did4,
|
||||
did5,
|
||||
did6,
|
||||
iloop_d7 * DataPerRead);
|
||||
|
||||
const unsigned dst_index =
|
||||
dst_desc.Get1dIndex(did0,
|
||||
did1,
|
||||
did2,
|
||||
did3,
|
||||
did4,
|
||||
did5,
|
||||
did6,
|
||||
iloop_d7 * DataPerRead);
|
||||
|
||||
if(DataPerRead == 1)
|
||||
{
|
||||
p_dst[dst_index] = p_src[src_index];
|
||||
}
|
||||
else if(DataPerRead == 2)
|
||||
{
|
||||
*(reinterpret_cast<Float2*>(p_dst + dst_index)) =
|
||||
*(reinterpret_cast<const Float2*>(p_src + src_index));
|
||||
}
|
||||
else if(DataPerRead == 4)
|
||||
{
|
||||
*(reinterpret_cast<Float4*>(p_dst + dst_index)) =
|
||||
*(reinterpret_cast<const Float4*>(p_src + src_index));
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user