diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 70619ee0a5..935926070d 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -9,8 +9,10 @@ namespace ck { -// Convert X to Y -template +// Convert X to Y, both X and Y are non-const data types. +template || std::is_const_v), bool> = false> __host__ __device__ constexpr Y type_convert(X x) { static_assert(!std::is_reference_v && !std::is_reference_v); @@ -18,6 +20,19 @@ __host__ __device__ constexpr Y type_convert(X x) return static_cast(x); } +// Convert X to Y, either X or Y is a const data type. +template || std::is_const_v, bool> = false> +__host__ __device__ constexpr Y type_convert(X x) +{ + static_assert(!std::is_reference_v && !std::is_reference_v); + + using NonConstY = std::remove_const_t; + using NonConstX = std::remove_const_t; + return static_cast(type_convert(x)); +} + // convert bfp16 to fp32 template <> inline __host__ __device__ constexpr float type_convert(bhalf_t x) diff --git a/test/data_type/CMakeLists.txt b/test/data_type/CMakeLists.txt index b1606d2a75..2409ca05c2 100644 --- a/test/data_type/CMakeLists.txt +++ b/test/data_type/CMakeLists.txt @@ -13,3 +13,5 @@ add_gtest_executable(test_bf8 bf8.cpp) if(result EQUAL 0) target_link_libraries(test_bf8 PRIVATE utility) endif() + +add_gtest_executable(test_type_convert_const type_convert_const.cpp) diff --git a/test/data_type/type_convert_const.cpp b/test/data_type/type_convert_const.cpp new file mode 100644 index 0000000000..8b9c34861a --- /dev/null +++ b/test/data_type/type_convert_const.cpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/utility/data_type.hpp" +#include "ck/utility/type_convert.hpp" + +using ck::bhalf_t; +using ck::type_convert; + +TEST(TypeConvertConst, ConvertToConst) +{ + constexpr float bf16_epsilon = 0.0078125; + constexpr float rel_tol = 2 * bf16_epsilon; + + const std::vector cases = {0.0, -123.f, 3.981323f, 0.2429f}; + + for(float x : cases) + { + const float abs_tol = std::abs(rel_tol * x); + { + bhalf_t y = type_convert(x); + // Test non-const bhalf to const float. + const float y_float = type_convert(y); + ASSERT_NEAR(y_float, x, abs_tol); + } + { + // Test non-const float to const bhalf. + const bhalf_t y = type_convert(x); + // Remove the constness manually to not rely on const casts anymore since the + // possible issue could hide after two casts. + bhalf_t& y_nonconst = const_cast(y); + float y_float = type_convert(y_nonconst); + ASSERT_NEAR(y_float, x, abs_tol); + } + } +} + +TEST(TypeConvertConst, ConvertFromConst) +{ + constexpr float bf16_epsilon = 0.0078125; + constexpr float rel_tol = 2 * bf16_epsilon; + + const std::vector cases = {0.0, -123.f, 3.981323f, 0.2429f}; + + for(const float x : cases) + { + const float abs_tol = std::abs(rel_tol * x); + { + // Test const float to const bhalf_t. + const bhalf_t y = type_convert(x); + // Remove the constness manually to not rely on const casts anymore since the + // possible issue could hide after two casts. + bhalf_t& y_nonconst = const_cast(y); + float y_float = type_convert(y_nonconst); + ASSERT_NEAR(y_float, x, abs_tol); + } + { + // Test const float to non-const bhalf. + bhalf_t y = type_convert(x); + float y_float = type_convert(y); + ASSERT_NEAR(y_float, x, abs_tol); + } + { + const bhalf_t y = type_convert(x); + // Test const bhalf to non-const float. + float y_float = type_convert(y); + ASSERT_NEAR(y_float, x, abs_tol); + } + // Tests with full type specializations for X. + { + // Test const float to const bhalf_t. + const bhalf_t y = type_convert(x); + // Remove the constness manually to not rely on const casts anymore since the + // possible issue could hide after two casts. + bhalf_t& y_nonconst = const_cast(y); + float y_float = type_convert(y_nonconst); + ASSERT_NEAR(y_float, x, abs_tol); + } + { + // Test const float to non-const bhalf. + bhalf_t y = type_convert(x); + float y_float = type_convert(y); + ASSERT_NEAR(y_float, x, abs_tol); + } + { + const bhalf_t y = type_convert(x); + // Test const bhalf to non-const float. + float y_float = type_convert(y); + ASSERT_NEAR(y_float, x, abs_tol); + } + } +}