From fc7756d43780294ebc6d1b4324bc33c282dbc5d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 26 Jul 2023 16:19:55 +0200 Subject: [PATCH] Disable XDL kernels on unsupported HW Add ck::is_xdl_supported (#768) * Disable XDL kernels on unsupported HW; Add ck::is_xdl_supported function (#765) * Do not throw an error when GEMM problem is not supported. --------- Co-authored-by: Bartlomiej Wroblewski Co-authored-by: Adam Osewski Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> [ROCm/composable_kernel commit: ac6d68b3536fe9f7e563e4903733b377aaa9013f] --- example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp | 6 +++--- include/ck/host_utility/device_prop.hpp | 7 +++++++ .../device_batched_contraction_multiple_d_xdl_cshuffle.hpp | 4 +--- .../gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp | 5 +++++ .../device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp | 4 +--- .../gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp | 4 +--- ...atched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp | 4 +--- .../impl/device_batched_gemm_reduce_xdl_cshuffle.hpp | 5 +++++ ...vice_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp | 4 +--- .../impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp | 4 +--- .../gpu/device/impl/device_batched_gemm_xdl.hpp | 5 +++++ .../gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp | 5 +++++ .../impl/device_contraction_multiple_d_xdl_cshuffle.hpp | 4 +--- ...conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp | 5 +++++ .../impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp | 5 +++++ ...wd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp | 5 +++++ ...2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp | 5 +++++ .../device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp | 5 +++++ .../device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp | 5 +++++ .../impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp | 5 +++++ .../device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp | 5 +++++ .../impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp | 5 +++++ .../impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp | 4 +--- .../device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp | 4 +--- .../device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp | 4 +--- .../gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp | 5 +++++ .../gpu/device/impl/device_gemm_xdl_cshuffle.hpp | 4 +--- .../gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp | 4 +--- .../gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp | 5 +++++ .../gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp | 5 +++++ .../device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp | 4 +--- .../device_grouped_contraction_multiple_d_xdl_cshuffle.hpp | 4 +--- ...ce_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 5 +++++ ...vice_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp | 4 +--- .../gpu/device/impl/device_grouped_gemm_xdl.hpp | 5 +++++ .../impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp | 5 +++++ .../device_splitk_contraction_multiple_d_xdl_cshuffle.hpp | 4 +--- 37 files changed, 121 insertions(+), 51 deletions(-) diff --git a/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp b/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp index 3afd0ebdb9..4a0c23cf44 100644 --- a/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp +++ b/example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp @@ -204,9 +204,9 @@ int main(int argc, char* argv[]) if(!gemm.IsSupportedArgument(argument)) { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; } float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index bd02d5d88a..be1dbc1657 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -51,4 +51,11 @@ inline std::string get_device_name() return name; } +inline bool is_xdl_supported() +{ + return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || + ck::get_device_name() == "gfx942"; +} + } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp index 46e71240c1..3270325069 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -840,9 +840,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || - ck::get_device_name() == "gfx942")) + if(!ck::is_xdl_supported()) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp index fc080df5f5..8f46e0c498 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp @@ -571,6 +571,11 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute{}, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp index 7012584aab..3dbe8c6722 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp @@ -589,9 +589,7 @@ struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm(arg.gemm_desc_kernel_arg_.size()) + arg.skipped_group_count_) != arg.group_count_) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp index 74f38b9db2..4f4413b78f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -502,6 +502,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK(arg.gemm_kernel_args_.size()) + arg.skipped_group_count_) != arg.group_count_) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp index f849ac799d..7ba2d96a9c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp @@ -939,9 +939,7 @@ struct DeviceSplitKContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || - ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" || - ck::get_device_name() == "gfx942")) + if(!ck::is_xdl_supported()) { return false; }