mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
[CK_TILE] Add logits soft-capping & customization support to the FMHA forward kernel/pipelines (#2163)
* hack for cap logits * fix bug * Re-format files * Allow specifying logits_soft_cap through APIs * Support turn on/off logits_soft_cap in async pipeline * Do not generate non-verified kernels * Align receipt used in Aiter * Sync logits soft-capping across pipelines * Re-enable some hdim pipelines * fix perf * Add attention variant for logits_soft_cap * Add newline at end-of-file * Fix performance * Add comment to explain logits_soft_cap pre-processing * Unify code * Unify floating-point literal style * Use class data member to slience the compilation error * [CK_TILE] Update attention customizaton interface: add LogitsMask() (#2133) * Send 'mask' along with variant params to the LogitsMask() * Send block indices to the variant * Add indices parameters in variant interface * Fix fmha bwd codegen error * Allow switch logits_soft_cap impl * Eliminate register spills * Fix compilation errors * Fix wrong LSE * Fix LSE for splitkv kernel * Sync splitkv pipeline changes * Add batch_prefill kernel/pipeline * Fix codegen error * Undo changes in CMakeLists.txt * Merge pipeline filtering check * Use different code path if kHasLogitsSoftCap=false * Remove [[maybe_unused]] attribute * Use pre-existing compile-time flag to instantiate templates * Sync pipeline changes * Update CHANGELOG.md --------- Co-authored-by: Bernard <bernaliu@amd.com> Co-authored-by: coderfeli <coderfeli@163.com>
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
@@ -72,6 +73,7 @@ auto create_args(int argc, char* argv[])
|
||||
"0",
|
||||
"scale factor of S. 0 means equal to 1/sqrt(hdim).\n"
|
||||
"note when squant=1, this value will be modified by range_q/k")
|
||||
.insert("logits_soft_cap", "0", "attention logits soft capping value.")
|
||||
.insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.")
|
||||
.insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.")
|
||||
.insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.")
|
||||
@@ -416,6 +418,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(scale_s == .0f)
|
||||
scale_s = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q)); // TODO: q ? v ?
|
||||
|
||||
const float logits_soft_cap = arg_parser.get_float("logits_soft_cap");
|
||||
|
||||
std::string squant_str = arg_parser.get_str("squant");
|
||||
bool squant = [&]() {
|
||||
if(squant_str == "auto")
|
||||
@@ -850,6 +854,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
else // fmha_fwd_traits or fmha_splitkv_traits
|
||||
{
|
||||
traits.is_group_mode = (mode == mode_enum::group);
|
||||
traits.has_logits_soft_cap = 0.f < logits_soft_cap;
|
||||
traits.mask_type = mask.type;
|
||||
traits.bias_type = bias.type;
|
||||
traits.has_lse = lse;
|
||||
@@ -1007,6 +1012,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
args.scale_p = scale_p;
|
||||
args.scale_o = scale_o;
|
||||
|
||||
args.logits_soft_cap = logits_soft_cap;
|
||||
|
||||
args.stride_bias =
|
||||
(bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) : stride_bias);
|
||||
args.stride_o = stride_o;
|
||||
@@ -1375,6 +1382,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale_s));
|
||||
|
||||
if(0.f < logits_soft_cap)
|
||||
{
|
||||
ck_tile::reference_unary_elementwise<SaccDataType, SaccDataType, SaccDataType>(
|
||||
s_host_ref, s_host_ref, [logits_soft_cap](SaccDataType logits) {
|
||||
return ck_tile::type_convert<SaccDataType>(
|
||||
logits_soft_cap *
|
||||
std::tanhf(ck_tile::type_convert<float>(logits / logits_soft_cap)));
|
||||
});
|
||||
}
|
||||
|
||||
if(bias.type == bias_enum::elementwise_bias)
|
||||
{
|
||||
// elementwise bias
|
||||
|
||||
Reference in New Issue
Block a user