[CK_TILE] support alibi (#1269)

* add alibi support

* fix code

* update code based on comment

* Support more hdim

* fix fp8 bias

* support seqlen_k=0 case

* remove unused printf

* fix format

---------

Co-authored-by: rocking <ChunYu.Lai@amd.com>
This commit is contained in:
carlushuang
2024-05-07 22:32:54 +08:00
committed by GitHub
parent 6d073d31bb
commit 851c3ed157
24 changed files with 948 additions and 115 deletions

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string>
#include <type_traits>
@@ -33,6 +34,7 @@ struct FmhaFwdKernel
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>;
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
@@ -41,7 +43,7 @@ struct FmhaFwdKernel
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr bool kHasBias = FmhaPipeline::kHasBias;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
@@ -81,7 +83,8 @@ struct FmhaFwdKernel
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(kHasBias ? "_bias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_
#undef _TS_
// clang-format on
@@ -136,6 +139,13 @@ struct FmhaFwdKernel
ck_tile::index_t batch_stride_bias = 0;
};
struct FmhaFwdAlibiKargs
{
// alibi is batch*nhead*1, no matter in batch/group mode, they are the same
const void* alibi_slope_ptr;
ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
};
struct FmhaFwdMaskKargs
{
// ck_tile::index_t window_size_left, window_size_right;
@@ -162,7 +172,11 @@ struct FmhaFwdKernel
struct FmhaFwdBatchModeKargs
: FmhaFwdCommonKargs,
std::conditional_t<kHasBias, FmhaFwdBatchModeBiasKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
FmhaFwdBatchModeBiasKargs,
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
FmhaFwdAlibiKargs,
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdBatchModeLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>
@@ -175,7 +189,11 @@ struct FmhaFwdKernel
struct FmhaFwdGroupModeKargs
: FmhaFwdCommonKargs,
std::conditional_t<kHasBias, FmhaFwdCommonBiasKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
FmhaFwdCommonBiasKargs,
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
FmhaFwdAlibiKargs,
FmhaFwdEmptyKargs<0>>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>
@@ -255,13 +273,18 @@ struct FmhaFwdKernel
batch_stride_v,
batch_stride_o};
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
kargs.bias_ptr = bias_ptr;
kargs.stride_bias = stride_bias;
kargs.nhead_stride_bias = nhead_stride_bias;
kargs.batch_stride_bias = batch_stride_bias;
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
kargs.alibi_slope_ptr = bias_ptr;
kargs.alibi_slope_stride = stride_bias;
}
if constexpr(kHasMask)
{
kargs.window_size_left = window_size_left;
@@ -345,12 +368,17 @@ struct FmhaFwdKernel
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
kargs.bias_ptr = bias_ptr;
kargs.stride_bias = stride_bias;
kargs.nhead_stride_bias = nhead_stride_bias;
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
kargs.alibi_slope_ptr = bias_ptr;
kargs.alibi_slope_stride = stride_bias;
}
if constexpr(kHasMask)
{
kargs.window_size_left = window_size_left;
@@ -421,14 +449,10 @@ struct FmhaFwdKernel
{
batch_offset_v = key_start;
}
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = query_start * kargs.stride_bias + key_start;
}
else
{
batch_offset_bias = key_start;
}
if constexpr(kStoreLSE)
{
batch_offset_lse = query_start;
@@ -461,7 +485,7 @@ struct FmhaFwdKernel
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
}
@@ -585,7 +609,7 @@ struct FmhaFwdKernel
const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto bias_dram_window_lengths =
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{});
if constexpr(kHasBias)
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
const BiasDataType* bias_ptr =
reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
@@ -654,6 +678,39 @@ struct FmhaFwdKernel
return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
}();
// WA i_batch capture structure binding before c++20
auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
// data loading, shared by entire wg
// TODO: how to use s_read?
SaccDataType slope =
*(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
i_batch_ * kargs.alibi_slope_stride + i_nhead_);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope *= ck_tile::log2e_v<>;
#endif
if constexpr(kHasMask)
{
return make_alibi_from_lr_mask<SaccDataType, true>(slope,
kargs.window_size_left,
kargs.window_size_right,
kargs.seqlen_q,
kargs.seqlen_k,
kargs.mask_type);
}
else
{
return Alibi<SaccDataType, true>{
slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::VERTICAL};
}
}
else
{
return EmptyPositionEncoding<SaccDataType>{};
}
}();
auto o_acc_tile = [&]() {
if constexpr(kDoFp8StaticQuant)
{
@@ -672,6 +729,7 @@ struct FmhaFwdKernel
scales{kargs.scale_p}, // p_compute_element_func
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
mask,
position_encoding,
kargs.scale_s,
smem_ptr);
}
@@ -683,6 +741,7 @@ struct FmhaFwdKernel
bias_dram_window,
lse_dram_window,
mask,
position_encoding,
kargs.scale_s,
smem_ptr);
}