Fix vector sr conversion

This commit is contained in:
Rostyslav Geyyer
2025-02-18 19:56:20 +00:00
parent e323d613ff
commit 7daf21081e
2 changed files with 3 additions and 2 deletions

View File

@@ -1022,6 +1022,7 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f
float2_t floatx2_array[16];
float32_t floatx32_array;
} float_values{{0}};
float_values.floatx32_array = x;
// TODO: pack in a loop
tmp_values.bitwise = __builtin_amdgcn_cvt_scalef32_sr_pk_fp4_f32(
tmp_values.bitwise, float_values.floatx2_array[0], rng, scale, 0);

View File

@@ -438,9 +438,9 @@ __global__ void test_mx_fp4x32_device_scaled_convert_sr(float* p_test, uint64_t*
ck::static_for<0, N / 2, 1>{}([&](auto ii) {
p_test[i++] = type_convert<float>(
f4x32.AsType<f4x2_pk_t>()(ck::Number<ii>{}).template unpack<>(ck::Number<0>{}));
f4_t(f4x32.AsType<f4x2_pk_t>()(ck::Number<ii>{}).template unpack<>(ck::Number<0>{})));
p_test[i++] = type_convert<float>(
f4x32.AsType<f4x2_pk_t>()(ck::Number<ii>{}).template unpack<>(ck::Number<1>{}));
f4_t(f4x32.AsType<f4x2_pk_t>()(ck::Number<ii>{}).template unpack<>(ck::Number<1>{})));
});
}