mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
tuning on vega 20
This commit is contained in:
@@ -140,7 +140,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// for 3x3, 34x34, v1r3, Pascal
|
||||
// for 3x3, 28x28, v1r3, Pascal
|
||||
// for 3x3, 14x14, v1r3, Pascal
|
||||
@@ -206,6 +206,8 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 1;
|
||||
#elif 0
|
||||
// for 3x3, 34x34, v1r1, Vega 20
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
@@ -227,16 +229,43 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
|
||||
constexpr index_t InBlockCopyDataPerRead_N = 2;
|
||||
using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 8>;
|
||||
constexpr index_t InBlockCopyDataPerRead_N = 2;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 2;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 2;
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 4;
|
||||
|
||||
#elif 1
|
||||
// for 3x3, 34x34, v1r3, Vega 20
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 4, 4>;
|
||||
constexpr index_t InBlockCopyDataPerRead_N = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 4;
|
||||
#elif 0
|
||||
// for 3x3, 56x56, v1r1, Pascal
|
||||
constexpr index_t NPerBlock = 32;
|
||||
@@ -448,7 +477,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
#elif 1
|
||||
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
#elif 0
|
||||
#elif 1
|
||||
GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
|
||||
#endif
|
||||
<GridSize,
|
||||
|
||||
@@ -182,7 +182,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
|
||||
constexpr auto gridwise_conv =
|
||||
#if 0
|
||||
GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
#elif 1
|
||||
#elif 0
|
||||
GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
|
||||
#elif 1
|
||||
GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
|
||||
|
||||
@@ -52,7 +52,7 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
in_nchw_vec(n, c, h, w) =
|
||||
vector_t::Pack(in_nchw(n, 2 * c, h, w), in_nchw(n, 2 * c + 1, h, w));
|
||||
#elif 1
|
||||
in_nchw_vec(n, c, h, w) = vector_t::Pack(in_nchw(n, 4 * c, h, w),
|
||||
in_nchw_vec(n, c, h, w) = vector_t::Pack(in_nchw(n, 4 * c, h, w),
|
||||
in_nchw(n, 4 * c + 1, h, w),
|
||||
in_nchw(n, 4 * c + 2, h, w),
|
||||
in_nchw(n, 4 * c + 3, h, w));
|
||||
@@ -114,37 +114,37 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// 3x3, 34x34, 128 thread, fp32, vector = 2
|
||||
constexpr index_t NPerBlock = 2;
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t CPerBlock = 2;
|
||||
constexpr index_t NPerBlock = 2;
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t CPerBlock = 2;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
|
||||
constexpr index_t NPerThread = 2;
|
||||
constexpr index_t KPerThread = 4;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t NPerThread = 2;
|
||||
constexpr index_t KPerThread = 4;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t InBlockCopyDataPerRead = 2;
|
||||
constexpr index_t InBlockCopyDataPerRead = 2;
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// 3x3, 34x34, 128 thread, int8, vector = 4
|
||||
constexpr index_t NPerBlock = 2;
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t NPerBlock = 2;
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 4;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
|
||||
constexpr index_t NPerThread = 1;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t CPerThread = 2;
|
||||
constexpr index_t NPerThread = 1;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t CPerThread = 2;
|
||||
constexpr index_t HoPerThread = 4;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t InBlockCopyDataPerRead = 2;
|
||||
constexpr index_t InBlockCopyDataPerRead = 2;
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
@@ -371,7 +371,7 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
|
||||
std::size_t ho = HoPerTile * htile + j;
|
||||
for(int i = 0; i < WoPerTile; ++i)
|
||||
{
|
||||
std::size_t wo = WoPerTile * wtile + i;
|
||||
std::size_t wo = WoPerTile * wtile + i;
|
||||
out_nkhw(n, k, ho, wo) = out_hold(n, k, htile, wtile, j, i);
|
||||
}
|
||||
}
|
||||
@@ -425,13 +425,13 @@ int main(int argc, char* argv[])
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 3x3, 56x56
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 56;
|
||||
constexpr index_t WI = 56;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
@@ -603,9 +603,9 @@ int main(int argc, char* argv[])
|
||||
device_direct_convolution_2_nchw_kcyx_nkhw
|
||||
#elif 0
|
||||
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
|
||||
#elif 0
|
||||
#elif 1
|
||||
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
|
||||
#elif 0
|
||||
#elif 1
|
||||
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
|
||||
#elif 1
|
||||
device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw
|
||||
|
||||
@@ -116,11 +116,7 @@ struct ConstantTensorDescriptor
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
constexpr index_t idim = IDim.Get();
|
||||
#if DEVICE_BACKEND_HIP
|
||||
id += __mul24(multi_id[idim], GetStride(IDim));
|
||||
#else
|
||||
id += multi_id[idim] * GetStride(IDim);
|
||||
#endif
|
||||
});
|
||||
|
||||
return id;
|
||||
|
||||
@@ -213,7 +213,6 @@ struct Blockwise3dTensorCopy3
|
||||
#pragma unroll
|
||||
for(index_t iloop_d2 = 0; iloop_d2 < nloop_d2; ++iloop_d2)
|
||||
{
|
||||
#pragma unroll
|
||||
const index_t src_offset =
|
||||
SrcDesc{}.Get1dIndex(iloop_d0 * thread_per_d0,
|
||||
iloop_d1 * thread_per_d1,
|
||||
|
||||
@@ -341,10 +341,11 @@ struct BlockwiseChwnTensorCopyPadded
|
||||
constexpr index_t NLoop = ref_desc.GetElementSize() / BlockSize;
|
||||
|
||||
const Float* p_src_tmp =
|
||||
p_src + src_desc.Get1dIndex(c_block_data_begin,
|
||||
(ho_block_data_begin + h_block_pad_low) - h_global_pad_low,
|
||||
(wo_block_data_begin + w_block_pad_low) - w_global_pad_low,
|
||||
n_block_data_begin);
|
||||
p_src +
|
||||
src_desc.Get1dIndex(c_block_data_begin,
|
||||
(ho_block_data_begin + h_block_pad_low) - h_global_pad_low,
|
||||
(wo_block_data_begin + w_block_pad_low) - w_global_pad_low,
|
||||
n_block_data_begin);
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
|
||||
@@ -404,8 +404,9 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
c_thread_sub_mtx,
|
||||
p_c_thread + c_thread_sub_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster,
|
||||
n_repeat * NPerLevel1Cluster),
|
||||
p_c_thread +
|
||||
c_thread_sub_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster,
|
||||
n_repeat * NPerLevel1Cluster),
|
||||
c_block_mtx,
|
||||
p_c_block +
|
||||
c_block_mtx.Get1dIndex(m_repeat * MPerLevel1Cluster,
|
||||
|
||||
@@ -93,10 +93,11 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
|
||||
Float p_out_thread[out_thread_desc.GetElementSpace()];
|
||||
|
||||
threadwise_4d_tensor_copy(out_block_desc,
|
||||
p_out_block + out_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
k_thread_data_begin,
|
||||
ho_thread_data_begin,
|
||||
wo_thread_data_begin),
|
||||
p_out_block +
|
||||
out_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
k_thread_data_begin,
|
||||
ho_thread_data_begin,
|
||||
wo_thread_data_begin),
|
||||
out_thread_desc,
|
||||
p_out_thread,
|
||||
out_thread_desc.GetLengths());
|
||||
@@ -107,10 +108,11 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
|
||||
// threadwise convolution
|
||||
threadwise_direct_convolution_2(
|
||||
in_thread_block_desc,
|
||||
p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data_begin,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
p_in_block +
|
||||
in_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data_begin,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_thread_block_desc,
|
||||
p_wei_block +
|
||||
wei_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data_begin, 0, 0),
|
||||
@@ -122,10 +124,11 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
|
||||
threadwise_4d_tensor_copy(out_thread_desc,
|
||||
p_out_thread,
|
||||
out_block_desc,
|
||||
p_out_block + out_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
k_thread_data_begin,
|
||||
ho_thread_data_begin,
|
||||
wo_thread_data_begin),
|
||||
p_out_block +
|
||||
out_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
k_thread_data_begin,
|
||||
ho_thread_data_begin,
|
||||
wo_thread_data_begin),
|
||||
out_thread_desc.GetLengths());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,7 +56,7 @@ struct BlockwiseNdTensorCopyReorder_v3
|
||||
"wrong! BlockSize is not big enough for ThreadPerDims!");
|
||||
|
||||
// sanity check: work division
|
||||
static_for<0, nDim, 1>{}([](auto IDim) {
|
||||
static_for<0, nDim, 1>{}([&](auto IDim) {
|
||||
constexpr auto I = decltype(IDim){};
|
||||
constexpr index_t src_len = src_lengths.Get(I);
|
||||
constexpr index_t src_sub_len = src_sub_lengths.Get(I);
|
||||
@@ -220,7 +220,7 @@ struct BlockwiseNdTensorCopyReorder_v3
|
||||
|
||||
constexpr index_t dst_offset = DstDesc{}.Get1dIndex(dst_data_multi_id);
|
||||
|
||||
// write in the order of dst
|
||||
// write in the order of dst
|
||||
#if 1
|
||||
threadwise_nd_tensor_copy_reorder_given_dst2src_v2(thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
|
||||
@@ -43,11 +43,12 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0),
|
||||
"wrong!");
|
||||
static_assert(
|
||||
NPerBlock % NPerThread == 0 &&
|
||||
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0)),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -219,8 +220,9 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
|
||||
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global + in_c_h_w_n_global_desc.Get1dIndex(
|
||||
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
|
||||
p_in_global +
|
||||
in_c_h_w_n_global_desc.Get1dIndex(
|
||||
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
|
||||
@@ -275,39 +277,40 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}(
|
||||
[&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
|
||||
// perfect forwarding.
|
||||
// Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert(
|
||||
(f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) /
|
||||
f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
@@ -322,47 +325,37 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
})
|
||||
.else_([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
threadwise_nd_tensor_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
}).else_([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2 * W3),
|
||||
W1,
|
||||
W2,
|
||||
W3,
|
||||
N / N1,
|
||||
N1>{});
|
||||
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
@@ -382,17 +375,17 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
threadwise_nd_tensor_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -44,11 +44,12 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0),
|
||||
"wrong!");
|
||||
static_assert(
|
||||
NPerBlock % NPerThread == 0 &&
|
||||
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0)),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -125,8 +126,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
#if 1
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
@@ -228,8 +229,9 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
|
||||
#if 1
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global + in_c_h_w_n_global_desc.Get1dIndex(
|
||||
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
|
||||
p_in_global +
|
||||
in_c_h_w_n_global_desc.Get1dIndex(
|
||||
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
|
||||
@@ -273,12 +275,12 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
|
||||
for(index_t
|
||||
c_block_data_begin = 0;
|
||||
c_block_data_begin < C;
|
||||
c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset +=
|
||||
CPerBlock * in_c_h_w_n_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset +=
|
||||
CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
|
||||
|
||||
@@ -308,39 +310,40 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}(
|
||||
[&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
|
||||
// perfect forwarding.
|
||||
// Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert(
|
||||
(f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) /
|
||||
f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
@@ -355,47 +358,37 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
})
|
||||
.else_([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
threadwise_nd_tensor_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
}).else_([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2 * W3),
|
||||
W1,
|
||||
W2,
|
||||
W3,
|
||||
N / N1,
|
||||
N1>{});
|
||||
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
@@ -415,17 +408,17 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
threadwise_nd_tensor_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -49,8 +49,11 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(
|
||||
NPerThread <= NPerBlock && NPerBlock % NPerThread == 0,
|
||||
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0");
|
||||
NPerBlock % NPerThread == 0 &&
|
||||
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0)),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -262,12 +265,12 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
|
||||
for(index_t
|
||||
c_block_data_begin = 0;
|
||||
c_block_data_begin < C;
|
||||
c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset +=
|
||||
CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
|
||||
p_wei_global_block_offset +=
|
||||
CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
|
||||
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
Float p_in_clipboard[blockwise_in_copy_reorder.GetRegisterClipboardSize()];
|
||||
Float p_wei_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
@@ -333,15 +336,16 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_10d_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global + out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
threadwise_10d_tensor_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -43,11 +43,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0),
|
||||
"wrong!");
|
||||
static_assert(
|
||||
NPerBlock % NPerThread == 0 &&
|
||||
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0)),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -212,10 +213,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global + in_c_h_w_n_global_desc.Get1dIndex(
|
||||
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
|
||||
p_in_global +
|
||||
in_c_h_w_n_global_desc.Get1dIndex(
|
||||
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
|
||||
@@ -226,6 +228,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
{
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
blockwise_in_copy.Run(p_in_global_block_offset +
|
||||
@@ -287,39 +290,40 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}(
|
||||
[&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
|
||||
// perfect forwarding.
|
||||
// Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert(
|
||||
(f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) /
|
||||
f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
@@ -334,47 +338,37 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
})
|
||||
.else_([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
threadwise_nd_tensor_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
}).else_([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2 * W3),
|
||||
W1,
|
||||
W2,
|
||||
W3,
|
||||
N / N1,
|
||||
N1>{});
|
||||
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
@@ -394,17 +388,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
threadwise_nd_tensor_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -43,11 +43,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0),
|
||||
"wrong!");
|
||||
static_assert(
|
||||
NPerBlock % NPerThread == 0 &&
|
||||
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0)),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -127,8 +128,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
#if 0
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
@@ -349,39 +350,40 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}(
|
||||
[&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
|
||||
// perfect forwarding.
|
||||
// Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert(
|
||||
(f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) /
|
||||
f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
@@ -396,47 +398,37 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
})
|
||||
.else_([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
threadwise_nd_tensor_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
}).else_([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2 * W3),
|
||||
W1,
|
||||
W2,
|
||||
W3,
|
||||
N / N1,
|
||||
N1>{});
|
||||
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
@@ -456,17 +448,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
threadwise_nd_tensor_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -47,11 +47,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0),
|
||||
"wrong!");
|
||||
static_assert(
|
||||
NPerBlock % NPerThread == 0 &&
|
||||
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0)),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -349,39 +350,40 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}(
|
||||
[&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
|
||||
// perfect forwarding.
|
||||
// Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert(
|
||||
(f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) /
|
||||
f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
@@ -396,47 +398,37 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
})
|
||||
.else_([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
threadwise_nd_tensor_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
}).else_([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2 * W3),
|
||||
W1,
|
||||
W2,
|
||||
W3,
|
||||
N / N1,
|
||||
N1>{});
|
||||
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
@@ -456,17 +448,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
threadwise_nd_tensor_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -47,11 +47,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0),
|
||||
"wrong!");
|
||||
static_assert(
|
||||
NPerBlock % NPerThread == 0 &&
|
||||
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0)),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -223,8 +224,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
|
||||
|
||||
#if 1
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global + in_n_c_h_w_global_desc.Get1dIndex(
|
||||
n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin);
|
||||
p_in_global +
|
||||
in_n_c_h_w_global_desc.Get1dIndex(
|
||||
n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
|
||||
@@ -329,39 +331,40 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}(
|
||||
[&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}([&](auto f_dummy) { // f_dummy do nothing but
|
||||
// perfect forwarding.
|
||||
// Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert(
|
||||
(f_dummy(GemmNPerThreadSubC) <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) /
|
||||
f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
@@ -376,47 +379,37 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
})
|
||||
.else_([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
threadwise_nd_tensor_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
}).else_([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2 * W3),
|
||||
W1,
|
||||
W2,
|
||||
W3,
|
||||
N / N1,
|
||||
N1>{});
|
||||
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2 * W3), W1, W2, W3, N / N1, N1>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
@@ -436,17 +429,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
threadwise_nd_tensor_copy(
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -47,11 +47,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0),
|
||||
"wrong!");
|
||||
static_assert(
|
||||
NPerBlock % NPerThread == 0 &&
|
||||
((GemmNPerThreadSubC <= NPerBlock && NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0)),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -223,8 +224,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
|
||||
#if 1
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global + in_n_c_h_w_global_desc.Get1dIndex(
|
||||
n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin);
|
||||
p_in_global +
|
||||
in_n_c_h_w_global_desc.Get1dIndex(
|
||||
n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
|
||||
@@ -409,14 +411,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global + out_n_k_h_w_global_desc.Get1dIndex(
|
||||
n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
p_out_global +
|
||||
out_n_k_h_w_global_desc.Get1dIndex(
|
||||
n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
map_out_global2thread);
|
||||
// Number<OutThreadCopyDataPerWrite_W>{});
|
||||
// Number<OutThreadCopyDataPerWrite_W>{});
|
||||
#endif
|
||||
})
|
||||
.else_([&](auto f_dummy) {
|
||||
@@ -500,14 +503,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_nkhw
|
||||
out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global + out_n_k_h_w_global_desc.Get1dIndex(
|
||||
n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
p_out_global +
|
||||
out_n_k_h_w_global_desc.Get1dIndex(
|
||||
n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
map_out_global2thread);
|
||||
// Number<OutThreadCopyDataPerWrite_W>{});
|
||||
// Number<OutThreadCopyDataPerWrite_W>{});
|
||||
#endif
|
||||
});
|
||||
}
|
||||
|
||||
@@ -365,13 +365,14 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
|
||||
|
||||
constexpr auto out_kb_global_desc = make_ConstantTensorDescriptor(Sequence<K, B>{});
|
||||
|
||||
threadwise_6d_tensor_copy(out_6d_thread_desc,
|
||||
p_out_thread,
|
||||
out_6d_global_desc,
|
||||
p_out_global + out_kb_global_desc.Get1dIndex(
|
||||
k_thread_data_begin, b_thread_data_begin),
|
||||
out_6d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite>{});
|
||||
threadwise_6d_tensor_copy(
|
||||
out_6d_thread_desc,
|
||||
p_out_thread,
|
||||
out_6d_global_desc,
|
||||
p_out_global +
|
||||
out_kb_global_desc.Get1dIndex(k_thread_data_begin, b_thread_data_begin),
|
||||
out_6d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -113,10 +113,11 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
|
||||
c_block_work_begin += CPerBlock)
|
||||
{
|
||||
// copy input tensor to LDS
|
||||
blockwise_in_copy.Run(p_in_global + in_global_desc.Get1dIndex(n_block_work_begin,
|
||||
c_block_work_begin,
|
||||
hi_block_work_begin,
|
||||
wi_block_work_begin),
|
||||
blockwise_in_copy.Run(p_in_global +
|
||||
in_global_desc.Get1dIndex(n_block_work_begin,
|
||||
c_block_work_begin,
|
||||
hi_block_work_begin,
|
||||
wi_block_work_begin),
|
||||
p_in_block);
|
||||
|
||||
// copy weight tensor to LDS
|
||||
@@ -143,9 +144,9 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
|
||||
}
|
||||
|
||||
// copy output tensor from LDS to device mem
|
||||
blockwise_out_copy.Run(p_out_block,
|
||||
p_out_global + out_global_desc.Get1dIndex(n_block_work_begin,
|
||||
k_block_work_begin,
|
||||
ho_block_work_begin,
|
||||
wo_block_work_begin));
|
||||
blockwise_out_copy.Run(
|
||||
p_out_block,
|
||||
p_out_global +
|
||||
out_global_desc.Get1dIndex(
|
||||
n_block_work_begin, k_block_work_begin, ho_block_work_begin, wo_block_work_begin));
|
||||
}
|
||||
|
||||
@@ -175,16 +175,18 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
|
||||
c_block_data_begin += CPerBlock, __syncthreads())
|
||||
{
|
||||
// copy input tensor to LDS
|
||||
blockwise_in_copy.Run(p_in_global + in_nchw_global_desc.Get1dIndex(n_block_data_begin,
|
||||
c_block_data_begin,
|
||||
hi_block_data_begin,
|
||||
wi_block_data_begin),
|
||||
blockwise_in_copy.Run(p_in_global +
|
||||
in_nchw_global_desc.Get1dIndex(n_block_data_begin,
|
||||
c_block_data_begin,
|
||||
hi_block_data_begin,
|
||||
wi_block_data_begin),
|
||||
p_in_block);
|
||||
|
||||
// copy weight tensor to LDS
|
||||
blockwise_wei_copy.Run(p_wei_global + wei_kcyx_global_desc.Get1dIndex(
|
||||
k_block_data_begin, c_block_data_begin, 0, 0),
|
||||
p_wei_block);
|
||||
blockwise_wei_copy.Run(
|
||||
p_wei_global +
|
||||
wei_kcyx_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
|
||||
p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@@ -194,10 +196,11 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
|
||||
#if 1
|
||||
threadwise_direct_convolution_2(
|
||||
in_nchw_thread_block_desc,
|
||||
p_in_block + in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
p_in_block +
|
||||
in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_kcyx_thread_block_desc,
|
||||
p_wei_block +
|
||||
wei_kcyx_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
|
||||
@@ -206,10 +209,11 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
|
||||
#elif 0
|
||||
threadwise_direct_convolution_3(
|
||||
in_nchw_thread_block_desc,
|
||||
p_in_block + in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
p_in_block +
|
||||
in_nchw_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_kcyx_thread_block_desc,
|
||||
p_wei_block +
|
||||
wei_kcyx_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
|
||||
@@ -224,9 +228,10 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
|
||||
out_nkhw_thread_desc,
|
||||
p_out_thread,
|
||||
out_nkhw_global_desc,
|
||||
p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
p_out_global +
|
||||
out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_nkhw_thread_desc.GetLengths());
|
||||
}
|
||||
|
||||
@@ -198,9 +198,10 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
p_in_vec_block);
|
||||
|
||||
// copy weight tensor to LDS
|
||||
blockwise_wei_copy.Run(p_wei_vec_global + wei_kcyx_vec_global_desc.Get1dIndex(
|
||||
k_block_data_begin, c_block_data_begin, 0, 0),
|
||||
p_wei_vec_block);
|
||||
blockwise_wei_copy.Run(
|
||||
p_wei_vec_global +
|
||||
wei_kcyx_vec_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
|
||||
p_wei_vec_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@@ -210,10 +211,11 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
#if 1
|
||||
threadwise_direct_convolution_2(
|
||||
in_nchw_vec_thread_block_desc,
|
||||
p_in_vec_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
p_in_vec_block +
|
||||
in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_kcyx_vec_thread_block_desc,
|
||||
p_wei_vec_block +
|
||||
wei_kcyx_vec_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
|
||||
@@ -222,10 +224,11 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
#elif 0
|
||||
threadwise_direct_convolution_3(
|
||||
in_nchw_vec_thread_block_desc,
|
||||
p_in_vec_block + in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
p_in_vec_block +
|
||||
in_nchw_vec_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
c_thread_data,
|
||||
hi_thread_data_begin,
|
||||
wi_thread_data_begin),
|
||||
wei_kcyx_vec_thread_block_desc,
|
||||
p_wei_vec_block +
|
||||
wei_kcyx_vec_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0),
|
||||
@@ -240,9 +243,10 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
|
||||
out_nkhw_thread_desc,
|
||||
p_out_thread,
|
||||
out_nkhw_global_desc,
|
||||
p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
p_out_global +
|
||||
out_nkhw_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin,
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin),
|
||||
out_nkhw_thread_desc.GetLengths());
|
||||
}
|
||||
|
||||
@@ -283,10 +283,11 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(
|
||||
out_hkwn_thread_desc,
|
||||
p_out_thread,
|
||||
out_khwn_global_desc,
|
||||
p_out_global + out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
p_out_global +
|
||||
out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_hkwn_thread_desc.GetLengths(),
|
||||
reorder_khwn_from_hkwn);
|
||||
}
|
||||
|
||||
@@ -22,8 +22,7 @@ std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
|
||||
return os;
|
||||
}
|
||||
|
||||
typedef enum
|
||||
{
|
||||
typedef enum {
|
||||
Half = 0,
|
||||
Float = 1,
|
||||
} DataType_t;
|
||||
|
||||
Reference in New Issue
Block a user