From a84009f83b2da43f13c38f5e64168f865165a66c Mon Sep 17 00:00:00 2001 From: danyao12 Date: Mon, 13 May 2024 10:39:44 +0800 Subject: [PATCH] bwd alibi --- example/ck_tile/01_fmha/fmha_bwd.cpp | 112 ++++++++++++++---- example/ck_tile/01_fmha/fmha_bwd.hpp | 11 +- example/ck_tile/01_fmha/generate.py | 14 +-- .../ck_tile/01_fmha/script/smoke_test_bwd.sh | 2 +- .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 82 +++++++++++-- ...k_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp | 41 +++++-- ...block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp | 41 +++++-- ...mha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp | 41 +++++-- .../block_fmha_bwd_pipeline_problem.hpp | 2 +- 9 files changed, 278 insertions(+), 68 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index 37a21b25ef..3215c95f4d 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -41,23 +41,27 @@ auto create_args(int argc, char* argv[]) .insert("b", "2", "batch size") .insert("h", "8", "num of head, for q") .insert("h_k", - "0", - "num of head, for k/v, 0 means equal to h\n" + "-1", + "num of head, for k/v, -1 means equal to h\n" "if not equal to h, then this is GQA/MQA case") .insert("s", "3328", "seqlen_q. if group-mode, means the average value of seqlen_q\n" "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary") - .insert("s_k", "0", "seqlen_k, 0 means equal to s") + .insert("s_k", "-1", "seqlen_k, -1 means equal to s") .insert("d", "128", "head dim for q, k") - .insert("d_v", "0", "head dim for v, 0 means equal to d") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") .insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)") .insert("iperm", "1", "permute input\n" "if true, will be b*h*s*d, else b*s*h*d") .insert("operm", "1", "permute output") - .insert("bias", "0", "add bias or not") + .insert("bias", + "n", + "n or 0, no bias\n" + "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" + "a(libi) or 2, alibi with 1*h. a:1, b*h") .insert("dbias", "0", "output bias gradient or not") .insert("prec", "fp16", "data type. fp16 or bf16") .insert("mask", @@ -106,7 +110,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t batch = arg_parser.get_int("b"); ck_tile::index_t nhead = arg_parser.get_int("h"); ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); - if(nhead_k == 0) + if(nhead_k < 0) nhead_k = nhead; if(nhead % nhead_k != 0) @@ -117,11 +121,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t seqlen_q = arg_parser.get_int("s"); ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); - if(seqlen_k == 0) + if(seqlen_k < 0) seqlen_k = seqlen_q; ck_tile::index_t hdim_q = arg_parser.get_int("d"); ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); - if(hdim_v == 0) + if(hdim_v < 0) hdim_v = hdim_q; if(hdim_q % 2 != 0 || hdim_v % 2 != 0) { @@ -136,14 +140,14 @@ bool run(const ck_tile::ArgParser& arg_parser) if(scale == .0f) scale = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); - bool use_bias = arg_parser.get_bool("bias"); + bias_info bias = bias_info::decode(arg_parser.get_str("bias")); bool use_dbias = arg_parser.get_bool("dbias"); float p_drop = arg_parser.get_float("p_drop"); uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); - if(use_dbias && !use_bias) + if(use_dbias && bias.type != bias_enum::elementwise_bias) { - std::cerr << "dbias only exists when there is a bias" << std::endl; + std::cerr << "dbias only exists when bias type is elementwise" << std::endl; return false; } @@ -263,12 +267,15 @@ bool run(const ck_tile::ArgParser& arg_parser) get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); ck_tile::HostTensor v_host( get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)); - // use bias shape = [1, 1, shape_seqlen_q, max_seqlen_k]. if use_bias=false, the bias_host - // will not be used for verification at all (but will be copied to device anyway). ck_tile::HostTensor bias_host( - use_bias + bias.type == bias_enum::elementwise_bias ? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + ck_tile::HostTensor alibi_slope_host( + bias.type == bias_enum::alibi + ? (bias.rank_info == 0 ? std::array{1, nhead} + : std::array{batch, nhead}) + : std::array{1, 1}); ck_tile::HostTensor o_host( get_lengths(o_perm, shape_batch, nhead, shape_seqlen_q, hdim_v)); ck_tile::HostTensor lse_host( @@ -315,6 +322,24 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::FillTrigValue{}(bias_host); ck_tile::FillTrigValue{}(do_host); } + if(bias.type == bias_enum::alibi) + { + auto slopes = ck_tile::get_alibi_slopes(nhead); + assert(slopes.size() == nhead); + if(bias.rank_info == 0) + { + // alibi in 1*h + std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin()); + } + else + { + // alibi in b*h + for(auto i_b = 0; i_b < batch; i_b++) + { + std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin() + i_b * nhead); + } + } + } ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); @@ -331,6 +356,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem dbias_buf(dbias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); @@ -354,7 +380,7 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k - << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << use_bias + << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias << ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", mask:" << mask << std::flush; @@ -363,7 +389,7 @@ bool run(const ck_tile::ArgParser& arg_parser) data_type, mode == mode_enum::group, mask.type, - use_bias, + bias.type, use_dbias, p_drop > 0.0f}; auto fmha_args = [&]() { @@ -409,7 +435,8 @@ bool run(const ck_tile::ArgParser& arg_parser) return fmha_bwd_args{q_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(), v_buf.GetDeviceBuffer(), - bias_buf.GetDeviceBuffer(), + bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() + : bias_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(), lse_buf.GetDeviceBuffer(), do_buf.GetDeviceBuffer(), @@ -435,7 +462,8 @@ bool run(const ck_tile::ArgParser& arg_parser) stride_q, stride_k, stride_v, - stride_bias, + bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) + : stride_bias, stride_o, stride_randval, stride_do, @@ -556,10 +584,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::identity{}, ck_tile::scales(scale)); // s_g_m_n = scale * q_g_m_k@k_g_n_k - if(use_bias) + if(bias.type == bias_enum::elementwise_bias) { - // clang-format off + // elementwise bias ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); + // clang-format off if(i_perm) bias_host_ref.ForEach([&](auto& self, auto i) { self(i) = bias_host(0, 0, i[1] + query_offset, i[2]); }); else @@ -572,6 +601,49 @@ bool run(const ck_tile::ArgParser& arg_parser) reference_batched_elementwise( s_host_ref, bias_host_ref, s_host_ref); } + else if(bias.type == bias_enum::alibi) + { + // alibi construct elementwise bias to verify + auto alibi_host = [&]() { + if(mask.type != mask_enum::no_mask) + { + return ck_tile::make_alibi_from_lr_mask( + 0, + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + static_cast(mask.type)); + } + else + { + return ck_tile::Alibi{ + 0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::VERTICAL}; + } + }(); + + ck_tile::HostTensor alibi_bias_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + auto i_b_slope = bias.rank_info == 0 ? 0 : wb; + for(auto i_h = 0; i_h < nhead; i_h++) + { + AccDataType current_slope = alibi_slope_host(i_b_slope, i_h); + alibi_host.slope = current_slope; + for(auto i_r = 0; i_r < real_seqlen_q; i_r++) + { + for(auto i_c = 0; i_c < real_seqlen_k; i_c++) + { + AccDataType pixel = 0; + alibi_host.update(pixel, i_r, i_c); + alibi_bias_host_ref(i_h, i_r, i_c) = pixel; + } + } + } + // [nhead, real_seqlen_q, real_seqlen_k] + ck_tile:: + reference_batched_elementwise( + s_host_ref, alibi_bias_host_ref, s_host_ref); + } if(mask.type == mask_enum::no_mask) { diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 67817bbcd3..9aaa3e3f23 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/fmha.hpp" #include "ck_tile/ops/epilogue.hpp" #include "mask.hpp" +#include "bias.hpp" #include template @@ -66,7 +67,7 @@ struct fmha_bwd_args const void* q_ptr; const void* k_ptr; const void* v_ptr; - const void* bias_ptr; + const void* bias_ptr; // bias or alibi_slope pointer const void* o_ptr; const void* lse_ptr; const void* do_ptr; @@ -92,7 +93,7 @@ struct fmha_bwd_args ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; - ck_tile::index_t stride_bias; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 ck_tile::index_t stride_o; ck_tile::index_t stride_randval; ck_tile::index_t stride_do; @@ -291,7 +292,7 @@ template ; - static constexpr bool kHasBias = kHasBias_; + static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kHasBiasGrad = kHasBiasGrad_; static constexpr bool kHasDropout = kHasDropout_; static constexpr bool kPadS = kPadS_; @@ -338,7 +339,7 @@ struct fmha_bwd_traits std::string data_type; bool is_group_mode; mask_enum mask_type; - bool has_bias; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_dbias; bool has_dropout; // TODO: padding check is inside this api diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 1be8620089..1ce50e0547 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -665,7 +665,7 @@ FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < }} """ -FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) && +FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>; @@ -687,7 +687,7 @@ class FmhaBwdDQDKDVApiTrait: bhdq : int # q head_dim bhdv : int # v head_dim mask : str - bias : str # true/false + bias : str dbias : str dropout : str spad : str @@ -756,7 +756,7 @@ class FmhaBwdApiPool: if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")): continue inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout], F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype], F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad]) @@ -852,7 +852,7 @@ class FmhaBwdDQDKDVKernel: F_skpad = BOOL_MAP[self.F_skpad], F_dpad = BOOL_MAP[self.F_dpad], F_dvpad = BOOL_MAP[self.F_dvpad], - F_bias = BOOL_MAP[self.F_bias], + F_bias = BIAS_MAP[self.F_bias], F_dbias = BOOL_MAP[self.F_dbias], F_dropout = BOOL_MAP[self.F_dropout], F_occupancy = self.F_tile.F_occupancy, @@ -874,7 +874,7 @@ class FmhaBwdDQDKDVKernel: mn = mask_name() n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name +\ f"_p{BOOL_MAP[self.F_spad][0]}{BOOL_MAP[self.F_skpad][0]}{BOOL_MAP[self.F_dpad][0]}{BOOL_MAP[self.F_dvpad][0]}" +\ - f"_b{BOOL_MAP[self.F_bias][0]}_db{BOOL_MAP[self.F_dbias][0]}_dp{BOOL_MAP[self.F_dropout][0]}" + f"_b{BIAS_MAP[self.F_bias][0]}_db{BOOL_MAP[self.F_dbias][0]}_dp{BOOL_MAP[self.F_dropout][0]}" if mn != '' : n += f'{mn}' return n @@ -928,13 +928,13 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], mask_impl) -> Tuple[Fm d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype) if d == None: continue - for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]): + for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]): tile = d[hdim_str][0] ppl = d[hdim_str][1] hdim = int(hdim_str) if (mode == "group") and (spad == "f" or skpad == "f"): continue - if (bias == "f" and dbias == "t"): + if ((bias == "no" or bias == "alibi") and dbias == "t"): continue k = FmhaBwdDQDKDVKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, diff --git a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh index 0a26260df0..e4a38dfce8 100644 --- a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh @@ -13,7 +13,7 @@ for prec in "fp16" "bf16" ; do for perm in 0 1 ; do for hdim in 32 64 128 ; do for mode in 0 1 ; do -for bias in 0 1 ; do +for bias in "n" "e" "a"; do for dbias in 0 1 ; do for p_drop in 0.0 0.2; do diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index b657646ebe..7773a91b32 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include #include @@ -56,7 +57,7 @@ struct FmhaBwdDQDKDVKernel static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; - static constexpr bool kHasBias = FmhaPipeline::kHasBias; + static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad; static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; using FmhaMask = ck_tile::remove_cvref_t; @@ -91,7 +92,8 @@ struct FmhaBwdDQDKDVKernel _TS_(bfs::kQKHeaddim) + "x" + _TS_(bfs::kVHeaddim) + "_" + "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + - ("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) + (kHasBias ? "_bias" : "") + + ("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) + + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + (kHasBiasGrad ? "_dbias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ); #undef _SS_ #undef _TS_ @@ -161,6 +163,13 @@ struct FmhaBwdDQDKDVKernel ck_tile::index_t batch_stride_bias = 0; }; + struct FmhaBwdAlibiKargs + { + // alibi is batch*nhead*1, no matter in batch/group mode, they are the same + const void* alibi_slope_ptr; + ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope + }; + struct FmhaBwdCommonBiasGradKargs { void* dbias_ptr = nullptr; @@ -212,7 +221,11 @@ struct FmhaBwdDQDKDVKernel struct FmhaBwdBatchModeKargs : FmhaBwdCommonKargs, - std::conditional_t>, + std::conditional_t>>, std::conditional_t>, std::conditional_t>, std::conditional_t> @@ -227,7 +240,11 @@ struct FmhaBwdDQDKDVKernel struct FmhaBwdGroupModeKargs : FmhaBwdCommonKargs, - std::conditional_t>, + std::conditional_t>>, std::conditional_t>, std::conditional_t>, std::conditional_t> @@ -336,13 +353,18 @@ struct FmhaBwdDQDKDVKernel batch_stride_dk, batch_stride_dv}; - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { kargs.bias_ptr = bias_ptr; kargs.stride_bias = stride_bias; kargs.nhead_stride_bias = nhead_stride_bias; kargs.batch_stride_bias = batch_stride_bias; } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } if constexpr(kHasBiasGrad) { @@ -458,12 +480,17 @@ struct FmhaBwdDQDKDVKernel reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr)}; - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { kargs.bias_ptr = bias_ptr; kargs.stride_bias = stride_bias; kargs.nhead_stride_bias = nhead_stride_bias; } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } if constexpr(kHasBiasGrad) { kargs.dbias_ptr = dbias_ptr; @@ -537,14 +564,10 @@ struct FmhaBwdDQDKDVKernel batch_offset_lsed = static_cast(i_batch) * kargs.batch_stride_lsed; batch_offset_dk = key_start * kargs.stride_dk; batch_offset_dv = key_start * kargs.stride_dv; - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = query_start * kargs.stride_bias; } - else - { - batch_offset_bias = key_start; - } if constexpr(kHasBiasGrad) { batch_offset_dbias = query_start * kargs.stride_dbias; @@ -587,7 +610,7 @@ struct FmhaBwdDQDKDVKernel batch_offset_lsed = static_cast(i_batch) * kargs.batch_stride_lsed; batch_offset_dk = static_cast(i_batch) * kargs.batch_stride_dk; batch_offset_dv = static_cast(i_batch) * kargs.batch_stride_dv; - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; } @@ -919,7 +942,7 @@ struct FmhaBwdDQDKDVKernel constexpr auto bias_dram_window_lengths = make_tuple(number{}, number{}); const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { const BiasDataType* bias_ptr = reinterpret_cast(kargs.bias_ptr) + @@ -977,6 +1000,38 @@ struct FmhaBwdDQDKDVKernel } }(); + // WA i_batch capture structure binding before c++20 + auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + // data loading, shared by entire wg + // TODO: how to use s_read? + AccDataType slope = *(reinterpret_cast(kargs.alibi_slope_ptr) + + i_batch_ * kargs.alibi_slope_stride + i_nhead_); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + slope *= ck_tile::log2e_v<>; +#endif + if constexpr(kHasMask) + { + return make_alibi_from_lr_mask(slope, + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type); + } + else + { + return Alibi{ + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::VERTICAL}; + } + } + else + { + return EmptyPositionEncoding{}; + } + }(); + // dropout float rp_undrop = 1; float scale_rp_undrop = 1; @@ -1061,6 +1116,7 @@ struct FmhaBwdDQDKDVKernel dq_dram_window, dbias_dram_window, mask, + position_encoding, kargs.raw_scale, #if CK_TILE_FMHA_FWD_FAST_EXP2 kargs.scale, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp index 4b2c469ca9..4e20c1377c 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" @@ -58,7 +59,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr bool kHasBias = Problem::kHasBias; + static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; static constexpr bool kHasDropout = Problem::kHasDropout; @@ -102,7 +103,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR typename LSEDramBlockWindowTmp, typename DDramBlockWindowTmp, typename QGradDramBlockWindowTmp, - typename BiasGradDramBlockWindowTmp> + typename BiasGradDramBlockWindowTmp, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, const QTDramBlockWindowTmp& qt_dram_block_window_tmp, @@ -118,6 +120,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, FmhaMask mask, + PositionEncoding position_encoding, float raw_scale, #if CK_TILE_FMHA_FWD_FAST_EXP2 float scale, @@ -433,13 +436,13 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR q_block_tile = load_tile(q_dram_window); // global read 1 } - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads } const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads @@ -484,7 +487,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR } // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { block_sync_lds(); auto bias_shuffle_tmp = make_static_distributed_tensor( @@ -505,6 +508,28 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR biast_tile); move_tile_window(bias_dram_window, {kM0, 0}); } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + constexpr auto st_spans = decltype(st_acc)::get_distributed_spans(); + sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + st_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + st_acc(i_j_idx) *= raw_scale; +#else + st_acc(i_j_idx) *= scale; +#endif + position_encoding.update(st_acc(i_j_idx), row, col); + }); + }); + } else { #if !CK_TILE_FMHA_FWD_FAST_EXP2 @@ -532,7 +557,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR const auto lse = load_tile(lse_dram_window); static const auto get_validated_lse = [](LSEDataType raw_lse) { - if constexpr(kHasBias || FmhaMask::IsMasking) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) { return raw_lse == -numeric::infinity() ? type_convert(0.f) @@ -554,7 +580,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp index ce81b3bfd6..dda894da0f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" @@ -58,7 +59,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr bool kHasBias = Problem::kHasBias; + static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; static constexpr bool kHasDropout = Problem::kHasDropout; @@ -102,7 +103,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR typename LSEDramBlockWindowTmp, typename DDramBlockWindowTmp, typename QGradDramBlockWindowTmp, - typename BiasGradDramBlockWindowTmp> + typename BiasGradDramBlockWindowTmp, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, const QTDramBlockWindowTmp& qt_dram_block_window_tmp, @@ -118,6 +120,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, FmhaMask mask, + PositionEncoding position_encoding, float raw_scale, #if CK_TILE_FMHA_FWD_FAST_EXP2 float scale, @@ -406,13 +409,13 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR q_block_tile = load_tile(q_dram_window); // global read 1 } - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads } const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads @@ -457,7 +460,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR } // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { block_sync_lds(); auto bias_shuffle_tmp = make_static_distributed_tensor( @@ -478,6 +481,28 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR biast_tile); move_tile_window(bias_dram_window, {kM0, 0}); } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + constexpr auto st_spans = decltype(st_acc)::get_distributed_spans(); + sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + st_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + st_acc(i_j_idx) *= raw_scale; +#else + st_acc(i_j_idx) *= scale; +#endif + position_encoding.update(st_acc(i_j_idx), row, col); + }); + }); + } else { #if !CK_TILE_FMHA_FWD_FAST_EXP2 @@ -505,7 +530,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR const auto lse = load_tile(lse_dram_window); static const auto get_validated_lse = [](LSEDataType raw_lse) { - if constexpr(kHasBias || FmhaMask::IsMasking) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) { return raw_lse == -numeric::infinity() ? type_convert(0.f) @@ -527,7 +553,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp index 5ffa7f8d50..6a20567753 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" @@ -58,7 +59,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr bool kHasBias = Problem::kHasBias; + static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; static constexpr bool kHasDropout = Problem::kHasDropout; @@ -102,7 +103,8 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS typename LSEDramBlockWindowTmp, typename DDramBlockWindowTmp, typename QGradDramBlockWindowTmp, - typename BiasGradDramBlockWindowTmp> + typename BiasGradDramBlockWindowTmp, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, const QTDramBlockWindowTmp& /*qt_dram_block_window_tmp*/, @@ -118,6 +120,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp, FmhaMask mask, + PositionEncoding position_encoding, float raw_scale, #if CK_TILE_FMHA_FWD_FAST_EXP2 float scale, @@ -372,13 +375,13 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS clear_tile(st_acc); // Initialize S^T store_tile(q_lds_window, q_block_tile); // LDS write - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads } const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads @@ -413,7 +416,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS } // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { block_sync_lds(); auto bias_shuffle_tmp = make_static_distributed_tensor( @@ -434,6 +437,28 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS biast_tile); move_tile_window(bias_dram_window, {kM0, 0}); } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto q_origin = q_dram_block_window.get_window_origin(); + constexpr auto st_spans = decltype(st_acc)::get_distributed_spans(); + sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + st_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + +#if !CK_TILE_FMHA_FWD_FAST_EXP2 + st_acc(i_j_idx) *= raw_scale; +#else + st_acc(i_j_idx) *= scale; +#endif + position_encoding.update(st_acc(i_j_idx), row, col); + }); + }); + } else { #if !CK_TILE_FMHA_FWD_FAST_EXP2 @@ -461,7 +486,8 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS const auto lse = load_tile(lse_dram_window); static const auto get_validated_lse = [](LSEDataType raw_lse) { - if constexpr(kHasBias || FmhaMask::IsMasking) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) { return raw_lse == -numeric::infinity() ? type_convert(0.f) @@ -483,7 +509,8 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp index 5ed41d6264..7b787e9f36 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp @@ -55,7 +55,7 @@ struct BlockFmhaBwdPipelineProblem static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; - static constexpr bool kHasBias = Traits::kHasBias; + static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad; static constexpr bool kHasDropout = Traits::kHasDropout; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;