mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Split up data_type header. (#1996)
* split fp64 vector data type
* add missing header
* move e8m0 structs
* split off numeric_utils header
* fix typo
* split off numeric limits header
* update data_type header
* fix clang format
* split off vector type header
* fix clang format
* fix typo for binary_inf
[ROCm/composable_kernel commit: d2eab23958]
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
#define CK_AMD_INLINE_ASM_HPP
|
||||
|
||||
#include "c_style_pointer_cast.hpp"
|
||||
#include "data_type.hpp"
|
||||
#include "dtype_vector.hpp"
|
||||
|
||||
// TODO: deprecate all amd_assembly_outer_product_xxx
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "ck/utility/dtype_fp64.hpp"
|
||||
|
||||
namespace ck {
|
||||
// Define the common macro for MI300 models
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
7
include/ck/utility/dtype_fp64.hpp
Normal file
7
include/ck/utility/dtype_fp64.hpp
Normal file
@@ -0,0 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
namespace ck {
|
||||
// fp64
|
||||
using double2_t = typename vector_type<double, 2>::type;
|
||||
using double4_t = typename vector_type<double, 4>::type;
|
||||
} // namespace ck
|
||||
2152
include/ck/utility/dtype_vector.hpp
Normal file
2152
include/ck/utility/dtype_vector.hpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -3,7 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/numeric_utils.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
#include "data_type.hpp"
|
||||
#include "dtype_fp64.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "data_type.hpp"
|
||||
#include "numeric_limits.hpp"
|
||||
#include "integral_constant.hpp"
|
||||
#include "number.hpp"
|
||||
#include "type.hpp"
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/numeric_limits.hpp"
|
||||
#include "ck/utility/mxfp_utils.hpp"
|
||||
|
||||
namespace ck::utils {
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/numeric_limits.hpp"
|
||||
#include "ck/utility/mxfp_utils.hpp"
|
||||
|
||||
namespace ck::utils {
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
#include "ck/utility/data_type.hpp"
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/utility/numeric_limits.hpp"
|
||||
#include "ck/utility/mxfp_utils.hpp"
|
||||
|
||||
#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__
|
||||
|
||||
555
include/ck/utility/numeric_limits.hpp
Normal file
555
include/ck/utility/numeric_limits.hpp
Normal file
@@ -0,0 +1,555 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC)
|
||||
template <typename T>
|
||||
struct NumericLimits;
|
||||
|
||||
template <>
|
||||
struct NumericLimits<int32_t>
|
||||
{
|
||||
__host__ __device__ static constexpr int32_t Lowest() noexcept { return -2147483647 - 1; }
|
||||
|
||||
__host__ __device__ static constexpr int32_t Min() noexcept { return -2147483647 - 1; }
|
||||
|
||||
__host__ __device__ static constexpr int32_t Max() noexcept { return 2147483647; }
|
||||
|
||||
__host__ __device__ static constexpr int32_t Infinity() noexcept { return 0; }
|
||||
|
||||
__host__ __device__ static constexpr int32_t QuietNaN() { return 0; }
|
||||
};
|
||||
template <>
|
||||
struct NumericLimits<int16_t>
|
||||
{
|
||||
__host__ __device__ static constexpr int16_t Lowest() noexcept { return -32768; }
|
||||
|
||||
__host__ __device__ static constexpr int16_t Min() noexcept { return -32768; }
|
||||
|
||||
__host__ __device__ static constexpr int16_t Max() noexcept { return 32767; }
|
||||
|
||||
__host__ __device__ static constexpr int16_t Infinity() noexcept { return 0; }
|
||||
|
||||
__host__ __device__ static constexpr int16_t QuietNaN() { return 0; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<int8_t>
|
||||
{
|
||||
__host__ __device__ static constexpr int8_t Lowest() noexcept { return -128; }
|
||||
|
||||
__host__ __device__ static constexpr int8_t Min() noexcept { return -128; }
|
||||
|
||||
__host__ __device__ static constexpr int8_t Max() noexcept { return 127; }
|
||||
|
||||
__host__ __device__ static constexpr int8_t Infinity() noexcept { return 0; }
|
||||
|
||||
__host__ __device__ static constexpr int8_t QuietNaN() { return 0; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<uint32_t>
|
||||
{
|
||||
__host__ __device__ static constexpr uint32_t Lowest() noexcept { return 0; }
|
||||
|
||||
__host__ __device__ static constexpr uint32_t Min() noexcept { return 0; }
|
||||
|
||||
__host__ __device__ static constexpr uint32_t Max() noexcept { return 4294967295U; }
|
||||
|
||||
__host__ __device__ static constexpr uint32_t Infinity() noexcept { return 0; }
|
||||
|
||||
__host__ __device__ static constexpr uint32_t QuietNaN() { return 0; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<uint16_t>
|
||||
{
|
||||
__host__ __device__ static constexpr uint16_t Lowest() noexcept { return 0; }
|
||||
|
||||
__host__ __device__ static constexpr uint16_t Min() noexcept { return 0; }
|
||||
|
||||
__host__ __device__ static constexpr uint16_t Max() noexcept { return 65535U; }
|
||||
|
||||
__host__ __device__ static constexpr uint16_t Infinity() noexcept { return 0; }
|
||||
|
||||
__host__ __device__ static constexpr uint16_t QuietNaN() { return 0; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<float>
|
||||
{
|
||||
static constexpr unsigned int binary_min = 0x00800000;
|
||||
static constexpr unsigned int binary_max = 0x7F7FFFFF;
|
||||
static constexpr unsigned int binary_lowest = 0xFF7FFFFF;
|
||||
static constexpr unsigned int binary_qnan = 0xFFC00001;
|
||||
static constexpr unsigned int binary_inf = 0x7F800000;
|
||||
|
||||
__host__ __device__ static constexpr float Min() { return bit_cast<float>(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr float Max() { return bit_cast<float>(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr float Lowest() { return bit_cast<float>(binary_lowest); }
|
||||
|
||||
__host__ __device__ static constexpr float QuietNaN() { return bit_cast<float>(binary_qnan); }
|
||||
|
||||
__host__ __device__ static constexpr float Infinity() { return bit_cast<float>(binary_inf); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<half_t>
|
||||
{
|
||||
static constexpr unsigned short binary_min = 0x0400;
|
||||
static constexpr unsigned short binary_max = 0x7BFF;
|
||||
static constexpr unsigned short binary_lowest = 0xFBFF;
|
||||
static constexpr unsigned short binary_qnan = 0x7FFF;
|
||||
|
||||
__host__ __device__ static constexpr half_t Min() { return bit_cast<half_t>(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr half_t Max() { return bit_cast<half_t>(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr half_t Lowest() { return bit_cast<half_t>(binary_lowest); }
|
||||
|
||||
__host__ __device__ static constexpr half_t QuietNaN() { return bit_cast<half_t>(binary_qnan); }
|
||||
};
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
template <>
|
||||
struct NumericLimits<int4_t>
|
||||
{
|
||||
__host__ __device__ static constexpr int4_t Min() { return int4_t(-8); }
|
||||
|
||||
__host__ __device__ static constexpr int4_t Max() { return int4_t(7); }
|
||||
|
||||
__host__ __device__ static constexpr int4_t Lowest() { return int4_t(-8); }
|
||||
};
|
||||
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
|
||||
template <>
|
||||
struct NumericLimits<f8_fnuz_t>
|
||||
{
|
||||
// negative zero nan mode with exp bias = 8
|
||||
static constexpr uint8_t binary_min = 0x08; // 0b00001000
|
||||
static constexpr uint8_t binary_max = 0x7F; // 0b01111111
|
||||
static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111
|
||||
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
|
||||
// ieee mode with exp bias = 7
|
||||
// static constexpr uint8_t binary_min = 0x08; // 0b00001000
|
||||
// static constexpr uint8_t binary_max = 0x77; // 0b01110111
|
||||
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
|
||||
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
|
||||
|
||||
__host__ __device__ static constexpr f8_fnuz_t Min() { return f8_fnuz_t(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr f8_fnuz_t Max() { return f8_fnuz_t(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr f8_fnuz_t Lowest() { return f8_fnuz_t(binary_lowest); }
|
||||
|
||||
__host__ __device__ static constexpr f8_fnuz_t QuietNaN() { return f8_fnuz_t(binary_qnan); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<bf8_fnuz_t>
|
||||
{
|
||||
// negative zero nan mode with exp bias = 16
|
||||
static constexpr uint8_t binary_min = 0x04; // 0b00000100
|
||||
static constexpr uint8_t binary_max = 0x7F; // 0b01111111
|
||||
static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111
|
||||
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
|
||||
// ieee mode with exp bias = 15
|
||||
// static constexpr uint8_t binary_min = 0x04; // 0b00000100
|
||||
// static constexpr uint8_t binary_max = 0x7B; // 0b01111011
|
||||
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
|
||||
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
|
||||
|
||||
__host__ __device__ static constexpr bf8_fnuz_t Min() { return bf8_fnuz_t(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr bf8_fnuz_t Max() { return bf8_fnuz_t(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr bf8_fnuz_t Lowest() { return bf8_fnuz_t(binary_lowest); }
|
||||
|
||||
__host__ __device__ static constexpr bf8_fnuz_t QuietNaN() { return bf8_fnuz_t(binary_qnan); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<f8_ocp_t>
|
||||
{
|
||||
static constexpr uint8_t binary_min = 0x08; // 0b00001000 = 2^-6
|
||||
static constexpr uint8_t binary_max = 0x7E; // 0b01111110 = 448
|
||||
static constexpr uint8_t binary_lowest = 0xFE; // 0b11111110 = -448
|
||||
static constexpr uint8_t binary_qnan = 0x7F; // 0b01111111
|
||||
|
||||
__host__ __device__ static constexpr f8_ocp_t Min() { return bit_cast<f8_ocp_t>(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr f8_ocp_t Max() { return bit_cast<f8_ocp_t>(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr f8_ocp_t Lowest()
|
||||
{
|
||||
return bit_cast<f8_ocp_t>(binary_lowest);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr f8_ocp_t QuietNaN()
|
||||
{
|
||||
return bit_cast<f8_ocp_t>(binary_qnan);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<bf8_ocp_t>
|
||||
{
|
||||
static constexpr uint8_t binary_min = 0x04; // 0b00000100 = 2^-14
|
||||
static constexpr uint8_t binary_max = 0x7B; // 0b01111011 = 57344
|
||||
static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 = -57344
|
||||
static constexpr uint8_t binary_qnan = 0x7D; // 0b01111101
|
||||
|
||||
__host__ __device__ static constexpr bf8_ocp_t Min() { return bit_cast<bf8_ocp_t>(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr bf8_ocp_t Max() { return bit_cast<bf8_ocp_t>(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr bf8_ocp_t Lowest()
|
||||
{
|
||||
return bit_cast<bf8_ocp_t>(binary_lowest);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bf8_ocp_t QuietNaN()
|
||||
{
|
||||
return bit_cast<bf8_ocp_t>(binary_qnan);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<f4_t>
|
||||
{
|
||||
static constexpr uint8_t binary_min_normal = 0x2; // 0b0010
|
||||
static constexpr uint8_t binary_max_normal = 0x7; // 0b0111
|
||||
static constexpr uint8_t binary_lowest_normal = 0xF; // 0b1111
|
||||
static constexpr uint8_t binary_min_subnorm = 0x1; // 0b0001
|
||||
static constexpr uint8_t binary_max_subnorm = 0x1; // 0b0001
|
||||
|
||||
static constexpr float data_max_normal_number = 6;
|
||||
static constexpr float data_min_subnormal_number = 0.5;
|
||||
|
||||
__host__ __device__ static constexpr f4_t Min() { return f4_t(binary_min_normal); }
|
||||
__host__ __device__ static constexpr f4_t Max() { return f4_t(binary_max_normal); }
|
||||
__host__ __device__ static constexpr f4_t Lowest() { return f4_t(binary_lowest_normal); }
|
||||
__host__ __device__ static constexpr f4_t MinSubnorm() { return f4_t(binary_min_subnorm); }
|
||||
__host__ __device__ static constexpr f4_t MaxSubnorm() { return f4_t(binary_max_subnorm); }
|
||||
|
||||
__host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; }
|
||||
__host__ __device__ static constexpr float DataMinSubnorm()
|
||||
{
|
||||
return data_min_subnormal_number;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<f6_t>
|
||||
{
|
||||
static constexpr uint8_t binary_min_normal = 0x08; // 0b001000
|
||||
static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111
|
||||
static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111
|
||||
static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001
|
||||
static constexpr uint8_t binary_max_subnorm = 0x07; // 0b000111
|
||||
|
||||
static constexpr float data_max_normal_number = 7.5;
|
||||
static constexpr float data_min_subnormal_number = 0.125;
|
||||
|
||||
__host__ __device__ static constexpr f6_t Min() { return f6_t(binary_min_normal & 0b111111); }
|
||||
__host__ __device__ static constexpr f6_t Max() { return f6_t(binary_max_normal & 0b111111); }
|
||||
__host__ __device__ static constexpr f6_t Lowest()
|
||||
{
|
||||
return f6_t(binary_lowest_normal & 0b111111);
|
||||
}
|
||||
__host__ __device__ static constexpr f6_t MinSubnorm()
|
||||
{
|
||||
return f6_t(binary_min_subnorm & 0b111111);
|
||||
}
|
||||
__host__ __device__ static constexpr f6_t MaxSubnorm()
|
||||
{
|
||||
return f6_t(binary_max_subnorm & 0b111111);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; }
|
||||
__host__ __device__ static constexpr float DataMinSubnorm()
|
||||
{
|
||||
return data_min_subnormal_number;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<bf6_t>
|
||||
{
|
||||
static constexpr uint8_t binary_min_normal = 0x08; // 0b001000
|
||||
static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111
|
||||
static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111
|
||||
static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001
|
||||
static constexpr uint8_t binary_max_subnorm = 0x03; // 0b000011
|
||||
|
||||
static constexpr float data_max_normal_number = 28;
|
||||
static constexpr float data_min_subnormal_number = 0.0625;
|
||||
|
||||
__host__ __device__ static constexpr bf6_t Min() { return bf6_t(binary_min_normal); }
|
||||
__host__ __device__ static constexpr bf6_t Max() { return bf6_t(binary_max_normal); }
|
||||
__host__ __device__ static constexpr bf6_t Lowest() { return bf6_t(binary_lowest_normal); }
|
||||
__host__ __device__ static constexpr bf6_t MinSubnorm() { return bf6_t(binary_min_subnorm); }
|
||||
__host__ __device__ static constexpr bf6_t MaxSubnorm() { return bf6_t(binary_max_subnorm); }
|
||||
|
||||
__host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; }
|
||||
__host__ __device__ static constexpr float DataMinSubnorm()
|
||||
{
|
||||
return data_min_subnormal_number;
|
||||
}
|
||||
};
|
||||
|
||||
#else
|
||||
template <typename T>
|
||||
struct NumericLimits
|
||||
{
|
||||
__host__ __device__ static constexpr T Min() { return std::numeric_limits<T>::min(); }
|
||||
__host__ __device__ static constexpr T Max() { return std::numeric_limits<T>::max(); }
|
||||
__host__ __device__ static constexpr T Lowest() { return std::numeric_limits<T>::lowest(); }
|
||||
__host__ __device__ static constexpr T QuietNaN()
|
||||
{
|
||||
return std::numeric_limits<T>::quiet_NaN();
|
||||
}
|
||||
__host__ __device__ static constexpr T Infinity() { return std::numeric_limits<T>::infinity(); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<half_t>
|
||||
{
|
||||
static constexpr unsigned short binary_min = 0x0400;
|
||||
static constexpr unsigned short binary_max = 0x7BFF;
|
||||
static constexpr unsigned short binary_lowest = 0xFBFF;
|
||||
static constexpr unsigned short binary_qnan = 0x7FFF;
|
||||
|
||||
__host__ __device__ static constexpr half_t Min() { return bit_cast<half_t>(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr half_t Max() { return bit_cast<half_t>(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr half_t Lowest() { return bit_cast<half_t>(binary_lowest); }
|
||||
|
||||
__host__ __device__ static constexpr half_t QuietNaN() { return bit_cast<half_t>(binary_qnan); }
|
||||
};
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
template <>
|
||||
struct NumericLimits<int4_t>
|
||||
{
|
||||
__host__ __device__ static constexpr int4_t Min() { return int4_t(-8); }
|
||||
|
||||
__host__ __device__ static constexpr int4_t Max() { return int4_t(7); }
|
||||
|
||||
__host__ __device__ static constexpr int4_t Lowest() { return int4_t(-8); }
|
||||
};
|
||||
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
|
||||
template <>
|
||||
struct NumericLimits<f8_fnuz_t>
|
||||
{
|
||||
// negative zero nan mode with exp bias = 8
|
||||
static constexpr uint8_t binary_min = 0x08; // 0b00001000
|
||||
static constexpr uint8_t binary_max = 0x7F; // 0b01111111
|
||||
static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111
|
||||
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
|
||||
// ieee mode with exp bias = 7
|
||||
// static constexpr uint8_t binary_min = 0x08; // 0b00001000
|
||||
// static constexpr uint8_t binary_max = 0x77; // 0b01110111
|
||||
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
|
||||
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
|
||||
|
||||
__host__ __device__ static constexpr f8_fnuz_t Min() { return f8_fnuz_t(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr f8_fnuz_t Max() { return f8_fnuz_t(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr f8_fnuz_t Lowest() { return f8_fnuz_t(binary_lowest); }
|
||||
|
||||
__host__ __device__ static constexpr f8_fnuz_t QuietNaN() { return f8_fnuz_t(binary_qnan); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<bf8_fnuz_t>
|
||||
{
|
||||
// negative zero nan mode with exp bias = 16
|
||||
static constexpr uint8_t binary_min = 0x04; // 0b00000100
|
||||
static constexpr uint8_t binary_max = 0x7F; // 0b01111111
|
||||
static constexpr uint8_t binary_lowest = 0xFF; // 0b11111111
|
||||
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
|
||||
// ieee mode with exp bias = 15
|
||||
// static constexpr uint8_t binary_min = 0x04; // 0b00000100
|
||||
// static constexpr uint8_t binary_max = 0x7B; // 0b01111011
|
||||
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
|
||||
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
|
||||
|
||||
__host__ __device__ static constexpr bf8_fnuz_t Min() { return bf8_fnuz_t(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr bf8_fnuz_t Max() { return bf8_fnuz_t(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr bf8_fnuz_t Lowest() { return bf8_fnuz_t(binary_lowest); }
|
||||
|
||||
__host__ __device__ static constexpr bf8_fnuz_t QuietNaN() { return bf8_fnuz_t(binary_qnan); }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<f8_ocp_t>
|
||||
{
|
||||
static constexpr uint8_t binary_min = 0x08; // 0b00001000 = 2^-6
|
||||
static constexpr uint8_t binary_max = 0x7E; // 0b01111110 = 448
|
||||
static constexpr uint8_t binary_lowest = 0xFE; // 0b11111110 = -448
|
||||
static constexpr uint8_t binary_qnan = 0x7F; // 0b01111111
|
||||
|
||||
__host__ __device__ static constexpr f8_ocp_t Min() { return bit_cast<f8_ocp_t>(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr f8_ocp_t Max() { return bit_cast<f8_ocp_t>(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr f8_ocp_t Lowest()
|
||||
{
|
||||
return bit_cast<f8_ocp_t>(binary_lowest);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr f8_ocp_t QuietNaN()
|
||||
{
|
||||
return bit_cast<f8_ocp_t>(binary_qnan);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<bf8_ocp_t>
|
||||
{
|
||||
static constexpr uint8_t binary_min = 0x04; // 0b00000100 = 2^-14
|
||||
static constexpr uint8_t binary_max = 0x7B; // 0b01111011 = 57344
|
||||
static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011 = -57344
|
||||
static constexpr uint8_t binary_qnan = 0x7D; // 0b01111101
|
||||
|
||||
__host__ __device__ static constexpr bf8_ocp_t Min() { return bit_cast<bf8_ocp_t>(binary_min); }
|
||||
|
||||
__host__ __device__ static constexpr bf8_ocp_t Max() { return bit_cast<bf8_ocp_t>(binary_max); }
|
||||
|
||||
__host__ __device__ static constexpr bf8_ocp_t Lowest()
|
||||
{
|
||||
return bit_cast<bf8_ocp_t>(binary_lowest);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bf8_ocp_t QuietNaN()
|
||||
{
|
||||
return bit_cast<bf8_ocp_t>(binary_qnan);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<f4_t>
|
||||
{
|
||||
static constexpr uint8_t binary_min_normal = 0x2; // 0b0010
|
||||
static constexpr uint8_t binary_max_normal = 0x7; // 0b0111
|
||||
static constexpr uint8_t binary_lowest_normal = 0xF; // 0b1111
|
||||
static constexpr uint8_t binary_min_subnorm = 0x1; // 0b0001
|
||||
static constexpr uint8_t binary_max_subnorm = 0x1; // 0b0001
|
||||
|
||||
static constexpr float data_max_normal_number = 6;
|
||||
static constexpr float data_min_subnormal_number = 0.5;
|
||||
|
||||
__host__ __device__ static constexpr f4_t Min() { return f4_t(binary_min_normal); }
|
||||
__host__ __device__ static constexpr f4_t Max() { return f4_t(binary_max_normal); }
|
||||
__host__ __device__ static constexpr f4_t Lowest() { return f4_t(binary_lowest_normal); }
|
||||
__host__ __device__ static constexpr f4_t MinSubnorm() { return f4_t(binary_min_subnorm); }
|
||||
__host__ __device__ static constexpr f4_t MaxSubnorm() { return f4_t(binary_max_subnorm); }
|
||||
|
||||
__host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; }
|
||||
__host__ __device__ static constexpr float DataMinSubnorm()
|
||||
{
|
||||
return data_min_subnormal_number;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<f6_t>
|
||||
{
|
||||
static constexpr uint8_t binary_min_normal = 0x08; // 0b001000
|
||||
static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111
|
||||
static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111
|
||||
static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001
|
||||
static constexpr uint8_t binary_max_subnorm = 0x07; // 0b000111
|
||||
|
||||
static constexpr float data_max_normal_number = 7.5;
|
||||
static constexpr float data_min_subnormal_number = 0.125;
|
||||
|
||||
__host__ __device__ static constexpr f6_t Min() { return f6_t(binary_min_normal & 0b111111); }
|
||||
__host__ __device__ static constexpr f6_t Max() { return f6_t(binary_max_normal & 0b111111); }
|
||||
__host__ __device__ static constexpr f6_t Lowest()
|
||||
{
|
||||
return f6_t(binary_lowest_normal & 0b111111);
|
||||
}
|
||||
__host__ __device__ static constexpr f6_t MinSubnorm()
|
||||
{
|
||||
return f6_t(binary_min_subnorm & 0b111111);
|
||||
}
|
||||
__host__ __device__ static constexpr f6_t MaxSubnorm()
|
||||
{
|
||||
return f6_t(binary_max_subnorm & 0b111111);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; }
|
||||
__host__ __device__ static constexpr float DataMinSubnorm()
|
||||
{
|
||||
return data_min_subnormal_number;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericLimits<bf6_t>
|
||||
{
|
||||
static constexpr uint8_t binary_min_normal = 0x08; // 0b001000
|
||||
static constexpr uint8_t binary_max_normal = 0x1F; // 0b011111
|
||||
static constexpr uint8_t binary_lowest_normal = 0x3F; // 0b111111
|
||||
static constexpr uint8_t binary_min_subnorm = 0x01; // 0b000001
|
||||
static constexpr uint8_t binary_max_subnorm = 0x03; // 0b000011
|
||||
|
||||
static constexpr float data_max_normal_number = 28;
|
||||
static constexpr float data_min_subnormal_number = 0.0625;
|
||||
|
||||
__host__ __device__ static constexpr bf6_t Min() { return bf6_t(binary_min_normal); }
|
||||
__host__ __device__ static constexpr bf6_t Max() { return bf6_t(binary_max_normal); }
|
||||
__host__ __device__ static constexpr bf6_t Lowest() { return bf6_t(binary_lowest_normal); }
|
||||
__host__ __device__ static constexpr bf6_t MinSubnorm() { return bf6_t(binary_min_subnorm); }
|
||||
__host__ __device__ static constexpr bf6_t MaxSubnorm() { return bf6_t(binary_max_subnorm); }
|
||||
|
||||
__host__ __device__ static constexpr float DataMaxNorm() { return data_max_normal_number; }
|
||||
__host__ __device__ static constexpr float DataMinSubnorm()
|
||||
{
|
||||
return data_min_subnormal_number;
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct NumericLimits<e8m0_bexp_t>
|
||||
{
|
||||
static constexpr e8m0_bexp_t binary_min = 0x00; // 0b00000000
|
||||
static constexpr e8m0_bexp_t binary_max = 0xFE; // 0b11111110
|
||||
static constexpr e8m0_bexp_t binary_qnan = 0xFF; // 0b11111111
|
||||
static constexpr e8m0_bexp_t binary_1 = 0x7F; // 0b01111111
|
||||
static constexpr e8m0_bexp_t binary_2 = 0x80; // 0b10000000
|
||||
static constexpr e8m0_bexp_t binary_3 = 0x82; // 0b10000010
|
||||
static constexpr e8m0_bexp_t binary_135 = 0x87; // 0b10000111
|
||||
static constexpr e8m0_bexp_t binary_142 = 0x8E; // 0b10001110
|
||||
|
||||
__host__ __device__ static constexpr e8m0_bexp_t Min() { return e8m0_bexp_t(binary_min); }
|
||||
__host__ __device__ static constexpr e8m0_bexp_t Max() { return e8m0_bexp_t(binary_max); }
|
||||
__host__ __device__ static constexpr e8m0_bexp_t QuietNaN() { return e8m0_bexp_t(binary_qnan); }
|
||||
__host__ __device__ static constexpr e8m0_bexp_t Binary_1() { return e8m0_bexp_t(binary_1); }
|
||||
__host__ __device__ static constexpr e8m0_bexp_t Binary_2() { return e8m0_bexp_t(binary_2); }
|
||||
__host__ __device__ static constexpr e8m0_bexp_t Binary_3() { return e8m0_bexp_t(binary_3); }
|
||||
__host__ __device__ static constexpr e8m0_bexp_t Binary_135()
|
||||
{
|
||||
return e8m0_bexp_t(binary_135);
|
||||
}
|
||||
__host__ __device__ static constexpr e8m0_bexp_t Binary_142()
|
||||
{
|
||||
return e8m0_bexp_t(binary_142);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
199
include/ck/utility/numeric_utils.hpp
Normal file
199
include/ck/utility/numeric_utils.hpp
Normal file
@@ -0,0 +1,199 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename T>
|
||||
struct NumericUtils
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericUtils<e8m0_bexp_t>
|
||||
{
|
||||
static constexpr int exp = 8;
|
||||
static constexpr int mant = 0;
|
||||
static constexpr int bias = 127;
|
||||
|
||||
static constexpr int unbiased_exp_min = -127;
|
||||
static constexpr int unbiased_exp_max = 127;
|
||||
static constexpr int biased_exp_min = 0;
|
||||
static constexpr int biased_exp_max = 254;
|
||||
|
||||
using bitwise_type = uint8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericUtils<float>
|
||||
{
|
||||
static constexpr int exp = 8;
|
||||
static constexpr int mant = 23;
|
||||
static constexpr int bias = 127;
|
||||
static constexpr uint32_t nan_mask = 0x7F800000;
|
||||
static constexpr uint32_t head_mask = 0xFF800000;
|
||||
static constexpr uint32_t mant_mask = 0x7FFFFF;
|
||||
static constexpr uint32_t exp_mask = 0xFF;
|
||||
static constexpr uint32_t Inf = 0x7F800000;
|
||||
static constexpr uint32_t NegInf = 0xFF800000;
|
||||
static constexpr uint32_t NaN = 0x7F800001;
|
||||
static constexpr uint32_t Neg0 = 0x80000000;
|
||||
static constexpr bool has_inf = true;
|
||||
using bitwise_type = uint32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericUtils<half_t>
|
||||
{
|
||||
static constexpr int exp = 5;
|
||||
static constexpr int mant = 10;
|
||||
static constexpr int bias = 15;
|
||||
static constexpr uint16_t nan_mask = 0x7C00;
|
||||
static constexpr uint16_t head_mask = 0xFC00;
|
||||
static constexpr uint16_t mant_mask = 0x3FF;
|
||||
static constexpr uint16_t exp_mask = 0x1F;
|
||||
static constexpr uint32_t Inf = 0x7C00;
|
||||
static constexpr uint32_t NegInf = 0xFC00;
|
||||
static constexpr uint32_t NaN = 0x7C01;
|
||||
static constexpr uint32_t Neg0 = 0x8000;
|
||||
static constexpr bool has_inf = true;
|
||||
using bitwise_type = uint16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericUtils<bhalf_t>
|
||||
{
|
||||
static constexpr int exp = 8;
|
||||
static constexpr int mant = 7;
|
||||
static constexpr int bias = 128; // negative zero nan mode
|
||||
// static constexpr int bias = 127; // ieee mode
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericUtils<f8_fnuz_t>
|
||||
{
|
||||
static constexpr int exp = 4;
|
||||
static constexpr int mant = 3;
|
||||
static constexpr int bias = 8; // negative zero nan mode
|
||||
// static constexpr int bias = 7; // ieee mode
|
||||
static constexpr bool has_inf = false;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericUtils<bf8_fnuz_t>
|
||||
{
|
||||
static constexpr int exp = 5;
|
||||
static constexpr int mant = 2;
|
||||
static constexpr int bias = 16; // negative zero nan mode
|
||||
// static constexpr int bias = 15; // ieee mode
|
||||
static constexpr bool has_inf = false;
|
||||
};
|
||||
template <>
|
||||
struct NumericUtils<f8_ocp_t>
|
||||
{
|
||||
static constexpr int exp = 4;
|
||||
static constexpr int mant = 3;
|
||||
static constexpr int bias = 7;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericUtils<bf8_ocp_t>
|
||||
{
|
||||
static constexpr int exp = 5;
|
||||
static constexpr int mant = 2;
|
||||
static constexpr int bias = 15;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericUtils<f4_t>
|
||||
{
|
||||
static constexpr int exp = 2;
|
||||
static constexpr int mant = 1;
|
||||
static constexpr int bias = 1;
|
||||
static constexpr uint32_t sr_shift = 10;
|
||||
|
||||
static constexpr int unbiased_exp_min = 0;
|
||||
static constexpr int unbiased_exp_max = 2;
|
||||
static constexpr int biased_exp_min = 1;
|
||||
static constexpr int biased_exp_max = 3;
|
||||
|
||||
static constexpr uint8_t positive_zero_mask = 0b0000;
|
||||
static constexpr uint8_t negative_zero_mask = 0b1000;
|
||||
|
||||
static constexpr uint8_t one_mask = 0b0010;
|
||||
static constexpr uint8_t set_sign_mask = 0b0111;
|
||||
|
||||
static constexpr uint8_t data_max_positive_normal_mask = 0b0111;
|
||||
static constexpr uint8_t data_max_negative_normal_mask = 0b1111;
|
||||
|
||||
static constexpr uint8_t data_max_positive_subnormal_mask = 0b0001;
|
||||
static constexpr uint8_t data_max_negative_subnormal_mask = 0b1001;
|
||||
|
||||
static constexpr bool has_inf = false;
|
||||
|
||||
using bitwise_type = uint8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericUtils<f6_t>
|
||||
{
|
||||
static constexpr int exp = 2;
|
||||
static constexpr int mant = 3;
|
||||
static constexpr int bias = 1;
|
||||
static constexpr uint32_t sr_shift = 12;
|
||||
|
||||
static constexpr int unbiased_exp_min = 0;
|
||||
static constexpr int unbiased_exp_max = 2;
|
||||
static constexpr int biased_exp_min = 1;
|
||||
static constexpr int biased_exp_max = 3;
|
||||
|
||||
static constexpr uint8_t positive_zero_mask = 0b000000;
|
||||
static constexpr uint8_t negative_zero_mask = 0b100000;
|
||||
|
||||
static constexpr uint8_t set_sign_mask = 0b011111;
|
||||
|
||||
static constexpr uint8_t data_max_positive_normal_mask = 0b011111;
|
||||
static constexpr uint8_t data_max_negative_normal_mask = 0b111111;
|
||||
|
||||
static constexpr uint8_t data_max_positive_subnormal_mask = 0b000111;
|
||||
static constexpr uint8_t data_max_negative_subnormal_mask = 0b100111;
|
||||
|
||||
static constexpr bool has_inf = false;
|
||||
static constexpr bool has_nan = false;
|
||||
static constexpr bool has_zero = true;
|
||||
|
||||
using bitwise_type = uint8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumericUtils<bf6_t>
|
||||
{
|
||||
static constexpr int exp = 3;
|
||||
static constexpr int mant = 2;
|
||||
static constexpr int bias = 3;
|
||||
static constexpr uint32_t sr_shift = 11;
|
||||
|
||||
static constexpr int unbiased_exp_min = -2;
|
||||
static constexpr int unbiased_exp_max = 4;
|
||||
static constexpr int biased_exp_min = 1;
|
||||
static constexpr int biased_exp_max = 7;
|
||||
|
||||
static constexpr uint8_t positive_zero_mask = 0b000000;
|
||||
static constexpr uint8_t negative_zero_mask = 0b100000;
|
||||
|
||||
static constexpr uint8_t set_sign_mask = 0b011111;
|
||||
|
||||
static constexpr uint8_t data_max_positive_normal_mask = 0b011111;
|
||||
static constexpr uint8_t data_max_negative_normal_mask = 0b111111;
|
||||
|
||||
static constexpr uint8_t data_max_positive_subnormal_mask = 0b000011;
|
||||
static constexpr uint8_t data_max_negative_subnormal_mask = 0b100011;
|
||||
|
||||
static constexpr bool has_inf = false;
|
||||
static constexpr bool has_nan = false;
|
||||
static constexpr bool has_zero = true;
|
||||
|
||||
using bitwise_type = uint8_t;
|
||||
};
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user