diff --git a/include/ck/utility/amd_wave_read_first_lane.hpp b/include/ck/utility/amd_wave_read_first_lane.hpp index 4652ce7a74..741b2975af 100644 --- a/include/ck/utility/amd_wave_read_first_lane.hpp +++ b/include/ck/utility/amd_wave_read_first_lane.hpp @@ -7,6 +7,7 @@ #include "ck/utility/functional2.hpp" #include "ck/utility/math.hpp" +#include #include #include #include @@ -14,29 +15,83 @@ namespace ck { namespace detail { -template -struct get_unsigned_int; +template +struct get_carrier; template <> -struct get_unsigned_int<1> +struct get_carrier<1> { using type = uint8_t; }; template <> -struct get_unsigned_int<2> +struct get_carrier<2> { using type = uint16_t; }; template <> -struct get_unsigned_int<4> +struct get_carrier<3> +{ + using type = class carrier + { + using value_type = uint32_t; + + std::array bytes; + static_assert(sizeof(bytes) <= sizeof(value_type)); + + // replacement of host std::copy_n() + template + __device__ static OutputIterator copy_n(InputIterator from, Size size, OutputIterator to) + { + if(0 < size) + { + *to = *from; + ++to; + for(Size count = 1; count < size; ++count) + { + *to = *++from; + ++to; + } + } + + return to; + } + + // method to trigger template substitution failure + __device__ carrier(const carrier& other) noexcept + { + copy_n(other.bytes.begin(), bytes.size(), bytes.begin()); + } + + public: + __device__ carrier& operator=(value_type value) noexcept + { + copy_n(reinterpret_cast(&value), bytes.size(), bytes.begin()); + + return *this; + } + + __device__ operator value_type() const noexcept + { + std::byte result[sizeof(value_type)]; + + copy_n(bytes.begin(), bytes.size(), result); + + return *reinterpret_cast(result); + } + }; +}; +static_assert(sizeof(get_carrier<3>::type) == 3); + +template <> +struct get_carrier<4> { using type = uint32_t; }; -template -using get_unsigned_int_t = typename get_unsigned_int::type; +template +using get_carrier_t = typename get_carrier::type; } // namespace detail @@ -61,7 +116,7 @@ __device__ auto amd_wave_read_first_lane(const Object& obj) constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize; for(Size offset = 0; offset < CompleteSgprCopyBoundary; offset += SgprSize) { - using Sgpr = detail::get_unsigned_int_t; + using Sgpr = detail::get_carrier_t; *reinterpret_cast(to_obj + offset) = amd_wave_read_first_lane(*reinterpret_cast(from_obj + offset)); @@ -69,9 +124,9 @@ __device__ auto amd_wave_read_first_lane(const Object& obj) if constexpr(0 < RemainedSize) { - using Carrier = detail::get_unsigned_int_t; + using Carrier = detail::get_carrier_t; - *reinterpret_cast(to_obj + CompleteSgprCopyBoundary) = amd_wave_read_first_lane( + *reinterpret_cast(to_obj + CompleteSgprCopyBoundary) = amd_wave_read_first_lane( *reinterpret_cast(from_obj + CompleteSgprCopyBoundary)); }