[CK_TILE] Refine FP32 => FP16/BF16 Conversion (#3215)

* [CK_TILE] Refine FP32 => FP16/BF16 Conversion

* Thank you Copilot

* Rename fix

* Fix example

* Fix accu checking

* Fix

* Fix
This commit is contained in:
Yi DING
2025-11-21 02:50:26 +08:00
committed by GitHub
parent 07314ac543
commit 8b284a63a4
7 changed files with 61 additions and 14 deletions

View File

@@ -704,12 +704,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync
const auto p = [&]() {
#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
// For fp32 to fp16,
// impl::cast_tile_pk_fp16_fp32 would cause precision issue,
// impl::cast_tile_pkrtz_fp16_fp32 would cause precision issue,
// since it uses __builtin_amdgcn_cvt_pkrtz, which is round to zero.
return cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
#else
if constexpr(std::is_same_v<PDataType, fp16_t>)
return impl::cast_tile_pk_fp16_fp32<PDataType>(
return impl::cast_tile_pkrtz_fp16_fp32<PDataType>(
tile_elementwise_in(p_compute_element_func, p_compute));
else
return cast_tile<PDataType>(

View File

@@ -657,12 +657,12 @@ struct BlockFmhaPipelineQRKSVSAsync
const auto p = [&]() {
#if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
// For fp32 to fp16,
// impl::cast_tile_pk_fp16_fp32 would cause precision issue,
// impl::cast_tile_pkrtz_fp16_fp32 would cause precision issue,
// since it uses __builtin_amdgcn_cvt_pkrtz, which is round to zero.
return cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
#else
if constexpr(std::is_same_v<PDataType, fp16_t>)
return impl::cast_tile_pk_fp16_fp32<PDataType>(
return impl::cast_tile_pkrtz_fp16_fp32<PDataType>(
tile_elementwise_in(p_compute_element_func, p_compute));
else
return cast_tile<PDataType>(