From facbaab7b7df9fe839543e60afabf2a6a6efe047 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer Date: Mon, 10 Mar 2025 19:10:07 +0000 Subject: [PATCH] Update host tensor utils --- include/ck/library/utility/host_tensor.hpp | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index f1730de0e1..250ebd6721 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -50,7 +50,8 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) { os << ck::type_convert(v); } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v || + std::is_same_v) { const auto packed_floats = ck::type_convert(v); const ck::vector_type vector_of_floats{packed_floats}; @@ -333,7 +334,8 @@ struct Tensor std::size_t GetElementSpaceSize() const { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return (mDesc.GetElementSpaceSize() + 1) / 2; } @@ -488,7 +490,8 @@ struct Tensor template std::size_t GetOffsetFromMultiIndex(Is... is) const { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return mDesc.GetOffsetFromMultiIndex(is...) / 2; } @@ -501,7 +504,8 @@ struct Tensor template T& operator()(Is... is) { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; } @@ -514,7 +518,8 @@ struct Tensor template const T& operator()(Is... is) const { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; } @@ -526,7 +531,8 @@ struct Tensor T& operator()(std::vector idx) { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; } @@ -538,7 +544,8 @@ struct Tensor const T& operator()(std::vector idx) const { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; }