From d56cc51f7dfb2f5184b07b40545f4534c131208b Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 30 May 2022 19:57:49 -0500 Subject: [PATCH] Minor fix for recent PR (#260) * fix example * update IsSupportedArgument * fix * disable fp64 conv example as test [ROCm/composable_kernel commit: 85fc91c3218c1d85169ed1fe95eef7b07942e648] --- example/01_gemm/CMakeLists.txt | 3 ++- example/01_gemm/gemm_dl_fp16.cpp | 4 +-- example/01_gemm/gemm_dl_fp32.cpp | 4 +-- example/01_gemm/gemm_dl_int8.cpp | 4 +-- example/01_gemm/gemm_xdl_bf16.cpp | 6 ++--- example/01_gemm/gemm_xdl_fp16.cpp | 6 ++--- example/01_gemm/gemm_xdl_fp64.cpp | 10 +++---- example/01_gemm/gemm_xdl_int8.cpp | 6 ++--- example/09_convnd_fwd/CMakeLists.txt | 3 ++- .../device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp | 26 ++++++++++++++++--- .../gpu/device/device_gemm_dl.hpp | 2 +- .../gpu/device/device_gemm_xdl.hpp | 20 ++++++++++++-- 12 files changed, 62 insertions(+), 32 deletions(-) diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index e458026c82..c03c454c68 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -4,4 +4,5 @@ add_example_executable(example_gemm_dl_int8 gemm_dl_int8.cpp) add_example_executable(example_gemm_xdl_fp16 gemm_xdl_fp16.cpp) add_example_executable(example_gemm_xdl_bf16 gemm_xdl_bf16.cpp) add_example_executable(example_gemm_xdl_int8 gemm_xdl_int8.cpp) -add_example_executable(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) +# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed +add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp) diff --git a/example/01_gemm/gemm_dl_fp16.cpp b/example/01_gemm/gemm_dl_fp16.cpp index 63d96a8e99..9a22628777 100644 --- a/example/01_gemm/gemm_dl_fp16.cpp +++ b/example/01_gemm/gemm_dl_fp16.cpp @@ -170,9 +170,7 @@ int main(int argc, char* argv[]) if(!gemm.IsSupportedArgument(argument)) { - std::cout << "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem" - << std::endl; + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; return 0; } diff --git a/example/01_gemm/gemm_dl_fp32.cpp b/example/01_gemm/gemm_dl_fp32.cpp index 20ca1a4d3d..32b183a3a1 100644 --- a/example/01_gemm/gemm_dl_fp32.cpp +++ b/example/01_gemm/gemm_dl_fp32.cpp @@ -169,9 +169,7 @@ int main(int argc, char* argv[]) if(!gemm.IsSupportedArgument(argument)) { - std::cout << "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem" - << std::endl; + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; return 0; } diff --git a/example/01_gemm/gemm_dl_int8.cpp b/example/01_gemm/gemm_dl_int8.cpp index caedb22537..16c9213104 100644 --- a/example/01_gemm/gemm_dl_int8.cpp +++ b/example/01_gemm/gemm_dl_int8.cpp @@ -167,9 +167,7 @@ int main(int argc, char* argv[]) if(!gemm.IsSupportedArgument(argument)) { - std::cout << "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem" - << std::endl; + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; return 0; } diff --git a/example/01_gemm/gemm_xdl_bf16.cpp b/example/01_gemm/gemm_xdl_bf16.cpp index 5bbfe96994..b126736be6 100644 --- a/example/01_gemm/gemm_xdl_bf16.cpp +++ b/example/01_gemm/gemm_xdl_bf16.cpp @@ -193,9 +193,9 @@ int main(int argc, char* argv[]) if(!gemm.IsSupportedArgument(argument)) { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; } float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); diff --git a/example/01_gemm/gemm_xdl_fp16.cpp b/example/01_gemm/gemm_xdl_fp16.cpp index a17e64f174..003534f79a 100644 --- a/example/01_gemm/gemm_xdl_fp16.cpp +++ b/example/01_gemm/gemm_xdl_fp16.cpp @@ -166,9 +166,9 @@ int main(int argc, char* argv[]) if(!gemm.IsSupportedArgument(argument)) { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; } float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); diff --git a/example/01_gemm/gemm_xdl_fp64.cpp b/example/01_gemm/gemm_xdl_fp64.cpp index 150d547264..7cea68c8b0 100644 --- a/example/01_gemm/gemm_xdl_fp64.cpp +++ b/example/01_gemm/gemm_xdl_fp64.cpp @@ -21,8 +21,6 @@ template using S = ck::Sequence; using F64 = double; -using F32 = float; -using F16 = ck::half_t; using ADataType = double; using BDataType = double; @@ -195,9 +193,9 @@ int main(int argc, char* argv[]) if(!gemm.IsSupportedArgument(argument)) { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; } float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); @@ -233,7 +231,7 @@ int main(int argc, char* argv[]) show_2d_matrix(std::cout << "c_host :", c_m_n_host_result) << std::endl; } #endif - ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); + return ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData) ? 0 : 1; } return 0; diff --git a/example/01_gemm/gemm_xdl_int8.cpp b/example/01_gemm/gemm_xdl_int8.cpp index 094a12e4e7..27fcd62a2c 100644 --- a/example/01_gemm/gemm_xdl_int8.cpp +++ b/example/01_gemm/gemm_xdl_int8.cpp @@ -194,9 +194,9 @@ int main(int argc, char* argv[]) if(!gemm.IsSupportedArgument(argument)) { - throw std::runtime_error( - "wrong! device_gemm with the specified compilation parameters does " - "not support this GEMM problem"); + std::cout << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return 0; } float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index bb3c31abf2..1724e51f3f 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -1,7 +1,8 @@ add_example_executable(example_convnd_fwd_xdl_fp32 convnd_fwd_xdl_fp32.cpp) add_example_executable(example_convnd_fwd_xdl_int8 convnd_fwd_xdl_int8.cpp) add_example_executable(example_convnd_fwd_xdl_fp16 convnd_fwd_xdl_fp16.cpp) -add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) +# FIXME: re-enable this exampe as test when SWDEV-335738 is fixed +add_example_executable_no_testing(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) target_link_libraries(example_convnd_fwd_xdl_fp64 PRIVATE conv_util) target_link_libraries(example_convnd_fwd_xdl_fp32 PRIVATE conv_util) target_link_libraries(example_convnd_fwd_xdl_int8 PRIVATE conv_util) diff --git a/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp index 1678f9991e..c1ab44a28b 100644 --- a/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -1,5 +1,4 @@ -#ifndef DEVICE_CONVND_FWD_XDL_NHWC_KYXC_NHWK_HPP -#define DEVICE_CONVND_FWD_XDL_NHWC_KYXC_NHWK_HPP +#pragma once #include #include @@ -8,6 +7,7 @@ #include #include "device.hpp" +#include "device_prop.hpp" #include "device_base.hpp" #include "device_conv_fwd.hpp" #include "convolution_forward_specialization.hpp" @@ -858,6 +858,27 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K static bool IsSupportedArgument(const Argument& arg) { + if(ck::get_device_name() == "gfx908") + { + if constexpr(!(is_same_v || is_same_v || + is_same_v)) + { + return false; + } + } + else if(ck::get_device_name() == "gfx90a") + { + if constexpr(!(is_same_v || is_same_v || + is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + // Input tensors can't be bigger than 2GB each. constexpr ck::long_index_t GB2 = (ck::long_index_t{1} << 31); @@ -1021,4 +1042,3 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K } // namespace device } // namespace tensor_operation } // namespace ck -#endif diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp index a6a059df77..8cd678fc1e 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp @@ -4,6 +4,7 @@ #include #include "device.hpp" +#include "device_prop.hpp" #include "device_base.hpp" #include "device_gemm.hpp" #include "common_header.hpp" @@ -13,7 +14,6 @@ #include "gemm_specialization.hpp" #include "element_wise_operation.hpp" #include "gridwise_gemm_dl_v1r3.hpp" -#include "device_prop.hpp" namespace ck { namespace tensor_operation { diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp index 31f354358f..3a8e1390e4 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp @@ -3,6 +3,7 @@ #include #include #include "device.hpp" +#include "device_prop.hpp" #include "device_base.hpp" #include "device_gemm.hpp" #include "common_header.hpp" @@ -11,7 +12,6 @@ #include "tensor_descriptor_helper.hpp" #include "gridwise_gemm_xdlops_v2r3.hpp" #include "gemm_specialization.hpp" -#include "device_prop.hpp" namespace ck { namespace tensor_operation { @@ -408,7 +408,23 @@ struct DeviceGemmXdl static bool IsSupportedArgument(const Argument& arg) { - if(!(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a")) + if(ck::get_device_name() == "gfx908") + { + if constexpr(!(is_same_v || is_same_v || + is_same_v)) + { + return false; + } + } + else if(ck::get_device_name() == "gfx90a") + { + if constexpr(!(is_same_v || is_same_v || + is_same_v || is_same_v)) + { + return false; + } + } + else { return false; }