mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
topk_softmax (#1592)
* topk_softmax * remove some file * fix atomix linear_offset * address various comment, and change sfc get_index api to static(tuple)
This commit is contained in:
@@ -10,6 +10,7 @@
|
||||
#include <random>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
@@ -41,6 +42,73 @@ struct FillUniformDistribution
|
||||
}
|
||||
};
|
||||
|
||||
namespace impl {
|
||||
|
||||
// clang-format off
|
||||
template<index_t bytes> struct RawIntegerType_ {};
|
||||
template<> struct RawIntegerType_<1> { using type = uint8_t;};
|
||||
template<> struct RawIntegerType_<2> { using type = uint16_t;};
|
||||
template<> struct RawIntegerType_<4> { using type = uint32_t;};
|
||||
template<> struct RawIntegerType_<8> { using type = uint64_t;};
|
||||
// clang-format on
|
||||
|
||||
template <typename T>
|
||||
using RawIntegerType = typename RawIntegerType_<sizeof(T)>::type;
|
||||
} // namespace impl
|
||||
|
||||
// Note: this struct will have no const-ness will generate random
|
||||
template <typename T>
|
||||
struct FillUniformDistribution_Unique
|
||||
{
|
||||
float a_{-5.f};
|
||||
float b_{5.f};
|
||||
std::optional<uint32_t> seed_{11939};
|
||||
|
||||
std::mt19937 gen_{};
|
||||
std::unordered_set<impl::RawIntegerType<T>> set_{};
|
||||
|
||||
FillUniformDistribution_Unique(float a = -5.f,
|
||||
float b = 5.f,
|
||||
std::optional<uint32_t> seed = {11939})
|
||||
: a_(a),
|
||||
b_(b),
|
||||
seed_(seed),
|
||||
gen_{seed_.has_value() ? *seed_ : std::random_device{}()},
|
||||
set_{}
|
||||
{
|
||||
}
|
||||
|
||||
template <typename ForwardIter>
|
||||
void operator()(ForwardIter first, ForwardIter last)
|
||||
{
|
||||
std::mt19937& gen = gen_;
|
||||
std::uniform_real_distribution<float> dis(a_, b_);
|
||||
auto& set = set_;
|
||||
std::generate(first, last, [&dis, &gen, &set]() {
|
||||
T v = static_cast<T>(0);
|
||||
do
|
||||
{
|
||||
v = ck_tile::type_convert<T>(dis(gen));
|
||||
} while(set.count(bit_cast<impl::RawIntegerType<T>>(v)) == 1);
|
||||
set.insert(bit_cast<impl::RawIntegerType<T>>(v));
|
||||
|
||||
return v;
|
||||
});
|
||||
}
|
||||
|
||||
template <typename ForwardRange>
|
||||
auto operator()(ForwardRange&& range)
|
||||
-> std::void_t<decltype(std::declval<FillUniformDistribution_Unique&>()(
|
||||
std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range))))>
|
||||
{
|
||||
(*this)(std::begin(std::forward<ForwardRange>(range)),
|
||||
std::end(std::forward<ForwardRange>(range)));
|
||||
}
|
||||
|
||||
void clear() { set_.clear(); }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct FillNormalDistribution
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user