diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp index 8fca6a1e2f..6624570b27 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -185,7 +185,9 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 BElementwiseOperation, CElementwiseOperation> { - static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + using CDEShuffleBlockTransferScalarPerVectors_ = CDEShuffleBlockTransferScalarPerVectors; + using CDataType_ = CDataType; // GridwiseGemm using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3< diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp index 1ea4854bd3..a819b91b05 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp @@ -11,6 +11,8 @@ #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include namespace ck { namespace tensor_operation { @@ -48,7 +50,48 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; - using DeviceOp = DeviceGroupedConvBwdWeight_Explicit_Xdl; + static constexpr bool IsTwoStageNeeded = + sizeof(WeiDataType) % 4 != 0 && + DeviceGemmV3Op::CDEShuffleBlockTransferScalarPerVectors_::At(I0) % 2 != 0; + + using DeviceOp = DeviceGroupedConvBwdWeight_Explicit_Xdl; + using TwoStageIntermediateType = typename DeviceGemmV3Op::CDataType_; + + static constexpr index_t ElementwiseBlockSize = 256; + static constexpr index_t ElemsPerBlock = 256; + + static auto GetElementwiseCGridDesc(index_t merged_filter_dims) + { + const auto padd_size = merged_filter_dims % ElemsPerBlock == 0 + ? 0 + : ElemsPerBlock - merged_filter_dims % ElemsPerBlock; + const auto desc = make_naive_tensor_descriptor_packed(make_tuple(I1, merged_filter_dims)); + return transform_tensor_descriptor( + desc, + make_tuple(make_pass_through_transform(I1), + make_right_pad_transform(merged_filter_dims, padd_size)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + using CElementwiseGridDesc = remove_cvref_t; + using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<1, ElemsPerBlock>; + using GridwiseElementwiseCast = GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + WeiElementwiseOperation, + ElementwiseBlockSize, + I1, + ElemsPerBlock, + I1, + ElemsPerBlock / ElementwiseBlockSize, + Sequence<0, 1>, + Sequence<1>, + Sequence<1>, + I1, + I1>; struct Argument : public BaseArgument { @@ -58,11 +101,11 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl WeiDataType* p_wei_grid, const OutDataType* p_out_grid, const std::array&, // input - const std::array&, + const std::array& b_g_n_c_wis_strides, const std::array& e_g_k_c_xs_lengths, // weight - const std::array&, + const std::array& e_g_k_c_xs_strides, const std::array& a_g_n_k_wos_lengths, // output - const std::array&, + const std::array& a_g_n_k_wos_strides, const std::array& conv_filter_strides, const std::array&, const std::array& input_left_pads, @@ -74,42 +117,114 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl : filter_spatial_lengths_{}, conv_filter_strides_{conv_filter_strides}, input_left_pads_{input_left_pads}, - input_right_pads_{input_right_pads} + input_right_pads_{input_right_pads}, + p_wei_grid_{p_wei_grid} { constexpr index_t spatial_offset = 3; - const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset, + const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset, end(a_g_n_k_wos_lengths), index_t{1}, std::multiplies<>{}); - const index_t M = e_g_k_c_xs_lengths[I1]; - const index_t N = e_g_k_c_xs_lengths[I2]; - const index_t K = a_g_n_k_wos_lengths[I1] * DoHoWo; - const index_t BatchSize = a_g_n_k_wos_lengths[I0]; + const index_t M = e_g_k_c_xs_lengths[I1]; + const index_t N = e_g_k_c_xs_lengths[I2]; + const index_t K = a_g_n_k_wos_lengths[I1] * DoHoWo; - explicit_gemm_args = GemmArgument{p_out_grid, - p_in_grid, - {}, - p_wei_grid, - M, - N, - K, - BatchSize * M, - BatchSize * N, - {}, - N, - M, - N, - {}, - M * N, - BatchSize, - out_element_op, - in_element_op, - wei_element_op, - split_k}; + const index_t StrideOut = a_g_n_k_wos_strides[spatial_offset + NDimSpatial - 1]; + const index_t StrideIn = b_g_n_c_wis_strides[spatial_offset + NDimSpatial - 1]; + const index_t StrideWei = e_g_k_c_xs_strides[I1]; + const index_t StrideBatchOut = a_g_n_k_wos_strides[I0]; + const index_t StrideBatchIn = b_g_n_c_wis_strides[I0]; + const index_t StrideBatchWei = e_g_k_c_xs_strides[I0]; + + const index_t BatchSize = a_g_n_k_wos_lengths[I0]; std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset, end(e_g_k_c_xs_lengths), begin(filter_spatial_lengths_)); + + if constexpr(IsTwoStageNeeded) + { + const index_t merged_filter_dims = std::accumulate(begin(e_g_k_c_xs_lengths), + end(e_g_k_c_xs_lengths), + index_t{1}, + std::multiplies<>{}); + elementwise_desc_ = GetElementwiseCGridDesc(merged_filter_dims); + elementwise_block_2_ctile_map_ = Block2TileMapElementwise{1, merged_filter_dims}; + // Check if stride to last dimension is product of all other dimensions. Then it is + // packed. + is_filter_data_packed = + e_g_k_c_xs_strides[0] == (merged_filter_dims / e_g_k_c_xs_lengths[0]); + + // Data type is modified during launch. It is checked in IsSupported if user + // allocated workspace + explicit_gemm_args = GemmArgument{p_out_grid, + p_in_grid, + {}, + static_cast(nullptr), + M, + N, + K, + StrideOut, + StrideIn, + {}, + StrideWei, + StrideBatchOut, + StrideBatchIn, + {}, + StrideBatchWei, + BatchSize, + out_element_op, + in_element_op, + wei_element_op, + split_k}; + } + else + { + explicit_gemm_args = GemmArgument{p_out_grid, + p_in_grid, + {}, + p_wei_grid, + M, + N, + K, + StrideOut, + StrideIn, + {}, + StrideWei, + StrideBatchOut, + StrideBatchIn, + {}, + StrideBatchWei, + BatchSize, + out_element_op, + in_element_op, + wei_element_op, + split_k}; + } + } + + std::size_t GetWorkspaceETensorSizeBytes() const + { + if constexpr(IsTwoStageNeeded) + { + return sizeof(TwoStageIntermediateType) * elementwise_desc_.GetElementSpaceSize(); + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceSizeBytes() const + { + if constexpr(IsTwoStageNeeded) + { + return GetWorkspaceETensorSizeBytes(); + } + else + { + return 0; + } } GemmArgument explicit_gemm_args; @@ -117,16 +232,56 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl const std::array& conv_filter_strides_; const std::array& input_left_pads_; const std::array& input_right_pads_; + WeiDataType* p_wei_grid_; + bool is_filter_data_packed; + CElementwiseGridDesc elementwise_desc_; + Block2TileMapElementwise elementwise_block_2_ctile_map_; }; // Invoker struct Invoker : public BaseInvoker { - using Argument = DeviceOp::Argument; + using Argument = DeviceOp::Argument; + using GemmArgument = typename DeviceGemmV3Op::Argument; float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - return explicit_gemm_op.Run(arg.explicit_gemm_args, stream_config); + if constexpr(IsTwoStageNeeded) + { + // Modify to use workspace as output + GemmArgument explicit_gemm_args_with_workspace = arg.explicit_gemm_args; + explicit_gemm_args_with_workspace.p_c_grid = + static_cast(arg.p_workspace_); + float avg_time = + explicit_gemm_op.Run(explicit_gemm_args_with_workspace, stream_config); + const index_t grid_size = + arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.elementwise_desc_); + const auto kernel = kernel_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + WeiElementwiseOperation>; + + avg_time += launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(ElementwiseBlockSize), + 0, + make_tuple(arg.elementwise_desc_), + make_tuple(arg.elementwise_desc_), + make_tuple(static_cast(arg.p_workspace_)), + make_tuple(arg.p_wei_grid_), + arg.elementwise_block_2_ctile_map_, + element_wise::PassThrough{}); + return avg_time; + } + else + { + return explicit_gemm_op.Run(arg.explicit_gemm_args, stream_config); + } } float Run(const BaseArgument* p_arg, @@ -174,6 +329,26 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl return false; } } + if constexpr(IsTwoStageNeeded) + { + if(!arg.is_filter_data_packed) + { + return false; + } + // Check this here, it allows to use other instances from factory even + // if workspace is not allocated + if(!arg.p_workspace_) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Warning: Workspace for " + "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not " + "allocated, use SetWorkSpacePointer." + << std::endl; + } + return false; + } + } // Gridwise GEMM size return DeviceGemmV3Op::IsSupportedArgument(arg.explicit_gemm_args); } @@ -277,6 +452,33 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl return str.str(); } + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!"); + } }; } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp index 1d291cca39..0c44ca6613 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp @@ -88,6 +88,97 @@ using device_gemm_xdl_universal_km_kn_mn_mem_instances = std::tuple< DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 2, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> // clang-format on >; + +template +using device_gemm_xdl_universal_km_kn_mn_irregular_odd_m_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Latency friendly + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 2, 16, 16, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 2, 2, 16, 16, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 32, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 4, 16, 16, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 4, 4, 16, 16, 2, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, 1, 1, S<1, 16, 1, 8>, S<2>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 4>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 4, 4, 16, 16, 1, 2, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 4, 4, 16, 16, 1, 4, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 4, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 2, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, 1, 1, S<1, 16, 1, 16>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; + +template +using device_gemm_xdl_universal_km_kn_mn_odd_n_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Latency friendly + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 4>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 2, 16, 16, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 32, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 2, 2, 16, 16, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 32, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 4, 16, 16, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 4, 4, 16, 16, 2, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 4>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 4, 4, 16, 16, 1, 2, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 4, 4, 16, 16, 1, 4, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 4, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 4, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 16>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 2, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 16, 1, 16>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; + +template +using device_gemm_xdl_universal_km_kn_mn_irregular_odd_mn_instances = std::tuple< + // clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Latency friendly + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 4>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 2, 2, 16, 16, 1, 1, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // Memory friendly + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 8, 2, 16, 16, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 32, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 16, 64, 2, 2, 16, 16, 4, 1, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 32, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 128, 16, 64, 8, 4, 16, 16, 4, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 64, 16, 64, 4, 4, 16, 16, 2, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 64, 16, 16, 64, 4, 4, 16, 16, 1, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 4>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 64, 4, 4, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 64, 64, 4, 4, 16, 16, 1, 2, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 128, 64, 4, 4, 16, 16, 1, 4, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 8>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 4, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 0, 1, 1, S<1, 16, 1, 16>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2>, + DeviceBatchedGemmMultiD_Xdl_CShuffle_V3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 64, 2, 2, 16, 16, 1, 4, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 0, 1, 1, S<1, 16, 1, 16>, S<1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v2> + // clang-format on + >; + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp index a53a92e795..3c0784eef3 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight.hpp @@ -397,24 +397,19 @@ struct DeviceOperationInstanceFactory>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instances( - std::vector>>& instances); - -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instances( - std::vector>>& instances); - -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instances( +void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instances( - std::vector>>& instances); - -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instances( +void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instances( +void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instances( +void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instances( - std::vector>>& instances); - -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instances( - std::vector>>& instances); - -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instances( +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instances( - std::vector>>& instances); - -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instances( +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instances( +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instances( +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instances( - std::vector>>& instances); - -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instances( - std::vector>>& instances); - -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instances( +void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instances( - std::vector>>& instances); - -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instances( +void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instances( +void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instances( +void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instances( - std::vector>>& instances); - -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instances( - std::vector>>& instances); - -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instances( +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instances( - std::vector>>& instances); - -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instances( +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instances( +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances( std::vector>>& instances); -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instances( +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances( + std::vector>>& instances); + +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances( std::vector>(instances); + device_gemm_xdl_universal_km_kn_mn_comp_instances>(instances); } -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_kpadding_instances( +void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances( std::vector>(instances); + device_gemm_xdl_universal_km_kn_mn_comp_instances>(instances); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp similarity index 95% rename from library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp index e11c9c68ad..0cf0b7f9e3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instances( +void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances( std::vector>( + device_gemm_xdl_universal_km_kn_mn_mem_instances>( instances); } -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mkpadding_instances( +void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances( std::vector>( + device_gemm_xdl_universal_km_kn_mn_mem_instances>( instances); } diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instance.cpp deleted file mode 100644 index 109f42703a..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instance.cpp +++ /dev/null @@ -1,67 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp" -#include "ck/host_utility/device_prop.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instances( - std::vector>>& instances) -{ - add_explicit_gemm_device_operation_instances< - 2, - NHWGC, - GKYXC, - NHWGK, - BF16, - BF16, - BF16, - PassThrough, - PassThrough, - PassThrough, - device_gemm_xdl_universal_km_kn_mn_mem_instances>(instances); -} - -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_kpadding_instances( - std::vector>>& instances) -{ - add_explicit_gemm_device_operation_instances< - 3, - NDHWGC, - GKZYXC, - NDHWGK, - BF16, - BF16, - BF16, - PassThrough, - PassThrough, - PassThrough, - device_gemm_xdl_universal_km_kn_mn_mem_instances>(instances); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp similarity index 95% rename from library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp index e7350ee6d4..1e280ed2bf 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instances( +void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances( std::vector>( + device_gemm_xdl_universal_km_kn_mn_mem_instances>( instances); } -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mkpadding_instances( +void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances( std::vector>( + device_gemm_xdl_universal_km_kn_mn_mem_instances>( instances); } diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instance.cpp similarity index 77% rename from library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instance.cpp index 1bed4ac5c4..a86efe9aa0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instances( +void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances( std::vector>(instances); + device_gemm_xdl_universal_km_kn_mn_irregular_odd_m_instances>(instances); } -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mkpadding_instances( +void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances( std::vector>(instances); + device_gemm_xdl_universal_km_kn_mn_irregular_odd_m_instances>(instances); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instance.cpp similarity index 77% rename from library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instance.cpp index 3cf9e00440..239664d1da 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instances( +void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances( std::vector>(instances); + device_gemm_xdl_universal_km_kn_mn_irregular_odd_mn_instances>(instances); } -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_kpadding_instances( +void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances( std::vector>(instances); + device_gemm_xdl_universal_km_kn_mn_irregular_odd_mn_instances>(instances); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instance.cpp similarity index 85% rename from library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instance.cpp index 8947235617..fe79c5c5dd 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/bf16_bf16_bf16/device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instances( +void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances( std::vector>(instances); + device_gemm_xdl_universal_km_kn_mn_odd_n_instances>( + instances); } -void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mpadding_instances( +void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances( std::vector>(instances); + device_gemm_xdl_universal_km_kn_mn_odd_n_instances>( + instances); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instance.cpp similarity index 96% rename from library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instance.cpp index 4a564da6c9..f1d1c5d228 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instances( +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances( std::vector>(instances); + device_gemm_xdl_universal_km_kn_mn_comp_instances>(instances); } -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mpadding_instances( +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances( std::vector>(instances); + device_gemm_xdl_universal_km_kn_mn_comp_instances>(instances); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp similarity index 94% rename from library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp index 5bf4b27771..3fd121dca6 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instances( +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances( std::vector>(instances); + device_gemm_xdl_universal_km_kn_mn_mem_instances>( + instances); } -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mkpadding_instances( +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances( std::vector>(instances); + device_gemm_xdl_universal_km_kn_mn_mem_instances>( + instances); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instance.cpp deleted file mode 100644 index 7e478364d3..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instance.cpp +++ /dev/null @@ -1,67 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_exp_gemm_xdl_universal_km_kn_mn_instance.hpp" -#include "ck/host_utility/device_prop.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instances( - std::vector>>& instances) -{ - add_explicit_gemm_device_operation_instances< - 2, - NHWGC, - GKYXC, - NHWGK, - F16, - F16, - F16, - PassThrough, - PassThrough, - PassThrough, - device_gemm_xdl_universal_km_kn_mn_mem_instances>(instances); -} - -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mkpadding_instances( - std::vector>>& instances) -{ - add_explicit_gemm_device_operation_instances< - 3, - NDHWGC, - GKZYXC, - NDHWGK, - F16, - F16, - F16, - PassThrough, - PassThrough, - PassThrough, - device_gemm_xdl_universal_km_kn_mn_mem_instances>(instances); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp similarity index 94% rename from library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp index 1b176d8d24..acc6c5e2df 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instances( +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances( std::vector>(instances); + device_gemm_xdl_universal_km_kn_mn_mem_instances>( + instances); } -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_kpadding_instances( +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances( std::vector>(instances); + device_gemm_xdl_universal_km_kn_mn_mem_instances>( + instances); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instance.cpp similarity index 77% rename from library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instance.cpp index 05636b2438..e9732bb675 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instances( +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances( std::vector>(instances); + device_gemm_xdl_universal_km_kn_mn_irregular_odd_m_instances>(instances); } -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mkpadding_instances( +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances( std::vector>(instances); + device_gemm_xdl_universal_km_kn_mn_irregular_odd_m_instances>(instances); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instance.cpp similarity index 77% rename from library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instance.cpp index 0d8755a31a..aaf1000249 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instances( +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances( std::vector>(instances); + device_gemm_xdl_universal_km_kn_mn_irregular_odd_mn_instances>(instances); } -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_kpadding_instances( +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances( std::vector>(instances); + device_gemm_xdl_universal_km_kn_mn_irregular_odd_mn_instances>(instances); } } // namespace instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instance.cpp similarity index 85% rename from library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instance.cpp index 174970fa12..1f9c8f3ca4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_convnd_bwd_weight/explicit_xdl/fp16_fp16_fp16/device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instance.cpp @@ -9,7 +9,7 @@ namespace tensor_operation { namespace device { namespace instance { -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instances( +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances( std::vector>(instances); + device_gemm_xdl_universal_km_kn_mn_odd_n_instances>( + instances); } -void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_kpadding_instances( +void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances( std::vector>(instances); + device_gemm_xdl_universal_km_kn_mn_odd_n_instances>( + instances); } } // namespace instance