Add FP4 MX MFMA tests (#2151)

* Add conversion tests

* Fix ctor

* Fix nan logic

* Fix conversion logic

* Permute packed f4_t values

* Fix conversion to float, repack vector elements

* Fix device tests

* Permute elements in a vector

* Add a repro test

* Add a conversion for a repro test

* Update test vectors

* Update conversion

* Fix the test

* Update test vector generator

* Fix vector sr conversion

* Permute conversion args

* Update conversion

* Test

* Fix packing

* Simplify conversion function

* Pack conversion in a loop

* Pack conversion in a loop

* Pack another conversion in a loop

* Pack one more conversion in a loop

* Pack the last conversion in a loop

* Clean up

* Add ops

* Add tests

* Add missing utils

* Update reference mx gemm

* Add f4x2 init mode

* Update host tensor utils

* Update chunk size for f4x2

* Add non scaled ops

* Add a type utility

* Update non scaled reference kernel

* Add non scaled tests

* Debug mfma arguments

* Add more debug info

* Update chunk size

* Update data layout

* Add more debugging

* Fix B stride

* Fix reference gemm

* Fix build

* One more reference fix

* Add more debug info

* Disable some tests

* Enable tests

* Add fp4 dimensions

* Update reference kernels

* Temp edits

* Remove leftovers

* Fix conflicts

* Clean up

* More clean up

* Revert "More clean up"

This reverts commit d8d35a0846.

* Add layouts to tests

---------

Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com>
This commit is contained in:
Rostyslav Geyyer
2025-05-06 09:24:00 -05:00
committed by GitHub
parent 4e9b76f88c
commit 8a0d659f92
8 changed files with 610 additions and 79 deletions

View File

@@ -79,6 +79,16 @@ struct ReferenceGemm : public device::BaseOperator
i4 = i4 - 8;
v_a = type_convert<ComputeTypeA>(i4);
}
else if constexpr(is_same_v<ADataType, f4x2_pk_t>)
{
// TODO: add support for ColMajor layout as well
if(k % 2 == 1)
v_a = type_convert<ComputeTypeA>(
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{})));
else
v_a = type_convert<ComputeTypeA>(
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{})));
}
else
{
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
@@ -95,6 +105,16 @@ struct ReferenceGemm : public device::BaseOperator
i4 = i4 - 8;
v_b = type_convert<ComputeTypeB>(i4);
}
else if constexpr(is_same_v<BDataType, f4x2_pk_t>)
{
// TODO: add support for RowMajor layout as well
if(k % 2 == 1)
v_b = type_convert<ComputeTypeB>(
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{})));
else
v_b = type_convert<ComputeTypeB>(
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{})));
}
else
{
arg.b_element_op_(v_b, arg.b_k_n_(k, n));

View File

@@ -89,9 +89,28 @@ struct ReferenceMXGemm : public device::BaseOperator
{
for(size_t k = 0; k < K; k++)
{
a_m_k_scaled(m, k) =
type_convert<ComputeTypeA>(arg.a_m_k_(m, k)) *
type_convert<ComputeTypeA>(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
if constexpr(is_same_v<ADataType, f4x2_pk_t>)
{
// TODO: add support for ColMajor layout as well
if(k % 2 == 1)
a_m_k_scaled(m, k) =
type_convert<ComputeTypeA>(
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{}))) *
type_convert<ComputeTypeA>(
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
else
a_m_k_scaled(m, k) =
type_convert<ComputeTypeA>(
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))) *
type_convert<ComputeTypeA>(
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
}
else
{
a_m_k_scaled(m, k) =
type_convert<ComputeTypeA>(arg.a_m_k_(m, k)) *
type_convert<ComputeTypeA>(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
}
}
}
@@ -99,9 +118,28 @@ struct ReferenceMXGemm : public device::BaseOperator
{
for(size_t k = 0; k < K; k++)
{
b_k_n_scaled(k, n) =
type_convert<ComputeTypeB>(arg.b_k_n_(k, n)) *
type_convert<ComputeTypeB>(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
if constexpr(is_same_v<BDataType, f4x2_pk_t>)
{
// TODO: add support for RowMajor layout as well
if(k % 2 == 1)
b_k_n_scaled(k, n) =
type_convert<ComputeTypeB>(
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{}))) *
type_convert<ComputeTypeB>(
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
else
b_k_n_scaled(k, n) =
type_convert<ComputeTypeB>(
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))) *
type_convert<ComputeTypeB>(
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
}
else
{
b_k_n_scaled(k, n) =
type_convert<ComputeTypeB>(arg.b_k_n_(k, n)) *
type_convert<ComputeTypeB>(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
}
}
}