Re-use already-existing scales<> functor template

This commit is contained in:
Po Yen Chen
2024-04-09 08:06:38 +00:00
parent ad45cf8613
commit 5d0ebdbfe4
4 changed files with 31 additions and 24 deletions

View File

@@ -323,14 +323,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto pcompute_element_func = [&]() {
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
return ck_tile::scale{10.f};
return ck_tile::scales{10.f};
else
return ck_tile::identity{};
}();
auto oacc_element_func = [&]() {
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
return ck_tile::compose(ck_tile::saturate_f8{}, ck_tile::scale{0.1f});
return ck_tile::compose(ck_tile::saturate_f8{}, ck_tile::scales{0.1f});
else
return ck_tile::identity{};
}();

View File

@@ -68,8 +68,8 @@ struct FmhaDefaultElementFunctions
/// TODO: support specifying more elementwise functions for input/output tensors
struct FmhaF8StaticQuantizationElementFunctions
{
using PComputeElementFunction = ck_tile::scale;
using OAccElementFunction = ck_tile::composer<ck_tile::saturate_f8, ck_tile::scale>;
using PComputeElementFunction = ck_tile::scales<float>;
using OAccElementFunction = ck_tile::composer<ck_tile::saturate_f8, ck_tile::scales<float>>;
};
template <>

View File

@@ -13,12 +13,37 @@
namespace ck_tile {
template <typename T, T s>
template <typename Scale, Scale lhs>
struct scales_c
{
template <typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const -> decltype(lhs * rhs)
{
return lhs * rhs;
}
};
template <typename Scale>
struct scales
{
CK_TILE_HOST_DEVICE constexpr T operator()(T a) const { return s * a; }
static_assert(std::is_copy_constructible_v<Scale>);
CK_TILE_HOST_DEVICE constexpr explicit scales(Scale lhs) : lhs_(lhs) {}
template <typename Right>
CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const
-> decltype(std::declval<const Scale&>() * rhs)
{
return lhs_ * rhs;
}
private:
Scale lhs_;
};
template <typename Scale>
__host__ __device__ scales(Scale)->scales<Scale>;
template <typename Left = void, typename Right = Left>
struct plus
{

View File

@@ -42,24 +42,6 @@ CK_TILE_HOST auto compose(F... f)
return composer<F...>(f...);
}
// start of unary element function
struct scale
{
CK_TILE_HOST_DEVICE scale(float factor) : factor_(factor) {}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T operator()(const T& x) const;
template <>
CK_TILE_HOST_DEVICE constexpr float operator()<float>(const float& x) const
{
return factor_ * x;
};
float factor_;
};
// TODO: Overload numeric::min() and numeric::max()
struct saturate_f8
{