mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
Add FP16/BF16<->FP8/BF8 conversions (#2035)
* Move conversion functions and add missing conversions
* Add tests
* Add missing conversions
* Add missing conversions
* Add bf8 tests
* Update clipping for vectors
* Add missing conversions
* Add bf16 fp8 tests
* Add bf16 bf8 tests
* Fix device conversion
* Fix conversions
* Fix vector use
* Minor fix
* Add a workaround flag
* Add a workaround flag for bf16 conversion
* Add another workaround
* Add a workaround for fp16 to bf8 conversion
* Update type alias
* Add docstrings and missing wrappers
* Fix if defined macros
* Fix more if defined macros
* Add comments
* Remove __host__ specifier
* Add a gfx950 guard
* Update function naming
[ROCm/composable_kernel commit: 265af71a71]
This commit is contained in:
@@ -1,13 +1,19 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
|
||||
using ck::bf8_ocp_t;
|
||||
using ck::bf8x2_ocp_t;
|
||||
using ck::bhalf2_t;
|
||||
using ck::bhalf_t;
|
||||
using ck::f8_convert_rne;
|
||||
using ck::f8_convert_sr;
|
||||
using ck::float2_t;
|
||||
using ck::half2_t;
|
||||
using ck::half_t;
|
||||
using ck::type_convert;
|
||||
|
||||
@@ -266,3 +272,590 @@ TEST(BF8OCP, ConvertFP16Stochastic)
|
||||
const auto bf8_nan = f8_convert_sr<bf8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
|
||||
ASSERT_TRUE(ck::fp8_impl::ocp_bf8_is_nan(bf8_nan.data));
|
||||
}
|
||||
|
||||
constexpr uint64_t test_size = 256 + 6;
|
||||
|
||||
__host__ __device__ void
|
||||
test_fp32_bf8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed)
|
||||
{
|
||||
if(p_completed == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t& i = *p_completed;
|
||||
i = 0;
|
||||
|
||||
if(p_test == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
|
||||
{
|
||||
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
|
||||
auto v = type_convert<float>(bf8_ocp_t{bf8_uid});
|
||||
p_test[i] = v;
|
||||
i++;
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/// Test vector conversion
|
||||
// bf8x2 -> fp32x2
|
||||
bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16
|
||||
|
||||
float2_t f32x2 = type_convert<float2_t>(bf8x2);
|
||||
p_test[i++] = f32x2[0];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = f32x2[1];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// fp32x2 -> bf8x2
|
||||
f32x2 = {-4.0f, 2.0f};
|
||||
bf8x2 = f8_convert_rne<bf8x2_ocp_t>(f32x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
bf8x2 = f8_convert_sr<bf8x2_ocp_t>(f32x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<float>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BF8OCP, HostFP32BF8Convert)
|
||||
{
|
||||
std::vector<float> out(test_size, -1.0f);
|
||||
uint64_t completed = 0;
|
||||
|
||||
test_fp32_bf8_type_convert(test_size, out.data(), &completed);
|
||||
|
||||
std::set<uint8_t> bf8_nan_ids;
|
||||
bf8_nan_ids.insert(0b11111111);
|
||||
bf8_nan_ids.insert(0b01111111);
|
||||
bf8_nan_ids.insert(0b11111101);
|
||||
bf8_nan_ids.insert(0b01111101);
|
||||
bf8_nan_ids.insert(0b11111110);
|
||||
bf8_nan_ids.insert(0b01111110);
|
||||
for(auto bf8_nan_id : bf8_nan_ids)
|
||||
{
|
||||
auto idx = bf8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(out[idx]));
|
||||
}
|
||||
|
||||
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
|
||||
{
|
||||
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
|
||||
auto idx = bf8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<float>(bf8_ocp_t{bf8_uid}))
|
||||
<< " bf8_id: " << bf8_id << std::endl
|
||||
<< type_convert<float>(bf8_ocp_t{bf8_uid});
|
||||
}
|
||||
|
||||
// /// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// bf8x2 -> fp32x2
|
||||
EXPECT_EQ(out[i++], -powf(2.0f, -14.0f));
|
||||
EXPECT_EQ(out[i++], powf(2.0f, -16.0f));
|
||||
|
||||
// fp32x2 -> bf8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], -4.0f);
|
||||
EXPECT_EQ(out[i++], 2.0f);
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], -4.0f);
|
||||
EXPECT_EQ(out[i++], 2.0f);
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
__global__ void device_test_fp32_bf8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed)
|
||||
{
|
||||
test_fp32_bf8_type_convert(N, p_test, p_completed);
|
||||
}
|
||||
|
||||
TEST(BF8OCP, DeviceFP32BF8Convert)
|
||||
{
|
||||
std::vector<float> out(test_size, -1.0f);
|
||||
|
||||
DeviceMem device_out(test_size * sizeof(float));
|
||||
DeviceMem device_completed(sizeof(uint64_t));
|
||||
|
||||
device_out.SetValue(-21.0f);
|
||||
device_completed.SetValue(-21.0f);
|
||||
|
||||
device_test_fp32_bf8_type_convert<<<1, 1>>>(
|
||||
test_size,
|
||||
static_cast<float*>(device_out.GetDeviceBuffer()),
|
||||
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
|
||||
|
||||
uint64_t completed = 0;
|
||||
device_completed.FromDevice(&completed);
|
||||
device_out.FromDevice(out.data());
|
||||
|
||||
std::set<uint8_t> bf8_nan_ids;
|
||||
bf8_nan_ids.insert(0b11111111);
|
||||
bf8_nan_ids.insert(0b01111111);
|
||||
bf8_nan_ids.insert(0b11111101);
|
||||
bf8_nan_ids.insert(0b01111101);
|
||||
bf8_nan_ids.insert(0b11111110);
|
||||
bf8_nan_ids.insert(0b01111110);
|
||||
for(auto bf8_nan_id : bf8_nan_ids)
|
||||
{
|
||||
auto idx = bf8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx];
|
||||
}
|
||||
|
||||
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
|
||||
{
|
||||
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
|
||||
auto idx = bf8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<float>(bf8_ocp_t{bf8_uid}))
|
||||
<< " bf8_id: " << bf8_id << std::endl
|
||||
<< type_convert<float>(bf8_ocp_t{bf8_uid});
|
||||
}
|
||||
|
||||
/// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// bf8x2 -> fp32x2
|
||||
EXPECT_EQ(out[i++], -powf(2.0f, -14.0f));
|
||||
EXPECT_EQ(out[i++], powf(2.0f, -16.0f));
|
||||
|
||||
// fp32x2 -> bf8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], -4.0f);
|
||||
EXPECT_EQ(out[i++], 2.0f);
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], -4.0f);
|
||||
EXPECT_EQ(out[i++], 2.0f);
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
__host__ __device__ void
|
||||
test_fp16_bf8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed)
|
||||
{
|
||||
if(p_completed == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t& i = *p_completed;
|
||||
i = 0;
|
||||
|
||||
if(p_test == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
|
||||
{
|
||||
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
|
||||
auto v = type_convert<half_t>(bf8_ocp_t{bf8_uid});
|
||||
p_test[i] = v;
|
||||
i++;
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/// Test vector conversion
|
||||
// bf8x2 -> fp16x2
|
||||
bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16
|
||||
|
||||
half2_t f16x2 = type_convert<half2_t>(bf8x2);
|
||||
p_test[i++] = f16x2[0];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = f16x2[1];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// fp16x2 -> bf8x2
|
||||
f16x2 = {-4.0f, 2.0f};
|
||||
bf8x2 = f8_convert_rne<bf8x2_ocp_t>(f16x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<half_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<half_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
bf8x2 = f8_convert_sr<bf8x2_ocp_t>(f16x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<half_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<half_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BF8OCP, HostFP16BF8Convert)
|
||||
{
|
||||
std::vector<half_t> out(test_size, -1.0f);
|
||||
uint64_t completed = 0;
|
||||
|
||||
test_fp16_bf8_type_convert(test_size, out.data(), &completed);
|
||||
|
||||
std::set<uint8_t> bf8_nan_ids;
|
||||
bf8_nan_ids.insert(0b11111111);
|
||||
bf8_nan_ids.insert(0b01111111);
|
||||
bf8_nan_ids.insert(0b11111101);
|
||||
bf8_nan_ids.insert(0b01111101);
|
||||
bf8_nan_ids.insert(0b11111110);
|
||||
bf8_nan_ids.insert(0b01111110);
|
||||
for(auto bf8_nan_id : bf8_nan_ids)
|
||||
{
|
||||
auto idx = bf8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])));
|
||||
}
|
||||
|
||||
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
|
||||
{
|
||||
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
|
||||
auto idx = bf8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<half_t>(bf8_ocp_t{bf8_uid}))
|
||||
<< " bf8_id: " << bf8_id << std::endl
|
||||
<< type_convert<float>(type_convert<half_t>(bf8_ocp_t{bf8_uid}));
|
||||
}
|
||||
|
||||
// /// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// bf8x2 -> fp16x2
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-powf(2.0f, -14.0f)));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(powf(2.0f, -16.0f)));
|
||||
|
||||
// fp16x2 -> bf8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
__global__ void device_test_fp16_bf8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed)
|
||||
{
|
||||
test_fp16_bf8_type_convert(N, p_test, p_completed);
|
||||
}
|
||||
|
||||
TEST(BF8OCP, DeviceFP16BF8Convert)
|
||||
{
|
||||
std::vector<half_t> out(test_size, -1.0f);
|
||||
|
||||
DeviceMem device_out(test_size * sizeof(half_t));
|
||||
DeviceMem device_completed(sizeof(uint64_t));
|
||||
|
||||
device_out.SetValue(-21.0f);
|
||||
device_completed.SetValue(-21.0f);
|
||||
|
||||
device_test_fp16_bf8_type_convert<<<1, 1>>>(
|
||||
test_size,
|
||||
static_cast<half_t*>(device_out.GetDeviceBuffer()),
|
||||
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
|
||||
|
||||
uint64_t completed = 0;
|
||||
device_completed.FromDevice(&completed);
|
||||
device_out.FromDevice(out.data());
|
||||
|
||||
std::set<uint8_t> bf8_nan_ids;
|
||||
bf8_nan_ids.insert(0b11111111);
|
||||
bf8_nan_ids.insert(0b01111111);
|
||||
bf8_nan_ids.insert(0b11111101);
|
||||
bf8_nan_ids.insert(0b01111101);
|
||||
bf8_nan_ids.insert(0b11111110);
|
||||
bf8_nan_ids.insert(0b01111110);
|
||||
for(auto bf8_nan_id : bf8_nan_ids)
|
||||
{
|
||||
auto idx = bf8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])))
|
||||
<< "idx: " << idx << " out[idx]: " << type_convert<float>(out[idx]);
|
||||
}
|
||||
|
||||
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
|
||||
{
|
||||
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
|
||||
auto idx = bf8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<half_t>(bf8_ocp_t{bf8_uid}))
|
||||
<< " bf8_id: " << bf8_id << std::endl
|
||||
<< type_convert<float>(type_convert<half_t>(bf8_ocp_t{bf8_uid}));
|
||||
}
|
||||
|
||||
/// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// bf8x2 -> fp16x2
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-powf(2.0f, -14.0f)));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(powf(2.0f, -16.0f)));
|
||||
|
||||
// fp16x2 -> bf8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
__host__ __device__ void
|
||||
test_bf16_bf8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed)
|
||||
{
|
||||
if(p_completed == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t& i = *p_completed;
|
||||
i = 0;
|
||||
|
||||
if(p_test == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
|
||||
{
|
||||
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
|
||||
auto v = type_convert<bhalf_t>(bf8_ocp_t{bf8_uid});
|
||||
p_test[i] = v;
|
||||
i++;
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/// Test vector conversion
|
||||
// bf8x2 -> bf16x2
|
||||
bf8x2_ocp_t bf8x2{bf8x2_ocp_t::data_v{0b10000100, 0b00000001}}; //-2^-14, 2^-16
|
||||
|
||||
bhalf2_t bf16x2 = type_convert<bhalf2_t>(bf8x2);
|
||||
p_test[i++] = bf16x2[0];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = bf16x2[1];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// bf16x2 -> bf8x2
|
||||
bf16x2 = {type_convert<bhalf_t>(-4.0f), type_convert<bhalf_t>(2.0f)};
|
||||
bf8x2 = f8_convert_rne<bf8x2_ocp_t>(bf16x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<bhalf_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<bhalf_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
bf8x2 = f8_convert_sr<bf8x2_ocp_t>(bf16x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<bhalf_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<bhalf_t>(bf8x2.AsType<bf8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BF8OCP, HostBF16BF8Convert)
|
||||
{
|
||||
std::vector<bhalf_t> out(test_size, -1.0f);
|
||||
uint64_t completed = 0;
|
||||
|
||||
test_bf16_bf8_type_convert(test_size, out.data(), &completed);
|
||||
|
||||
std::set<uint8_t> bf8_nan_ids;
|
||||
bf8_nan_ids.insert(0b11111111);
|
||||
bf8_nan_ids.insert(0b01111111);
|
||||
bf8_nan_ids.insert(0b11111101);
|
||||
bf8_nan_ids.insert(0b01111101);
|
||||
bf8_nan_ids.insert(0b11111110);
|
||||
bf8_nan_ids.insert(0b01111110);
|
||||
for(auto bf8_nan_id : bf8_nan_ids)
|
||||
{
|
||||
auto idx = bf8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])));
|
||||
}
|
||||
|
||||
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
|
||||
{
|
||||
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
|
||||
auto idx = bf8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<bhalf_t>(bf8_ocp_t{bf8_uid}))
|
||||
<< " bf8_id: " << bf8_id << std::endl
|
||||
<< type_convert<float>(type_convert<bhalf_t>(bf8_ocp_t{bf8_uid}));
|
||||
}
|
||||
|
||||
// /// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// bf8x2 -> bf16x2
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-powf(2.0f, -14.0f)));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(powf(2.0f, -16.0f)));
|
||||
|
||||
// bf16x2 -> bf8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
__global__ void
|
||||
device_test_bf16_bf8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed)
|
||||
{
|
||||
test_bf16_bf8_type_convert(N, p_test, p_completed);
|
||||
}
|
||||
|
||||
TEST(BF8OCP, DeviceBF16BF8Convert)
|
||||
{
|
||||
std::vector<bhalf_t> out(test_size, -1.0f);
|
||||
|
||||
DeviceMem device_out(test_size * sizeof(bhalf_t));
|
||||
DeviceMem device_completed(sizeof(uint64_t));
|
||||
|
||||
device_out.SetValue(-21.0f);
|
||||
device_completed.SetValue(-21.0f);
|
||||
|
||||
device_test_bf16_bf8_type_convert<<<1, 1>>>(
|
||||
test_size,
|
||||
static_cast<bhalf_t*>(device_out.GetDeviceBuffer()),
|
||||
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
|
||||
|
||||
uint64_t completed = 0;
|
||||
device_completed.FromDevice(&completed);
|
||||
device_out.FromDevice(out.data());
|
||||
|
||||
std::set<uint8_t> bf8_nan_ids;
|
||||
bf8_nan_ids.insert(0b11111111);
|
||||
bf8_nan_ids.insert(0b01111111);
|
||||
bf8_nan_ids.insert(0b11111101);
|
||||
bf8_nan_ids.insert(0b01111101);
|
||||
bf8_nan_ids.insert(0b11111110);
|
||||
bf8_nan_ids.insert(0b01111110);
|
||||
for(auto bf8_nan_id : bf8_nan_ids)
|
||||
{
|
||||
auto idx = bf8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])))
|
||||
<< "idx: " << idx << " out[idx]: " << type_convert<float>(out[idx]);
|
||||
}
|
||||
|
||||
for(ck::index_t bf8_id = 0; bf8_id < 256; bf8_id++)
|
||||
{
|
||||
if(bf8_nan_ids.find(bf8_id) != bf8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t bf8_uid = static_cast<uint8_t>(bf8_id);
|
||||
auto idx = bf8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<bhalf_t>(bf8_ocp_t{bf8_uid}))
|
||||
<< " bf8_id: " << bf8_id << std::endl
|
||||
<< type_convert<float>(type_convert<bhalf_t>(bf8_ocp_t{bf8_uid}));
|
||||
}
|
||||
|
||||
/// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// bf8x2 -> bf16x2
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-powf(2.0f, -14.0f)));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(powf(2.0f, -16.0f)));
|
||||
|
||||
// bf16x2 -> bf8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/type_convert.hpp"
|
||||
|
||||
using ck::bhalf2_t;
|
||||
using ck::bhalf_t;
|
||||
using ck::f8_convert_rne;
|
||||
using ck::f8_convert_sr;
|
||||
using ck::f8_ocp_t;
|
||||
using ck::f8x2_ocp_t;
|
||||
using ck::float2_t;
|
||||
using ck::half2_t;
|
||||
using ck::half_t;
|
||||
using ck::type_convert;
|
||||
|
||||
@@ -248,3 +254,566 @@ TEST(FP8OCP, ConvertFP16Stochastic)
|
||||
auto f8_nan = f8_convert_sr<f8_ocp_t>(ck::NumericLimits<half_t>::QuietNaN());
|
||||
ASSERT_TRUE(ck::fp8_impl::ocp_f8_is_nan(f8_nan.data));
|
||||
}
|
||||
|
||||
constexpr uint64_t test_size = 256 + 6;
|
||||
|
||||
__host__ __device__ void
|
||||
test_fp32_fp8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed)
|
||||
{
|
||||
if(p_completed == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t& i = *p_completed;
|
||||
i = 0;
|
||||
|
||||
if(p_test == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
|
||||
{
|
||||
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
|
||||
auto v = type_convert<float>(f8_ocp_t{fp8_uid});
|
||||
p_test[i] = v;
|
||||
i++;
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/// Test vector conversion
|
||||
// fp8x2 -> fp32x2
|
||||
f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9
|
||||
|
||||
float2_t f32x2 = type_convert<float2_t>(fp8x2);
|
||||
p_test[i++] = f32x2[0];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = f32x2[1];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// fp32x2 -> fp8x2
|
||||
f32x2 = {-4.0f, 2.0f};
|
||||
fp8x2 = f8_convert_rne<f8x2_ocp_t>(f32x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
fp8x2 = f8_convert_sr<f8x2_ocp_t>(f32x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<float>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FP8OCP, HostFP32FP8Convert)
|
||||
{
|
||||
std::vector<float> out(test_size, -1.0f);
|
||||
uint64_t completed = 0;
|
||||
|
||||
test_fp32_fp8_type_convert(test_size, out.data(), &completed);
|
||||
|
||||
std::set<uint8_t> fp8_nan_ids;
|
||||
fp8_nan_ids.insert(0b11111111); //-NaN
|
||||
fp8_nan_ids.insert(0b01111111); // +NaN
|
||||
for(auto fp8_nan_id : fp8_nan_ids)
|
||||
{
|
||||
auto idx = fp8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(out[idx]));
|
||||
}
|
||||
|
||||
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
|
||||
{
|
||||
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
|
||||
auto idx = fp8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<float>(f8_ocp_t{fp8_uid}))
|
||||
<< " fp8_id: " << fp8_id << std::endl
|
||||
<< type_convert<float>(f8_ocp_t{fp8_uid});
|
||||
}
|
||||
|
||||
// /// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// fp8x2 -> fp32x2
|
||||
EXPECT_EQ(out[i++], -powf(2.0f, -6.0f));
|
||||
EXPECT_EQ(out[i++], powf(2.0f, -9.0f));
|
||||
|
||||
// fp32x2 -> fp8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], -4.0f);
|
||||
EXPECT_EQ(out[i++], 2.0f);
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], -4.0f);
|
||||
EXPECT_EQ(out[i++], 2.0f);
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
__global__ void device_test_fp32_fp8_type_convert(uint64_t N, float* p_test, uint64_t* p_completed)
|
||||
{
|
||||
test_fp32_fp8_type_convert(N, p_test, p_completed);
|
||||
}
|
||||
|
||||
TEST(FP8OCP, DeviceFP32FP8Convert)
|
||||
{
|
||||
std::vector<float> out(test_size, -1.0f);
|
||||
|
||||
DeviceMem device_out(test_size * sizeof(float));
|
||||
DeviceMem device_completed(sizeof(uint64_t));
|
||||
|
||||
device_out.SetValue(-21.0f);
|
||||
device_completed.SetValue(-21.0f);
|
||||
|
||||
device_test_fp32_fp8_type_convert<<<1, 1>>>(
|
||||
test_size,
|
||||
static_cast<float*>(device_out.GetDeviceBuffer()),
|
||||
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
|
||||
|
||||
uint64_t completed = 0;
|
||||
device_completed.FromDevice(&completed);
|
||||
device_out.FromDevice(out.data());
|
||||
|
||||
std::set<uint8_t> fp8_nan_ids;
|
||||
fp8_nan_ids.insert(0b11111111); //-NaN
|
||||
fp8_nan_ids.insert(0b01111111); // +NaN
|
||||
for(auto fp8_nan_id : fp8_nan_ids)
|
||||
{
|
||||
auto idx = fp8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(out[idx])) << "idx: " << idx << " out[idx]: " << out[idx];
|
||||
}
|
||||
|
||||
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
|
||||
{
|
||||
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
|
||||
auto idx = fp8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<float>(f8_ocp_t{fp8_uid}))
|
||||
<< " fp8_id: " << fp8_id << std::endl
|
||||
<< type_convert<float>(f8_ocp_t{fp8_uid});
|
||||
}
|
||||
|
||||
/// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// fp8x2 -> fp32x2
|
||||
EXPECT_EQ(out[i++], -powf(2.0f, -6.0f));
|
||||
EXPECT_EQ(out[i++], powf(2.0f, -9.0f));
|
||||
|
||||
// fp32x2 -> fp8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], -4.0f);
|
||||
EXPECT_EQ(out[i++], 2.0f);
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], -4.0f);
|
||||
EXPECT_EQ(out[i++], 2.0f);
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
__host__ __device__ void
|
||||
test_fp16_fp8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed)
|
||||
{
|
||||
if(p_completed == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t& i = *p_completed;
|
||||
i = 0;
|
||||
|
||||
if(p_test == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
|
||||
{
|
||||
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
|
||||
auto v = type_convert<half_t>(f8_ocp_t{fp8_uid});
|
||||
p_test[i] = v;
|
||||
i++;
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/// Test vector conversion
|
||||
// fp8x2 -> fp16x2
|
||||
f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9
|
||||
|
||||
half2_t f16x2 = type_convert<half2_t>(fp8x2);
|
||||
p_test[i++] = f16x2[0];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = f16x2[1];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// fp16x2 -> fp8x2
|
||||
f16x2 = {-4.0f, 2.0f};
|
||||
fp8x2 = f8_convert_rne<f8x2_ocp_t>(f16x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<half_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<half_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
fp8x2 = f8_convert_sr<f8x2_ocp_t>(f16x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<half_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<half_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FP8OCP, HostFP16FP8Convert)
|
||||
{
|
||||
std::vector<half_t> out(test_size, -1.0f);
|
||||
uint64_t completed = 0;
|
||||
|
||||
test_fp16_fp8_type_convert(test_size, out.data(), &completed);
|
||||
|
||||
std::set<uint8_t> fp8_nan_ids;
|
||||
fp8_nan_ids.insert(0b11111111); //-NaN
|
||||
fp8_nan_ids.insert(0b01111111); // +NaN
|
||||
for(auto fp8_nan_id : fp8_nan_ids)
|
||||
{
|
||||
auto idx = fp8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])));
|
||||
}
|
||||
|
||||
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
|
||||
{
|
||||
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
|
||||
auto idx = fp8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<half_t>(f8_ocp_t{fp8_uid}))
|
||||
<< " fp8_id: " << fp8_id << std::endl
|
||||
<< type_convert<float>(type_convert<half_t>(f8_ocp_t{fp8_uid}));
|
||||
}
|
||||
|
||||
// /// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// fp8x2 -> fp16x2
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-powf(2.0f, -6.0f)));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(powf(2.0f, -9.0f)));
|
||||
|
||||
// fp16x2 -> fp8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
__global__ void device_test_fp16_fp8_type_convert(uint64_t N, half_t* p_test, uint64_t* p_completed)
|
||||
{
|
||||
test_fp16_fp8_type_convert(N, p_test, p_completed);
|
||||
}
|
||||
|
||||
TEST(FP8OCP, DeviceFP16FP8Convert)
|
||||
{
|
||||
std::vector<half_t> out(test_size, -1.0f);
|
||||
|
||||
DeviceMem device_out(test_size * sizeof(half_t));
|
||||
DeviceMem device_completed(sizeof(uint64_t));
|
||||
|
||||
device_out.SetValue(-21.0f);
|
||||
device_completed.SetValue(-21.0f);
|
||||
|
||||
device_test_fp16_fp8_type_convert<<<1, 1>>>(
|
||||
test_size,
|
||||
static_cast<half_t*>(device_out.GetDeviceBuffer()),
|
||||
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
|
||||
|
||||
uint64_t completed = 0;
|
||||
device_completed.FromDevice(&completed);
|
||||
device_out.FromDevice(out.data());
|
||||
|
||||
std::set<uint8_t> fp8_nan_ids;
|
||||
fp8_nan_ids.insert(0b11111111); //-NaN
|
||||
fp8_nan_ids.insert(0b01111111); // +NaN
|
||||
for(auto fp8_nan_id : fp8_nan_ids)
|
||||
{
|
||||
auto idx = fp8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])))
|
||||
<< "idx: " << idx << " out[idx]: " << type_convert<float>(out[idx]);
|
||||
}
|
||||
|
||||
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
|
||||
{
|
||||
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
|
||||
auto idx = fp8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<half_t>(f8_ocp_t{fp8_uid}))
|
||||
<< " fp8_id: " << fp8_id << std::endl
|
||||
<< type_convert<float>(type_convert<half_t>(f8_ocp_t{fp8_uid}));
|
||||
}
|
||||
|
||||
/// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// fp8x2 -> fp16x2
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-powf(2.0f, -6.0f)));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(powf(2.0f, -9.0f)));
|
||||
|
||||
// fp16x2 -> fp8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<half_t>(2.0f));
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
__host__ __device__ void
|
||||
test_bf16_fp8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed)
|
||||
{
|
||||
if(p_completed == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
uint64_t& i = *p_completed;
|
||||
i = 0;
|
||||
|
||||
if(p_test == nullptr)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
|
||||
{
|
||||
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
|
||||
auto v = type_convert<bhalf_t>(f8_ocp_t{fp8_uid});
|
||||
p_test[i] = v;
|
||||
i++;
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
/// Test vector conversion
|
||||
// fp8x2 -> bf16x2
|
||||
f8x2_ocp_t fp8x2{f8x2_ocp_t::data_v{0b10001000, 0b00000001}}; //-2^-6, 2^-9
|
||||
|
||||
bhalf2_t bf16x2 = type_convert<bhalf2_t>(fp8x2);
|
||||
p_test[i++] = bf16x2[0];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = bf16x2[1];
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// bf16x2 -> fp8x2
|
||||
bf16x2 = {type_convert<bhalf_t>(-4.0f), type_convert<bhalf_t>(2.0f)};
|
||||
fp8x2 = f8_convert_rne<f8x2_ocp_t>(bf16x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<bhalf_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<bhalf_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
fp8x2 = f8_convert_sr<f8x2_ocp_t>(bf16x2); // expect {-4, 2}
|
||||
|
||||
p_test[i++] = type_convert<bhalf_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<0>{})); //-4f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
p_test[i++] = type_convert<bhalf_t>(fp8x2.AsType<f8_ocp_t>()(ck::Number<1>{})); // 2f
|
||||
if(i >= N)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(FP8OCP, HostBF16FP8Convert)
|
||||
{
|
||||
std::vector<bhalf_t> out(test_size, -1.0f);
|
||||
uint64_t completed = 0;
|
||||
|
||||
test_bf16_fp8_type_convert(test_size, out.data(), &completed);
|
||||
|
||||
std::set<uint8_t> fp8_nan_ids;
|
||||
fp8_nan_ids.insert(0b11111111); //-NaN
|
||||
fp8_nan_ids.insert(0b01111111); // +NaN
|
||||
for(auto fp8_nan_id : fp8_nan_ids)
|
||||
{
|
||||
auto idx = fp8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])));
|
||||
}
|
||||
|
||||
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
|
||||
{
|
||||
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
|
||||
auto idx = fp8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<bhalf_t>(f8_ocp_t{fp8_uid}))
|
||||
<< " fp8_id: " << fp8_id << std::endl
|
||||
<< type_convert<float>(type_convert<bhalf_t>(f8_ocp_t{fp8_uid}));
|
||||
}
|
||||
|
||||
// /// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// fp8x2 -> bf16x2
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-powf(2.0f, -6.0f)));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(powf(2.0f, -9.0f)));
|
||||
|
||||
// bf16x2 -> fp8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
__global__ void
|
||||
device_test_bf16_fp8_type_convert(uint64_t N, bhalf_t* p_test, uint64_t* p_completed)
|
||||
{
|
||||
test_bf16_fp8_type_convert(N, p_test, p_completed);
|
||||
}
|
||||
|
||||
TEST(FP8OCP, DeviceBF16FP8Convert)
|
||||
{
|
||||
std::vector<bhalf_t> out(test_size, -1.0f);
|
||||
|
||||
DeviceMem device_out(test_size * sizeof(bhalf_t));
|
||||
DeviceMem device_completed(sizeof(uint64_t));
|
||||
|
||||
device_out.SetValue(-21.0f);
|
||||
device_completed.SetValue(-21.0f);
|
||||
|
||||
device_test_bf16_fp8_type_convert<<<1, 1>>>(
|
||||
test_size,
|
||||
static_cast<bhalf_t*>(device_out.GetDeviceBuffer()),
|
||||
static_cast<uint64_t*>(device_completed.GetDeviceBuffer()));
|
||||
|
||||
uint64_t completed = 0;
|
||||
device_completed.FromDevice(&completed);
|
||||
device_out.FromDevice(out.data());
|
||||
|
||||
std::set<uint8_t> fp8_nan_ids;
|
||||
fp8_nan_ids.insert(0b11111111); //-NaN
|
||||
fp8_nan_ids.insert(0b01111111); // +NaN
|
||||
for(auto fp8_nan_id : fp8_nan_ids)
|
||||
{
|
||||
auto idx = fp8_nan_id;
|
||||
ASSERT_TRUE(std::isnan(type_convert<float>(out[idx])))
|
||||
<< "idx: " << idx << " out[idx]: " << type_convert<float>(out[idx]);
|
||||
}
|
||||
|
||||
for(ck::index_t fp8_id = 0; fp8_id < 256; fp8_id++)
|
||||
{
|
||||
if(fp8_nan_ids.find(fp8_id) != fp8_nan_ids.end())
|
||||
continue;
|
||||
|
||||
uint8_t fp8_uid = static_cast<uint8_t>(fp8_id);
|
||||
auto idx = fp8_uid;
|
||||
ASSERT_FLOAT_EQ(out[idx], type_convert<bhalf_t>(f8_ocp_t{fp8_uid}))
|
||||
<< " fp8_id: " << fp8_id << std::endl
|
||||
<< type_convert<float>(type_convert<bhalf_t>(f8_ocp_t{fp8_uid}));
|
||||
}
|
||||
|
||||
/// Test vector conversions
|
||||
|
||||
auto i = 256;
|
||||
|
||||
// fp8x2 -> bf16x2
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-powf(2.0f, -6.0f)));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(powf(2.0f, -9.0f)));
|
||||
|
||||
// bf16x2 -> fp8x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
|
||||
// SR
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(-4.0f));
|
||||
EXPECT_EQ(out[i++], type_convert<bhalf_t>(2.0f));
|
||||
|
||||
EXPECT_EQ(test_size, completed);
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user