mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
refactor
This commit is contained in:
@@ -281,6 +281,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
#if 0
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
|
||||
@@ -337,6 +338,55 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw
|
||||
1,
|
||||
1>({0, 0, 0, 0, 0, 0, 0, 0}, {0, 0, 0, 0, 0, 0, 0, 0})
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
#else
|
||||
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
// output memory layout descriptor in register, src of threadwise copy
|
||||
constexpr auto out_k0_k1_n1_b_n2_thread_mem_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, N1, 1, N2>{});
|
||||
|
||||
// output memory layout descriptor in device memory
|
||||
constexpr auto out_n0_n1_n2_k0_k1_h_w_global_mem_desc =
|
||||
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}).Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
// output merged global tensor descriptor, dst of threadwise copy
|
||||
constexpr auto out_k0_k1_n1_b_n2_global_merged_desc =
|
||||
make_ConstantMergedTensorDescriptor(out_n0_n1_n2_k0_k1_h_w_global_mem_desc,
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<1>{},
|
||||
Sequence<0, 5, 6>{},
|
||||
Sequence<2>{});
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_on_global =
|
||||
k_block_data_on_global + c_thread_mtx_on_block.row;
|
||||
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v2r1<
|
||||
decltype(out_k0_k1_n1_b_n2_thread_mem_desc),
|
||||
decltype(out_k0_k1_n1_b_n2_global_merged_desc),
|
||||
decltype(out_k0_k1_n1_b_n2_thread_mem_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 5, 1>::type,
|
||||
arithmetic_sequence_gen<0, 5, 1>::type,
|
||||
3,
|
||||
3,
|
||||
1,
|
||||
1>({0, 0, 0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
0,
|
||||
b_thread_data_on_global,
|
||||
0})
|
||||
.template Run_amd_experiment<Float, 0, 2>(p_out_thread, p_out_global);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -295,7 +295,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
|
||||
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
|
||||
constexpr index_t N = 128;
|
||||
|
||||
Reference in New Issue
Block a user