diff --git a/codegen/test/common.hpp b/codegen/test/common.hpp index 48dfc66511..f97b4436c1 100644 --- a/codegen/test/common.hpp +++ b/codegen/test/common.hpp @@ -14,65 +14,33 @@ #include "ck/host/stringutils.hpp" // NOLINTNEXTLINE -const char* const disable_warning_pragma = R"__migraphx__( -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Weverything" +const char* const content_wrapper = R"__ck__( ${content} -#pragma clang diagnostic pop -)__migraphx__"; +)__ck__"; template -inline std::string ck_disable_warnings(P p) +inline std::string ck_content_wrapper(P p) { - return ck::host::InterpolateString(disable_warning_pragma, + return ck::host::InterpolateString(content_wrapper, {{"content", std::string{p.data(), p.size()}}}); } -inline std::vector create_headers_for_hiprtc_test() +inline std::vector create_headers_for_test() { auto ck_headers = ck::host::GetHeaders(); std::vector result; std::transform(ck_headers.begin(), ck_headers.end(), std::back_inserter(result), [&](auto& p) { - return rtc::src_file{p.first, ck_disable_warnings(p.second)}; + return rtc::src_file{p.first, ck_content_wrapper(p.second)}; }); return result; } -inline const std::vector& get_headers_for_hiprtc_test() -{ - static const std::vector headers = create_headers_for_hiprtc_test(); - return headers; -} - -inline std::vector create_headers_for_clang_test() -{ - std::vector result; - auto hs = ck::host::GetHeaders(); - std::transform( - hs.begin(), hs.end(), std::back_inserter(result), [&](const auto& p) -> rtc::src_file { - return {p.first, {p.second.begin(), p.second.end()}}; - }); - return result; -} - -inline const std::vector& get_headers_for_clang_test() -{ - static const std::vector headers = create_headers_for_clang_test(); - return headers; -} - inline const std::vector& get_headers_for_test() { - if(ck::EnvIsEnabled(CK_ENV(CK_CODEGEN_TESTS_ENABLE_HIPRTC))) - { - return get_headers_for_hiprtc_test(); - } - else - { - return get_headers_for_clang_test(); - } + static const std::vector headers = create_headers_for_test(); + return headers; } template diff --git a/codegen/test/gemm_multiple_d.cpp b/codegen/test/gemm_multiple_d.cpp index e5b3b5d6b9..7a20e54766 100644 --- a/codegen/test/gemm_multiple_d.cpp +++ b/codegen/test/gemm_multiple_d.cpp @@ -71,11 +71,11 @@ TEST_CASE(test_problem_kernel) {"m", std::to_string(prob.M)}, {"n", std::to_string(prob.N)}, {"k", std::to_string(prob.K)}}); - + auto srcs = get_headers_for_test(); + srcs.push_back({"main.cpp", src}); rtc::compile_options options; options.kernel_name = "f"; - options.additional_src_files = get_headers_for_test(); - auto k = rtc::compile_kernel(src, options); + auto k = rtc::compile_kernel(srcs, options); auto block_size = solution.GetTemplateParameter("BlockSize"); auto m_per_block = solution.GetTemplateParameter("MPerBlock"); auto n_per_block = solution.GetTemplateParameter("NPerBlock"); diff --git a/codegen/test/rtc/include/rtc/compile_kernel.hpp b/codegen/test/rtc/include/rtc/compile_kernel.hpp index 1c7262157e..6f3d107afa 100644 --- a/codegen/test/rtc/include/rtc/compile_kernel.hpp +++ b/codegen/test/rtc/include/rtc/compile_kernel.hpp @@ -19,40 +19,11 @@ struct compile_options { std::string flags = ""; std::string kernel_name = "main"; - std::vector additional_src_files = {}; - std::string params = ""; }; -struct hip_compile_options -{ - std::size_t global; - std::size_t local; - std::string kernel_name = "kernel"; - std::string params = ""; - std::vector additional_src_files = {}; - - /** - * @brief Set the launch parameters but allow v to override the values - * - * @param v A value class which can have a "global" and/or "local" keys to override the default - * global and local - * @param compute_global A function used to compute the global based on the local - * @param default_local The defaul local to use if its missing from the v parameter - */ - void set_launch_params(const std::function& compute_global, - std::size_t default_local = 1024); - - void set_launch_params(std::size_t default_global, std::size_t default_local = 1024) - { - set_launch_params([=](auto) { return default_global; }, default_local); - } -}; - -kernel compile_kernel(const std::vector& src, +kernel compile_kernel(const std::vector& srcs, compile_options options = compile_options{}); -kernel compile_kernel(const std::string& content, compile_options options = compile_options{}); - } // namespace rtc #endif diff --git a/codegen/test/rtc/src/compile_kernel.cpp b/codegen/test/rtc/src/compile_kernel.cpp index 0cb9a627d6..b15756d6e0 100644 --- a/codegen/test/rtc/src/compile_kernel.cpp +++ b/codegen/test/rtc/src/compile_kernel.cpp @@ -131,32 +131,17 @@ void hiprtc_check_error(hiprtcResult err, const std::string& msg, const std::str } // NOLINTNEXTLINE -#define MIGRAPHX_HIPRTC(...) \ - hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, "Lorem ipsum dolor sit amet") +#define RTC_HIPRTC(...) hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, "Lorem ipsum dolor sit amet") -#define MIGRAPHX_HIPRTC_THROW(error, msg) throw std::runtime_error(hiprtc_error(error, msg)) +#define RTC_HIPRTC_THROW(error, msg) throw std::runtime_error(hiprtc_error(error, msg)) -template // NOLINT -struct manage_deleter +struct hiprtc_program_destroy { - template - void operator()(T* x) const - { - if(x != nullptr) - { - (void)f(x); - } - } + void operator()(hiprtcProgram prog) const { hiprtcDestroyProgram(&prog); } }; -template // NOLINT -using manage_ptr = std::unique_ptr>; - -#define MIGRAPHX_MANAGE_PTR(T, F) manage_ptr, decltype(&F), &F> // NOLINT - -// Workaround hiprtc's broken API -void hiprtc_program_destroy(hiprtcProgram prog) { hiprtcDestroyProgram(&prog); } -using hiprtc_program_ptr = MIGRAPHX_MANAGE_PTR(hiprtcProgram, hiprtc_program_destroy); +using hiprtc_program_ptr = + std::unique_ptr, hiprtc_program_destroy>; template hiprtc_program_ptr hiprtc_program_create(Ts... xs) @@ -165,7 +150,7 @@ hiprtc_program_ptr hiprtc_program_create(Ts... xs) auto result = hiprtcCreateProgram(&prog, xs...); hiprtc_program_ptr p{prog}; if(result != HIPRTC_SUCCESS) - MIGRAPHX_HIPRTC_THROW(result, "Create program failed."); + RTC_HIPRTC_THROW(result, "Create program failed."); return p; } @@ -252,11 +237,11 @@ struct hiprtc_program std::string log() const { std::size_t n = 0; - MIGRAPHX_HIPRTC(hiprtcGetProgramLogSize(prog.get(), &n)); + RTC_HIPRTC(hiprtcGetProgramLogSize(prog.get(), &n)); if(n == 0) return {}; std::string buffer(n, '\0'); - MIGRAPHX_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data())); + RTC_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data())); assert(buffer.back() != 0); return buffer; } @@ -264,108 +249,28 @@ struct hiprtc_program std::vector get_code_obj() const { std::size_t n = 0; - MIGRAPHX_HIPRTC(hiprtcGetCodeSize(prog.get(), &n)); + RTC_HIPRTC(hiprtcGetCodeSize(prog.get(), &n)); std::vector buffer(n); - MIGRAPHX_HIPRTC(hiprtcGetCode(prog.get(), buffer.data())); + RTC_HIPRTC(hiprtcGetCode(prog.get(), buffer.data())); return buffer; } }; -std::vector> compile_hip_src_with_hiprtc(std::vector srcs, - const std::string& params, - const std::string& arch) +std::vector> compile_hip_src_with_hiprtc(const std::vector& srcs, + const compile_options& options) { - hiprtc_program prog(std::move(srcs)); - auto options = ck::host::SplitString(params, ' '); - options.push_back("-DMIGRAPHX_USE_HIPRTC=1"); - if(true) - { - options.push_back("-DMIGRAPHX_HAS_DPP=0"); - options.push_back("-DMIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS=1"); - options.push_back("-Wno-reserved-identifier"); - options.push_back("-Wno-unused-parameter"); - options.push_back("-Wno-gnu-line-marker"); - options.push_back("-Wno-old-style-cast"); - } - if(true) - options.push_back("-DMIGRAPHX_DEBUG"); - if(std::none_of(options.begin(), options.end(), [](const std::string& s) { - return ck::host::StartsWith(s, "--std=") or ck::host::StartsWith(s, "-std="); - })) - options.push_back("-std=c++17"); - options.push_back("-fno-gpu-rdc"); - options.push_back("-O3"); - options.push_back("-Wno-cuda-compat"); - options.push_back("--offload-arch=" + arch); - prog.compile(options); + hiprtc_program prog(srcs); + auto flags = ck::host::SplitString(options.flags, ' '); + prog.compile(flags); return {prog.get_code_obj()}; } -bool hip_has_flags(const std::vector& flags) +static kernel hiprtc_compile_kernel(const std::vector& srcs, compile_options options) { - hiprtc_program prog{" "}; - try - { - prog.compile(flags, true); - return true; - } - catch(...) - { - return false; - } -} - -bool hip_accept_non_uniform_wg() -{ - static bool non_uniform_wg = hip_has_flags({"-fno-offload-uniform-block"}); - return non_uniform_wg; -} - -static std::vector get_compiler_warnings() -{ - std::vector warnings = { - "-Weverything", - "-Wno-c++98-compat", - "-Wno-c++98-compat-pedantic", - "-Wno-conversion", - "-Wno-double-promotion", - "-Wno-exit-time-destructors", - "-Wno-extra-semi", - "-Wno-extra-semi-stmt", - "-Wno-float-conversion", - "-Wno-gnu-anonymous-struct", - "-Wno-gnu-zero-variadic-macro-arguments", - "-Wno-missing-prototypes", - "-Wno-nested-anon-types", - "-Wno-padded", - "-Wno-shorten-64-to-32", - "-Wno-sign-conversion", - "-Wno-sign-compare", - "-Wno-unused-command-line-argument", - "-Wno-weak-vtables", - "-Wno-c99-extensions", - }; - - if(hip_has_flags({"-Werror", "-Wunsafe-buffer-usage"})) - warnings.push_back("-Wno-unsafe-buffer-usage"); - return warnings; -} - -const std::vector& compiler_warnings() -{ - static std::vector warnings = get_compiler_warnings(); - return warnings; -} - -static kernel hiprtc_compile_kernel(const std::string& content, compile_options options) -{ - std::vector srcs = options.additional_src_files; - srcs.push_back(src_file{std::string("main.cpp"), content}); - - options.params += " " + ck::host::JoinStrings(compiler_warnings(), " "); - options.params += " -ftemplate-backtrace-limit=0"; - options.params += " -Werror"; - auto cos = compile_hip_src_with_hiprtc(srcs, options.params, get_device_name()); + options.flags += " -I. -O3"; + options.flags += " -std=c++17"; + options.flags += " --offload-arch=" + get_device_name(); + auto cos = compile_hip_src_with_hiprtc(srcs, options); if(cos.size() != 1) std::runtime_error("No code object"); auto& obj = cos.front(); @@ -373,22 +278,16 @@ static kernel hiprtc_compile_kernel(const std::string& content, compile_options return kernel{obj.data(), options.kernel_name}; } -kernel compile_kernel(const std::string& content, compile_options options) +kernel compile_kernel(const std::vector& srcs, compile_options options) { if(ck::EnvIsEnabled(CK_ENV(CK_CODEGEN_TESTS_ENABLE_HIPRTC))) { - return hiprtc_compile_kernel(content, options); + return hiprtc_compile_kernel(srcs, options); } else { - options.additional_src_files.push_back({"main.cpp", content}); - return clang_compile_kernel(options.additional_src_files, options); + return clang_compile_kernel(srcs, options); } } -kernel compile_kernel(const std::vector& src, compile_options options) -{ - return clang_compile_kernel(src, options); -} - } // namespace rtc diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 1f16953243..10be374617 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -4,7 +4,9 @@ #pragma once #include "ck/config.h" + #ifndef __HIPCC_RTC__ +#include "ck/utility/env.hpp" #ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS #include "hip/hip_runtime.h" #include "hip/hip_fp16.h"