mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
[CK_TILE] Add logits soft-capping & customization support to the FMHA forward kernel/pipelines (#2163)
* hack for cap logits
* fix bug
* Re-format files
* Allow specifying logits_soft_cap through APIs
* Support turn on/off logits_soft_cap in async pipeline
* Do not generate non-verified kernels
* Align receipt used in Aiter
* Sync logits soft-capping across pipelines
* Re-enable some hdim pipelines
* fix perf
* Add attention variant for logits_soft_cap
* Add newline at end-of-file
* Fix performance
* Add comment to explain logits_soft_cap pre-processing
* Unify code
* Unify floating-point literal style
* Use class data member to slience the compilation error
* [CK_TILE] Update attention customizaton interface: add LogitsMask() (#2133)
* Send 'mask' along with variant params to the LogitsMask()
* Send block indices to the variant
* Add indices parameters in variant interface
* Fix fmha bwd codegen error
* Allow switch logits_soft_cap impl
* Eliminate register spills
* Fix compilation errors
* Fix wrong LSE
* Fix LSE for splitkv kernel
* Sync splitkv pipeline changes
* Add batch_prefill kernel/pipeline
* Fix codegen error
* Undo changes in CMakeLists.txt
* Merge pipeline filtering check
* Use different code path if kHasLogitsSoftCap=false
* Remove [[maybe_unused]] attribute
* Use pre-existing compile-time flag to instantiate templates
* Sync pipeline changes
* Update CHANGELOG.md
---------
Co-authored-by: Bernard <bernaliu@amd.com>
Co-authored-by: coderfeli <coderfeli@163.com>
[ROCm/composable_kernel commit: 2920604786]
This commit is contained in:
@@ -487,6 +487,9 @@ struct log2e<float>
|
||||
template <typename T = double>
|
||||
constexpr T log2e_v = log2e<T>::value;
|
||||
|
||||
template <typename T = double>
|
||||
constexpr T log2e_rcp_v = 1. / log2e<T>::value;
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float exp2(float x) { return exp2f(x); };
|
||||
|
||||
@@ -1380,6 +1383,44 @@ CK_TILE_DEVICE double exp<double>(double x)
|
||||
return exp(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T tanh_fast(T x)
|
||||
{
|
||||
return type_convert<T>((exp<T>(2.0 * type_convert<float>(x)) - 1.0) /
|
||||
(exp<T>(2.0 * type_convert<float>(x)) + 1.0));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float tanh_fast<float>(float x)
|
||||
{
|
||||
// float a = __builtin_amdgcn_sinh(x);
|
||||
// float b = __builtin_amdgcn_cosh(x);
|
||||
// float e = a * __builtin_amdgcn_rcpf(b);
|
||||
// return e;
|
||||
|
||||
float a = 2.0f * log2e_v<float> * x;
|
||||
a = __builtin_amdgcn_exp2f(a);
|
||||
a = __builtin_amdgcn_rcpf(a + 1.0f);
|
||||
a = 2 * a;
|
||||
a = 1 - a;
|
||||
return a;
|
||||
|
||||
// float e, r, s, t, d;
|
||||
// float a = x;
|
||||
// s = abs(a);
|
||||
// t = -log2e_v<float> * 2.0f * s;
|
||||
// e = __builtin_amdgcn_exp2f(t);
|
||||
// d = e + 1.0f;
|
||||
// r = __builtin_amdgcn_rcpf(d);
|
||||
// r = e * (-r) + r;
|
||||
// if (s < 4.997253418e-3f) r = a;
|
||||
// union fipnr {float f; unsigned int i;};
|
||||
// fipnr r_; r_.f = r;
|
||||
// fipnr a_; a_.f = a;
|
||||
// { r_.i = (r_.i|(a_.i&0x80000000)); r = r_.f; }
|
||||
// return r;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T log(T x)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user