[CK_TILE] Refine FP32 => FP16/BF16 Conversion (#3215)

* [CK_TILE] Refine FP32 => FP16/BF16 Conversion

* Thank you Copilot

* Rename fix

* Fix example

* Fix accu checking

* Fix

* Fix
This commit is contained in:
Yi DING
2025-11-21 02:50:26 +08:00
committed by GitHub
parent 07314ac543
commit 8b284a63a4
7 changed files with 61 additions and 14 deletions

View File

@@ -283,7 +283,10 @@ template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant<rounding> = {})
{
#if CK_TILE_USE_LLVM_BUILTIN_BF16
// Use builtin bfloat16 conversion only on gfx950 as its predecessors do not support bf16 cvt
// instructions, resulting in suboptimal performance; Add host side marcro check for consistency
// during accuracy tests.
#if CK_TILE_USE_LLVM_BUILTIN_BF16 && (defined(__gfx950__) || defined(CK_GFX950_SUPPORT))
return static_cast<bfloat16_t>(f);
#else
return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
@@ -427,4 +430,14 @@ bfloat16_t exp2(bfloat16_t x) { return static_cast<bfloat16_t>(exp2f(static_cast
CK_TILE_DEVICE
bfloat16_t log(bfloat16_t x) { return static_cast<bfloat16_t>(__logf(static_cast<float>(x))); };
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
using fp32x2_t = float __attribute__((ext_vector_type(2)));
template <bf16_rounding_mode rounding =
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
CK_TILE_HOST_DEVICE constexpr bf16x2_t fp32x2_to_bf16x2(const fp32x2_t& x)
{
return bf16x2_t{float_to_bf16<rounding>(x.x), float_to_bf16<rounding>(x.y)};
}
} // namespace ck_tile

View File

@@ -383,6 +383,7 @@ half_t log(half_t x) { return static_cast<half_t>(__logf(static_cast<float>(x)))
#endif
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
using fp32x2_t = float __attribute__((ext_vector_type(2)));
CK_TILE_HOST fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
{
@@ -401,4 +402,9 @@ CK_TILE_DEVICE fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
return c;
}
CK_TILE_HOST_DEVICE
constexpr fp16x2_t fp32x2_to_fp16x2(const fp32x2_t& x)
{
return fp16x2_t{float_to_fp16(x.x), float_to_fp16(x.y)};
}
} // namespace ck_tile

View File

@@ -64,6 +64,9 @@ CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
CK_TILE_TYPE_CONVERT(float, float, int8_t, int8)
CK_TILE_TYPE_CONVERT(int8_t, int8, float, float)
CK_TILE_TYPE_CONVERT(fp16x2_t, fp16x2, fp32x2_t, fp32x2)
CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, fp32x2_t, fp32x2)
#undef CK_TILE_TYPE_CONVERT
} // namespace ck_tile

View File

@@ -228,7 +228,7 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors)
}
template <typename OutDataType, typename InTensor>
CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor& in_dstr_tensors)
CK_TILE_DEVICE auto cast_tile_pkrtz_fp16_fp32(const InTensor& in_dstr_tensors)
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
// This API is designed to use the _pk_ serious of function
@@ -258,6 +258,30 @@ CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor& in_dstr_tensors)
#endif
}
template <typename OutDataType, typename InTensor>
CK_TILE_DEVICE auto cast_tile_pk_fp16bf16_fp32(const InTensor& in_dstr_tensors)
{
// This API is designed to help compiler to identify pairs of f32 -> fp16/bf16 cast and use
// cvt_pk instruction when possible
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
static_assert(thread_buffer_size % 2 == 0);
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
using f16x2_t = std::conditional_t<std::is_same_v<OutDataType, fp16_t>, fp16x2_t, bf16x2_t>;
for(index_t i = 0; i < thread_buffer_size / 2; i++)
{
auto o = type_convert<f16x2_t>(fp32x2_t{
in_dstr_tensors.get_thread_buffer()[2 * i + 0],
in_dstr_tensors.get_thread_buffer()[2 * i + 1],
});
out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = o.x;
out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = o.y;
}
return out_dstr_tensor;
}
#if CK_TILE_USE_SUBDWORD_TILE_CAST
// this function assume either src or dst (or both) date type is under 1 dword
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
@@ -329,22 +353,20 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
if constexpr((std::is_same_v<DstType, fp8_t> || std::is_same_v<DstType, bf8_t>) &&
std::is_same_v<typename SrcTensor::DataType, float> &&
(SrcTensor::get_thread_buffer_size() % 4 == 0))
{
return impl::cast_tile_pk_fp8_fp32<DstType, SrcTensor>(src_tensor);
}
#if CK_TILE_USE_PK_FP16_TILE_CAST
else if constexpr(std::is_same_v<DstType, fp16_t> &&
std::is_same_v<typename SrcTensor::DataType, float> &&
(SrcTensor::get_thread_buffer_size() % 2 == 0))
{
return impl::cast_tile_pk_fp16_fp32<DstType, SrcTensor>(src_tensor);
}
return impl::cast_tile_pkrtz_fp16_fp32<DstType, SrcTensor>(src_tensor);
#endif
else if constexpr((std::is_same_v<DstType, fp16_t> || std::is_same_v<DstType, bf16_t>) &&
std::is_same_v<typename SrcTensor::DataType, float> &&
(SrcTensor::get_thread_buffer_size() % 2 == 0))
return impl::cast_tile_pk_fp16bf16_fp32<DstType, SrcTensor>(src_tensor);
#if CK_TILE_USE_SUBDWORD_TILE_CAST
else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4)
{
return impl::cast_tile_opt_subdword<DstType, SrcTensor>(src_tensor);
}
#endif
else
return tile_elementwise_in(type_convert<DstType, typename SrcTensor::DataType>, src_tensor);