mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
topk_softmax (#1592)
* topk_softmax * remove some file * fix atomix linear_offset * address various comment, and change sfc get_index api to static(tuple)
This commit is contained in:
@@ -623,7 +623,7 @@ template <typename... Ys,
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+=(tuple<Ys...>& y, const X& x)
|
||||
{
|
||||
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
|
||||
static_assert(X::size() == sizeof...(Ys), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Ys);
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
|
||||
return y;
|
||||
@@ -635,7 +635,7 @@ template <typename... Ys,
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-=(tuple<Ys...>& y, const X& x)
|
||||
{
|
||||
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
|
||||
static_assert(X::size() == sizeof...(Ys), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Ys);
|
||||
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
|
||||
return y;
|
||||
@@ -647,7 +647,7 @@ template <typename... Xs,
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
|
||||
static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> r;
|
||||
@@ -655,13 +655,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const tuple<Ys...>& y)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
return generate_tuple([&](auto i) { return x[i] + y[i]; }, number<NSize>{});
|
||||
}
|
||||
|
||||
template <typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
|
||||
static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> r;
|
||||
@@ -669,13 +677,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
|
||||
return r;
|
||||
}
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const tuple<Ys...>& y)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
return generate_tuple([&](auto i) { return x[i] - y[i]; }, number<NSize>{});
|
||||
}
|
||||
|
||||
template <typename... Xs,
|
||||
typename Y,
|
||||
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
|
||||
false>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const Y& y)
|
||||
{
|
||||
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
|
||||
static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
|
||||
tuple<Xs...> r;
|
||||
@@ -706,6 +722,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
|
||||
return a * x;
|
||||
}
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const tuple<Ys...>& y)
|
||||
{
|
||||
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
|
||||
constexpr index_t NSize = sizeof...(Xs);
|
||||
return generate_tuple([&](auto i) { return x[i] * y[i]; }, number<NSize>{});
|
||||
}
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple<Xs...>& x, const tuple<Ys...>& y)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user