mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
update scale for mxfp4
This commit is contained in:
@@ -227,11 +227,11 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x, float scale)
|
||||
return convert_to_type<pk_fp4_t>(x, scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float scale = 1.0f)
|
||||
{
|
||||
return float_to_e2m1(x, scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale = 1.0f)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
@@ -239,7 +239,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float sca
|
||||
return float_to_e2m1(type_convert<float>(x), scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale = 1.0f)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
@@ -247,7 +247,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float sca
|
||||
return float_to_e2m1(type_convert<float>(x), scale);
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float scale = 1.0f)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
@@ -256,7 +256,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float
|
||||
float_to_e2m1(type_convert<float>(x[1]), scale));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale = 1.0f)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
@@ -265,7 +265,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float
|
||||
float_to_e2m1(type_convert<float>(x[1]), scale));
|
||||
#endif
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale = 1.0f)
|
||||
{
|
||||
#if CK_TILE_FP4_CVT_DEVICE
|
||||
return impl::_to_f4(x, scale);
|
||||
@@ -274,27 +274,27 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t& x, float scale = 1.0f)
|
||||
{
|
||||
return x.to_fp32x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t& x, float scale = 1.0f)
|
||||
{
|
||||
return x.to_fp16x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t& x, float scale = 1.0f)
|
||||
{
|
||||
return x.to_bf16x2(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr float pk_fp4_to_float(const pk_fp4_t& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr float pk_fp4_to_float(const pk_fp4_t& x, float scale = 1.0f)
|
||||
{
|
||||
return x.to_float(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_to_fp16(const pk_fp4_t& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_to_fp16(const pk_fp4_t& x, float scale = 1.0f)
|
||||
{
|
||||
return x.to_fp16(scale);
|
||||
}
|
||||
CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_to_bf16(const pk_fp4_t& x, float scale)
|
||||
CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_to_bf16(const pk_fp4_t& x, float scale = 1.0f)
|
||||
{
|
||||
return x.to_bf16(scale);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user