diff --git a/codegen/include/ck/host/stringutils.hpp b/codegen/include/ck/host/stringutils.hpp index 37d1144ae1..89c1884d2e 100644 --- a/codegen/include/ck/host/stringutils.hpp +++ b/codegen/include/ck/host/stringutils.hpp @@ -100,33 +100,5 @@ inline auto Transform(const Range1& r1, const Range2& r2, F f) return result; } -inline bool StartsWith(const std::string& value, const std::string& prefix) -{ - if(prefix.size() > value.size()) - return false; - else - return std::equal(prefix.begin(), prefix.end(), value.begin()); -} - -inline bool EndsWith(const std::string& value, const std::string& suffix) -{ - if(suffix.size() > value.size()) - return false; - else - return std::equal(suffix.rbegin(), suffix.rend(), value.rbegin()); -} - -inline std::vector SplitString(const std::string& s, char delim) -{ - std::vector elems; - std::stringstream ss(s + delim); - std::string item; - while(std::getline(ss, item, delim)) - { - elems.push_back(item); - } - return elems; -} - } // namespace host } // namespace ck diff --git a/codegen/test/CMakeLists.txt b/codegen/test/CMakeLists.txt index 94b1b99409..6dd130bc3f 100644 --- a/codegen/test/CMakeLists.txt +++ b/codegen/test/CMakeLists.txt @@ -1,10 +1,3 @@ -option(USE_HIPRTC_FOR_CODEGEN_TESTS "Whether to enable hipRTC for codegen tests." ON) - -if(USE_HIPRTC_FOR_CODEGEN_TESTS) - add_compile_definitions(HIPRTC_FOR_CODEGEN_TESTS) - message("CK compiled with USE_HIPRTC_FOR_CODEGEN_TESTS set to ${USE_HIPRTC_FOR_CODEGEN_TESTS}") -endif() - list(APPEND CMAKE_PREFIX_PATH /opt/rocm) add_subdirectory(rtc) file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp) diff --git a/codegen/test/rtc/CMakeLists.txt b/codegen/test/rtc/CMakeLists.txt index 39497f1a21..a83574947d 100644 --- a/codegen/test/rtc/CMakeLists.txt +++ b/codegen/test/rtc/CMakeLists.txt @@ -2,3 +2,9 @@ file(GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp) add_library(ck_rtc ${RTC_SOURCES}) target_include_directories(ck_rtc PUBLIC include) target_link_libraries(ck_rtc PUBLIC hip::host) + +option(USE_HIPRTC_FOR_CODEGEN_TESTS "Whether to enable hipRTC for codegen tests." ON) +if(USE_HIPRTC_FOR_CODEGEN_TESTS) + target_compile_definitions(ck_rtc PUBLIC HIPRTC_FOR_CODEGEN_TESTS) + message("CK compiled with USE_HIPRTC_FOR_CODEGEN_TESTS set to ${USE_HIPRTC_FOR_CODEGEN_TESTS}") +endif() diff --git a/codegen/test/rtc/src/compile_kernel.cpp b/codegen/test/rtc/src/compile_kernel.cpp index 445f8725d2..a3377d6853 100644 --- a/codegen/test/rtc/src/compile_kernel.cpp +++ b/codegen/test/rtc/src/compile_kernel.cpp @@ -1,8 +1,8 @@ -#include #include #include #ifdef HIPRTC_FOR_CODEGEN_TESTS #include +#include #endif #include #include @@ -14,6 +14,26 @@ namespace rtc { +bool EndsWith(const std::string& value, const std::string& suffix) +{ + if(suffix.size() > value.size()) + return false; + else + return std::equal(suffix.rbegin(), suffix.rend(), value.rbegin()); +} + +std::vector SplitString(const std::string& s, char delim) +{ + std::vector elems; + std::stringstream ss(s + delim); + std::string item; + while(std::getline(ss, item, delim)) + { + elems.push_back(item); + } + return elems; +} + template T generic_read_file(const std::string& filename, size_t offset = 0, size_t nbytes = 0) { @@ -108,42 +128,27 @@ kernel clang_compile_kernel(const std::vector& srcs, compile_options o #ifdef HIPRTC_FOR_CODEGEN_TESTS +std::string hiprtc_error(hiprtcResult err, const std::string& msg) +{ + return "hiprtc: " + (hiprtcGetErrorString(err) + (": " + msg)); +} + +void hiprtc_check_error(hiprtcResult err, const std::string& msg = "") +{ + if(err != HIPRTC_SUCCESS) + throw std::runtime_error(hiprtc_error(err, msg)); +} + struct hiprtc_src_file { hiprtc_src_file() = default; hiprtc_src_file(const src_file& s) : path(s.path.string()), content(s.content) {} std::string path; std::string content; - template - static auto reflect(Self& self, F f) - { - return pack(f(self.path, "path"), f(self.content, "content")); - } }; -std::string hiprtc_error(hiprtcResult err, const std::string& msg) -{ - return "hiprtc: " + (hiprtcGetErrorString(err) + (": " + msg)); -} - -void hiprtc_check_error(hiprtcResult err, const std::string& msg, const std::string& ctx) -{ - if(err != HIPRTC_SUCCESS) - throw std::runtime_error(hiprtc_error(err, msg)); -} - -// NOLINTNEXTLINE -#define RTC_HIPRTC(...) hiprtc_check_error(__VA_ARGS__, #__VA_ARGS__, "Lorem ipsum dolor sit amet") - -#define RTC_HIPRTC_THROW(error, msg) throw std::runtime_error(hiprtc_error(error, msg)) - -struct hiprtc_program_destroy -{ - void operator()(hiprtcProgram prog) const { hiprtcDestroyProgram(&prog); } -}; - -using hiprtc_program_ptr = - std::unique_ptr, hiprtc_program_destroy>; +void hiprtc_program_destroy(hiprtcProgram prog) { hiprtcDestroyProgram(&prog); } +using hiprtc_program_ptr = RTC_MANAGE_PTR(hiprtcProgram, hiprtc_program_destroy); template hiprtc_program_ptr hiprtc_program_create(Ts... xs) @@ -151,8 +156,7 @@ hiprtc_program_ptr hiprtc_program_create(Ts... xs) hiprtcProgram prog = nullptr; auto result = hiprtcCreateProgram(&prog, xs...); hiprtc_program_ptr p{prog}; - if(result != HIPRTC_SUCCESS) - RTC_HIPRTC_THROW(result, "Create program failed."); + hiprtc_check_error(result, "Create program failed."); return p; } @@ -193,7 +197,7 @@ struct hiprtc_program { for(auto&& src : srcs) { - if(ck::host::EndsWith(src.path, ".cpp")) + if(EndsWith(src.path, ".cpp")) { cpp_src = std::move(src.content); cpp_name = std::move(src.path); @@ -239,11 +243,11 @@ struct hiprtc_program std::string log() const { std::size_t n = 0; - RTC_HIPRTC(hiprtcGetProgramLogSize(prog.get(), &n)); + hiprtc_check_error(hiprtcGetProgramLogSize(prog.get(), &n)); if(n == 0) return {}; std::string buffer(n, '\0'); - RTC_HIPRTC(hiprtcGetProgramLog(prog.get(), buffer.data())); + hiprtc_check_error(hiprtcGetProgramLog(prog.get(), buffer.data())); assert(buffer.back() != 0); return buffer; } @@ -251,9 +255,9 @@ struct hiprtc_program std::vector get_code_obj() const { std::size_t n = 0; - RTC_HIPRTC(hiprtcGetCodeSize(prog.get(), &n)); + hiprtc_check_error(hiprtcGetCodeSize(prog.get(), &n)); std::vector buffer(n); - RTC_HIPRTC(hiprtcGetCode(prog.get(), buffer.data())); + hiprtc_check_error(hiprtcGetCode(prog.get(), buffer.data())); return buffer; } }; @@ -262,7 +266,7 @@ std::vector> compile_hip_src_with_hiprtc(const std::vector using bool_constant = integral_constant; @@ -113,51 +114,75 @@ constexpr T&& forward(typename remove_reference::type&& t_) noexcept return static_cast(t_); } -template struct is_const : false_type {}; -template struct is_const : true_type {}; -template< class T > +template +struct is_const : false_type +{ +}; +template +struct is_const : true_type +{ +}; +template inline constexpr bool is_const_v = is_const::value; -template< class T > +template inline constexpr bool is_reference_v = is_reference::value; -template struct remove_const { typedef T type; }; -template struct remove_const { typedef T type; }; -template< class T > +template +struct remove_const +{ + typedef T type; +}; +template +struct remove_const +{ + typedef T type; +}; +template using remove_const_t = typename remove_const::type; -template< class T > +template inline constexpr bool is_class_v = is_class::value; -template< class T > +template inline constexpr bool is_trivially_copyable_v = is_trivially_copyable::value; -template< class... > +template using void_t = void; -using __hip::declval; +template +U private_declval(int); + +template +T private_declval(long); + +template +auto declval() noexcept -> decltype(private_declval(0)); + #else + #include #include +using std::declval; +using std::false_type; using std::forward; using std::is_base_of; using std::is_class; +using std::is_class_v; +using std::is_const_v; using std::is_pointer; using std::is_reference; +using std::is_reference_v; using std::is_trivially_copyable; +using std::is_trivially_copyable_v; using std::is_unsigned; +using std::remove_const_t; using std::remove_cv; using std::remove_pointer; using std::remove_reference; -using std::is_const_v; -using std::is_reference_v; -using std::remove_const_t; -using std::is_class_v; -using std::is_trivially_copyable_v; -using std::void_t; -using std::false_type; using std::true_type; -using std::declval; +using std::void_t; + #endif template