mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
No raw index calculation (#31)
* Replace most raw index calculation to coordinate transformation * Overhaul blockwise and threadwise GEMM * Overhaul driver for gridwies GEMM kernel Co-authored-by: Jing Zhang <jizhan@amd.com>
This commit is contained in:
@@ -46,6 +46,7 @@ void launch_kernel(F kernel,
|
||||
|
||||
template <typename... Args, typename F>
|
||||
float launch_and_time_kernel(F kernel,
|
||||
int nrepeat,
|
||||
dim3 grid_dim,
|
||||
dim3 block_dim,
|
||||
std::size_t lds_byte,
|
||||
@@ -54,15 +55,32 @@ float launch_and_time_kernel(F kernel,
|
||||
{
|
||||
KernelTimer timer;
|
||||
|
||||
printf("%s: block_dim {%d, %d, %d}, grid_dim {%d, %d, %d} \n",
|
||||
__func__,
|
||||
grid_dim.x,
|
||||
grid_dim.y,
|
||||
grid_dim.z,
|
||||
block_dim.x,
|
||||
block_dim.y,
|
||||
block_dim.z);
|
||||
|
||||
printf("Warm up\n");
|
||||
|
||||
// warm up
|
||||
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
|
||||
|
||||
printf("Start running %d times...\n", nrepeat);
|
||||
|
||||
timer.Start();
|
||||
|
||||
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
|
||||
for(int i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
hipLaunchKernelGGL(kernel, grid_dim, block_dim, lds_byte, stream_id, args...);
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
hipGetLastError();
|
||||
|
||||
return timer.GetElapsedTime();
|
||||
return timer.GetElapsedTime() / nrepeat;
|
||||
}
|
||||
|
||||
#elif CK_DEVICE_BACKEND_NVIDIA
|
||||
|
||||
@@ -29,8 +29,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw"
|
||||
<< std::endl;
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
@@ -459,50 +468,94 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 4;
|
||||
#endif
|
||||
|
||||
constexpr auto conv_driver =
|
||||
constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
|
||||
const auto descs =
|
||||
#if 1
|
||||
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
|
||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_pad
|
||||
#elif 0
|
||||
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
|
||||
#elif 1
|
||||
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
|
||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_no_pad
|
||||
#else
|
||||
transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw_1x1
|
||||
#endif
|
||||
<BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type,
|
||||
TAcc,
|
||||
TOut,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK,
|
||||
GemmABlockTransferDstScalarPerVector_GemmM,
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmN,
|
||||
GemmCThreadTransferDstScalarPerVector_GemmN1>{};
|
||||
<GemmMPerBlock, GemmNPerBlock, GemmM1, GemmN1>(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
conv_driver.Run(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()));
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = launch_kernel_dynamic_gemm_v1<
|
||||
BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperation::Set,
|
||||
decltype(descs[I0]),
|
||||
decltype(descs[I1]),
|
||||
decltype(descs[I2]),
|
||||
decltype(descs[I3]),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1, 0>,
|
||||
0,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK,
|
||||
GemmABlockTransferDstScalarPerVector_GemmM,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmN,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmN,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused with
|
||||
// MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<2, 3, 0, 1>,
|
||||
3,
|
||||
GemmCThreadTransferDstScalarPerVector_GemmN1,
|
||||
decltype(descs[I4]),
|
||||
decltype(descs[I5]),
|
||||
decltype(descs[I6]),
|
||||
decltype(descs[I7]),
|
||||
decltype(descs[I8])>(static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
descs[I0],
|
||||
descs[I1],
|
||||
descs[I2],
|
||||
descs[I3],
|
||||
descs[I4],
|
||||
descs[I5],
|
||||
descs[I6],
|
||||
descs[I7],
|
||||
descs[I8],
|
||||
nrepeat);
|
||||
|
||||
float perf = (float)calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
|
||||
@@ -29,13 +29,17 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk"
|
||||
<< std::endl;
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
constexpr auto N = OutDesc::GetLengths()[I0];
|
||||
constexpr auto K = OutDesc::GetLengths()[I1];
|
||||
@@ -53,7 +57,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
|
||||
constexpr auto C0 = C / Number<InWeiVectorSize>{};
|
||||
constexpr auto C1 = Number<InWeiVectorSize>{};
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
// run-time variables
|
||||
constexpr auto in_n_hi_wi_c0_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0));
|
||||
@@ -112,7 +116,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
// cdata = 16, BlockSize = 64, 16x64x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
@@ -372,51 +376,92 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmM1 = 4;
|
||||
#endif
|
||||
|
||||
constexpr auto conv_driver =
|
||||
constexpr index_t GemmM1 = GemmMPerThread * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t GemmN1 = GemmNPerThread * GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
|
||||
const auto descs =
|
||||
#if 1
|
||||
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_pad
|
||||
#elif 0
|
||||
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_no_pad
|
||||
#elif 1
|
||||
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk_1x1
|
||||
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_pad
|
||||
#else
|
||||
transform_forward_convolution_into_gemm_v4r4_nhwc_kyxc_nhwk_1x1
|
||||
#endif
|
||||
<BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type,
|
||||
TAcc,
|
||||
TOut,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK,
|
||||
GemmABlockTransferDstScalarPerVector_GemmM,
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmK,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmN,
|
||||
GemmCThreadTransferDstScalarPerVector_GemmM1>{};
|
||||
<GemmMPerBlock, GemmNPerBlock, GemmM1, GemmN1>(wei_k_y_x_c0_desc,
|
||||
in_n_hi_wi_c0_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads);
|
||||
|
||||
conv_driver.Run(wei_k_y_x_c0_desc,
|
||||
in_n_hi_wi_c0_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()));
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = launch_kernel_dynamic_gemm_v1<
|
||||
BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperation::Set,
|
||||
decltype(descs[I0]),
|
||||
decltype(descs[I1]),
|
||||
decltype(descs[I2]),
|
||||
decltype(descs[I3]),
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockTransferThreadClusterLengths_GemmK_GemmM,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1, 0>,
|
||||
0,
|
||||
GemmABlockTransferSrcScalarPerVector_GemmK,
|
||||
GemmABlockTransferDstScalarPerVector_GemmM,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
GemmBBlockTransferThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockTransferThreadClusterLengths_GemmK_GemmN,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1, 0>,
|
||||
0,
|
||||
GemmBBlockTransferSrcScalarPerVector_GemmK,
|
||||
GemmBBlockTransferDstScalarPerVector_GemmN,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused with
|
||||
// MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<2, 3, 0, 1>,
|
||||
1,
|
||||
GemmCThreadTransferDstScalarPerVector_GemmM1,
|
||||
decltype(descs[I4]),
|
||||
decltype(descs[I5]),
|
||||
decltype(descs[I6]),
|
||||
decltype(descs[I7]),
|
||||
decltype(descs[I8])>(static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<typename vector_type<TInWei, InWeiVectorSize>::type*>(
|
||||
in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
descs[I0],
|
||||
descs[I1],
|
||||
descs[I2],
|
||||
descs[I3],
|
||||
descs[I4],
|
||||
descs[I5],
|
||||
descs[I6],
|
||||
descs[I7],
|
||||
descs[I8],
|
||||
nrepeat);
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
|
||||
Reference in New Issue
Block a user