Update host tensor utils

This commit is contained in:
Rostyslav Geyyer
2025-03-10 19:10:07 +00:00
parent 234cbcb7af
commit facbaab7b7

View File

@@ -50,7 +50,8 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
{
os << ck::type_convert<float>(v);
}
else if constexpr(std::is_same_v<RangeType, ck::pk_i4_t>)
else if constexpr(std::is_same_v<RangeType, ck::pk_i4_t> ||
std::is_same_v<RangeType, ck::f4x2_pk_t>)
{
const auto packed_floats = ck::type_convert<ck::float2_t>(v);
const ck::vector_type<float, 2> vector_of_floats{packed_floats};
@@ -333,7 +334,8 @@ struct Tensor
std::size_t GetElementSpaceSize() const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return (mDesc.GetElementSpaceSize() + 1) / 2;
}
@@ -488,7 +490,8 @@ struct Tensor
template <typename... Is>
std::size_t GetOffsetFromMultiIndex(Is... is) const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mDesc.GetOffsetFromMultiIndex(is...) / 2;
}
@@ -501,7 +504,8 @@ struct Tensor
template <typename... Is>
T& operator()(Is... is)
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
@@ -514,7 +518,8 @@ struct Tensor
template <typename... Is>
const T& operator()(Is... is) const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
}
@@ -526,7 +531,8 @@ struct Tensor
T& operator()(std::vector<std::size_t> idx)
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}
@@ -538,7 +544,8 @@ struct Tensor
const T& operator()(std::vector<std::size_t> idx) const
{
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t>)
if constexpr(ck::is_same_v<ck::remove_cvref_t<T>, ck::pk_i4_t> ||
ck::is_same_v<ck::remove_cvref_t<T>, ck::f4x2_pk_t>)
{
return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
}