mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
Address PR comments.
This commit is contained in:
@@ -14,65 +14,33 @@
|
||||
#include "ck/host/stringutils.hpp"
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
const char* const disable_warning_pragma = R"__migraphx__(
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Weverything"
|
||||
const char* const content_wrapper = R"__ck__(
|
||||
${content}
|
||||
#pragma clang diagnostic pop
|
||||
)__migraphx__";
|
||||
)__ck__";
|
||||
|
||||
template <class P>
|
||||
inline std::string ck_disable_warnings(P p)
|
||||
inline std::string ck_content_wrapper(P p)
|
||||
{
|
||||
return ck::host::InterpolateString(disable_warning_pragma,
|
||||
return ck::host::InterpolateString(content_wrapper,
|
||||
{{"content", std::string{p.data(), p.size()}}});
|
||||
}
|
||||
|
||||
inline std::vector<rtc::src_file> create_headers_for_hiprtc_test()
|
||||
inline std::vector<rtc::src_file> create_headers_for_test()
|
||||
{
|
||||
auto ck_headers = ck::host::GetHeaders();
|
||||
std::vector<rtc::src_file> result;
|
||||
|
||||
std::transform(ck_headers.begin(), ck_headers.end(), std::back_inserter(result), [&](auto& p) {
|
||||
return rtc::src_file{p.first, ck_disable_warnings(p.second)};
|
||||
return rtc::src_file{p.first, ck_content_wrapper(p.second)};
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
inline const std::vector<rtc::src_file>& get_headers_for_hiprtc_test()
|
||||
{
|
||||
static const std::vector<rtc::src_file> headers = create_headers_for_hiprtc_test();
|
||||
return headers;
|
||||
}
|
||||
|
||||
inline std::vector<rtc::src_file> create_headers_for_clang_test()
|
||||
{
|
||||
std::vector<rtc::src_file> result;
|
||||
auto hs = ck::host::GetHeaders();
|
||||
std::transform(
|
||||
hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file {
|
||||
return {p.first, {p.second.begin(), p.second.end()}};
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
inline const std::vector<rtc::src_file>& get_headers_for_clang_test()
|
||||
{
|
||||
static const std::vector<rtc::src_file> headers = create_headers_for_clang_test();
|
||||
return headers;
|
||||
}
|
||||
|
||||
inline const std::vector<rtc::src_file>& get_headers_for_test()
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_CODEGEN_TESTS_ENABLE_HIPRTC)))
|
||||
{
|
||||
return get_headers_for_hiprtc_test();
|
||||
}
|
||||
else
|
||||
{
|
||||
return get_headers_for_clang_test();
|
||||
}
|
||||
static const std::vector<rtc::src_file> headers = create_headers_for_test();
|
||||
return headers;
|
||||
}
|
||||
|
||||
template <typename V>
|
||||
|
||||
@@ -71,11 +71,11 @@ TEST_CASE(test_problem_kernel)
|
||||
{"m", std::to_string(prob.M)},
|
||||
{"n", std::to_string(prob.N)},
|
||||
{"k", std::to_string(prob.K)}});
|
||||
|
||||
auto srcs = get_headers_for_test();
|
||||
srcs.push_back({"main.cpp", src});
|
||||
rtc::compile_options options;
|
||||
options.kernel_name = "f";
|
||||
options.additional_src_files = get_headers_for_test();
|
||||
auto k = rtc::compile_kernel(src, options);
|
||||
auto k = rtc::compile_kernel(srcs, options);
|
||||
auto block_size = solution.GetTemplateParameter<std::size_t>("BlockSize");
|
||||
auto m_per_block = solution.GetTemplateParameter<std::size_t>("MPerBlock");
|
||||
auto n_per_block = solution.GetTemplateParameter<std::size_t>("NPerBlock");
|
||||
|
||||
@@ -19,40 +19,11 @@ struct compile_options
|
||||
{
|
||||
std::string flags = "";
|
||||
std::string kernel_name = "main";
|
||||
std::vector<src_file> additional_src_files = {};
|
||||
std::string params = "";
|
||||
};
|
||||
|
||||
struct hip_compile_options
|
||||
{
|
||||
std::size_t global;
|
||||
std::size_t local;
|
||||
std::string kernel_name = "kernel";
|
||||
std::string params = "";
|
||||
std::vector<src_file> additional_src_files = {};
|
||||
|
||||
/**
|
||||
* @brief Set the launch parameters but allow v to override the values
|
||||
*
|
||||
* @param v A value class which can have a "global" and/or "local" keys to override the default
|
||||
* global and local
|
||||
* @param compute_global A function used to compute the global based on the local
|
||||
* @param default_local The defaul local to use if its missing from the v parameter
|
||||
*/
|
||||
void set_launch_params(const std::function<std::size_t(std::size_t local)>& compute_global,
|
||||
std::size_t default_local = 1024);
|
||||
|
||||
void set_launch_params(std::size_t default_global, std::size_t default_local = 1024)
|
||||
{
|
||||
set_launch_params([=](auto) { return default_global; }, default_local);
|
||||
}
|
||||
};
|
||||
|
||||
kernel compile_kernel(const std::vector<src_file>& src,
|
||||
kernel compile_kernel(const std::vector<src_file>& srcs,
|
||||
compile_options options = compile_options{});
|
||||
|
||||
kernel compile_kernel(const std::string& content, compile_options options = compile_options{});
|
||||
|
||||
} // namespace rtc
|
||||
|
||||
#endif
|
||||
|
||||
@@ -131,32 +131,17 @@ void hiprtc_check_error(hiprtcResult err, const std::string& msg, const std::str
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
#define MIGRAPHX_HIPRTC(...) \
|
||||
hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, "Lorem ipsum dolor sit amet")
|
||||
#define RTC_HIPRTC(...) hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, "Lorem ipsum dolor sit amet")
|
||||
|
||||
#define MIGRAPHX_HIPRTC_THROW(error, msg) throw std::runtime_error(hiprtc_error(error, msg))
|
||||
#define RTC_HIPRTC_THROW(error, msg) throw std::runtime_error(hiprtc_error(error, msg))
|
||||
|
||||
template <class F, F f> // NOLINT
|
||||
struct manage_deleter
|
||||
struct hiprtc_program_destroy
|
||||
{
|
||||
template <class T>
|
||||
void operator()(T* x) const
|
||||
{
|
||||
if(x != nullptr)
|
||||
{
|
||||
(void)f(x);
|
||||
}
|
||||
}
|
||||
void operator()(hiprtcProgram prog) const { hiprtcDestroyProgram(&prog); }
|
||||
};
|
||||
|
||||
template <class T, class F, F f> // NOLINT
|
||||
using manage_ptr = std::unique_ptr<T, manage_deleter<F, f>>;
|
||||
|
||||
#define MIGRAPHX_MANAGE_PTR(T, F) manage_ptr<std::remove_pointer_t<T>, decltype(&F), &F> // NOLINT
|
||||
|
||||
// Workaround hiprtc's broken API
|
||||
void hiprtc_program_destroy(hiprtcProgram prog) { hiprtcDestroyProgram(&prog); }
|
||||
using hiprtc_program_ptr = MIGRAPHX_MANAGE_PTR(hiprtcProgram, hiprtc_program_destroy);
|
||||
using hiprtc_program_ptr =
|
||||
std::unique_ptr<std::remove_pointer_t<hiprtcProgram>, hiprtc_program_destroy>;
|
||||
|
||||
template <class... Ts>
|
||||
hiprtc_program_ptr hiprtc_program_create(Ts... xs)
|
||||
@@ -165,7 +150,7 @@ hiprtc_program_ptr hiprtc_program_create(Ts... xs)
|
||||
auto result = hiprtcCreateProgram(&prog, xs...);
|
||||
hiprtc_program_ptr p{prog};
|
||||
if(result != HIPRTC_SUCCESS)
|
||||
MIGRAPHX_HIPRTC_THROW(result, "Create program failed.");
|
||||
RTC_HIPRTC_THROW(result, "Create program failed.");
|
||||
return p;
|
||||
}
|
||||
|
||||
@@ -252,11 +237,11 @@ struct hiprtc_program
|
||||
std::string log() const
|
||||
{
|
||||
std::size_t n = 0;
|
||||
MIGRAPHX_HIPRTC(hiprtcGetProgramLogSize(prog.get(), &n));
|
||||
RTC_HIPRTC(hiprtcGetProgramLogSize(prog.get(), &n));
|
||||
if(n == 0)
|
||||
return {};
|
||||
std::string buffer(n, '\0');
|
||||
MIGRAPHX_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data()));
|
||||
RTC_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data()));
|
||||
assert(buffer.back() != 0);
|
||||
return buffer;
|
||||
}
|
||||
@@ -264,108 +249,28 @@ struct hiprtc_program
|
||||
std::vector<char> get_code_obj() const
|
||||
{
|
||||
std::size_t n = 0;
|
||||
MIGRAPHX_HIPRTC(hiprtcGetCodeSize(prog.get(), &n));
|
||||
RTC_HIPRTC(hiprtcGetCodeSize(prog.get(), &n));
|
||||
std::vector<char> buffer(n);
|
||||
MIGRAPHX_HIPRTC(hiprtcGetCode(prog.get(), buffer.data()));
|
||||
RTC_HIPRTC(hiprtcGetCode(prog.get(), buffer.data()));
|
||||
return buffer;
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<std::vector<char>> compile_hip_src_with_hiprtc(std::vector<src_file> srcs,
|
||||
const std::string& params,
|
||||
const std::string& arch)
|
||||
std::vector<std::vector<char>> compile_hip_src_with_hiprtc(const std::vector<src_file>& srcs,
|
||||
const compile_options& options)
|
||||
{
|
||||
hiprtc_program prog(std::move(srcs));
|
||||
auto options = ck::host::SplitString(params, ' ');
|
||||
options.push_back("-DMIGRAPHX_USE_HIPRTC=1");
|
||||
if(true)
|
||||
{
|
||||
options.push_back("-DMIGRAPHX_HAS_DPP=0");
|
||||
options.push_back("-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1");
|
||||
options.push_back("-Wno-reserved-identifier");
|
||||
options.push_back("-Wno-unused-parameter");
|
||||
options.push_back("-Wno-gnu-line-marker");
|
||||
options.push_back("-Wno-old-style-cast");
|
||||
}
|
||||
if(true)
|
||||
options.push_back("-DMIGRAPHX_DEBUG");
|
||||
if(std::none_of(options.begin(), options.end(), [](const std::string& s) {
|
||||
return ck::host::StartsWith(s, "--std=") or ck::host::StartsWith(s, "-std=");
|
||||
}))
|
||||
options.push_back("-std=c++17");
|
||||
options.push_back("-fno-gpu-rdc");
|
||||
options.push_back("-O3");
|
||||
options.push_back("-Wno-cuda-compat");
|
||||
options.push_back("--offload-arch=" + arch);
|
||||
prog.compile(options);
|
||||
hiprtc_program prog(srcs);
|
||||
auto flags = ck::host::SplitString(options.flags, ' ');
|
||||
prog.compile(flags);
|
||||
return {prog.get_code_obj()};
|
||||
}
|
||||
|
||||
bool hip_has_flags(const std::vector<std::string>& flags)
|
||||
static kernel hiprtc_compile_kernel(const std::vector<src_file>& srcs, compile_options options)
|
||||
{
|
||||
hiprtc_program prog{" "};
|
||||
try
|
||||
{
|
||||
prog.compile(flags, true);
|
||||
return true;
|
||||
}
|
||||
catch(...)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool hip_accept_non_uniform_wg()
|
||||
{
|
||||
static bool non_uniform_wg = hip_has_flags({"-fno-offload-uniform-block"});
|
||||
return non_uniform_wg;
|
||||
}
|
||||
|
||||
static std::vector<std::string> get_compiler_warnings()
|
||||
{
|
||||
std::vector<std::string> warnings = {
|
||||
"-Weverything",
|
||||
"-Wno-c++98-compat",
|
||||
"-Wno-c++98-compat-pedantic",
|
||||
"-Wno-conversion",
|
||||
"-Wno-double-promotion",
|
||||
"-Wno-exit-time-destructors",
|
||||
"-Wno-extra-semi",
|
||||
"-Wno-extra-semi-stmt",
|
||||
"-Wno-float-conversion",
|
||||
"-Wno-gnu-anonymous-struct",
|
||||
"-Wno-gnu-zero-variadic-macro-arguments",
|
||||
"-Wno-missing-prototypes",
|
||||
"-Wno-nested-anon-types",
|
||||
"-Wno-padded",
|
||||
"-Wno-shorten-64-to-32",
|
||||
"-Wno-sign-conversion",
|
||||
"-Wno-sign-compare",
|
||||
"-Wno-unused-command-line-argument",
|
||||
"-Wno-weak-vtables",
|
||||
"-Wno-c99-extensions",
|
||||
};
|
||||
|
||||
if(hip_has_flags({"-Werror", "-Wunsafe-buffer-usage"}))
|
||||
warnings.push_back("-Wno-unsafe-buffer-usage");
|
||||
return warnings;
|
||||
}
|
||||
|
||||
const std::vector<std::string>& compiler_warnings()
|
||||
{
|
||||
static std::vector<std::string> warnings = get_compiler_warnings();
|
||||
return warnings;
|
||||
}
|
||||
|
||||
static kernel hiprtc_compile_kernel(const std::string& content, compile_options options)
|
||||
{
|
||||
std::vector<src_file> srcs = options.additional_src_files;
|
||||
srcs.push_back(src_file{std::string("main.cpp"), content});
|
||||
|
||||
options.params += " " + ck::host::JoinStrings(compiler_warnings(), " ");
|
||||
options.params += " -ftemplate-backtrace-limit=0";
|
||||
options.params += " -Werror";
|
||||
auto cos = compile_hip_src_with_hiprtc(srcs, options.params, get_device_name());
|
||||
options.flags += " -I. -O3";
|
||||
options.flags += " -std=c++17";
|
||||
options.flags += " --offload-arch=" + get_device_name();
|
||||
auto cos = compile_hip_src_with_hiprtc(srcs, options);
|
||||
if(cos.size() != 1)
|
||||
std::runtime_error("No code object");
|
||||
auto& obj = cos.front();
|
||||
@@ -373,22 +278,16 @@ static kernel hiprtc_compile_kernel(const std::string& content, compile_options
|
||||
return kernel{obj.data(), options.kernel_name};
|
||||
}
|
||||
|
||||
kernel compile_kernel(const std::string& content, compile_options options)
|
||||
kernel compile_kernel(const std::vector<src_file>& srcs, compile_options options)
|
||||
{
|
||||
if(ck::EnvIsEnabled(CK_ENV(CK_CODEGEN_TESTS_ENABLE_HIPRTC)))
|
||||
{
|
||||
return hiprtc_compile_kernel(content, options);
|
||||
return hiprtc_compile_kernel(srcs, options);
|
||||
}
|
||||
else
|
||||
{
|
||||
options.additional_src_files.push_back({"main.cpp", content});
|
||||
return clang_compile_kernel(options.additional_src_files, options);
|
||||
return clang_compile_kernel(srcs, options);
|
||||
}
|
||||
}
|
||||
|
||||
kernel compile_kernel(const std::vector<src_file>& src, compile_options options)
|
||||
{
|
||||
return clang_compile_kernel(src, options);
|
||||
}
|
||||
|
||||
} // namespace rtc
|
||||
|
||||
@@ -4,7 +4,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/config.h"
|
||||
|
||||
#ifndef __HIPCC_RTC__
|
||||
#include "ck/utility/env.hpp"
|
||||
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "hip/hip_fp16.h"
|
||||
|
||||
Reference in New Issue
Block a user