diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 8e8e19f089..ccb960ad83 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -323,14 +323,14 @@ bool run(const ck_tile::ArgParser& arg_parser) auto pcompute_element_func = [&]() { if constexpr(std::is_same_v) - return ck_tile::scale{10.f}; + return ck_tile::scales{10.f}; else return ck_tile::identity{}; }(); auto oacc_element_func = [&]() { if constexpr(std::is_same_v) - return ck_tile::compose(ck_tile::saturate_f8{}, ck_tile::scale{0.1f}); + return ck_tile::compose(ck_tile::saturate_f8{}, ck_tile::scales{0.1f}); else return ck_tile::identity{}; }(); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 54bcf2c1ee..8bdebf5769 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -68,8 +68,8 @@ struct FmhaDefaultElementFunctions /// TODO: support specifying more elementwise functions for input/output tensors struct FmhaF8StaticQuantizationElementFunctions { - using PComputeElementFunction = ck_tile::scale; - using OAccElementFunction = ck_tile::composer; + using PComputeElementFunction = ck_tile::scales; + using OAccElementFunction = ck_tile::composer>; }; template <> diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 868fff05b7..bfc0c9c010 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -13,12 +13,37 @@ namespace ck_tile { -template +template +struct scales_c +{ + template + CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const -> decltype(lhs * rhs) + { + return lhs * rhs; + } +}; + +template struct scales { - CK_TILE_HOST_DEVICE constexpr T operator()(T a) const { return s * a; } + static_assert(std::is_copy_constructible_v); + + CK_TILE_HOST_DEVICE constexpr explicit scales(Scale lhs) : lhs_(lhs) {} + + template + CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const + -> decltype(std::declval() * rhs) + { + return lhs_ * rhs; + } + + private: + Scale lhs_; }; +template +__host__ __device__ scales(Scale)->scales; + template struct plus { diff --git a/include/ck_tile/core/utility/unary_element_function.hpp b/include/ck_tile/core/utility/unary_element_function.hpp index 7eb77c01da..85a3acd450 100644 --- a/include/ck_tile/core/utility/unary_element_function.hpp +++ b/include/ck_tile/core/utility/unary_element_function.hpp @@ -42,24 +42,6 @@ CK_TILE_HOST auto compose(F... f) return composer(f...); } -// start of unary element function - -struct scale -{ - CK_TILE_HOST_DEVICE scale(float factor) : factor_(factor) {} - - template - CK_TILE_HOST_DEVICE constexpr T operator()(const T& x) const; - - template <> - CK_TILE_HOST_DEVICE constexpr float operator()(const float& x) const - { - return factor_ * x; - }; - - float factor_; -}; - // TODO: Overload numeric::min() and numeric::max() struct saturate_f8 {