mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
slightly better
This commit is contained in:
@@ -1383,27 +1383,33 @@ CK_TILE_DEVICE double exp<double>(double x)
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T tanh_fast(T x)
|
||||
{
|
||||
return type_convert<T>((exp_fast_exp2<T>(2.0 * type_convert<float>(x)) - 1.0) / (exp_fast_exp2<T>(2.0 * type_convert<float>(x)) + 1.0));
|
||||
return x;
|
||||
// return type_convert<T>((exp_fast_exp2<T>(2.0 * type_convert<float>(x)) - 1.0) / (exp_fast_exp2<T>(2.0 * type_convert<float>(x)) + 1.0));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float tanh_fast<float>(float x)
|
||||
{
|
||||
// return (exp_fast_exp2<float>(2.0f * x) - 1.0f) / (exp_fast_exp2<float>(2.0f * x) + 1.0f);
|
||||
float e, r, s, t, d;
|
||||
float a = x;
|
||||
s = abs(a);
|
||||
t = -log2e_v<float> * 2.0f * s;
|
||||
e = __builtin_amdgcn_exp2f(t);
|
||||
d = e + 1.0f;
|
||||
r = __builtin_amdgcn_rcpf(d);
|
||||
r = e * (-r) + r;
|
||||
if (s < 4.997253418e-3f) r = a;
|
||||
union fipnr {float f; unsigned int i;};
|
||||
fipnr r_; r_.f = r;
|
||||
fipnr a_; a_.f = a;
|
||||
{ r_.i = (r_.i|(a_.i&0x80000000)); r = r_.f; }
|
||||
return r;
|
||||
// float a = __builtin_amdgcn_exp2f(2.88539f * x);
|
||||
// float b = __builtin_amdgcn_rcpf(a + 1.0f);
|
||||
// return 1.0f - 2.0f * b;
|
||||
// float x2 = x * x;
|
||||
// float x5 = x * 0.13333f - 0.33333f;
|
||||
// float x3 = x2 * x5 + 1.0f;
|
||||
// return x3 * x;
|
||||
// 1.442695
|
||||
// float a = 2.0f * log2e_v<float> * x;
|
||||
// a = __builtin_amdgcn_exp2f(a);
|
||||
// a = __builtin_amdgcn_rcpf(a + 1.0f);
|
||||
// a = 2 * a;
|
||||
// a = 1 - a;
|
||||
// return a;
|
||||
float a = __builtin_amdgcn_exp2f(2.88539f * x);
|
||||
float b = __builtin_amdgcn_rcpf(a + 1.0f);
|
||||
return 1.0f - 2.0f * b;
|
||||
|
||||
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -175,7 +175,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
"wrong!");
|
||||
|
||||
const float logits_cap = 30.0f;
|
||||
const float logits_cap_scale = scale_s / (logits_cap * log2e_v<float>);
|
||||
const float logits_cap_rev = __builtin_amdgcn_rcpf(logits_cap);
|
||||
|
||||
// K tile in LDS
|
||||
KDataType* k_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
|
||||
@@ -427,11 +427,11 @@ struct BlockFmhaPipelineQRKSVS
|
||||
else
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
float scale_lo = scale_s * 0.6931472f;
|
||||
// #if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = (x * scale_s) * __builtin_amdgcn_rcpf(log2e_v<>); }, s_acc);
|
||||
tile_elementwise_inout(
|
||||
[&logits_cap_scale, &logits_cap](auto& x) {
|
||||
x = log2e_v<SaccDataType> * logits_cap * tanh_fast<SaccDataType>(x * __builtin_amdgcn_rcpf(logits_cap));
|
||||
[&scale_lo, &logits_cap, &logits_cap_rev](auto& x) {
|
||||
x = log2e_v<SaccDataType> * logits_cap * tanh_fast<SaccDataType>(x * scale_lo * logits_cap_rev);
|
||||
},
|
||||
s_acc
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user