mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add MX FP4 device conversion tests (#1889)
* Add conversion tests
* Fix ctor
* Fix nan logic
* Fix conversion logic
* Permute packed f4_t values
* Fix conversion to float, repack vector elements
* Fix device tests
* Permute elements in a vector
* Add a repro test
* Add a conversion for a repro test
* Update test vectors
* Update conversion
* Fix the test
* Update test vector generator
* Fix vector sr conversion
* Permute conversion args
* Update conversion
* Test
* Fix packing
* Simplify conversion function
* Pack conversion in a loop
* Pack conversion in a loop
* Pack another conversion in a loop
* Pack one more conversion in a loop
* Pack the last conversion in a loop
* Clean up
* Add printf to fix intrinsic
* Add a sw-based workaround
[ROCm/composable_kernel commit: 441343a23d]
This commit is contained in:
@@ -245,6 +245,9 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
|
||||
// workaround: compiler issue on gfx908
|
||||
#define CK_WORKAROUND_SWDEV_388832 1
|
||||
|
||||
// workaround: compiler issue on gfx950
|
||||
#define CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION 1
|
||||
|
||||
// denorm test fix, necessary for gfx90a
|
||||
#ifndef CK_GFX90A_DENORM_WORKAROUND
|
||||
#define CK_GFX90A_DENORM_WORKAROUND 0
|
||||
|
||||
@@ -36,22 +36,22 @@ struct f4x2_pk_t
|
||||
{
|
||||
using type = uint8_t;
|
||||
type data;
|
||||
f4x2_pk_t() : data{type{}} {}
|
||||
f4x2_pk_t(type init) : data{init} {}
|
||||
__host__ __device__ f4x2_pk_t() : data{type{}} {}
|
||||
__host__ __device__ f4x2_pk_t(type init) : data{init} {}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ inline type unpack(Number<I>) const
|
||||
{
|
||||
static_assert(I < 2, "Index is out of range.");
|
||||
if constexpr(I == 0)
|
||||
return data & 0b00001111;
|
||||
else
|
||||
return (data >> 4);
|
||||
else
|
||||
return data & 0b00001111;
|
||||
}
|
||||
|
||||
__host__ __device__ inline type pack(const type x0, const type x1)
|
||||
{
|
||||
return (x1 << 4) | (x0 & 0b00001111);
|
||||
return (x0 << 4) | (x1 & 0b00001111);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
#pragma once
|
||||
@@ -14,7 +14,7 @@ __host__ __device__ inline bool is_nan<f4_t>(e8m0_bexp_t const scale,
|
||||
f4_t const dataBytes [[maybe_unused]])
|
||||
{
|
||||
// no need to check for data as it does not have NaN representation
|
||||
return scale == NumericLimits<e8m0_bexp_t>::QuietNaN();
|
||||
return scale.is_nan();
|
||||
}
|
||||
|
||||
// no infinity representation in ocp_e2m1_mxfp4 will always return false
|
||||
@@ -27,11 +27,9 @@ __host__ __device__ inline bool is_inf<f4_t>(e8m0_bexp_t const scale [[maybe_unu
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ inline bool is_zero<f4_t>(e8m0_bexp_t const scale, f4_t const data)
|
||||
__host__ __device__ inline bool is_zero<f4_t>(e8m0_bexp_t const scale [[maybe_unused]],
|
||||
f4_t const data)
|
||||
{
|
||||
if(is_nan<f4_t>(scale, data))
|
||||
return false;
|
||||
|
||||
// no need to check for scale as it does not have a 0 representation
|
||||
f4_t result = (data & 0b00001111) & NumericUtils<f4_t>::set_sign_mask;
|
||||
|
||||
|
||||
@@ -99,7 +99,7 @@ template <typename T>
|
||||
__host__ __device__ T sat_convert_to_type_sr(float value, uint32_t seed);
|
||||
|
||||
template <typename T>
|
||||
inline T convert_to_type(float value)
|
||||
__host__ __device__ inline T convert_to_type(float value)
|
||||
{
|
||||
using bitwise_type = typename NumericUtils<T>::bitwise_type;
|
||||
|
||||
@@ -258,7 +258,7 @@ inline T convert_to_type(float value)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T convert_to_type_sr(float value, uint32_t seed)
|
||||
__host__ __device__ inline T convert_to_type_sr(float value, uint32_t seed)
|
||||
{
|
||||
if(std::abs(value) > NumericLimits<T>::Max())
|
||||
{
|
||||
|
||||
@@ -377,12 +377,15 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_b
|
||||
f4x2_t f4x2_array[4];
|
||||
} value{};
|
||||
value.f4x2_array[0] = x;
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert<float>(scale), 0);
|
||||
float2_t tmp =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert<float>(scale), 0);
|
||||
// permute high bits and low bits to match the order of the original vector
|
||||
return float2_t{tmp[1], tmp[0]};
|
||||
#else
|
||||
float2_t ret{utils::to_float<f4_t>(
|
||||
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})),
|
||||
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})),
|
||||
utils::to_float<f4_t>(
|
||||
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}))};
|
||||
scale, x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}))};
|
||||
return ret;
|
||||
#endif
|
||||
}
|
||||
@@ -398,109 +401,16 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m
|
||||
f4x32_t f4x32_array;
|
||||
f4x2_t fp4x2[16];
|
||||
} value{x};
|
||||
union
|
||||
{
|
||||
uint32_t bitwise;
|
||||
f4x2_t f4x2_array[4];
|
||||
} bitwise_value{};
|
||||
float2_t op;
|
||||
float32_t ret;
|
||||
// TODO: pack in a loop
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[0];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[0] = op[0];
|
||||
ret[1] = op[1];
|
||||
float f_scale = type_convert<float>(scale);
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[1];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[2] = op[0];
|
||||
ret[3] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[2];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[4] = op[0];
|
||||
ret[5] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[3];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[6] = op[0];
|
||||
ret[7] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[4];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[8] = op[0];
|
||||
ret[9] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[5];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[10] = op[0];
|
||||
ret[11] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[6];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[12] = op[0];
|
||||
ret[13] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[7];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[14] = op[0];
|
||||
ret[15] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[8];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[16] = op[0];
|
||||
ret[17] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[9];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[18] = op[0];
|
||||
ret[19] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[10];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[20] = op[0];
|
||||
ret[21] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[11];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[22] = op[0];
|
||||
ret[23] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[12];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[24] = op[0];
|
||||
ret[25] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[13];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[26] = op[0];
|
||||
ret[27] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[14];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[28] = op[0];
|
||||
ret[29] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[15];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[30] = op[0];
|
||||
ret[31] = op[1];
|
||||
ck::static_for<0, 32 / 2, 1>{}([&](auto idx) {
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], f_scale, 0);
|
||||
// permute high bits and low bits to match the order of the original vector
|
||||
ret[2 * idx] = op[1];
|
||||
ret[2 * idx + 1] = op[0];
|
||||
});
|
||||
|
||||
return ret;
|
||||
#else
|
||||
@@ -515,106 +425,18 @@ inline __host__ __device__ float32_t scaled_type_convert<float32_t, f4x32_t>(e8m
|
||||
f4x2_t f4x2_array[16];
|
||||
f4x32_t f4x32_array;
|
||||
} f4_values{bit_cast<__uint128_t>(x)};
|
||||
// TODO: pack in a loop
|
||||
float_values.float_array[0] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[1] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[2] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[3] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[4] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[5] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[6] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[7] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
|
||||
float_values.float_array[0] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[1] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[2] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[3] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[4] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[5] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[6] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[7] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
ck::static_for<0, 32 / 2, 1>{}([&](auto idx) {
|
||||
float_values.float_array[2 * idx] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[idx].template AsType<f4x2_pk_t>()[Number<0>{}].template unpack<>(
|
||||
Number<0>{}));
|
||||
|
||||
float_values.float_array[0] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[1] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[2] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[3] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[4] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[5] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[6] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[7] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
|
||||
float_values.float_array[0] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[1] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[2] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[3] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[4] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[5] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[6] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[7] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[2 * idx + 1] = utils::to_float<f4_t>(
|
||||
scale,
|
||||
f4_values.f4x2_array[idx].template AsType<f4x2_pk_t>()[Number<0>{}].template unpack<>(
|
||||
Number<1>{}));
|
||||
});
|
||||
|
||||
return float_values.float32_array;
|
||||
#endif
|
||||
|
||||
@@ -732,7 +732,8 @@ inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f)
|
||||
uint32_t bitwise;
|
||||
f4x2_t f4x2_array[4];
|
||||
} value{0};
|
||||
value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[0], x[1], scale, 0);
|
||||
// permute high bits and low bits to match the order of the original vector
|
||||
value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[1], x[0], scale, 0);
|
||||
return value.f4x2_array[0];
|
||||
#else
|
||||
union
|
||||
@@ -757,58 +758,13 @@ inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0
|
||||
f4x2_t f4x2_array[16];
|
||||
f4x32_t f4x32_array;
|
||||
} f4_values{}, tmp_values{};
|
||||
// TODO: pack in a loop
|
||||
tmp_values.bitwise =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[0], x[1], scale, 0);
|
||||
f4_values.f4x2_array[0] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[2], x[3], scale, 0);
|
||||
f4_values.f4x2_array[1] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[4], x[5], scale, 0);
|
||||
f4_values.f4x2_array[2] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[6], x[7], scale, 0);
|
||||
f4_values.f4x2_array[3] = tmp_values.f4x2_array[0];
|
||||
|
||||
tmp_values.bitwise =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[8], x[9], scale, 0);
|
||||
f4_values.f4x2_array[4] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[10], x[11], scale, 0);
|
||||
f4_values.f4x2_array[5] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[12], x[13], scale, 0);
|
||||
f4_values.f4x2_array[6] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[14], x[15], scale, 0);
|
||||
f4_values.f4x2_array[7] = tmp_values.f4x2_array[0];
|
||||
|
||||
tmp_values.bitwise =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[16], x[17], scale, 0);
|
||||
f4_values.f4x2_array[8] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[18], x[19], scale, 0);
|
||||
f4_values.f4x2_array[9] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[20], x[21], scale, 0);
|
||||
f4_values.f4x2_array[10] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[22], x[23], scale, 0);
|
||||
f4_values.f4x2_array[11] = tmp_values.f4x2_array[0];
|
||||
|
||||
tmp_values.bitwise =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[24], x[25], scale, 0);
|
||||
f4_values.f4x2_array[12] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[26], x[27], scale, 0);
|
||||
f4_values.f4x2_array[13] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[28], x[29], scale, 0);
|
||||
f4_values.f4x2_array[14] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_fp4_f32(tmp_values.bitwise, x[30], x[31], scale, 0);
|
||||
f4_values.f4x2_array[15] = tmp_values.f4x2_array[0];
|
||||
ck::static_for<0, 32 / 2, 1>{}([&](auto idx) {
|
||||
// permute high bits and low bits to match the order of the original vector
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(
|
||||
tmp_values.bitwise, x[2 * idx + 1], x[2 * idx], scale, 0);
|
||||
f4_values.f4x2_array[idx] = tmp_values.f4x2_array[0];
|
||||
});
|
||||
|
||||
return f4_values.f4x32_array;
|
||||
#else
|
||||
@@ -818,106 +774,14 @@ inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0
|
||||
f4x2_t f4x2_array[16];
|
||||
f4x32_t f4x32_array;
|
||||
} f4_values{};
|
||||
// TODO: pack in a loop
|
||||
auto tmp = utils::sat_convert_to_type<f4_t>(x[0] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[1] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[2] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[3] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[4] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[5] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[6] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[7] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[8] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[9] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[10] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[11] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[12] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[13] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[14] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[15] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
f4_t tmp;
|
||||
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[16] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[17] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[18] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[19] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[20] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[21] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[22] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[23] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[24] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[25] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[26] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[27] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[28] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[29] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[30] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[31] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
ck::static_for<0, 32, 1>{}([&](auto idx) {
|
||||
tmp = utils::sat_convert_to_type<f4_t>(x[static_cast<int>(idx)] / scale);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
});
|
||||
|
||||
return f4_values.f4x32_array;
|
||||
#endif
|
||||
@@ -967,7 +831,16 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
|
||||
uint32_t bitwise;
|
||||
f4x2_t f4x2_array[4];
|
||||
} value{0};
|
||||
value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(value.bitwise, x, rng, scale, 0);
|
||||
// apply a temporary workaround for gfx950
|
||||
#if CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION
|
||||
uint8_t l = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
|
||||
uint8_t h = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
|
||||
value.bitwise = (h << 4) | l;
|
||||
#else
|
||||
// permute high bits and low bits to match the order of the original vector
|
||||
value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
value.bitwise, float2_t{x[1], x[0]}, rng, scale, 0);
|
||||
#endif // CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION
|
||||
return value.f4x2_array[0];
|
||||
#else
|
||||
union
|
||||
@@ -997,64 +870,23 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f
|
||||
__uint128_t bitwise;
|
||||
f4x2_t f4x2_array[16];
|
||||
f4x32_t f4x32_array;
|
||||
} f4_values{0}, tmp_values{0};
|
||||
} f4_values{0};
|
||||
union
|
||||
{
|
||||
float2_t floatx2_array[16];
|
||||
float32_t floatx32_array;
|
||||
} float_values{{0}};
|
||||
// TODO: pack in a loop
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise, float_values.floatx2_array[0], rng, scale, 0);
|
||||
f4_values.f4x2_array[0] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise, float_values.floatx2_array[1], rng, scale, 0);
|
||||
f4_values.f4x2_array[1] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise, float_values.floatx2_array[2], rng, scale, 0);
|
||||
f4_values.f4x2_array[2] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise, float_values.floatx2_array[3], rng, scale, 0);
|
||||
f4_values.f4x2_array[3] = tmp_values.f4x2_array[0];
|
||||
float_values.floatx32_array = x;
|
||||
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise, float_values.floatx2_array[4], rng, scale, 0);
|
||||
f4_values.f4x2_array[4] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise, float_values.floatx2_array[5], rng, scale, 0);
|
||||
f4_values.f4x2_array[5] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise, float_values.floatx2_array[6], rng, scale, 0);
|
||||
f4_values.f4x2_array[6] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise, float_values.floatx2_array[7], rng, scale, 0);
|
||||
f4_values.f4x2_array[7] = tmp_values.f4x2_array[0];
|
||||
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise, float_values.floatx2_array[8], rng, scale, 0);
|
||||
f4_values.f4x2_array[8] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise, float_values.floatx2_array[9], rng, scale, 0);
|
||||
f4_values.f4x2_array[9] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise, float_values.floatx2_array[10], rng, scale, 0);
|
||||
f4_values.f4x2_array[10] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise, float_values.floatx2_array[11], rng, scale, 0);
|
||||
f4_values.f4x2_array[11] = tmp_values.f4x2_array[0];
|
||||
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise, float_values.floatx2_array[12], rng, scale, 0);
|
||||
f4_values.f4x2_array[12] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise, float_values.floatx2_array[13], rng, scale, 0);
|
||||
f4_values.f4x2_array[13] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise, float_values.floatx2_array[14], rng, scale, 0);
|
||||
f4_values.f4x2_array[14] = tmp_values.f4x2_array[0];
|
||||
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
tmp_values.bitwise, float_values.floatx2_array[15], rng, scale, 0);
|
||||
f4_values.f4x2_array[15] = tmp_values.f4x2_array[0];
|
||||
ck::static_for<0, 32 / 2, 1>{}([&](auto idx) {
|
||||
// permute high bits and low bits to match the order of the original vector
|
||||
f4_values.f4x2_array[idx] = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
f4_values.bitwise,
|
||||
float2_t{float_values.floatx2_array[idx][1], float_values.floatx2_array[idx][0]},
|
||||
rng,
|
||||
scale,
|
||||
0);
|
||||
});
|
||||
|
||||
return f4_values.f4x32_array;
|
||||
#else
|
||||
@@ -1064,106 +896,14 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f
|
||||
f4x2_t f4x2_array[16];
|
||||
f4x32_t f4x32_array;
|
||||
} f4_values{0};
|
||||
// TODO: pack in a loop
|
||||
auto tmp = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[2] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[3] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[4] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[5] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[6] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[7] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[8] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[9] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[10] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[11] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[12] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[13] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[14] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[15] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
f4_t tmp;
|
||||
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[16] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[17] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[18] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[19] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[20] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[21] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[22] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[23] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[24] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[25] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[26] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[27] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[28] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[29] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[30] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[31] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
ck::static_for<0, 32, 1>{}([&](auto idx) {
|
||||
tmp = utils::sat_convert_to_type_sr<f4_t>(x[static_cast<int>(idx)] / scale, rng);
|
||||
f4_values.bitwise <<= 4;
|
||||
f4_values.bitwise |= tmp;
|
||||
});
|
||||
|
||||
return f4_values.f4x32_array;
|
||||
#endif
|
||||
@@ -1232,13 +972,15 @@ inline __host__ __device__ float2_t type_convert<float2_t, f4x2_t>(f4x2_t x)
|
||||
} value{};
|
||||
value.f4x2_array[0] = x;
|
||||
float scale = 1.0f;
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0);
|
||||
float2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0);
|
||||
// permute high bits and low bits to match the order of the original vector
|
||||
return float2_t{tmp[1], tmp[0]};
|
||||
#else
|
||||
float2_t ret{
|
||||
utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})),
|
||||
x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})),
|
||||
utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}))};
|
||||
x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}))};
|
||||
return ret;
|
||||
#endif
|
||||
}
|
||||
@@ -1253,110 +995,16 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
|
||||
f4x32_t f4x32_array;
|
||||
f4x2_t fp4x2[16];
|
||||
} value{x};
|
||||
union
|
||||
{
|
||||
uint32_t bitwise;
|
||||
f4x2_t f4x2_array[4];
|
||||
} bitwise_value{};
|
||||
float2_t op;
|
||||
float32_t ret;
|
||||
float scale = 1.0f;
|
||||
// TODO: pack in a loop
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[0];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[0] = op[0];
|
||||
ret[1] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[1];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[2] = op[0];
|
||||
ret[3] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[2];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[4] = op[0];
|
||||
ret[5] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[3];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[6] = op[0];
|
||||
ret[7] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[4];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[8] = op[0];
|
||||
ret[9] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[5];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[10] = op[0];
|
||||
ret[11] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[6];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[12] = op[0];
|
||||
ret[13] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[7];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[14] = op[0];
|
||||
ret[15] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[8];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[16] = op[0];
|
||||
ret[17] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[9];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[18] = op[0];
|
||||
ret[19] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[10];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[20] = op[0];
|
||||
ret[21] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[11];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[22] = op[0];
|
||||
ret[23] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[12];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[24] = op[0];
|
||||
ret[25] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[13];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[26] = op[0];
|
||||
ret[27] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[14];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[28] = op[0];
|
||||
ret[29] = op[1];
|
||||
|
||||
bitwise_value.f4x2_array[0] = value.fp4x2[15];
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(
|
||||
bitwise_value.bitwise, type_convert<float>(scale), 0);
|
||||
ret[30] = op[0];
|
||||
ret[31] = op[1];
|
||||
ck::static_for<0, 32 / 2, 1>{}([&](auto idx) {
|
||||
op = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.fp4x2[idx], scale, 0);
|
||||
// permute high bits and low bits to match the order of the original vector
|
||||
ret[2 * idx] = op[1];
|
||||
ret[2 * idx + 1] = op[0];
|
||||
});
|
||||
|
||||
return ret;
|
||||
#else
|
||||
@@ -1371,106 +1019,18 @@ inline __host__ __device__ float32_t type_convert<float32_t, f4x32_t>(f4x32_t x)
|
||||
f4x2_t f4x2_array[16];
|
||||
f4x32_t f4x32_array;
|
||||
} f4_values{bit_cast<__uint128_t>(x)};
|
||||
// TODO: pack in a loop
|
||||
float_values.float_array[0] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[1] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[0].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[2] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[3] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[1].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[4] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[5] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[2].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[6] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[7] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[3].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
|
||||
float_values.float_array[0] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[1] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[4].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[2] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[3] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[5].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[4] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[5] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[6].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[6] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[7] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[7].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
ck::static_for<0, 32 / 2, 1>{}([&](auto idx) {
|
||||
float_values.float_array[2 * idx] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[idx].template AsType<f4x2_pk_t>()[Number<0>{}].template unpack<>(
|
||||
Number<0>{}));
|
||||
|
||||
float_values.float_array[0] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[1] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[8].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[2] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[3] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[9].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[4] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[5] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[10].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[6] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[7] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[11].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
|
||||
float_values.float_array[0] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[1] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[12].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[2] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[3] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[13].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[4] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[5] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[14].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[6] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}));
|
||||
float_values.float_array[7] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[15].template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}));
|
||||
float_values.float_array[2 * idx + 1] = utils::to_float<f4_t>(
|
||||
NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
f4_values.f4x2_array[idx].template AsType<f4x2_pk_t>()[Number<0>{}].template unpack<>(
|
||||
Number<1>{}));
|
||||
});
|
||||
|
||||
return float_values.float32_array;
|
||||
#endif
|
||||
|
||||
@@ -75,6 +75,12 @@ if(GPU_TARGETS MATCHES "gfx950")
|
||||
endif()
|
||||
add_dependencies(test_mx_data_types test_mx_bf8)
|
||||
|
||||
add_gtest_executable(test_mx_fp4 test_mx_fp4.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_mx_fp4 PRIVATE utility)
|
||||
endif()
|
||||
add_dependencies(test_mx_data_types test_mx_fp4)
|
||||
|
||||
add_gtest_executable(test_e8m0 test_e8m0.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_e8m0 PRIVATE utility)
|
||||
|
||||
541
test/data_type/test_mx_fp4.cpp
Normal file
541
test/data_type/test_mx_fp4.cpp
Normal file
@@ -0,0 +1,541 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/utility/scaled_type_convert.hpp"
|
||||
|
||||
using ck::e8m0_bexp_t;
|
||||
using ck::float16_t;
|
||||
using ck::float2_t;
|
||||
using ck::float32_t;
|
||||
using ck::scaled_type_convert;
|
||||
using ck::type_convert;
|
||||
|
||||
using ck::f4_convert_rne;
|
||||
using ck::f4_convert_sr;
|
||||
using ck::f4_t;
|
||||
using ck::f4x16_t;
|
||||
using ck::f4x2_pk_t;
|
||||
using ck::f4x2_t;
|
||||
using ck::f4x32_t;
|
||||
|
||||
constexpr uint64_t test_size = 256 * 16 + 2 + 4 + 6;
|
||||
|
||||
/**
|
||||
* @brief Tests conversion of FP4 values to float using E8M0 exponent scaling.
|
||||
*
|
||||
* This function performs a series of conversions from FP4 values to float values using
|
||||
* E8M0 exponent scaling. It handles all possible combinations of E8M0 and FP4 values,
|
||||
* as well as specific vector and rounding conversions.
|
||||
*
|
||||
* @param N The maximum number of conversions to perform.
|
||||
* @param p_test Pointer to the output array where the converted float values will be stored.
|
||||
* @param p_completed Pointer to a variable that tracks the number of completed conversions.
|
||||
*
|
||||
* @note If either p_test or p_completed is nullptr, the function will return immediately.
|
||||
* @note The function will stop converting if the number of conversions reaches N.
|
||||
* @note First 256*16 conversions are for all possible combinations of E8M0 and FP4 values that are
|
||||
* stored in memory sequentially with FP4 values varying faster.
|
||||
*
|
||||
* The function performs the following conversions:
|
||||
* - All possible combinations of E8M0 and FP4 values. [256x16]
|
||||
* - Vector conversions f4x2 -> f32x2. [2]
|
||||
* - Vector conversions f32x2 -> f4x2 rne. [2]
|
||||
* - Vector conversions f32x2 -> f4x2 sr. [2]
|
||||
* - Round to nearest even conversions for specific float values. [6]
|
||||
*
|
||||
* The results are stored in the p_test array, and the number of completed conversions
|
||||
* is updated in the p_completed variable.
|
||||
*/
|
||||
__host__ __device__ void
|
||||
test_mx_fp4_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed)
|
||||
{
|
||||
if(p_completed == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t& i = *p_completed;
|
||||
i = 0;
|
||||
|
||||
if(p_test == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// All possible combinations of E8M0 and FP4
|
||||
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
|
||||
{
|
||||
for(ck::index_t fp4_id = 0; fp4_id < 16; fp4_id++)
|
||||
{
|
||||
uint8_t fp4_uid = static_cast<uint8_t>(fp4_id);
|
||||
auto v = scaled_type_convert<float>(e8m0_bexp_t(exp_id), f4_t(fp4_uid & 0b00001111));
|
||||
p_test[i] = v;
|
||||
i++;
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Test vector conversions
|
||||
// f4x2 -> f32x2
|
||||
f4x2_t f4x2{f4x2_t::data_v{0b00011100}}; // 0b0001(=0.5) and 0b1100(=-2.0)
|
||||
auto scale2 = e8m0_bexp_t(2.0f);
|
||||
|
||||
float2_t f32x2 = scaled_type_convert<float2_t>(scale2, f4x2);
|
||||
p_test[i++] = f32x2[0];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = f32x2[1];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// f32x2 -> f4x2
|
||||
f32x2 = {1.0f, -4.0f};
|
||||
f4x2 = f4_convert_rne(f32x2, type_convert<float>(scale2)); // expect {0.5, -2}
|
||||
|
||||
p_test[i++] = type_convert<float>(
|
||||
f4_t(f4x2.AsType<f4x2_pk_t>()(ck::Number<0>{}).unpack<>(ck::Number<0>{}))); // 0.5f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<float>(
|
||||
f4_t(f4x2.AsType<f4x2_pk_t>()(ck::Number<0>{}).unpack<>(ck::Number<1>{}))); // -2.0f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
f4x2 = f4_convert_sr(f32x2, type_convert<float>(scale2)); // expect {0.5, -2}
|
||||
|
||||
p_test[i++] = type_convert<float>(
|
||||
f4_t(f4x2.AsType<f4x2_pk_t>()(ck::Number<0>{}).unpack<>(ck::Number<0>{}))); // 0.5f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<float>(
|
||||
f4_t(f4x2.AsType<f4x2_pk_t>()(ck::Number<0>{}).unpack<>(ck::Number<1>{}))); // -2.0f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
/// Test round to nearest even
|
||||
|
||||
p_test[i++] = type_convert<float>(f4_convert_rne(24.0f, 4.0f)); // 24/4
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
p_test[i++] = type_convert<float>(
|
||||
f4_convert_rne(std::numeric_limits<float>::quiet_NaN(), 4.0f)); // => NaN
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// Inf/2 > 6.0 => 6.0 on device
|
||||
p_test[i++] = type_convert<float>(f4_convert_rne(std::numeric_limits<float>::infinity(), 2.0f));
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// 256/0.5 > 6.0 => 6.0 on device
|
||||
p_test[i++] = type_convert<float>(f4_convert_rne(256.0f, 0.5f));
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// -256/0.5 < -6.0 => -6.0 on device
|
||||
p_test[i++] = type_convert<float>(f4_convert_rne(-256.0f, 0.5f));
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// proper scale selection
|
||||
p_test[i++] = type_convert<float>(f4_convert_rne(20.0f, 4.0f)); // 20.0/4.0 = 5.0
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(MXFP4, HostScaledConvert)
|
||||
{
|
||||
std::vector<float> out(test_size, -1.0f);
|
||||
uint64_t completed = 0;
|
||||
|
||||
test_mx_fp4_scaled_convert(test_size, out.data(), &completed);
|
||||
|
||||
// V = X * P; X - E8M0 scale, P - FP4
|
||||
|
||||
// If X = NaN, then V = NaN regardless of P
|
||||
uint8_t e8m0_nan_id = ck::NumericLimits<e8m0_bexp_t>::QuietNaN().data;
|
||||
for(ck::index_t fp4_id = 0; fp4_id < 16; fp4_id++)
|
||||
{
|
||||
auto idx = e8m0_nan_id * 16 + fp4_id;
|
||||
ASSERT_TRUE(std::isnan(out[idx]));
|
||||
}
|
||||
|
||||
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
|
||||
{
|
||||
if(exp_id == e8m0_nan_id)
|
||||
continue;
|
||||
for(ck::index_t fp4_id = 0; fp4_id < 16; fp4_id++)
|
||||
{
|
||||
uint8_t fp4_uid = static_cast<uint8_t>(fp4_id);
|
||||
auto idx = exp_id * 16 + fp4_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx],
|
||||
type_convert<float>(e8m0_bexp_t(exp_id)) *
|
||||
type_convert<float>(f4_t(fp4_uid & 0b00001111)))
|
||||
<< "exp_id: " << exp_id << " fp4_id: " << fp4_id << std::endl
|
||||
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
|
||||
<< type_convert<float>(f4_t(fp4_uid & 0b00001111));
|
||||
}
|
||||
}
|
||||
|
||||
/// Test vector conversions
|
||||
|
||||
auto i = 256 * 16;
|
||||
|
||||
// f4x2 -> f32x2
|
||||
EXPECT_EQ(out[i++], 1.0f);
|
||||
EXPECT_EQ(out[i++], -4.0f);
|
||||
|
||||
// f32x2 -> f4x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], 0.5f);
|
||||
EXPECT_EQ(out[i++], -2.0f);
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], 0.5f);
|
||||
EXPECT_EQ(out[i++], -2.0f);
|
||||
|
||||
/// Test round to nearest even
|
||||
EXPECT_EQ(out[i++], 24.0f / 4.0f) << "out[i-1]: " << out[i - 1];
|
||||
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f4_t>::Max()))
|
||||
<< "out[i-1]: " << out[i - 1];
|
||||
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f4_t>::Max()))
|
||||
<< "out[i-1]: " << out[i - 1];
|
||||
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f4_t>::Max()))
|
||||
<< "out[i-1]: " << out[i - 1];
|
||||
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f4_t>::Lowest()))
|
||||
<< "out[i-1]: " << out[i - 1];
|
||||
EXPECT_EQ(out[i++], type_convert<float>(type_convert<f4_t>(5.0f)))
|
||||
<< "out[i-1]: " << out[i - 1];
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
__global__ void test_mx_fp4_device_scaled_convert(uint64_t N, float* p_test, uint64_t* p_completed)
|
||||
{
|
||||
test_mx_fp4_scaled_convert(N, p_test, p_completed);
|
||||
}
|
||||
|
||||
TEST(MXFP4, DeviceScaledConvert)
|
||||
{
|
||||
std::vector<float> out(test_size, -1.0f);
|
||||
|
||||
DeviceMem device_out(test_size * sizeof(float));
|
||||
DeviceMem device_completed(sizeof(uint64_t));
|
||||
|
||||
device_out.SetValue(-21.0f);
|
||||
device_completed.SetValue(-21.0f);
|
||||
|
||||
test_mx_fp4_device_scaled_convert<<<1, 1>>>(
|
||||
test_size,
|
||||
static_cast<float*>(device_out.GetDeviceBuffer()),
|
||||
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
|
||||
|
||||
uint64_t completed = 0;
|
||||
device_completed.FromDevice(&completed);
|
||||
device_out.FromDevice(out.data());
|
||||
|
||||
// V = X * P; X - E8M0 scale, P - FP4
|
||||
|
||||
// If X = NaN, then V = NaN regardless of P
|
||||
uint8_t e8m0_nan_id = ck::NumericLimits<e8m0_bexp_t>::QuietNaN().data;
|
||||
for(ck::index_t fp4_id = 0; fp4_id < 16; fp4_id++)
|
||||
{
|
||||
auto idx = e8m0_nan_id * 16 + fp4_id;
|
||||
ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx];
|
||||
}
|
||||
|
||||
for(ck::index_t exp_id = 0; exp_id < 256; exp_id++)
|
||||
{
|
||||
if(exp_id == e8m0_nan_id)
|
||||
continue;
|
||||
for(ck::index_t fp4_id = 0; fp4_id < 16; fp4_id++)
|
||||
{
|
||||
uint8_t fp4_uid = static_cast<uint8_t>(fp4_id);
|
||||
auto idx = exp_id * 16 + fp4_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx],
|
||||
type_convert<float>(e8m0_bexp_t(exp_id)) *
|
||||
type_convert<float>(f4_t(fp4_uid & 0b00001111)))
|
||||
<< "exp_id: " << exp_id << " fp4_id: " << fp4_id << std::endl
|
||||
<< type_convert<float>(e8m0_bexp_t(exp_id)) << " * "
|
||||
<< type_convert<float>(f4_t(fp4_uid & 0b00001111));
|
||||
}
|
||||
}
|
||||
|
||||
/// Test vector conversions
|
||||
|
||||
auto i = 256 * 16;
|
||||
|
||||
// f4x2 -> f32x2
|
||||
EXPECT_EQ(out[i++], 1.0f);
|
||||
EXPECT_EQ(out[i++], -4.0f);
|
||||
|
||||
// f32x2 -> f4x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], 0.5f);
|
||||
EXPECT_EQ(out[i++], -2.0f);
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], 0.5f);
|
||||
EXPECT_EQ(out[i++], -2.0f);
|
||||
|
||||
/// Test round to nearest even
|
||||
EXPECT_EQ(out[i++], 24.0f / 4.0f) << "out[i-1]: " << out[i - 1];
|
||||
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f4_t>::Max()))
|
||||
<< "out[i-1]: " << out[i - 1];
|
||||
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f4_t>::Max()))
|
||||
<< "out[i-1]: " << out[i - 1];
|
||||
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f4_t>::Max()))
|
||||
<< "out[i-1]: " << out[i - 1];
|
||||
EXPECT_EQ(out[i++], type_convert<float>(ck::NumericLimits<f4_t>::Lowest()))
|
||||
<< "out[i-1]: " << out[i - 1];
|
||||
EXPECT_EQ(out[i++], type_convert<float>(type_convert<f4_t>(5.0f)))
|
||||
<< "out[i-1]: " << out[i - 1];
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
__host__ __device__ float vec16_generator(ck::index_t i, float scale)
|
||||
{
|
||||
return scale * type_convert<float>(f4_t(i & 0b00001111));
|
||||
}
|
||||
|
||||
__host__ __device__ float vec32_generator(ck::index_t i, float scale)
|
||||
{
|
||||
if(i < 16)
|
||||
{
|
||||
return vec16_generator(
|
||||
i, scale); // all positive values, then all negative values in ascending order
|
||||
}
|
||||
else
|
||||
{
|
||||
return vec16_generator(
|
||||
15 - (i % 16),
|
||||
scale); // all negative values, then all positive values in descending order
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void test_mx_fp4x32_device_scaled_convert(float* p_test, uint64_t* p_completed)
|
||||
{
|
||||
constexpr int N = 32;
|
||||
if(p_completed == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t& i = *p_completed;
|
||||
i = 0;
|
||||
|
||||
if(p_test == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
auto scale2 = e8m0_bexp_t(2.0f);
|
||||
|
||||
f4x32_t f4x32{};
|
||||
float32_t float32{};
|
||||
ck::static_for<0, N, 1>{}([&](auto ii) {
|
||||
float32[static_cast<int>(ii)] = vec32_generator(ii, type_convert<float>(scale2));
|
||||
});
|
||||
|
||||
f4x32 = f4_convert_rne(float32, type_convert<float>(scale2));
|
||||
|
||||
ck::static_for<0, N / 2, 1>{}([&](auto ii) {
|
||||
p_test[i++] = type_convert<float>(
|
||||
f4_t(f4x32.AsType<f4x2_pk_t>()(ck::Number<ii>{}).template unpack<>(ck::Number<0>{})));
|
||||
p_test[i++] = type_convert<float>(
|
||||
f4_t(f4x32.AsType<f4x2_pk_t>()(ck::Number<ii>{}).template unpack<>(ck::Number<1>{})));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(MXFP4, DeviceF32x32ToF4x32ScaledConvert)
|
||||
{
|
||||
constexpr int N = 32;
|
||||
std::vector<float> out(N, -1.0f);
|
||||
|
||||
DeviceMem device_out(N * sizeof(float));
|
||||
DeviceMem device_completed(sizeof(uint64_t));
|
||||
|
||||
device_out.SetValue(-21.0f);
|
||||
device_completed.SetValue(-21.0f);
|
||||
|
||||
test_mx_fp4x32_device_scaled_convert<<<1, 1>>>(
|
||||
static_cast<float*>(device_out.GetDeviceBuffer()),
|
||||
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
|
||||
|
||||
uint64_t completed = 0;
|
||||
device_completed.FromDevice(&completed);
|
||||
device_out.FromDevice(out.data());
|
||||
|
||||
auto i = 0;
|
||||
auto scale2 = e8m0_bexp_t(2.0f);
|
||||
|
||||
ck::static_for<0, N, 1>{}([&](auto ii) {
|
||||
EXPECT_EQ(out[i++],
|
||||
vec32_generator(ii, type_convert<float>(scale2)) / type_convert<float>(scale2))
|
||||
<< "ii: " << ii << std::endl;
|
||||
});
|
||||
|
||||
EXPECT_EQ(N, completed);
|
||||
EXPECT_EQ(N, i);
|
||||
}
|
||||
|
||||
__global__ void test_mx_fp4x32_device_scaled_convert_sr(float* p_test, uint64_t* p_completed)
|
||||
{
|
||||
constexpr int N = 32;
|
||||
if(p_completed == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t& i = *p_completed;
|
||||
i = 0;
|
||||
|
||||
if(p_test == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
auto scale2 = e8m0_bexp_t(2.0f);
|
||||
|
||||
f4x32_t f4x32{};
|
||||
float32_t float32{};
|
||||
ck::static_for<0, N, 1>{}([&](auto ii) {
|
||||
float32[static_cast<int>(ii)] = vec32_generator(ii, type_convert<float>(scale2));
|
||||
});
|
||||
|
||||
f4x32 = f4_convert_sr(float32, type_convert<float>(scale2));
|
||||
|
||||
ck::static_for<0, N / 2, 1>{}([&](auto ii) {
|
||||
p_test[i++] = type_convert<float>(
|
||||
f4_t(f4x32.AsType<f4x2_pk_t>()(ck::Number<ii>{}).template unpack<>(ck::Number<0>{})));
|
||||
p_test[i++] = type_convert<float>(
|
||||
f4_t(f4x32.AsType<f4x2_pk_t>()(ck::Number<ii>{}).template unpack<>(ck::Number<1>{})));
|
||||
});
|
||||
}
|
||||
|
||||
TEST(MXFP4, DeviceF32x32ToF4x32ScaledConvertSR)
|
||||
{
|
||||
constexpr int N = 32;
|
||||
std::vector<float> out(N, -1.0f);
|
||||
|
||||
DeviceMem device_out(N * sizeof(float));
|
||||
DeviceMem device_completed(sizeof(uint64_t));
|
||||
|
||||
device_out.SetValue(-21.0f);
|
||||
device_completed.SetValue(-21.0f);
|
||||
|
||||
test_mx_fp4x32_device_scaled_convert_sr<<<1, 1>>>(
|
||||
static_cast<float*>(device_out.GetDeviceBuffer()),
|
||||
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
|
||||
|
||||
uint64_t completed = 0;
|
||||
device_completed.FromDevice(&completed);
|
||||
device_out.FromDevice(out.data());
|
||||
|
||||
auto i = 0;
|
||||
auto scale2 = e8m0_bexp_t(2.0f);
|
||||
|
||||
ck::static_for<0, N, 1>{}([&](auto ii) {
|
||||
EXPECT_EQ(out[i++],
|
||||
vec32_generator(ii, type_convert<float>(scale2)) / type_convert<float>(scale2))
|
||||
<< "ii: " << ii << std::endl;
|
||||
});
|
||||
|
||||
EXPECT_EQ(N, completed);
|
||||
EXPECT_EQ(N, i);
|
||||
}
|
||||
|
||||
__global__ void test_mx_f32x32_device_scaled_convert(float* p_test, uint64_t* p_completed)
|
||||
{
|
||||
constexpr int N = 32;
|
||||
if(p_completed == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t& i = *p_completed;
|
||||
i = 0;
|
||||
|
||||
if(p_test == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
auto scale2 = e8m0_bexp_t(2.0f);
|
||||
|
||||
f4x32_t f4x32{};
|
||||
float32_t float32{};
|
||||
ck::static_for<0, N / 2, 1>{}([&](auto ii) {
|
||||
f4x32.AsType<f4x2_pk_t>()(ck::Number<ii>{}) = f4x2_pk_t{}.pack(
|
||||
type_convert<f4_t>(vec32_generator(2 * ii, type_convert<float>(scale2)) /
|
||||
type_convert<float>(scale2)),
|
||||
type_convert<f4_t>(vec32_generator(2 * ii + 1, type_convert<float>(scale2)) /
|
||||
type_convert<float>(scale2)));
|
||||
});
|
||||
|
||||
float32 = scaled_type_convert<float32_t>(scale2, f4x32);
|
||||
|
||||
ck::static_for<0, N, 1>{}([&](auto ii) { p_test[i++] = float32[static_cast<int>(ii)]; });
|
||||
}
|
||||
|
||||
TEST(MXFP4, DeviceF4x32ToF32x32ScaledConvert)
|
||||
{
|
||||
constexpr int N = 32;
|
||||
std::vector<float> out(N, -1.0f);
|
||||
|
||||
DeviceMem device_out(N * sizeof(float));
|
||||
DeviceMem device_completed(sizeof(uint64_t));
|
||||
|
||||
device_out.SetValue(-21.0f);
|
||||
device_completed.SetValue(-21.0f);
|
||||
|
||||
test_mx_f32x32_device_scaled_convert<<<1, 1>>>(
|
||||
static_cast<float*>(device_out.GetDeviceBuffer()),
|
||||
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
|
||||
|
||||
uint64_t completed = 0;
|
||||
device_completed.FromDevice(&completed);
|
||||
device_out.FromDevice(out.data());
|
||||
|
||||
auto i = 0;
|
||||
auto scale2 = e8m0_bexp_t(2.0f);
|
||||
|
||||
ck::static_for<0, N, 1>{}([&](auto ii) {
|
||||
EXPECT_EQ(out[i++], vec32_generator(ii, type_convert<float>(scale2)))
|
||||
<< "ii: " << ii << std::endl;
|
||||
});
|
||||
|
||||
EXPECT_EQ(N, completed);
|
||||
EXPECT_EQ(N, i);
|
||||
}
|
||||
Reference in New Issue
Block a user