This commit is contained in:
Chao Liu
2019-09-18 16:08:24 -05:00
parent 86cc678f18
commit 94bb1b4835
2 changed files with 51 additions and 1 deletions

View File

@@ -281,6 +281,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
// copy output: register to global memory
{
#if 0
constexpr index_t K2 = GemmMPerThreadSubC;
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
@@ -337,6 +338,55 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
1,
1>({0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0})
.Run(p_out_thread, p_out_thread_on_global);
#else
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
// define tensor descriptor for threadwise copy
// output memory layout descriptor in register, src of threadwise copy
constexpr auto out_k0_k1_n1_b_n2_thread_mem_desc = make_ConstantTensorDescriptor_packed(
Sequence<GemmMRepeat, GemmMPerThreadSubC, N1, 1, N2>{});
// output memory layout descriptor in device memory
constexpr auto out_n0_n1_n2_k0_k1_h_w_global_mem_desc =
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}).Fold(I0, Number<N1>{}, Number<N2>{});
// output merged global tensor descriptor, dst of threadwise copy
constexpr auto out_k0_k1_n1_b_n2_global_merged_desc =
make_ConstantMergedTensorDescriptor(out_n0_n1_n2_k0_k1_h_w_global_mem_desc,
Sequence<3>{},
Sequence<4>{},
Sequence<1>{},
Sequence<0, 5, 6>{},
Sequence<2>{});
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const auto c_thread_mtx_on_block =
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
const index_t k_thread_data_on_global =
k_block_data_on_global + c_thread_mtx_on_block.row;
const index_t b_thread_data_on_global =
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
ThreadwiseGenericTensorSliceCopy_v2r1<
decltype(out_k0_k1_n1_b_n2_thread_mem_desc),
decltype(out_k0_k1_n1_b_n2_global_merged_desc),
decltype(out_k0_k1_n1_b_n2_thread_mem_desc.GetLengths()),
arithmetic_sequence_gen<0, 5, 1>::type,
arithmetic_sequence_gen<0, 5, 1>::type,
3,
3,
1,
1>({0, 0, 0, 0, 0},
{k_thread_data_on_global / K1,
k_thread_data_on_global % K1,
0,
b_thread_data_on_global,
0})
.template Run_amd_experiment<Float, 0, 2>(p_out_thread, p_out_global);
#endif
}
}
};