mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Re-use already-existing scales<> functor template
This commit is contained in:
@@ -323,14 +323,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
auto pcompute_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
|
||||
return ck_tile::scale{10.f};
|
||||
return ck_tile::scales{10.f};
|
||||
else
|
||||
return ck_tile::identity{};
|
||||
}();
|
||||
|
||||
auto oacc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
|
||||
return ck_tile::compose(ck_tile::saturate_f8{}, ck_tile::scale{0.1f});
|
||||
return ck_tile::compose(ck_tile::saturate_f8{}, ck_tile::scales{0.1f});
|
||||
else
|
||||
return ck_tile::identity{};
|
||||
}();
|
||||
|
||||
@@ -68,8 +68,8 @@ struct FmhaDefaultElementFunctions
|
||||
/// TODO: support specifying more elementwise functions for input/output tensors
|
||||
struct FmhaF8StaticQuantizationElementFunctions
|
||||
{
|
||||
using PComputeElementFunction = ck_tile::scale;
|
||||
using OAccElementFunction = ck_tile::composer<ck_tile::saturate_f8, ck_tile::scale>;
|
||||
using PComputeElementFunction = ck_tile::scales<float>;
|
||||
using OAccElementFunction = ck_tile::composer<ck_tile::saturate_f8, ck_tile::scales<float>>;
|
||||
};
|
||||
|
||||
template <>
|
||||
|
||||
@@ -13,12 +13,37 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T, T s>
|
||||
template <typename Scale, Scale lhs>
|
||||
struct scales_c
|
||||
{
|
||||
template <typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const -> decltype(lhs * rhs)
|
||||
{
|
||||
return lhs * rhs;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scale>
|
||||
struct scales
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T a) const { return s * a; }
|
||||
static_assert(std::is_copy_constructible_v<Scale>);
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr explicit scales(Scale lhs) : lhs_(lhs) {}
|
||||
|
||||
template <typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const
|
||||
-> decltype(std::declval<const Scale&>() * rhs)
|
||||
{
|
||||
return lhs_ * rhs;
|
||||
}
|
||||
|
||||
private:
|
||||
Scale lhs_;
|
||||
};
|
||||
|
||||
template <typename Scale>
|
||||
__host__ __device__ scales(Scale)->scales<Scale>;
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct plus
|
||||
{
|
||||
|
||||
@@ -42,24 +42,6 @@ CK_TILE_HOST auto compose(F... f)
|
||||
return composer<F...>(f...);
|
||||
}
|
||||
|
||||
// start of unary element function
|
||||
|
||||
struct scale
|
||||
{
|
||||
CK_TILE_HOST_DEVICE scale(float factor) : factor_(factor) {}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& x) const;
|
||||
|
||||
template <>
|
||||
CK_TILE_HOST_DEVICE constexpr float operator()<float>(const float& x) const
|
||||
{
|
||||
return factor_ * x;
|
||||
};
|
||||
|
||||
float factor_;
|
||||
};
|
||||
|
||||
// TODO: Overload numeric::min() and numeric::max()
|
||||
struct saturate_f8
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user