mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
add 2nd variation of implicit gemm
This commit is contained in:
@@ -8,8 +8,9 @@
|
||||
#include "conv_common.cuh"
|
||||
#include "device_direct_convolution_1.cuh"
|
||||
#include "device_direct_convolution_2.cuh"
|
||||
#include "device_implicit_gemm_convolution_nchw_kcsr.cuh"
|
||||
#include "device_implicit_gemm_convolution_nchw_srck.cuh"
|
||||
#include "device_implicit_gemm_convolution_1_nchw_kcsr.cuh"
|
||||
#include "device_implicit_gemm_convolution_1_nchw_srck.cuh"
|
||||
#include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh"
|
||||
//#include "device_winograd_convolution.cuh"
|
||||
|
||||
struct GeneratorTensor_1
|
||||
@@ -52,6 +53,21 @@ struct GeneratorTensor_3
|
||||
}
|
||||
};
|
||||
|
||||
struct GeneratorTensor_Checkboard
|
||||
{
|
||||
template <class... Ts>
|
||||
double operator()(Ts... Xs) const
|
||||
{
|
||||
std::array<unsigned long, sizeof...(Ts)> dims = {{Xs...}};
|
||||
return std::accumulate(dims.begin(),
|
||||
dims.end(),
|
||||
true,
|
||||
[](bool init, unsigned long x) -> int { return init != (x % 2); })
|
||||
? 1
|
||||
: -1;
|
||||
}
|
||||
};
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <class TConstTensorDesc>
|
||||
void ostream_ConstantTensorDescriptor(TConstTensorDesc, std::ostream& os = std::cout)
|
||||
@@ -337,13 +353,13 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
int main()
|
||||
{
|
||||
#if 0
|
||||
constexpr unsigned N = 1;
|
||||
constexpr unsigned C = 1;
|
||||
constexpr unsigned N = 1;
|
||||
constexpr unsigned C = 2;
|
||||
constexpr unsigned HI = 34;
|
||||
constexpr unsigned WI = 34;
|
||||
constexpr unsigned K = 1;
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
constexpr unsigned K = 2;
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
#elif 1
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 256;
|
||||
@@ -389,31 +405,29 @@ int main()
|
||||
#if 0
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
#elif 0
|
||||
#elif 1
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei_kcsr.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
#endif
|
||||
|
||||
for(int i = 0; i < 40; ++i)
|
||||
{
|
||||
#if 0
|
||||
device_direct_convolution_1(in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device);
|
||||
#elif 0
|
||||
device_direct_convolution_2(
|
||||
in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device);
|
||||
#elif 1
|
||||
device_implicit_gemm_convolution_nchw_kcsr(
|
||||
in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device);
|
||||
#elif 1
|
||||
device_implicit_gemm_convolution_nchw_srck(
|
||||
in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device);
|
||||
#elif 0
|
||||
device_winograd_convolution(
|
||||
in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device);
|
||||
#endif
|
||||
}
|
||||
unsigned nrepeat = 100;
|
||||
|
||||
#if 0
|
||||
device_direct_convolution_1
|
||||
#elif 0
|
||||
device_direct_convolution_2
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_1_nchw_kcsr
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_1_nchw_srck
|
||||
#elif 1
|
||||
device_implicit_gemm_convolution_2_cnhw_srck_knhw
|
||||
#elif 0
|
||||
device_winograd_convolution
|
||||
#endif
|
||||
(in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat);
|
||||
|
||||
#if 1
|
||||
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host);
|
||||
check_error(out_nkhw_host, out_nkhw_device);
|
||||
#elif 0
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
#pragma once
|
||||
#include "gridwise_implicit_gemm_convolution_nchw_kcsr.cuh"
|
||||
#include "gridwise_implicit_gemm_convolution_1_nchw_kcsr.cuh"
|
||||
|
||||
template <class T, class InDesc, class WeiDesc, class OutDesc>
|
||||
void device_implicit_gemm_convolution_nchw_kcsr(
|
||||
void device_implicit_gemm_convolution_1_nchw_kcsr(
|
||||
InDesc, const Tensor<T>& in, WeiDesc, const Tensor<T>& wei, OutDesc, Tensor<T>& out)
|
||||
{
|
||||
std::size_t data_sz = sizeof(T);
|
||||
@@ -81,21 +81,21 @@ void device_implicit_gemm_convolution_nchw_kcsr(
|
||||
cudaEventCreate(&start);
|
||||
cudaEventRecord(start, 0);
|
||||
|
||||
gridwise_implicit_gemm_convolution_nchw_kcsr<GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
InDesc,
|
||||
WeiDesc,
|
||||
OutDesc,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
CPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
KPerThread,
|
||||
CPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread>
|
||||
gridwise_implicit_gemm_convolution_1_nchw_kcsr<GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
InDesc,
|
||||
WeiDesc,
|
||||
OutDesc,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
CPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
KPerThread,
|
||||
CPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread>
|
||||
<<<grid_dim, block_dim>>>(InDesc{},
|
||||
static_cast<T*>(in_device_buf.GetDeviceBuffer()),
|
||||
WeiDesc{},
|
||||
@@ -1,13 +1,13 @@
|
||||
#pragma once
|
||||
#include "gridwise_implicit_gemm_convolution_nchw_srck.cuh"
|
||||
#include "gridwise_implicit_gemm_convolution_1_nchw_srck.cuh"
|
||||
|
||||
template <class T, class InDesc, class WeiDesc, class OutDesc>
|
||||
void device_implicit_gemm_convolution_nchw_srck(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcsr,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw)
|
||||
void device_implicit_gemm_convolution_1_nchw_srck(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcsr,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -88,6 +88,19 @@ void device_implicit_gemm_convolution_nchw_srck(InDesc,
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 0
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 2;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 32;
|
||||
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#endif
|
||||
|
||||
@@ -106,21 +119,21 @@ void device_implicit_gemm_convolution_nchw_srck(InDesc,
|
||||
cudaEventCreate(&start);
|
||||
cudaEventRecord(start, 0);
|
||||
|
||||
gridwise_implicit_gemm_convolution_nchw_srck<GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_srck_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
CPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
KPerThread,
|
||||
CPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread>
|
||||
gridwise_implicit_gemm_convolution_1_nchw_srck<GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_srck_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
CPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
KPerThread,
|
||||
CPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread>
|
||||
<<<grid_dim, block_dim>>>(in_nchw_desc,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
wei_srck_desc,
|
||||
163
driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh
Normal file
163
driver/device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh
Normal file
@@ -0,0 +1,163 @@
|
||||
#pragma once
|
||||
#include "gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh"
|
||||
|
||||
template <class T, class InDesc, class WeiDesc, class OutDesc>
|
||||
void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcsr,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
unsigned nrepeat)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_nchw_desc = InDesc{};
|
||||
constexpr auto wei_kcsr_desc = WeiDesc{};
|
||||
constexpr auto out_nkhw_desc = OutDesc{};
|
||||
|
||||
constexpr unsigned N = in_nchw_desc.GetLength(I0);
|
||||
constexpr unsigned Hi = in_nchw_desc.GetLength(I2);
|
||||
constexpr unsigned Wi = in_nchw_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned Ho = out_nkhw_desc.GetLength(I2);
|
||||
constexpr unsigned Wo = out_nkhw_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned K = wei_kcsr_desc.GetLength(I0);
|
||||
constexpr unsigned C = wei_kcsr_desc.GetLength(I1);
|
||||
constexpr unsigned S = wei_kcsr_desc.GetLength(I2);
|
||||
constexpr unsigned R = wei_kcsr_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1);
|
||||
|
||||
// convert in_nchw to in_cnhw
|
||||
auto in_cnhw_desc = make_ConstantTensorDescriptor(Sequence<C, N, Hi, Wi>{});
|
||||
ostream_ConstantTensorDescriptor(in_cnhw_desc, std::cout << "in_cnhw_desc: ");
|
||||
|
||||
Tensor<T> in_cnhw(make_TensorDescriptor(in_cnhw_desc));
|
||||
|
||||
auto f_reorder_nchw2cnhw = [&](auto n, auto c, auto hi, auto wi) {
|
||||
in_cnhw(c, n, hi, wi) = in_nchw(n, c, hi, wi);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_reorder_nchw2cnhw, N, C, Hi, Wi)(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
// convert wei_kcsr to wei_srck
|
||||
auto wei_srck_desc = make_ConstantTensorDescriptor(Sequence<S, R, C, K>{});
|
||||
ostream_ConstantTensorDescriptor(wei_srck_desc, std::cout << "wei_srck_desc: ");
|
||||
|
||||
Tensor<T> wei_srck(make_TensorDescriptor(wei_srck_desc));
|
||||
|
||||
auto f_reorder_kcsr2srck = [&](auto k, auto c, auto s, auto r) {
|
||||
wei_srck(s, r, c, k) = wei_kcsr(k, c, s, r);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_reorder_kcsr2srck, K, C, S, R)(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
// conver out_nkhw to out_knhw
|
||||
auto out_knhw_desc = make_ConstantTensorDescriptor(Sequence<K, N, Ho, Wo>{});
|
||||
ostream_ConstantTensorDescriptor(out_knhw_desc, std::cout << "out_knhw_desc: ");
|
||||
|
||||
Tensor<T> out_knhw(make_TensorDescriptor(out_knhw_desc));
|
||||
|
||||
#if 0
|
||||
constexpr unsigned BPerBlock = 128;
|
||||
constexpr unsigned KPerBlock = 1;
|
||||
constexpr unsigned CPerBlock = 1;
|
||||
|
||||
constexpr unsigned BPerThread = 4;
|
||||
constexpr unsigned KPerThread = 1;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 32;
|
||||
#elif 0
|
||||
constexpr unsigned BPerBlock = 128;
|
||||
constexpr unsigned KPerBlock = 2;
|
||||
constexpr unsigned CPerBlock = 2;
|
||||
|
||||
constexpr unsigned BPerThread = 4;
|
||||
constexpr unsigned KPerThread = 2;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 32;
|
||||
#elif 1
|
||||
constexpr unsigned BPerBlock = 128;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 2;
|
||||
|
||||
constexpr unsigned BPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#endif
|
||||
|
||||
constexpr unsigned GridSize =
|
||||
((N * Hi * Wi + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
|
||||
|
||||
dim3 block_dim(BlockSize);
|
||||
dim3 grid_dim(GridSize);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
// mem
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_cnhw_device_buf(data_sz * (in_cnhw.mDesc.GetElementSpace() + BGhostRead +
|
||||
BPerBlock)); // reserve extra space for BGhostRead
|
||||
DeviceMem wei_srck_device_buf(data_sz * wei_srck.mDesc.GetElementSpace());
|
||||
DeviceMem out_knhw_device_buf(data_sz * out_knhw.mDesc.GetElementSpace());
|
||||
|
||||
in_cnhw_device_buf.ToDevice(in_cnhw.mData.data());
|
||||
wei_srck_device_buf.ToDevice(wei_srck.mData.data());
|
||||
out_knhw_device_buf.ToDevice(out_knhw.mData.data());
|
||||
|
||||
for(unsigned i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
cudaEvent_t start, stop;
|
||||
float elapsedTime;
|
||||
cudaEventCreate(&start);
|
||||
cudaEventRecord(start, 0);
|
||||
|
||||
gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw<GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
decltype(in_cnhw_desc),
|
||||
decltype(wei_srck_desc),
|
||||
decltype(out_knhw_desc),
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
CPerBlock,
|
||||
BPerThread,
|
||||
KPerThread,
|
||||
CPerThread>
|
||||
<<<grid_dim, block_dim>>>(in_cnhw_desc,
|
||||
static_cast<T*>(in_cnhw_device_buf.GetDeviceBuffer()),
|
||||
wei_srck_desc,
|
||||
static_cast<T*>(wei_srck_device_buf.GetDeviceBuffer()),
|
||||
out_knhw_desc,
|
||||
static_cast<T*>(out_knhw_device_buf.GetDeviceBuffer()));
|
||||
|
||||
cudaEventCreate(&stop);
|
||||
cudaEventRecord(stop, 0);
|
||||
cudaEventSynchronize(stop);
|
||||
|
||||
cudaEventElapsedTime(&elapsedTime, start, stop);
|
||||
printf("Elapsed time : %f ms\n", elapsedTime);
|
||||
}
|
||||
|
||||
checkCudaErrors(cudaGetLastError());
|
||||
out_knhw_device_buf.FromDevice(out_knhw.mData.data());
|
||||
|
||||
// convert out_knhw to out_nkhw
|
||||
auto f_reorder_knhw2nkhw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
out_nkhw(n, k, ho, wo) = out_knhw(k, n, ho, wo);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_reorder_knhw2nkhw, N, K, Ho, Wo)(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
@@ -46,3 +46,14 @@ __host__ __device__ constexpr auto
|
||||
{
|
||||
return ConstantMatrixDescriptor<NRow, NCol, RowStride>{};
|
||||
}
|
||||
|
||||
template <class TDesc>
|
||||
__host__ __device__ void print_ConstantMatrixDescriptor(TDesc, const char* s)
|
||||
{
|
||||
const auto desc = TDesc{};
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
printf("%s NRow %u NCol %u RowStride %u\n", s, desc.NRow(), desc.NCol(), desc.RowStride());
|
||||
}
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
#pragma once
|
||||
#include "common.cuh"
|
||||
|
||||
// this is ugly, only for 2d
|
||||
template <unsigned L0, unsigned L1>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1>)
|
||||
{
|
||||
return Sequence<L1, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>)
|
||||
@@ -49,28 +56,85 @@ struct ConstantTensorDescriptor
|
||||
// this is ugly, only for 4d
|
||||
__host__ __device__ constexpr unsigned GetElementSize() const
|
||||
{
|
||||
static_assert(nDim == 4, "nDim is not 4");
|
||||
static_assert(nDim >= 2 && nDim <= 4, "nDim");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
if(nDim == 2)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3);
|
||||
return GetLength(I0) * GetLength(I1);
|
||||
}
|
||||
else if(nDim == 3)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
return GetLength(I0) * GetLength(I1) * GetLength(I2);
|
||||
}
|
||||
else if(nDim == 4)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3);
|
||||
}
|
||||
}
|
||||
|
||||
// this is ugly, only for 4d
|
||||
__host__ __device__ constexpr unsigned GetElementSpace() const
|
||||
{
|
||||
static_assert(nDim == 4, "nDim is not 4");
|
||||
static_assert(nDim >= 2 && nDim <= 4, "nDim");
|
||||
|
||||
if(nDim == 2)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) + 1;
|
||||
}
|
||||
else if(nDim == 3)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
|
||||
(GetLength(I2) - 1) * GetStride(I2) + 1;
|
||||
}
|
||||
else if(nDim == 4)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
|
||||
(GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) + 1;
|
||||
}
|
||||
}
|
||||
|
||||
// this is ugly, only for 2d
|
||||
__host__ __device__ unsigned Get1dIndex(unsigned i0, unsigned i1) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
static_assert(nDim == 2, "nDim is not 2");
|
||||
return i0 * GetStride(I0) + i1 * GetStride(I1);
|
||||
}
|
||||
|
||||
// this is ugly, only for 3d
|
||||
__host__ __device__ unsigned Get1dIndex(unsigned i0, unsigned i1, unsigned i2) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) +
|
||||
(GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) + 1;
|
||||
static_assert(nDim == 3, "nDim is not 3");
|
||||
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2);
|
||||
}
|
||||
|
||||
// this is ugly, only for 4d
|
||||
@@ -106,28 +170,44 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Stride
|
||||
return ConstantTensorDescriptor<Lengths, Strides>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <class TDesc>
|
||||
__host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
{
|
||||
constexpr auto desc = TDesc{};
|
||||
constexpr auto desc = TDesc{};
|
||||
constexpr unsigned ndim = desc.GetDimension();
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
static_assert(ndim >= 2 && ndim <= 4, "wrong!");
|
||||
|
||||
static_assert(desc.GetDimension() == 4, "dim is not 4");
|
||||
if(ndim == 2)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
desc.GetLength(I3),
|
||||
desc.GetStride(I0),
|
||||
desc.GetStride(I1),
|
||||
desc.GetStride(I2),
|
||||
desc.GetStride(I3));
|
||||
}
|
||||
printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetStride(I0),
|
||||
desc.GetStride(I1));
|
||||
}
|
||||
else if(ndim == 4)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
|
||||
s,
|
||||
desc.GetDimension(),
|
||||
desc.GetLength(I0),
|
||||
desc.GetLength(I1),
|
||||
desc.GetLength(I2),
|
||||
desc.GetLength(I3),
|
||||
desc.GetStride(I0),
|
||||
desc.GetStride(I1),
|
||||
desc.GetStride(I2),
|
||||
desc.GetStride(I3));
|
||||
}
|
||||
}
|
||||
|
||||
172
src/include/blockwise_2d_tensor_op.cuh
Normal file
172
src/include/blockwise_2d_tensor_op.cuh
Normal file
@@ -0,0 +1,172 @@
|
||||
#pragma once
|
||||
#include "ConstantTensorDescriptor.cuh"
|
||||
|
||||
template <unsigned BlockSize, class Float, class DstDesc, class F>
|
||||
__device__ void
|
||||
blockwise_2d_tensor_pointwise_operation_unary(DstDesc, Float* __restrict__ p_dst, F f)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
constexpr auto desc = make_ConstantTensorDescriptor(dst_desc.GetLengths());
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op_unary: dst_desc: ");
|
||||
print_ConstantTensorDescriptor(desc, "blockwise_4d_tensor_op_unary: desc: ");
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr unsigned NLoop = desc.GetElementSize() / BlockSize;
|
||||
|
||||
for(unsigned iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
unsigned is = threadIdx.x + iloop * BlockSize;
|
||||
|
||||
const unsigned did0 = is / desc.GetStride(I0);
|
||||
|
||||
is -= did0 * desc.GetStride(I0);
|
||||
|
||||
const unsigned did1 = is / desc.GetStride(I1);
|
||||
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
f(p_dst[dindex]);
|
||||
}
|
||||
|
||||
constexpr bool has_tail = (desc.GetElementSize() > NLoop * BlockSize);
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
unsigned is = threadIdx.x + NLoop * BlockSize;
|
||||
|
||||
if(is < desc.GetElementSize())
|
||||
{
|
||||
const unsigned did0 = is / desc.GetStride(I0);
|
||||
|
||||
is -= did0 * desc.GetStride(I0);
|
||||
|
||||
const unsigned did1 = is / desc.GetStride(I1);
|
||||
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
f(p_dst[dindex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Function: p_dst[reorder[i0], reorder[i1], reorder[i2], reorder[i3]] = p_src[i0,i1,i2,i3]
|
||||
// TODO: in order to optimize mem access for different mem type,
|
||||
// need to write specialized version
|
||||
template <unsigned BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class DstFromSrcReorder,
|
||||
class F>
|
||||
__device__ void blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
|
||||
SrcDesc,
|
||||
Float* const __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
DstFromSrcReorder,
|
||||
F f)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0);
|
||||
constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1);
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
|
||||
|
||||
constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
|
||||
for(unsigned iloop = 0; iloop < NLoop; ++iloop)
|
||||
{
|
||||
unsigned is = threadIdx.x + iloop * BlockSize;
|
||||
|
||||
unsigned did[2];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
is -= did[0] * ref_desc.GetStride(I0);
|
||||
|
||||
did[1] = is / ref_desc.GetStride(I1);
|
||||
|
||||
const unsigned aindex = src_desc.Get1dIndex(did[0], did[1]);
|
||||
|
||||
const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
|
||||
|
||||
f(p_src[aindex], p_dst[bindex]);
|
||||
}
|
||||
|
||||
constexpr bool has_tail = (ref_desc.GetElementSize() > NLoop * BlockSize);
|
||||
|
||||
if(has_tail)
|
||||
{
|
||||
unsigned is = threadIdx.x + NLoop * BlockSize;
|
||||
|
||||
if(is < ref_desc.GetElementSize())
|
||||
{
|
||||
unsigned did[2];
|
||||
|
||||
did[0] = is / ref_desc.GetStride(I0);
|
||||
|
||||
is -= did[0] * ref_desc.GetStride(I0);
|
||||
|
||||
did[1] = is / ref_desc.GetStride(I1);
|
||||
|
||||
const unsigned aindex = src_desc.Get1dIndex(did[0], did[1]);
|
||||
|
||||
const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
|
||||
|
||||
f(p_src[aindex], p_dst[bindex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <unsigned BlockSize, class Float, class DstDesc>
|
||||
__device__ void blockwise_2d_tensor_set_zero(DstDesc, Float* __restrict__ p_dst)
|
||||
{
|
||||
auto f_set_zero = [](Float& v) { v = Float(0); };
|
||||
|
||||
blockwise_2d_tensor_pointwise_operation_unary<BlockSize>(DstDesc{}, p_dst, f_set_zero);
|
||||
}
|
||||
|
||||
template <unsigned BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class DstFromSrcReorder>
|
||||
__device__ void
|
||||
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
|
||||
Float* const __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
DstFromSrcReorder)
|
||||
{
|
||||
auto f_copy = [](const Float& src, Float& dst) { dst = src; };
|
||||
|
||||
blockwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src<BlockSize>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
|
||||
}
|
||||
|
||||
template <unsigned BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
|
||||
__device__ void blockwise_2d_tensor_copy(
|
||||
SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths)
|
||||
{
|
||||
constexpr auto dst_from_src_reorder = Sequence<0, 1>{};
|
||||
|
||||
blockwise_2d_tensor_copy_reorder_by_get_dst_from_src<BlockSize>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder);
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
#include "ConstantTensorDescriptor.cuh"
|
||||
#include "threadwise_tensor_op.cuh"
|
||||
#include "threadwise_4d_tensor_op.cuh"
|
||||
#include "threadwise_direct_convolution.cuh"
|
||||
|
||||
template <unsigned BlockSize,
|
||||
|
||||
@@ -114,14 +114,21 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
|
||||
: b_block_mtx.Get1dIndex(c_thread_mtx_index.col_begin, 0));
|
||||
|
||||
#if 0
|
||||
printf("%u %u, %u %u %u, %u %u\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
c_thread_mtx_index.batch_begin,
|
||||
c_thread_mtx_index.row_begin,
|
||||
c_thread_mtx_index.col_begin,
|
||||
mMyThreadOffsetA,
|
||||
mMyThreadOffsetB);
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantMatrixDescriptor(BlockMatrixA{}, "a_block_mtx: ");
|
||||
print_ConstantMatrixDescriptor(BlockMatrixB{}, "b_block_mtx: ");
|
||||
print_ConstantMatrixDescriptor(ThreadMatrixC{}, "c_thread_mtx: ");
|
||||
|
||||
printf("%u %u, %u %u %u, %u %u\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
c_thread_mtx_index.batch_begin,
|
||||
c_thread_mtx_index.row_begin,
|
||||
c_thread_mtx_index.col_begin,
|
||||
mMyThreadOffsetA,
|
||||
mMyThreadOffsetB);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
#include "ConstantTensorDescriptor.cuh"
|
||||
#include "blockwise_tensor_op.cuh"
|
||||
#include "blockwise_4d_tensor_op.cuh"
|
||||
#include "blockwise_direct_convolution.cuh"
|
||||
|
||||
template <class Float,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
#pragma once
|
||||
#include "ConstantTensorDescriptor.cuh"
|
||||
#include "blockwise_tensor_op.cuh"
|
||||
#include "blockwise_4d_tensor_op.cuh"
|
||||
#include "blockwise_direct_convolution.cuh"
|
||||
#include "threadwise_tensor_op.cuh"
|
||||
#include "threadwise_4d_tensor_op.cuh"
|
||||
#include "threadwise_direct_convolution.cuh"
|
||||
|
||||
template <class Float,
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
#include "common.cuh"
|
||||
#include "ConstantTensorDescriptor.cuh"
|
||||
#include "ConstantMatrixDescriptor.cuh"
|
||||
#include "blockwise_tensor_op.cuh"
|
||||
#include "threadwise_tensor_op.cuh"
|
||||
#include "blockwise_4d_tensor_op.cuh"
|
||||
#include "threadwise_4d_tensor_op.cuh"
|
||||
#include "gemm.cuh"
|
||||
|
||||
template <unsigned GridSize,
|
||||
@@ -21,12 +21,13 @@ template <unsigned GridSize,
|
||||
unsigned CPerThread,
|
||||
unsigned HoPerThread,
|
||||
unsigned WoPerThread>
|
||||
__global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc,
|
||||
Float* const __restrict__ p_in_global,
|
||||
WeiGlobalDesc,
|
||||
Float* const __restrict__ p_wei_global,
|
||||
OutGlobalDesc,
|
||||
Float* __restrict__ p_out_global)
|
||||
__global__ void
|
||||
gridwise_implicit_gemm_convolution_1_nchw_kcsr(InGlobalDesc,
|
||||
Float* const __restrict__ p_in_global,
|
||||
WeiGlobalDesc,
|
||||
Float* const __restrict__ p_wei_global,
|
||||
OutGlobalDesc,
|
||||
Float* __restrict__ p_out_global)
|
||||
{
|
||||
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
|
||||
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
|
||||
@@ -2,8 +2,8 @@
|
||||
#include "common.cuh"
|
||||
#include "ConstantTensorDescriptor.cuh"
|
||||
#include "ConstantMatrixDescriptor.cuh"
|
||||
#include "blockwise_tensor_op.cuh"
|
||||
#include "threadwise_tensor_op.cuh"
|
||||
#include "blockwise_4d_tensor_op.cuh"
|
||||
#include "threadwise_4d_tensor_op.cuh"
|
||||
#include "gemm.cuh"
|
||||
|
||||
template <unsigned GridSize,
|
||||
@@ -21,12 +21,13 @@ template <unsigned GridSize,
|
||||
unsigned CPerThread,
|
||||
unsigned HoPerThread,
|
||||
unsigned WoPerThread>
|
||||
__global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc,
|
||||
Float* const __restrict__ p_in_global,
|
||||
WeiGlobalDesc,
|
||||
Float* const __restrict__ p_wei_global,
|
||||
OutGlobalDesc,
|
||||
Float* __restrict__ p_out_global)
|
||||
__global__ void
|
||||
gridwise_implicit_gemm_convolution_1_nchw_srck(InGlobalDesc,
|
||||
Float* const __restrict__ p_in_global,
|
||||
WeiGlobalDesc,
|
||||
Float* const __restrict__ p_wei_global,
|
||||
OutGlobalDesc,
|
||||
Float* __restrict__ p_out_global)
|
||||
{
|
||||
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
|
||||
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
|
||||
@@ -0,0 +1,264 @@
|
||||
#pragma once
|
||||
#include "common.cuh"
|
||||
#include "ConstantTensorDescriptor.cuh"
|
||||
#include "ConstantMatrixDescriptor.cuh"
|
||||
#include "blockwise_4d_tensor_op.cuh"
|
||||
#include "blockwise_2d_tensor_op.cuh"
|
||||
#include "threadwise_2d_tensor_op.cuh"
|
||||
#include "gemm.cuh"
|
||||
|
||||
// define B = N*Hi*Wi
|
||||
template <unsigned GridSize,
|
||||
unsigned BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
unsigned BPerBlock,
|
||||
unsigned KPerBlock,
|
||||
unsigned CPerBlock,
|
||||
unsigned BPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned CPerThread>
|
||||
__global__ void
|
||||
gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
Float* const __restrict__ p_in_global,
|
||||
WeiGlobalDesc,
|
||||
Float* const __restrict__ p_wei_global,
|
||||
OutGlobalDesc,
|
||||
Float* __restrict__ p_out_global)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_cnhw_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_srck_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_knhw_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr unsigned C = in_cnhw_global_desc.GetLength(I0);
|
||||
constexpr unsigned N = in_cnhw_global_desc.GetLength(I1);
|
||||
constexpr unsigned Hi = in_cnhw_global_desc.GetLength(I2);
|
||||
constexpr unsigned Wi = in_cnhw_global_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned K = out_knhw_global_desc.GetLength(I0);
|
||||
constexpr unsigned Ho = out_knhw_global_desc.GetLength(I2);
|
||||
constexpr unsigned Wo = out_knhw_global_desc.GetLength(I3);
|
||||
|
||||
constexpr unsigned S = wei_srck_global_desc.GetLength(I0);
|
||||
constexpr unsigned R = wei_srck_global_desc.GetLength(I1);
|
||||
|
||||
constexpr unsigned B = N * Hi * Wi;
|
||||
constexpr unsigned BGhostRead = (S - 1) * Wi + (R - 1);
|
||||
|
||||
// divide block work by 2d: [K, B]
|
||||
constexpr unsigned KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr unsigned BBlockWork = (B + BPerBlock - 1) / BPerBlock;
|
||||
|
||||
const unsigned k_block_work_id = get_block_1d_id() / BBlockWork;
|
||||
const unsigned b_block_work_id = get_block_1d_id() - k_block_work_id * BBlockWork;
|
||||
|
||||
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const unsigned b_block_data_begin = b_block_work_id * BPerBlock;
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
{
|
||||
printf("K %u B %u, BGhostRead %u\n", K, B, BGhostRead);
|
||||
|
||||
printf("%u %u, KBlockWork %u BBlockWork %u, k_block_data_begin %u b_block_data_begin %u\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
KBlockWork,
|
||||
BBlockWork,
|
||||
k_block_data_begin,
|
||||
b_block_data_begin);
|
||||
}
|
||||
#endif
|
||||
|
||||
// flattend (2d) tensor view of gridwise input
|
||||
constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence<C, B>{});
|
||||
|
||||
// tensor view of blockwise input and weight
|
||||
constexpr auto in_cb_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<CPerBlock, BPerBlock + BGhostRead>{});
|
||||
|
||||
constexpr auto wei_srck_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<S, R, CPerBlock, KPerBlock>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_kb_thread_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<KPerThread, BPerThread>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_cb_block_desc, "in_cb_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_srck_block_desc, "wei_srck_block_desc");
|
||||
print_ConstantTensorDescriptor(out_kb_thread_desc, "out_kb_thread_desc");
|
||||
|
||||
printf("KPerBlock %u\n", KPerBlock);
|
||||
}
|
||||
#endif
|
||||
|
||||
// a series of blockwise GEMM
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
|
||||
// a_mtx[C,K] is a sub-matrix of wei_block[S,R,C,K]
|
||||
// b_mtx[C,B] is a subset of in_block[C,B + BGhostRead]
|
||||
// c_mtx[K,B] is out_block[K,B]
|
||||
const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}); // constexpr doesn't compile
|
||||
|
||||
const auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{},
|
||||
Number<BPerBlock>{},
|
||||
Number<in_cb_block_desc.GetStride(I0)>{}); // constexpr doesn't compile
|
||||
|
||||
const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThread>{}, Number<BPerThread>{}); // constexpr doesn't compile
|
||||
|
||||
const auto blockwise_gemm =
|
||||
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxb_block_mtx_desc),
|
||||
decltype(c_kxb_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
CPerThread,
|
||||
true>{};
|
||||
|
||||
// LDS
|
||||
constexpr unsigned in_block_size = in_cb_block_desc.GetElementSpace();
|
||||
constexpr unsigned wei_block_size = wei_srck_block_desc.GetElementSpace();
|
||||
|
||||
__shared__ Float p_in_block[in_block_size];
|
||||
__shared__ Float p_wei_block[wei_block_size];
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_kb_thread_desc.GetElementSpace()];
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_2d_tensor_set_zero(out_kb_thread_desc, p_out_thread);
|
||||
|
||||
for(unsigned c_block_data_begin = 0; c_block_data_begin < C;
|
||||
c_block_data_begin += CPerBlock, __syncthreads())
|
||||
{
|
||||
// input: global mem to LDS,
|
||||
// formmat is [CPerBlock,BPerBlock + BGhostRead]
|
||||
blockwise_2d_tensor_copy<BlockSize>(
|
||||
in_cb_global_desc,
|
||||
p_in_global + in_cb_global_desc.Get1dIndex(c_block_data_begin, b_block_data_begin),
|
||||
in_cb_block_desc,
|
||||
p_in_block,
|
||||
in_cb_block_desc.GetLengths());
|
||||
|
||||
// weight: global mem to LDS,
|
||||
// format is [S,R,CPerBlock,KPerBlock]
|
||||
blockwise_4d_tensor_copy<BlockSize>(
|
||||
wei_srck_global_desc,
|
||||
p_wei_global +
|
||||
wei_srck_global_desc.Get1dIndex(0, 0, c_block_data_begin, k_block_data_begin),
|
||||
wei_srck_block_desc,
|
||||
p_wei_block,
|
||||
wei_srck_block_desc.GetLengths());
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// a series of GEMM
|
||||
for(unsigned s = 0; s < S; ++s)
|
||||
{
|
||||
for(unsigned r = 0; r < R; ++r)
|
||||
{
|
||||
auto f_accum = [](auto& c, const auto&& ab) { c += ab; };
|
||||
|
||||
blockwise_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
p_in_block + s * Wi + r,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// output: register to global mem,
|
||||
const auto matrix_c_index =
|
||||
blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
|
||||
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
|
||||
const unsigned b_thread_data_begin = matrix_c_index.col_begin;
|
||||
|
||||
const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin;
|
||||
const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin;
|
||||
|
||||
#if 0
|
||||
//if(get_block_1d_id() == 10)
|
||||
{
|
||||
printf("%u %u, batch_begin %u row_begin %u col_begin %u, k_data_begin %u b_data_begin %u, %f %f %f %f\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
matrix_c_index.batch_begin,
|
||||
matrix_c_index.row_begin,
|
||||
matrix_c_index.col_begin,
|
||||
k_data_begin,
|
||||
b_data_begin,
|
||||
p_out_thread[0], p_out_thread[1], p_out_thread[2], p_out_thread[3]);
|
||||
}
|
||||
#endif
|
||||
|
||||
for(unsigned k = 0; k < out_kb_thread_desc.GetLength(I0); ++k)
|
||||
{
|
||||
for(unsigned b = 0; b < out_kb_thread_desc.GetLength(I1); ++b)
|
||||
{
|
||||
unsigned k_data = k_data_begin + k;
|
||||
unsigned b_data = b_data_begin + b;
|
||||
|
||||
unsigned n_data = b_data / (Hi * Wi);
|
||||
unsigned itmp = b_data - n_data * (Hi * Wi);
|
||||
unsigned h_data = itmp / Wi;
|
||||
unsigned w_data = itmp - h_data * Wi;
|
||||
|
||||
#if 0
|
||||
if(get_block_1d_id() == 10)
|
||||
{
|
||||
printf("%u %u, k %u b %u, k_data %u n_data %u h_data %u w_data %u %f\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
k,
|
||||
b,
|
||||
k_data,
|
||||
n_data,
|
||||
h_data,
|
||||
w_data,
|
||||
p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]);
|
||||
}
|
||||
#endif
|
||||
if(k_data < K && n_data < N && h_data < Ho && w_data < Wo)
|
||||
{
|
||||
p_out_global[out_knhw_global_desc.Get1dIndex(k_data, n_data, h_data, w_data)] =
|
||||
p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)];
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("%u %u, k %u b %u, k_data %u n_data %u h_data %u w_data %u %f\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
k,
|
||||
b,
|
||||
k_data,
|
||||
n_data,
|
||||
h_data,
|
||||
w_data,
|
||||
p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
140
src/include/threadwise_2d_tensor_op.cuh
Normal file
140
src/include/threadwise_2d_tensor_op.cuh
Normal file
@@ -0,0 +1,140 @@
|
||||
#pragma once
|
||||
#include "ConstantTensorDescriptor.cuh"
|
||||
|
||||
template <class Float, class Desc, class F>
|
||||
__device__ void threadwise_2d_tensor_pointwise_operation_unary(Desc, Float* __restrict__ p, F f)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr auto desc = Desc{};
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_op_unary: ");
|
||||
}
|
||||
#endif
|
||||
|
||||
for(unsigned did0 = 0; did0 < desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < desc.GetLength(I1); ++did1)
|
||||
{
|
||||
const unsigned dindex = desc.Get1dIndex(did0, did1);
|
||||
|
||||
f(p[dindex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: in order to optimize mem access for different mem type,
|
||||
// need to write specialized version
|
||||
template <class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class DstFromSrcReorder,
|
||||
class F>
|
||||
__device__ void threadwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
|
||||
SrcDesc,
|
||||
Float* const __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
DstFromSrcReorder,
|
||||
F f)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0);
|
||||
constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1);
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
|
||||
|
||||
for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
|
||||
{
|
||||
const unsigned aindex = src_desc.Get1dIndex(did0, did1);
|
||||
|
||||
const unsigned did[2] = {did0, did1};
|
||||
|
||||
const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1]);
|
||||
|
||||
f(p_src[aindex], p_dst[bindex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class Float, class Desc>
|
||||
__device__ void threadwise_2d_tensor_set_zero(Desc, Float* __restrict__ p)
|
||||
{
|
||||
auto f_set_zero = [](Float& v) { v = Float(0); };
|
||||
|
||||
threadwise_2d_tensor_pointwise_operation_unary<Float, Desc, decltype(f_set_zero)>(
|
||||
Desc{}, p, f_set_zero);
|
||||
}
|
||||
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, class DstFromSrcReorder>
|
||||
__device__ void
|
||||
threadwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
|
||||
Float* const __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
DstFromSrcReorder)
|
||||
{
|
||||
auto f_copy = [](const Float& src, Float& dst) { dst = src; };
|
||||
|
||||
threadwise_2d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy);
|
||||
}
|
||||
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
|
||||
__device__ void threadwise_2d_tensor_copy(
|
||||
SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths)
|
||||
{
|
||||
auto dst_from_src_reorder = Sequence<0, 1>{};
|
||||
|
||||
threadwise_2d_tensor_copy_reorder_by_get_dst_from_src(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder);
|
||||
}
|
||||
|
||||
template <class Float, class Desc, class IDim, class NShift>
|
||||
__device__ void threadwise_2d_tensor_shift_down(Desc, Float* __restrict__ p, IDim, NShift)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr auto desc = Desc{};
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_shift_down: ");
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr unsigned nshift = NShift::mValue;
|
||||
|
||||
constexpr unsigned did0_end =
|
||||
is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - nshift : desc.GetLength(I0);
|
||||
|
||||
constexpr unsigned did1_end =
|
||||
is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - nshift : desc.GetLength(I1);
|
||||
|
||||
for(unsigned did0 = 0; did0 < did0_end; ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < did1_end; ++did1)
|
||||
{
|
||||
const unsigned dindex = desc.Get1dIndex(did0, did1);
|
||||
|
||||
const unsigned sindex = dindex + nshift * desc.GetStride(IDim{});
|
||||
|
||||
p[dindex] = p[sindex];
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user