mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Handle type conversions to a const datatype (#944)
* Handle type conversions to a const datatype
* Review: Handle X being const data type as well
* Review: Remove typo
[ROCm/composable_kernel commit: f4af5aed8b]
This commit is contained in:
committed by
GitHub
parent
9bc92adde3
commit
b50a087d91
@@ -9,8 +9,10 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Convert X to Y
|
||||
template <typename Y, typename X>
|
||||
// Convert X to Y, both X and Y are non-const data types.
|
||||
template <typename Y,
|
||||
typename X,
|
||||
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
|
||||
__host__ __device__ constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
@@ -18,6 +20,19 @@ __host__ __device__ constexpr Y type_convert(X x)
|
||||
return static_cast<Y>(x);
|
||||
}
|
||||
|
||||
// Convert X to Y, either X or Y is a const data type.
|
||||
template <typename Y,
|
||||
typename X,
|
||||
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false>
|
||||
__host__ __device__ constexpr Y type_convert(X x)
|
||||
{
|
||||
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
|
||||
|
||||
using NonConstY = std::remove_const_t<Y>;
|
||||
using NonConstX = std::remove_const_t<X>;
|
||||
return static_cast<Y>(type_convert<NonConstY, NonConstX>(x));
|
||||
}
|
||||
|
||||
// convert bfp16 to fp32
|
||||
template <>
|
||||
inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x)
|
||||
|
||||
@@ -13,3 +13,5 @@ add_gtest_executable(test_bf8 bf8.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_bf8 PRIVATE utility)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_type_convert_const type_convert_const.cpp)
|
||||
|
||||
93
test/data_type/type_convert_const.cpp
Normal file
93
test/data_type/type_convert_const.cpp
Normal file
@@ -0,0 +1,93 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
|
||||
using ck::bhalf_t;
|
||||
using ck::type_convert;
|
||||
|
||||
TEST(TypeConvertConst, ConvertToConst)
|
||||
{
|
||||
constexpr float bf16_epsilon = 0.0078125;
|
||||
constexpr float rel_tol = 2 * bf16_epsilon;
|
||||
|
||||
const std::vector<float> cases = {0.0, -123.f, 3.981323f, 0.2429f};
|
||||
|
||||
for(float x : cases)
|
||||
{
|
||||
const float abs_tol = std::abs(rel_tol * x);
|
||||
{
|
||||
bhalf_t y = type_convert<bhalf_t>(x);
|
||||
// Test non-const bhalf to const float.
|
||||
const float y_float = type_convert<const float>(y);
|
||||
ASSERT_NEAR(y_float, x, abs_tol);
|
||||
}
|
||||
{
|
||||
// Test non-const float to const bhalf.
|
||||
const bhalf_t y = type_convert<const bhalf_t>(x);
|
||||
// Remove the constness manually to not rely on const casts anymore since the
|
||||
// possible issue could hide after two casts.
|
||||
bhalf_t& y_nonconst = const_cast<bhalf_t&>(y);
|
||||
float y_float = type_convert<float>(y_nonconst);
|
||||
ASSERT_NEAR(y_float, x, abs_tol);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TypeConvertConst, ConvertFromConst)
|
||||
{
|
||||
constexpr float bf16_epsilon = 0.0078125;
|
||||
constexpr float rel_tol = 2 * bf16_epsilon;
|
||||
|
||||
const std::vector<float> cases = {0.0, -123.f, 3.981323f, 0.2429f};
|
||||
|
||||
for(const float x : cases)
|
||||
{
|
||||
const float abs_tol = std::abs(rel_tol * x);
|
||||
{
|
||||
// Test const float to const bhalf_t.
|
||||
const bhalf_t y = type_convert<const bhalf_t>(x);
|
||||
// Remove the constness manually to not rely on const casts anymore since the
|
||||
// possible issue could hide after two casts.
|
||||
bhalf_t& y_nonconst = const_cast<bhalf_t&>(y);
|
||||
float y_float = type_convert<float>(y_nonconst);
|
||||
ASSERT_NEAR(y_float, x, abs_tol);
|
||||
}
|
||||
{
|
||||
// Test const float to non-const bhalf.
|
||||
bhalf_t y = type_convert<bhalf_t>(x);
|
||||
float y_float = type_convert<float>(y);
|
||||
ASSERT_NEAR(y_float, x, abs_tol);
|
||||
}
|
||||
{
|
||||
const bhalf_t y = type_convert<const bhalf_t>(x);
|
||||
// Test const bhalf to non-const float.
|
||||
float y_float = type_convert<float>(y);
|
||||
ASSERT_NEAR(y_float, x, abs_tol);
|
||||
}
|
||||
// Tests with full type specializations for X.
|
||||
{
|
||||
// Test const float to const bhalf_t.
|
||||
const bhalf_t y = type_convert<const bhalf_t, const float>(x);
|
||||
// Remove the constness manually to not rely on const casts anymore since the
|
||||
// possible issue could hide after two casts.
|
||||
bhalf_t& y_nonconst = const_cast<bhalf_t&>(y);
|
||||
float y_float = type_convert<float>(y_nonconst);
|
||||
ASSERT_NEAR(y_float, x, abs_tol);
|
||||
}
|
||||
{
|
||||
// Test const float to non-const bhalf.
|
||||
bhalf_t y = type_convert<bhalf_t, const float>(x);
|
||||
float y_float = type_convert<float>(y);
|
||||
ASSERT_NEAR(y_float, x, abs_tol);
|
||||
}
|
||||
{
|
||||
const bhalf_t y = type_convert<const bhalf_t, const float>(x);
|
||||
// Test const bhalf to non-const float.
|
||||
float y_float = type_convert<float, const bhalf_t>(y);
|
||||
ASSERT_NEAR(y_float, x, abs_tol);
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user