mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Update test vector generator
This commit is contained in:
@@ -324,22 +324,23 @@ TEST(MXFP4, DeviceScaledConvert)
|
||||
EXPECT_EQ(test_size, i);
|
||||
}
|
||||
|
||||
__host__ __device__ float vec16_generator(ck::index_t i)
|
||||
__host__ __device__ float vec16_generator(ck::index_t i, float scale)
|
||||
{
|
||||
return type_convert<float>(f4_t(i & 0b00001111));
|
||||
return scale * type_convert<float>(f4_t(i & 0b00001111));
|
||||
}
|
||||
|
||||
__host__ __device__ float vec32_generator(ck::index_t i)
|
||||
__host__ __device__ float vec32_generator(ck::index_t i, float scale)
|
||||
{
|
||||
if(i < 16)
|
||||
{
|
||||
return vec16_generator(
|
||||
i); // all positive values, then all negative values in ascending order
|
||||
i, scale); // all positive values, then all negative values in ascending order
|
||||
}
|
||||
else
|
||||
{
|
||||
return vec16_generator(
|
||||
15 - (i % 16)); // all negative values, then all positive values in descending order
|
||||
15 - (i % 16),
|
||||
scale); // all negative values, then all positive values in descending order
|
||||
}
|
||||
}
|
||||
|
||||
@@ -363,8 +364,9 @@ __global__ void test_mx_fp4x32_device_scaled_convert(float* p_test, uint64_t* p_
|
||||
|
||||
f4x32_t f4x32{};
|
||||
float32_t float32{};
|
||||
ck::static_for<0, N, 1>{}(
|
||||
[&](auto ii) { float32[static_cast<int>(ii)] = vec32_generator(ii); });
|
||||
ck::static_for<0, N, 1>{}([&](auto ii) {
|
||||
float32[static_cast<int>(ii)] = vec32_generator(ii, type_convert<float>(scale2));
|
||||
});
|
||||
|
||||
f4x32 = f4_convert_rne(float32, type_convert<float>(scale2));
|
||||
|
||||
@@ -399,7 +401,8 @@ TEST(MXFP4, DeviceF32x32ToF4x32ScaledConvert)
|
||||
auto scale2 = e8m0_bexp_t(2.0f);
|
||||
|
||||
ck::static_for<0, N, 1>{}([&](auto ii) {
|
||||
EXPECT_EQ(out[i++], vec32_generator(ii) / type_convert<float>(scale2))
|
||||
EXPECT_EQ(out[i++],
|
||||
vec32_generator(ii, type_convert<float>(scale2)) / type_convert<float>(scale2))
|
||||
<< "ii: " << ii << std::endl;
|
||||
});
|
||||
|
||||
@@ -427,8 +430,9 @@ __global__ void test_mx_fp4x32_device_scaled_convert_sr(float* p_test, uint64_t*
|
||||
|
||||
f4x32_t f4x32{};
|
||||
float32_t float32{};
|
||||
ck::static_for<0, N, 1>{}(
|
||||
[&](auto ii) { float32[static_cast<int>(ii)] = vec32_generator(ii); });
|
||||
ck::static_for<0, N, 1>{}([&](auto ii) {
|
||||
float32[static_cast<int>(ii)] = vec32_generator(ii, type_convert<float>(scale2));
|
||||
});
|
||||
|
||||
f4x32 = f4_convert_sr(float32, type_convert<float>(scale2));
|
||||
|
||||
@@ -463,7 +467,8 @@ TEST(MXFP4, DeviceF32x32ToF4x32ScaledConvertSR)
|
||||
auto scale2 = e8m0_bexp_t(2.0f);
|
||||
|
||||
ck::static_for<0, N, 1>{}([&](auto ii) {
|
||||
EXPECT_EQ(out[i++], vec32_generator(ii) / type_convert<float>(scale2))
|
||||
EXPECT_EQ(out[i++],
|
||||
vec32_generator(ii, type_convert<float>(scale2)) / type_convert<float>(scale2))
|
||||
<< "ii: " << ii << std::endl;
|
||||
});
|
||||
|
||||
@@ -493,8 +498,10 @@ __global__ void test_mx_f32x32_device_scaled_convert(float* p_test, uint64_t* p_
|
||||
float32_t float32{};
|
||||
ck::static_for<0, N / 2, 1>{}([&](auto ii) {
|
||||
f4x32.AsType<f4x2_pk_t>()(ck::Number<ii>{})
|
||||
.pack(type_convert<f4_t>(vec32_generator(2 * ii) / type_convert<float>(scale2)),
|
||||
type_convert<f4_t>(vec32_generator(2 * ii + 1) / type_convert<float>(scale2)));
|
||||
.pack(type_convert<f4_t>(vec32_generator(2 * ii, type_convert<float>(scale2)) /
|
||||
type_convert<float>(scale2)),
|
||||
type_convert<f4_t>(vec32_generator(2 * ii + 1, type_convert<float>(scale2)) /
|
||||
type_convert<float>(scale2)));
|
||||
});
|
||||
|
||||
float32 = scaled_type_convert<float32_t>(scale2, f4x32);
|
||||
@@ -525,7 +532,8 @@ TEST(MXFP4, DeviceF4x32ToF32x32ScaledConvert)
|
||||
auto scale2 = e8m0_bexp_t(2.0f);
|
||||
|
||||
ck::static_for<0, N, 1>{}([&](auto ii) {
|
||||
EXPECT_EQ(out[i++], vec32_generator(ii) / type_convert<float>(scale2))
|
||||
EXPECT_EQ(out[i++],
|
||||
vec32_generator(ii, type_convert<float>(scale2)) / type_convert<float>(scale2))
|
||||
<< "ii: " << ii << std::endl;
|
||||
});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user