mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Fix wrong value produced by saturating
This commit is contained in:
@@ -330,7 +330,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
auto oacc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
|
||||
return ck_tile::compose(ck_tile::saturate_f8{}, ck_tile::scales{0.1f});
|
||||
return ck_tile::compose(ck_tile::saturates<ck_tile::fp8_t>{}, ck_tile::scales{0.1f});
|
||||
else
|
||||
return ck_tile::identity{};
|
||||
}();
|
||||
|
||||
@@ -69,7 +69,8 @@ struct FmhaDefaultElementFunctions
|
||||
struct FmhaF8StaticQuantizationElementFunctions
|
||||
{
|
||||
using PComputeElementFunction = ck_tile::scales<float>;
|
||||
using OAccElementFunction = ck_tile::composer<ck_tile::saturate_f8, ck_tile::scales<float>>;
|
||||
using OAccElementFunction =
|
||||
ck_tile::composer<ck_tile::saturates<ck_tile::fp8_t>, ck_tile::scales<float>>;
|
||||
};
|
||||
|
||||
template <>
|
||||
|
||||
@@ -42,18 +42,27 @@ CK_TILE_HOST auto compose(F... f)
|
||||
return composer<F...>(f...);
|
||||
}
|
||||
|
||||
// TODO: Overload numeric::min() and numeric::max()
|
||||
struct saturate_f8
|
||||
template <typename To>
|
||||
struct saturates
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& x) const
|
||||
template <typename From>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const From& from) const
|
||||
-> std::enable_if_t<std::is_arithmetic_v<From>, From>
|
||||
{
|
||||
static_assert(std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t>,
|
||||
"Data type is not supported by this operation!");
|
||||
|
||||
T y = clamp(x, static_cast<T>(-448), static_cast<T>(448));
|
||||
return y;
|
||||
if constexpr(std::is_floating_point_v<To> || std::is_same_v<To, half_t> ||
|
||||
std::is_same_v<To, bfloat16_t> || std::is_same_v<To, fp8_t> ||
|
||||
std::is_same_v<To, bf8_t>)
|
||||
{
|
||||
return clamp(from,
|
||||
type_convert<From>(numeric<To>::lowest()),
|
||||
type_convert<From>(numeric<To>::max()));
|
||||
}
|
||||
else
|
||||
{
|
||||
return clamp(from,
|
||||
type_convert<From>(numeric<To>::min()),
|
||||
type_convert<From>(numeric<To>::max()));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user