mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 23:05:54 +00:00
Re-use already-existing scales<> functor template
This commit is contained in:
@@ -13,12 +13,37 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T, T s>
|
||||
template <typename Scale, Scale lhs>
|
||||
struct scales_c
|
||||
{
|
||||
template <typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const -> decltype(lhs * rhs)
|
||||
{
|
||||
return lhs * rhs;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Scale>
|
||||
struct scales
|
||||
{
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T a) const { return s * a; }
|
||||
static_assert(std::is_copy_constructible_v<Scale>);
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr explicit scales(Scale lhs) : lhs_(lhs) {}
|
||||
|
||||
template <typename Right>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(const Right& rhs) const
|
||||
-> decltype(std::declval<const Scale&>() * rhs)
|
||||
{
|
||||
return lhs_ * rhs;
|
||||
}
|
||||
|
||||
private:
|
||||
Scale lhs_;
|
||||
};
|
||||
|
||||
template <typename Scale>
|
||||
__host__ __device__ scales(Scale)->scales<Scale>;
|
||||
|
||||
template <typename Left = void, typename Right = Left>
|
||||
struct plus
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user