diff --git a/CMakeLists.txt b/CMakeLists.txt index 95ef2df7e9..e5c82b9705 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,8 @@ cmake_minimum_required(VERSION 3.14) -cmake_policy(SET CMP0140 NEW) +if(POLICY CMP0140) + # policies CMP0140 not known to CMake until 3.25 + cmake_policy(SET CMP0140 NEW) +endif() # This has to be initialized before the project() command appears # Set the default of CMAKE_BUILD_TYPE to be release, unless user specifies with -D. MSVC_IDE does not use CMAKE_BUILD_TYPE diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp index 387499e584..1a518a5302 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp @@ -11,12 +11,12 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#ifdef CK_ENABLE_FP16 + namespace ck { namespace tensor_operation { namespace device { namespace instance { - +#ifdef CK_ENABLE_FP16 void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_INT8 void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances( std::vector>>& instances); - +#endif // GEMM + Bilinear template > op_ptrs; - +#ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { @@ -187,8 +188,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) @@ -211,7 +214,7 @@ struct DeviceOperationInstanceFactory, - std::tuple, - std::tuple, - std::tuple>; +using KernelTypes = ::testing::Types< +#ifdef CK_ENABLE_FP16 + std::tuple +#endif +#ifdef CK_ENABLE_FP32 + , + std::tuple +#endif +#ifdef CK_ENABLE_BF16 + , + std::tuple +#endif +#ifdef CK_ENABLE_FP64 + , + std::tuple +#endif + >; TYPED_TEST_SUITE(TestBatchNormBwdRank4, KernelTypes); diff --git a/test/batchnorm/batchnorm_fwd_rank_4.cpp b/test/batchnorm/batchnorm_fwd_rank_4.cpp index 9b6fbd0f66..6bf635f0cd 100644 --- a/test/batchnorm/batchnorm_fwd_rank_4.cpp +++ b/test/batchnorm/batchnorm_fwd_rank_4.cpp @@ -87,10 +87,23 @@ class TestBatchNormFwdRank4 : public ::testing::Test } }; -using KernelTypes = ::testing::Types, - std::tuple, - std::tuple, - std::tuple>; +using KernelTypes = ::testing::Types< +#ifdef CK_ENABLE_FP16 + std::tuple +#endif +#ifdef CK_ENABLE_FP32 + , + std::tuple +#endif +#ifdef CK_ENABLE_BF16 + , + std::tuple +#endif +#ifdef CK_ENABLE_FP64 + , + std::tuple +#endif + >; TYPED_TEST_SUITE(TestBatchNormFwdRank4, KernelTypes); diff --git a/test/batchnorm/batchnorm_infer_rank_4.cpp b/test/batchnorm/batchnorm_infer_rank_4.cpp index ecb4043b36..0165192acf 100644 --- a/test/batchnorm/batchnorm_infer_rank_4.cpp +++ b/test/batchnorm/batchnorm_infer_rank_4.cpp @@ -67,10 +67,23 @@ class TestBatchNormInferRank4 : public ::testing::Test } }; -using KernelTypes = ::testing::Types, - std::tuple, - std::tuple, - std::tuple>; +using KernelTypes = ::testing::Types< +#ifdef CK_ENABLE_FP16 + std::tuple +#endif +#ifdef CK_ENABLE_FP32 + , + std::tuple +#endif +#ifdef CK_ENABLE_BF16 + , + std::tuple +#endif +#ifdef CK_ENABLE_FP64 + , + std::tuple +#endif + >; TYPED_TEST_SUITE(TestBatchNormInferRank4, KernelTypes);