mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 23:05:54 +00:00
Merge commit '3052d7c9e6972d5ea7d2225ab78e45554ba70efd' into develop
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -18,16 +19,14 @@ struct Add
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
typename = std::enable_if_t<is_any_of<T, float, double, int32_t, int8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return y + x;
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
typename = std::enable_if_t<is_any_of<T, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
|
||||
{
|
||||
float y_ = type_convert<float>(y);
|
||||
@@ -46,16 +45,14 @@ struct SquareAdd
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
typename = std::enable_if_t<is_any_of<T, float, double, int32_t, int8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return y + (x * x);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
typename = std::enable_if_t<is_any_of<T, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
|
||||
{
|
||||
float y_ = type_convert<float>(y);
|
||||
@@ -66,48 +63,74 @@ struct SquareAdd
|
||||
|
||||
struct Max
|
||||
{
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
|
||||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return numeric<T>::lowest();
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
|
||||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return max(y, x);
|
||||
}
|
||||
|
||||
// Overload with changed flag for index tracking
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
|
||||
{
|
||||
T new_max = max(y, x);
|
||||
if(x > y)
|
||||
{
|
||||
changed = true;
|
||||
}
|
||||
return new_max;
|
||||
}
|
||||
};
|
||||
|
||||
struct AbsMax
|
||||
{
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
|
||||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return numeric<T>::zero();
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
|
||||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return max(y, abs(x));
|
||||
}
|
||||
|
||||
// Overload with changed flag for index tracking
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
|
||||
{
|
||||
T new_max = max(y, abs(x));
|
||||
if(abs(x) > y)
|
||||
{
|
||||
changed = true;
|
||||
}
|
||||
return new_max;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ReduceOp
|
||||
|
||||
50
include/ck_tile/core/utility/reduce_operator_accumulate.hpp
Normal file
50
include/ck_tile/core/utility/reduce_operator_accumulate.hpp
Normal file
@@ -0,0 +1,50 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief Accumulate with index tracking reductions, provides deterministic first occurring index
|
||||
struct AccumulateWithIndex
|
||||
{
|
||||
template <typename ReduceOp, typename T, typename IndexType>
|
||||
CK_TILE_HOST_DEVICE void operator()(const ReduceOp& reduce_func,
|
||||
T& current_value,
|
||||
IndexType& current_index,
|
||||
const T& new_value,
|
||||
const IndexType& new_index) const
|
||||
{
|
||||
bool changed = false;
|
||||
current_value = reduce_func(current_value, new_value, changed);
|
||||
|
||||
if(changed)
|
||||
{
|
||||
current_index = new_index;
|
||||
}
|
||||
else if(new_index < current_index)
|
||||
{
|
||||
bool reverse_changed = false;
|
||||
reduce_func(new_value, current_value, reverse_changed);
|
||||
|
||||
if(!reverse_changed)
|
||||
{
|
||||
current_index = new_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Accumulate
|
||||
{
|
||||
template <typename ReduceOp, typename T>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()(const ReduceOp& reduce_func, T& current_value, const T& new_value) const
|
||||
{
|
||||
current_value = reduce_func(current_value, new_value);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user