mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
adding implicit gemm
This commit is contained in:
154
driver/conv.cu
154
driver/conv.cu
@@ -85,19 +85,19 @@ auto make_TensorDescriptor(TConstTensorDesc)
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void host_direct_convolution(const Tensor<T>& in, const Tensor<T>& wei, Tensor<T>& out)
|
||||
void host_direct_convolution(const Tensor<T>& in_nchw, const Tensor<T>& wei_kcsr, Tensor<T>& out)
|
||||
{
|
||||
auto f = [&](auto n, auto k, auto ho, auto wo) {
|
||||
double v = 0;
|
||||
for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c)
|
||||
for(int c = 0; c < wei_kcsr.mDesc.GetLengths()[1]; ++c)
|
||||
{
|
||||
for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y)
|
||||
for(int y = 0; y < wei_kcsr.mDesc.GetLengths()[2]; ++y)
|
||||
{
|
||||
int hi = ho + y;
|
||||
for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x)
|
||||
for(int x = 0; x < wei_kcsr.mDesc.GetLengths()[3]; ++x)
|
||||
{
|
||||
int wi = wo + x;
|
||||
v += in(n, c, hi, wi) * wei(k, c, y, x);
|
||||
v += in_nchw(n, c, hi, wi) * wei_kcsr(k, c, y, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -114,19 +114,21 @@ void host_direct_convolution(const Tensor<T>& in, const Tensor<T>& wei, Tensor<T
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void host_winograd_3x3_convolution(const Tensor<T>& in, const Tensor<T>& wei, Tensor<T>& out)
|
||||
void host_winograd_3x3_convolution(const Tensor<T>& in_nchw,
|
||||
const Tensor<T>& wei_kcsr,
|
||||
Tensor<T>& out)
|
||||
{
|
||||
constexpr std::size_t OutTileSizeH = 2;
|
||||
constexpr std::size_t OutTileSizeW = 2;
|
||||
|
||||
std::size_t N = in.mDesc.GetLengths()[0];
|
||||
std::size_t C = in.mDesc.GetLengths()[1];
|
||||
std::size_t HI = in.mDesc.GetLengths()[2];
|
||||
std::size_t WI = in.mDesc.GetLengths()[3];
|
||||
std::size_t N = in_nchw.mDesc.GetLengths()[0];
|
||||
std::size_t C = in_nchw.mDesc.GetLengths()[1];
|
||||
std::size_t HI = in_nchw.mDesc.GetLengths()[2];
|
||||
std::size_t WI = in_nchw.mDesc.GetLengths()[3];
|
||||
|
||||
std::size_t K = wei.mDesc.GetLengths()[0];
|
||||
std::size_t S = wei.mDesc.GetLengths()[2];
|
||||
std::size_t R = wei.mDesc.GetLengths()[3];
|
||||
std::size_t K = wei_kcsr.mDesc.GetLengths()[0];
|
||||
std::size_t S = wei_kcsr.mDesc.GetLengths()[2];
|
||||
std::size_t R = wei_kcsr.mDesc.GetLengths()[3];
|
||||
|
||||
std::size_t HO = out.mDesc.GetLengths()[2];
|
||||
std::size_t WO = out.mDesc.GetLengths()[3];
|
||||
@@ -150,7 +152,7 @@ void host_winograd_3x3_convolution(const Tensor<T>& in, const Tensor<T>& wei, Te
|
||||
for(int i = 0; i < InTileSizeW; ++i)
|
||||
{
|
||||
std::size_t wi = OutTileSizeW * x + i;
|
||||
in_hold(n, c, y, x, j, i) = in(n, c, hi, wi);
|
||||
in_hold(n, c, y, x, j, i) = in_nchw(n, c, hi, wi);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -194,45 +196,49 @@ void host_winograd_3x3_convolution(const Tensor<T>& in, const Tensor<T>& wei, Te
|
||||
};
|
||||
|
||||
auto f_wei_transform = [&](auto k, auto c) {
|
||||
wei_transform(k, c, 0, 0) = wei(k, c, 0, 0);
|
||||
wei_transform(k, c, 0, 0) = wei_kcsr(k, c, 0, 0);
|
||||
wei_transform(k, c, 0, 1) =
|
||||
0.5 * wei(k, c, 0, 0) + 0.5 * wei(k, c, 0, 1) + 0.5 * wei(k, c, 0, 2);
|
||||
0.5 * wei_kcsr(k, c, 0, 0) + 0.5 * wei_kcsr(k, c, 0, 1) + 0.5 * wei_kcsr(k, c, 0, 2);
|
||||
wei_transform(k, c, 0, 2) =
|
||||
0.5 * wei(k, c, 0, 0) - 0.5 * wei(k, c, 0, 1) + 0.5 * wei(k, c, 0, 2);
|
||||
wei_transform(k, c, 0, 3) = wei(k, c, 0, 2);
|
||||
0.5 * wei_kcsr(k, c, 0, 0) - 0.5 * wei_kcsr(k, c, 0, 1) + 0.5 * wei_kcsr(k, c, 0, 2);
|
||||
wei_transform(k, c, 0, 3) = wei_kcsr(k, c, 0, 2);
|
||||
|
||||
wei_transform(k, c, 1, 0) =
|
||||
0.5 * wei(k, c, 0, 0) + 0.5 * wei(k, c, 1, 0) + 0.5 * wei(k, c, 2, 0);
|
||||
wei_transform(k, c, 1, 1) =
|
||||
0.25 * wei(k, c, 0, 0) + 0.25 * wei(k, c, 0, 1) + 0.25 * wei(k, c, 0, 2) +
|
||||
0.25 * wei(k, c, 1, 0) + 0.25 * wei(k, c, 1, 1) + 0.25 * wei(k, c, 1, 2) +
|
||||
0.25 * wei(k, c, 2, 0) + 0.25 * wei(k, c, 2, 1) + 0.25 * wei(k, c, 2, 2);
|
||||
wei_transform(k, c, 1, 2) =
|
||||
0.25 * wei(k, c, 0, 0) - 0.25 * wei(k, c, 0, 1) + 0.25 * wei(k, c, 0, 2) +
|
||||
0.25 * wei(k, c, 1, 0) - 0.25 * wei(k, c, 1, 1) + 0.25 * wei(k, c, 1, 2) +
|
||||
0.25 * wei(k, c, 2, 0) - 0.25 * wei(k, c, 2, 1) + 0.25 * wei(k, c, 2, 2);
|
||||
0.5 * wei_kcsr(k, c, 0, 0) + 0.5 * wei_kcsr(k, c, 1, 0) + 0.5 * wei_kcsr(k, c, 2, 0);
|
||||
wei_transform(k, c, 1, 1) = 0.25 * wei_kcsr(k, c, 0, 0) + 0.25 * wei_kcsr(k, c, 0, 1) +
|
||||
0.25 * wei_kcsr(k, c, 0, 2) + 0.25 * wei_kcsr(k, c, 1, 0) +
|
||||
0.25 * wei_kcsr(k, c, 1, 1) + 0.25 * wei_kcsr(k, c, 1, 2) +
|
||||
0.25 * wei_kcsr(k, c, 2, 0) + 0.25 * wei_kcsr(k, c, 2, 1) +
|
||||
0.25 * wei_kcsr(k, c, 2, 2);
|
||||
wei_transform(k, c, 1, 2) = 0.25 * wei_kcsr(k, c, 0, 0) - 0.25 * wei_kcsr(k, c, 0, 1) +
|
||||
0.25 * wei_kcsr(k, c, 0, 2) + 0.25 * wei_kcsr(k, c, 1, 0) -
|
||||
0.25 * wei_kcsr(k, c, 1, 1) + 0.25 * wei_kcsr(k, c, 1, 2) +
|
||||
0.25 * wei_kcsr(k, c, 2, 0) - 0.25 * wei_kcsr(k, c, 2, 1) +
|
||||
0.25 * wei_kcsr(k, c, 2, 2);
|
||||
wei_transform(k, c, 1, 3) =
|
||||
0.5 * wei(k, c, 0, 2) + 0.5 * wei(k, c, 1, 2) + 0.5 * wei(k, c, 2, 2);
|
||||
0.5 * wei_kcsr(k, c, 0, 2) + 0.5 * wei_kcsr(k, c, 1, 2) + 0.5 * wei_kcsr(k, c, 2, 2);
|
||||
|
||||
wei_transform(k, c, 2, 0) =
|
||||
0.5 * wei(k, c, 0, 0) - 0.5 * wei(k, c, 1, 0) + 0.5 * wei(k, c, 2, 0);
|
||||
wei_transform(k, c, 2, 1) =
|
||||
0.25 * wei(k, c, 0, 0) + 0.25 * wei(k, c, 0, 1) + 0.25 * wei(k, c, 0, 2) -
|
||||
0.25 * wei(k, c, 1, 0) - 0.25 * wei(k, c, 1, 1) - 0.25 * wei(k, c, 1, 2) +
|
||||
0.25 * wei(k, c, 2, 0) + 0.25 * wei(k, c, 2, 1) + 0.25 * wei(k, c, 2, 2);
|
||||
wei_transform(k, c, 2, 2) =
|
||||
0.25 * wei(k, c, 0, 0) - 0.25 * wei(k, c, 0, 1) + 0.25 * wei(k, c, 0, 2) -
|
||||
0.25 * wei(k, c, 1, 0) + 0.25 * wei(k, c, 1, 1) - 0.25 * wei(k, c, 1, 2) +
|
||||
0.25 * wei(k, c, 2, 0) - 0.25 * wei(k, c, 2, 1) + 0.25 * wei(k, c, 2, 2);
|
||||
0.5 * wei_kcsr(k, c, 0, 0) - 0.5 * wei_kcsr(k, c, 1, 0) + 0.5 * wei_kcsr(k, c, 2, 0);
|
||||
wei_transform(k, c, 2, 1) = 0.25 * wei_kcsr(k, c, 0, 0) + 0.25 * wei_kcsr(k, c, 0, 1) +
|
||||
0.25 * wei_kcsr(k, c, 0, 2) - 0.25 * wei_kcsr(k, c, 1, 0) -
|
||||
0.25 * wei_kcsr(k, c, 1, 1) - 0.25 * wei_kcsr(k, c, 1, 2) +
|
||||
0.25 * wei_kcsr(k, c, 2, 0) + 0.25 * wei_kcsr(k, c, 2, 1) +
|
||||
0.25 * wei_kcsr(k, c, 2, 2);
|
||||
wei_transform(k, c, 2, 2) = 0.25 * wei_kcsr(k, c, 0, 0) - 0.25 * wei_kcsr(k, c, 0, 1) +
|
||||
0.25 * wei_kcsr(k, c, 0, 2) - 0.25 * wei_kcsr(k, c, 1, 0) +
|
||||
0.25 * wei_kcsr(k, c, 1, 1) - 0.25 * wei_kcsr(k, c, 1, 2) +
|
||||
0.25 * wei_kcsr(k, c, 2, 0) - 0.25 * wei_kcsr(k, c, 2, 1) +
|
||||
0.25 * wei_kcsr(k, c, 2, 2);
|
||||
wei_transform(k, c, 2, 3) =
|
||||
0.5 * wei(k, c, 0, 2) - 0.5 * wei(k, c, 1, 2) + 0.5 * wei(k, c, 2, 2);
|
||||
0.5 * wei_kcsr(k, c, 0, 2) - 0.5 * wei_kcsr(k, c, 1, 2) + 0.5 * wei_kcsr(k, c, 2, 2);
|
||||
|
||||
wei_transform(k, c, 3, 0) = wei(k, c, 2, 0);
|
||||
wei_transform(k, c, 3, 0) = wei_kcsr(k, c, 2, 0);
|
||||
wei_transform(k, c, 3, 1) =
|
||||
0.5 * wei(k, c, 2, 0) + 0.5 * wei(k, c, 2, 1) + 0.5 * wei(k, c, 2, 2);
|
||||
0.5 * wei_kcsr(k, c, 2, 0) + 0.5 * wei_kcsr(k, c, 2, 1) + 0.5 * wei_kcsr(k, c, 2, 2);
|
||||
wei_transform(k, c, 3, 2) =
|
||||
0.5 * wei(k, c, 2, 0) - 0.5 * wei(k, c, 2, 1) + 0.5 * wei(k, c, 2, 2);
|
||||
wei_transform(k, c, 3, 3) = wei(k, c, 2, 2);
|
||||
0.5 * wei_kcsr(k, c, 2, 0) - 0.5 * wei_kcsr(k, c, 2, 1) + 0.5 * wei_kcsr(k, c, 2, 2);
|
||||
wei_transform(k, c, 3, 3) = wei_kcsr(k, c, 2, 2);
|
||||
};
|
||||
|
||||
auto f_out_transform = [&](auto n, auto k, auto y, auto x) {
|
||||
@@ -366,54 +372,66 @@ int main()
|
||||
constexpr unsigned R = 3;
|
||||
#endif
|
||||
|
||||
auto in_desc = make_ConstantTensorDescriptor(Sequence<N, C, HI, WI>{});
|
||||
auto wei_desc = make_ConstantTensorDescriptor(Sequence<K, C, S, R>{});
|
||||
auto out_desc = get_convolution_output_default_4d_tensor_descriptor(in_desc, wei_desc);
|
||||
auto in_nchw_desc = make_ConstantTensorDescriptor(Sequence<N, C, HI, WI>{});
|
||||
auto wei_kcsr_desc = make_ConstantTensorDescriptor(Sequence<K, C, S, R>{});
|
||||
auto wei_srck_desc = make_ConstantTensorDescriptor(Sequence<S, R, C, K>{});
|
||||
auto out_nkhw_desc =
|
||||
get_convolution_output_default_4d_tensor_descriptor(in_nchw_desc, wei_kcsr_desc);
|
||||
|
||||
ostream_ConstantTensorDescriptor(in_desc, std::cout << "in_desc: ");
|
||||
ostream_ConstantTensorDescriptor(wei_desc, std::cout << "wei_desc: ");
|
||||
ostream_ConstantTensorDescriptor(out_desc, std::cout << "out_desc: ");
|
||||
ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
|
||||
ostream_ConstantTensorDescriptor(wei_kcsr_desc, std::cout << "wei_kcsr_desc: ");
|
||||
ostream_ConstantTensorDescriptor(wei_srck_desc, std::cout << "wei_srck_desc: ");
|
||||
ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: ");
|
||||
|
||||
Tensor<float> in(make_TensorDescriptor(in_desc));
|
||||
Tensor<float> wei(make_TensorDescriptor(wei_desc));
|
||||
Tensor<float> out_host(make_TensorDescriptor(out_desc));
|
||||
Tensor<float> out_device(make_TensorDescriptor(out_desc));
|
||||
Tensor<float> in_nchw(make_TensorDescriptor(in_nchw_desc));
|
||||
Tensor<float> wei_kcsr(make_TensorDescriptor(wei_kcsr_desc));
|
||||
Tensor<float> wei_srck(make_TensorDescriptor(wei_srck_desc));
|
||||
Tensor<float> out_nkhw_host(make_TensorDescriptor(out_nkhw_desc));
|
||||
Tensor<float> out_nkhw_device(make_TensorDescriptor(out_nkhw_desc));
|
||||
|
||||
#if 0
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei_srck.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
#elif 1
|
||||
std::size_t num_thread = std::thread::hardware_concurrency();
|
||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei_kcsr.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei_srck.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
#endif
|
||||
|
||||
for(int i = 0; i < 40; ++i)
|
||||
{
|
||||
#if 0
|
||||
device_direct_convolution_1(in_desc, in, wei_desc, wei, out_desc, out_device);
|
||||
device_direct_convolution_1(in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device);
|
||||
#elif 0
|
||||
device_direct_convolution_2(in_desc, in, wei_desc, wei, out_desc, out_device);
|
||||
device_direct_convolution_2(
|
||||
in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device);
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution(
|
||||
in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device);
|
||||
#elif 1
|
||||
device_implicit_gemm_convolution(in_desc, in, wei_desc, wei, out_desc, out_device);
|
||||
device_implicit_gemm_convolution(
|
||||
in_nchw_desc, in_nchw, wei_srck_desc, wei_srck, out_nkhw_desc, out_nkhw_device);
|
||||
#elif 0
|
||||
device_winograd_convolution(in_desc, in, wei_desc, wei, out_desc, out_device);
|
||||
device_winograd_convolution(
|
||||
in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if 1
|
||||
host_winograd_3x3_convolution(in, wei, out_host);
|
||||
check_error(out_host, out_device);
|
||||
host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host);
|
||||
check_error(out_nkhw_host, out_nkhw_device);
|
||||
#elif 0
|
||||
host_direct_convolution(in, wei, out_host);
|
||||
check_error(out_host, out_device);
|
||||
host_direct_convolution(in_nchw, wei_kcsr, out_nkhw_host);
|
||||
check_error(out_nkhw_host, out_nkhw_device);
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
LogRange(std::cout << "in : ", in.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "wei: ", wei.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_host : ", out_host.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_device: ", out_device.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "wei_kcsr: ", wei_kcsr.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
#include "gridwise_implicit_gemm_convolution.cuh"
|
||||
#include "gridwise_implicit_gemm_convolution_nchw_kcsr.cuh"
|
||||
#include "gridwise_implicit_gemm_convolution_nchw_srck.cuh"
|
||||
|
||||
template <class T, class InDesc, class WeiDesc, class OutDesc>
|
||||
void device_implicit_gemm_convolution(
|
||||
@@ -25,7 +26,7 @@ void device_implicit_gemm_convolution(
|
||||
constexpr auto wei_desc = WeiDesc{};
|
||||
constexpr auto out_desc = OutDesc{};
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
@@ -39,6 +40,20 @@ void device_implicit_gemm_convolution(
|
||||
constexpr unsigned WoPerThread = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 256;
|
||||
#elif 1
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
constexpr unsigned KPerBlock = 32;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 32;
|
||||
|
||||
constexpr unsigned NPerThread = 2;
|
||||
constexpr unsigned KPerThread = 4;
|
||||
constexpr unsigned CPerThread = 2;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#endif
|
||||
|
||||
constexpr unsigned GridSize =
|
||||
@@ -56,27 +71,31 @@ void device_implicit_gemm_convolution(
|
||||
cudaEventCreate(&start);
|
||||
cudaEventRecord(start, 0);
|
||||
|
||||
gridwise_implicit_gemm_convolution_nchw_kcsr<GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
InDesc,
|
||||
WeiDesc,
|
||||
OutDesc,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
CPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
KPerThread,
|
||||
CPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread>
|
||||
<<<grid_dim, block_dim>>>(InDesc{},
|
||||
static_cast<T*>(in_device_buf.GetDeviceBuffer()),
|
||||
WeiDesc{},
|
||||
static_cast<T*>(wei_device_buf.GetDeviceBuffer()),
|
||||
OutDesc{},
|
||||
static_cast<T*>(out_device_buf.GetDeviceBuffer()));
|
||||
#if 0
|
||||
gridwise_implicit_gemm_convolution_nchw_kcsr
|
||||
#elif 1
|
||||
gridwise_implicit_gemm_convolution_nchw_srck
|
||||
#endif
|
||||
<GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
InDesc,
|
||||
WeiDesc,
|
||||
OutDesc,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
CPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
KPerThread,
|
||||
CPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread><<<grid_dim, block_dim>>>(InDesc{},
|
||||
static_cast<T*>(in_device_buf.GetDeviceBuffer()),
|
||||
WeiDesc{},
|
||||
static_cast<T*>(wei_device_buf.GetDeviceBuffer()),
|
||||
OutDesc{},
|
||||
static_cast<T*>(out_device_buf.GetDeviceBuffer()));
|
||||
|
||||
cudaEventCreate(&stop);
|
||||
cudaEventRecord(stop, 0);
|
||||
|
||||
Reference in New Issue
Block a user