mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Refine FP32 => FP16/BF16 Conversion (#3215)
* [CK_TILE] Refine FP32 => FP16/BF16 Conversion * Thank you Copilot * Rename fix * Fix example * Fix accu checking * Fix * Fix
This commit is contained in:
@@ -283,7 +283,10 @@ template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant<rounding> = {})
|
||||
{
|
||||
#if CK_TILE_USE_LLVM_BUILTIN_BF16
|
||||
// Use builtin bfloat16 conversion only on gfx950 as its predecessors do not support bf16 cvt
|
||||
// instructions, resulting in suboptimal performance; Add host side marcro check for consistency
|
||||
// during accuracy tests.
|
||||
#if CK_TILE_USE_LLVM_BUILTIN_BF16 && (defined(__gfx950__) || defined(CK_GFX950_SUPPORT))
|
||||
return static_cast<bfloat16_t>(f);
|
||||
#else
|
||||
return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
|
||||
@@ -427,4 +430,14 @@ bfloat16_t exp2(bfloat16_t x) { return static_cast<bfloat16_t>(exp2f(static_cast
|
||||
CK_TILE_DEVICE
|
||||
bfloat16_t log(bfloat16_t x) { return static_cast<bfloat16_t>(__logf(static_cast<float>(x))); };
|
||||
|
||||
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t fp32x2_to_bf16x2(const fp32x2_t& x)
|
||||
{
|
||||
return bf16x2_t{float_to_bf16<rounding>(x.x), float_to_bf16<rounding>(x.y)};
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -383,6 +383,7 @@ half_t log(half_t x) { return static_cast<half_t>(__logf(static_cast<float>(x)))
|
||||
#endif
|
||||
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
|
||||
CK_TILE_HOST fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
|
||||
{
|
||||
@@ -401,4 +402,9 @@ CK_TILE_DEVICE fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y)
|
||||
return c;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr fp16x2_t fp32x2_to_fp16x2(const fp32x2_t& x)
|
||||
{
|
||||
return fp16x2_t{float_to_fp16(x.x), float_to_fp16(x.y)};
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -64,6 +64,9 @@ CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float)
|
||||
|
||||
CK_TILE_TYPE_CONVERT(float, float, int8_t, int8)
|
||||
CK_TILE_TYPE_CONVERT(int8_t, int8, float, float)
|
||||
|
||||
CK_TILE_TYPE_CONVERT(fp16x2_t, fp16x2, fp32x2_t, fp32x2)
|
||||
CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, fp32x2_t, fp32x2)
|
||||
#undef CK_TILE_TYPE_CONVERT
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user