[CK_TILE] Add logits soft-capping & customization support to the FMHA forward kernel/pipelines (#2163)

* hack for cap logits

* fix bug

* Re-format files

* Allow specifying logits_soft_cap through APIs

* Support turn on/off logits_soft_cap in async pipeline

* Do not generate non-verified kernels

* Align receipt used in Aiter

* Sync logits soft-capping across pipelines

* Re-enable some hdim pipelines

* fix perf

* Add attention variant for logits_soft_cap

* Add newline at end-of-file

* Fix performance

* Add comment to explain logits_soft_cap pre-processing

* Unify code

* Unify floating-point literal style

* Use class data member to slience the compilation error

* [CK_TILE] Update attention customizaton interface: add LogitsMask() (#2133)

* Send 'mask' along with variant params to the LogitsMask()

* Send block indices to the variant

* Add indices parameters in variant interface

* Fix fmha bwd codegen error

* Allow switch logits_soft_cap impl

* Eliminate register spills

* Fix compilation errors

* Fix wrong LSE

* Fix LSE for splitkv kernel

* Sync splitkv pipeline changes

* Add batch_prefill kernel/pipeline

* Fix codegen error

* Undo changes in CMakeLists.txt

* Merge pipeline filtering check

* Use different code path if kHasLogitsSoftCap=false

* Remove [[maybe_unused]] attribute

* Use pre-existing compile-time flag to instantiate templates

* Sync pipeline changes

* Update CHANGELOG.md

---------

Co-authored-by: Bernard <bernaliu@amd.com>
Co-authored-by: coderfeli <coderfeli@163.com>
This commit is contained in:
Po Yen Chen
2025-05-13 12:19:25 +08:00
committed by GitHub
parent f05e45ba59
commit 2920604786
29 changed files with 4621 additions and 226 deletions

View File

@@ -11,6 +11,7 @@
#include <array>
#include <cstring>
#include <functional>
#include <cmath>
#include <numeric>
#include <ostream>
#include <string>
@@ -72,6 +73,7 @@ auto create_args(int argc, char* argv[])
"0",
"scale factor of S. 0 means equal to 1/sqrt(hdim).\n"
"note when squant=1, this value will be modified by range_q/k")
.insert("logits_soft_cap", "0", "attention logits soft capping value.")
.insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.")
.insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.")
.insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.")
@@ -416,6 +418,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(scale_s == .0f)
scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q)); // TODO: q ? v ?
const float logits_soft_cap = arg_parser.get_float("logits_soft_cap");
std::string squant_str = arg_parser.get_str("squant");
bool squant = [&]() {
if(squant_str == "auto")
@@ -850,6 +854,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
else // fmha_fwd_traits or fmha_splitkv_traits
{
traits.is_group_mode = (mode == mode_enum::group);
traits.has_logits_soft_cap = 0.f < logits_soft_cap;
traits.mask_type = mask.type;
traits.bias_type = bias.type;
traits.has_lse = lse;
@@ -1007,6 +1012,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.scale_p = scale_p;
args.scale_o = scale_o;
args.logits_soft_cap = logits_soft_cap;
args.stride_bias =
(bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias);
args.stride_o = stride_o;
@@ -1375,6 +1382,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::identity{},
ck_tile::scales(scale_s));
if(0.f < logits_soft_cap)
{
ck_tile::reference_unary_elementwise<SaccDataType, SaccDataType, SaccDataType>(
s_host_ref, s_host_ref, [logits_soft_cap](SaccDataType logits) {
return ck_tile::type_convert<SaccDataType>(
logits_soft_cap *
std::tanhf(ck_tile::type_convert<float>(logits / logits_soft_cap)));
});
}
if(bias.type == bias_enum::elementwise_bias)
{
// elementwise bias