mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
bwd alibi
This commit is contained in:
@@ -41,23 +41,27 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("b", "2", "batch size")
|
||||
.insert("h", "8", "num of head, for q")
|
||||
.insert("h_k",
|
||||
"0",
|
||||
"num of head, for k/v, 0 means equal to h\n"
|
||||
"-1",
|
||||
"num of head, for k/v, -1 means equal to h\n"
|
||||
"if not equal to h, then this is GQA/MQA case")
|
||||
.insert("s",
|
||||
"3328",
|
||||
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
|
||||
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary")
|
||||
.insert("s_k", "0", "seqlen_k, 0 means equal to s")
|
||||
.insert("s_k", "-1", "seqlen_k, -1 means equal to s")
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "0", "head dim for v, 0 means equal to d")
|
||||
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
|
||||
.insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)")
|
||||
.insert("iperm",
|
||||
"1",
|
||||
"permute input\n"
|
||||
"if true, will be b*h*s*d, else b*s*h*d")
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("bias", "0", "add bias or not")
|
||||
.insert("bias",
|
||||
"n",
|
||||
"n or 0, no bias\n"
|
||||
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
|
||||
"a(libi) or 2, alibi with 1*h. a:1, b*h")
|
||||
.insert("dbias", "0", "output bias gradient or not")
|
||||
.insert("prec", "fp16", "data type. fp16 or bf16")
|
||||
.insert("mask",
|
||||
@@ -106,7 +110,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::index_t batch = arg_parser.get_int("b");
|
||||
ck_tile::index_t nhead = arg_parser.get_int("h");
|
||||
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
|
||||
if(nhead_k == 0)
|
||||
if(nhead_k < 0)
|
||||
nhead_k = nhead;
|
||||
|
||||
if(nhead % nhead_k != 0)
|
||||
@@ -117,11 +121,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
ck_tile::index_t seqlen_q = arg_parser.get_int("s");
|
||||
ck_tile::index_t seqlen_k = arg_parser.get_int("s_k");
|
||||
if(seqlen_k == 0)
|
||||
if(seqlen_k < 0)
|
||||
seqlen_k = seqlen_q;
|
||||
ck_tile::index_t hdim_q = arg_parser.get_int("d");
|
||||
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
|
||||
if(hdim_v == 0)
|
||||
if(hdim_v < 0)
|
||||
hdim_v = hdim_q;
|
||||
if(hdim_q % 2 != 0 || hdim_v % 2 != 0)
|
||||
{
|
||||
@@ -136,14 +140,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(scale == .0f)
|
||||
scale = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q));
|
||||
|
||||
bool use_bias = arg_parser.get_bool("bias");
|
||||
bias_info bias = bias_info::decode(arg_parser.get_str("bias"));
|
||||
bool use_dbias = arg_parser.get_bool("dbias");
|
||||
float p_drop = arg_parser.get_float("p_drop");
|
||||
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
|
||||
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
|
||||
if(use_dbias && !use_bias)
|
||||
if(use_dbias && bias.type != bias_enum::elementwise_bias)
|
||||
{
|
||||
std::cerr << "dbias only exists when there is a bias" << std::endl;
|
||||
std::cerr << "dbias only exists when bias type is elementwise" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -263,12 +267,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
|
||||
ck_tile::HostTensor<VDataType> v_host(
|
||||
get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v));
|
||||
// use bias shape = [1, 1, shape_seqlen_q, max_seqlen_k]. if use_bias=false, the bias_host
|
||||
// will not be used for verification at all (but will be copied to device anyway).
|
||||
ck_tile::HostTensor<BiasDataType> bias_host(
|
||||
use_bias
|
||||
bias.type == bias_enum::elementwise_bias
|
||||
? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
ck_tile::HostTensor<AccDataType> alibi_slope_host(
|
||||
bias.type == bias_enum::alibi
|
||||
? (bias.rank_info == 0 ? std::array<ck_tile::index_t, 2>{1, nhead}
|
||||
: std::array<ck_tile::index_t, 2>{batch, nhead})
|
||||
: std::array<ck_tile::index_t, 2>{1, 1});
|
||||
ck_tile::HostTensor<ODataType> o_host(
|
||||
get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v));
|
||||
ck_tile::HostTensor<LSEDataType> lse_host(
|
||||
@@ -315,6 +322,24 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::FillTrigValue<BiasDataType>{}(bias_host);
|
||||
ck_tile::FillTrigValue<OGradDataType>{}(do_host);
|
||||
}
|
||||
if(bias.type == bias_enum::alibi)
|
||||
{
|
||||
auto slopes = ck_tile::get_alibi_slopes<AccDataType>(nhead);
|
||||
assert(slopes.size() == nhead);
|
||||
if(bias.rank_info == 0)
|
||||
{
|
||||
// alibi in 1*h
|
||||
std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin());
|
||||
}
|
||||
else
|
||||
{
|
||||
// alibi in b*h
|
||||
for(auto i_b = 0; i_b < batch; i_b++)
|
||||
{
|
||||
std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin() + i_b * nhead);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
|
||||
@@ -331,6 +356,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
|
||||
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(q_host.data());
|
||||
k_buf.ToDevice(k_host.data());
|
||||
@@ -354,7 +380,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
|
||||
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k
|
||||
<< ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << use_bias
|
||||
<< ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias
|
||||
<< ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", mask:" << mask
|
||||
<< std::flush;
|
||||
|
||||
@@ -363,7 +389,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
data_type,
|
||||
mode == mode_enum::group,
|
||||
mask.type,
|
||||
use_bias,
|
||||
bias.type,
|
||||
use_dbias,
|
||||
p_drop > 0.0f};
|
||||
auto fmha_args = [&]() {
|
||||
@@ -409,7 +435,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
return fmha_bwd_args{q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
v_buf.GetDeviceBuffer(),
|
||||
bias_buf.GetDeviceBuffer(),
|
||||
bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
|
||||
: bias_buf.GetDeviceBuffer(),
|
||||
o_buf.GetDeviceBuffer(),
|
||||
lse_buf.GetDeviceBuffer(),
|
||||
do_buf.GetDeviceBuffer(),
|
||||
@@ -435,7 +462,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
stride_q,
|
||||
stride_k,
|
||||
stride_v,
|
||||
stride_bias,
|
||||
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead)
|
||||
: stride_bias,
|
||||
stride_o,
|
||||
stride_randval,
|
||||
stride_do,
|
||||
@@ -556,10 +584,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale)); // s_g_m_n = scale * q_g_m_k@k_g_n_k
|
||||
|
||||
if(use_bias)
|
||||
if(bias.type == bias_enum::elementwise_bias)
|
||||
{
|
||||
// clang-format off
|
||||
// elementwise bias
|
||||
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
|
||||
// clang-format off
|
||||
if(i_perm)
|
||||
bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); });
|
||||
else
|
||||
@@ -572,6 +601,49 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
reference_batched_elementwise<AccDataType, BiasDataType, AccDataType, AccDataType>(
|
||||
s_host_ref, bias_host_ref, s_host_ref);
|
||||
}
|
||||
else if(bias.type == bias_enum::alibi)
|
||||
{
|
||||
// alibi construct elementwise bias to verify
|
||||
auto alibi_host = [&]() {
|
||||
if(mask.type != mask_enum::no_mask)
|
||||
{
|
||||
return ck_tile::make_alibi_from_lr_mask<AccDataType, true>(
|
||||
0,
|
||||
mask.left,
|
||||
mask.right,
|
||||
real_seqlen_q,
|
||||
real_seqlen_k,
|
||||
static_cast<ck_tile::GenericAttentionMaskEnum>(mask.type));
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::Alibi<AccDataType, true>{
|
||||
0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::VERTICAL};
|
||||
}
|
||||
}();
|
||||
|
||||
ck_tile::HostTensor<AccDataType> alibi_bias_host_ref(
|
||||
{nhead, real_seqlen_q, real_seqlen_k});
|
||||
auto i_b_slope = bias.rank_info == 0 ? 0 : wb;
|
||||
for(auto i_h = 0; i_h < nhead; i_h++)
|
||||
{
|
||||
AccDataType current_slope = alibi_slope_host(i_b_slope, i_h);
|
||||
alibi_host.slope = current_slope;
|
||||
for(auto i_r = 0; i_r < real_seqlen_q; i_r++)
|
||||
{
|
||||
for(auto i_c = 0; i_c < real_seqlen_k; i_c++)
|
||||
{
|
||||
AccDataType pixel = 0;
|
||||
alibi_host.update(pixel, i_r, i_c);
|
||||
alibi_bias_host_ref(i_h, i_r, i_c) = pixel;
|
||||
}
|
||||
}
|
||||
}
|
||||
// [nhead, real_seqlen_q, real_seqlen_k]
|
||||
ck_tile::
|
||||
reference_batched_elementwise<AccDataType, AccDataType, AccDataType, AccDataType>(
|
||||
s_host_ref, alibi_bias_host_ref, s_host_ref);
|
||||
}
|
||||
|
||||
if(mask.type == mask_enum::no_mask)
|
||||
{
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "mask.hpp"
|
||||
#include "bias.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
template <typename DataType>
|
||||
@@ -66,7 +67,7 @@ struct fmha_bwd_args
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr;
|
||||
const void* bias_ptr; // bias or alibi_slope pointer
|
||||
const void* o_ptr;
|
||||
const void* lse_ptr;
|
||||
const void* do_ptr;
|
||||
@@ -92,7 +93,7 @@ struct fmha_bwd_args
|
||||
ck_tile::index_t stride_q;
|
||||
ck_tile::index_t stride_k;
|
||||
ck_tile::index_t stride_v;
|
||||
ck_tile::index_t stride_bias;
|
||||
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
|
||||
ck_tile::index_t stride_o;
|
||||
ck_tile::index_t stride_randval;
|
||||
ck_tile::index_t stride_do;
|
||||
@@ -291,7 +292,7 @@ template <ck_tile::index_t HDim_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_,
|
||||
typename FmhaMask_,
|
||||
bool kHasBias_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kHasBiasGrad_,
|
||||
bool kHasDropout_,
|
||||
bool kPadS_,
|
||||
@@ -305,7 +306,7 @@ struct fmha_bwd_dq_dk_dv_traits_
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
static constexpr bool kHasBias = kHasBias_;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
@@ -338,7 +339,7 @@ struct fmha_bwd_traits
|
||||
std::string data_type;
|
||||
bool is_group_mode;
|
||||
mask_enum mask_type;
|
||||
bool has_bias;
|
||||
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
|
||||
bool has_dbias;
|
||||
bool has_dropout;
|
||||
// TODO: padding check is inside this api
|
||||
|
||||
@@ -665,7 +665,7 @@ FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) &&
|
||||
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>;
|
||||
@@ -687,7 +687,7 @@ class FmhaBwdDQDKDVApiTrait:
|
||||
bhdq : int # q head_dim
|
||||
bhdv : int # v head_dim
|
||||
mask : str
|
||||
bias : str # true/false
|
||||
bias : str
|
||||
dbias : str
|
||||
dropout : str
|
||||
spad : str
|
||||
@@ -756,7 +756,7 @@ class FmhaBwdApiPool:
|
||||
if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")):
|
||||
continue
|
||||
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout],
|
||||
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype],
|
||||
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad])
|
||||
|
||||
@@ -852,7 +852,7 @@ class FmhaBwdDQDKDVKernel:
|
||||
F_skpad = BOOL_MAP[self.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_dvpad],
|
||||
F_bias = BOOL_MAP[self.F_bias],
|
||||
F_bias = BIAS_MAP[self.F_bias],
|
||||
F_dbias = BOOL_MAP[self.F_dbias],
|
||||
F_dropout = BOOL_MAP[self.F_dropout],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
@@ -874,7 +874,7 @@ class FmhaBwdDQDKDVKernel:
|
||||
mn = mask_name()
|
||||
n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name +\
|
||||
f"_p{BOOL_MAP[self.F_spad][0]}{BOOL_MAP[self.F_skpad][0]}{BOOL_MAP[self.F_dpad][0]}{BOOL_MAP[self.F_dvpad][0]}" +\
|
||||
f"_b{BOOL_MAP[self.F_bias][0]}_db{BOOL_MAP[self.F_dbias][0]}_dp{BOOL_MAP[self.F_dropout][0]}"
|
||||
f"_b{BIAS_MAP[self.F_bias][0]}_db{BOOL_MAP[self.F_dbias][0]}_dp{BOOL_MAP[self.F_dropout][0]}"
|
||||
if mn != '' : n += f'{mn}'
|
||||
return n
|
||||
|
||||
@@ -928,13 +928,13 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], mask_impl) -> Tuple[Fm
|
||||
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype)
|
||||
if d == None:
|
||||
continue
|
||||
for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
tile = d[hdim_str][0]
|
||||
ppl = d[hdim_str][1]
|
||||
hdim = int(hdim_str)
|
||||
if (mode == "group") and (spad == "f" or skpad == "f"):
|
||||
continue
|
||||
if (bias == "f" and dbias == "t"):
|
||||
if ((bias == "no" or bias == "alibi") and dbias == "t"):
|
||||
continue
|
||||
k = FmhaBwdDQDKDVKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile,
|
||||
F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad,
|
||||
|
||||
@@ -13,7 +13,7 @@ for prec in "fp16" "bf16" ; do
|
||||
for perm in 0 1 ; do
|
||||
for hdim in 32 64 128 ; do
|
||||
for mode in 0 1 ; do
|
||||
for bias in 0 1 ; do
|
||||
for bias in "n" "e" "a"; do
|
||||
for dbias in 0 1 ; do
|
||||
for p_drop in 0.0 0.2; do
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
@@ -56,7 +57,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
static constexpr bool kHasBias = FmhaPipeline::kHasBias;
|
||||
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
@@ -91,7 +92,8 @@ struct FmhaBwdDQDKDVKernel
|
||||
_TS_(bfs::kQKHeaddim) + "x" + _TS_(bfs::kVHeaddim) + "_" +
|
||||
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
|
||||
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
|
||||
("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) + (kHasBias ? "_bias" : "") +
|
||||
("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) +
|
||||
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
|
||||
(kHasBiasGrad ? "_dbias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" );
|
||||
#undef _SS_
|
||||
#undef _TS_
|
||||
@@ -161,6 +163,13 @@ struct FmhaBwdDQDKDVKernel
|
||||
ck_tile::index_t batch_stride_bias = 0;
|
||||
};
|
||||
|
||||
struct FmhaBwdAlibiKargs
|
||||
{
|
||||
// alibi is batch*nhead*1, no matter in batch/group mode, they are the same
|
||||
const void* alibi_slope_ptr;
|
||||
ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
|
||||
};
|
||||
|
||||
struct FmhaBwdCommonBiasGradKargs
|
||||
{
|
||||
void* dbias_ptr = nullptr;
|
||||
@@ -212,7 +221,11 @@ struct FmhaBwdDQDKDVKernel
|
||||
|
||||
struct FmhaBwdBatchModeKargs
|
||||
: FmhaBwdCommonKargs,
|
||||
std::conditional_t<kHasBias, FmhaBwdBatchModeBiasKargs, FmhaBwdEmptyKargs<0>>,
|
||||
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
|
||||
FmhaBwdBatchModeBiasKargs,
|
||||
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
|
||||
FmhaBwdAlibiKargs,
|
||||
FmhaBwdEmptyKargs<0>>>,
|
||||
std::conditional_t<kHasBiasGrad, FmhaBwdBatchModeBiasGradKargs, FmhaBwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>,
|
||||
std::conditional_t<kHasDropout, FmhaBwdBatchModeDropoutKargs, FmhaBwdEmptyKargs<3>>
|
||||
@@ -227,7 +240,11 @@ struct FmhaBwdDQDKDVKernel
|
||||
|
||||
struct FmhaBwdGroupModeKargs
|
||||
: FmhaBwdCommonKargs,
|
||||
std::conditional_t<kHasBias, FmhaBwdCommonBiasKargs, FmhaBwdEmptyKargs<0>>,
|
||||
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
|
||||
FmhaBwdCommonBiasKargs,
|
||||
std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
|
||||
FmhaBwdAlibiKargs,
|
||||
FmhaBwdEmptyKargs<0>>>,
|
||||
std::conditional_t<kHasBiasGrad, FmhaBwdCommonBiasGradKargs, FmhaBwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>,
|
||||
std::conditional_t<kHasDropout, FmhaBwdCommonDropoutKargs, FmhaBwdEmptyKargs<3>>
|
||||
@@ -336,13 +353,18 @@ struct FmhaBwdDQDKDVKernel
|
||||
batch_stride_dk,
|
||||
batch_stride_dv};
|
||||
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
kargs.bias_ptr = bias_ptr;
|
||||
kargs.stride_bias = stride_bias;
|
||||
kargs.nhead_stride_bias = nhead_stride_bias;
|
||||
kargs.batch_stride_bias = batch_stride_bias;
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
kargs.alibi_slope_ptr = bias_ptr;
|
||||
kargs.alibi_slope_stride = stride_bias;
|
||||
}
|
||||
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
@@ -458,12 +480,17 @@ struct FmhaBwdDQDKDVKernel
|
||||
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
|
||||
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
kargs.bias_ptr = bias_ptr;
|
||||
kargs.stride_bias = stride_bias;
|
||||
kargs.nhead_stride_bias = nhead_stride_bias;
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
kargs.alibi_slope_ptr = bias_ptr;
|
||||
kargs.alibi_slope_stride = stride_bias;
|
||||
}
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
kargs.dbias_ptr = dbias_ptr;
|
||||
@@ -537,14 +564,10 @@ struct FmhaBwdDQDKDVKernel
|
||||
batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
|
||||
batch_offset_dk = key_start * kargs.stride_dk;
|
||||
batch_offset_dv = key_start * kargs.stride_dv;
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = query_start * kargs.stride_bias;
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_bias = key_start;
|
||||
}
|
||||
if constexpr(kHasBiasGrad)
|
||||
{
|
||||
batch_offset_dbias = query_start * kargs.stride_dbias;
|
||||
@@ -587,7 +610,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
|
||||
batch_offset_dk = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dk;
|
||||
batch_offset_dv = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dv;
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
|
||||
}
|
||||
@@ -919,7 +942,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
constexpr auto bias_dram_window_lengths =
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{});
|
||||
const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
const BiasDataType* bias_ptr =
|
||||
reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
|
||||
@@ -977,6 +1000,38 @@ struct FmhaBwdDQDKDVKernel
|
||||
}
|
||||
}();
|
||||
|
||||
// WA i_batch capture structure binding before c++20
|
||||
auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
// data loading, shared by entire wg
|
||||
// TODO: how to use s_read?
|
||||
AccDataType slope = *(reinterpret_cast<const AccDataType*>(kargs.alibi_slope_ptr) +
|
||||
i_batch_ * kargs.alibi_slope_stride + i_nhead_);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
slope *= ck_tile::log2e_v<>;
|
||||
#endif
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
return make_alibi_from_lr_mask<AccDataType, true>(slope,
|
||||
kargs.window_size_left,
|
||||
kargs.window_size_right,
|
||||
kargs.seqlen_q,
|
||||
kargs.seqlen_k,
|
||||
kargs.mask_type);
|
||||
}
|
||||
else
|
||||
{
|
||||
return Alibi<AccDataType, true>{
|
||||
slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::VERTICAL};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return EmptyPositionEncoding<AccDataType>{};
|
||||
}
|
||||
}();
|
||||
|
||||
// dropout
|
||||
float rp_undrop = 1;
|
||||
float scale_rp_undrop = 1;
|
||||
@@ -1061,6 +1116,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
dq_dram_window,
|
||||
dbias_dram_window,
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.raw_scale,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
kargs.scale,
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
@@ -58,7 +59,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kHasBias = Problem::kHasBias;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
@@ -102,7 +103,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp,
|
||||
typename QGradDramBlockWindowTmp,
|
||||
typename BiasGradDramBlockWindowTmp>
|
||||
typename BiasGradDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const QTDramBlockWindowTmp& qt_dram_block_window_tmp,
|
||||
@@ -118,6 +120,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
|
||||
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
|
||||
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float raw_scale,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
float scale,
|
||||
@@ -433,13 +436,13 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
|
||||
q_block_tile = load_tile(q_dram_window); // global read 1
|
||||
}
|
||||
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
@@ -484,7 +487,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
|
||||
}
|
||||
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
block_sync_lds();
|
||||
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
|
||||
@@ -505,6 +508,28 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
|
||||
biast_tile);
|
||||
move_tile_window(bias_dram_window, {kM0, 0});
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans();
|
||||
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
st_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
st_acc(i_j_idx) *= raw_scale;
|
||||
#else
|
||||
st_acc(i_j_idx) *= scale;
|
||||
#endif
|
||||
position_encoding.update(st_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
@@ -532,7 +557,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
|
||||
const auto lse = load_tile(lse_dram_window);
|
||||
|
||||
static const auto get_validated_lse = [](LSEDataType raw_lse) {
|
||||
if constexpr(kHasBias || FmhaMask::IsMasking)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_lse == -numeric<LSEDataType>::infinity()
|
||||
? type_convert<LSEDataType>(0.f)
|
||||
@@ -554,7 +580,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
|
||||
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
@@ -58,7 +59,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kHasBias = Problem::kHasBias;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
@@ -102,7 +103,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp,
|
||||
typename QGradDramBlockWindowTmp,
|
||||
typename BiasGradDramBlockWindowTmp>
|
||||
typename BiasGradDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const QTDramBlockWindowTmp& qt_dram_block_window_tmp,
|
||||
@@ -118,6 +120,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR
|
||||
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
|
||||
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float raw_scale,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
float scale,
|
||||
@@ -406,13 +409,13 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR
|
||||
q_block_tile = load_tile(q_dram_window); // global read 1
|
||||
}
|
||||
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
@@ -457,7 +460,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR
|
||||
}
|
||||
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
block_sync_lds();
|
||||
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
|
||||
@@ -478,6 +481,28 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR
|
||||
biast_tile);
|
||||
move_tile_window(bias_dram_window, {kM0, 0});
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans();
|
||||
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
st_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
st_acc(i_j_idx) *= raw_scale;
|
||||
#else
|
||||
st_acc(i_j_idx) *= scale;
|
||||
#endif
|
||||
position_encoding.update(st_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
@@ -505,7 +530,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR
|
||||
const auto lse = load_tile(lse_dram_window);
|
||||
|
||||
static const auto get_validated_lse = [](LSEDataType raw_lse) {
|
||||
if constexpr(kHasBias || FmhaMask::IsMasking)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_lse == -numeric<LSEDataType>::infinity()
|
||||
? type_convert<LSEDataType>(0.f)
|
||||
@@ -527,7 +553,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR
|
||||
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
@@ -58,7 +59,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
|
||||
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
|
||||
static constexpr bool kHasBias = Problem::kHasBias;
|
||||
static constexpr auto BiasEnum = Problem::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
@@ -102,7 +103,8 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
|
||||
typename LSEDramBlockWindowTmp,
|
||||
typename DDramBlockWindowTmp,
|
||||
typename QGradDramBlockWindowTmp,
|
||||
typename BiasGradDramBlockWindowTmp>
|
||||
typename BiasGradDramBlockWindowTmp,
|
||||
typename PositionEncoding>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
|
||||
const QTDramBlockWindowTmp& /*qt_dram_block_window_tmp*/,
|
||||
@@ -118,6 +120,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
|
||||
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
|
||||
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
|
||||
FmhaMask mask,
|
||||
PositionEncoding position_encoding,
|
||||
float raw_scale,
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
float scale,
|
||||
@@ -372,13 +375,13 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
|
||||
clear_tile(st_acc); // Initialize S^T
|
||||
store_tile(q_lds_window, q_block_tile); // LDS write
|
||||
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
}
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(
|
||||
0); // prevent from messing up the order of global loads
|
||||
@@ -413,7 +416,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
|
||||
}
|
||||
|
||||
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
block_sync_lds();
|
||||
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
|
||||
@@ -434,6 +437,28 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
|
||||
biast_tile);
|
||||
move_tile_window(bias_dram_window, {kM0, 0});
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
const auto q_origin = q_dram_block_window.get_window_origin();
|
||||
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans();
|
||||
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
st_acc.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
st_acc(i_j_idx) *= raw_scale;
|
||||
#else
|
||||
st_acc(i_j_idx) *= scale;
|
||||
#endif
|
||||
position_encoding.update(st_acc(i_j_idx), row, col);
|
||||
});
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
#if !CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
@@ -461,7 +486,8 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
|
||||
const auto lse = load_tile(lse_dram_window);
|
||||
|
||||
static const auto get_validated_lse = [](LSEDataType raw_lse) {
|
||||
if constexpr(kHasBias || FmhaMask::IsMasking)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return raw_lse == -numeric<LSEDataType>::infinity()
|
||||
? type_convert<LSEDataType>(0.f)
|
||||
@@ -483,7 +509,8 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
|
||||
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
#if CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
if constexpr(kHasBias)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ struct BlockFmhaBwdPipelineProblem
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr bool kHasBias = Traits::kHasBias;
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
|
||||
static constexpr bool kHasDropout = Traits::kHasDropout;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
|
||||
Reference in New Issue
Block a user