mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE]Moe update index (#1672)
* update MOCK_ID for moe-sorting * add moe-smoothquant * update a comment * fix format * hot fix * update topk in overflow case * update comments * update bf16 cvt --------- Co-authored-by: valarLip <340077269@qq.com>
This commit is contained in:
@@ -18,6 +18,7 @@ enum class bf16_rounding_mode
|
||||
truncate_with_nan,
|
||||
truncate,
|
||||
standard_asm,
|
||||
rta_asm, // round to nearest away
|
||||
};
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
@@ -180,6 +181,39 @@ uint16_t float_to_bf16_rtn_asm(float f)
|
||||
return uint16_t(u.int32);
|
||||
}
|
||||
|
||||
// TODO: do we need this on host?
|
||||
CK_TILE_HOST
|
||||
uint16_t float_to_bf16_rta_asm(float f) { return float_to_bf16_rtn_raw(f); }
|
||||
|
||||
CK_TILE_DEVICE
|
||||
uint16_t float_to_bf16_rta_asm(float f)
|
||||
{
|
||||
union
|
||||
{
|
||||
float fp32;
|
||||
struct
|
||||
{
|
||||
uint16_t lo;
|
||||
uint16_t hi;
|
||||
};
|
||||
} u = {f};
|
||||
|
||||
const uint32_t low_nan = 0x7fff;
|
||||
const uint32_t hi_nan = 0x7fff0000;
|
||||
|
||||
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
|
||||
uint32x2_t check_nan;
|
||||
|
||||
asm volatile("v_cmp_u_f32 %[s_cnan], %[v_x], %[v_x] \n"
|
||||
"v_add3_u32 %[v_x], %[v_x], %[v_blo], 1 \n"
|
||||
"v_cndmask_b32 %[v_x], %[v_x], %[v_bhi], %[s_cnan]"
|
||||
: [s_cnan] "+s"(check_nan), [v_x] "+v"(u.fp32)
|
||||
: [v_blo] "v"(low_nan), [v_bhi] "v"(hi_nan));
|
||||
|
||||
// Note: in above code snipet, we use hi 16 bit
|
||||
return u.hi;
|
||||
}
|
||||
|
||||
// Truncate instead of rounding, preserving SNaN
|
||||
CK_TILE_HOST_DEVICE
|
||||
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
|
||||
@@ -213,6 +247,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
|
||||
return float_to_bf16_rtn_asm(f);
|
||||
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
|
||||
return float_to_bf16_truc_nan_raw(f);
|
||||
else if constexpr(rounding == bf16_rounding_mode::rta_asm)
|
||||
return float_to_bf16_rta_asm(f);
|
||||
else
|
||||
return float_to_bf16_truc_raw(f);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user