mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
refactor
This commit is contained in:
@@ -354,10 +354,10 @@ int main()
|
||||
{
|
||||
#if 0
|
||||
constexpr unsigned N = 1;
|
||||
constexpr unsigned C = 2;
|
||||
constexpr unsigned C = 1;
|
||||
constexpr unsigned HI = 34;
|
||||
constexpr unsigned WI = 34;
|
||||
constexpr unsigned K = 2;
|
||||
constexpr unsigned K = 4;
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
#elif 1
|
||||
@@ -418,7 +418,7 @@ int main()
|
||||
device_direct_convolution_2
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_1_nchw_kcsr
|
||||
#elif 1
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_1_nchw_srck_nkhw
|
||||
#elif 1
|
||||
device_implicit_gemm_convolution_2_cnhw_srck_knhw
|
||||
|
||||
@@ -4,12 +4,12 @@
|
||||
|
||||
template <class T, class InDesc, class WeiDesc, class OutDesc>
|
||||
void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcsr,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
unsigned nrepeat)
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcsr,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
unsigned nrepeat)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -104,7 +104,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 1
|
||||
#elif 0
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
constexpr unsigned KPerBlock = 32;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
@@ -137,20 +137,20 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
|
||||
cudaEventRecord(start, 0);
|
||||
|
||||
gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw<GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_srck_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
CPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
KPerThread,
|
||||
CPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread>
|
||||
BlockSize,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_srck_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
CPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
KPerThread,
|
||||
CPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread>
|
||||
<<<grid_dim, block_dim>>>(in_nchw_desc,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
wei_srck_desc,
|
||||
@@ -165,10 +165,9 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
|
||||
cudaEventElapsedTime(&elapsedTime, start, stop);
|
||||
printf("Elapsed time : %f ms\n", elapsedTime);
|
||||
|
||||
usleep(10);
|
||||
usleep(10000);
|
||||
}
|
||||
|
||||
|
||||
checkCudaErrors(cudaGetLastError());
|
||||
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
#include "gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw.cuh"
|
||||
#include <unistd.h>
|
||||
|
||||
template <class T, class InDesc, class WeiDesc, class OutDesc>
|
||||
void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
|
||||
@@ -67,35 +68,29 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
|
||||
|
||||
#if 0
|
||||
constexpr unsigned BPerBlock = 128;
|
||||
constexpr unsigned KPerBlock = 1;
|
||||
constexpr unsigned KPerBlock = 4;
|
||||
constexpr unsigned CPerBlock = 1;
|
||||
|
||||
constexpr unsigned BPerThread = 4;
|
||||
constexpr unsigned KPerThread = 1;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 32;
|
||||
#elif 0
|
||||
constexpr unsigned BPerBlock = 128;
|
||||
constexpr unsigned KPerBlock = 2;
|
||||
constexpr unsigned CPerBlock = 2;
|
||||
constexpr unsigned ThreadPerClusterRow = 4;
|
||||
constexpr unsigned ThreadPerClusterColumn = 16;
|
||||
|
||||
constexpr unsigned BPerThread = 4;
|
||||
constexpr unsigned KPerThread = 2;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 32;
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 1
|
||||
constexpr unsigned BPerBlock = 128;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 2;
|
||||
|
||||
constexpr unsigned BPerBatch = 32;
|
||||
|
||||
constexpr unsigned BPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
|
||||
constexpr unsigned ThreadPerClusterRow = 4;
|
||||
constexpr unsigned ThreadPerClusterColumn = 16;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#endif
|
||||
|
||||
@@ -137,7 +132,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
|
||||
BPerThread,
|
||||
KPerThread,
|
||||
CPerThread,
|
||||
BPerBatch>
|
||||
ThreadPerClusterRow,
|
||||
ThreadPerClusterColumn>
|
||||
<<<grid_dim, block_dim>>>(in_cnhw_desc,
|
||||
static_cast<T*>(in_cnhw_device_buf.GetDeviceBuffer()),
|
||||
wei_srck_desc,
|
||||
@@ -151,6 +147,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
|
||||
|
||||
cudaEventElapsedTime(&elapsedTime, start, stop);
|
||||
printf("Elapsed time : %f ms\n", elapsedTime);
|
||||
|
||||
usleep(10000);
|
||||
}
|
||||
|
||||
checkCudaErrors(cudaGetLastError());
|
||||
|
||||
Reference in New Issue
Block a user