Add FP4 MX MFMA tests (#2151)

* Add conversion tests

* Fix ctor

* Fix nan logic

* Fix conversion logic

* Permute packed f4_t values

* Fix conversion to float, repack vector elements

* Fix device tests

* Permute elements in a vector

* Add a repro test

* Add a conversion for a repro test

* Update test vectors

* Update conversion

* Fix the test

* Update test vector generator

* Fix vector sr conversion

* Permute conversion args

* Update conversion

* Test

* Fix packing

* Simplify conversion function

* Pack conversion in a loop

* Pack conversion in a loop

* Pack another conversion in a loop

* Pack one more conversion in a loop

* Pack the last conversion in a loop

* Clean up

* Add ops

* Add tests

* Add missing utils

* Update reference mx gemm

* Add f4x2 init mode

* Update host tensor utils

* Update chunk size for f4x2

* Add non scaled ops

* Add a type utility

* Update non scaled reference kernel

* Add non scaled tests

* Debug mfma arguments

* Add more debug info

* Update chunk size

* Update data layout

* Add more debugging

* Fix B stride

* Fix reference gemm

* Fix build

* One more reference fix

* Add more debug info

* Disable some tests

* Enable tests

* Add fp4 dimensions

* Update reference kernels

* Temp edits

* Remove leftovers

* Fix conflicts

* Clean up

* More clean up

* Revert "More clean up"

This reverts commit d8d35a0846.

* Add layouts to tests

---------

Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com>
This commit is contained in:
Rostyslav Geyyer
2025-05-06 09:24:00 -05:00
committed by GitHub
parent 4e9b76f88c
commit 8a0d659f92
8 changed files with 610 additions and 79 deletions

View File

@@ -51,7 +51,8 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
{
os << ck::type_convert<float>(v);
}
else if constexpr(std::is_same_v<RangeType, ck::pk_i4_t>)
else if constexpr(std::is_same_v<RangeType, ck::pk_i4_t> ||
std::is_same_v<RangeType, ck::f4x2_pk_t>)
{
const auto packed_floats = ck::type_convert<ck::float2_t>(v);
const ck::vector_type<float, 2> vector_of_floats{packed_floats};
@@ -359,7 +360,8 @@ struct Tensor
std::size_t GetElementSpaceSize() const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return (mDesc.GetElementSpaceSize() + 1) / 2;
}
@@ -514,7 +516,8 @@ struct Tensor
template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mDesc.GetOffsetFromMultiIndex(is...) / 2;
}
@@ -527,7 +530,8 @@ struct Tensor
template <typename... Is>
T& operator()(Is... is)
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
@@ -540,7 +544,8 @@ struct Tensor
template <typename... Is>
const T& operator()(Is... is) const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
@@ -552,7 +557,8 @@ struct Tensor
T& operator()(std::vector<std::size_t> idx)
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}
@@ -564,7 +570,8 @@ struct Tensor
const T& operator()(std::vector<std::size_t> idx) const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -81,6 +81,18 @@ struct GeneratorTensor_1<ck::f4_t>
}
};
template <>
struct GeneratorTensor_1<ck::f4x2_pk_t>
{
float value = 1.0;
template <typename... Is>
ck::f4x2_pk_t operator()(Is...)
{
return ck::f4x2_pk_t{ck::type_convert<ck::f4x2_t>(ck::float2_t{value, value})};
}
};
template <>
struct GeneratorTensor_1<int8_t>
{
@@ -209,6 +221,21 @@ struct GeneratorTensor_2<ck::f4_t>
}
};
template <>
struct GeneratorTensor_2<ck::f4x2_pk_t>
{
int min_value = 0;
int max_value = 1;
template <typename... Is>
ck::f4x2_pk_t operator()(Is...)
{
float tmp0 = (std::rand() % (max_value - min_value)) + min_value;
float tmp1 = (std::rand() % (max_value - min_value)) + min_value;
return ck::f4x2_pk_t{ck::type_convert<ck::f4x2_t>(ck::float2_t{tmp0, tmp1})};
}
};
template <typename T>
struct GeneratorTensor_3
{
@@ -296,6 +323,25 @@ struct GeneratorTensor_3<ck::f4_t>
}
};
template <>
struct GeneratorTensor_3<ck::f4x2_pk_t>
{
float min_value = 0;
float max_value = 1;
template <typename... Is>
ck::f4x2_pk_t operator()(Is...)
{
float tmp0 = float(std::rand()) / float(RAND_MAX);
float tmp1 = float(std::rand()) / float(RAND_MAX);
float fp32_tmp0 = min_value + tmp0 * (max_value - min_value);
float fp32_tmp1 = min_value + tmp1 * (max_value - min_value);
return ck::f4x2_pk_t{ck::type_convert<ck::f4x2_t>(ck::float2_t{fp32_tmp0, fp32_tmp1})};
}
};
template <typename T>
struct GeneratorTensor_4
{