mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Clean up
This commit is contained in:
@@ -379,8 +379,7 @@ inline __host__ __device__ float2_t scaled_type_convert<float2_t, f4x2_t>(e8m0_b
|
||||
value.f4x2_array[0] = x;
|
||||
float2_t tmp =
|
||||
__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, type_convert<float>(scale), 0);
|
||||
// intrinsic packs vector as {element1, element0}, so we should repack it as {element0,
|
||||
// element1}
|
||||
// permute high bits and low bits to match the order of the original vector
|
||||
return float2_t{tmp[1], tmp[0]};
|
||||
#else
|
||||
float2_t ret{utils::to_float<f4_t>(
|
||||
|
||||
@@ -734,6 +734,7 @@ inline __host__ __device__ f4x2_t f4_convert_rne(float2_t x, float scale = 1.0f)
|
||||
uint32_t bitwise;
|
||||
f4x2_t f4x2_array[4];
|
||||
} value{0};
|
||||
// permute high bits and low bits to match the order of the original vector
|
||||
value.bitwise = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(value.bitwise, x[1], x[0], scale, 0);
|
||||
return value.f4x2_array[0];
|
||||
#else
|
||||
@@ -824,34 +825,7 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
|
||||
uint32_t bitwise;
|
||||
f4x2_t f4x2_array[4];
|
||||
} value{0};
|
||||
printf("%f, %f\n", x[0], x[1]);
|
||||
value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
value.bitwise, float2_t{x[1], x[0]}, rng, scale, 0);
|
||||
return value.f4x2_array[0];
|
||||
#else
|
||||
union
|
||||
{
|
||||
uint32_t bitwise;
|
||||
f4x2_t f4x2_array[4];
|
||||
} value{0};
|
||||
uint8_t l = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
|
||||
uint8_t h = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
|
||||
value.bitwise = (h << 4) | l;
|
||||
return value.f4x2_array[0];
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert vector of 2 fp32 to vector of 2 fp4 with sr
|
||||
inline __host__ __device__ f4x2_t f4_convert_sr_repro(float2_t x, float scale = 1.0f)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
|
||||
#if defined(__gfx950__)
|
||||
union
|
||||
{
|
||||
uint32_t bitwise;
|
||||
f4x2_t f4x2_array[4];
|
||||
} value{0};
|
||||
// permute high bits and low bits to match the order of the original vector
|
||||
value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
value.bitwise, float2_t{x[1], x[0]}, rng, scale, 0);
|
||||
return value.f4x2_array[0];
|
||||
@@ -981,13 +955,15 @@ inline __host__ __device__ float2_t type_convert<float2_t, f4x2_t>(f4x2_t x)
|
||||
} value{};
|
||||
value.f4x2_array[0] = x;
|
||||
float scale = 1.0f;
|
||||
return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0);
|
||||
float2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(value.bitwise, scale, 0);
|
||||
// permute high bits and low bits to match the order of the original vector
|
||||
return float2_t{tmp[1], tmp[0]};
|
||||
#else
|
||||
float2_t ret{
|
||||
utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{})),
|
||||
x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{})),
|
||||
utils::to_float<f4_t>(NumericLimits<e8m0_bexp_t>::Binary_1(),
|
||||
x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<0>{}))};
|
||||
x.template AsType<f4x2_pk_t>()[Number<0>{}].unpack<>(Number<1>{}))};
|
||||
return ret;
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -81,12 +81,6 @@ if(GPU_TARGETS MATCHES "gfx950")
|
||||
endif()
|
||||
add_dependencies(test_mx_data_types test_mx_fp4)
|
||||
|
||||
add_gtest_executable(test_mx_fp4_repro test_mx_fp4_repro.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_mx_fp4_repro PRIVATE utility)
|
||||
endif()
|
||||
add_dependencies(test_mx_data_types test_mx_fp4_repro)
|
||||
|
||||
add_gtest_executable(test_e8m0 test_e8m0.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_e8m0 PRIVATE utility)
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/utility/scaled_type_convert.hpp"
|
||||
|
||||
using ck::e8m0_bexp_t;
|
||||
using ck::float16_t;
|
||||
using ck::float2_t;
|
||||
using ck::float32_t;
|
||||
using ck::scaled_type_convert;
|
||||
using ck::type_convert;
|
||||
|
||||
using ck::f4_convert_rne;
|
||||
using ck::f4_convert_sr;
|
||||
using ck::f4_t;
|
||||
using ck::f4x16_t;
|
||||
using ck::f4x2_pk_t;
|
||||
using ck::f4x2_t;
|
||||
using ck::f4x32_t;
|
||||
|
||||
__host__ __device__ void test_mx_fp4_to_fp32(float* p_test)
|
||||
{
|
||||
/// Test vector conversions
|
||||
// f4x2 -> f32x2
|
||||
f4x2_t f4x2{f4x2_t::data_v{0b00011100}}; // 0b0001(=0.5) and 0b1100(=-2.0)
|
||||
auto scale2 = e8m0_bexp_t(2.0f);
|
||||
|
||||
float2_t f32x2 = scaled_type_convert<float2_t>(scale2, f4x2);
|
||||
p_test[0] = f32x2[0];
|
||||
p_test[1] = f32x2[1];
|
||||
}
|
||||
|
||||
__global__ void run_test_mx_fp4_to_fp32(float* p_test) { test_mx_fp4_to_fp32(p_test); }
|
||||
|
||||
__host__ __device__ void test_mx_fp32_to_fp4_rne(float* p_test)
|
||||
{
|
||||
// f32x2 -> f4x2
|
||||
float2_t f32x2 = {1.0f, -4.0f};
|
||||
auto scale2 = e8m0_bexp_t(2.0f);
|
||||
f4x2_t f4x2 = f4_convert_rne(f32x2, type_convert<float>(scale2)); // expect {0.5, -2}
|
||||
|
||||
p_test[0] = type_convert<float>(
|
||||
f4_t(f4x2.AsType<f4x2_pk_t>()(ck::Number<0>{}).unpack<>(ck::Number<0>{}))); // 0.5f
|
||||
p_test[1] = type_convert<float>(
|
||||
f4_t(f4x2.AsType<f4x2_pk_t>()(ck::Number<0>{}).unpack<>(ck::Number<1>{}))); // -2.0f
|
||||
}
|
||||
|
||||
__global__ void run_test_mx_fp32_to_fp4_rne(float* p_test) { test_mx_fp32_to_fp4_rne(p_test); }
|
||||
|
||||
__host__ __device__ void test_mx_fp32_to_fp4_sr(float* p_test)
|
||||
{
|
||||
float2_t f32x2 = {1.0f, -4.0f};
|
||||
auto scale2 = e8m0_bexp_t(2.0f);
|
||||
f4x2_t f4x2 = f4_convert_sr(f32x2, type_convert<float>(scale2)); // expect {0.5, -2}
|
||||
|
||||
p_test[0] = type_convert<float>(
|
||||
f4_t(f4x2.AsType<f4x2_pk_t>()(ck::Number<0>{}).unpack<>(ck::Number<0>{}))); // 0.5f
|
||||
p_test[1] = type_convert<float>(
|
||||
f4_t(f4x2.AsType<f4x2_pk_t>()(ck::Number<0>{}).unpack<>(ck::Number<1>{}))); // -2.0f
|
||||
}
|
||||
|
||||
__global__ void run_test_mx_fp32_to_fp4_sr(float* p_test) { test_mx_fp32_to_fp4_sr(p_test); }
|
||||
|
||||
__host__ __device__ void test_mx_fp32_to_fp4_sr_failing(float* p_test)
|
||||
{
|
||||
float2_t f32x2 = {1.0f, -4.0f};
|
||||
auto scale2 = e8m0_bexp_t(2.0f);
|
||||
f4x2_t f4x2 = ck::f4_convert_sr_repro(f32x2, type_convert<float>(scale2)); // expect {0.5, -2}
|
||||
|
||||
p_test[0] = type_convert<float>(
|
||||
f4_t(f4x2.AsType<f4x2_pk_t>()(ck::Number<0>{}).unpack<>(ck::Number<0>{}))); // 0.5f
|
||||
p_test[1] = type_convert<float>(
|
||||
f4_t(f4x2.AsType<f4x2_pk_t>()(ck::Number<0>{}).unpack<>(ck::Number<1>{}))); // -2.0f
|
||||
}
|
||||
|
||||
__global__ void run_test_mx_fp32_to_fp4_sr_failing(float* p_test)
|
||||
{
|
||||
test_mx_fp32_to_fp4_sr_failing(p_test);
|
||||
}
|
||||
|
||||
TEST(MXFP4, FP4ToFP32)
|
||||
{
|
||||
std::vector<float> out(2, -1.0f);
|
||||
|
||||
DeviceMem device_out(2 * sizeof(float));
|
||||
|
||||
run_test_mx_fp4_to_fp32<<<1, 1>>>(static_cast<float*>(device_out.GetDeviceBuffer()));
|
||||
|
||||
device_out.FromDevice(out.data());
|
||||
|
||||
// f4x2 -> f32x2
|
||||
EXPECT_EQ(out[0], 1.0f);
|
||||
EXPECT_EQ(out[1], -4.0f);
|
||||
}
|
||||
|
||||
TEST(MXFP4, FP32ToFP4RNE)
|
||||
{
|
||||
std::vector<float> out(2, -1.0f);
|
||||
|
||||
DeviceMem device_out(2 * sizeof(float));
|
||||
|
||||
run_test_mx_fp32_to_fp4_rne<<<1, 1>>>(static_cast<float*>(device_out.GetDeviceBuffer()));
|
||||
|
||||
device_out.FromDevice(out.data());
|
||||
|
||||
// f32x2 -> f4x2
|
||||
// RNE
|
||||
EXPECT_EQ(out[0], 0.5f);
|
||||
EXPECT_EQ(out[1], -2.0f);
|
||||
}
|
||||
|
||||
TEST(MXFP4, FP32ToFP4SR)
|
||||
{
|
||||
std::vector<float> out(2, -1.0f);
|
||||
|
||||
DeviceMem device_out(2 * sizeof(float));
|
||||
|
||||
run_test_mx_fp32_to_fp4_sr<<<1, 1>>>(static_cast<float*>(device_out.GetDeviceBuffer()));
|
||||
|
||||
device_out.FromDevice(out.data());
|
||||
|
||||
// SR
|
||||
EXPECT_EQ(out[0], 0.5f);
|
||||
EXPECT_EQ(out[1], -2.0f);
|
||||
}
|
||||
|
||||
TEST(MXFP4, FP32ToFP4SRFailing)
|
||||
{
|
||||
std::vector<float> out(2, -1.0f);
|
||||
|
||||
DeviceMem device_out(2 * sizeof(float));
|
||||
|
||||
run_test_mx_fp32_to_fp4_sr_failing<<<1, 1>>>(static_cast<float*>(device_out.GetDeviceBuffer()));
|
||||
|
||||
device_out.FromDevice(out.data());
|
||||
|
||||
// SR
|
||||
EXPECT_EQ(out[0], 0.5f);
|
||||
EXPECT_EQ(out[1], -2.0f);
|
||||
}
|
||||
Reference in New Issue
Block a user