Add FP16/BF16<->FP8/BF8 conversions (#2035)

* Move conversion functions and add missing conversions

* Add tests

* Add missing conversions

* Add missing conversions

* Add bf8 tests

* Update clipping for vectors

* Add missing conversions

* Add bf16 fp8 tests

* Add bf16 bf8 tests

* Fix device conversion

* Fix conversions

* Fix vector use

* Minor fix

* Add a workaround flag

* Add a workaround flag for bf16 conversion

* Add another workaround

* Add a workaround for fp16 to bf8 conversion

* Update type alias

* Add docstrings and missing wrappers

* Fix if defined macros

* Fix more if defined macros

* Add comments

* Remove __host__ specifier

* Add a gfx950 guard

* Update function naming

[ROCm/composable_kernel commit: 265af71a71]
This commit is contained in:
Rostyslav Geyyer
2025-04-03 12:42:03 -05:00
committed by GitHub
parent b7359bcfac
commit 7fbc128e83
7 changed files with 2628 additions and 110 deletions

View File

@@ -1,13 +1,19 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/library/utility/device_memory.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using ck::bf8_ocp_t;
using ck::bf8x2_ocp_t;
using ck::bhalf2_t;
using ck::bhalf_t;
using ck::f8_convert_rne;
using ck::f8_convert_sr;
using ck::float2_t;
using ck::half2_t;
using ck::half_t;
using ck::type_convert;
@@ -266,3 +272,590 @@ TEST(BF8OCP, ConvertFP16Stochastic)
const auto bf8_nan = f8_convert_sr<bf8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data));
}
constexpr uint64_t test_size = 256 + 6;
__host__ __device__ void
test_fp32_bf8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed)
{
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto v = type_convert<float>(bf8_ocp_t{bf8_uid});
p_test[i] = v;
i++;
if(i >= N)
{
return;
}
}
/// Test vector conversion
// bf8x2 -> fp32x2
bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16
float2_t f32x2 = type_convert<float2_t>(bf8x2);
p_test[i++] = f32x2[0];
if(i >= N)
{
return;
}
p_test[i++] = f32x2[1];
if(i >= N)
{
return;
}
// fp32x2 -> bf8x2
f32x2 = {-4.0f, 2.0f};
bf8x2 = f8_convert_rne<bf8x2_ocp_t>(f32x2); // expect {-4, 2}
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
bf8x2 = f8_convert_sr<bf8x2_ocp_t>(f32x2); // expect {-4, 2}
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
}
TEST(BF8OCP, HostFP32BF8Convert)
{
std::vector<float> out(test_size, -1.0f);
uint64_t completed = 0;
test_fp32_bf8_type_convert(test_size, out.data(), &completed);
std::set<uint8_t> bf8_nan_ids;
bf8_nan_ids.insert(0b11111111);
bf8_nan_ids.insert(0b01111111);
bf8_nan_ids.insert(0b11111101);
bf8_nan_ids.insert(0b01111101);
bf8_nan_ids.insert(0b11111110);
bf8_nan_ids.insert(0b01111110);
for(auto bf8_nan_id : bf8_nan_ids)
{
auto idx = bf8_nan_id;
ASSERT_TRUE(std::isnan(out[idx]));
}
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
continue;
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto idx = bf8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<float>(bf8_ocp_t{bf8_uid}))
<< " bf8_id: " << bf8_id << std::endl
<< type_convert<float>(bf8_ocp_t{bf8_uid});
}
// /// Test vector conversions
auto i = 256;
// bf8x2 -> fp32x2
EXPECT_EQ(out[i++], -powf(2.0f, -14.0f));
EXPECT_EQ(out[i++], powf(2.0f, -16.0f));
// fp32x2 -> bf8x2
// RNE
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
// SR
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__global__ void device_test_fp32_bf8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed)
{
test_fp32_bf8_type_convert(N, p_test, p_completed);
}
TEST(BF8OCP, DeviceFP32BF8Convert)
{
std::vector<float> out(test_size, -1.0f);
DeviceMem device_out(test_size * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
device_test_fp32_bf8_type_convert<<<1, 1>>>(
test_size,
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
std::set<uint8_t> bf8_nan_ids;
bf8_nan_ids.insert(0b11111111);
bf8_nan_ids.insert(0b01111111);
bf8_nan_ids.insert(0b11111101);
bf8_nan_ids.insert(0b01111101);
bf8_nan_ids.insert(0b11111110);
bf8_nan_ids.insert(0b01111110);
for(auto bf8_nan_id : bf8_nan_ids)
{
auto idx = bf8_nan_id;
ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx];
}
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
continue;
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto idx = bf8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<float>(bf8_ocp_t{bf8_uid}))
<< " bf8_id: " << bf8_id << std::endl
<< type_convert<float>(bf8_ocp_t{bf8_uid});
}
/// Test vector conversions
auto i = 256;
// bf8x2 -> fp32x2
EXPECT_EQ(out[i++], -powf(2.0f, -14.0f));
EXPECT_EQ(out[i++], powf(2.0f, -16.0f));
// fp32x2 -> bf8x2
// RNE
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
// SR
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__host__ __device__ void
test_fp16_bf8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed)
{
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto v = type_convert<half_t>(bf8_ocp_t{bf8_uid});
p_test[i] = v;
i++;
if(i >= N)
{
return;
}
}
/// Test vector conversion
// bf8x2 -> fp16x2
bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16
half2_t f16x2 = type_convert<half2_t>(bf8x2);
p_test[i++] = f16x2[0];
if(i >= N)
{
return;
}
p_test[i++] = f16x2[1];
if(i >= N)
{
return;
}
// fp16x2 -> bf8x2
f16x2 = {-4.0f, 2.0f};
bf8x2 = f8_convert_rne<bf8x2_ocp_t>(f16x2); // expect {-4, 2}
p_test[i++] = type_convert<half_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<half_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
bf8x2 = f8_convert_sr<bf8x2_ocp_t>(f16x2); // expect {-4, 2}
p_test[i++] = type_convert<half_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<half_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
}
TEST(BF8OCP, HostFP16BF8Convert)
{
std::vector<half_t> out(test_size, -1.0f);
uint64_t completed = 0;
test_fp16_bf8_type_convert(test_size, out.data(), &completed);
std::set<uint8_t> bf8_nan_ids;
bf8_nan_ids.insert(0b11111111);
bf8_nan_ids.insert(0b01111111);
bf8_nan_ids.insert(0b11111101);
bf8_nan_ids.insert(0b01111101);
bf8_nan_ids.insert(0b11111110);
bf8_nan_ids.insert(0b01111110);
for(auto bf8_nan_id : bf8_nan_ids)
{
auto idx = bf8_nan_id;
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])));
}
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
continue;
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto idx = bf8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<half_t>(bf8_ocp_t{bf8_uid}))
<< " bf8_id: " << bf8_id << std::endl
<< type_convert<float>(type_convert<half_t>(bf8_ocp_t{bf8_uid}));
}
// /// Test vector conversions
auto i = 256;
// bf8x2 -> fp16x2
EXPECT_EQ(out[i++], type_convert<half_t>(-powf(2.0f, -14.0f)));
EXPECT_EQ(out[i++], type_convert<half_t>(powf(2.0f, -16.0f)));
// fp16x2 -> bf8x2
// RNE
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
// SR
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__global__ void device_test_fp16_bf8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed)
{
test_fp16_bf8_type_convert(N, p_test, p_completed);
}
TEST(BF8OCP, DeviceFP16BF8Convert)
{
std::vector<half_t> out(test_size, -1.0f);
DeviceMem device_out(test_size * sizeof(half_t));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
device_test_fp16_bf8_type_convert<<<1, 1>>>(
test_size,
static_cast<half_t*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
std::set<uint8_t> bf8_nan_ids;
bf8_nan_ids.insert(0b11111111);
bf8_nan_ids.insert(0b01111111);
bf8_nan_ids.insert(0b11111101);
bf8_nan_ids.insert(0b01111101);
bf8_nan_ids.insert(0b11111110);
bf8_nan_ids.insert(0b01111110);
for(auto bf8_nan_id : bf8_nan_ids)
{
auto idx = bf8_nan_id;
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])))
<< "idx: " << idx << " out[idx]: " << type_convert<float>(out[idx]);
}
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
continue;
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto idx = bf8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<half_t>(bf8_ocp_t{bf8_uid}))
<< " bf8_id: " << bf8_id << std::endl
<< type_convert<float>(type_convert<half_t>(bf8_ocp_t{bf8_uid}));
}
/// Test vector conversions
auto i = 256;
// bf8x2 -> fp16x2
EXPECT_EQ(out[i++], type_convert<half_t>(-powf(2.0f, -14.0f)));
EXPECT_EQ(out[i++], type_convert<half_t>(powf(2.0f, -16.0f)));
// fp16x2 -> bf8x2
// RNE
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
// SR
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__host__ __device__ void
test_bf16_bf8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed)
{
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto v = type_convert<bhalf_t>(bf8_ocp_t{bf8_uid});
p_test[i] = v;
i++;
if(i >= N)
{
return;
}
}
/// Test vector conversion
// bf8x2 -> bf16x2
bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16
bhalf2_t bf16x2 = type_convert<bhalf2_t>(bf8x2);
p_test[i++] = bf16x2[0];
if(i >= N)
{
return;
}
p_test[i++] = bf16x2[1];
if(i >= N)
{
return;
}
// bf16x2 -> bf8x2
bf16x2 = {type_convert<bhalf_t>(-4.0f), type_convert<bhalf_t>(2.0f)};
bf8x2 = f8_convert_rne<bf8x2_ocp_t>(bf16x2); // expect {-4, 2}
p_test[i++] = type_convert<bhalf_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<bhalf_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
bf8x2 = f8_convert_sr<bf8x2_ocp_t>(bf16x2); // expect {-4, 2}
p_test[i++] = type_convert<bhalf_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<bhalf_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
}
TEST(BF8OCP, HostBF16BF8Convert)
{
std::vector<bhalf_t> out(test_size, -1.0f);
uint64_t completed = 0;
test_bf16_bf8_type_convert(test_size, out.data(), &completed);
std::set<uint8_t> bf8_nan_ids;
bf8_nan_ids.insert(0b11111111);
bf8_nan_ids.insert(0b01111111);
bf8_nan_ids.insert(0b11111101);
bf8_nan_ids.insert(0b01111101);
bf8_nan_ids.insert(0b11111110);
bf8_nan_ids.insert(0b01111110);
for(auto bf8_nan_id : bf8_nan_ids)
{
auto idx = bf8_nan_id;
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])));
}
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
continue;
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto idx = bf8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<bhalf_t>(bf8_ocp_t{bf8_uid}))
<< " bf8_id: " << bf8_id << std::endl
<< type_convert<float>(type_convert<bhalf_t>(bf8_ocp_t{bf8_uid}));
}
// /// Test vector conversions
auto i = 256;
// bf8x2 -> bf16x2
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-powf(2.0f, -14.0f)));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(powf(2.0f, -16.0f)));
// bf16x2 -> bf8x2
// RNE
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
// SR
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__global__ void
device_test_bf16_bf8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed)
{
test_bf16_bf8_type_convert(N, p_test, p_completed);
}
TEST(BF8OCP, DeviceBF16BF8Convert)
{
std::vector<bhalf_t> out(test_size, -1.0f);
DeviceMem device_out(test_size * sizeof(bhalf_t));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
device_test_bf16_bf8_type_convert<<<1, 1>>>(
test_size,
static_cast<bhalf_t*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
std::set<uint8_t> bf8_nan_ids;
bf8_nan_ids.insert(0b11111111);
bf8_nan_ids.insert(0b01111111);
bf8_nan_ids.insert(0b11111101);
bf8_nan_ids.insert(0b01111101);
bf8_nan_ids.insert(0b11111110);
bf8_nan_ids.insert(0b01111110);
for(auto bf8_nan_id : bf8_nan_ids)
{
auto idx = bf8_nan_id;
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])))
<< "idx: " << idx << " out[idx]: " << type_convert<float>(out[idx]);
}
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
{
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
continue;
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
auto idx = bf8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<bhalf_t>(bf8_ocp_t{bf8_uid}))
<< " bf8_id: " << bf8_id << std::endl
<< type_convert<float>(type_convert<bhalf_t>(bf8_ocp_t{bf8_uid}));
}
/// Test vector conversions
auto i = 256;
// bf8x2 -> bf16x2
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-powf(2.0f, -14.0f)));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(powf(2.0f, -16.0f)));
// bf16x2 -> bf8x2
// RNE
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
// SR
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}

