mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
improve implicit gemm NCHW, SRCK, NKHW, and tuned
This commit is contained in:
@@ -361,7 +361,7 @@ int main()
|
||||
constexpr unsigned K = 1;
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 3x3, 34x34
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 256;
|
||||
@@ -370,15 +370,6 @@ int main()
|
||||
constexpr unsigned K = 64;
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
#elif 0
|
||||
// 3x3, 54x54
|
||||
constexpr unsigned N = 64;
|
||||
constexpr unsigned C = 64;
|
||||
constexpr unsigned HI = 54;
|
||||
constexpr unsigned WI = 54;
|
||||
constexpr unsigned K = 64;
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
#elif 0
|
||||
// 3x3, 56x56
|
||||
constexpr unsigned N = 64;
|
||||
@@ -415,6 +406,15 @@ int main()
|
||||
constexpr unsigned K = 64;
|
||||
constexpr unsigned S = 7;
|
||||
constexpr unsigned R = 7;
|
||||
#elif 1
|
||||
// 3x3, 58x58
|
||||
constexpr unsigned N = 16;
|
||||
constexpr unsigned C = 128;
|
||||
constexpr unsigned HI = 58;
|
||||
constexpr unsigned WI = 58;
|
||||
constexpr unsigned K = 256;
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
#endif
|
||||
|
||||
auto in_nchw_desc = make_ConstantTensorDescriptor(Sequence<N, C, HI, WI>{});
|
||||
@@ -449,7 +449,7 @@ int main()
|
||||
device_direct_convolution_2
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_1_nchw_kcsr
|
||||
#elif 1
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_1_nchw_srck_nkhw
|
||||
#elif 1
|
||||
device_implicit_gemm_convolution_1_chwn_csrk_khwn
|
||||
|
||||
@@ -87,8 +87,8 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 8;
|
||||
#elif 1
|
||||
// for 3x3, 34x34 | 3x3 58x58
|
||||
#elif 0
|
||||
// for 3x3, 34x34 | 3x3 58x58, NKC = 64, 64, 256
|
||||
constexpr unsigned NPerBlock = 16;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
@@ -101,6 +101,21 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 1
|
||||
// 3x3 58x58, NKC = 16,256,128
|
||||
constexpr unsigned NPerBlock = 8;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 2;
|
||||
constexpr unsigned HoPerBlock = 4;
|
||||
constexpr unsigned WoPerBlock = 4;
|
||||
|
||||
constexpr unsigned NPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 0
|
||||
// for 5x5, 36x36
|
||||
|
||||
@@ -65,7 +65,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
|
||||
constexpr unsigned WoPerThread = 2;
|
||||
|
||||
constexpr unsigned BlockSize = 16;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// for 3x3, 34x34
|
||||
constexpr unsigned NPerBlock = 1;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
@@ -73,6 +73,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
|
||||
constexpr unsigned HoPerBlock = 4;
|
||||
constexpr unsigned WoPerBlock = 32;
|
||||
|
||||
constexpr unsigned NPerThread = 1;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
@@ -80,16 +81,32 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 0
|
||||
// for 3x3, 34x34
|
||||
constexpr unsigned NPerBlock = 2;
|
||||
// for 3x3, 58x58
|
||||
constexpr unsigned NPerBlock = 4;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 2;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 32;
|
||||
constexpr unsigned HoPerBlock = 4;
|
||||
constexpr unsigned WoPerBlock = 8;
|
||||
|
||||
constexpr unsigned NPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 2;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 1
|
||||
// for 3x3, 56x56
|
||||
constexpr unsigned NPerBlock = 32;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
constexpr unsigned HoPerBlock = 2;
|
||||
constexpr unsigned WoPerBlock = 2;
|
||||
|
||||
constexpr unsigned NPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
constexpr unsigned HoPerThread = 1;
|
||||
constexpr unsigned WoPerThread = 1;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
@@ -123,6 +140,7 @@ void device_implicit_gemm_convolution_1_nchw_srck_nkhw(InDesc,
|
||||
CPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
NPerThread,
|
||||
KPerThread,
|
||||
CPerThread,
|
||||
HoPerThread,
|
||||
|
||||
@@ -17,6 +17,7 @@ template <unsigned GridSize,
|
||||
unsigned CPerBlock,
|
||||
unsigned HoPerBlock,
|
||||
unsigned WoPerBlock,
|
||||
unsigned NPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned CPerThread,
|
||||
unsigned HoPerThread,
|
||||
@@ -32,7 +33,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
|
||||
// NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N]
|
||||
// for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N"
|
||||
// if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock
|
||||
constexpr unsigned NPerThread = NPerBlock;
|
||||
static_assert(NPerBlock % NPerThread == 0, "wrong! NPerBlock % NPerThread !=0");
|
||||
static_assert((NPerThread < NPerBlock && WoPerThread == 1) || NPerThread == NPerBlock,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -207,7 +210,9 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
|
||||
|
||||
const unsigned ho_thread_data_begin = matrix_c_index.batch_begin;
|
||||
const unsigned k_thread_data_begin = matrix_c_index.row_begin;
|
||||
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerThread;
|
||||
const unsigned wo_thread_data_begin = matrix_c_index.col_begin / NPerBlock;
|
||||
const unsigned n_thread_data_begin =
|
||||
matrix_c_index.col_begin - wo_thread_data_begin * NPerBlock;
|
||||
|
||||
// output: register to global mem,
|
||||
// convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo]
|
||||
@@ -217,7 +222,7 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
|
||||
out_hkwn_thread_desc,
|
||||
p_out_thread,
|
||||
out_nkhw_global_desc,
|
||||
p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin,
|
||||
p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
|
||||
Reference in New Issue
Block a user