mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Add a conversion for a repro test
This commit is contained in:
@@ -978,6 +978,33 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f)
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert vector of 2 fp32 to vector of 2 fp4 with sr
|
||||
inline __host__ __device__ f4x2_t f4_convert_sr_repro(float2_t x, float scale = 1.0f)
|
||||
{
|
||||
constexpr int seed = 1254739;
|
||||
uint32_t rng = prand_generator<float, seed>(reinterpret_cast<uintptr_t>(&x), x[0]);
|
||||
#if defined(__gfx950__)
|
||||
union
|
||||
{
|
||||
uint32_t bitwise;
|
||||
f4x2_t f4x2_array[4];
|
||||
} value{0};
|
||||
value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
|
||||
value.bitwise, float2_t{x[1], x[0]}, rng, scale, 0);
|
||||
return value.f4x2_array[0];
|
||||
#else
|
||||
union
|
||||
{
|
||||
uint32_t bitwise;
|
||||
f4x2_t f4x2_array[4];
|
||||
} value{0};
|
||||
uint8_t l = utils::sat_convert_to_type_sr<f4_t>(x[1] / scale, rng);
|
||||
uint8_t h = utils::sat_convert_to_type_sr<f4_t>(x[0] / scale, rng);
|
||||
value.bitwise = (h << 4) | l;
|
||||
return value.f4x2_array[0];
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert vector of 32 fp32 to vector of 32 fp4 with sr
|
||||
inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f)
|
||||
{
|
||||
|
||||
@@ -63,6 +63,23 @@ __host__ __device__ void test_mx_fp32_to_fp4_sr(float* p_test)
|
||||
|
||||
__global__ void run_test_mx_fp32_to_fp4_sr(float* p_test) { test_mx_fp32_to_fp4_sr(p_test); }
|
||||
|
||||
__host__ __device__ void test_mx_fp32_to_fp4_sr_failing(float* p_test)
|
||||
{
|
||||
float2_t f32x2 = {1.0f, -4.0f};
|
||||
auto scale2 = e8m0_bexp_t(2.0f);
|
||||
f4x2_t f4x2 = ck::f4_convert_sr_repro(f32x2, type_convert<float>(scale2)); // expect {0.5, -2}
|
||||
|
||||
p_test[0] = type_convert<float>(
|
||||
f4_t(f4x2.AsType<f4x2_pk_t>()(ck::Number<0>{}).unpack<>(ck::Number<0>{}))); // 0.5f
|
||||
p_test[1] = type_convert<float>(
|
||||
f4_t(f4x2.AsType<f4x2_pk_t>()(ck::Number<0>{}).unpack<>(ck::Number<1>{}))); // -2.0f
|
||||
}
|
||||
|
||||
__global__ void run_test_mx_fp32_to_fp4_sr_failing(float* p_test)
|
||||
{
|
||||
test_mx_fp32_to_fp4_sr_failing(p_test);
|
||||
}
|
||||
|
||||
TEST(MXFP4, FP4ToFP32)
|
||||
{
|
||||
std::vector<float> out(2, -1.0f);
|
||||
@@ -120,3 +137,21 @@ TEST(MXFP4, FP32ToFP4SR)
|
||||
EXPECT_EQ(out[0], 0.5f);
|
||||
EXPECT_EQ(out[1], -2.0f);
|
||||
}
|
||||
|
||||
TEST(MXFP4, FP32ToFP4SRFailing)
|
||||
{
|
||||
std::vector<float> out(2, -1.0f);
|
||||
|
||||
DeviceMem device_out(2 * sizeof(float));
|
||||
// DeviceMem device_completed(sizeof(uint64_t));
|
||||
|
||||
run_test_mx_fp32_to_fp4_sr_failing<<<1, 1>>>(static_cast<float*>(device_out.GetDeviceBuffer()));
|
||||
|
||||
// uint64_t completed = 0;
|
||||
// device_completed.FromDevice(&completed);
|
||||
device_out.FromDevice(out.data());
|
||||
|
||||
// SR
|
||||
EXPECT_EQ(out[0], 0.5f);
|
||||
EXPECT_EQ(out[1], -2.0f);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user