View File

@@ -1,13 +1,19 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/library/utility/device_memory.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using ck::bhalf2_t;
using ck::bhalf_t;
using ck::f8_convert_rne;
using ck::f8_convert_sr;
using ck::f8_ocp_t;
using ck::f8x2_ocp_t;
using ck::float2_t;
using ck::half2_t;
using ck::half_t;
using ck::type_convert;
@@ -248,3 +254,566 @@ TEST(FP8OCP, ConvertFP16Stochastic)
auto f8_nan = f8_convert_sr<f8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
ASSERT_TRUE(ck::fp8_impl::ocp_f8_is_nan(f8_nan.data));
}
constexpr uint64_t test_size = 256 + 6;
__host__ __device__ void
test_fp32_fp8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed)
{
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto v = type_convert<float>(f8_ocp_t{fp8_uid});
p_test[i] = v;
i++;
if(i >= N)
{
return;
}
}
/// Test vector conversion
// fp8x2 -> fp32x2
f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9
float2_t f32x2 = type_convert<float2_t>(fp8x2);
p_test[i++] = f32x2[0];
if(i >= N)
{
return;
}
p_test[i++] = f32x2[1];
if(i >= N)
{
return;
}
// fp32x2 -> fp8x2
f32x2 = {-4.0f, 2.0f};
fp8x2 = f8_convert_rne<f8x2_ocp_t>(f32x2); // expect {-4, 2}
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
fp8x2 = f8_convert_sr<f8x2_ocp_t>(f32x2); // expect {-4, 2}
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
}
TEST(FP8OCP, HostFP32FP8Convert)
{
std::vector<float> out(test_size, -1.0f);
uint64_t completed = 0;
test_fp32_fp8_type_convert(test_size, out.data(), &completed);
std::set<uint8_t> fp8_nan_ids;
fp8_nan_ids.insert(0b11111111); //-NaN
fp8_nan_ids.insert(0b01111111); // +NaN
for(auto fp8_nan_id : fp8_nan_ids)
{
auto idx = fp8_nan_id;
ASSERT_TRUE(std::isnan(out[idx]));
}
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
continue;
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto idx = fp8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<float>(f8_ocp_t{fp8_uid}))
<< " fp8_id: " << fp8_id << std::endl
<< type_convert<float>(f8_ocp_t{fp8_uid});
}
// /// Test vector conversions
auto i = 256;
// fp8x2 -> fp32x2
EXPECT_EQ(out[i++], -powf(2.0f, -6.0f));
EXPECT_EQ(out[i++], powf(2.0f, -9.0f));
// fp32x2 -> fp8x2
// RNE
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
// SR
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__global__ void device_test_fp32_fp8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed)
{
test_fp32_fp8_type_convert(N, p_test, p_completed);
}
TEST(FP8OCP, DeviceFP32FP8Convert)
{
std::vector<float> out(test_size, -1.0f);
DeviceMem device_out(test_size * sizeof(float));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
device_test_fp32_fp8_type_convert<<<1, 1>>>(
test_size,
static_cast<float*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
std::set<uint8_t> fp8_nan_ids;
fp8_nan_ids.insert(0b11111111); //-NaN
fp8_nan_ids.insert(0b01111111); // +NaN
for(auto fp8_nan_id : fp8_nan_ids)
{
auto idx = fp8_nan_id;
ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx];
}
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
continue;
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto idx = fp8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<float>(f8_ocp_t{fp8_uid}))
<< " fp8_id: " << fp8_id << std::endl
<< type_convert<float>(f8_ocp_t{fp8_uid});
}
/// Test vector conversions
auto i = 256;
// fp8x2 -> fp32x2
EXPECT_EQ(out[i++], -powf(2.0f, -6.0f));
EXPECT_EQ(out[i++], powf(2.0f, -9.0f));
// fp32x2 -> fp8x2
// RNE
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
// SR
EXPECT_EQ(out[i++], -4.0f);
EXPECT_EQ(out[i++], 2.0f);
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__host__ __device__ void
test_fp16_fp8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed)
{
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto v = type_convert<half_t>(f8_ocp_t{fp8_uid});
p_test[i] = v;
i++;
if(i >= N)
{
return;
}
}
/// Test vector conversion
// fp8x2 -> fp16x2
f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9
half2_t f16x2 = type_convert<half2_t>(fp8x2);
p_test[i++] = f16x2[0];
if(i >= N)
{
return;
}
p_test[i++] = f16x2[1];
if(i >= N)
{
return;
}
// fp16x2 -> fp8x2
f16x2 = {-4.0f, 2.0f};
fp8x2 = f8_convert_rne<f8x2_ocp_t>(f16x2); // expect {-4, 2}
p_test[i++] = type_convert<half_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<half_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
fp8x2 = f8_convert_sr<f8x2_ocp_t>(f16x2); // expect {-4, 2}
p_test[i++] = type_convert<half_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<half_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
}
TEST(FP8OCP, HostFP16FP8Convert)
{
std::vector<half_t> out(test_size, -1.0f);
uint64_t completed = 0;
test_fp16_fp8_type_convert(test_size, out.data(), &completed);
std::set<uint8_t> fp8_nan_ids;
fp8_nan_ids.insert(0b11111111); //-NaN
fp8_nan_ids.insert(0b01111111); // +NaN
for(auto fp8_nan_id : fp8_nan_ids)
{
auto idx = fp8_nan_id;
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])));
}
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
continue;
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto idx = fp8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<half_t>(f8_ocp_t{fp8_uid}))
<< " fp8_id: " << fp8_id << std::endl
<< type_convert<float>(type_convert<half_t>(f8_ocp_t{fp8_uid}));
}
// /// Test vector conversions
auto i = 256;
// fp8x2 -> fp16x2
EXPECT_EQ(out[i++], type_convert<half_t>(-powf(2.0f, -6.0f)));
EXPECT_EQ(out[i++], type_convert<half_t>(powf(2.0f, -9.0f)));
// fp16x2 -> fp8x2
// RNE
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
// SR
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__global__ void device_test_fp16_fp8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed)
{
test_fp16_fp8_type_convert(N, p_test, p_completed);
}
TEST(FP8OCP, DeviceFP16FP8Convert)
{
std::vector<half_t> out(test_size, -1.0f);
DeviceMem device_out(test_size * sizeof(half_t));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
device_test_fp16_fp8_type_convert<<<1, 1>>>(
test_size,
static_cast<half_t*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
std::set<uint8_t> fp8_nan_ids;
fp8_nan_ids.insert(0b11111111); //-NaN
fp8_nan_ids.insert(0b01111111); // +NaN
for(auto fp8_nan_id : fp8_nan_ids)
{
auto idx = fp8_nan_id;
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])))
<< "idx: " << idx << " out[idx]: " << type_convert<float>(out[idx]);
}
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
continue;
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto idx = fp8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<half_t>(f8_ocp_t{fp8_uid}))
<< " fp8_id: " << fp8_id << std::endl
<< type_convert<float>(type_convert<half_t>(f8_ocp_t{fp8_uid}));
}
/// Test vector conversions
auto i = 256;
// fp8x2 -> fp16x2
EXPECT_EQ(out[i++], type_convert<half_t>(-powf(2.0f, -6.0f)));
EXPECT_EQ(out[i++], type_convert<half_t>(powf(2.0f, -9.0f)));
// fp16x2 -> fp8x2
// RNE
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
// SR
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__host__ __device__ void
test_bf16_fp8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed)
{
if(p_completed == nullptr)
{
return;
}
uint64_t& i = *p_completed;
i = 0;
if(p_test == nullptr)
{
return;
}
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto v = type_convert<bhalf_t>(f8_ocp_t{fp8_uid});
p_test[i] = v;
i++;
if(i >= N)
{
return;
}
}
/// Test vector conversion
// fp8x2 -> bf16x2
f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9
bhalf2_t bf16x2 = type_convert<bhalf2_t>(fp8x2);
p_test[i++] = bf16x2[0];
if(i >= N)
{
return;
}
p_test[i++] = bf16x2[1];
if(i >= N)
{
return;
}
// bf16x2 -> fp8x2
bf16x2 = {type_convert<bhalf_t>(-4.0f), type_convert<bhalf_t>(2.0f)};
fp8x2 = f8_convert_rne<f8x2_ocp_t>(bf16x2); // expect {-4, 2}
p_test[i++] = type_convert<bhalf_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<bhalf_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
fp8x2 = f8_convert_sr<f8x2_ocp_t>(bf16x2); // expect {-4, 2}
p_test[i++] = type_convert<bhalf_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
if(i >= N)
{
return;
}
p_test[i++] = type_convert<bhalf_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
if(i >= N)
{
return;
}
}
TEST(FP8OCP, HostBF16FP8Convert)
{
std::vector<bhalf_t> out(test_size, -1.0f);
uint64_t completed = 0;
test_bf16_fp8_type_convert(test_size, out.data(), &completed);
std::set<uint8_t> fp8_nan_ids;
fp8_nan_ids.insert(0b11111111); //-NaN
fp8_nan_ids.insert(0b01111111); // +NaN
for(auto fp8_nan_id : fp8_nan_ids)
{
auto idx = fp8_nan_id;
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])));
}
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
continue;
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto idx = fp8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<bhalf_t>(f8_ocp_t{fp8_uid}))
<< " fp8_id: " << fp8_id << std::endl
<< type_convert<float>(type_convert<bhalf_t>(f8_ocp_t{fp8_uid}));
}
// /// Test vector conversions
auto i = 256;
// fp8x2 -> bf16x2
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-powf(2.0f, -6.0f)));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(powf(2.0f, -9.0f)));
// bf16x2 -> fp8x2
// RNE
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
// SR
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}
__global__ void
device_test_bf16_fp8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed)
{
test_bf16_fp8_type_convert(N, p_test, p_completed);
}
TEST(FP8OCP, DeviceBF16FP8Convert)
{
std::vector<bhalf_t> out(test_size, -1.0f);
DeviceMem device_out(test_size * sizeof(bhalf_t));
DeviceMem device_completed(sizeof(uint64_t));
device_out.SetValue(-21.0f);
device_completed.SetValue(-21.0f);
device_test_bf16_fp8_type_convert<<<1, 1>>>(
test_size,
static_cast<bhalf_t*>(device_out.GetDeviceBuffer()),
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
uint64_t completed = 0;
device_completed.FromDevice(&completed);
device_out.FromDevice(out.data());
std::set<uint8_t> fp8_nan_ids;
fp8_nan_ids.insert(0b11111111); //-NaN
fp8_nan_ids.insert(0b01111111); // +NaN
for(auto fp8_nan_id : fp8_nan_ids)
{
auto idx = fp8_nan_id;
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])))
<< "idx: " << idx << " out[idx]: " << type_convert<float>(out[idx]);
}
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
{
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
continue;
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
auto idx = fp8_uid;
ASSERT_FLOAT_EQ(out[idx], type_convert<bhalf_t>(f8_ocp_t{fp8_uid}))
<< " fp8_id: " << fp8_id << std::endl
<< type_convert<float>(type_convert<bhalf_t>(f8_ocp_t{fp8_uid}));
}
/// Test vector conversions
auto i = 256;
// fp8x2 -> bf16x2
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-powf(2.0f, -6.0f)));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(powf(2.0f, -9.0f)));
// bf16x2 -> fp8x2
// RNE
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
// SR
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
EXPECT_EQ(test_size, completed);
EXPECT_EQ(test_size, i);
}