mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
[CK_TILE] float -> bf16 inline asm rtn (#1482)
* asm rtn
* add asm rtn macro
* reorder macro
---------
Co-authored-by: carlushuang <carlus.huang@amd.com>
[ROCm/composable_kernel commit: b8addae293]
This commit is contained in:
@@ -46,6 +46,7 @@
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
|
||||
|
||||
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
|
||||
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
|
||||
|
||||
@@ -17,6 +17,7 @@ enum class bf16_rounding_mode
|
||||
standard = 0, // rtn
|
||||
truncate_with_nan,
|
||||
truncate,
|
||||
standard_asm,
|
||||
};
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
@@ -148,6 +149,37 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f)
|
||||
return uint16_t(u.int32 >> 16);
|
||||
}
|
||||
|
||||
CK_TILE_HOST
|
||||
constexpr uint16_t float_to_bf16_rtn_asm(float f) { return float_to_bf16_rtn_raw(f); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
uint16_t float_to_bf16_rtn_asm(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
uint32_t int32;
|
||||
} u = {f};
|
||||
|
||||
static constexpr uint32_t FP32_NAN = 0x7fff0000;
|
||||
static constexpr uint32_t ROUND_BIAS_FOR_BF16 = 0x7fff;
|
||||
|
||||
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
|
||||
uint32x2_t check_nan;
|
||||
uint32_t tmp;
|
||||
asm volatile("\n \
|
||||
v_cmp_u_f32 %0, %2, %2 \n \
|
||||
v_bfe_u32 %1, %2, 16, 1 \n \
|
||||
v_add3_u32 %1, %2, %1, %3 \n \
|
||||
v_cndmask_b32 %2, %1, %4, %0 \n \
|
||||
v_lshrrev_b32 %2, 16, %2 \n \
|
||||
"
|
||||
: "=s"(check_nan), "+v"(tmp), "+v"(u.fp32)
|
||||
: "v"(ROUND_BIAS_FOR_BF16), "v"(FP32_NAN));
|
||||
|
||||
return uint16_t(u.int32);
|
||||
}
|
||||
|
||||
// Truncate instead of rounding, preserving SNaN
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
|
||||
@@ -177,6 +209,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
|
||||
{
|
||||
if constexpr(rounding == bf16_rounding_mode::standard)
|
||||
return float_to_bf16_rtn_raw(f);
|
||||
else if constexpr(rounding == bf16_rounding_mode::standard_asm)
|
||||
return float_to_bf16_rtn_asm(f);
|
||||
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
|
||||
return float_to_bf16_truc_nan_raw(f);
|
||||
else
|
||||
|
||||
Reference in New Issue
Block a user