merge moe sorting

This commit is contained in:
coderfeli
2025-02-25 05:08:21 +00:00
parent d5b2c900b9
commit 7ca2d03e82
257 changed files with 11031 additions and 1958 deletions

View File

@@ -2,27 +2,38 @@
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/host/headers.hpp"
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <test.hpp>
#include <algorithm>
#include <cmath>
#include <iterator>
#include <numeric>
#include <random>
#include <test.hpp>
#include <rtc/compile_kernel.hpp>
#include <rtc/hip.hpp>
#include <fstream>
#include <unordered_set>
std::vector<rtc::src_file> get_headers_for_test()
inline std::vector<rtc::src_file> create_headers_for_test()
{
auto ck_headers = ck::host::GetHeaders();
std::vector<rtc::src_file> result;
auto hs = ck::host::GetHeaders();
std::transform(
hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file {
return {p.first, p.second};
});
std::transform(ck_headers.begin(), ck_headers.end(), std::back_inserter(result), [](auto& p) {
std::string content;
content.reserve(p.second.size() + 1);
content.push_back(' '); // We need a whitespace before the content for hipRTC to work
content.append(p.second.data(), p.second.size());
return rtc::src_file{p.first, std::move(content)};
});
return result;
}
inline const std::vector<rtc::src_file>& get_headers_for_test()
{
static const std::vector<rtc::src_file> headers = create_headers_for_test();
return headers;
}
template <typename V>
std::size_t GetSize(V mLens, V mStrides)
{
@@ -37,18 +48,24 @@ std::size_t GetSize(V mLens, V mStrides)
return space;
}
template <class T, typename V>
rtc::buffer<T> generate_buffer(V mLens, V mStrides, std::size_t seed = 0)
template <class T>
rtc::buffer<T> generate_buffer(std::size_t n, std::size_t seed = 0)
{
std::size_t space = GetSize(mLens, mStrides);
rtc::buffer<T> result(space);
rtc::buffer<T> result(n);
std::mt19937 gen(seed);
std::uniform_real_distribution<double> dis(-1.0);
std::generate(result.begin(), result.end(), [&] { return dis(gen); });
// std::fill(result.begin(), result.end(), 1);
return result;
}
template <class T, typename V>
std::enable_if_t<!std::is_integral_v<V>, rtc::buffer<T>>
generate_buffer(V mLens, V mStrides, std::size_t seed = 0)
{
std::size_t space = GetSize(mLens, mStrides);
return generate_buffer<T>(space, seed);
}
template <class T, class U>
bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01)
{
@@ -57,7 +74,7 @@ bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01)
});
}
std::string classify(double x)
inline std::string classify(double x)
{
switch(std::fpclassify(x))
{