mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Add pk_fp4 data type (#2422)
* [draft] Add pk_fp4 and test * Add hw conversion for fp4 * Refine test code and pk_fp4 constructor. * fix test indent * modify according to comment. * fix clang-format * modify according comments. --------- Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
213
include/ck_tile/core/numeric/mxfp_convert.hpp
Normal file
213
include/ck_tile/core/numeric/mxfp_convert.hpp
Normal file
@@ -0,0 +1,213 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
// modify from include/ck/utility/mxfp_utils.hpp
|
||||
|
||||
template <typename T>
|
||||
struct numeric_utils : numeric_traits<T>
|
||||
{
|
||||
|
||||
using traits = numeric_traits<T>;
|
||||
using _numeric = numeric<T>;
|
||||
using raw_type = typename T::raw_type;
|
||||
|
||||
static constexpr int exp_mask = (1 << traits::exp) - 1;
|
||||
|
||||
static constexpr int get_exponent(raw_type x)
|
||||
{
|
||||
// TODO: check if repeated calls are optimized.
|
||||
return (x >> traits::mant) & exp_mask;
|
||||
}
|
||||
static constexpr bool is_positive(raw_type x)
|
||||
{
|
||||
return (x >> (traits::exp + traits::mant)) == _numeric::binary_zero;
|
||||
}
|
||||
static constexpr bool is_subnormal(raw_type x)
|
||||
{
|
||||
return get_exponent(x) == _numeric::binary_zero;
|
||||
}
|
||||
// TODO: replace double with template arg?
|
||||
static constexpr double get_mantissa(raw_type x)
|
||||
{
|
||||
double mantissa = is_subnormal(x) ? 0.0f : 1.0f;
|
||||
for(uint32_t i = 0; i < traits::mant; ++i)
|
||||
{
|
||||
mantissa += std::ldexp(static_cast<float>(x & 0b1), -(traits::mant - i));
|
||||
x >>= 1;
|
||||
}
|
||||
return mantissa;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, int scale_exp = 127)
|
||||
{
|
||||
using utils = numeric_utils<T>;
|
||||
static constexpr int e8m0_bias = 127; // TODO: make it generic.
|
||||
float sign = utils::is_positive(data) ? 1.0 : -1.0;
|
||||
int exp = (utils::is_subnormal(data) ? 1 : utils::get_exponent(data)) - utils::bias;
|
||||
float mant = utils::get_mantissa(data);
|
||||
|
||||
return std::ldexp(sign * mant, exp + scale_exp - e8m0_bias);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE typename T::raw_type convert_to_type(float value)
|
||||
{
|
||||
using bitwise_type = typename numeric_traits<T>::bitwise_type;
|
||||
|
||||
if(std::abs(value) > float(numeric<T>::max()))
|
||||
{
|
||||
float max_value = numeric<T>::max();
|
||||
|
||||
// cppcheck-suppress redundantAssignment
|
||||
uint32_t max_bitwise = bit_cast<uint32_t>(max_value);
|
||||
|
||||
// cppcheck-suppress redundantAssignment
|
||||
bitwise_type sign =
|
||||
bit_cast<uint32_t>(value) >> (numeric_traits<float>::exp + numeric_traits<float>::mant);
|
||||
bitwise_type exp =
|
||||
((max_bitwise >> numeric_traits<float>::mant) & numeric_traits<float>::exp_mask) -
|
||||
(numeric_traits<float>::bias - numeric_traits<T>::bias);
|
||||
bitwise_type mantissa =
|
||||
max_bitwise >> (numeric_traits<float>::mant - numeric_traits<T>::mant);
|
||||
|
||||
uint32_t mant_prev = max_bitwise >> (numeric_traits<float>::mant - numeric_traits<T>::mant);
|
||||
mant_prev &= ((1 << numeric_traits<T>::mant) - 1);
|
||||
mant_prev--;
|
||||
|
||||
mant_prev <<= (numeric_traits<float>::mant - numeric_traits<T>::mant);
|
||||
uint32_t prev_bit =
|
||||
((max_bitwise >> numeric_traits<float>::mant) << numeric_traits<float>::mant) |
|
||||
mant_prev;
|
||||
|
||||
float prev_val = bit_cast<float>(prev_bit);
|
||||
float diff = max_value - prev_val;
|
||||
|
||||
float actual_max = max_value + (diff / 2);
|
||||
|
||||
if(std::abs(value) < actual_max)
|
||||
{
|
||||
return sign << ((numeric_traits<T>::exp + numeric_traits<T>::mant)) |
|
||||
(exp << numeric_traits<T>::mant) | mantissa;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(!numeric<T>::has_inf())
|
||||
{
|
||||
|
||||
return (1 << (numeric_traits<T>::mant + numeric_traits<T>::exp)) - 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
exp++;
|
||||
return sign << ((numeric_traits<T>::exp + numeric_traits<T>::mant)) |
|
||||
(exp << numeric_traits<T>::mant);
|
||||
}
|
||||
}
|
||||
}
|
||||
const int mfmt = numeric_traits<float>::mant;
|
||||
uint32_t x;
|
||||
x = bit_cast<uint32_t>(value);
|
||||
|
||||
uint32_t head, mantissa;
|
||||
int32_t exponent, bias;
|
||||
uint32_t sign;
|
||||
|
||||
head = x & numeric_traits<float>::head_mask;
|
||||
mantissa = x & numeric_traits<float>::mant_mask;
|
||||
exponent = (head >> numeric_traits<float>::mant) & numeric_traits<float>::exp_mask;
|
||||
sign = head >> (numeric_traits<float>::mant + numeric_traits<float>::exp);
|
||||
bias = numeric_traits<float>::bias;
|
||||
|
||||
if(x == 0)
|
||||
{
|
||||
return 0b0;
|
||||
}
|
||||
|
||||
const int mini_bias = numeric_traits<T>::bias;
|
||||
const int mini_denormal_act_exponent = 1 - mini_bias;
|
||||
|
||||
int act_exponent, out_exponent, exponent_diff;
|
||||
|
||||
bool is_subnorm = false;
|
||||
|
||||
if(exponent == 0)
|
||||
{
|
||||
act_exponent = exponent - bias + 1;
|
||||
exponent_diff = mini_denormal_act_exponent - act_exponent;
|
||||
is_subnorm = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
act_exponent = exponent - bias;
|
||||
if(act_exponent <= mini_denormal_act_exponent)
|
||||
{
|
||||
exponent_diff = mini_denormal_act_exponent - act_exponent;
|
||||
is_subnorm = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
exponent_diff = 0;
|
||||
}
|
||||
mantissa += (1UL << mfmt);
|
||||
}
|
||||
|
||||
auto shift_amount = (mfmt - numeric_traits<T>::mant + exponent_diff);
|
||||
shift_amount = (shift_amount >= 64) ? 63 : shift_amount;
|
||||
bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));
|
||||
|
||||
float min_subnorm = float(numeric<T>::epsilon()) * (sign ? -1 : 1);
|
||||
|
||||
if(is_subnorm && std::abs(value) < std::abs(min_subnorm))
|
||||
{
|
||||
// closer to 0
|
||||
if(std::abs(value) <= std::abs(min_subnorm - value))
|
||||
return sign << (numeric_traits<T>::exp + numeric_traits<T>::mant);
|
||||
else
|
||||
return 1 | (sign << (numeric_traits<T>::exp + numeric_traits<T>::mant));
|
||||
}
|
||||
|
||||
if(exponent_diff > 0)
|
||||
mantissa >>= exponent_diff;
|
||||
else if(exponent_diff == -1)
|
||||
mantissa <<= -exponent_diff;
|
||||
bool implicit_one = mantissa & (1 << mfmt);
|
||||
out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1);
|
||||
|
||||
uint32_t drop_mask = (1UL << (mfmt - numeric_traits<T>::mant)) - 1;
|
||||
bool odd = mantissa & (1UL << (mfmt - numeric_traits<T>::mant));
|
||||
mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask;
|
||||
|
||||
if(out_exponent == 0)
|
||||
{
|
||||
if((1UL << mfmt) & mantissa)
|
||||
{
|
||||
out_exponent = 1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if((1UL << (mfmt + 1)) & mantissa)
|
||||
{
|
||||
mantissa >>= 1;
|
||||
out_exponent++;
|
||||
}
|
||||
}
|
||||
|
||||
mantissa >>= (mfmt - numeric_traits<T>::mant);
|
||||
|
||||
if(out_exponent == 0 && mantissa == 0)
|
||||
{
|
||||
return sign << (numeric_traits<T>::exp + numeric_traits<T>::mant);
|
||||
}
|
||||
|
||||
mantissa &= (1UL << numeric_traits<T>::mant) - 1;
|
||||
return (sign << (numeric_traits<T>::exp + numeric_traits<T>::mant)) |
|
||||
(out_exponent << numeric_traits<T>::mant) | mantissa;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -103,94 +103,92 @@ struct numeric_traits<float>
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
|
||||
attr_ bool operator==(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) == static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator!=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) != static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator<(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) < static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator<=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) <= static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator>(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) > static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator>=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) >= static_cast<float>(y); \
|
||||
} \
|
||||
attr_ type_ operator+(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator-(const type_& x) \
|
||||
{ \
|
||||
constexpr uint32_t bits = sizeof(type_) * 8; \
|
||||
constexpr uint32_t mask = 1 << (bits - 1); \
|
||||
type_ y = x; \
|
||||
y.data ^= static_cast<typename type_::raw_type>(mask); \
|
||||
return y; \
|
||||
} \
|
||||
attr_ type_ operator-(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator*(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator/(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_& operator+=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator-=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator*=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator/=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator++(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator--(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_ operator++(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return y; \
|
||||
} \
|
||||
attr_ type_ operator--(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return y; \
|
||||
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
|
||||
attr_ bool operator==(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return std::abs(static_cast<float>(x) - static_cast<float>(y)) < \
|
||||
static_cast<float>(numeric<type_>::epsilon()); \
|
||||
} \
|
||||
attr_ bool operator!=(const type_& x, const type_& y) { return not operator==(x, y); } \
|
||||
attr_ bool operator<(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) < static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator<=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) <= static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator>(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) > static_cast<float>(y); \
|
||||
} \
|
||||
attr_ bool operator>=(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return static_cast<float>(x) >= static_cast<float>(y); \
|
||||
} \
|
||||
attr_ type_ operator+(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator-(const type_& x) \
|
||||
{ \
|
||||
constexpr uint32_t bits = sizeof(type_) * 8; \
|
||||
constexpr uint32_t mask = 1 << (bits - 1); \
|
||||
type_ y = x; \
|
||||
y.data ^= static_cast<typename type_::raw_type>(mask); \
|
||||
return y; \
|
||||
} \
|
||||
attr_ type_ operator-(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator*(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_ operator/(const type_& x, const type_& y) \
|
||||
{ \
|
||||
return type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
} \
|
||||
attr_ type_& operator+=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator-=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator*=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator/=(type_& x, const type_& y) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator++(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_& operator--(type_& x) \
|
||||
{ \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return x; \
|
||||
} \
|
||||
attr_ type_ operator++(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) + 1.f); \
|
||||
return y; \
|
||||
} \
|
||||
attr_ type_ operator--(type_& x, int) \
|
||||
{ \
|
||||
type_ y(x); \
|
||||
x = type_(static_cast<float>(x) - 1.f); \
|
||||
return y; \
|
||||
}
|
||||
|
||||
324
include/ck_tile/core/numeric/pk_fp4.hpp
Normal file
324
include/ck_tile/core/numeric/pk_fp4.hpp
Normal file
@@ -0,0 +1,324 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/mxfp_convert.hpp"
|
||||
|
||||
#if defined(__gfx950__)
|
||||
#define CK_TILE_FP4_CVT_DEVICE 1
|
||||
#else
|
||||
#define CK_TILE_FP4_CVT_DEVICE 0
|
||||
#endif
|
||||
|
||||
#define TEST_convert_with_table 0
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using fp32_t = float;
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float);
|
||||
|
||||
// TODO: Add stochastic method
|
||||
struct pk_float4_e2m1_t
|
||||
{
|
||||
static constexpr int exponent = 2;
|
||||
static constexpr int mantissa = 1;
|
||||
static constexpr int bias = 1;
|
||||
// TODO: Can we merge raw_type and type?
|
||||
using raw_type = uint8_t;
|
||||
using type = raw_type;
|
||||
raw_type data;
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t() : data{type{}} {}
|
||||
template <typename T, typename = std::enable_if_t<std::is_integral_v<T>>>
|
||||
CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t(T init) : data{static_cast<type>(init)}
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init) : data{float_to_e2m1(init)}
|
||||
{
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; }
|
||||
CK_TILE_HOST_DEVICE constexpr operator float() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp16_t() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator bf16_t() const;
|
||||
CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const;
|
||||
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE raw_type unpack(number<I>) const;
|
||||
CK_TILE_HOST_DEVICE static pk_float4_e2m1_t pack(const type x0, const type x1)
|
||||
{
|
||||
return (x1 << 4) | (x0 & 0b00001111);
|
||||
}
|
||||
|
||||
#if TEST_convert_with_table
|
||||
static constexpr float e2m1_to_fp32_table[16] = {
|
||||
0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6};
|
||||
static constexpr fp16_t e2m1_to_fp16_table[16] = {
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x0000)), // 0
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x3800)), // 0.5
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x3C00)), // 1
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x3E00)), // 1.5
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x4000)), // 2
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x4200)), // 3
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x4400)), // 4
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x4600)), // 6
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0x8000)), // -0
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xB800)), // -0.5
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xBC00)), // -1
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xBE00)), // -1.5
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xC000)), // -2
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xC200)), // -3
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xC400)), // -4
|
||||
bit_cast<fp16_t>(static_cast<uint16_t>(0xC600)) // -6
|
||||
};
|
||||
#endif
|
||||
};
|
||||
|
||||
using pk_fp4_t = pk_float4_e2m1_t;
|
||||
using pk_fp4_raw_t = typename pk_fp4_t::raw_type;
|
||||
|
||||
template <>
|
||||
struct numeric_traits<pk_fp4_t>
|
||||
{
|
||||
using bitwise_type = pk_fp4_raw_t;
|
||||
|
||||
static constexpr int exp = 2;
|
||||
static constexpr int mant = 1;
|
||||
static constexpr int bias = 1;
|
||||
static constexpr int PackedSize = 2;
|
||||
};
|
||||
|
||||
// limits
|
||||
template <class T>
|
||||
struct numeric;
|
||||
|
||||
template <>
|
||||
struct numeric<pk_fp4_t>
|
||||
{
|
||||
static constexpr pk_fp4_raw_t binary_min_normal = 0b00100010; // 1
|
||||
static constexpr pk_fp4_raw_t binary_max_normal = 0b01110111; // 6
|
||||
static constexpr pk_fp4_raw_t binary_lowest_normal = 0b11111111; // -6
|
||||
static constexpr pk_fp4_raw_t binary_min_subnorm = 0b00010001; // 0.5
|
||||
static constexpr pk_fp4_raw_t binary_max_subnorm = 0b00010001; // 0.5
|
||||
static constexpr pk_fp4_raw_t binary_zero = 0b00000000; // 0
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t min() { return binary_min_normal; }
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t max() { return binary_max_normal; }
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t lowest() { return binary_lowest_normal; }
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t epsilon() { return binary_min_subnorm; }
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t round_error() { return binary_min_subnorm; }
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t zero() { return binary_zero; }
|
||||
CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min() { return binary_min_subnorm; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool has_inf() { return false; }
|
||||
// N/A
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t infinity() { return max(); }
|
||||
// N/A
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t quiet_NaN() { return max(); }
|
||||
// N/A
|
||||
CK_TILE_HOST_DEVICE static constexpr pk_fp4_t signaling_NaN() { return max(); }
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE pk_fp4_raw_t pk_fp4_t::unpack(number<I>) const
|
||||
{
|
||||
static_assert(I < 2, "Index is out of range.");
|
||||
if constexpr(I == 1)
|
||||
return (data >> 4);
|
||||
else
|
||||
return data & 0b00001111;
|
||||
}
|
||||
CK_TILE_ARITHMETIC_USING_FLOAT(CK_TILE_HOST_DEVICE, pk_fp4_t)
|
||||
// TODO: consider replace this macro to improve performance
|
||||
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
namespace impl {
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T _from_f4(pk_fp4_raw_t src, float scale = 1.0f)
|
||||
{
|
||||
// TODO: check the order
|
||||
if constexpr(std::is_same_v<T, fp32_t>)
|
||||
return fp32x2_t(__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0))[0];
|
||||
else if constexpr(std::is_same_v<T, fp32x2_t>)
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, fp16_t>)
|
||||
return fp16x2_t(__builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0))[0];
|
||||
else if constexpr(std::is_same_v<T, fp16x2_t>)
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, bf16_t>)
|
||||
return bf16x2_t(__builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0))[0];
|
||||
else if constexpr(std::is_same_v<T, bf16x2_t>)
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0);
|
||||
else
|
||||
static_assert(std::false_type::value, "Unsupported type.");
|
||||
return T{};
|
||||
}
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE pk_fp4_raw_t _to_f4(T src, float scale = 1.0f)
|
||||
{
|
||||
// TODO: check the order
|
||||
union
|
||||
{
|
||||
uint32_t u32;
|
||||
pk_fp4_raw_t pf4[4];
|
||||
} cvt{0};
|
||||
if constexpr(std::is_same_v<T, fp32_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src, src, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, fp32x2_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src[0], src[1], scale, 0);
|
||||
else if constexpr(std::is_same_v<T, fp16_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, fp16x2_t{src, src}, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, fp16x2_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, src, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, bf16_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32, bf16x2_t{src, src}, scale, 0);
|
||||
else if constexpr(std::is_same_v<T, bf16x2_t>)
|
||||
cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32, src, scale, 0);
|
||||
else
|
||||
static_assert(std::false_type::value, "Unsupported type.");
|
||||
return cvt.pf4[0];
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator bf16_t() const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<bf16_t>(data);
|
||||
#else
|
||||
return bf16_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{})))};
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator bf16x2_t() const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<bf16x2_t>(data);
|
||||
#else
|
||||
return bf16x2_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}))),
|
||||
type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{})))};
|
||||
#endif
|
||||
}
|
||||
|
||||
// TODO: make float_to_e2m1 generic so that we can convert from directrly.
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
#else
|
||||
return convert_to_type<pk_fp4_t>(x);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x) { return fp32x2_t(x); }
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x) { return fp16x2_t(x); }
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x) { return bf16x2_t(x); }
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x) { return float_to_e2m1(x); }
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
#else
|
||||
return float_to_e2m1(type_convert<float>(x));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
#else
|
||||
return float_to_e2m1(type_convert<float>(x));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
#else
|
||||
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0])),
|
||||
float_to_e2m1(type_convert<float>(x[1])));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
#else
|
||||
return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0])),
|
||||
float_to_e2m1(type_convert<float>(x[1])));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x);
|
||||
#else
|
||||
return pk_fp4_t::pack(float_to_e2m1(x[0]), float_to_e2m1(x[1]));
|
||||
#endif
|
||||
}
|
||||
|
||||
#if TEST_convert_with_table == 0
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator float() const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp32_t>(data);
|
||||
#else
|
||||
return convert_to_float<pk_fp4_t>(unpack(number<0>{}));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp32x2_t() const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp32x2_t>(data);
|
||||
#else
|
||||
return fp32x2_t{convert_to_float<pk_fp4_t>(unpack(number<0>{})),
|
||||
convert_to_float<pk_fp4_t>(unpack(number<1>{}))};
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16_t() const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp16_t>(data);
|
||||
#else
|
||||
return fp16_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{})))};
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16x2_t() const
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_from_f4<fp16x2_t>(data);
|
||||
#else
|
||||
return fp16x2_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}))),
|
||||
type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{})))};
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator float() const
|
||||
{
|
||||
return e2m1_to_fp32_table[data & 0xf];
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp32x2_t() const
|
||||
{
|
||||
return fp32x2_t{e2m1_to_fp32_table[data & 0xf], e2m1_to_fp32_table[data >> 4]};
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16_t() const
|
||||
{
|
||||
return e2m1_to_fp16_table[data & 0xf];
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16x2_t() const
|
||||
{
|
||||
return fp16x2_t{e2m1_to_fp16_table[data & 0xf], e2m1_to_fp16_table[data >> 4]};
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -11,6 +11,7 @@
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/float8.hpp"
|
||||
#include "ck_tile/core/numeric/int8.hpp"
|
||||
#include "ck_tile/core/numeric/mxfp_convert.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -64,6 +65,21 @@ CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
|
||||
CK_TILE_TYPE_CONVERT(float, float, int8_t, int8)
|
||||
CK_TILE_TYPE_CONVERT(int8_t, int8, float, float)
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
#include "ck_tile/core/numeric/pk_fp4.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp32x2_t, fp32x2)
|
||||
CK_TILE_TYPE_CONVERT(fp32x2_t, fp32x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16x2_t, fp16x2)
|
||||
CK_TILE_TYPE_CONVERT(fp16x2_t, fp16x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16x2_t, bf16x2)
|
||||
CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, pk_fp4_t, pk_fp4)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, float, float)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, bf16_t, bf16)
|
||||
CK_TILE_TYPE_CONVERT(pk_fp4_t, pk_fp4, fp16_t, fp16)
|
||||
#undef CK_TILE_TYPE_CONVERT
|
||||
#endif
|
||||
|
||||
|
||||
Reference in New Issue
Block a user