diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index ccb960ad83..6126cc9059 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -330,7 +330,7 @@ bool run(const ck_tile::ArgParser& arg_parser) auto oacc_element_func = [&]() { if constexpr(std::is_same_v) - return ck_tile::compose(ck_tile::saturate_f8{}, ck_tile::scales{0.1f}); + return ck_tile::compose(ck_tile::saturates{}, ck_tile::scales{0.1f}); else return ck_tile::identity{}; }(); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 8bdebf5769..6f21b9bffc 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -69,7 +69,8 @@ struct FmhaDefaultElementFunctions struct FmhaF8StaticQuantizationElementFunctions { using PComputeElementFunction = ck_tile::scales; - using OAccElementFunction = ck_tile::composer>; + using OAccElementFunction = + ck_tile::composer, ck_tile::scales>; }; template <> diff --git a/include/ck_tile/core/utility/unary_element_function.hpp b/include/ck_tile/core/utility/unary_element_function.hpp index 85a3acd450..199d225ba4 100644 --- a/include/ck_tile/core/utility/unary_element_function.hpp +++ b/include/ck_tile/core/utility/unary_element_function.hpp @@ -42,18 +42,27 @@ CK_TILE_HOST auto compose(F... f) return composer(f...); } -// TODO: Overload numeric::min() and numeric::max() -struct saturate_f8 +template +struct saturates { - template - CK_TILE_HOST_DEVICE constexpr T operator()(const T& x) const + template + CK_TILE_HOST_DEVICE constexpr auto operator()(const From& from) const + -> std::enable_if_t, From> { - static_assert(std::is_same_v || std::is_same_v || - std::is_same_v, - "Data type is not supported by this operation!"); - - T y = clamp(x, static_cast(-448), static_cast(448)); - return y; + if constexpr(std::is_floating_point_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) + { + return clamp(from, + type_convert(numeric::lowest()), + type_convert(numeric::max())); + } + else + { + return clamp(from, + type_convert(numeric::min()), + type_convert(numeric::max())); + } } };