Handle type conversions to a const datatype (#944)

* Handle type conversions to a const datatype

* Review: Handle X being const data type as well

* Review: Remove typo

[ROCm/composable_kernel commit: f4af5aed8b]
This commit is contained in:
Bartlomiej Wroblewski
2023-09-27 22:02:42 +02:00
committed by GitHub
parent be5cb244c0
commit bf38d27453
3 changed files with 112 additions and 2 deletions

View File

@@ -9,8 +9,10 @@
namespace ck {
// Convert X to Y
template <typename Y, typename X>
// Convert X to Y, both X and Y are non-const data types.
template <typename Y,
typename X,
std::enable_if_t<!(std::is_const_v<Y> || std::is_const_v<X>), bool> = false>
__host__ __device__ constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
@@ -18,6 +20,19 @@ __host__ __device__ constexpr Y type_convert(X x)
return static_cast<Y>(x);
}
// Convert X to Y, either X or Y is a const data type.
template <typename Y,
typename X,
std::enable_if_t<std::is_const_v<Y> || std::is_const_v<X>, bool> = false>
__host__ __device__ constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
using NonConstY = std::remove_const_t<Y>;
using NonConstX = std::remove_const_t<X>;
return static_cast<Y>(type_convert<NonConstY, NonConstX>(x));
}
// convert bfp16 to fp32
template <>
inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x)