mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 12:11:19 +00:00
[CK_TILE] Add logits soft-capping & customization support to the FMHA forward kernel/pipelines (#2163)
* hack for cap logits * fix bug * Re-format files * Allow specifying logits_soft_cap through APIs * Support turn on/off logits_soft_cap in async pipeline * Do not generate non-verified kernels * Align receipt used in Aiter * Sync logits soft-capping across pipelines * Re-enable some hdim pipelines * fix perf * Add attention variant for logits_soft_cap * Add newline at end-of-file * Fix performance * Add comment to explain logits_soft_cap pre-processing * Unify code * Unify floating-point literal style * Use class data member to slience the compilation error * [CK_TILE] Update attention customizaton interface: add LogitsMask() (#2133) * Send 'mask' along with variant params to the LogitsMask() * Send block indices to the variant * Add indices parameters in variant interface * Fix fmha bwd codegen error * Allow switch logits_soft_cap impl * Eliminate register spills * Fix compilation errors * Fix wrong LSE * Fix LSE for splitkv kernel * Sync splitkv pipeline changes * Add batch_prefill kernel/pipeline * Fix codegen error * Undo changes in CMakeLists.txt * Merge pipeline filtering check * Use different code path if kHasLogitsSoftCap=false * Remove [[maybe_unused]] attribute * Use pre-existing compile-time flag to instantiate templates * Sync pipeline changes * Update CHANGELOG.md --------- Co-authored-by: Bernard <bernaliu@amd.com> Co-authored-by: coderfeli <coderfeli@163.com>
This commit is contained in:
@@ -487,6 +487,9 @@ struct log2e<float>
|
||||
template <typename T = double>
|
||||
constexpr T log2e_v = log2e<T>::value;
|
||||
|
||||
template <typename T = double>
|
||||
constexpr T log2e_rcp_v = 1. / log2e<T>::value;
|
||||
|
||||
CK_TILE_DEVICE
|
||||
float exp2(float x) { return exp2f(x); };
|
||||
|
||||
@@ -1380,6 +1383,44 @@ CK_TILE_DEVICE double exp<double>(double x)
|
||||
return exp(x);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T tanh_fast(T x)
|
||||
{
|
||||
return type_convert<T>((exp<T>(2.0 * type_convert<float>(x)) - 1.0) /
|
||||
(exp<T>(2.0 * type_convert<float>(x)) + 1.0));
|
||||
};
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE float tanh_fast<float>(float x)
|
||||
{
|
||||
// float a = __builtin_amdgcn_sinh(x);
|
||||
// float b = __builtin_amdgcn_cosh(x);
|
||||
// float e = a * __builtin_amdgcn_rcpf(b);
|
||||
// return e;
|
||||
|
||||
float a = 2.0f * log2e_v<float> * x;
|
||||
a = __builtin_amdgcn_exp2f(a);
|
||||
a = __builtin_amdgcn_rcpf(a + 1.0f);
|
||||
a = 2 * a;
|
||||
a = 1 - a;
|
||||
return a;
|
||||
|
||||
// float e, r, s, t, d;
|
||||
// float a = x;
|
||||
// s = abs(a);
|
||||
// t = -log2e_v<float> * 2.0f * s;
|
||||
// e = __builtin_amdgcn_exp2f(t);
|
||||
// d = e + 1.0f;
|
||||
// r = __builtin_amdgcn_rcpf(d);
|
||||
// r = e * (-r) + r;
|
||||
// if (s < 4.997253418e-3f) r = a;
|
||||
// union fipnr {float f; unsigned int i;};
|
||||
// fipnr r_; r_.f = r;
|
||||
// fipnr a_; a_.f = a;
|
||||
// { r_.i = (r_.i|(a_.i&0x80000000)); r = r_.f; }
|
||||
// return r;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T log(T x)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user