[CK_TILE]Moe update index (#1672)

* update MOCK_ID for moe-sorting

* add moe-smoothquant

* update a comment

* fix format

* hot fix

* update topk in overflow case

* update comments

* update bf16 cvt

---------

Co-authored-by: valarLip <340077269@qq.com>
This commit is contained in:
carlushuang
2024-11-25 13:12:35 +08:00
committed by GitHub
parent ce2bdf42a9
commit 36c7ce4e0e
36 changed files with 1321 additions and 11 deletions

View File

@@ -18,6 +18,7 @@ enum class bf16_rounding_mode
truncate_with_nan,
truncate,
standard_asm,
rta_asm, // round to nearest away
};
template <bf16_rounding_mode rounding =
@@ -180,6 +181,39 @@ uint16_t float_to_bf16_rtn_asm(float f)
return uint16_t(u.int32);
}
// TODO: do we need this on host?
CK_TILE_HOST
uint16_t float_to_bf16_rta_asm(float f) { return float_to_bf16_rtn_raw(f); }
CK_TILE_DEVICE
uint16_t float_to_bf16_rta_asm(float f)
{
union
{
float fp32;
struct
{
uint16_t lo;
uint16_t hi;
};
} u = {f};
const uint32_t low_nan = 0x7fff;
const uint32_t hi_nan = 0x7fff0000;
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
uint32x2_t check_nan;
asm volatile("v_cmp_u_f32 %[s_cnan], %[v_x], %[v_x] \n"
"v_add3_u32 %[v_x], %[v_x], %[v_blo], 1 \n"
"v_cndmask_b32 %[v_x], %[v_x], %[v_bhi], %[s_cnan]"
: [s_cnan] "+s"(check_nan), [v_x] "+v"(u.fp32)
: [v_blo] "v"(low_nan), [v_bhi] "v"(hi_nan));
// Note: in above code snipet, we use hi 16 bit
return u.hi;
}
// Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE
constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
@@ -213,6 +247,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
return float_to_bf16_rtn_asm(f);
else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
return float_to_bf16_truc_nan_raw(f);
else if constexpr(rounding == bf16_rounding_mode::rta_asm)
return float_to_bf16_rta_asm(f);
else
return float_to_bf16_truc_raw(f);
}