mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Implement hiprtc for codegen tests
This commit is contained in:
@@ -1,22 +1,29 @@
|
||||
cmake_minimum_required(VERSION 3.16)
|
||||
project(composable_kernel_host LANGUAGES CXX HIP)
|
||||
find_package(ROCM)
|
||||
include(ROCMInstallTargets)
|
||||
include(ROCMTest)
|
||||
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
|
||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
|
||||
set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
|
||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||
set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..)
|
||||
|
||||
|
||||
add_compile_options(-std=c++17)
|
||||
find_package(hip)
|
||||
add_custom_target(codegen)
|
||||
|
||||
|
||||
# add include directories
|
||||
include_directories(BEFORE
|
||||
${PROJECT_BINARY_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/include
|
||||
${PROJECT_SOURCE_DIR}/library/include
|
||||
${HIP_INCLUDE_DIRS}
|
||||
${CK_ROOT}/include/
|
||||
)
|
||||
|
||||
|
||||
list(APPEND CMAKE_MODULE_PATH ${CK_ROOT}/cmake)
|
||||
include(Embed)
|
||||
file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
|
||||
@@ -25,29 +32,31 @@ file(GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
|
||||
#message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
|
||||
#message(STATUS "RELATIVE: ${CK_ROOT}/include")
|
||||
add_embed_library(ck_headers ${KERNEL_FILES} RELATIVE ${CK_ROOT}/include)
|
||||
|
||||
|
||||
file(GLOB SOURCES CONFIGURE_DEPENDS src/*.cpp)
|
||||
|
||||
|
||||
##message(STATUS "SOURCE_FILES: ${SOURCES}")
|
||||
# TODO: Use object library
|
||||
add_library(ck_host STATIC ${SOURCES})
|
||||
target_link_libraries(ck_host PRIVATE ck_headers)
|
||||
|
||||
|
||||
set_target_properties(ck_host PROPERTIES
|
||||
LINKER_LANGUAGE CXX
|
||||
POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
|
||||
target_include_directories(ck_host PUBLIC
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/solution_instances>
|
||||
$<BUILD_INTERFACE:${CMAKE_CURRENT_BINARY_DIR}/embed/ck_headers/include>
|
||||
)
|
||||
|
||||
|
||||
add_executable(ck-template-driver driver/main.cpp)
|
||||
target_link_libraries(ck-template-driver ck_host)
|
||||
|
||||
|
||||
rocm_install(
|
||||
TARGETS ck_host ck_headers
|
||||
EXPORT ck_hostTargets
|
||||
)
|
||||
rocm_install(DIRECTORY include/ck DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
|
||||
|
||||
add_subdirectory(test)
|
||||
|
||||
add_subdirectory(test)
|
||||
@@ -100,5 +100,33 @@ inline auto Transform(const Range1& r1, const Range2& r2, F f)
|
||||
return result;
|
||||
}
|
||||
|
||||
inline bool StartsWith(const std::string& value, const std::string& prefix)
|
||||
{
|
||||
if(prefix.size() > value.size())
|
||||
return false;
|
||||
else
|
||||
return std::equal(prefix.begin(), prefix.end(), value.begin());
|
||||
}
|
||||
|
||||
inline bool EndsWith(const std::string& value, const std::string& suffix)
|
||||
{
|
||||
if(suffix.size() > value.size())
|
||||
return false;
|
||||
else
|
||||
return std::equal(suffix.rbegin(), suffix.rend(), value.rbegin());
|
||||
}
|
||||
|
||||
inline std::vector<std::string> SplitString(const std::string& s, char delim)
|
||||
{
|
||||
std::vector<std::string> elems;
|
||||
std::stringstream ss(s + delim);
|
||||
std::string item;
|
||||
while(std::getline(ss, item, delim))
|
||||
{
|
||||
elems.push_back(item);
|
||||
}
|
||||
return elems;
|
||||
}
|
||||
|
||||
} // namespace host
|
||||
} // namespace ck
|
||||
|
||||
@@ -8,18 +8,73 @@
|
||||
#include <rtc/compile_kernel.hpp>
|
||||
#include <rtc/hip.hpp>
|
||||
#include <fstream>
|
||||
#include <unordered_set>
|
||||
#include "ck/host/headers.hpp"
|
||||
#include "rtc/hiprtc_enable_env.hpp"
|
||||
#include "ck/host/stringutils.hpp"
|
||||
|
||||
std::vector<rtc::src_file> get_headers_for_test()
|
||||
// NOLINTNEXTLINE
|
||||
const char* const disable_warning_pragma = R"__migraphx__(
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Weverything"
|
||||
${content}
|
||||
#pragma clang diagnostic pop
|
||||
)__migraphx__";
|
||||
|
||||
template <class P>
|
||||
inline std::string ck_disable_warnings(P p)
|
||||
{
|
||||
return ck::host::InterpolateString(disable_warning_pragma,
|
||||
{{"content", std::string{p.data(), p.size()}}});
|
||||
}
|
||||
|
||||
inline std::vector<rtc::src_file> create_headers_for_hiprtc_test()
|
||||
{
|
||||
auto ck_headers = ck::host::GetHeaders();
|
||||
std::vector<rtc::src_file> result;
|
||||
|
||||
std::transform(ck_headers.begin(), ck_headers.end(), std::back_inserter(result), [&](auto& p) {
|
||||
return rtc::src_file{p.first, ck_disable_warnings(p.second)};
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
inline const std::vector<rtc::src_file>& get_headers_for_hiprtc_test()
|
||||
{
|
||||
static const std::vector<rtc::src_file> headers = create_headers_for_hiprtc_test();
|
||||
return headers;
|
||||
}
|
||||
|
||||
inline std::vector<rtc::src_file> create_headers_for_clang_test()
|
||||
{
|
||||
std::vector<rtc::src_file> result;
|
||||
auto hs = ck::host::GetHeaders();
|
||||
std::transform(
|
||||
hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file {
|
||||
return {p.first, p.second};
|
||||
return {p.first, {p.second.begin(), p.second.end()}};
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
inline const std::vector<rtc::src_file>& get_headers_for_clang_test()
|
||||
{
|
||||
static const std::vector<rtc::src_file> headers = create_headers_for_clang_test();
|
||||
return headers;
|
||||
}
|
||||
|
||||
inline const std::vector<rtc::src_file>& get_headers_for_test()
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_CODEGEN_TESTS_ENABLE_HIPRTC)))
|
||||
{
|
||||
return get_headers_for_hiprtc_test();
|
||||
}
|
||||
else
|
||||
{
|
||||
return get_headers_for_clang_test();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename V>
|
||||
std::size_t GetSize(V mLens, V mStrides)
|
||||
{
|
||||
@@ -34,18 +89,24 @@ std::size_t GetSize(V mLens, V mStrides)
|
||||
return space;
|
||||
}
|
||||
|
||||
template <class T, typename V>
|
||||
rtc::buffer<T> generate_buffer(V mLens, V mStrides, std::size_t seed = 0)
|
||||
template <class T>
|
||||
rtc::buffer<T> generate_buffer(std::size_t n, std::size_t seed = 0)
|
||||
{
|
||||
std::size_t space = GetSize(mLens, mStrides);
|
||||
rtc::buffer<T> result(space);
|
||||
rtc::buffer<T> result(n);
|
||||
std::mt19937 gen(seed);
|
||||
std::uniform_real_distribution<double> dis(-1.0);
|
||||
std::generate(result.begin(), result.end(), [&] { return dis(gen); });
|
||||
// std::fill(result.begin(), result.end(), 1);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class T, typename V>
|
||||
std::enable_if_t<!std::is_integral_v<V>, rtc::buffer<T>>
|
||||
generate_buffer(V mLens, V mStrides, std::size_t seed = 0)
|
||||
{
|
||||
std::size_t space = GetSize(mLens, mStrides);
|
||||
return generate_buffer<T>(space, seed);
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01)
|
||||
{
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include "common.hpp"
|
||||
#include "ck/host/device_gemm_multiple_d/problem.hpp"
|
||||
#include "ck/host/device_gemm_multiple_d/operation.hpp"
|
||||
#include "ck/host/headers.hpp"
|
||||
@@ -15,116 +16,6 @@
|
||||
using half = _Float16;
|
||||
// using half = __fp16;
|
||||
|
||||
std::vector<rtc::src_file> get_headers_for_test()
|
||||
{
|
||||
std::vector<rtc::src_file> result;
|
||||
auto hs = ck::host::GetHeaders();
|
||||
std::transform(
|
||||
hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file {
|
||||
return {p.first, p.second};
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
rtc::buffer<T> generate_buffer(std::size_t n, std::size_t seed = 0)
|
||||
{
|
||||
rtc::buffer<T> result(n);
|
||||
std::mt19937 gen(seed);
|
||||
std::uniform_real_distribution<double> dis(-1.0);
|
||||
std::generate(result.begin(), result.end(), [&] { return dis(gen); });
|
||||
return result;
|
||||
}
|
||||
|
||||
template <class T, class U>
|
||||
bool allclose(const T& a, const U& b, double atol = 0.01, double rtol = 0.01)
|
||||
{
|
||||
return std::equal(a.begin(), a.end(), b.begin(), b.end(), [&](double x, double y) {
|
||||
return fabs(x - y) < atol + rtol * fabs(y);
|
||||
});
|
||||
}
|
||||
|
||||
std::string classify(double x)
|
||||
{
|
||||
switch(std::fpclassify(x))
|
||||
{
|
||||
case FP_INFINITE: return "inf";
|
||||
case FP_NAN: return "nan";
|
||||
case FP_NORMAL: return "normal";
|
||||
case FP_SUBNORMAL: return "subnormal";
|
||||
case FP_ZERO: return "zero";
|
||||
default: return "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
template <class Buffer>
|
||||
void print_classification(const Buffer& x)
|
||||
{
|
||||
std::unordered_set<std::string> result;
|
||||
for(const auto& i : x)
|
||||
result.insert(classify(i));
|
||||
for(const auto& c : result)
|
||||
std::cout << c << ", ";
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
template <class Buffer>
|
||||
void print_statistics(const Buffer& x)
|
||||
{
|
||||
std::cout << "Min value: " << *std::min_element(x.begin(), x.end()) << ", ";
|
||||
std::cout << "Max value: " << *std::max_element(x.begin(), x.end()) << ", ";
|
||||
double num_elements = x.size();
|
||||
auto mean =
|
||||
std::accumulate(x.begin(), x.end(), double{0.0}, std::plus<double>{}) / num_elements;
|
||||
auto stddev = std::sqrt(
|
||||
std::accumulate(x.begin(),
|
||||
x.end(),
|
||||
double{0.0},
|
||||
[&](double r, double v) { return r + std::pow((v - mean), 2.0); }) /
|
||||
num_elements);
|
||||
std::cout << "Mean: " << mean << ", ";
|
||||
std::cout << "StdDev: " << stddev << "\n";
|
||||
}
|
||||
|
||||
template <class Buffer>
|
||||
void print_preview(const Buffer& x)
|
||||
{
|
||||
if(x.size() <= 10)
|
||||
{
|
||||
std::for_each(x.begin(), x.end(), [&](double i) { std::cout << i << ", "; });
|
||||
}
|
||||
else
|
||||
{
|
||||
std::for_each(x.begin(), x.begin() + 5, [&](double i) { std::cout << i << ", "; });
|
||||
std::cout << "..., ";
|
||||
std::for_each(x.end() - 5, x.end(), [&](double i) { std::cout << i << ", "; });
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
struct check_all
|
||||
{
|
||||
rtc::buffer<T> data{};
|
||||
bool operator()(const rtc::buffer<T>& x)
|
||||
{
|
||||
if(data.empty())
|
||||
{
|
||||
data = x;
|
||||
return true;
|
||||
}
|
||||
if(std::any_of(x.begin(), x.end(), [](double y) { return std::isnan(y); }))
|
||||
return false;
|
||||
return allclose(data, x);
|
||||
}
|
||||
};
|
||||
|
||||
template <class Solution>
|
||||
auto report(const Solution& solution, bool pass)
|
||||
{
|
||||
return test::make_predicate(solution.ToTemplateString(), [=] { return pass; });
|
||||
}
|
||||
|
||||
const std::string gemm_compile_check = R"__ck__(
|
||||
#include <${include}>
|
||||
|
||||
@@ -163,23 +54,28 @@ TEST_CASE(test_problem_kernel)
|
||||
std::string epilogue = "";
|
||||
std::string prologue = "";
|
||||
|
||||
for(auto solution : prob.GetSolutions("gfx90a", prologue, epilogue))
|
||||
auto solutions = prob.GetSolutions("gfx90a", prologue, epilogue);
|
||||
std::cout << "Num solutions: " << solutions.size() << std::endl;
|
||||
|
||||
for(auto i = 0; i < solutions.size(); ++i)
|
||||
{
|
||||
auto src = ck::host::InterpolateString(gemm_compile_check,
|
||||
{{"include", prob.GetIncludeHeader()},
|
||||
{"template", solution.ToTemplateString()},
|
||||
{"m", std::to_string(prob.M)},
|
||||
{"n", std::to_string(prob.N)},
|
||||
{"k", std::to_string(prob.K)}});
|
||||
auto srcs = get_headers_for_test();
|
||||
srcs.push_back({"main.cpp", src});
|
||||
std::cout << "Testing solution " << std::to_string(i + 1) << std::endl;
|
||||
auto&& solution = solutions[i];
|
||||
auto src = ck::host::InterpolateString(gemm_compile_check,
|
||||
{{"include", prob.GetIncludeHeader()},
|
||||
{"template", solution.ToTemplateString()},
|
||||
{"m", std::to_string(prob.M)},
|
||||
{"n", std::to_string(prob.N)},
|
||||
{"k", std::to_string(prob.K)}});
|
||||
|
||||
rtc::compile_options options;
|
||||
options.kernel_name = "f";
|
||||
auto k = rtc::compile_kernel(srcs, options);
|
||||
auto block_size = solution.GetTemplateParameter<std::size_t>("BlockSize");
|
||||
auto m_per_block = solution.GetTemplateParameter<std::size_t>("MPerBlock");
|
||||
auto n_per_block = solution.GetTemplateParameter<std::size_t>("NPerBlock");
|
||||
auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) *
|
||||
options.kernel_name = "f";
|
||||
options.additional_src_files = get_headers_for_test();
|
||||
auto k = rtc::compile_kernel(src, options);
|
||||
auto block_size = solution.GetTemplateParameter<std::size_t>("BlockSize");
|
||||
auto m_per_block = solution.GetTemplateParameter<std::size_t>("MPerBlock");
|
||||
auto n_per_block = solution.GetTemplateParameter<std::size_t>("NPerBlock");
|
||||
auto grid_size = ck::host::integer_divide_ceil(prob.M, m_per_block) *
|
||||
ck::host::integer_divide_ceil(prob.N, n_per_block);
|
||||
k.launch(nullptr, grid_size * block_size, block_size)(a.data(), b.data(), c.data());
|
||||
|
||||
|
||||
@@ -4,24 +4,29 @@
|
||||
#include <rtc/kernel.hpp>
|
||||
#include <ck/filesystem.hpp>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
|
||||
namespace rtc {
|
||||
|
||||
struct src_file
|
||||
{
|
||||
CK::fs::path path;
|
||||
std::string_view content;
|
||||
std::string content;
|
||||
};
|
||||
|
||||
struct compile_options
|
||||
{
|
||||
std::string flags = "";
|
||||
std::string kernel_name = "main";
|
||||
std::vector<src_file> additional_src_files = {};
|
||||
std::string params = "";
|
||||
};
|
||||
|
||||
kernel compile_kernel(const std::vector<src_file>& src,
|
||||
compile_options options = compile_options{});
|
||||
|
||||
kernel compile_kernel(const std::string& content, compile_options options = compile_options{});
|
||||
|
||||
} // namespace rtc
|
||||
|
||||
#endif
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include <hip/hip_runtime_api.h>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace rtc {
|
||||
|
||||
|
||||
3
codegen/test/rtc/include/rtc/hiprtc_enable_env.hpp
Normal file
3
codegen/test/rtc/include/rtc/hiprtc_enable_env.hpp
Normal file
@@ -0,0 +1,3 @@
|
||||
#include <ck/utility/env.hpp>
|
||||
|
||||
CK_DECLARE_ENV_VAR_BOOL(CK_CODEGEN_TESTS_ENABLE_HIPRTC)
|
||||
@@ -1,10 +1,16 @@
|
||||
#include "rtc/hip.hpp"
|
||||
#include <rtc/compile_kernel.hpp>
|
||||
// TODO include only if USE_RTC is set?
|
||||
#include <hip/hiprtc.h>
|
||||
#include <rtc/tmp_dir.hpp>
|
||||
#include <stdexcept>
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <cassert>
|
||||
#include <numeric>
|
||||
#include <deque>
|
||||
#include <rtc/hiprtc_enable_env.hpp>
|
||||
#include <ck/host/stringutils.hpp>
|
||||
|
||||
namespace rtc {
|
||||
|
||||
@@ -59,7 +65,7 @@ std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip --cuda-device
|
||||
// TODO: undo after extracting the codeobj
|
||||
// std::string compiler() { return "/opt/rocm/llvm/bin/clang++ -x hip"; }
|
||||
|
||||
kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options)
|
||||
kernel clang_compile_kernel(const std::vector<src_file>& srcs, compile_options options)
|
||||
{
|
||||
assert(not srcs.empty());
|
||||
tmp_dir td{"compile"};
|
||||
@@ -100,4 +106,289 @@ kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options
|
||||
return kernel{obj.data(), options.kernel_name};
|
||||
}
|
||||
|
||||
struct hiprtc_src_file
|
||||
{
|
||||
hiprtc_src_file() = default;
|
||||
hiprtc_src_file(const src_file& s) : path(s.path.string()), content(s.content) {}
|
||||
std::string path;
|
||||
std::string content;
|
||||
template <class Self, class F>
|
||||
static auto reflect(Self& self, F f)
|
||||
{
|
||||
return pack(f(self.path, "path"), f(self.content, "content"));
|
||||
}
|
||||
};
|
||||
|
||||
std::string hiprtc_error(hiprtcResult err, const std::string& msg)
|
||||
{
|
||||
return "hiprtc: " + (hiprtcGetErrorString(err) + (": " + msg));
|
||||
}
|
||||
|
||||
void hiprtc_check_error(hiprtcResult err, const std::string& msg, const std::string& ctx)
|
||||
{
|
||||
if(err != HIPRTC_SUCCESS)
|
||||
throw std::runtime_error(hiprtc_error(err, msg));
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
#define MIGRAPHX_HIPRTC(...) \
|
||||
hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, "Lorem ipsum dolor sit amet")
|
||||
|
||||
#define MIGRAPHX_HIPRTC_THROW(error, msg) throw std::runtime_error(hiprtc_error(error, msg))
|
||||
|
||||
template <class F, F f> // NOLINT
|
||||
struct manage_deleter
|
||||
{
|
||||
template <class T>
|
||||
void operator()(T* x) const
|
||||
{
|
||||
if(x != nullptr)
|
||||
{
|
||||
(void)f(x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class T, class F, F f> // NOLINT
|
||||
using manage_ptr = std::unique_ptr<T, manage_deleter<F, f>>;
|
||||
|
||||
#define MIGRAPHX_MANAGE_PTR(T, F) manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F> // NOLINT
|
||||
|
||||
// Workaround hiprtc's broken API
|
||||
void hiprtc_program_destroy(hiprtcProgram prog) { hiprtcDestroyProgram(&prog); }
|
||||
using hiprtc_program_ptr = MIGRAPHX_MANAGE_PTR(hiprtcProgram, hiprtc_program_destroy);
|
||||
|
||||
template <class... Ts>
|
||||
hiprtc_program_ptr hiprtc_program_create(Ts... xs)
|
||||
{
|
||||
hiprtcProgram prog = nullptr;
|
||||
auto result = hiprtcCreateProgram(&prog, xs...);
|
||||
hiprtc_program_ptr p{prog};
|
||||
if(result != HIPRTC_SUCCESS)
|
||||
MIGRAPHX_HIPRTC_THROW(result, "Create program failed.");
|
||||
return p;
|
||||
}
|
||||
|
||||
struct hiprtc_program
|
||||
{
|
||||
struct string_array
|
||||
{
|
||||
std::deque<std::string> strings{};
|
||||
std::vector<const char*> c_strs{};
|
||||
|
||||
string_array() {}
|
||||
string_array(const string_array&) = delete;
|
||||
|
||||
std::size_t size() const { return strings.size(); }
|
||||
|
||||
const char** data() { return c_strs.data(); }
|
||||
|
||||
void push_back(std::string s)
|
||||
{
|
||||
strings.push_back(std::move(s));
|
||||
c_strs.push_back(strings.back().c_str());
|
||||
}
|
||||
};
|
||||
|
||||
hiprtc_program_ptr prog = nullptr;
|
||||
string_array headers{};
|
||||
string_array include_names{};
|
||||
std::string cpp_src = "";
|
||||
std::string cpp_name = "";
|
||||
|
||||
hiprtc_program(const std::string& src, const std::string& name = "main.cpp")
|
||||
: cpp_src(src), cpp_name(name)
|
||||
{
|
||||
create_program();
|
||||
}
|
||||
|
||||
hiprtc_program(std::vector<src_file> srcs)
|
||||
{
|
||||
for(auto&& src : srcs)
|
||||
{
|
||||
if(ck::host::EndsWith(src.path, ".cpp"))
|
||||
{
|
||||
cpp_src = std::move(src.content);
|
||||
cpp_name = std::move(src.path);
|
||||
}
|
||||
else
|
||||
{
|
||||
headers.push_back(std::string(src.content.begin(), src.content.end()));
|
||||
include_names.push_back(std::move(src.path));
|
||||
}
|
||||
}
|
||||
create_program();
|
||||
}
|
||||
|
||||
void create_program()
|
||||
{
|
||||
assert(not cpp_src.empty());
|
||||
assert(not cpp_name.empty());
|
||||
assert(headers.size() == include_names.size());
|
||||
prog = hiprtc_program_create(cpp_src.c_str(),
|
||||
cpp_name.c_str(),
|
||||
headers.size(),
|
||||
headers.data(),
|
||||
include_names.data());
|
||||
}
|
||||
|
||||
void compile(const std::vector<std::string>& options, bool quiet = false) const
|
||||
{
|
||||
std::vector<const char*> c_options;
|
||||
std::transform(options.begin(),
|
||||
options.end(),
|
||||
std::back_inserter(c_options),
|
||||
[](const std::string& s) { return s.c_str(); });
|
||||
auto result = hiprtcCompileProgram(prog.get(), c_options.size(), c_options.data());
|
||||
auto prog_log = log();
|
||||
if(not prog_log.empty() and not quiet)
|
||||
{
|
||||
std::cerr << prog_log << std::endl;
|
||||
}
|
||||
if(result != HIPRTC_SUCCESS)
|
||||
throw std::runtime_error("Compilation failed.");
|
||||
}
|
||||
|
||||
std::string log() const
|
||||
{
|
||||
std::size_t n = 0;
|
||||
MIGRAPHX_HIPRTC(hiprtcGetProgramLogSize(prog.get(), &n));
|
||||
if(n == 0)
|
||||
return {};
|
||||
std::string buffer(n, '\0');
|
||||
MIGRAPHX_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data()));
|
||||
assert(buffer.back() != 0);
|
||||
return buffer;
|
||||
}
|
||||
|
||||
std::vector<char> get_code_obj() const
|
||||
{
|
||||
std::size_t n = 0;
|
||||
MIGRAPHX_HIPRTC(hiprtcGetCodeSize(prog.get(), &n));
|
||||
std::vector<char> buffer(n);
|
||||
MIGRAPHX_HIPRTC(hiprtcGetCode(prog.get(), buffer.data()));
|
||||
return buffer;
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<src_file> srcs,
|
||||
const std::string& params,
|
||||
const std::string& arch)
|
||||
{
|
||||
hiprtc_program prog(std::move(srcs));
|
||||
auto options = ck::host::SplitString(params, ' ');
|
||||
options.push_back("-DMIGRAPHX_USE_HIPRTC=1");
|
||||
if(true)
|
||||
{
|
||||
options.push_back("-DMIGRAPHX_HAS_DPP=0");
|
||||
options.push_back("-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1");
|
||||
options.push_back("-Wno-reserved-identifier");
|
||||
options.push_back("-Wno-unused-parameter");
|
||||
options.push_back("-Wno-gnu-line-marker");
|
||||
options.push_back("-Wno-old-style-cast");
|
||||
}
|
||||
if(true)
|
||||
options.push_back("-DMIGRAPHX_DEBUG");
|
||||
if(std::none_of(options.begin(), options.end(), [](const std::string& s) {
|
||||
return ck::host::StartsWith(s, "--std=") or ck::host::StartsWith(s, "-std=");
|
||||
}))
|
||||
options.push_back("-std=c++17");
|
||||
options.push_back("-fno-gpu-rdc");
|
||||
options.push_back("-O3");
|
||||
options.push_back("-Wno-cuda-compat");
|
||||
options.push_back("--offload-arch=" + arch);
|
||||
prog.compile(options);
|
||||
return {prog.get_code_obj()};
|
||||
}
|
||||
|
||||
bool hip_has_flags(const std::vector<std::string>& flags)
|
||||
{
|
||||
hiprtc_program prog{" "};
|
||||
try
|
||||
{
|
||||
prog.compile(flags, true);
|
||||
return true;
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool hip_accept_non_uniform_wg()
|
||||
{
|
||||
static bool non_uniform_wg = hip_has_flags({"-fno-offload-uniform-block"});
|
||||
return non_uniform_wg;
|
||||
}
|
||||
|
||||
static std::vector<std::string> get_compiler_warnings()
|
||||
{
|
||||
std::vector<std::string> warnings = {
|
||||
"-Weverything",
|
||||
"-Wno-c++98-compat",
|
||||
"-Wno-c++98-compat-pedantic",
|
||||
"-Wno-conversion",
|
||||
"-Wno-double-promotion",
|
||||
"-Wno-exit-time-destructors",
|
||||
"-Wno-extra-semi",
|
||||
"-Wno-extra-semi-stmt",
|
||||
"-Wno-float-conversion",
|
||||
"-Wno-gnu-anonymous-struct",
|
||||
"-Wno-gnu-zero-variadic-macro-arguments",
|
||||
"-Wno-missing-prototypes",
|
||||
"-Wno-nested-anon-types",
|
||||
"-Wno-padded",
|
||||
"-Wno-shorten-64-to-32",
|
||||
"-Wno-sign-conversion",
|
||||
"-Wno-sign-compare",
|
||||
"-Wno-unused-command-line-argument",
|
||||
"-Wno-weak-vtables",
|
||||
"-Wno-c99-extensions",
|
||||
};
|
||||
|
||||
if(hip_has_flags({"-Werror", "-Wunsafe-buffer-usage"}))
|
||||
warnings.push_back("-Wno-unsafe-buffer-usage");
|
||||
return warnings;
|
||||
}
|
||||
|
||||
const std::vector<std::string>& compiler_warnings()
|
||||
{
|
||||
static std::vector<std::string> warnings = get_compiler_warnings();
|
||||
return warnings;
|
||||
}
|
||||
|
||||
static kernel hiprtc_compile_kernel(const std::string& content, compile_options options)
|
||||
{
|
||||
std::vector<src_file> srcs = options.additional_src_files;
|
||||
srcs.push_back(src_file{std::string("main.cpp"), content});
|
||||
|
||||
options.params += " " + ck::host::JoinStrings(compiler_warnings(), " ");
|
||||
options.params += " -ftemplate-backtrace-limit=0";
|
||||
options.params += " -Werror";
|
||||
auto cos = compile_hip_src_with_hiprtc(srcs, options.params, get_device_name());
|
||||
if(cos.size() != 1)
|
||||
std::runtime_error("No code object");
|
||||
auto& obj = cos.front();
|
||||
|
||||
return kernel{obj.data(), options.kernel_name};
|
||||
}
|
||||
|
||||
kernel compile_kernel(const std::string& content, compile_options options)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_CODEGEN_TESTS_ENABLE_HIPRTC)))
|
||||
{
|
||||
return hiprtc_compile_kernel(content, options);
|
||||
}
|
||||
else
|
||||
{
|
||||
options.additional_src_files.push_back({"main.cpp", content});
|
||||
return clang_compile_kernel(options.additional_src_files, options);
|
||||
}
|
||||
}
|
||||
|
||||
kernel compile_kernel(const std::vector<src_file>& src, compile_options options)
|
||||
{
|
||||
return clang_compile_kernel(src, options);
|
||||
}
|
||||
|
||||
} // namespace rtc
|
||||
|
||||
Reference in New Issue
Block a user