From 99024ff3715906e150c8fe2f7498f982d5cfc32a Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 26 Sep 2023 08:39:11 -0700 Subject: [PATCH] Resolve some data type issues and cmake policy. (#940) * split the types in gemm_bilinear instances, add condition to cmake policy * fix syntax * split the data types in batchnorm examples * fix the batchnorm_bwd test * fix types in the batchnorm_bwd test [ROCm/composable_kernel commit: 2ea75bd6d7f0a061f78e4bb007a840a394a74ba9] --- CMakeLists.txt | 5 ++++- .../gpu/gemm_bilinear.hpp | 20 ++++++++++-------- test/batchnorm/batchnorm_bwd_rank_4.cpp | 21 +++++++++++++++---- test/batchnorm/batchnorm_fwd_rank_4.cpp | 21 +++++++++++++++---- test/batchnorm/batchnorm_infer_rank_4.cpp | 21 +++++++++++++++---- 5 files changed, 66 insertions(+), 22 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 95ef2df7e9..e5c82b9705 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,8 @@ cmake_minimum_required(VERSION 3.14) -cmake_policy(SET CMP0140 NEW) +if(POLICY CMP0140) + # policies CMP0140 not known to CMake until 3.25 + cmake_policy(SET CMP0140 NEW) +endif() # This has to be initialized before the project() command appears # Set the default of CMAKE_BUILD_TYPE to be release, unless user specifies with -D. MSVC_IDE does not use CMAKE_BUILD_TYPE diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp index 387499e584..1a518a5302 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_bilinear.hpp @@ -11,12 +11,12 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" -#ifdef CK_ENABLE_FP16 + namespace ck { namespace tensor_operation { namespace device { namespace instance { - +#ifdef CK_ENABLE_FP16 void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances( std::vector>>& instances); - +#endif +#ifdef CK_ENABLE_INT8 void add_device_gemm_bilinear_wmma_c_shuffle_i8_i8_i8_i8_mk_kn_mn_mn_instances( std::vector>>& instances); - +#endif // GEMM + Bilinear template > op_ptrs; - +#ifdef CK_ENABLE_FP16 if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) { @@ -187,8 +188,10 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) +#endif +#ifdef CK_ENABLE_INT8 + if constexpr(is_same_v && is_same_v && + is_same_v && is_same_v) { if constexpr(is_same_v && is_same_v && is_same_v && is_same_v) @@ -211,7 +214,7 @@ struct DeviceOperationInstanceFactory, - std::tuple, - std::tuple, - std::tuple>; +using KernelTypes = ::testing::Types< +#ifdef CK_ENABLE_FP16 + std::tuple +#endif +#ifdef CK_ENABLE_FP32 + , + std::tuple +#endif +#ifdef CK_ENABLE_BF16 + , + std::tuple +#endif +#ifdef CK_ENABLE_FP64 + , + std::tuple +#endif + >; TYPED_TEST_SUITE(TestBatchNormBwdRank4, KernelTypes); diff --git a/test/batchnorm/batchnorm_fwd_rank_4.cpp b/test/batchnorm/batchnorm_fwd_rank_4.cpp index 9b6fbd0f66..6bf635f0cd 100644 --- a/test/batchnorm/batchnorm_fwd_rank_4.cpp +++ b/test/batchnorm/batchnorm_fwd_rank_4.cpp @@ -87,10 +87,23 @@ class TestBatchNormFwdRank4 : public ::testing::Test } }; -using KernelTypes = ::testing::Types, - std::tuple, - std::tuple, - std::tuple>; +using KernelTypes = ::testing::Types< +#ifdef CK_ENABLE_FP16 + std::tuple +#endif +#ifdef CK_ENABLE_FP32 + , + std::tuple +#endif +#ifdef CK_ENABLE_BF16 + , + std::tuple +#endif +#ifdef CK_ENABLE_FP64 + , + std::tuple +#endif + >; TYPED_TEST_SUITE(TestBatchNormFwdRank4, KernelTypes); diff --git a/test/batchnorm/batchnorm_infer_rank_4.cpp b/test/batchnorm/batchnorm_infer_rank_4.cpp index ecb4043b36..0165192acf 100644 --- a/test/batchnorm/batchnorm_infer_rank_4.cpp +++ b/test/batchnorm/batchnorm_infer_rank_4.cpp @@ -67,10 +67,23 @@ class TestBatchNormInferRank4 : public ::testing::Test } }; -using KernelTypes = ::testing::Types, - std::tuple, - std::tuple, - std::tuple>; +using KernelTypes = ::testing::Types< +#ifdef CK_ENABLE_FP16 + std::tuple +#endif +#ifdef CK_ENABLE_FP32 + , + std::tuple +#endif +#ifdef CK_ENABLE_BF16 + , + std::tuple +#endif +#ifdef CK_ENABLE_FP64 + , + std::tuple +#endif + >; TYPED_TEST_SUITE(TestBatchNormInferRank4, KernelTypes);