slightly better

This commit is contained in:
coderfeli
2025-04-16 14:38:14 +00:00
parent b9204670c9
commit 238331dbb5
2 changed files with 25 additions and 19 deletions

View File

@@ -1383,27 +1383,33 @@ CK_TILE_DEVICE double exp<double>(double x)
template <typename T>
CK_TILE_DEVICE T tanh_fast(T x)
{
return type_convert<T>((exp_fast_exp2<T>(2.0 * type_convert<float>(x)) - 1.0) / (exp_fast_exp2<T>(2.0 * type_convert<float>(x)) + 1.0));
return x;
// return type_convert<T>((exp_fast_exp2<T>(2.0 * type_convert<float>(x)) - 1.0) / (exp_fast_exp2<T>(2.0 * type_convert<float>(x)) + 1.0));
};
template <>
CK_TILE_DEVICE float tanh_fast<float>(float x)
{
// return (exp_fast_exp2<float>(2.0f * x) - 1.0f) / (exp_fast_exp2<float>(2.0f * x) + 1.0f);
float e, r, s, t, d;
float a = x;
s = abs(a);
t = -log2e_v<float> * 2.0f * s;
e = __builtin_amdgcn_exp2f(t);
d = e + 1.0f;
r = __builtin_amdgcn_rcpf(d);
r = e * (-r) + r;
if (s < 4.997253418e-3f) r = a;
union fipnr {float f; unsigned int i;};
fipnr r_; r_.f = r;
fipnr a_; a_.f = a;
{ r_.i = (r_.i|(a_.i&0x80000000)); r = r_.f; }
return r;
// float a = __builtin_amdgcn_exp2f(2.88539f * x);
// float b = __builtin_amdgcn_rcpf(a + 1.0f);
// return 1.0f - 2.0f * b;
// float x2 = x * x;
// float x5 = x * 0.13333f - 0.33333f;
// float x3 = x2 * x5 + 1.0f;
// return x3 * x;
// 1.442695
// float a = 2.0f * log2e_v<float> * x;
// a = __builtin_amdgcn_exp2f(a);
// a = __builtin_amdgcn_rcpf(a + 1.0f);
// a = 2 * a;
// a = 1 - a;
// return a;
float a = __builtin_amdgcn_exp2f(2.88539f * x);
float b = __builtin_amdgcn_rcpf(a + 1.0f);
return 1.0f - 2.0f * b;
};
template <typename T>

View File

@@ -175,7 +175,7 @@ struct BlockFmhaPipelineQRKSVS
"wrong!");
const float logits_cap = 30.0f;
const float logits_cap_scale = scale_s / (logits_cap * log2e_v<float>);
const float logits_cap_rev = __builtin_amdgcn_rcpf(logits_cap);
// K tile in LDS
KDataType* k_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
@@ -427,11 +427,11 @@ struct BlockFmhaPipelineQRKSVS
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
float scale_lo = scale_s * 0.6931472f;
// #if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = (x * scale_s) * __builtin_amdgcn_rcpf(log2e_v<>); }, s_acc);
tile_elementwise_inout(
[&logits_cap_scale, &logits_cap](auto& x) {
x = log2e_v<SaccDataType> * logits_cap * tanh_fast<SaccDataType>(x * __builtin_amdgcn_rcpf(logits_cap));
[&scale_lo, &logits_cap, &logits_cap_rev](auto& x) {
x = log2e_v<SaccDataType> * logits_cap * tanh_fast<SaccDataType>(x * scale_lo * logits_cap_rev);
},
s_acc
);