mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
implicit gemm v1r2: adding support for nchw
This commit is contained in:
@@ -243,7 +243,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 4;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// for 3x3, 28x28, v1r1, Pacal
|
||||
constexpr index_t NPerBlock = 32;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
@@ -386,7 +386,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
constexpr auto gridwise_conv =
|
||||
#if 0
|
||||
#if 1
|
||||
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
#elif 0
|
||||
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn_lds_double_buffer
|
||||
|
||||
433
driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp
Normal file
433
driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp
Normal file
@@ -0,0 +1,433 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "gridwise_convolution_wrapper.hip.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp"
|
||||
|
||||
template <class T, class InDesc, class WeiDesc, class OutDesc>
|
||||
void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
index_t nrepeat)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_nchw_desc = InDesc{};
|
||||
constexpr auto wei_kcyx_desc = WeiDesc{};
|
||||
constexpr auto out_nkhw_desc = OutDesc{};
|
||||
|
||||
constexpr index_t Hi = in_nchw_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_nchw_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t N = out_nkhw_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = wei_kcyx_desc.GetLength(I0);
|
||||
constexpr index_t C = wei_kcyx_desc.GetLength(I1);
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_kcyx_desc.GetLength(I3);
|
||||
|
||||
// reorder weight
|
||||
auto wei_cyxk_desc = make_ConstantTensorDescriptor(Sequence<C, Y, X, K>{});
|
||||
ostream_ConstantTensorDescriptor(wei_cyxk_desc, std::cout << "wei_cyxk_desc: ");
|
||||
|
||||
Tensor<T> wei_cyxk(make_TensorDescriptor(wei_cyxk_desc));
|
||||
|
||||
auto f_reorder_kcyx2cyxk = [&](auto k, auto c, auto y, auto x) {
|
||||
wei_cyxk(c, y, x, k) = wei_kcyx(k, c, y, x);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_reorder_kcyx2cyxk, K, C, Y, X)(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
// output
|
||||
auto out_khwn_desc = make_ConstantTensorDescriptor(Sequence<K, Ho, Wo, N>{});
|
||||
ostream_ConstantTensorDescriptor(out_khwn_desc, std::cout << "out_khwn_desc: ");
|
||||
|
||||
Tensor<T> out_khwn(make_TensorDescriptor(out_khwn_desc));
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
|
||||
DeviceMem wei_cyxk_device_buf(data_sz * wei_cyxk.mDesc.GetElementSpace());
|
||||
DeviceMem out_khwn_device_buf(data_sz * out_khwn.mDesc.GetElementSpace());
|
||||
|
||||
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
|
||||
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
|
||||
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
|
||||
|
||||
#if 0
|
||||
// for 3x3, 34x34, v1r1, Pascal
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 3x3, 34x34, v1r2, Pascal, in-block-copy1
|
||||
constexpr index_t NPerBlock = 4;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 4;
|
||||
constexpr index_t WoPerBlock = 8;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 1;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 3x3, 34x34, v1r1, Vega 20
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
|
||||
constexpr index_t InBlockCopyDataPerRead = 2;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 2;
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 4;
|
||||
|
||||
constexpr index_t BlockSize = 256;
|
||||
#elif 0
|
||||
// for 3x3, 56x56, v1, Pascal
|
||||
constexpr index_t NPerBlock = 32;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 1;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 3x3, 56x56, v1r2, Pascal
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 1;
|
||||
constexpr index_t GemmDataPerReadB = 1;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 1;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 4;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 3x3, 28x28, v1r1, Pacal
|
||||
constexpr index_t NPerBlock = 32;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 1;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 1
|
||||
// for 3x3, 28x28, v1r2, Pascal
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr index_t InBlockCopyDataPerRead = 2;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 1x1, 28x28
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 1
|
||||
// for 1x1, 14x14, Pascal
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr index_t NPerThread = 8;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#endif
|
||||
|
||||
constexpr index_t GridSize =
|
||||
((N + NPerBlock - 1) / NPerBlock) * ((K + KPerBlock - 1) / KPerBlock) *
|
||||
((Ho + HoPerBlock - 1) / HoPerBlock) * ((Wo + WoPerBlock - 1) / WoPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
constexpr auto gridwise_conv =
|
||||
#if 1
|
||||
GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
#endif
|
||||
<GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_cyxk_desc),
|
||||
decltype(out_khwn_desc),
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
CPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
NPerThread,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB,
|
||||
Sequence<InBlockCopy_ThreadPerDimN,
|
||||
InBlockCopy_ThreadPerDimC,
|
||||
InBlockCopy_ThreadPerDimH,
|
||||
InBlockCopy_ThreadPerDimW>,
|
||||
InBlockCopyDataPerRead,
|
||||
WeiBlockCopyDataPerRead,
|
||||
OutThreadCopyDataPerWrite>{};
|
||||
|
||||
float time = launch_kernel(run_gridwise_convolution<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_cyxk_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_khwn_device_buf.GetDeviceBuffer()));
|
||||
|
||||
printf("Elapsed time : %f ms, %f TFlop/s\n",
|
||||
time,
|
||||
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1000) * 1000 * 1000) / time);
|
||||
usleep(std::min(time * 1000, float(10000)));
|
||||
}
|
||||
|
||||
out_khwn_device_buf.FromDevice(out_khwn.mData.data());
|
||||
|
||||
// reorder output
|
||||
auto f_reorder_khwn2nkhw = [&](auto k, auto ho, auto wo, auto n) {
|
||||
out_nkhw(n, k, ho, wo) = out_khwn(k, ho, wo, n);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_reorder_khwn2nkhw, K, Ho, Wo, N)(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
@@ -38,9 +38,6 @@ struct GeneratorTensor_2
|
||||
|
||||
struct GeneratorTensor_3
|
||||
{
|
||||
int min_value = 0;
|
||||
int max_value = 9;
|
||||
|
||||
template <class... Is>
|
||||
double operator()(Is... is)
|
||||
{
|
||||
@@ -420,11 +417,10 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if 0
|
||||
// 3x3, 34x34
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 34;
|
||||
constexpr index_t WI = 34;
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 8;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
@@ -432,15 +428,27 @@ int main(int argc, char* argv[])
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 3x3, 56x56
|
||||
// 3x3, 34x34
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 56;
|
||||
constexpr index_t WI = 56;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 34;
|
||||
constexpr index_t WI = 34;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 3x3, 56x56
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 56;
|
||||
constexpr index_t WI = 56;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
@@ -642,6 +650,9 @@ int main(int argc, char* argv[])
|
||||
#if 0
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
#elif 0
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_3{}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
#elif 1
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei_kcyx.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
@@ -664,7 +675,7 @@ int main(int argc, char* argv[])
|
||||
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
|
||||
#elif 1
|
||||
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
|
||||
#elif 0
|
||||
#elif 1
|
||||
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
#pragma once
|
||||
#include "Sequence.hip.hpp"
|
||||
#include "functional.hip.hpp"
|
||||
|
||||
template <class TData, index_t NSize>
|
||||
struct Array
|
||||
@@ -18,3 +20,35 @@ struct Array
|
||||
|
||||
__host__ __device__ TData& operator[](index_t i) { return mData[i]; }
|
||||
};
|
||||
|
||||
template <class TData, index_t NSize, index_t... IRs>
|
||||
__host__ __device__ auto reorder_array_given_new2old(const Array<TData, NSize>& old_array,
|
||||
Sequence<IRs...> new2old)
|
||||
{
|
||||
Array<TData, NSize> new_array;
|
||||
|
||||
static_assert(NSize == sizeof...(IRs), "NSize not consistent");
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto IDim) {
|
||||
constexpr index_t idim = IDim.Get();
|
||||
new_array[idim] = old_array[new2old.Get(IDim)];
|
||||
});
|
||||
|
||||
return new_array;
|
||||
}
|
||||
|
||||
template <class TData, index_t NSize, index_t... IRs>
|
||||
__host__ __device__ auto reorder_array_given_old2new(const Array<TData, NSize>& old_array,
|
||||
Sequence<IRs...> old2new)
|
||||
{
|
||||
Array<TData, NSize> new_array;
|
||||
|
||||
static_assert(NSize == sizeof...(IRs), "NSize not consistent");
|
||||
|
||||
static_for<0, NSize, 1>{}([&](auto IDim) {
|
||||
constexpr index_t idim = IDim.Get();
|
||||
new_array[old2new.Get(IDim)] = old_array[idim];
|
||||
});
|
||||
|
||||
return new_array;
|
||||
}
|
||||
@@ -108,11 +108,11 @@ template <class Lengths, class Strides>
|
||||
struct ConstantTensorDescriptor
|
||||
{
|
||||
using Type = ConstantTensorDescriptor<Lengths, Strides>;
|
||||
static constexpr index_t nDim = Lengths::nDim;
|
||||
static constexpr index_t nDim = Lengths::GetSize();
|
||||
|
||||
__host__ __device__ constexpr ConstantTensorDescriptor()
|
||||
{
|
||||
static_assert(Lengths::nDim == Strides::nDim, "nDim not consistent");
|
||||
static_assert(Lengths::GetSize() == Strides::GetSize(), "nDim not consistent");
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetDimension() { return nDim; }
|
||||
@@ -157,12 +157,10 @@ struct ConstantTensorDescriptor
|
||||
return align.Get() * ((element_space_unaligned + align.Get() - 1) / align.Get());
|
||||
}
|
||||
|
||||
template <class... Is>
|
||||
__host__ __device__ static index_t Get1dIndex(Is... is)
|
||||
template <index_t NSize>
|
||||
__host__ __device__ static index_t Get1dIndex(Array<index_t, NSize> multi_id)
|
||||
{
|
||||
static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong");
|
||||
|
||||
const auto multi_id = Array<index_t, nDim>(is...);
|
||||
static_assert(NSize == nDim, "wrong! Dimension not consistent");
|
||||
|
||||
index_t id = 0;
|
||||
|
||||
@@ -178,6 +176,16 @@ struct ConstantTensorDescriptor
|
||||
return id;
|
||||
}
|
||||
|
||||
template <class... Is>
|
||||
__host__ __device__ static index_t Get1dIndex(Is... is)
|
||||
{
|
||||
static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong");
|
||||
|
||||
const auto multi_id = Array<index_t, nDim>(is...);
|
||||
|
||||
return Get1dIndex(multi_id);
|
||||
}
|
||||
|
||||
__host__ __device__ static Array<index_t, nDim> GetMultiIndex(index_t id)
|
||||
{
|
||||
Array<index_t, nDim> multi_id;
|
||||
|
||||
@@ -7,9 +7,11 @@ struct Sequence
|
||||
{
|
||||
using Type = Sequence<Is...>;
|
||||
|
||||
static constexpr index_t nDim = sizeof...(Is);
|
||||
static constexpr index_t mSize = sizeof...(Is);
|
||||
|
||||
const index_t mData[nDim] = {Is...};
|
||||
const index_t mData[mSize] = {Is...};
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSize() { return mSize; }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr index_t Get(Number<I>) const
|
||||
@@ -19,36 +21,38 @@ struct Sequence
|
||||
|
||||
__host__ __device__ index_t operator[](index_t i) const { return mData[i]; }
|
||||
|
||||
// this is ugly, only for nDIm = 4
|
||||
template <index_t I0, index_t I1, index_t I2, index_t I3>
|
||||
__host__ __device__ constexpr auto ReorderByGetNewFromOld(Sequence<I0, I1, I2, I3>) const
|
||||
template <index_t... IRs>
|
||||
__host__ __device__ constexpr auto ReorderGivenNew2Old(Sequence<IRs...> /*new2old*/) const
|
||||
{
|
||||
static_assert(nDim == 4, "nDim != 4");
|
||||
static_assert(mSize == sizeof...(IRs), "mSize not consistent");
|
||||
|
||||
constexpr auto old_sequence = Type{};
|
||||
constexpr auto old = Type{};
|
||||
|
||||
constexpr index_t NR0 = old_sequence.mData[I0];
|
||||
constexpr index_t NR1 = old_sequence.mData[I1];
|
||||
constexpr index_t NR2 = old_sequence.mData[I2];
|
||||
constexpr index_t NR3 = old_sequence.mData[I3];
|
||||
|
||||
return Sequence<NR0, NR1, NR2, NR3>{};
|
||||
return Sequence<old.Get(Number<IRs>{})...>{};
|
||||
}
|
||||
|
||||
template <index_t I0, index_t I1, index_t I2, index_t I3>
|
||||
__host__ __device__ constexpr auto ReorderByPutOldToNew(Sequence<I0, I1, I2, I3>) const
|
||||
template <index_t... IRs>
|
||||
__host__ __device__ constexpr auto ReorderGivenOld2New(Sequence<IRs...> /*old2new*/) const
|
||||
{
|
||||
// don't know how to implement this
|
||||
printf("Sequence::ReorderByPutOldToNew not implemented");
|
||||
printf("Sequence::ReorderGivenOld2New not implemented");
|
||||
assert(false);
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto PushFront(Number<I>) const
|
||||
{
|
||||
return Sequence<I, Is...>{};
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto PushBack(Number<I>) const
|
||||
{
|
||||
return Sequence<Is..., I>{};
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto PopFront() const;
|
||||
|
||||
__host__ __device__ constexpr auto PopBack() const;
|
||||
|
||||
template <class F>
|
||||
@@ -58,33 +62,84 @@ struct Sequence
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t... Is, index_t I>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>)
|
||||
template <index_t I, index_t... Is>
|
||||
__host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
|
||||
{
|
||||
static_assert(sizeof...(Is) >= 1, "empty Sequence!");
|
||||
static_assert(sizeof...(Is) > 0, "empty Sequence!");
|
||||
return Sequence<Is...>{};
|
||||
}
|
||||
|
||||
template <class F, index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto sequence_sequence_op(Sequence<Xs...>, Sequence<Ys...>, F f)
|
||||
template <index_t... Is, index_t I>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>)
|
||||
{
|
||||
static_assert(Sequence<Xs...>::nDim == Sequence<Ys...>::nDim, "Dim not the same");
|
||||
static_assert(sizeof...(Is) > 0, "empty Sequence!");
|
||||
return Sequence<Is...>{};
|
||||
}
|
||||
|
||||
#if 1
|
||||
// this is ugly, only for 2 sequences
|
||||
template <class F, index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
|
||||
{
|
||||
static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize, "Dim not the same");
|
||||
|
||||
return Sequence<f(Xs, Ys)...>{};
|
||||
}
|
||||
|
||||
template <index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto sequence_sequence_add(Sequence<Xs...>, Sequence<Ys...>)
|
||||
// this is ugly, only for 3 sequences
|
||||
template <class F, index_t... Xs, index_t... Ys, index_t... Zs>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
|
||||
{
|
||||
struct add
|
||||
{
|
||||
__host__ __device__ constexpr index_t operator()(index_t x, index_t y) const
|
||||
{
|
||||
return x + y;
|
||||
}
|
||||
};
|
||||
static_assert(Sequence<Xs...>::mSize == Sequence<Ys...>::mSize &&
|
||||
Sequence<Xs...>::mSize == Sequence<Zs...>::mSize,
|
||||
"Dim not the same");
|
||||
|
||||
return sequence_sequence_op(Sequence<Xs...>{}, Sequence<Ys...>{}, add{});
|
||||
return Sequence<f(Xs, Ys, Zs)...>{};
|
||||
}
|
||||
#else
|
||||
template <index_t NRemain>
|
||||
struct transform_sequences_impl
|
||||
{
|
||||
template <class F, class Y, class... Xs>
|
||||
__host__ __device__ constexpr auto operator()(F f, Y y, Xs... xs) const
|
||||
{
|
||||
static_assert(NRemain > 1, "wrong! should have NRemain > 1");
|
||||
|
||||
constexpr index_t N = f(Xs{}.Get(Number<0>{})...);
|
||||
constexpr auto y_new = y.PushBack(Number<N>{});
|
||||
|
||||
return transform_sequences_impl<NRemain - 1>{}(f, y_new, xs.PopFront()...);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct transform_sequences_impl<1>
|
||||
{
|
||||
template <class F, class Y, class... Xs>
|
||||
__host__ __device__ constexpr auto operator()(F f, Y, Xs...) const
|
||||
{
|
||||
constexpr index_t N = f(Xs{}.Get(Number<0>{})...);
|
||||
return Y{}.PushBack(Number<N>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <class F, class X, class... Xs>
|
||||
__host__ __device__ constexpr auto transform_sequences(F f, X x, Xs... xs)
|
||||
{
|
||||
constexpr index_t nSize = X::GetSize();
|
||||
constexpr auto I0 = Number<0>{};
|
||||
|
||||
constexpr auto y0 = Sequence<f(X{}.Get(I0), Xs{}.Get(I0)...)>{};
|
||||
|
||||
return transform_sequences_impl<nSize - 1>{}(f, y0, x.PopFront(), xs.PopFront()...);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ constexpr auto Sequence<Is...>::PopFront() const
|
||||
{
|
||||
return sequence_pop_front(Type{});
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
@@ -107,6 +162,6 @@ template <class Seq, class Reduce, index_t I>
|
||||
__host__ __device__ constexpr index_t accumulate_on_sequence(Seq, Reduce, Number<I>)
|
||||
{
|
||||
constexpr index_t a =
|
||||
static_const_reduce_n<Seq::nDim>{}(accumulate_on_sequence_f<Seq>{}, Reduce{});
|
||||
static_const_reduce_n<Seq::mSize>{}(accumulate_on_sequence_f<Seq>{}, Reduce{});
|
||||
return Reduce{}(a, I);
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ template <index_t BlockSize,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class DstFromSrcReorder,
|
||||
class MapDst2Src,
|
||||
class F>
|
||||
__device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
|
||||
SrcDesc,
|
||||
@@ -75,14 +75,14 @@ __device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_ds
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
DstFromSrcReorder,
|
||||
MapDst2Src,
|
||||
F f)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0);
|
||||
constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1);
|
||||
constexpr index_t IR0 = MapDst2Src{}.Get(I0);
|
||||
constexpr index_t IR1 = MapDst2Src{}.Get(I1);
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
@@ -147,19 +147,19 @@ template <index_t BlockSize,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class DstFromSrcReorder>
|
||||
class MapDst2Src>
|
||||
__device__ void
|
||||
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
DstFromSrcReorder)
|
||||
MapDst2Src)
|
||||
{
|
||||
auto f_copy = [](const Float& src, Float& dst) { dst = src; };
|
||||
|
||||
blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
@@ -192,7 +192,7 @@ struct Blockwise2dTensorCopy1
|
||||
// but we need to make sure dst stride0 is big enough,
|
||||
// so that the out-of-bound write won't contaminate next line in dst
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t read_per_d1 = integer_divide_ceil(L1, DataPerRead);
|
||||
constexpr index_t read_per_d1 = mod_conv::integer_divide_ceil(L1, DataPerRead);
|
||||
|
||||
static_assert(read_per_d1 * DataPerRead <= DstDesc{}.GetStride(I0),
|
||||
"wrong! out-of-bound write will contaminate next line!\n");
|
||||
@@ -209,7 +209,7 @@ struct Blockwise2dTensorCopy1
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr index_t read_per_d1 = integer_divide_ceil(L1, DataPerRead);
|
||||
constexpr index_t read_per_d1 = mod_conv::integer_divide_ceil(L1, DataPerRead);
|
||||
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(Sequence<L0, read_per_d1>{});
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ struct Blockwise3dTensorCopy1
|
||||
// but we need to make sure dst stride2 is big enough,
|
||||
// so that the out-of-bound write won't contaminate next line in dst
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
constexpr index_t read_per_d2 = integer_divide_ceil(L2, DataPerRead);
|
||||
constexpr index_t read_per_d2 = mod_conv::integer_divide_ceil(L2, DataPerRead);
|
||||
|
||||
static_assert(read_per_d2 * DataPerRead <= DstDesc{}.GetStride(I1),
|
||||
"wrong! out-of-bound write will contaminate next line!\n");
|
||||
@@ -52,7 +52,7 @@ struct Blockwise3dTensorCopy1
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
|
||||
constexpr index_t read_per_d2 = integer_divide_ceil(L2, DataPerRead);
|
||||
constexpr index_t read_per_d2 = mod_conv::integer_divide_ceil(L2, DataPerRead);
|
||||
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(Sequence<L0, L1, read_per_d2>{});
|
||||
|
||||
@@ -98,3 +98,271 @@ struct Blockwise3dTensorCopy1
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// starting point need to be aligned to float4 or float2 or float
|
||||
// stride3 need to be 1 for both source and destination
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class CopyLengths,
|
||||
class ThreadPerDims,
|
||||
index_t DataPerRead>
|
||||
struct Blockwise3dTensorCopy3
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
index_t mSrcMyThreadOffset;
|
||||
index_t mDstMyThreadOffset;
|
||||
|
||||
__device__ Blockwise3dTensorCopy3()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
static_assert(DataPerRead == 1 ||
|
||||
(SrcDesc{}.GetStride(I2) == 1 && DstDesc{}.GetStride(I2) == 1),
|
||||
"wrong! only support stride3 == 1 if DataPerRead > 1!\n");
|
||||
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(
|
||||
SrcDesc{}.GetStride(I1) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(I1) % DataPerRead == 0,
|
||||
"wrong! src and dst stride1 should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
|
||||
// we allow out-of-bound read from src in D2 dimension,
|
||||
// but we need to make sure dst stride is big enough,
|
||||
// so that the out-of-bound write won't contaminate next line in dst
|
||||
constexpr index_t nloop_d2 = mod_conv::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
|
||||
|
||||
static_assert(nloop_d2 * thread_per_d2 * DataPerRead <= DstDesc{}.GetStride(I1),
|
||||
"wrong! out-of-bound write will contaminate next line!\n");
|
||||
|
||||
static_assert(L0 % thread_per_d0 == 0 && L1 % thread_per_d1 == 0,
|
||||
"wrong! L0, L1, L2 should be divided evenly!\n");
|
||||
|
||||
static_assert(BlockSize >= thread_per_d0 * thread_per_d1 * thread_per_d2,
|
||||
"wrrong! BlockSize is not big enough for ThreadPerDims!");
|
||||
|
||||
constexpr index_t num_active_thread =
|
||||
accumulate_on_sequence(ThreadPerDims{}, mod_conv::multiplies<index_t>{}, Number<1>{});
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor(ThreadPerDims{});
|
||||
const auto thread_multi_id = thread_cluster_desc.GetMultiIndex(get_thread_local_1d_id());
|
||||
|
||||
mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(
|
||||
thread_multi_id[0], thread_multi_id[1], thread_multi_id[2] * DataPerRead);
|
||||
|
||||
mDstMyThreadOffset = DstDesc{}.Get1dIndex(
|
||||
thread_multi_id[0], thread_multi_id[1], thread_multi_id[2] * DataPerRead);
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1 * thread_per_d2;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = mod_conv::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
|
||||
|
||||
#pragma unroll
|
||||
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
|
||||
{
|
||||
#pragma unroll
|
||||
const index_t src_offset =
|
||||
SrcDesc{}.Get1dIndex(iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2 * DataPerRead);
|
||||
|
||||
const index_t dst_offset =
|
||||
DstDesc{}.Get1dIndex(iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2 * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[dst_offset + mDstMyThreadOffset])) = *(
|
||||
reinterpret_cast<const vector_t*>(&p_src[src_offset + mSrcMyThreadOffset]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ constexpr index_t GetRegisterClipboardSize() const
|
||||
{
|
||||
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = mod_conv::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
|
||||
|
||||
return DataPerRead * nloop_d0 * nloop_d1 * nloop_d2;
|
||||
}
|
||||
|
||||
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
|
||||
Float* __restrict__ p_clipboard) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1 * thread_per_d2;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = mod_conv::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
|
||||
|
||||
constexpr auto clipboard_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<nloop_d0, nloop_d1, nloop_d2 * DataPerRead>{});
|
||||
|
||||
#pragma unroll
|
||||
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
|
||||
{
|
||||
const index_t src_offset =
|
||||
SrcDesc{}.Get1dIndex(iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2 * DataPerRead);
|
||||
|
||||
const index_t clipboard_offset =
|
||||
clipboard_desc.Get1dIndex(iloop_d0, iloop_d1, iloop_d2 * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_clipboard[clipboard_offset])) = *(
|
||||
reinterpret_cast<const vector_t*>(&p_src[src_offset + mSrcMyThreadOffset]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
|
||||
Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
constexpr index_t L0 = CopyLengths{}.Get(I0);
|
||||
constexpr index_t L1 = CopyLengths{}.Get(I1);
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
|
||||
constexpr index_t thread_per_d0 = ThreadPerDims{}.Get(I0);
|
||||
constexpr index_t thread_per_d1 = ThreadPerDims{}.Get(I1);
|
||||
constexpr index_t thread_per_d2 = ThreadPerDims{}.Get(I2);
|
||||
|
||||
constexpr index_t num_active_thread = thread_per_d0 * thread_per_d1 * thread_per_d2;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = mod_conv::integer_divide_ceil(L2, thread_per_d2 * DataPerRead);
|
||||
|
||||
constexpr auto clipboard_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<nloop_d0, nloop_d1, nloop_d2 * DataPerRead>{});
|
||||
|
||||
#pragma unroll
|
||||
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d1 = 0; iloop_d1 < nloop_d1; ++iloop_d1)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
|
||||
{
|
||||
const index_t clipboard_offset =
|
||||
clipboard_desc.Get1dIndex(iloop_d0, iloop_d1, iloop_d2 * DataPerRead);
|
||||
|
||||
const index_t dst_offset =
|
||||
DstDesc{}.Get1dIndex(iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
iloop_d2 * thread_per_d2 * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[dst_offset + mDstMyThreadOffset])) =
|
||||
*(reinterpret_cast<const vector_t*>(&p_clipboard[clipboard_offset]));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -84,7 +84,7 @@ template <index_t BlockSize,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class DstFromSrcReorder,
|
||||
class MapDst2Src,
|
||||
class F>
|
||||
__device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
|
||||
SrcDesc,
|
||||
@@ -92,7 +92,7 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
DstFromSrcReorder,
|
||||
MapDst2Src,
|
||||
F f)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
@@ -100,10 +100,10 @@ __device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_ds
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0);
|
||||
constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1);
|
||||
constexpr index_t IR2 = DstFromSrcReorder{}.Get(I2);
|
||||
constexpr index_t IR3 = DstFromSrcReorder{}.Get(I3);
|
||||
constexpr index_t IR0 = MapDst2Src{}.Get(I0);
|
||||
constexpr index_t IR1 = MapDst2Src{}.Get(I1);
|
||||
constexpr index_t IR2 = MapDst2Src{}.Get(I2);
|
||||
constexpr index_t IR3 = MapDst2Src{}.Get(I3);
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
@@ -184,19 +184,19 @@ template <index_t BlockSize,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class DstFromSrcReorder>
|
||||
class MapDst2Src>
|
||||
__device__ void
|
||||
blockwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
DstFromSrcReorder)
|
||||
MapDst2Src)
|
||||
{
|
||||
auto f_copy = [](const Float& src, Float& dst) { dst = src; };
|
||||
|
||||
blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
@@ -231,7 +231,7 @@ struct Blockwise4dTensorCopy1
|
||||
// but we need to make sure dst stride2 is big enough,
|
||||
// so that the out-of-bound write won't contaminate next line in dst
|
||||
constexpr index_t L3 = CopyLengths{}.Get(I3);
|
||||
constexpr index_t read_per_d3 = integer_divide_ceil(L3, DataPerRead);
|
||||
constexpr index_t read_per_d3 = mod_conv::integer_divide_ceil(L3, DataPerRead);
|
||||
|
||||
static_assert(read_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2),
|
||||
"wrong! out-of-bound write will contaminate next line!\n");
|
||||
@@ -252,7 +252,7 @@ struct Blockwise4dTensorCopy1
|
||||
constexpr index_t L2 = CopyLengths{}.Get(I2);
|
||||
constexpr index_t L3 = CopyLengths{}.Get(I3);
|
||||
|
||||
constexpr index_t read_per_d3 = integer_divide_ceil(L3, DataPerRead);
|
||||
constexpr index_t read_per_d3 = mod_conv::integer_divide_ceil(L3, DataPerRead);
|
||||
|
||||
constexpr auto ref_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<L0, L1, L2, read_per_d3>{});
|
||||
@@ -481,7 +481,7 @@ struct Blockwise4dTensorCopy3
|
||||
// we allow out-of-bound read from src in D3 dimension,
|
||||
// but we need to make sure dst stride is big enough,
|
||||
// so that the out-of-bound write won't contaminate next line in dst
|
||||
constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
constexpr index_t nloop_d3 = mod_conv::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
|
||||
static_assert(nloop_d3 * thread_per_d3 * DataPerRead <= DstDesc{}.GetStride(I2),
|
||||
"wrong! out-of-bound write will contaminate next line!\n");
|
||||
@@ -548,7 +548,7 @@ struct Blockwise4dTensorCopy3
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = L2 / thread_per_d2;
|
||||
constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
constexpr index_t nloop_d3 = mod_conv::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
|
||||
#pragma unroll
|
||||
for(index_t iloop_d0 = 0; iloop_d0 < nloop_d0; ++iloop_d0)
|
||||
@@ -605,7 +605,7 @@ struct Blockwise4dTensorCopy3
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = L2 / thread_per_d2;
|
||||
constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
constexpr index_t nloop_d3 = mod_conv::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
|
||||
return DataPerRead * nloop_d0 * nloop_d1 * nloop_d2 * nloop_d3;
|
||||
}
|
||||
@@ -642,7 +642,7 @@ struct Blockwise4dTensorCopy3
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = L2 / thread_per_d2;
|
||||
constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
constexpr index_t nloop_d3 = mod_conv::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
|
||||
constexpr auto clipboard_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<nloop_d0, nloop_d1, nloop_d2, nloop_d3 * DataPerRead>{});
|
||||
@@ -709,7 +709,7 @@ struct Blockwise4dTensorCopy3
|
||||
constexpr index_t nloop_d0 = L0 / thread_per_d0;
|
||||
constexpr index_t nloop_d1 = L1 / thread_per_d1;
|
||||
constexpr index_t nloop_d2 = L2 / thread_per_d2;
|
||||
constexpr index_t nloop_d3 = integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
constexpr index_t nloop_d3 = mod_conv::integer_divide_ceil(L3, thread_per_d3 * DataPerRead);
|
||||
|
||||
constexpr auto clipboard_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<nloop_d0, nloop_d1, nloop_d2, nloop_d3 * DataPerRead>{});
|
||||
@@ -749,7 +749,7 @@ template <index_t BlockSize,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class DstFromSrcReorder>
|
||||
class MapDst2Src>
|
||||
struct Blockwise4dTensorCopyReorder1
|
||||
{
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
@@ -757,60 +757,104 @@ struct Blockwise4dTensorCopyReorder1
|
||||
auto f_copy = [](const Float& src, Float& dst) { dst = src; };
|
||||
|
||||
blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
|
||||
}
|
||||
};
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcLengths,
|
||||
class SrcSubLengths,
|
||||
class SrcThreadPerDims,
|
||||
class DstFromSrcReorder,
|
||||
index_t DataPerRead,
|
||||
index_t DataPerWrite>
|
||||
class SrcClusterLengths,
|
||||
class MapDst2Src,
|
||||
class MapThreadCluster2SrcCluster,
|
||||
index_t SrcDataPerRead,
|
||||
index_t DstDataPerWrite>
|
||||
struct Blockwise4dTensorCopyReorder3
|
||||
{
|
||||
static constexpr index_t nDim = SrcLengths::GetSize();
|
||||
|
||||
index_t mSrcMyThreadOffset;
|
||||
index_t mDstMyThreadOffset;
|
||||
|
||||
__device__ Blockwise4dTensorCopyReorder3()
|
||||
{
|
||||
constexpr index_t nDim = SrcDesc{}.GetDimension();
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
static_assert(DstDesc{}.GetDimension() == nDim && SrcOpLengths::nDim == nDim &&
|
||||
SrcOpThreadPerDims::nDim == nDim && DstFromSrcReorder::nDim == nDim,
|
||||
"wrong! nDim is not consistent\n");
|
||||
constexpr auto src_lengths = SrcLengths{};
|
||||
|
||||
// Src
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
constexpr auto map_dst2src = MapDst2Src{};
|
||||
|
||||
static_assert(DataPerRead == 1 || SrcDesc{}.GetStride(Number<nDim-1>{}) == 1,
|
||||
"wrong! only support src.stride(nDim-1) == 1 if DataPerRead > 1!\n");
|
||||
constexpr auto src_sub_lengths = SrcSubLengths{};
|
||||
constexpr auto dst_sub_lengths = src_sub_lengths.ReorderGivenNew2Old(map_dst2src);
|
||||
|
||||
static_assert(
|
||||
SrcDesc{}.GetStride(Number<nDim-2>{}) % DataPerRead == 0,
|
||||
"wrong! src.stride(nDim-2) should be multiple of DataPerRead to keep alignment");
|
||||
constexpr auto map_thread_cluster_2_src_cluster = MapThreadCluster2SrcCluster{};
|
||||
|
||||
static_assert(SrcSubLengths{}.Get(Number<nDim-1>{}) % DataPerRead == 0, "wrong! SrcSubLengths[nDim-1] % DataPerRead != 0\n");
|
||||
constexpr auto src_cluster_lengths = SrcClusterLengths{};
|
||||
constexpr auto thread_cluster_lengths =
|
||||
src_cluster_lengths.ReorderGivenNew2Old(map_thread_cluster_2_src_cluster);
|
||||
|
||||
static_loop<nDim-1>([](auto I){
|
||||
constexpr index_t src_len = SrcLengths{}.Get(I);
|
||||
constexpr index_t src_sub_len = SrcSubLengths{}.Get(I);
|
||||
constexpr index_t thread_per_dim = SrcThreadPerDims{}.Get(I);
|
||||
static_assert(src_len % (src_sub_len * thread_per_dim) == 0,
|
||||
"wrong! cannot evenly divide tensor lengths");
|
||||
});
|
||||
constexpr auto thread_cluster_desc = make_ConstantTensorDescriptor(thread_cluster_lengths);
|
||||
|
||||
constexpr index_t num_active_thread = accumulate_on_sequence(SrcOpThreadPerDims{}, mod_conv::multiplies<index_t>{}, Number<1>{});
|
||||
// sanity check: data type
|
||||
static_assert(is_same<Float, float>::value, "wrong! only support float for now!\n");
|
||||
|
||||
// sanity check: nDim
|
||||
static_assert(SrcDesc::GetDimension() == nDim && DstDesc::GetDimension() == nDim &&
|
||||
SrcLengths::GetSize() == nDim && SrcSubLengths::GetSize() == nDim &&
|
||||
SrcClusterLengths::GetSize() == nDim && MapDst2Src::GetSize() == nDim &&
|
||||
MapThreadCluster2SrcCluster::GetSize() == nDim,
|
||||
"wrong! nDim is not consistent\n");
|
||||
|
||||
// sanity check: BlockSize
|
||||
constexpr index_t num_active_thread = thread_cluster_desc.GetElementSize();
|
||||
|
||||
static_assert(BlockSize >= num_active_thread,
|
||||
"wrong! BlockSize is not big enough for ThreadPerDims!");
|
||||
|
||||
// sanity check: work division
|
||||
static_for<0, nDim, 1>{}([](auto IDim) {
|
||||
constexpr auto I = decltype(IDim){};
|
||||
constexpr index_t src_len = src_lengths.Get(I);
|
||||
constexpr index_t src_sub_len = src_sub_lengths.Get(I);
|
||||
constexpr index_t src_cluster_len = src_cluster_lengths.Get(I);
|
||||
static_assert(src_len % (src_sub_len * src_cluster_len) == 0,
|
||||
"wrong! cannot evenly divide Src tensor lengths");
|
||||
});
|
||||
|
||||
// sanity check: src read
|
||||
static_assert(SrcDataPerRead == 1 || SrcDataPerRead == 2 || SrcDataPerRead == 4,
|
||||
"wrong! only support SrcDataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(SrcDataPerRead == 1 || src_desc.GetStride(Number<nDim - 1>{}) == 1,
|
||||
"wrong! only support src.stride(nDim-1) == 1 if SrcDataPerRead > 1!\n");
|
||||
|
||||
static_assert(src_sub_lengths.Get(Number<nDim - 1>{}) % SrcDataPerRead == 0,
|
||||
"wrong! src_sub_lengths[nDim-1] % SrcDataPerRead != 0\n");
|
||||
|
||||
static_assert(src_desc.GetStride(Number<nDim - 2>{}) % SrcDataPerRead == 0,
|
||||
"wrong! should satisfy src_desc.stride(nDim-2) % SrcDataPerRead == 0, to "
|
||||
"keep alignment");
|
||||
|
||||
// sanity check: dst write
|
||||
static_assert(DstDataPerWrite == 1 || DstDataPerWrite == 2 || DstDataPerWrite == 4,
|
||||
"wrong! only support DstDataPerWrite == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(DstDataPerWrite == 1 || dst_desc.GetStride(Number<nDim - 1>{}) == 1,
|
||||
"wrong! only support dst.stride(nDim-1) == 1 if DstDataPerWrite > 1!\n");
|
||||
|
||||
static_assert(dst_sub_lengths.Get(Number<nDim - 1>{}) % DstDataPerWrite == 0,
|
||||
"wrong! dst_sub_lengths[nDim-1] % DstDataPerWrite != 0\n");
|
||||
|
||||
static_assert(dst_desc.GetStride(Number<nDim - 2>{}) % DstDataPerWrite == 0,
|
||||
"wrong! should satisfy dst_desc.stride(nDim-2) % DstDataPerWrite == 0, to "
|
||||
"keep alignment");
|
||||
|
||||
// start dividing work
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
@@ -819,37 +863,251 @@ struct Blockwise4dTensorCopyReorder3
|
||||
}
|
||||
}
|
||||
|
||||
const auto thread_multi_id = SrcOpThreadPerDims::GetMultiIndex(get_thread_local_1d_id());
|
||||
const auto thread_multi_id = thread_cluster_desc.GetMultiIndex(get_thread_local_1d_id());
|
||||
|
||||
// compiler: thread_multi_id, src_data_multi_id, dst_data_multi_id, will use separate
|
||||
// regsiters, or only one copy???
|
||||
auto src_data_multi_id =
|
||||
reorder_array_given_old2new(thread_multi_id, map_thread_cluster_2_src_cluster);
|
||||
|
||||
const index_t thread_id_d0 =
|
||||
get_thread_local_1d_id() / (thread_per_d1 * thread_per_d2 * thread_per_d3);
|
||||
index_t itmp = get_thread_local_1d_id() -
|
||||
thread_id_d0 * (thread_per_d1 * thread_per_d2 * thread_per_d3);
|
||||
const index_t thread_id_d1 = itmp / (thread_per_d2 * thread_per_d3);
|
||||
itmp -= thread_id_d1 * (thread_per_d2 * thread_per_d3);
|
||||
const index_t thread_id_d2 = itmp / thread_per_d3;
|
||||
const index_t thread_id_d3 = itmp - thread_id_d2 * thread_per_d3;
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
constexpr auto I = decltype(IDim){};
|
||||
constexpr index_t i = I.Get();
|
||||
// compiler: will it really compute index here, or be associated with Get1dIndex and
|
||||
// optimized away???
|
||||
src_data_multi_id[i] *= src_sub_lengths.Get(I);
|
||||
});
|
||||
|
||||
// compiler: will it really compute index here, or be associated with Get1dIndex and
|
||||
// optimized away???
|
||||
const auto dst_data_multi_id = reorder_array_given_new2old(src_data_multi_id, map_dst2src);
|
||||
|
||||
mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(
|
||||
thread_id_d0, thread_id_d1, thread_id_d2, thread_id_d3 * DataPerRead);
|
||||
mSrcMyThreadOffset = src_desc.Get1dIndex(src_data_multi_id);
|
||||
mDstMyThreadOffset = dst_desc.Get1dIndex(dst_data_multi_id);
|
||||
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("tid %5u, "
|
||||
"thread_multi_id %5u %5u %5u %5u, "
|
||||
"src_data_multi_id %5u %5u %5u %5u, "
|
||||
"dst_data_multi_id %5u %5u %5u %5u, "
|
||||
"mSrcMyThreadOffset %u, mDstMyThreadOffset %u\n",
|
||||
get_thread_local_1d_id(),
|
||||
thread_multi_id[0],
|
||||
thread_multi_id[1],
|
||||
thread_multi_id[2],
|
||||
thread_multi_id[3],
|
||||
src_data_multi_id[0],
|
||||
src_data_multi_id[1],
|
||||
src_data_multi_id[2],
|
||||
src_data_multi_id[3],
|
||||
dst_data_multi_id[0],
|
||||
dst_data_multi_id[1],
|
||||
dst_data_multi_id[2],
|
||||
dst_data_multi_id[3],
|
||||
mSrcMyThreadOffset,
|
||||
mDstMyThreadOffset);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegisterClipboardSize()
|
||||
{
|
||||
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
|
||||
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
|
||||
|
||||
constexpr auto src_data_per_cluster_per_dims = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, SrcClusterLengths{});
|
||||
|
||||
constexpr auto cluster_per_dims =
|
||||
transform_sequences(mod_conv::integer_divide_ceiler<index_t>{},
|
||||
SrcLengths{},
|
||||
src_data_per_cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_lengths = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor(thread_tensor_lengths);
|
||||
|
||||
return thread_tensor_desc.GetElementSpace();
|
||||
}
|
||||
|
||||
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
|
||||
Float* __restrict__ p_clipboard) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
|
||||
|
||||
constexpr auto src_data_per_cluster_per_dims = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, SrcClusterLengths{});
|
||||
|
||||
constexpr auto cluster_per_dims =
|
||||
transform_sequences(mod_conv::integer_divide_ceiler<index_t>{},
|
||||
SrcLengths{},
|
||||
src_data_per_cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_lengths = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor(thread_tensor_lengths);
|
||||
|
||||
constexpr auto thread_sub_tensor_desc =
|
||||
make_ConstantTensorDescriptor(SrcClusterLengths{}, thread_tensor_desc.GetStrides());
|
||||
|
||||
for(index_t icluster_d0 = 0; icluster_d0 < cluster_per_dims.Get(I0); ++icluster_d0)
|
||||
{
|
||||
for(index_t icluster_d1 = 0; icluster_d1 < cluster_per_dims.Get(I1); ++icluster_d1)
|
||||
{
|
||||
for(index_t icluster_d2 = 0; icluster_d2 < cluster_per_dims.Get(I2); ++icluster_d2)
|
||||
{
|
||||
for(index_t icluster_d3 = 0; icluster_d3 < cluster_per_dims.Get(I3);
|
||||
++icluster_d3)
|
||||
{
|
||||
const index_t src_offset = SrcDesc{}.Get1dIndex(
|
||||
icluster_d0 * src_data_per_cluster_per_dims.Get(I0),
|
||||
icluster_d1 * src_data_per_cluster_per_dims.Get(I1),
|
||||
icluster_d2 * src_data_per_cluster_per_dims.Get(I2),
|
||||
icluster_d3 * src_data_per_cluster_per_dims.Get(I3));
|
||||
|
||||
const index_t clipboard_offset = thread_tensor_desc.Get1dIndex(
|
||||
icluster_d0 * thread_sub_tensor_lengths.Get(I0),
|
||||
icluster_d1 * thread_sub_tensor_lengths.Get(I1),
|
||||
icluster_d2 * thread_sub_tensor_lengths.Get(I2),
|
||||
icluster_d3 * thread_sub_tensor_lengths.Get(I3));
|
||||
|
||||
threadwise_4d_tensor_copy_v2(SrcDesc{},
|
||||
p_src + src_offset + mSrcMyThreadOffset,
|
||||
thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
thread_sub_tensor_lengths,
|
||||
Number<SrcDataPerRead>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("tid %5u, "
|
||||
"data: %f %f %f %f %f %f %f %f\n",
|
||||
get_thread_local_1d_id(),
|
||||
p_clipboard[0],
|
||||
p_clipboard[1],
|
||||
p_clipboard[2],
|
||||
p_clipboard[3],
|
||||
p_clipboard[4],
|
||||
p_clipboard[5],
|
||||
p_clipboard[6],
|
||||
p_clipboard[7]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
|
||||
Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto thread_sub_tensor_lengths = SrcSubLengths{};
|
||||
|
||||
constexpr auto src_data_per_cluster_per_dims = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, SrcClusterLengths{});
|
||||
|
||||
constexpr auto cluster_per_dims =
|
||||
transform_sequences(mod_conv::integer_divide_ceiler<index_t>{},
|
||||
SrcLengths{},
|
||||
src_data_per_cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_lengths = transform_sequences(
|
||||
mod_conv::multiplies<index_t>{}, thread_sub_tensor_lengths, cluster_per_dims);
|
||||
|
||||
constexpr auto thread_tensor_desc = make_ConstantTensorDescriptor(thread_tensor_lengths);
|
||||
|
||||
constexpr auto thread_sub_tensor_desc =
|
||||
make_ConstantTensorDescriptor(SrcClusterLengths{}, thread_tensor_desc.GetStrides());
|
||||
|
||||
for(index_t icluster_d0 = 0; icluster_d0 < cluster_per_dims.Get(I0); ++icluster_d0)
|
||||
{
|
||||
for(index_t icluster_d1 = 0; icluster_d1 < cluster_per_dims.Get(I1); ++icluster_d1)
|
||||
{
|
||||
for(index_t icluster_d2 = 0; icluster_d2 < cluster_per_dims.Get(I2); ++icluster_d2)
|
||||
{
|
||||
for(index_t icluster_d3 = 0; icluster_d3 < cluster_per_dims.Get(I3);
|
||||
++icluster_d3)
|
||||
{
|
||||
const index_t clipboard_offset = thread_tensor_desc.Get1dIndex(
|
||||
icluster_d0 * thread_sub_tensor_lengths.Get(I0),
|
||||
icluster_d1 * thread_sub_tensor_lengths.Get(I1),
|
||||
icluster_d2 * thread_sub_tensor_lengths.Get(I2),
|
||||
icluster_d3 * thread_sub_tensor_lengths.Get(I3));
|
||||
|
||||
const auto dst_multi_id = reorder_array_given_new2old(
|
||||
Array<index_t, nDim>{
|
||||
icluster_d0 * src_data_per_cluster_per_dims.Get(I0),
|
||||
icluster_d1 * src_data_per_cluster_per_dims.Get(I1),
|
||||
icluster_d2 * src_data_per_cluster_per_dims.Get(I2),
|
||||
icluster_d3 * src_data_per_cluster_per_dims.Get(I3)},
|
||||
MapDst2Src{});
|
||||
|
||||
const index_t dst_offset = DstDesc{}.Get1dIndex(dst_multi_id);
|
||||
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("tid %5u, "
|
||||
"clipboard_offsetm %5u, dst_offset %5u\n",
|
||||
get_thread_local_1d_id(),
|
||||
clipboard_offset,
|
||||
dst_offset);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(
|
||||
thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
DstDesc{},
|
||||
p_dst + dst_offset + mDstMyThreadOffset,
|
||||
thread_sub_tensor_lengths,
|
||||
MapDst2Src{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("tid %5u, "
|
||||
"data: %f %f %f %f %f %f %f %f\n",
|
||||
get_thread_local_1d_id(),
|
||||
p_clipboard[0],
|
||||
p_clipboard[1],
|
||||
p_clipboard[2],
|
||||
p_clipboard[3],
|
||||
p_clipboard[4],
|
||||
p_clipboard[5],
|
||||
p_clipboard[6],
|
||||
p_clipboard[7]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
Float p_clipboard[GetRegisterClipboardSize()];
|
||||
|
||||
RunLoadRegisterClipboard(p_src, p_clipboard);
|
||||
RunStoreRegisterClipboard(p_clipboard, p_dst);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
@@ -25,12 +25,38 @@ struct is_same<T, T>
|
||||
static const bool value = true;
|
||||
};
|
||||
|
||||
__host__ __device__ constexpr index_t integer_divide_ceil(index_t a, index_t b)
|
||||
namespace mod_conv { // namespace mod_conv
|
||||
template <class T>
|
||||
struct multiplies
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a * b; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct plus
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct integer_divide_ceiler
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const
|
||||
{
|
||||
static_assert(is_same<T, index_t>::value || is_same<T, int>::value, "wrong type");
|
||||
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
__host__ __device__ constexpr T integer_divide_ceil(T a, T b)
|
||||
{
|
||||
static_assert(is_same<T, index_t>::value || is_same<T, int>::value, "wrong type");
|
||||
|
||||
return (a + b - 1) / b;
|
||||
}
|
||||
|
||||
namespace mod_conv { // namespace mod_conv
|
||||
template <class T>
|
||||
__host__ __device__ constexpr T max(T x, T y)
|
||||
{
|
||||
|
||||
@@ -70,18 +70,3 @@ __host__ __device__ constexpr auto unpacker(F f)
|
||||
return [=](auto xs_array){ f(xs...); };
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace mod_conv {
|
||||
template <class T>
|
||||
struct multiplies
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a * b; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct plus
|
||||
{
|
||||
__host__ __device__ constexpr T operator()(T a, T b) const { return a + b; }
|
||||
};
|
||||
|
||||
} // namespace mod_conv
|
||||
|
||||
@@ -248,42 +248,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
}
|
||||
}
|
||||
|
||||
// output: register to global mem,
|
||||
#if 0
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
|
||||
{
|
||||
for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
|
||||
{
|
||||
for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
|
||||
{
|
||||
for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
|
||||
{
|
||||
const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
|
||||
|
||||
const auto c_thread_mtx_distance =
|
||||
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
|
||||
|
||||
const index_t ho_thread =
|
||||
c_thread_mtx_begin.batch + c_thread_mtx_distance.batch;
|
||||
const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
|
||||
const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
|
||||
|
||||
const index_t wo_thread = b_thread / NPerBlock;
|
||||
const index_t n_thread = b_thread % NPerBlock;
|
||||
|
||||
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
|
||||
ho_block_data_begin + ho_thread,
|
||||
wo_block_data_begin + wo_thread,
|
||||
n_block_data_begin + n_thread)] =
|
||||
p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif 1
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
@@ -331,6 +296,5 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite>{});
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -0,0 +1,362 @@
|
||||
#pragma once
|
||||
#include "common.hip.hpp"
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
#include "ConstantMatrixDescriptor.hip.hpp"
|
||||
#include "blockwise_2d_tensor_op.hip.hpp"
|
||||
#include "blockwise_3d_tensor_op.hip.hpp"
|
||||
#include "blockwise_4d_tensor_op.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_batched_gemm.hip.hpp"
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopyThreadPerDims,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead,
|
||||
index_t OutThreadCopyDataPerWrite>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(
|
||||
NPerThread <= NPerBlock && NPerBlock % NPerThread == 0,
|
||||
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
|
||||
|
||||
constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2);
|
||||
constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
|
||||
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
|
||||
|
||||
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
|
||||
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
|
||||
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
|
||||
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
|
||||
const index_t w_block_work_id = itmp / NBlockWork;
|
||||
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
|
||||
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin;
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// global tensor view
|
||||
constexpr auto wei_c_x_k_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<C, X, K>{}, Sequence<Y * X * K, K, 1>{});
|
||||
|
||||
// LDS tensor view
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align = mod_conv::max(
|
||||
InBlockCopyDataPerRead, WeiBlockCopyDataPerRead, GemmDataPerReadA, GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HoPerBlock, WiPerBlock, NPerBlock>{}, Number<max_align>{});
|
||||
|
||||
constexpr auto wei_c_x_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, X, KPerBlock>{}, Number<max_align>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
|
||||
auto map_chwn2nchw = Sequence<1, 2, 3, 0>{};
|
||||
#if 0
|
||||
const auto blockwise_in_copy_reorder =
|
||||
Blockwise4dTensorCopyReorder1<BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WiPerBlock>,
|
||||
decltype(map_chwn2nchw)>{};
|
||||
#else
|
||||
auto map_thread_cluster_2_src_cluster = Sequence<1, 2, 0, 3>{};
|
||||
|
||||
const auto blockwise_in_copy_reorder =
|
||||
Blockwise4dTensorCopyReorder3<BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WiPerBlock>,
|
||||
Sequence<4, 1, 1, 2>,
|
||||
Sequence<4, 8, 2, 2>,
|
||||
decltype(map_chwn2nchw),
|
||||
decltype(map_thread_cluster_2_src_cluster),
|
||||
2,
|
||||
4>{};
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
printf("size %u\n", blockwise_in_copy_reorder.GetRegisterClipboardSize());
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, X * KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
#if 0
|
||||
Blockwise3dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_x_k_global_desc),
|
||||
decltype(wei_c_x_k_block_desc),
|
||||
decltype(wei_c_x_k_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead>{};
|
||||
#else
|
||||
Blockwise3dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_x_k_global_desc),
|
||||
decltype(wei_c_x_k_block_desc),
|
||||
decltype(wei_c_x_k_block_desc.GetLengths()),
|
||||
Sequence<4, 1, 32>,
|
||||
WeiBlockCopyDataPerRead>{};
|
||||
#endif
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
|
||||
constexpr auto a_c_k_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<KPerBlock>{},
|
||||
Number<wei_c_x_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_wn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_k_wn_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{},
|
||||
Number<WoPerThread * NPerThread>{},
|
||||
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
|
||||
BlockSize,
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_wn_block_mtx_desc),
|
||||
decltype(c_k_wn_thread_mtx_desc),
|
||||
0,
|
||||
in_c_h_w_n_block_desc.GetStride(I1),
|
||||
out_k_h_w_n_thread_desc.GetStride(I1),
|
||||
HoPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space =
|
||||
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
|
||||
constexpr index_t wei_block_space =
|
||||
wei_c_x_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_x_k_block_desc, "wei_c_x_k_block_desc");
|
||||
|
||||
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
|
||||
}
|
||||
#endif
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
|
||||
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global + in_n_c_h_w_global_desc.Get1dIndex(
|
||||
n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
|
||||
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
blockwise_in_copy_reorder.Run(p_in_global_block_offset +
|
||||
in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, 0),
|
||||
p_in_block);
|
||||
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset +
|
||||
wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, 0),
|
||||
p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
blockwise_batch_gemm.Run(p_wei_block + wei_c_x_k_block_desc.Get1dIndex(0, x, 0),
|
||||
p_in_block +
|
||||
in_c_h_w_n_block_desc.Get1dIndex(0, 0, x, 0),
|
||||
p_out_thread);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
// output: register to global mem,
|
||||
#if 0
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
|
||||
{
|
||||
for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
|
||||
{
|
||||
for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
|
||||
{
|
||||
for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
|
||||
{
|
||||
const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
|
||||
|
||||
const auto c_thread_mtx_distance =
|
||||
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
|
||||
|
||||
const index_t ho_thread =
|
||||
c_thread_mtx_begin.batch + c_thread_mtx_distance.batch;
|
||||
const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
|
||||
const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
|
||||
|
||||
const index_t wo_thread = b_thread / NPerBlock;
|
||||
const index_t n_thread = b_thread % NPerBlock;
|
||||
|
||||
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
|
||||
ho_block_data_begin + ho_thread,
|
||||
wo_block_data_begin + wo_thread,
|
||||
n_block_data_begin + n_thread)] =
|
||||
p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif 1
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin =
|
||||
c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin;
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2), W1, W2, N / (N1 * N2), N1, N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_10d_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global + out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite>{});
|
||||
#endif
|
||||
}
|
||||
};
|
||||
@@ -29,26 +29,21 @@ __device__ void threadwise_2d_tensor_pointwise_operation_unary(Desc, Float* __re
|
||||
|
||||
// TODO: in order to optimize mem access for different mem type,
|
||||
// need to write specialized version
|
||||
template <class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class DstFromSrcReorder,
|
||||
class F>
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, class MapDst2Src, class F>
|
||||
__device__ void threadwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
|
||||
SrcDesc,
|
||||
Float* const __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
DstFromSrcReorder,
|
||||
MapDst2Src,
|
||||
F f)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0);
|
||||
constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1);
|
||||
constexpr index_t IR0 = MapDst2Src{}.Get(I0);
|
||||
constexpr index_t IR1 = MapDst2Src{}.Get(I1);
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
@@ -78,19 +73,19 @@ __device__ void threadwise_2d_tensor_set_zero(Desc, Float* __restrict__ p)
|
||||
Desc{}, p, f_set_zero);
|
||||
}
|
||||
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, class DstFromSrcReorder>
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, class MapDst2Src>
|
||||
__device__ void
|
||||
threadwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
|
||||
Float* const __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
DstFromSrcReorder)
|
||||
MapDst2Src)
|
||||
{
|
||||
auto f_copy = [](const Float& src, Float& dst) { dst = src; };
|
||||
|
||||
threadwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
|
||||
}
|
||||
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
|
||||
|
||||
@@ -42,7 +42,7 @@ template <class SrcData,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class DstFromSrcReorder,
|
||||
class MapDst2Src,
|
||||
class F>
|
||||
__device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
|
||||
SrcDesc,
|
||||
@@ -50,7 +50,7 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d
|
||||
DstDesc,
|
||||
DstData* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
DstFromSrcReorder,
|
||||
MapDst2Src,
|
||||
F f)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
@@ -58,10 +58,10 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr index_t IR0 = DstFromSrcReorder{}.Get(I0);
|
||||
constexpr index_t IR1 = DstFromSrcReorder{}.Get(I1);
|
||||
constexpr index_t IR2 = DstFromSrcReorder{}.Get(I2);
|
||||
constexpr index_t IR3 = DstFromSrcReorder{}.Get(I3);
|
||||
constexpr index_t IR0 = MapDst2Src{}.Get(I0);
|
||||
constexpr index_t IR1 = MapDst2Src{}.Get(I1);
|
||||
constexpr index_t IR2 = MapDst2Src{}.Get(I2);
|
||||
constexpr index_t IR3 = MapDst2Src{}.Get(I3);
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
@@ -82,7 +82,29 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d
|
||||
const index_t bindex =
|
||||
dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
|
||||
|
||||
#if 1
|
||||
f(p_src[aindex], p_dst[bindex]);
|
||||
#else
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("tid %5u, "
|
||||
"src did %u %u %u %u, "
|
||||
"dst did %u %u %u %u, "
|
||||
"aindex %5u, "
|
||||
"bindex %5u\n",
|
||||
get_thread_local_1d_id(),
|
||||
did0,
|
||||
did1,
|
||||
did2,
|
||||
did3,
|
||||
did[IR0],
|
||||
did[IR1],
|
||||
did[IR2],
|
||||
did[IR3],
|
||||
aindex,
|
||||
bindex);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -103,19 +125,19 @@ template <class SrcData,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class DstFromSrcReorder>
|
||||
class MapDst2Src>
|
||||
__device__ void
|
||||
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
|
||||
const SrcData* __restrict__ p_src,
|
||||
DstDesc,
|
||||
DstData* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
DstFromSrcReorder)
|
||||
MapDst2Src)
|
||||
{
|
||||
auto f_copy = [](const SrcData& src, DstData& dst) { dst = static_cast<DstData>(src); };
|
||||
|
||||
threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
|
||||
}
|
||||
|
||||
template <class SrcData, class DstData, class SrcDesc, class DstDesc, class SrcOpLengths>
|
||||
@@ -137,13 +159,12 @@ __device__ void threadwise_4d_tensor_copy_v2(SrcDesc,
|
||||
SrcOpLengths,
|
||||
Number<DataPerRead>)
|
||||
{
|
||||
using Float2 = float2;
|
||||
using Float4 = float4;
|
||||
|
||||
static_assert(SrcDesc{}.GetDimension() == 4 && DstDesc{}.GetDimension() == 4 &&
|
||||
SrcOpLengths::nDim == 4,
|
||||
SrcOpLengths::GetSize() == 4,
|
||||
"wrong! should be 4 dimension");
|
||||
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
@@ -183,24 +204,8 @@ __device__ void threadwise_4d_tensor_copy_v2(SrcDesc,
|
||||
const index_t dst_index =
|
||||
dst_desc.Get1dIndex(did0, did1, did2, iloop_d3 * DataPerRead);
|
||||
|
||||
if(DataPerRead == 1)
|
||||
{
|
||||
p_dst[dst_index] = p_src[src_index];
|
||||
}
|
||||
else if(DataPerRead == 2)
|
||||
{
|
||||
*(reinterpret_cast<Float2*>(p_dst + dst_index)) =
|
||||
*(reinterpret_cast<const Float2*>(p_src + src_index));
|
||||
}
|
||||
else if(DataPerRead == 4)
|
||||
{
|
||||
*(reinterpret_cast<Float4*>(p_dst + dst_index)) =
|
||||
*(reinterpret_cast<const Float4*>(p_src + src_index));
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[dst_index])) =
|
||||
*(reinterpret_cast<const vector_t*>(&p_src[src_index]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,7 +175,7 @@ __device__ void threadwise_10d_tensor_copy(SrcDesc,
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
static_assert(SrcDesc{}.GetDimension() == 10 && DstDesc{}.GetDimension() == 10 &&
|
||||
SrcOpLengths::nDim == 10,
|
||||
SrcOpLengths::GetSize() == 10,
|
||||
"wrong! should be 10 dimension");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
|
||||
Reference in New Issue
Block a user