diff --git a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp index e86b7556fe..b11cbfb879 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp +++ b/example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_xdl_fp16_comp_bf8_fp8.cpp @@ -35,4 +35,12 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat #include "run_grouped_conv_bwd_data_example.inc" -int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); } +int main(int argc, char* argv[]) +{ + // temp disable on gfx11 + if(ck::is_gfx11_supported()) + { + return 0; + } + return run_grouped_conv_bwd_data_example(argc, argv); +} diff --git a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc index 9ddc541463..37609f2492 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc +++ b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_bias_relu_example.inc @@ -108,7 +108,7 @@ bool run_conv_bwd_data_bias_relu(const ExecutionConfig& config, "not support this Conv problem - skipping." << std::endl; - return true; + return false; } float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); diff --git a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_example.inc b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_example.inc index 0e9da3e5e3..9d10f11ca8 100644 --- a/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_example.inc +++ b/example/38_grouped_conv_bwd_data_multiple_d/run_grouped_conv_bwd_data_example.inc @@ -96,7 +96,7 @@ bool run_conv_bwd_data(const ExecutionConfig& config, "not support this Conv problem - skipping." << std::endl; - return true; + return false; } float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 3cce0f7f09..fc0f23fc16 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -18,7 +18,6 @@ #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" #include "ck/host_utility/device_prop.hpp" @@ -139,30 +138,22 @@ __launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) static_for<0, NumDTensor, 1>{}( [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - index_t group_id; - if(gemms_count == 1) + index_t left = 0; + index_t right = gemms_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ && + block_args_id < gemm_kernel_args[group_id].BlockEnd_)) && + left <= right) { - group_id = 0; - } - else - { - index_t left = 0; - index_t right = gemms_count; - group_id = index_t((left + right) / 2); - while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ && - block_args_id < gemm_kernel_args[group_id].BlockEnd_)) && - left <= right) + if(block_args_id < gemm_kernel_args[group_id].BlockStart_) { - if(block_args_id < gemm_kernel_args[group_id].BlockStart_) - { - right = group_id; - } - else - { - left = group_id; - } - group_id = index_t((left + right) / 2); + right = group_id; } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); } if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm) @@ -182,8 +173,7 @@ __launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_kernel_args[group_id].block_2_ctile_map_, KBatch, - k_idx, - gemm_kernel_args[group_id].e_grid_desc_m_n_); + k_idx); } else { @@ -204,8 +194,7 @@ __launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_kernel_args[group_id].block_2_ctile_map_, KBatch, - k_idx, - gemm_kernel_args[group_id].e_grid_desc_m_n_); + k_idx); } else { @@ -224,8 +213,7 @@ __launch_bounds__(GridwiseGemm::MaxBlockSize, CK_MIN_BLOCK_PER_CU) gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, gemm_kernel_args[group_id].block_2_ctile_map_, KBatch, - k_idx, - gemm_kernel_args[group_id].e_grid_desc_m_n_); + k_idx); } } } @@ -522,97 +510,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 GridwiseGemmCTransposeBase, GridwiseGemm32>; - // Non-grouped GridwiseGemm for single-group specialization. - // Uses simpler epilogue and address computation, matching the non-grouped kernel. - // Requires AK1 == BK1 (true for all backward data convolution instances). - template - using NonGroupedGridwiseGemmBase = - GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3, - 7, - 1>; - using NonGroupedGridwiseGemm64 = NonGroupedGridwiseGemmBase; - - // Flat descriptor type aliases for group_count=1 fast path. - // Derived from the _Packed() methods in ConvToGemmBwdDataTransform. - // Uses a canonical NHWGK/NHWGC layout for type derivation (the _Packed() methods - // have enable_if on layout, but the resulting descriptor types are layout-independent). - template - struct FlatDescTypes - { - using ADesc = int; - using BDesc = int; - using CDesc = int; - }; - - template - struct FlatDescTypes - { - using CanonicalTransform = TransformConvBwdDataToGemm_v1<2, - ConvBackwardDataSpecialization, - AK1, - BK1, - MPerBlock, - NPerBlock, - KPerBlock, - DoPadGemmM, - DoPadGemmN, - tensor_layout::convolution::NHWGK, - tensor_layout::convolution::GKYXC, - tensor_layout::convolution::NHWGC, - true, - ABDataType, - EDataType, - 1, - index_t, - false>; - - using ADesc = remove_cvref_t< - decltype(std::declval().MakeADescriptor_AK0_M_AK1_Packed())>; - using BDesc = remove_cvref_t< - decltype(std::declval().MakeBDescriptor_BK0_N_BK1_Packed())>; - using CDesc = remove_cvref_t< - decltype(std::declval().MakeCDescriptor_M_N_Packed())>; - }; - - using FlatAGridDesc_K0_M_K1 = - typename FlatDescTypes::ADesc; - using FlatBGridDesc_K0_N_K1 = - typename FlatDescTypes::BDesc; - using FlatCGridDesc_M_N = - typename FlatDescTypes::CDesc; - template static auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N e_grid_desc_m_n) @@ -668,7 +565,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ds_grid_desc_mblock_mperblock_nblock_nperblock, EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, - EGridDesc_M_N e_grid_desc_m_n, GroupedGemmBlock2ETileMap block_2_ctile_map, index_t BlockStart, index_t BlockEnd, @@ -682,8 +578,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 e_grid_desc_mblock_mperblock_nblock_nperblock_( e_grid_desc_mblock_mperblock_nblock_nperblock), - e_grid_desc_m_n_(e_grid_desc_m_n), - // block-to-e-tile map block_2_ctile_map_(block_2_ctile_map), BlockStart_(BlockStart), @@ -698,7 +592,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_; EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; - EGridDesc_M_N e_grid_desc_m_n_; // block-to-e-tile map GroupedGemmBlock2ETileMap block_2_ctile_map_; @@ -831,7 +724,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths}, conv_filter_strides_{conv_filter_strides}, - conv_filter_dilations_{conv_filter_dilations}, input_left_pads_{input_left_pads}, input_right_pads_{input_right_pads}, k_batch_{split_k} @@ -1078,7 +970,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 ds_grid_desc_m_n), MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( e_grid_desc_m_n), - e_grid_desc_m_n, block_2_etile_map, BlockStart, BlockEnd, @@ -1089,22 +980,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 gemms_grid_size_.push_back(grid_size); grid_size = 0; } - - // Build packed descriptors for group_count=1 fast path - if constexpr(NDimSpatial == 2 && !CTranspose) - { - // K must be >= AK1 to ensure K0 = K/AK1 >= 1; otherwise - // the flat descriptor would have K0=0 which is invalid. - if(num_group_ == 1 && a_g_n_k_wos_lengths[2] >= AK1) - { - flat_a_container_.push_back( - conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1_Packed()); - flat_b_container_.push_back( - conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1_Packed()); - flat_c_container_.push_back( - conv_to_gemm_transform_.MakeCDescriptor_M_N_Packed()); - } - } } } } @@ -1270,15 +1145,9 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 std::array b_g_k_c_xs_lengths_; std::array e_g_n_c_wis_lengths_; std::array conv_filter_strides_; - std::array conv_filter_dilations_; std::array input_left_pads_; std::array input_right_pads_; - // Flat descriptors for group_count=1 fast path - std::vector flat_a_container_; - std::vector flat_b_container_; - std::vector flat_c_container_; - const index_t k_batch_; index_t num_workgroups_per_Conv_N_; std::vector gemms_grid_size_; @@ -1353,7 +1222,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 no_loop_in_all_gemm &= !gemm_kernel_args[i].HasMainKBlockLoop_; } - // Original multi-group kernel launch auto launch_kernel = [&](auto has_main_k_block_loop, auto no_main_k_block_loop) { constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool no_main_loop = no_main_k_block_loop.value; @@ -1484,95 +1352,20 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 } } }; - - // Dispatch: use flat-descriptor path (non-grouped GEMM) for G=1, - // NDim=2. The non-grouped kernel uses packed (flat) descriptors - // with fewer transform layers, avoiding the grouped GEMM overhead. - // The NoShuffle epilogue used by this path supports only unary - // element-wise ops (e.g. PassThrough). When D tensors are present - // the cde_element_op is multi-input (e.g. AddRelu(y, x0, x1)), - // which is incompatible with the unary VGPR->global writer; fall - // back to the regular CShuffle path in that case. - bool used_flat_desc = false; - if constexpr(NDimSpatial == 2 && !CTranspose && NumDTensor == 0) + if(has_loop_in_all_gemm) { - if(arg.num_group_ == 1 && arg.k_batch_ == 1 && !arg.flat_a_container_.empty()) - { - used_flat_desc = true; - const index_t flat_idx = gemm_set_id; - const auto& flat_a = arg.flat_a_container_[flat_idx]; - const auto& flat_b = arg.flat_b_container_[flat_idx]; - const auto& flat_c = arg.flat_c_container_[flat_idx]; - const index_t padded_K0 = flat_a.GetLength(I0); - const bool flat_desc_has_main_loop = - NonGroupedGridwiseGemm64::CalculateHasMainKBlockLoop(padded_K0 * AK1); - const index_t flat_grid_size = - NonGroupedGridwiseGemm64::Block2CTileMap::CalculateGridSize( - flat_c.GetLength(I0), flat_c.GetLength(I1)); - if(flat_desc_has_main_loop) - { - const auto kernel = kernel_gemm_xdlops_v2r3; - ave_time += launch_and_time_kernel_with_preprocess(stream_config, - clear_workspace, - kernel, - dim3(flat_grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_e_grid, - flat_a, - flat_b, - flat_c); - } - else - { - const auto kernel = kernel_gemm_xdlops_v2r3; - ave_time += launch_and_time_kernel_with_preprocess(stream_config, - clear_workspace, - kernel, - dim3(flat_grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_e_grid, - flat_a, - flat_b, - flat_c); - } - } + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); } - - if(!used_flat_desc) + else if(no_loop_in_all_gemm) { - if(has_loop_in_all_gemm) - { - ave_time += launch_kernel(integral_constant{}, - integral_constant{}); - } - else if(no_loop_in_all_gemm) - { - ave_time += launch_kernel(integral_constant{}, - integral_constant{}); - } - else - { - ave_time += launch_kernel(integral_constant{}, - integral_constant{}); - } + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); } } @@ -1788,17 +1581,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 { return false; } - // This entire device template instantiates XDL (MFMA) kernels, which are - // CDNA-only. The shared is_xdl_wmma_supported() helper above can return - // true for FP16/BF16 with 16x16 on RDNA (gfx11/gfx12) because it is - // also used by WMMA device templates. Reject all instances of this XDL - // template on RDNA to avoid launching MFMA kernels on hardware that - // does not implement those intrinsics. The corresponding WMMA path - // lives in device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp. - if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) - { - return false; - } if(!is_bf16_atomic_supported() && std::is_same_v && arg.k_batch_ > 1) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index 4510ac82da..16e5feb0ea 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -704,8 +704,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle typename BGridDesc_BK0_N_BK1, typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - typename Block2ETileMap, - typename EGridDesc_M_N_Direct = Tuple<>> + typename Block2ETileMap> __device__ static void Run(const ADataType* __restrict__ p_a_grid, const BDataType* __restrict__ p_b_grid, DsGridPointer p_ds_grid, @@ -721,9 +720,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap& block_2_etile_map, - const index_t k_batch = 1, - const index_t k_idx = 0, - const EGridDesc_M_N_Direct& e_grid_desc_m_n_direct = {}) + const index_t k_batch = 1, + const index_t k_idx = 0) { const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); @@ -912,61 +910,18 @@ struct GridwiseGemmMultipleD_xdl_cshuffle c_thread_buf, num_k_block_main_loop); - // Use RunEpilogueNoShuffle from gridwise_common (the base class) for - // direct VGPR-to-global output when conditions are met: - // (1) no D tensors, (2) scalar-per-vector is 1, (3) PassThrough element-wise op. - if constexpr(NumDTensor == 0 && CDEShuffleBlockTransferScalarPerVector_NPerBlock == 1 && - is_same_v) - { - const auto e_grid_desc_m_n = [&]() { - if constexpr(!is_same_v>) - { - return e_grid_desc_m_n_direct; - } - else - { - return transform_tensor_descriptor( - e_grid_desc_mblock_mperblock_nblock_nperblock, - make_tuple(make_merge_transform(make_tuple( - e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), - Number{})), - make_merge_transform(make_tuple( - e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2), - Number{}))), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - }(); - - const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm.MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(e_grid_desc_m_n); - - Base::template RunEpilogueNoShuffle, - 7>(blockwise_gemm, - c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, - c_thread_buf, - block_work_idx[I0], - block_work_idx[I1], - p_e_grid, - cde_element_op); - } - else - { - Base::template RunMultiDEpilogue( - blockwise_gemm, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - c_thread_buf, - block_work_idx[I0], - block_work_idx[I1], - p_shared, - p_ds_grid, - p_e_grid, - cde_element_op); - } + // Shuffle C and write out. + Base::template RunMultiDEpilogue( + blockwise_gemm, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + c_thread_buf, + block_work_idx[I0], + block_work_idx[I1], + p_shared, + p_ds_grid, + p_e_grid, + cde_element_op); } template || - is_same_v), - bool>::type = false> - __host__ __device__ auto MakeADescriptor_AK0_M_AK1_Packed() const - { - const auto K0 = K_ / AK1; - - const auto out_n_ho_wo_k_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N_, Ho_, Wo_, K_)); - - if constexpr(ConvBwdDataSpecialization == - ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: - Filter1x1Stride1Pad0) - { - const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(N_ * Ho_ * Wo_, K_)), - make_tuple(make_pass_through_transform(N_ * Ho_ * Wo_), - make_unmerge_transform(make_tuple(K0, Number{}))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0, 2>{})); - - return ck::tensor_operation::device::PadTensorDescriptor( - out_gemmk0_gemmm_gemmk1_grid_desc, - make_tuple(Number{}, Number{}, Number{}), - Sequence{}); - } - else - { - const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_); - const auto XDotSlice = math::integer_divide_ceil(X_ - IdxXTilde_, XTilde_); - - const auto IHTildeSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_); - const auto IWTildeSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_); - const auto IHTildeSliceEnd = math::min( - HTilde_, math::integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1); - const auto IWTildeSliceEnd = math::min( - WTilde_, math::integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1); - const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; - const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; - - const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( - out_n_ho_wo_k_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Ho_, I0, I0), - make_pad_transform(Wo_, I0, I0), - make_pass_through_transform(K_)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( - out_n_hop_wop_k_grid_desc, - make_tuple( - make_pass_through_transform(N_), - make_embed_transform(make_tuple(YDot_, HTilde_), - make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), - make_embed_transform(make_tuple(XDot_, WTilde_), - make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), - make_pass_through_transform(K_)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = - transform_tensor_descriptor( - out_n_ydot_htilde_xdot_wtilde_k_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_slice_transform(YDot_, I0, YDotSlice), - make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), - make_slice_transform(XDot_, I0, XDotSlice), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), - make_unmerge_transform(make_tuple(K0, Number{}))), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5, 6>{})); - - const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), - make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice)), - make_pass_through_transform(Number{})), - make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return ck::tensor_operation::device::PadTensorDescriptor( - out_gemmk0_gemmm_gemmk1_grid_desc, - make_tuple(Number{}, Number{}, Number{}), - Sequence{}); - } - } - - template ), - bool>::type = false> - __host__ __device__ auto MakeBDescriptor_BK0_N_BK1_Packed() const - { - const auto K0 = K_ / BK1; - - if constexpr(ConvBwdDataSpecialization == - ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: - Filter1x1Stride1Pad0) - { - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - make_naive_tensor_descriptor_packed(make_tuple(K_, C_)), - make_tuple(make_unmerge_transform(make_tuple(K0, Number{})), - make_pass_through_transform(C_)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return ck::tensor_operation::device::PadTensorDescriptor( - wei_gemmk0_gemmn_gemmk1_grid_desc, - make_tuple(Number{}, Number{}, Number{}), - Sequence{}); - } - else - { - const auto wei_k_y_x_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(K_, Y_, X_, C_)); - - const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_); - const auto XDotSlice = math::integer_divide_ceil(X_ - IdxXTilde_, XTilde_); - - const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( - wei_k_y_x_c_grid_desc, - make_tuple(make_pass_through_transform(K_), - make_embed_transform(make_tuple(YDot_, YTilde_), - make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), - make_embed_transform(make_tuple(XDot_, XTilde_), - make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), - make_pass_through_transform(C_)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor( - wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(K0, Number{})), - make_slice_transform(YDot_, I0, YDotSlice), - make_slice_transform(XDot_, I0, XDotSlice), - make_freeze_transform(IdxYTilde_), - make_freeze_transform(IdxXTilde_), - make_pass_through_transform(C_)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<3>{}, - Sequence<2>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0, 1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<>{}, - Sequence<>{}, - Sequence<4>{})); - - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( - wei_k0_k1_ydotslice_xdotslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), - make_pass_through_transform(C_), - make_pass_through_transform(Number{})), - make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return ck::tensor_operation::device::PadTensorDescriptor( - wei_gemmk0_gemmn_gemmk1_grid_desc, - make_tuple(Number{}, Number{}, Number{}), - Sequence{}); - } - } - - template || - is_same_v || - is_same_v), - bool>::type = false> - __host__ __device__ auto MakeCDescriptor_M_N_Packed() const - { - const auto in_n_hi_wi_c_grid_desc = - make_naive_tensor_descriptor_packed(make_tuple(N_, Hi_, Wi_, C_)); - - if constexpr(ConvBwdDataSpecialization == - ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: - Filter1x1Stride1Pad0) - { - const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)), - make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)), - make_pass_through_transform(C_)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - in_n_y_ho_x_wo_c_grid_desc, - make_tuple(make_freeze_transform(I0), - make_freeze_transform(I0), - make_merge_transform(make_tuple(N_, Ho_, Wo_)), - make_pass_through_transform(C_)), - make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), - make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); - - return ck::tensor_operation::device::PadTensorDescriptor( - in_gemmm_gemmn_grid_desc, - make_tuple(Number{}, Number{}), - Sequence{}); - } - else - { - const auto IHTildeSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_); - const auto IWTildeSliceBegin = math::integer_divide_floor( - math::max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_); - const auto IHTildeSliceEnd = math::min( - HTilde_, math::integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1); - const auto IWTildeSliceEnd = math::min( - WTilde_, math::integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1); - const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; - const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; - - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_n_hi_wi_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), - make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(C_)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); - - const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(YTilde_, HTilde_), - make_tuple(ConvDilationH_, ConvStrideH_)), - make_embed_transform(make_tuple(XTilde_, WTilde_), - make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(C_)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); - - const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( - in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_freeze_transform(IdxYTilde_), - make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), - make_freeze_transform(IdxXTilde_), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(C_)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<0>{}, - Sequence<>{}, - Sequence<1>{}, - Sequence<>{}, - Sequence<2>{}, - Sequence<3>{})); - - const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( - in_n_htildeslice_wtildeslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice)), - make_pass_through_transform(C_)), - make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - return ck::tensor_operation::device::PadTensorDescriptor( - in_gemmm_gemmn_grid_desc, - make_tuple(Number{}, Number{}), - Sequence{}); - } - } - IndexType N_; IndexType Di_, Hi_, Wi_; IndexType Do_, Ho_, Wo_; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp index 057ea19cbd..970bcb0439 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp @@ -154,39 +154,6 @@ using device_grouped_conv_bwd_data_xdl_f16_instances = // clang-format on >; -template -using device_grouped_conv_bwd_data_xdl_f16_noshuffle_instances = - std::tuple< - // clang-format off - // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| - // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // f16_f16_f32_f16 - noshuffle epilogue (CDEBlockTransferScalarPerVector_NPerBlock = 1) - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> - // clang-format on - >; - template ; -// bf16_bf16_f32_bf16 - noshuffle epilogue (CDEBlockTransferScalarPerVector_NPerBlock = 1) -// Same tile shapes as bf16_instances but with ScalarPerVector=1, enabling the no-shuffle fast path -// (VGPR -> Global direct write, 0 LDS barriers) instead of CShuffle (VGPR -> LDS -> Global, 8 -// barriers). -template -using device_grouped_conv_bwd_data_xdl_bf16_noshuffle_instances = std::tuple< - // clang-format off - // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| - // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> - // clang-format on - >; - // f32_f32_f32_f32 template ; -template -using device_grouped_conv_bwd_data_xdl_f32_noshuffle_instances = - std::tuple< - // clang-format off - // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| - // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // f32_f32_f32_f32 - noshuffle epilogue (CDEBlockTransferScalarPerVector_NPerBlock = 1) - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 1>, - - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> - // clang-format on - >; - -// bf16 - BBlockTransfer parameters matching the non-grouped DeviceConvNdBwdDataNwcKxcNwk_Xdl -// instances. The key difference from bf16_instances: BBlockTransfer uses S<4, BlockSize/4, 1> -// thread cluster and S<2, 0, 1> arrange order, which gives full thread utilization for B-matrix -// loads. These are optimal when opt3 flat descriptor path is active (G=1, 2D convolutions). -template -using device_grouped_conv_bwd_data_xdl_bf16_nongrouped_match_instances = std::tuple< - // clang-format off - // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer| - // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| - // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> - // clang-format on - >; - -// f16 - BBlockTransfer parameters matching the non-grouped DeviceConvNdBwdDataNwcKxcNwk_Xdl -// instances. -template -using device_grouped_conv_bwd_data_xdl_f16_nongrouped_match_instances = std::tuple< - // clang-format off - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 1> - // clang-format on - >; - -// f32 - BBlockTransfer parameters matching the non-grouped DeviceConvNdBwdDataNwcKxcNwk_Xdl -// instances. F32 uses K1=4, KPerBlock=16, and smaller scalar-per-vector values. -template -using device_grouped_conv_bwd_data_xdl_f32_nongrouped_match_instances = std::tuple< - // clang-format off - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 32, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 4>, 1>, - DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<2, 0, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 1> - // clang-format on - >; - // f32_f32_f32_f32 tf32 template {}); - // 3. Default - noshuffle epilogue - add_device_operation_instances( - instances, - device_grouped_conv_bwd_data_xdl_bf16_noshuffle_instances<2, - NHWGK, - GKYXC, - Empty_Tuple, - NHWGC, - ConvBwdDataDefault>{}); - // 4. Filter1x1Stride1Pad0 - noshuffle epilogue - add_device_operation_instances(instances, - device_grouped_conv_bwd_data_xdl_bf16_noshuffle_instances< - 2, - NHWGK, - GKYXC, - Empty_Tuple, - NHWGC, - ConvBwdDataFilter1x1Stride1Pad0>{}); - // 5. Default - nongrouped_match instances - add_device_operation_instances( - instances, - device_grouped_conv_bwd_data_xdl_bf16_nongrouped_match_instances<2, - NHWGK, - GKYXC, - Empty_Tuple, - NHWGC, - ConvBwdDataDefault>{}); - // 6. Filter1x1Stride1Pad0 - nongrouped_match instances - add_device_operation_instances(instances, - device_grouped_conv_bwd_data_xdl_bf16_nongrouped_match_instances< - 2, - NHWGK, - GKYXC, - Empty_Tuple, - NHWGC, - ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp index dfcfac2bfa..085b6aaaf5 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -41,42 +41,6 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( Empty_Tuple, NHWGC, ConvBwdDataFilter1x1Stride1Pad0>{}); - // 3. Default - noshuffle epilogue - add_device_operation_instances( - instances, - device_grouped_conv_bwd_data_xdl_f16_noshuffle_instances<2, - NHWGK, - GKYXC, - Empty_Tuple, - NHWGC, - ConvBwdDataDefault>{}); - // 4. Filter1x1Stride1Pad0 - noshuffle epilogue - add_device_operation_instances(instances, - device_grouped_conv_bwd_data_xdl_f16_noshuffle_instances< - 2, - NHWGK, - GKYXC, - Empty_Tuple, - NHWGC, - ConvBwdDataFilter1x1Stride1Pad0>{}); - // 5. Default - nongrouped_match instances - add_device_operation_instances( - instances, - device_grouped_conv_bwd_data_xdl_f16_nongrouped_match_instances<2, - NHWGK, - GKYXC, - Empty_Tuple, - NHWGC, - ConvBwdDataDefault>{}); - // 6. Filter1x1Stride1Pad0 - nongrouped_match instances - add_device_operation_instances(instances, - device_grouped_conv_bwd_data_xdl_f16_nongrouped_match_instances< - 2, - NHWGK, - GKYXC, - Empty_Tuple, - NHWGC, - ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp index 011a9d5cd1..02a0eeb517 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp @@ -41,42 +41,6 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances( Empty_Tuple, NHWGC, ConvBwdDataFilter1x1Stride1Pad0>{}); - // 3. Default - noshuffle epilogue - add_device_operation_instances( - instances, - device_grouped_conv_bwd_data_xdl_f32_noshuffle_instances<2, - NHWGK, - GKYXC, - Empty_Tuple, - NHWGC, - ConvBwdDataDefault>{}); - // 4. Filter1x1Stride1Pad0 - noshuffle epilogue - add_device_operation_instances(instances, - device_grouped_conv_bwd_data_xdl_f32_noshuffle_instances< - 2, - NHWGK, - GKYXC, - Empty_Tuple, - NHWGC, - ConvBwdDataFilter1x1Stride1Pad0>{}); - // 5. Default - nongrouped_match instances - add_device_operation_instances( - instances, - device_grouped_conv_bwd_data_xdl_f32_nongrouped_match_instances<2, - NHWGK, - GKYXC, - Empty_Tuple, - NHWGC, - ConvBwdDataDefault>{}); - // 6. Filter1x1Stride1Pad0 - nongrouped_match instances - add_device_operation_instances(instances, - device_grouped_conv_bwd_data_xdl_f32_nongrouped_match_instances< - 2, - NHWGK, - GKYXC, - Empty_Tuple, - NHWGC, - ConvBwdDataFilter1x1Stride1Pad0>{}); } } // namespace instance diff --git a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface.cpp b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface.cpp index dd097c8e97..ab89d9d0f0 100644 --- a/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface.cpp +++ b/test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface.cpp @@ -256,11 +256,6 @@ TYPED_TEST_SUITE(TestGroupedConvndBwdDataFilter1x1Wmma, KernelTypes); TYPED_TEST(TestGroupedConvndBwdDataFilter1x1Xdl, SpecializationCheckXdl) { - if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) - { - GTEST_SKIP() << "XDL (MFMA) bwd_data instances are not supported on RDNA (gfx11/gfx12)"; - } - // Check filter 3,3 instead of 1,1 this->conv_param = {2, 2, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; bool is_supported = this->template Run<2>(); @@ -331,10 +326,6 @@ TYPED_TEST(TestGroupedConvndBwdDataDefaultWmma, VectorLoadCheckWmma) TYPED_TEST(TestGroupedConvndBwdDataDefaultXdl, SplitK) { - if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) - { - GTEST_SKIP() << "XDL (MFMA) bwd_data instances are not supported on RDNA (gfx11/gfx12)"; - } if(ck::is_xdl_supported()) { // SplitK = 1