mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Add FP4 MX MFMA tests (#2151)
* Add conversion tests
* Fix ctor
* Fix nan logic
* Fix conversion logic
* Permute packed f4_t values
* Fix conversion to float, repack vector elements
* Fix device tests
* Permute elements in a vector
* Add a repro test
* Add a conversion for a repro test
* Update test vectors
* Update conversion
* Fix the test
* Update test vector generator
* Fix vector sr conversion
* Permute conversion args
* Update conversion
* Test
* Fix packing
* Simplify conversion function
* Pack conversion in a loop
* Pack conversion in a loop
* Pack another conversion in a loop
* Pack one more conversion in a loop
* Pack the last conversion in a loop
* Clean up
* Add ops
* Add tests
* Add missing utils
* Update reference mx gemm
* Add f4x2 init mode
* Update host tensor utils
* Update chunk size for f4x2
* Add non scaled ops
* Add a type utility
* Update non scaled reference kernel
* Add non scaled tests
* Debug mfma arguments
* Add more debug info
* Update chunk size
* Update data layout
* Add more debugging
* Fix B stride
* Fix reference gemm
* Fix build
* One more reference fix
* Add more debug info
* Disable some tests
* Enable tests
* Add fp4 dimensions
* Update reference kernels
* Temp edits
* Remove leftovers
* Fix conflicts
* Clean up
* More clean up
* Revert "More clean up"
This reverts commit d8d35a0846.
* Add layouts to tests
---------
Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com>
This commit is contained in:
@@ -79,6 +79,16 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
i4 = i4 - 8;
|
||||
v_a = type_convert<ComputeTypeA>(i4);
|
||||
}
|
||||
else if constexpr(is_same_v<ADataType, f4x2_pk_t>)
|
||||
{
|
||||
// TODO: add support for ColMajor layout as well
|
||||
if(k % 2 == 1)
|
||||
v_a = type_convert<ComputeTypeA>(
|
||||
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{})));
|
||||
else
|
||||
v_a = type_convert<ComputeTypeA>(
|
||||
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{})));
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
|
||||
@@ -95,6 +105,16 @@ struct ReferenceGemm : public device::BaseOperator
|
||||
i4 = i4 - 8;
|
||||
v_b = type_convert<ComputeTypeB>(i4);
|
||||
}
|
||||
else if constexpr(is_same_v<BDataType, f4x2_pk_t>)
|
||||
{
|
||||
// TODO: add support for RowMajor layout as well
|
||||
if(k % 2 == 1)
|
||||
v_b = type_convert<ComputeTypeB>(
|
||||
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{})));
|
||||
else
|
||||
v_b = type_convert<ComputeTypeB>(
|
||||
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{})));
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.b_element_op_(v_b, arg.b_k_n_(k, n));
|
||||
|
||||
@@ -89,9 +89,28 @@ struct ReferenceMXGemm : public device::BaseOperator
|
||||
{
|
||||
for(size_t k = 0; k < K; k++)
|
||||
{
|
||||
a_m_k_scaled(m, k) =
|
||||
type_convert<ComputeTypeA>(arg.a_m_k_(m, k)) *
|
||||
type_convert<ComputeTypeA>(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
|
||||
if constexpr(is_same_v<ADataType, f4x2_pk_t>)
|
||||
{
|
||||
// TODO: add support for ColMajor layout as well
|
||||
if(k % 2 == 1)
|
||||
a_m_k_scaled(m, k) =
|
||||
type_convert<ComputeTypeA>(
|
||||
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{}))) *
|
||||
type_convert<ComputeTypeA>(
|
||||
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
|
||||
else
|
||||
a_m_k_scaled(m, k) =
|
||||
type_convert<ComputeTypeA>(
|
||||
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))) *
|
||||
type_convert<ComputeTypeA>(
|
||||
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k_scaled(m, k) =
|
||||
type_convert<ComputeTypeA>(arg.a_m_k_(m, k)) *
|
||||
type_convert<ComputeTypeA>(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,9 +118,28 @@ struct ReferenceMXGemm : public device::BaseOperator
|
||||
{
|
||||
for(size_t k = 0; k < K; k++)
|
||||
{
|
||||
b_k_n_scaled(k, n) =
|
||||
type_convert<ComputeTypeB>(arg.b_k_n_(k, n)) *
|
||||
type_convert<ComputeTypeB>(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
|
||||
if constexpr(is_same_v<BDataType, f4x2_pk_t>)
|
||||
{
|
||||
// TODO: add support for RowMajor layout as well
|
||||
if(k % 2 == 1)
|
||||
b_k_n_scaled(k, n) =
|
||||
type_convert<ComputeTypeB>(
|
||||
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{}))) *
|
||||
type_convert<ComputeTypeB>(
|
||||
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
|
||||
else
|
||||
b_k_n_scaled(k, n) =
|
||||
type_convert<ComputeTypeB>(
|
||||
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))) *
|
||||
type_convert<ComputeTypeB>(
|
||||
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_scaled(k, n) =
|
||||
type_convert<ComputeTypeB>(arg.b_k_n_(k, n)) *
|
||||
type_convert<ComputeTypeB>(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user