topk_softmax (#1592)

* topk_softmax

* remove some file

* fix atomix linear_offset

* address various comment, and change sfc get_index api to static(tuple)
This commit is contained in:
carlushuang
2024-10-26 23:52:49 +08:00
committed by GitHub
parent 31bf253aeb
commit b098b71b05
41 changed files with 5603 additions and 226 deletions

View File

@@ -623,7 +623,7 @@ template <typename... Ys,
false>
CK_TILE_HOST_DEVICE constexpr auto operator+=(tuple<Ys...>& y, const X& x)
{
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
static_assert(X::size() == sizeof...(Ys), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Ys);
static_for<0, NSize, 1>{}([&](auto i) { y[i] += x[i]; });
return y;
@@ -635,7 +635,7 @@ template <typename... Ys,
false>
CK_TILE_HOST_DEVICE constexpr auto operator-=(tuple<Ys...>& y, const X& x)
{
static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
static_assert(X::size() == sizeof...(Ys), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Ys);
static_for<0, NSize, 1>{}([&](auto i) { y[i] -= x[i]; });
return y;
@@ -647,7 +647,7 @@ template <typename... Xs,
false>
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r;
@@ -655,13 +655,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
return r;
}
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] + y[i]; }, number<NSize>{});
}
template <typename... Xs,
typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r;
@@ -669,13 +677,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
return r;
}
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] - y[i]; }, number<NSize>{});
}
template <typename... Xs,
typename Y,
std::enable_if_t<!std::is_integral<Y>::value && !std::is_floating_point<Y>::value, bool> =
false>
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const Y& y)
{
static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
static_assert(Y::size() == sizeof...(Xs), "wrong! size not the same");
constexpr index_t NSize = sizeof...(Xs);
tuple<Xs...> r;
@@ -706,6 +722,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
return a * x;
}
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, const tuple<Ys...>& y)
{
static_assert(sizeof...(Xs) == sizeof...(Ys), "wrong!");
constexpr index_t NSize = sizeof...(Xs);
return generate_tuple([&](auto i) { return x[i] * y[i]; }, number<NSize>{});
}
template <typename... Xs, typename... Ys>
CK_TILE_HOST_DEVICE constexpr auto operator/(const tuple<Xs...>& x, const tuple<Ys...>& y)
{