mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
@@ -181,12 +181,6 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
|
||||
InBlockCopyDataPerAccess_W2>({0, 0, 0, 0, b_block_data_on_global, 0, 0, 0},
|
||||
{0, 0, 0, 0, 0, 0, 0, 0});
|
||||
|
||||
#if 0
|
||||
{
|
||||
printf("id (%d %d), in offset: %d %d\n", get_block_1d_id(), get_thread_local_1d_id(), blockwise_in_copy.mThreadSrcOffset, blockwise_in_copy.mThreadDstOffset);
|
||||
}
|
||||
#endif
|
||||
|
||||
// weight tensor
|
||||
// tensor descriptor in device memory, src of blockwise copy
|
||||
constexpr auto wei_e_k_global_desc =
|
||||
|
||||
@@ -44,7 +44,8 @@ template <index_t GridSize,
|
||||
class WeiBlockCopySrcAccessOrder,
|
||||
class WeiBlockCopyDstAccessOrder,
|
||||
index_t WeiBlockCopySrcDataPerRead_E,
|
||||
index_t WeiBlockCopyDstDataPerWrite_K>
|
||||
index_t WeiBlockCopyDstDataPerWrite_K,
|
||||
index_t OutThreadCopyDataPerAccess_B>
|
||||
struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
@@ -133,12 +134,16 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
|
||||
decltype(in_e_b_global_desc),
|
||||
decltype(in_e_b_block_desc),
|
||||
MergedTensorCoordinate<decltype(in_e_b_global_desc)>,
|
||||
NormalTensorCoordinate<decltype(in_e_b_block_desc)>,
|
||||
decltype(in_e_b_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_E_B,
|
||||
InBlockCopyClusterLengths_E_B,
|
||||
InBlockCopyThreadClusterArrangeOrder>(
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
1,
|
||||
1,
|
||||
InBlockCopyDataPerAccess_B,
|
||||
InBlockCopyDataPerAccess_B>(
|
||||
{0, b_block_data_on_global}, {0, 0});
|
||||
|
||||
// weight tensor
|
||||
@@ -155,16 +160,21 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
// operator for blockwise copy of weight into LDS
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v2<
|
||||
BlockSize,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
NormalTensorCoordinate<decltype(wei_e_k_global_desc)>,
|
||||
NormalTensorCoordinate<decltype(wei_e_k_block_desc)>,
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder>({0, k_block_data_on_global}, {0, 0});
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
0,
|
||||
1,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
@@ -283,15 +293,20 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
|
||||
using OutThreadCopySliceLengths =
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC>;
|
||||
|
||||
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2<
|
||||
decltype(out_k0_k1_b_thread_desc),
|
||||
decltype(out_k0_k1_b_global_desc),
|
||||
NormalTensorCoordinate<decltype(out_k0_k1_b_thread_desc)>,
|
||||
MergedTensorCoordinate<decltype(out_k0_k1_b_global_desc)>,
|
||||
OutThreadCopySliceLengths>({0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
b_thread_data_on_global});
|
||||
auto threadwise_out_copy =
|
||||
ThreadwiseGenericTensorSliceCopy_v2r1<decltype(out_k0_k1_b_thread_desc),
|
||||
decltype(out_k0_k1_b_global_desc),
|
||||
OutThreadCopySliceLengths,
|
||||
arithmetic_sequence_gen<0, 3, 1>::type,
|
||||
arithmetic_sequence_gen<0, 3, 1>::type,
|
||||
2,
|
||||
2,
|
||||
OutThreadCopyDataPerAccess_B,
|
||||
OutThreadCopyDataPerAccess_B>(
|
||||
{0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
b_thread_data_on_global});
|
||||
|
||||
for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user