mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
Re-use already-existing scales<> functor template
This commit is contained in:
@@ -323,14 +323,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
auto pcompute_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
|
||||
return ck_tile::scale{10.f};
|
||||
return ck_tile::scales{10.f};
|
||||
else
|
||||
return ck_tile::identity{};
|
||||
}();
|
||||
|
||||
auto oacc_element_func = [&]() {
|
||||
if constexpr(std::is_same_v<DataType, ck_tile::fp8_t>)
|
||||
return ck_tile::compose(ck_tile::saturate_f8{}, ck_tile::scale{0.1f});
|
||||
return ck_tile::compose(ck_tile::saturate_f8{}, ck_tile::scales{0.1f});
|
||||
else
|
||||
return ck_tile::identity{};
|
||||
}();
|
||||
|
||||
@@ -68,8 +68,8 @@ struct FmhaDefaultElementFunctions
|
||||
/// TODO: support specifying more elementwise functions for input/output tensors
|
||||
struct FmhaF8StaticQuantizationElementFunctions
|
||||
{
|
||||
using PComputeElementFunction = ck_tile::scale;
|
||||
using OAccElementFunction = ck_tile::composer<ck_tile::saturate_f8, ck_tile::scale>;
|
||||
using PComputeElementFunction = ck_tile::scales<float>;
|
||||
using OAccElementFunction = ck_tile::composer<ck_tile::saturate_f8, ck_tile::scales<float>>;
|
||||
};
|
||||
|
||||
template <>
|
||||
|
||||
Reference in New Issue
Block a user