From 9d3f634a3cc37599db9f3ebfeb1f9a3c45d9a673 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Mon, 23 Aug 2021 11:22:10 -0500 Subject: [PATCH] Xdlops refactor fix (#22) * added constexpr ahead of adptor; clean unused driver; rename M/NPerWave to M/NPerXDL * fixed bwd * fixed comment --- .../blockwise_gemm_xdlops.hpp | 3 +- .../gridwise_gemm_xdlops_v2r3.hpp | 20 +- .../include/tensor_operation/xdlops_gemm.hpp | 5 +- ...plicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp | 74 ++--- ...icit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp | 40 +-- ...plicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp | 280 ------------------ ...icit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp | 6 +- ...icit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp | 44 +-- .../include/driver_gemm_xdlops_v2r3.hpp | 8 +- .../src/conv_fwd_driver_offline.cpp | 6 +- ...tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp | 8 +- 11 files changed, 111 insertions(+), 383 deletions(-) delete mode 100644 host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp index 0c381de0d8..a8236737df 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp @@ -4,6 +4,7 @@ #include "common_header.hpp" #include "threadwise_tensor_slice_transfer.hpp" #include "xdlops_gemm.hpp" +#include "tensor_adaptor.hpp" namespace ck { @@ -40,7 +41,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 { const index_t thread_id = get_thread_local_1d_id(); - const auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), make_tuple(Sequence<0, 1, 2>{}), make_tuple(Sequence<0>{})); diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index 51af11fefc..3e4d74e9d8 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -103,8 +103,8 @@ template ; @@ -366,8 +368,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 FloatAB, decltype(a_k0_m_k1_block_desc), decltype(b_k0_n_k1_block_desc), - MPerWave, - NPerWave, + MPerXDL, + NPerXDL, MRepeat, NRepeat, K1>{}; diff --git a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp index 56581c024b..f945b0fdf5 100644 --- a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp +++ b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp @@ -749,8 +749,9 @@ struct XdlopsGemm __device__ static auto GetBlkIdx() { - const auto laneId = GetLaneId(); - const auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor( + const auto laneId = GetLaneId(); + + constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor( make_tuple(make_merge_transform( make_tuple(1, mfma_instr.num_input_blks, mfma_instr.num_threads_per_blk))), make_tuple(Sequence<0, 1, 2>{}), diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp index 7bd82bf6d5..7196f3c179 100644 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp @@ -56,9 +56,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; constexpr index_t MRepeat = 2; constexpr index_t NRepeat = 2; @@ -84,9 +84,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 2; constexpr index_t NRepeat = 2; @@ -112,9 +112,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 4; constexpr index_t NRepeat = 2; @@ -140,9 +140,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 2; constexpr index_t NRepeat = 4; @@ -168,9 +168,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; constexpr index_t MRepeat = 4; constexpr index_t NRepeat = 2; @@ -223,25 +223,27 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 - constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple( + // clang-format off + constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple( make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: MWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 3+: NWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 7+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: MRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: NRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: MWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 3-: NWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + //clang-format on constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}; @@ -263,8 +265,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( GemmMPerBlock, GemmNPerBlock, GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, + GemmMPerXDL, + GemmNPerXDL, GemmK1, MRepeat, NRepeat, @@ -289,7 +291,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( GemmCThreadTransferDstScalarPerVector, decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(in_m0_m1_m2_n_grid_step_hacks), + decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), false // CAccessOrderMRepeatNRepeat @@ -301,7 +303,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( in_gemmm_gemmn_grid_desc, wei_gemmk0_gemmm_gemmk1_grid_step_hacks, out_gemmk0_gemmn_gemmk1_grid_step_hacks, - in_m0_m1_m2_n_grid_step_hacks, + in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp index 0ebf8571f4..2cbae2daf3 100644 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp @@ -195,25 +195,27 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 - constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple( + // clang-format off + constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple( make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 2+: MWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: NWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 4+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 5+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 6+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: MRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: NRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 2-: MWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: NWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 4-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 5-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + //clang-format on constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; @@ -265,7 +267,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk GemmCThreadTransferDstScalarPerVector, decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(in_m0_m1_m2_n_grid_step_hacks), + decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), true // CAccessOrderMRepeatNRepeat @@ -277,7 +279,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk in_gemmm_gemmn_grid_desc, out_gemmk0_gemmm_gemmk1_grid_step_hacks, wei_gemmk0_gemmn_gemmk1_grid_step_hacks, - in_m0_m1_m2_n_grid_step_hacks, + in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 4a9d01081c..0000000000 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,280 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp" - -template -void device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( - const InLengths& in_n_c_hi_wi_lengths, - const WeiLengths& wei_k_c_y_x_lengths, - const OutLengths& out_n_k_ho_wo_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - constexpr auto I8 = Number<8>{}; - - DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); - - in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); - - const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); - const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); - const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); - -#if 0 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 8; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 0 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 0 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 4] - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 4] - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 4; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#endif - - const auto descs = -#if 1 - transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad -#else - transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1 -#endif - ( - wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads); - - for(index_t i = 0; i < 5; ++i) - { -#if 0 - float ave_time = launch_kernel_gemm_xdlops_v1 -#else - float ave_time = launch_kernel_gemm_xdlops_v2 -#endif - , - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK, - GemmABlockTransferDstScalarPerVector_KPack, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<0, 2, 1>, - Sequence<1, 0, 2>, - 1, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_KPack, - false, // don't move back src coordinate after threadwise copy, which will be fused - // with MoveSrcSliceWindow() to save addr computation - Sequence<2, 3, 0, 1>, - 3, - GemmCThreadTransferDstScalarPerVector_GemmN1, - decltype(descs[I4]), - decltype(descs[I5]), - decltype(descs[I6]), - decltype(descs[I7]), - decltype(descs[I8])>(static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), - static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - descs[I0], - descs[I1], - descs[I2], - descs[I3], - descs[I4], - descs[I5], - descs[I6], - descs[I7], - descs[I8], - nrepeat); - - float perf = (float)calculate_convolution_flops( - in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp index 34de1ac0ed..dc4f5eafb6 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -105,7 +105,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); - constexpr auto out_m0_m1_m2_n_grid_step_hacks = + constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, @@ -169,7 +169,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( GemmCThreadTransferDstScalarPerVector, decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(out_m0_m1_m2_n_grid_step_hacks), + decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), false>(static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), @@ -180,7 +180,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( out_gemmm_gemmn_grid_desc, wei_gemmk0_gemmm_gemmk1_grid_step_hacks, in_gemmk0_gemmn_gemmk1_grid_step_hacks, - out_m0_m1_m2_n_grid_step_hacks, + out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index a392df6aa8..5ff8dfb665 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -56,8 +56,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; constexpr index_t GemmK1 = 4; constexpr index_t MRepeat = 4; @@ -84,9 +84,9 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; constexpr index_t MRepeat = 2; constexpr index_t NRepeat = 2; @@ -112,9 +112,9 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 4; constexpr index_t NRepeat = 4; @@ -140,9 +140,9 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 4; constexpr index_t NRepeat = 2; @@ -168,9 +168,9 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 2; constexpr index_t NRepeat = 4; @@ -196,9 +196,9 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 2; constexpr index_t NRepeat = 2; @@ -249,7 +249,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 - constexpr auto out_m0_m1_m2_n_grid_step_hacks = + constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: MWaves @@ -287,8 +287,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( GemmMPerBlock, GemmNPerBlock, GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, + GemmMPerXDL, + GemmNPerXDL, GemmK1, MRepeat, NRepeat, @@ -313,7 +313,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( GemmCThreadTransferDstScalarPerVector, decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(out_m0_m1_m2_n_grid_step_hacks), + decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), false // CAccessOrderMRepeatNRepeat @@ -325,7 +325,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( out_gemmm_gemmn_grid_desc, in_gemmk0_gemmm_gemmk1_grid_step_hacks, wei_gemmk0_gemmn_gemmk1_grid_step_hacks, - out_m0_m1_m2_n_grid_step_hacks, + out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp index a33e5bf4d0..2bf8adba84 100644 --- a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp +++ b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp @@ -17,8 +17,8 @@ template