[CK_TILE] Add logits soft-capping & customization support to the FMHA forward kernel/pipelines (#2163)

* hack for cap logits

* fix bug

* Re-format files

* Allow specifying logits_soft_cap through APIs

* Support turn on/off logits_soft_cap in async pipeline

* Do not generate non-verified kernels

* Align receipt used in Aiter

* Sync logits soft-capping across pipelines

* Re-enable some hdim pipelines

* fix perf

* Add attention variant for logits_soft_cap

* Add newline at end-of-file

* Fix performance

* Add comment to explain logits_soft_cap pre-processing

* Unify code

* Unify floating-point literal style

* Use class data member to slience the compilation error

* [CK_TILE] Update attention customizaton interface: add LogitsMask() (#2133)

* Send 'mask' along with variant params to the LogitsMask()

* Send block indices to the variant

* Add indices parameters in variant interface

* Fix fmha bwd codegen error

* Allow switch logits_soft_cap impl

* Eliminate register spills

* Fix compilation errors

* Fix wrong LSE

* Fix LSE for splitkv kernel

* Sync splitkv pipeline changes

* Add batch_prefill kernel/pipeline

* Fix codegen error

* Undo changes in CMakeLists.txt

* Merge pipeline filtering check

* Use different code path if kHasLogitsSoftCap=false

* Remove [[maybe_unused]] attribute

* Use pre-existing compile-time flag to instantiate templates

* Sync pipeline changes

* Update CHANGELOG.md

---------

Co-authored-by: Bernard <bernaliu@amd.com>
Co-authored-by: coderfeli <coderfeli@163.com>
This commit is contained in:
Po Yen Chen
2025-05-13 12:19:25 +08:00
committed by GitHub
parent f05e45ba59
commit 2920604786
29 changed files with 4621 additions and 226 deletions

View File

@@ -487,6 +487,9 @@ struct log2e<float>
template <typename T = double>
constexpr T log2e_v = log2e<T>::value;
template <typename T = double>
constexpr T log2e_rcp_v = 1. / log2e<T>::value;
CK_TILE_DEVICE
float exp2(float x) { return exp2f(x); };
@@ -1380,6 +1383,44 @@ CK_TILE_DEVICE double exp<double>(double x)
return exp(x);
};
template <typename T>
CK_TILE_DEVICE T tanh_fast(T x)
{
return type_convert<T>((exp<T>(2.0 * type_convert<float>(x)) - 1.0) /
(exp<T>(2.0 * type_convert<float>(x)) + 1.0));
};
template <>
CK_TILE_DEVICE float tanh_fast<float>(float x)
{
// float a = __builtin_amdgcn_sinh(x);
// float b = __builtin_amdgcn_cosh(x);
// float e = a * __builtin_amdgcn_rcpf(b);
// return e;
float a = 2.0f * log2e_v<float> * x;
a = __builtin_amdgcn_exp2f(a);
a = __builtin_amdgcn_rcpf(a + 1.0f);
a = 2 * a;
a = 1 - a;
return a;
// float e, r, s, t, d;
// float a = x;
// s = abs(a);
// t = -log2e_v<float> * 2.0f * s;
// e = __builtin_amdgcn_exp2f(t);
// d = e + 1.0f;
// r = __builtin_amdgcn_rcpf(d);
// r = e * (-r) + r;
// if (s < 4.997253418e-3f) r = a;
// union fipnr {float f; unsigned int i;};
// fipnr r_; r_.f = r;
// fipnr a_; a_.f = a;
// { r_.i = (r_.i|(a_.i&0x80000000)); r = r_.f; }
// return r;
};
template <typename T>
CK_TILE_DEVICE T log(T x)
{