mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
added GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
This commit is contained in:
@@ -111,7 +111,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// for 3x3, 34x34, v1r2, Pascal, in-block-copy1
|
||||
constexpr index_t NPerBlock = 4;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
|
||||
@@ -62,208 +62,10 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
|
||||
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
|
||||
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
|
||||
|
||||
#if 0
|
||||
// for 3x3, 34x34, v1r1, Pascal
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 3x3, 34x34, v1r2, Pascal, in-block-copy1
|
||||
constexpr index_t NPerBlock = 4;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 4;
|
||||
constexpr index_t WoPerBlock = 8;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 1;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 3x3, 34x34, v1r1, Vega 20
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 4;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
|
||||
constexpr index_t InBlockCopyDataPerRead = 2;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 2;
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 4;
|
||||
|
||||
constexpr index_t BlockSize = 256;
|
||||
#elif 0
|
||||
// for 3x3, 56x56, v1, Pascal
|
||||
constexpr index_t NPerBlock = 32;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 1;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 3x3, 56x56, v1r2, Pascal
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 1;
|
||||
constexpr index_t GemmDataPerReadB = 1;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 1;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 4;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 3x3, 28x28, v1r1, Pacal
|
||||
constexpr index_t NPerBlock = 32;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 1;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 8;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 1
|
||||
#if 1
|
||||
// for 3x3, 28x28, v1r2, Pascal
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
@@ -275,14 +77,6 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr index_t InBlockCopyDataPerRead = 2;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
@@ -293,73 +87,16 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 2>;
|
||||
using InBlockReorderSrcClusterLengths_NCHW = Sequence<4, 8, 2, 2>;
|
||||
using InBlockReorderMapThreadCluster2SrcCluster = Sequence<1, 2, 3, 0>;
|
||||
constexpr index_t InBlockReorderDataPerRead_W = 2;
|
||||
constexpr index_t InBlockReorderDataPerWrite_N = 4;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 0
|
||||
// for 1x1, 28x28
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
using WeiBlockCopyClusterLengths = Sequence<4, 1, 32>;
|
||||
constexpr index_t WeiBlockCopyDataPerRead_C = 4;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t CPerThread = 1;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
#elif 1
|
||||
// for 1x1, 14x14, Pascal
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr index_t NPerThread = 8;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 8;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead = 4;
|
||||
constexpr index_t OutThreadCopyDataPerWrite = 2;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
|
||||
#endif
|
||||
|
||||
constexpr index_t GridSize =
|
||||
@@ -398,13 +135,14 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
|
||||
GemmKPerThreadLoop,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB,
|
||||
Sequence<InBlockCopy_ThreadPerDimN,
|
||||
InBlockCopy_ThreadPerDimC,
|
||||
InBlockCopy_ThreadPerDimH,
|
||||
InBlockCopy_ThreadPerDimW>,
|
||||
InBlockCopyDataPerRead,
|
||||
WeiBlockCopyDataPerRead,
|
||||
OutThreadCopyDataPerWrite>{};
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
InBlockReorderMapThreadCluster2SrcCluster,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N,
|
||||
WeiBlockCopyClusterLengths,
|
||||
WeiBlockCopyDataPerRead_C,
|
||||
OutThreadCopyDataPerWrite_N>{};
|
||||
|
||||
float time = launch_kernel(run_gridwise_convolution<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
|
||||
@@ -673,7 +673,7 @@ int main(int argc, char* argv[])
|
||||
device_direct_convolution_2_nchw_kcyx_nkhw
|
||||
#elif 0
|
||||
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
|
||||
#elif 1
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
|
||||
#elif 1
|
||||
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
|
||||
|
||||
@@ -231,7 +231,7 @@ struct Blockwise3dTensorCopy3
|
||||
}
|
||||
}
|
||||
|
||||
__device__ constexpr index_t GetRegisterClipboardSize() const
|
||||
__device__ static constexpr index_t GetRegisterClipboardSize()
|
||||
{
|
||||
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
|
||||
|
||||
|
||||
@@ -761,7 +761,6 @@ struct Blockwise4dTensorCopyReorder1
|
||||
}
|
||||
};
|
||||
|
||||
#if 1
|
||||
template <index_t BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
@@ -1070,36 +1069,17 @@ struct Blockwise4dTensorCopyReorder3
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(
|
||||
threadwise_4d_tensor_copy_reorder_given_dst2src_v2(
|
||||
thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
DstDesc{},
|
||||
p_dst + dst_offset + mDstMyThreadOffset,
|
||||
thread_sub_tensor_lengths,
|
||||
MapDst2Src{});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("tid %5u, "
|
||||
"data: %f %f %f %f %f %f %f %f\n",
|
||||
get_thread_local_1d_id(),
|
||||
p_clipboard[0],
|
||||
p_clipboard[1],
|
||||
p_clipboard[2],
|
||||
p_clipboard[3],
|
||||
p_clipboard[4],
|
||||
p_clipboard[5],
|
||||
p_clipboard[6],
|
||||
p_clipboard[7]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void Run(const Float* __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
@@ -1110,4 +1090,3 @@ struct Blockwise4dTensorCopyReorder3
|
||||
RunStoreRegisterClipboard(p_clipboard, p_dst);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
@@ -33,10 +33,14 @@ template <index_t GridSize,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopyThreadPerDims,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead,
|
||||
index_t OutThreadCopyDataPerWrite>
|
||||
class InBlockReorderSrcSubLengths_NCHW,
|
||||
class InBlockReorderSrcClusterLengths_NCHW,
|
||||
class InBlockReorderMapThreadCluster2SrcCluster,
|
||||
index_t InBlockReorderDataPerRead_W,
|
||||
index_t InBlockReorderDataPerWrite_N,
|
||||
class WeiBlockCopyClusterLengths_KXC,
|
||||
index_t WeiBlockCopyDataPerRead_C,
|
||||
index_t OutThreadCopyDataPerWrite_N>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
@@ -101,8 +105,10 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
|
||||
// LDS tensor view
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align = mod_conv::max(
|
||||
InBlockCopyDataPerRead, WeiBlockCopyDataPerRead, GemmDataPerReadA, GemmDataPerReadB);
|
||||
constexpr index_t max_align = mod_conv::max(InBlockReorderDataPerWrite_N,
|
||||
WeiBlockCopyDataPerRead_C,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HoPerBlock, WiPerBlock, NPerBlock>{}, Number<max_align>{});
|
||||
@@ -117,68 +123,38 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
// blockwise copy
|
||||
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
|
||||
auto map_chwn2nchw = Sequence<1, 2, 3, 0>{};
|
||||
#if 0
|
||||
const auto blockwise_in_copy_reorder =
|
||||
Blockwise4dTensorCopyReorder1<BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WiPerBlock>,
|
||||
decltype(map_chwn2nchw)>{};
|
||||
#else
|
||||
auto map_thread_cluster_2_src_cluster = Sequence<1, 2, 0, 3>{};
|
||||
|
||||
const auto blockwise_in_copy_reorder =
|
||||
Blockwise4dTensorCopyReorder3<BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WiPerBlock>,
|
||||
Sequence<4, 1, 1, 2>,
|
||||
Sequence<4, 8, 2, 2>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
decltype(map_thread_cluster_2_src_cluster),
|
||||
2,
|
||||
4>{};
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
printf("size %u\n", blockwise_in_copy_reorder.GetRegisterClipboardSize());
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
InBlockReorderMapThreadCluster2SrcCluster,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>{};
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, X * KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
#if 0
|
||||
Blockwise3dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_x_k_global_desc),
|
||||
decltype(wei_c_x_k_block_desc),
|
||||
decltype(wei_c_x_k_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead>{};
|
||||
#else
|
||||
Blockwise3dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_x_k_global_desc),
|
||||
decltype(wei_c_x_k_block_desc),
|
||||
decltype(wei_c_x_k_block_desc.GetLengths()),
|
||||
Sequence<4, 1, 32>,
|
||||
WeiBlockCopyDataPerRead>{};
|
||||
#endif
|
||||
WeiBlockCopyDataPerRead_C>{};
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
|
||||
constexpr auto a_c_k_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<KPerBlock>{},
|
||||
Number<wei_c_x_k_block_desc.GetStride(I0)>{});
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
|
||||
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_x_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_wn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
@@ -252,6 +228,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
{
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
#if 1
|
||||
blockwise_in_copy_reorder.Run(p_in_global_block_offset +
|
||||
in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, 0),
|
||||
p_in_block);
|
||||
@@ -259,6 +236,23 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset +
|
||||
wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, 0),
|
||||
p_wei_block);
|
||||
#else
|
||||
Float p_in_clipboard[blockwise_in_copy_reorder.GetRegisterClipboardSize()];
|
||||
Float p_wei_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
blockwise_in_copy_reorder.RunLoadRegisterClipboard(
|
||||
p_in_global_block_offset + in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, 0),
|
||||
p_in_clipboard);
|
||||
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(
|
||||
p_wei_global_block_offset + wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, 0),
|
||||
p_wei_clipboard);
|
||||
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_clipboard, p_wei_block);
|
||||
|
||||
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_clipboard, p_in_block);
|
||||
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@@ -274,42 +268,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
}
|
||||
}
|
||||
|
||||
// output: register to global mem,
|
||||
#if 0
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
|
||||
{
|
||||
for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
|
||||
{
|
||||
for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
|
||||
{
|
||||
for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
|
||||
{
|
||||
const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
|
||||
|
||||
const auto c_thread_mtx_distance =
|
||||
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
|
||||
|
||||
const index_t ho_thread =
|
||||
c_thread_mtx_begin.batch + c_thread_mtx_distance.batch;
|
||||
const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
|
||||
const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
|
||||
|
||||
const index_t wo_thread = b_thread / NPerBlock;
|
||||
const index_t n_thread = b_thread % NPerBlock;
|
||||
|
||||
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
|
||||
ho_block_data_begin + ho_thread,
|
||||
wo_block_data_begin + wo_thread,
|
||||
n_block_data_begin + n_thread)] =
|
||||
p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif 1
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
@@ -356,7 +315,6 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite>{});
|
||||
#endif
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -44,7 +44,7 @@ template <class SrcData,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src,
|
||||
class F>
|
||||
__device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
|
||||
__device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_given_dst2src(
|
||||
SrcDesc,
|
||||
const SrcData* __restrict__ p_src,
|
||||
DstDesc,
|
||||
@@ -82,9 +82,9 @@ __device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_d
|
||||
const index_t bindex =
|
||||
dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]);
|
||||
|
||||
#if 1
|
||||
f(p_src[aindex], p_dst[bindex]);
|
||||
#else
|
||||
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0)
|
||||
{
|
||||
printf("tid %5u, "
|
||||
@@ -126,17 +126,16 @@ template <class SrcData,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src>
|
||||
__device__ void
|
||||
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
|
||||
const SrcData* __restrict__ p_src,
|
||||
DstDesc,
|
||||
DstData* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src)
|
||||
__device__ void threadwise_4d_tensor_copy_reorder_given_dst2src(SrcDesc,
|
||||
const SrcData* __restrict__ p_src,
|
||||
DstDesc,
|
||||
DstData* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src)
|
||||
{
|
||||
auto f_copy = [](const SrcData& src, DstData& dst) { dst = static_cast<DstData>(src); };
|
||||
|
||||
threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src(
|
||||
threadwise_4d_tensor_pointwise_operation_binary_reorder_given_dst2src(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
|
||||
}
|
||||
|
||||
@@ -146,7 +145,7 @@ __device__ void threadwise_4d_tensor_copy(
|
||||
{
|
||||
auto dst_from_src_reorder = Sequence<0, 1, 2, 3>{};
|
||||
|
||||
threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(
|
||||
threadwise_4d_tensor_copy_reorder_given_dst2src(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder);
|
||||
}
|
||||
|
||||
@@ -212,6 +211,61 @@ __device__ void threadwise_4d_tensor_copy_v2(SrcDesc,
|
||||
}
|
||||
}
|
||||
|
||||
template <class SrcData,
|
||||
class DstData,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class MapDst2Src>
|
||||
__device__ void
|
||||
threadwise_4d_tensor_copy_reorder_given_dst2src_v2(SrcDesc,
|
||||
const SrcData* __restrict__ p_src,
|
||||
DstDesc,
|
||||
DstData* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
MapDst2Src)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr index_t IR0 = MapDst2Src{}.Get(I0);
|
||||
constexpr index_t IR1 = MapDst2Src{}.Get(I1);
|
||||
constexpr index_t IR2 = MapDst2Src{}.Get(I2);
|
||||
constexpr index_t IR3 = MapDst2Src{}.Get(I3);
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
// ref_desc has dst_desc's ordering
|
||||
constexpr auto ref_desc =
|
||||
make_ConstantTensorDescriptor(SrcOpLengths{}.ReorderGivenNew2Old(MapDst2Src{}));
|
||||
|
||||
for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
|
||||
{
|
||||
for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
|
||||
{
|
||||
for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
|
||||
{
|
||||
const auto dst_multi_id = Array<index_t, 4>{did0, did1, did2, did3};
|
||||
|
||||
const auto src_multi_id =
|
||||
reorder_array_given_old2new(dst_multi_id, MapDst2Src{});
|
||||
|
||||
const index_t dst_index = dst_desc.Get1dIndex(dst_multi_id);
|
||||
|
||||
const index_t src_index = src_desc.Get1dIndex(src_multi_id);
|
||||
|
||||
p_dst[dst_index] = p_src[src_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class Float, class Desc, class IDim, class NShift>
|
||||
__device__ void threadwise_4d_tensor_shift_down(Desc, Float* __restrict__ p, IDim, NShift)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user