mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 12:00:07 +00:00
Added bwd data v3r1 v4r1, tweaking v1 (#10)
* Added bwd data v3r1: breaking down compute into a series of load balanced GEMM, and launch in a single kernel
* Added bwd data v4r1: like v3r1, but launch GEMMs in multiple kernels
* Tweaked v1r1 and v1r2 (atomic) on AMD GPU
[ROCm/composable_kernel commit: c5da0377fb]
This commit is contained in:
@@ -30,33 +30,81 @@ struct KernelTimer
|
||||
std::unique_ptr<KernelTimerImpl> impl;
|
||||
};
|
||||
|
||||
#if CK_DEVICE_BACKEND_AMD
|
||||
using device_stream_t = hipStream_t;
|
||||
|
||||
template <typename... Args, typename F>
|
||||
float launch_kernel(F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
|
||||
void launch_kernel(F kernel,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
hipStream_t stream_id,
|
||||
Args... args)
|
||||
{
|
||||
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
|
||||
}
|
||||
|
||||
template <typename... Args, typename F>
|
||||
float launch_and_time_kernel(F kernel,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
hipStream_t stream_id,
|
||||
Args... args)
|
||||
{
|
||||
KernelTimer timer;
|
||||
|
||||
#if CK_DEVICE_BACKEND_AMD
|
||||
timer.Start();
|
||||
|
||||
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, 0, args...);
|
||||
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
|
||||
|
||||
timer.End();
|
||||
|
||||
hipGetErrorString(hipGetLastError());
|
||||
|
||||
return timer.GetElapsedTime();
|
||||
}
|
||||
|
||||
#elif CK_DEVICE_BACKEND_NVIDIA
|
||||
using device_stream_t = cudaStream_t;
|
||||
|
||||
template <typename... Args, typename F>
|
||||
void launch_kernel(F kernel,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
cudaStream_t stream_id,
|
||||
Args... args)
|
||||
{
|
||||
const void* f = reinterpret_cast<const void*>(kernel);
|
||||
void* p_args[] = {&args...};
|
||||
|
||||
cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, stream_id);
|
||||
}
|
||||
|
||||
template <typename... Args, typename F>
|
||||
float launch_and_time_kernel(F kernel,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
cudaStream_t stream_id,
|
||||
Args... args)
|
||||
{
|
||||
KernelTimer timer;
|
||||
|
||||
const void* f = reinterpret_cast<const void*>(kernel);
|
||||
void* p_args[] = {&args...};
|
||||
|
||||
timer.Start();
|
||||
|
||||
cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, 0);
|
||||
cudaError_t error = cudaLaunchKernel(f, grid_dim, block_dim, p_args, lds_byte, stream_id);
|
||||
|
||||
timer.End();
|
||||
|
||||
checkCudaErrors(error);
|
||||
#endif
|
||||
|
||||
return timer.GetElapsedTime();
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
@@ -88,7 +88,8 @@ void device_col2im_eb_nchw(ColDesc,
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_col2im),
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_col2im),
|
||||
const T* const __restrict__,
|
||||
T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
|
||||
@@ -5,6 +5,10 @@
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
namespace launcher {
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
@@ -121,20 +125,18 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gridwise_conv,
|
||||
const_cast<T* const __restrict__>(
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())),
|
||||
const_cast<const T* const __restrict__>(
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer())),
|
||||
const_cast<const T* const __restrict__>(
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())));
|
||||
float time = launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
gridwise_conv,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
|
||||
printf("Elapsed time : %f ms, %f TFlop/s\n",
|
||||
time,
|
||||
@@ -145,3 +147,5 @@ void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc i
|
||||
|
||||
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
|
||||
}
|
||||
|
||||
} // namespace launcher
|
||||
|
||||
@@ -5,6 +5,10 @@
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp"
|
||||
|
||||
namespace launcher {
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
@@ -129,20 +133,18 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gridwise_conv,
|
||||
const_cast<T* const __restrict__>(
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())),
|
||||
const_cast<const T* const __restrict__>(
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer())),
|
||||
const_cast<const T* const __restrict__>(
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())));
|
||||
float time = launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
gridwise_conv,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
|
||||
printf("Elapsed time : %f ms, %f TFlop/s\n",
|
||||
time,
|
||||
@@ -153,3 +155,5 @@ void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc i
|
||||
|
||||
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
|
||||
}
|
||||
|
||||
} // namespace launcher
|
||||
|
||||
@@ -5,6 +5,10 @@
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
namespace launcher {
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
@@ -27,12 +31,16 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
|
||||
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
|
||||
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
|
||||
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
|
||||
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
|
||||
|
||||
constexpr index_t Hi = in_nchw_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_nchw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
|
||||
|
||||
@@ -81,6 +89,67 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
// for 1x1 weight, 8x8 input
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
|
||||
#endif
|
||||
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
@@ -92,14 +161,24 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
|
||||
constexpr index_t right_pad_ho = (ConvDilationH / hcf_stride_dilation_h) * (Y - Ytilda);
|
||||
constexpr index_t right_pad_wo = (ConvDilationW / hcf_stride_dilation_w) * (X - Xtilda);
|
||||
constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t Htilda = Ho + right_pad_ho;
|
||||
constexpr index_t Wtilda = Wo + right_pad_wo;
|
||||
constexpr index_t HtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HtildaRight = math::min(
|
||||
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WtildaRight = math::min(
|
||||
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
|
||||
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
|
||||
|
||||
constexpr index_t GemmM = C * Ytilda * Xtilda;
|
||||
constexpr index_t GemmN = N * Htilda * Wtilda;
|
||||
constexpr index_t GemmN = N * HtildaTrim * WtildaTrim;
|
||||
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
@@ -142,20 +221,18 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
gridwise_conv,
|
||||
const_cast<T* const __restrict__>(
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())),
|
||||
const_cast<const T* const __restrict__>(
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer())),
|
||||
const_cast<const T* const __restrict__>(
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer())));
|
||||
float time = launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
gridwise_conv,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
|
||||
printf("Elapsed time : %f ms, %f TFlop/s\n",
|
||||
time,
|
||||
@@ -166,3 +243,5 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
|
||||
|
||||
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
|
||||
}
|
||||
|
||||
} // namespace launcher
|
||||
|
||||
@@ -0,0 +1,186 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
namespace launcher {
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
typename OutDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v3r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
|
||||
Tensor<T>& in_nchw,
|
||||
WeiDesc wei_kcyx_desc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc out_nkhw_desc,
|
||||
const Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
std::size_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
|
||||
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
|
||||
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
|
||||
|
||||
constexpr index_t Hi = in_nchw_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_nchw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
|
||||
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
|
||||
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
|
||||
|
||||
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
|
||||
constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t HtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HtildaRight = math::min(
|
||||
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WtildaRight = math::min(
|
||||
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
|
||||
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
|
||||
|
||||
constexpr index_t GemmM = C;
|
||||
constexpr index_t GemmN = N * HtildaTrim * WtildaTrim;
|
||||
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
constexpr auto gridwise_conv = GridwiseConvolutionBackwardDataImplicitGemm_v3r1_nchw_kcyx_nkhw<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
gridwise_conv,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
|
||||
printf("Elapsed time : %f ms, %f TFlop/s\n",
|
||||
time,
|
||||
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1000) * 1000 * 1000) / time);
|
||||
usleep(std::min(time * 1000, float(10000)));
|
||||
}
|
||||
|
||||
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
|
||||
}
|
||||
|
||||
} // namespace launcher
|
||||
@@ -0,0 +1,232 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
namespace launcher {
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
typename OutDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
|
||||
Tensor<T>& in_nchw,
|
||||
WeiDesc wei_kcyx_desc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc out_nkhw_desc,
|
||||
const Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
std::size_t nrepeat)
|
||||
{
|
||||
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
|
||||
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
|
||||
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
|
||||
|
||||
constexpr index_t Hi = in_nchw_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_nchw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
|
||||
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
|
||||
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
|
||||
|
||||
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 1>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<8, 1>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t hcf_stride_dilation_h = math::hcf(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t hcf_stride_dilation_w = math::hcf(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t Ytilda = ConvStrideH / hcf_stride_dilation_h;
|
||||
constexpr index_t Xtilda = ConvStrideW / hcf_stride_dilation_w;
|
||||
|
||||
constexpr index_t Ydot = math::integer_divide_ceil(Y, Ytilda);
|
||||
constexpr index_t Xdot = math::integer_divide_ceil(X, Xtilda);
|
||||
|
||||
constexpr index_t Htilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t Wtilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t HtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (Ytilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WtildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (Xtilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HtildaRight = math::min(
|
||||
Htilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WtildaRight = math::min(
|
||||
Wtilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HtildaTrim = HtildaRight - HtildaLeft;
|
||||
constexpr index_t WtildaTrim = WtildaRight - WtildaLeft;
|
||||
|
||||
constexpr index_t GemmM = C;
|
||||
constexpr index_t GemmN = N * HtildaTrim * WtildaTrim;
|
||||
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
KernelTimer timer;
|
||||
|
||||
timer.Start();
|
||||
|
||||
static_for<0, Ytilda, 1>{}([&](auto ytilda_) {
|
||||
static_for<0, Xtilda, 1>{}([&](auto xtilda_) {
|
||||
constexpr index_t ytilda = decltype(ytilda_){};
|
||||
constexpr index_t xtilda = decltype(xtilda_){};
|
||||
|
||||
constexpr auto gridwise_conv =
|
||||
GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
ytilda,
|
||||
xtilda,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>{};
|
||||
|
||||
launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
gridwise_conv,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
});
|
||||
});
|
||||
|
||||
timer.End();
|
||||
|
||||
float time = timer.GetElapsedTime();
|
||||
|
||||
printf("Elapsed time : %f ms, %f TFlop/s\n",
|
||||
time,
|
||||
(float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1000) * 1000 * 1000) / time);
|
||||
usleep(std::min(time * 1000, float(10000)));
|
||||
}
|
||||
|
||||
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
|
||||
}
|
||||
|
||||
} // namespace launcher
|
||||
@@ -82,13 +82,13 @@ void device_convolution_direct_v2_nchw_kcyx_nkhw(InDesc,
|
||||
WoPerThread,
|
||||
InBlockCopyDataPerRead,
|
||||
WeiBlockCopyDataPerRead>;
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<gridwise_conv, T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
static_cast<T*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_device_buf.GetDeviceBuffer()));
|
||||
float time = launch_and_time_kernel(run_gridwise_convolution_kernel<gridwise_conv, T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
static_cast<T*>(in_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_device_buf.GetDeviceBuffer()));
|
||||
|
||||
printf("Elapsed time : %f ms\n", time);
|
||||
usleep(std::min(time * 1000, float(10000)));
|
||||
|
||||
@@ -458,7 +458,8 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
|
||||
@@ -161,7 +161,8 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded(InDesc,
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
|
||||
@@ -354,7 +354,8 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
|
||||
WeiBlockCopyDataPerRead_K,
|
||||
OutThreadCopyDataPerWrite_W>{};
|
||||
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
|
||||
@@ -306,7 +306,8 @@ void device_convolution_implicit_gemm_v2_chwn_cyxk_khwn(InDesc,
|
||||
WeiBlockCopyDataPerRead,
|
||||
OutThreadCopyDataPerWrite>{};
|
||||
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
|
||||
@@ -135,7 +135,8 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
|
||||
WeiBlockCopyClusterLengths_C_K,
|
||||
WeiBlockCopyDataPerAccess_K>{};
|
||||
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
|
||||
@@ -54,7 +54,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
// BlockSize = 256, EperBlock = 8, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -128,7 +128,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// BlockSize = 64, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
@@ -258,13 +258,15 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_operation<decltype(gridwise_conv),
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
gridwise_conv,
|
||||
const_cast<const T* const __restrict__>(
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer())),
|
||||
|
||||
@@ -81,6 +81,43 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc,
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
// BlockSize = 256, EPerBlock = 16, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t EPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<16, 1, 16, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
@@ -247,10 +284,12 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated(InDesc,
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
|
||||
@@ -200,7 +200,8 @@ void device_convolution_implicit_gemm_v4r2_nchw_kcyx_nkhw(InDesc,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>{};
|
||||
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
|
||||
@@ -158,7 +158,8 @@ void device_convolution_implicit_gemm_v4r3_nchw_kcyx_nkhw(InDesc,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>{};
|
||||
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
|
||||
@@ -53,7 +53,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
// BlockSize = 256, GemmKPerBlock = 8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -83,6 +83,37 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 1
|
||||
// BlockSize = 256, GemmKPerBlock = 16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t ThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<4, 2>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 64>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 2>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<4, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 0
|
||||
// BlockSize = 256, GemmKPerBlock = 8
|
||||
@@ -116,7 +147,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// BlockSize = 256, GemmKPerBlock = 16
|
||||
// 1x1 filter, 8x8 image
|
||||
constexpr index_t BlockSize = 256;
|
||||
@@ -225,10 +256,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
|
||||
@@ -205,7 +205,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated(InDesc,
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
float time =
|
||||
launch_and_time_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
|
||||
@@ -178,7 +178,7 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
float time = launch_kernel(
|
||||
float time = launch_and_time_kernel(
|
||||
gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw<TInWei,
|
||||
TOut,
|
||||
accum_t,
|
||||
|
||||
Reference in New Issue
Block a user