mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
refactor
This commit is contained in:
@@ -9,7 +9,7 @@
|
||||
#include "device_direct_convolution_1.cuh"
|
||||
#include "device_direct_convolution_2.cuh"
|
||||
#include "device_implicit_gemm_convolution_1_nchw_kcsr.cuh"
|
||||
#include "device_implicit_gemm_convolution_1_nchw_srck.cuh"
|
||||
#include "device_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh"
|
||||
#include "device_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh"
|
||||
//#include "device_winograd_convolution.cuh"
|
||||
|
||||
@@ -418,8 +418,8 @@ int main()
|
||||
device_direct_convolution_2
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_1_nchw_kcsr
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_1_nchw_srck
|
||||
#elif 1
|
||||
device_implicit_gemm_convolution_1_nchw_srck_nkhw
|
||||
#elif 1
|
||||
device_implicit_gemm_convolution_2_cnhw_srck_knhw
|
||||
#elif 0
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
#pragma once
|
||||
#include "gridwise_implicit_gemm_convolution_1_nchw_srck.cuh"
|
||||
#include "gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw.cuh"
|
||||
#include <unistd.h>
|
||||
|
||||
template <class T, class InDesc, class WeiDesc, class OutDesc>
|
||||
void device_implicit_gemm_convolution_1_nchw_srck(InDesc,
|
||||
void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcsr,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw)
|
||||
Tensor<T>& out_nkhw,
|
||||
unsigned nrepeat)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -101,6 +103,19 @@ void device_implicit_gemm_convolution_1_nchw_srck(InDesc,
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 1
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
constexpr unsigned KPerBlock = 32;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 32;
|
||||
|
||||
constexpr unsigned KPerThread = 4;
|
||||
constexpr unsigned CPerThread = 2;
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#endif
|
||||
|
||||
@@ -113,40 +128,46 @@ void device_implicit_gemm_convolution_1_nchw_srck(InDesc,
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
cudaEvent_t start, stop;
|
||||
float elapsedTime;
|
||||
for(unsigned i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
cudaEvent_t start, stop;
|
||||
float elapsedTime;
|
||||
|
||||
cudaEventCreate(&start);
|
||||
cudaEventRecord(start, 0);
|
||||
cudaEventCreate(&start);
|
||||
cudaEventRecord(start, 0);
|
||||
|
||||
gridwise_implicit_gemm_convolution_1_nchw_srck<GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_srck_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
CPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
KPerThread,
|
||||
CPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread>
|
||||
<<<grid_dim, block_dim>>>(in_nchw_desc,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
wei_srck_desc,
|
||||
static_cast<T*>(wei_srck_device_buf.GetDeviceBuffer()),
|
||||
out_nkhw_desc,
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw<GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_srck_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
CPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
KPerThread,
|
||||
CPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread>
|
||||
<<<grid_dim, block_dim>>>(in_nchw_desc,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
wei_srck_desc,
|
||||
static_cast<T*>(wei_srck_device_buf.GetDeviceBuffer()),
|
||||
out_nkhw_desc,
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
|
||||
cudaEventCreate(&stop);
|
||||
cudaEventRecord(stop, 0);
|
||||
cudaEventSynchronize(stop);
|
||||
cudaEventCreate(&stop);
|
||||
cudaEventRecord(stop, 0);
|
||||
cudaEventSynchronize(stop);
|
||||
|
||||
cudaEventElapsedTime(&elapsedTime, start, stop);
|
||||
printf("Elapsed time : %f ms\n", elapsedTime);
|
||||
|
||||
usleep(10);
|
||||
}
|
||||
|
||||
cudaEventElapsedTime(&elapsedTime, start, stop);
|
||||
printf("Elapsed time : %f ms\n", elapsedTime);
|
||||
|
||||
checkCudaErrors(cudaGetLastError());
|
||||
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
|
||||
@@ -90,6 +90,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 2;
|
||||
|
||||
constexpr unsigned BPerBatch = 32;
|
||||
|
||||
constexpr unsigned BPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
@@ -134,7 +136,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
|
||||
CPerBlock,
|
||||
BPerThread,
|
||||
KPerThread,
|
||||
CPerThread>
|
||||
CPerThread,
|
||||
BPerBatch>
|
||||
<<<grid_dim, block_dim>>>(in_cnhw_desc,
|
||||
static_cast<T*>(in_cnhw_device_buf.GetDeviceBuffer()),
|
||||
wei_srck_desc,
|
||||
|
||||
@@ -22,7 +22,7 @@ template <unsigned GridSize,
|
||||
unsigned HoPerThread,
|
||||
unsigned WoPerThread>
|
||||
__global__ void
|
||||
gridwise_implicit_gemm_convolution_1_nchw_srck(InGlobalDesc,
|
||||
gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
|
||||
Float* const __restrict__ p_in_global,
|
||||
WeiGlobalDesc,
|
||||
Float* const __restrict__ p_wei_global,
|
||||
@@ -19,7 +19,8 @@ template <unsigned GridSize,
|
||||
unsigned CPerBlock,
|
||||
unsigned BPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned CPerThread>
|
||||
unsigned CPerThread,
|
||||
unsigned BPerBatch>
|
||||
__global__ void
|
||||
gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
Float* const __restrict__ p_in_global,
|
||||
@@ -111,15 +112,17 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}); // constexpr doesn't compile
|
||||
|
||||
static_assert(BPerBlock % BPerBatch == 0 && BPerBatch % BPerThread == 0, "B cannot be evenly divided\n");
|
||||
|
||||
const auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{},
|
||||
Number<BPerBlock>{},
|
||||
Number<BPerBatch>{},
|
||||
Number<in_cb_block_desc.GetStride(I0)>{}); // constexpr doesn't compile
|
||||
|
||||
const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThread>{}, Number<BPerThread>{}); // constexpr doesn't compile
|
||||
|
||||
const auto blockwise_gemm =
|
||||
const auto blockwise_batched_gemm =
|
||||
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxb_block_mtx_desc),
|
||||
@@ -128,9 +131,9 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
false,
|
||||
false,
|
||||
0,
|
||||
BPerBatch,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
BPerBlock/BPerBatch,
|
||||
1,
|
||||
CPerThread,
|
||||
true>{};
|
||||
@@ -179,7 +182,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
{
|
||||
auto f_accum = [](auto& c, const auto&& ab) { c += ab; };
|
||||
|
||||
blockwise_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
blockwise_batched_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
p_in_block + s * Wi + r,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
@@ -189,10 +192,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
|
||||
// output: register to global mem,
|
||||
const auto matrix_c_index =
|
||||
blockwise_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
|
||||
blockwise_batched_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id());
|
||||
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
|
||||
const unsigned b_thread_data_begin = matrix_c_index.col_begin;
|
||||
const unsigned b_thread_data_begin = matrix_c_index.batch_begin * BPerBatch + matrix_c_index.col_begin;
|
||||
|
||||
const unsigned k_data_begin = k_block_data_begin + k_thread_data_begin;
|
||||
const unsigned b_data_begin = b_block_data_begin + b_thread_data_begin;
|
||||
|
||||
Reference in New Issue
Block a user