Re-use already-existing scales<> functor template

This commit is contained in:
Po Yen Chen
2024-04-09 08:06:38 +00:00
parent ad45cf8613
commit 5d0ebdbfe4
4 changed files with 31 additions and 24 deletions

View File

@@ -323,14 +323,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto pcompute_element_func = [&]() {
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
return ck_tile::scale{10.f};
return ck_tile::scales{10.f};
else
return ck_tile::identity{};
}();
auto oacc_element_func = [&]() {
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
return ck_tile::compose(ck_tile::saturate_f8{}, ck_tile::scale{0.1f});
return ck_tile::compose(ck_tile::saturate_f8{}, ck_tile::scales{0.1f});
else
return ck_tile::identity{};
}();

View File

@@ -68,8 +68,8 @@ struct FmhaDefaultElementFunctions
/// TODO: support specifying more elementwise functions for input/output tensors
struct FmhaF8StaticQuantizationElementFunctions
{
using PComputeElementFunction = ck_tile::scale;
using OAccElementFunction = ck_tile::composer<ck_tile::saturate_f8, ck_tile::scale>;
using PComputeElementFunction = ck_tile::scales<float>;
using OAccElementFunction = ck_tile::composer<ck_tile::saturate_f8, ck_tile::scales<float>>;
};
template <>