Add half->int8 saturate conversion to promise valid range (#1983)

* Add half->int8 saturate conversion to promise valid range

* add gpu only macro

---------

Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
ZZK
2025-01-08 22:01:07 +08:00
committed by GitHub
parent c506e16788
commit 7de6a59784

View File

@@ -267,6 +267,44 @@ struct NumericConverter<uint8_t, float, FloatRoundStyle::round_toward_zero> {
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Partial specializations for cutlass::half_t => int8_t
//
/////////////////////////////////////////////////////////////////////////////////////////////////
template <>
struct NumericConverter<int8_t, cutlass::half_t, FloatRoundStyle::round_to_nearest> {
using result_type = int8_t;
using source_type = cutlass::half_t;
static FloatRoundStyle const round_style = FloatRoundStyle::round_to_nearest;
CUTLASS_HOST_DEVICE
static result_type convert(source_type const & s) {
#if defined(__CUDA_ARCH__)
union { int8_t int8[2]; int16_t int16; };
union { cutlass::half_t fp16; int16_t int16_in; };
fp16 = s;
asm volatile ("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in));
return int8[0];
#elif !defined(__CUDACC_RTC__)
std::fesetround(FE_TONEAREST);
int32_t intermediate = (int32_t)std::nearbyint(static_cast<float>(s));
// Low-end saturation
intermediate = std::max(intermediate, (int32_t)std::numeric_limits<int8_t>::lowest());
// High-end saturation
intermediate = std::min(intermediate, (int32_t)std::numeric_limits<int8_t>::max());
return static_cast<result_type>(intermediate);
#endif
}
CUTLASS_HOST_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
//
// Partial specializations for float => integer_subbyte