From e38b4a3327f9bed5266e1f50e30c16ebf08ca306 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer Date: Fri, 14 Feb 2025 21:31:40 +0000 Subject: [PATCH] Add a conversion for a repro test --- include/ck/utility/type_convert.hpp | 27 +++++++++++++++++++++ test/data_type/test_mx_fp4_repro.cpp | 35 ++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 9c6df8d76b..0759ae8b34 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -978,6 +978,33 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) #endif } +// convert vector of 2 fp32 to vector of 2 fp4 with sr +inline __host__ __device__ f4x2_t f4_convert_sr_repro(float2_t x, float scale = 1.0f) +{ + constexpr int seed = 1254739; + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); +#if defined(__gfx950__) + union + { + uint32_t bitwise; + f4x2_t f4x2_array[4]; + } value{0}; + value.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32( + value.bitwise, float2_t{x[1], x[0]}, rng, scale, 0); + return value.f4x2_array[0]; +#else + union + { + uint32_t bitwise; + f4x2_t f4x2_array[4]; + } value{0}; + uint8_t l = utils::sat_convert_to_type_sr(x[1] / scale, rng); + uint8_t h = utils::sat_convert_to_type_sr(x[0] / scale, rng); + value.bitwise = (h << 4) | l; + return value.f4x2_array[0]; +#endif +} + // convert vector of 32 fp32 to vector of 32 fp4 with sr inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f) { diff --git a/test/data_type/test_mx_fp4_repro.cpp b/test/data_type/test_mx_fp4_repro.cpp index 5210ca9dc9..a53c39ebdb 100644 --- a/test/data_type/test_mx_fp4_repro.cpp +++ b/test/data_type/test_mx_fp4_repro.cpp @@ -63,6 +63,23 @@ __host__ __device__ void test_mx_fp32_to_fp4_sr(float* p_test) __global__ void run_test_mx_fp32_to_fp4_sr(float* p_test) { test_mx_fp32_to_fp4_sr(p_test); } +__host__ __device__ void test_mx_fp32_to_fp4_sr_failing(float* p_test) +{ + float2_t f32x2 = {1.0f, -4.0f}; + auto scale2 = e8m0_bexp_t(2.0f); + f4x2_t f4x2 = ck::f4_convert_sr_repro(f32x2, type_convert(scale2)); // expect {0.5, -2} + + p_test[0] = type_convert( + f4_t(f4x2.AsType()(ck::Number<0>{}).unpack<>(ck::Number<0>{}))); // 0.5f + p_test[1] = type_convert( + f4_t(f4x2.AsType()(ck::Number<0>{}).unpack<>(ck::Number<1>{}))); // -2.0f +} + +__global__ void run_test_mx_fp32_to_fp4_sr_failing(float* p_test) +{ + test_mx_fp32_to_fp4_sr_failing(p_test); +} + TEST(MXFP4, FP4ToFP32) { std::vector out(2, -1.0f); @@ -120,3 +137,21 @@ TEST(MXFP4, FP32ToFP4SR) EXPECT_EQ(out[0], 0.5f); EXPECT_EQ(out[1], -2.0f); } + +TEST(MXFP4, FP32ToFP4SRFailing) +{ + std::vector out(2, -1.0f); + + DeviceMem device_out(2 * sizeof(float)); + // DeviceMem device_completed(sizeof(uint64_t)); + + run_test_mx_fp32_to_fp4_sr_failing<<<1, 1>>>(static_cast(device_out.GetDeviceBuffer())); + + // uint64_t completed = 0; + // device_completed.FromDevice(&completed); + device_out.FromDevice(out.data()); + + // SR + EXPECT_EQ(out[0], 0.5f); + EXPECT_EQ(out[1], -2.0f); +}