diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index e31c96caaa..85d25c63d9 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -17,7 +17,7 @@ add_custom_command( set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") # not using add_example_executable() to add this target, since we don't want this to have # to be included in "make all/install/check" -message("adding tile_example ${EXAMPLE_NAME}") +message("adding example ${EXAMPLE_FMHA_FWD}") add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp) target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index 5a428e4d41..fd5690a795 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -30,27 +30,29 @@ args: -mode kernel mode. 0:batch, 1:group (default:0) -b batch size (default:2) -h num of head, for q (default:8) - -h_k num of head, for k/v, 0 means equal to h (default:0) + -h_k num of head, for k/v, -1 means equal to h (default:-1) if not equal to h, then this is GQA/MQA case -s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328) total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary - -s_k seqlen_k, 0 means equal to s (default:0) + -s_k seqlen_k, -1 means equal to s (default:-1) -d head dim for q, k (default:128) - -d_v head dim for v, 0 means equal to d (default:0) + -d_v head dim for v, -1 means equal to d (default:-1) -scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0) note when squant=1, this value will be modified by range_q/k - -range_q per-tensor quantization range of q. used if squant=1. (default:2) - -range_k per-tensor quantization range of k. used if squant=1. (default:2) - -range_v per-tensor quantization range of v. used if squant=1. (default:2) + -range_q per-tensor quantization range of q. used if squant=1. (default:16) + -range_k per-tensor quantization range of k. used if squant=1. (default:16) + -range_v per-tensor quantization range of v. used if squant=1. (default:16) -range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1) - -range_o per-tensor quantization range of o (p*v). used if squant=1. (default:2) + -range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16) -squant if using static quantization fusion or not. 0: original flow(not prefered) (default:0) 1: apply scale_p and scale_o with respect to P and O. calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o -iperm permute input (default:1) if true, will be b*h*s*d, else b*s*h*d -operm permute output (default:1) - -bias add bias or not (default:0) + -bias n or 0, no bias (default:n) + e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s + a(libi) or 2, alibi with 1*h. a:1, b*h -prec data type. fp16/bf16/fp8/bf8 (default:fp16) -mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0) 't', top-left causal mask, 'b', bottom-r causal mask @@ -59,11 +61,11 @@ args: 'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa 'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa 'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now) - -vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r) -lse 0 not store lse, 1 store lse (default:0) -kname if set to 1 will print kernel name (default:0) -init init method. 0:random int, 1:random float, 2:trig float, 3:quantization (default:1) + -seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939) ``` Example: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case. @@ -85,6 +87,9 @@ If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support prov ### attention bias Attention bias is supported with the layout of `1*1*s*s`(similiar to input/output, different layout can be supported by changing the stride value for bias, or even extend to `b*h*s*s`) and bias value in float number. +### alibi +alibi is supported + ### lse For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting `-lse=1` diff --git a/example/ck_tile/01_fmha/bias.hpp b/example/ck_tile/01_fmha/bias.hpp new file mode 100644 index 0000000000..f9dc656f63 --- /dev/null +++ b/example/ck_tile/01_fmha/bias.hpp @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha.hpp" + +// keep sync with BlockAttentionBiasEnum +enum class bias_enum +{ + no_bias = 0, + elementwise_bias = 1, + alibi = 2, +}; + +struct bias_info +{ + bias_enum type; + /* + * simple dispatch logic + * + * if type == elementwise_bias: + * if rank_info == 0: + * bias is 1*1*s*s + * elif rank_info == 1: + * bias is 1*h*s*s + * elif rank_info == 2: + * bias is b*h*s*s + * + * elif type == alibi: + * if rank_info == 0: + * alibi in 1*h + * elif rank_info == 1: + * alibi in b*h + */ + int rank_info; + + void serialize(std::ostream& os) const + { + if(type == bias_enum::no_bias) + os << "n"; + else if(type == bias_enum::elementwise_bias) + { + os << "e"; + if(rank_info != 0) + { + os << "[" << rank_info << "]"; + } + } + else if(type == bias_enum::alibi) + { + os << "alibi"; + if(rank_info != 0) + { + os << "[" << rank_info << "]"; + } + } + } + + static bias_info decode(std::string str) + { + bias_info info{bias_enum::no_bias, 0}; + if(str == "0" || str == "n") + { + info.type = bias_enum::no_bias; + } + else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 || + str.compare(0, 11, "elementwise") == 0) + { + info.type = bias_enum::elementwise_bias; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string e = str.substr(found_0 + 1); + info.rank_info = atoi(e.c_str()); + } + } + else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 || + str.compare(0, 5, "alibi") == 0) + { + info.type = bias_enum::alibi; + auto found_0 = str.find(':'); + if(found_0 != std::string::npos) + { + std::string e = str.substr(found_0 + 1); + info.rank_info = atoi(e.c_str()); + } + } + return info; + } + + friend std::ostream& operator<<(std::ostream& os, const bias_info& bi) + { + bi.serialize(os); + return os; + } +}; diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 8ca4ff9337..686633bb2d 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -41,16 +41,16 @@ auto create_args(int argc, char* argv[]) .insert("b", "2", "batch size") .insert("h", "8", "num of head, for q") .insert("h_k", - "0", - "num of head, for k/v, 0 means equal to h\n" + "-1", + "num of head, for k/v, -1 means equal to h\n" "if not equal to h, then this is GQA/MQA case") .insert("s", "3328", "seqlen_q. if group-mode, means the average value of seqlen_q\n" "total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary") - .insert("s_k", "0", "seqlen_k, 0 means equal to s") + .insert("s_k", "-1", "seqlen_k, -1 means equal to s") .insert("d", "128", "head dim for q, k") - .insert("d_v", "0", "head dim for v, 0 means equal to d") + .insert("d_v", "-1", "head dim for v, -1 means equal to d") .insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim).\n" @@ -71,7 +71,11 @@ auto create_args(int argc, char* argv[]) "permute input\n" "if true, will be b*h*s*d, else b*s*h*d") .insert("operm", "1", "permute output") - .insert("bias", "0", "add bias or not") + .insert("bias", + "n", + "n or 0, no bias\n" + "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" + "a(libi) or 2, alibi with 1*h. a:1, b*h") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("mask", "0", @@ -153,7 +157,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t batch = arg_parser.get_int("b"); ck_tile::index_t nhead = arg_parser.get_int("h"); ck_tile::index_t nhead_k = arg_parser.get_int("h_k"); - if(nhead_k == 0) + if(nhead_k < 0) nhead_k = nhead; if(nhead % nhead_k != 0) @@ -164,11 +168,11 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::index_t seqlen_q = arg_parser.get_int("s"); ck_tile::index_t seqlen_k = arg_parser.get_int("s_k"); - if(seqlen_k == 0) + if(seqlen_k < 0) seqlen_k = seqlen_q; ck_tile::index_t hdim_q = arg_parser.get_int("d"); ck_tile::index_t hdim_v = arg_parser.get_int("d_v"); - if(hdim_v == 0) + if(hdim_v < 0) hdim_v = hdim_q; bool i_perm = arg_parser.get_bool("iperm"); // if true, will be batch * nhead * seqlen * hdim @@ -208,9 +212,9 @@ bool run(const ck_tile::ArgParser& arg_parser) } std::string vlayout = arg_parser.get_str("vlayout"); - bool use_bias = arg_parser.get_bool("bias"); bool lse = arg_parser.get_bool("lse"); + bias_info bias = bias_info::decode(arg_parser.get_str("bias")); mask_info mask = mask_info::decode(arg_parser.get_str("mask"), seqlen_q, seqlen_k); int init_method = arg_parser.get_int("init"); @@ -295,12 +299,18 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::HostTensor v_host( is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)); - // use bias shape = [1, 1, shape_seqlen_q, shape_seqlen_k]. if use_bias=false, the bias_host - // will not be used for verification at all (but will be copied to device anyway). + ck_tile::HostTensor bias_host( - use_bias + bias.type == bias_enum::elementwise_bias ? get_lengths(i_perm, 1, 1, shape_seqlen_q, shape_seqlen_k) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); + + ck_tile::HostTensor alibi_slope_host( + bias.type == bias_enum::alibi + ? (bias.rank_info == 0 ? std::array{1, nhead} + : std::array{batch, nhead}) + : std::array{1, 1}); + // self define lse data layout as [shape_batch, nhead, shape_seqlen_q] ck_tile::HostTensor lse_host( lse ? std::array{shape_batch, nhead, shape_seqlen_q} @@ -341,6 +351,24 @@ bool run(const ck_tile::ArgParser& arg_parser) // Assume bias is in [-1.f, 1.f] in original fp32 ck_tile::FillUniformDistribution{-qscale_bias, qscale_bias, seed}(bias_host); } + if(bias.type == bias_enum::alibi) + { + auto slopes = ck_tile::get_alibi_slopes(nhead); + assert(slopes.size() == nhead); + if(bias.rank_info == 0) + { + // alibi in 1*h + std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin()); + } + else + { + // alibi in b*h + for(auto i_b = 0; i_b < batch; i_b++) + { + std::copy(slopes.begin(), slopes.end(), alibi_slope_host.begin() + i_b * nhead); + } + } + } ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); @@ -350,6 +378,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); @@ -357,6 +386,7 @@ bool run(const ck_tile::ArgParser& arg_parser) bias_buf.ToDevice(bias_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); seqstart_k.ToDevice(seqstart_k_host.data()); + alibi_slope_buf.ToDevice(alibi_slope_host.data()); // clang-format off auto layout_str = [&](bool permute){ @@ -372,9 +402,9 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k - << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s - << ", bias:" << use_bias << ", lse:" << lse << ", squant:" << squant - << ", mask:" << mask << ", v:" << vlayout << std::flush; + << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias + << ", lse:" << lse << ", squant:" << squant << ", mask:" << mask << ", v:" << vlayout + << std::flush; auto fmha_traits = fmha_fwd_traits{hdim_q, hdim_v, @@ -382,7 +412,7 @@ bool run(const ck_tile::ArgParser& arg_parser) mode == mode_enum::group, is_v_rowmajor, mask.type, - use_bias, + bias.type, lse, squant}; @@ -441,7 +471,8 @@ bool run(const ck_tile::ArgParser& arg_parser) return fmha_fwd_args{q_buf.GetDeviceBuffer(), k_buf.GetDeviceBuffer(), v_buf.GetDeviceBuffer(), - bias_buf.GetDeviceBuffer(), + bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer() + : bias_buf.GetDeviceBuffer(), lse_buf.GetDeviceBuffer(), o_buf.GetDeviceBuffer(), seqstart_q.GetDeviceBuffer(), @@ -461,7 +492,8 @@ bool run(const ck_tile::ArgParser& arg_parser) stride_q, stride_k, stride_v, - stride_bias, + bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead) + : stride_bias, stride_o, nhead_stride_q, nhead_stride_k, @@ -564,8 +596,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::identity{}, ck_tile::scales(scale_s)); - if(use_bias) + if(bias.type == bias_enum::elementwise_bias) { + // elementwise bias ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); // clang-format off if(i_perm) @@ -582,6 +615,51 @@ bool run(const ck_tile::ArgParser& arg_parser) SMPLComputeDataType>( s_host_ref, bias_host_ref, s_host_ref); } + else if(bias.type == bias_enum::alibi) + { + // alibi construct elementwise bias to verify + auto alibi_host = [&]() { + if(mask.type != mask_enum::no_mask) + { + return ck_tile::make_alibi_from_lr_mask( + 0, + mask.left, + mask.right, + real_seqlen_q, + real_seqlen_k, + static_cast(mask.type)); + } + else + { + return ck_tile::Alibi{ + 0, real_seqlen_q, real_seqlen_k, ck_tile::AlibiMode::VERTICAL}; + } + }(); + + ck_tile::HostTensor alibi_bias_host_ref( + {nhead, real_seqlen_q, real_seqlen_k}); + auto i_b_slope = bias.rank_info == 0 ? 0 : wb; + for(auto i_h = 0; i_h < nhead; i_h++) + { + SaccDataType current_slope = alibi_slope_host(i_b_slope, i_h); + alibi_host.slope = current_slope; + for(auto i_r = 0; i_r < real_seqlen_q; i_r++) + { + for(auto i_c = 0; i_c < real_seqlen_k; i_c++) + { + SaccDataType pixel = 0; + alibi_host.update(pixel, i_r, i_c); + alibi_bias_host_ref(i_h, i_r, i_c) = pixel; + } + } + } + // [nhead, real_seqlen_q, real_seqlen_k] + ck_tile::reference_batched_elementwise( + s_host_ref, alibi_bias_host_ref, s_host_ref); + } if(mask.type == mask_enum::no_mask) { diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 9a82ab6b79..fb3907fec1 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/fmha.hpp" #include "ck_tile/ops/epilogue.hpp" #include "mask.hpp" +#include "bias.hpp" #include template @@ -86,7 +87,7 @@ struct fmha_fwd_args const void* q_ptr; const void* k_ptr; const void* v_ptr; - const void* bias_ptr; + const void* bias_ptr; // bias or alibi_slope pointer void* lse_ptr; void* o_ptr; const void* seqstart_q_ptr; @@ -106,7 +107,7 @@ struct fmha_fwd_args ck_tile::index_t stride_q; ck_tile::index_t stride_k; ck_tile::index_t stride_v; - ck_tile::index_t stride_bias; + ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 ck_tile::index_t stride_o; ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_k; @@ -219,7 +220,7 @@ template ; - static constexpr bool kHasBias = kHasBias_; + static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kStoreLse = kStoreLse_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kPadS = kPadS_; @@ -261,7 +262,7 @@ struct fmha_fwd_traits bool is_group_mode; bool is_v_rowmajor; mask_enum mask_type; - bool has_bias; + bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bool has_lse; bool do_fp8_static_quant; // TODO: padding check is inside this api diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 56d699e5fe..51fecd07b5 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -40,6 +40,19 @@ MASK_MAP = { "generic" : "FmhaMasks::GenericMask" } +BIAS_MAP = { + "no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS", + "bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS", + "alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI" +} + +# TODO: this is ugly +BIAS_CHECK_MAP = { + "no" : "bias_enum::no_bias", + "bias" : "bias_enum::elementwise_bias", + "alibi" : "bias_enum::alibi" +} + MODE_MAP = { "batch" : "false", "group" : "true" @@ -173,7 +186,7 @@ MASK_SIMPLIFIED_CHECK_MAP = { "s_mask" : "t.mask_type != mask_enum::no_mask", } -FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && +FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}, {F_vlayout}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; return fmha_fwd_(s, a); @@ -213,7 +226,7 @@ class FmhaFwdApiTrait: bk0blen : int vlayout : str mask : str - bias : str # true/false + bias : str # lse : str # squant : str # spad : str @@ -241,8 +254,8 @@ class FmhaFwdApiTrait: def skcheck(self) -> str: if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true if self.pipeline_tag == 'qr_async': - if self.skpad == 't' : return f'a.seqlen_k % {self.bn0} != 0' - else : return f'a.seqlen_k % {self.bn0} == 0' + if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0' + else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0' elif self.pipeline_tag in ['qr', 'qr_fp8']: if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_k % {self.bn0} == 0' @@ -297,7 +310,7 @@ class FmhaFwdPipeline: pn = pad_name() n = f'{self.tag}_v{self.F_vlayout[0]}' if pn != '' : n += f'_{pn}' - if self.F_bias == 't' : n += '_bias' + if self.F_bias != 'no' : n += f'_{self.F_bias}' if self.F_mask[0:2] == 's_': if self.F_mask == 's_mask': n += f'_mask' else: @@ -332,7 +345,8 @@ class FmhaFwdApiPool: if_k = 'if' if k == 0 else 'else if' inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_mask=get_mask_map(self.mask_impl)[trait.mask], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_lse=BOOL_MAP[trait.lse], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], + F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad], F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0blen=trait.bk0blen, @@ -400,7 +414,7 @@ class FmhaFwdKernel: F_skpad = BOOL_MAP[self.F_pipeline.F_skpad], F_dpad = BOOL_MAP[self.F_pipeline.F_dpad], F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad], - F_bias = BOOL_MAP[self.F_pipeline.F_bias], + F_bias = BIAS_MAP[self.F_pipeline.F_bias], F_lse = BOOL_MAP[self.F_pipeline.F_lse], F_squant = BOOL_MAP[self.F_pipeline.F_squant], F_occupancy = self.F_tile.F_occupancy, @@ -454,7 +468,9 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[ } elif dtype == 'fp8' or dtype == 'bf8': return { - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1) + '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1), + '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1), + '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1) } else: return None @@ -472,7 +488,7 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"]): + for mask, bias, lse in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): if hdim == 256: # if True: pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', bias, lse, squant, mask)) @@ -490,7 +506,7 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', bias, lse, squant, mask)) # TODO: cover arbitraty hdim elif dtype in ['fp8', 'bf8']: # no need lse kernels - for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), ["t", "f"]): + for mask, bias in itertools.product(get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', squant, mask)) else: assert False diff --git a/example/ck_tile/01_fmha/mask.hpp b/example/ck_tile/01_fmha/mask.hpp index 56fc8b8b1d..c77b700b16 100644 --- a/example/ck_tile/01_fmha/mask.hpp +++ b/example/ck_tile/01_fmha/mask.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -149,11 +149,9 @@ struct mask_info return tmp; } - friend std::ostream& operator<<(std::ostream& os, const mask_info& mi); + friend std::ostream& operator<<(std::ostream& os, const mask_info& mi) + { + mi.serialize(os); + return os; + } }; - -inline std::ostream& operator<<(std::ostream& os, const mask_info& mi) -{ - mi.serialize(os); - return os; -} diff --git a/example/ck_tile/01_fmha/script/smoke_test.sh b/example/ck_tile/01_fmha/script/smoke_test.sh index 4dd5c2ae12..2c4bb562a3 100755 --- a/example/ck_tile/01_fmha/script/smoke_test.sh +++ b/example/ck_tile/01_fmha/script/smoke_test.sh @@ -17,7 +17,7 @@ for perm in 0 1 ; do for vlayout in "r" "c" ; do for hdim in 32 64 128 256 ; do for lse in 0 1 ; do -for bias in 0 1 ; do +for bias in "n" "e" "a"; do # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS @@ -27,6 +27,7 @@ $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$b $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS +$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -kname=$KNAME $COMMON_ARGS done done @@ -37,9 +38,11 @@ done done for perm in 0 1 ; do -for bias in 0 1 ; do +for bias in "n" "e" "a" ; do for b in 1 2 ; do +for hdim in 64 128 256 ; do $EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS done done done +done diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index d915df6e4c..82b6953b5e 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -154,3 +154,8 @@ #ifndef CK_TILE_USE_SUBDWORD_TILE_CAST #define CK_TILE_USE_SUBDWORD_TILE_CAST 0 #endif + +// TODO: better solve this inside compiler +#ifndef CK_TILE_FMHA_FWD_FAST_EXP2 +#define CK_TILE_FMHA_FWD_FAST_EXP2 0 +#endif diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 72ec607b42..d4984363da 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -536,4 +536,15 @@ float log(float x) { return __logf(x); }; CK_TILE_HOST float log(float x) { return std::logf(x); }; +CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) +{ + // TODO: this is hacky, we use u16 + return __builtin_amdgcn_sad_u16(x, y, acc); +} + +CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) +{ + return (x > y ? (x - y) : (y - x)) + acc; +} + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index c567e63ddf..1122bf87b7 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -3,7 +3,9 @@ #pragma once +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp" +#include "ck_tile/ops/fmha/block/block_position_encoding.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" diff --git a/include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp new file mode 100644 index 0000000000..e5be21e048 --- /dev/null +++ b/include/ck_tile/ops/fmha/block/block_attention_bias_enum.hpp @@ -0,0 +1,37 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck_tile { + +// This class is used for codegen pattern matching +enum class BlockAttentionBiasEnum +{ + NO_BIAS = 0, + ELEMENTWISE_BIAS = 1, // attention bias, each elements add to the result of Q*K(after scale) + ALIBI = 2, // bias computed with position encoding, applied after scale +}; + +template +struct BlockAttentionBiasEnumToStr; + +template <> +struct BlockAttentionBiasEnumToStr +{ + static constexpr const char* name = ""; +}; +template <> +struct BlockAttentionBiasEnumToStr +{ + static constexpr const char* name = "bias"; +}; +template <> +struct BlockAttentionBiasEnumToStr +{ + static constexpr const char* name = "alibi"; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/block/block_position_encoding.hpp b/include/ck_tile/ops/fmha/block/block_position_encoding.hpp new file mode 100644 index 0000000000..9c6c353908 --- /dev/null +++ b/include/ck_tile/ops/fmha/block/block_position_encoding.hpp @@ -0,0 +1,189 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_masking.hpp" +#include +#include + +namespace ck_tile { + +enum struct PositionEncodingEnum +{ + NO = 0, + ALIBI = 1, +}; + +/* +VERTICAL: + [0] 1 2 3 4 5 + [0] 1 2 3 4 5 + [0] 1 2 3 4 5 + [0] 1 2 3 4 5 + +TOP_LEFT: + [0] 1 2 3 4 5 + 1 [0] 1 2 3 4 + 2 1 [0] 1 2 3 + 3 2 1 [0] 1 2 + +FROM_BOTTOM_RIGHT: + 2 1 [0] 1 2 3 + 3 2 1 [0] 1 2 + 4 3 2 1 [0] 1 + 5 4 3 2 1 [0] +*/ + +enum struct AlibiMode +{ + VERTICAL = 0, + FROM_TOP_LEFT = 1, // keep sync with mask enum + FROM_BOTTOM_RIGHT = 2, +}; + +template +struct Alibi +{ + // RowMajor here means if pixel within the same thread are along the row, or col + // this may impact the performance of update(), while the result are the same. + // e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false + CK_TILE_HOST_DEVICE Alibi(DataType slope_, + index_t y_total_, + index_t x_total_, + AlibiMode mode_ = AlibiMode::VERTICAL) + { + slope = mode_ == AlibiMode::VERTICAL ? slope_ : -slope; + + shift_left_up = [&]() { + if(RowMajor) + { + return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0; + } + else + { + return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0; + } + }(); + shift_right_down = [&]() { + if(RowMajor) + { + return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(x_total_ - y_total_, 0) : 0; + } + else + { + return mode_ == AlibiMode::FROM_BOTTOM_RIGHT ? max(y_total_ - x_total_, 0) : 0; + } + }(); + mode = mode_; + } + + CK_TILE_HOST_DEVICE void update(DataType& pixel, index_t row_idx, index_t col_idx) + { + if constexpr(RowMajor) + { + // at least 3 instructions per row + index_t current_zero_point = + mode == AlibiMode::VERTICAL ? shift_right_down : row_idx + shift_right_down; + + // for every threads, most of the pixels are along the row, below operation should be + // the main hot spot. + auto position = type_convert(sad(bit_cast(current_zero_point), + bit_cast(col_idx + shift_left_up), + 0)); + pixel += slope * position; + } + else + { + // at least 3 instructions per col; + index_t current_zero_point = mode == AlibiMode::VERTICAL + ? row_idx + col_idx + shift_right_down + : col_idx + shift_right_down; + + // for every threads, most of the pixels are along the col, below operation should be + // the main hot spot. + auto position = type_convert(sad(bit_cast(current_zero_point), + bit_cast(row_idx + shift_left_up), + 0)); + pixel += slope * position; + } + } + + DataType slope; // float? + index_t shift_left_up; // always possitive + index_t shift_right_down; // always possitive + AlibiMode mode; +}; + +template +struct EmptyPositionEncoding +{ + CK_TILE_HOST_DEVICE void update(DataType& /*pixel*/, index_t /*row_idx*/, index_t /*col_idx*/) + { + } +}; + +// +// can convert from the FA style left/right to our generic coordinate +// if left_size < 0 && right_size = 0, it is normal causal mask +// local is left_size >=0 or right_size >=0 +template +CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope, + index_t window_left_size, + index_t window_right_size, + index_t y_total, + index_t x_total, + GenericAttentionMaskEnum mask_enum) +{ + // assume mask_enum will never be NO_MASK, since if we do not have mask, it's + // totally OK to use constexpr + bool is_causal = window_left_size < 0 && window_right_size == 0; + AlibiMode alibi_mode = + is_causal ? AlibiMode::VERTICAL + : static_cast(mask_enum) /*either top-left or bottom-right*/; + return Alibi{slope, y_total, x_total, alibi_mode}; +} + +// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742 +// Do we need a device version? +template +CK_TILE_HOST std::vector get_alibi_slopes(ck_tile::index_t nheads) +{ + auto get_slopes_power_of_2 = [](ck_tile::index_t n) { + float start = std::powf( + static_cast(2), + -std::powf(static_cast(2), -static_cast((integer_log2_floor(n) - 3)))); + + std::vector rtn; + for(auto i = 0; i < n; i++) + { + rtn.push_back(static_cast(start * std::powf(start, i))); + } + return rtn; + }; + if(is_power_of_two_integer(nheads)) + { + // power of 2 calculation + return get_slopes_power_of_2(nheads); + } + else + { + ck_tile::index_t closest_power_of_2 = 1 << integer_log2_floor(nheads); + auto v0 = get_slopes_power_of_2(closest_power_of_2); + auto v1 = get_slopes_power_of_2(closest_power_of_2 * 2); + auto v1_sliced = [&](auto vec, ck_tile::index_t rem) { + std::vector sliced; + for(ck_tile::index_t i = 0; i < static_cast(vec.size()); i++) + { + if(i % 2 == 0) + sliced.push_back(vec[i]); + } + std::vector sliced_2(sliced.begin(), sliced.begin() + rem); + return sliced_2; + }(v1, nheads - closest_power_of_2); + v0.insert(v0.end(), v1_sliced.begin(), v1_sliced.end()); + return v0; + } +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 0732fd2ce2..10ce7395ad 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include #include @@ -33,6 +34,7 @@ struct FmhaFwdKernel using BiasDataType = ck_tile::remove_cvref_t; using LSEDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; + using SaccDataType = ck_tile::remove_cvref_t; using VLayout = ck_tile::remove_cvref_t; @@ -41,7 +43,7 @@ struct FmhaFwdKernel static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; - static constexpr bool kHasBias = FmhaPipeline::kHasBias; + static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; using FmhaMask = ck_tile::remove_cvref_t; @@ -81,7 +83,8 @@ struct FmhaFwdKernel "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + "v" + (std::is_same_v ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + - (kHasBias ? "_bias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); #undef _SS_ #undef _TS_ // clang-format on @@ -136,6 +139,13 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_bias = 0; }; + struct FmhaFwdAlibiKargs + { + // alibi is batch*nhead*1, no matter in batch/group mode, they are the same + const void* alibi_slope_ptr; + ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope + }; + struct FmhaFwdMaskKargs { // ck_tile::index_t window_size_left, window_size_right; @@ -162,7 +172,11 @@ struct FmhaFwdKernel struct FmhaFwdBatchModeKargs : FmhaFwdCommonKargs, - std::conditional_t>, + std::conditional_t>>, std::conditional_t>, std::conditional_t>, std::conditional_t> @@ -175,7 +189,11 @@ struct FmhaFwdKernel struct FmhaFwdGroupModeKargs : FmhaFwdCommonKargs, - std::conditional_t>, + std::conditional_t>>, std::conditional_t>, std::conditional_t>, std::conditional_t> @@ -255,13 +273,18 @@ struct FmhaFwdKernel batch_stride_v, batch_stride_o}; - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { kargs.bias_ptr = bias_ptr; kargs.stride_bias = stride_bias; kargs.nhead_stride_bias = nhead_stride_bias; kargs.batch_stride_bias = batch_stride_bias; } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } if constexpr(kHasMask) { kargs.window_size_left = window_size_left; @@ -345,12 +368,17 @@ struct FmhaFwdKernel reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr)}; - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { kargs.bias_ptr = bias_ptr; kargs.stride_bias = stride_bias; kargs.nhead_stride_bias = nhead_stride_bias; } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + kargs.alibi_slope_ptr = bias_ptr; + kargs.alibi_slope_stride = stride_bias; + } if constexpr(kHasMask) { kargs.window_size_left = window_size_left; @@ -421,14 +449,10 @@ struct FmhaFwdKernel { batch_offset_v = key_start; } - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = query_start * kargs.stride_bias + key_start; } - else - { - batch_offset_bias = key_start; - } if constexpr(kStoreLSE) { batch_offset_lse = query_start; @@ -461,7 +485,7 @@ struct FmhaFwdKernel batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; } @@ -585,7 +609,7 @@ struct FmhaFwdKernel const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { constexpr auto bias_dram_window_lengths = make_tuple(number{}, number{}); - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { const BiasDataType* bias_ptr = reinterpret_cast(kargs.bias_ptr) + @@ -654,6 +678,39 @@ struct FmhaFwdKernel return FmhaMask{kargs.seqlen_q, kargs.seqlen_k}; }(); + // WA i_batch capture structure binding before c++20 + auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() { + if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + // data loading, shared by entire wg + // TODO: how to use s_read? + SaccDataType slope = + *(reinterpret_cast(kargs.alibi_slope_ptr) + + i_batch_ * kargs.alibi_slope_stride + i_nhead_); +#if CK_TILE_FMHA_FWD_FAST_EXP2 + slope *= ck_tile::log2e_v<>; +#endif + if constexpr(kHasMask) + { + return make_alibi_from_lr_mask(slope, + kargs.window_size_left, + kargs.window_size_right, + kargs.seqlen_q, + kargs.seqlen_k, + kargs.mask_type); + } + else + { + return Alibi{ + slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::VERTICAL}; + } + } + else + { + return EmptyPositionEncoding{}; + } + }(); + auto o_acc_tile = [&]() { if constexpr(kDoFp8StaticQuant) { @@ -672,6 +729,7 @@ struct FmhaFwdKernel scales{kargs.scale_p}, // p_compute_element_func composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func mask, + position_encoding, kargs.scale_s, smem_ptr); } @@ -683,6 +741,7 @@ struct FmhaFwdKernel bias_dram_window, lse_dram_window, mask, + position_encoding, kargs.scale_s, smem_ptr); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp index 5500174086..cf70dff63f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp @@ -13,4 +13,23 @@ enum class BlockFmhaPipelineEnum QSKSVS, }; +template +struct BlockFmhaPipelineEnumToStr; + +template <> +struct BlockFmhaPipelineEnumToStr +{ + static constexpr const char* name = "qr"; +}; +template <> +struct BlockFmhaPipelineEnumToStr +{ + static constexpr const char* name = "qr_async"; +}; +template <> +struct BlockFmhaPipelineEnumToStr +{ + static constexpr const char* name = "qs"; +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp index 9d27b2df68..159fb40743 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp @@ -45,7 +45,7 @@ struct BlockFmhaPipelineProblem static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; - static constexpr bool kHasBias = Traits::kHasBias; + static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 9e239bb916..60650761d8 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" @@ -46,7 +47,7 @@ struct BlockFmhaPipelineQRKSVS static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr bool kHasBias = Problem::kHasBias; + static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; // last dimension vector length used to create tensor view(and decide buffer_load vector length) @@ -82,7 +83,7 @@ struct BlockFmhaPipelineQRKSVS } else if constexpr(kK0BlockLength <= 128) { - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 1; else return 2; @@ -113,7 +114,8 @@ struct BlockFmhaPipelineQRKSVS typename LSEElementFunction, typename SAccElementFunction, typename PComputeElementFunction, - typename OAccElementFunction> + typename OAccElementFunction, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -129,6 +131,7 @@ struct BlockFmhaPipelineQRKSVS const PComputeElementFunction& p_compute_element_func, const OAccElementFunction& o_acc_element_func, FmhaMask mask, + PositionEncoding position_encoding, float scale_s, void* smem_ptr) const { @@ -270,13 +273,13 @@ struct BlockFmhaPipelineQRKSVS k_block_tile = load_tile(k_dram_window); } - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads } const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads @@ -322,7 +325,7 @@ struct BlockFmhaPipelineQRKSVS } // STAGE 2, scale_s, add bias, mask, softmax - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); @@ -338,6 +341,25 @@ struct BlockFmhaPipelineQRKSVS s_acc, bias_tile); } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } else { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); @@ -382,7 +404,8 @@ struct BlockFmhaPipelineQRKSVS static const auto get_validated_m = [](SMPLComputeDataType raw_m) { /// NOTICE: bias might be materialized mask including -inf values, need /// consideration - if constexpr(kHasBias || FmhaMask::IsMasking) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) { return raw_m == -numeric::infinity() ? type_convert(0.f) @@ -403,7 +426,8 @@ struct BlockFmhaPipelineQRKSVS sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } @@ -427,7 +451,8 @@ struct BlockFmhaPipelineQRKSVS constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 const auto tmp = [&]() { - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); } @@ -519,7 +544,8 @@ struct BlockFmhaPipelineQRKSVS sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); } @@ -563,7 +589,8 @@ struct BlockFmhaPipelineQRKSVS typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, - typename LSEDramBlockWindowTmp> + typename LSEDramBlockWindowTmp, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile @@ -571,6 +598,7 @@ struct BlockFmhaPipelineQRKSVS const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, + PositionEncoding position_encoding, float scale_s, void* smem_ptr) const { @@ -588,6 +616,7 @@ struct BlockFmhaPipelineQRKSVS identity{}, identity{}, mask, + position_encoding, scale_s, smem_ptr); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 0573b50d04..8a19deb02a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" @@ -51,7 +52,7 @@ struct BlockFmhaPipelineQRKSVSAsync static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) - static constexpr bool kHasBias = Problem::kHasBias; + static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; // last dimension vector length used to create tensor view(and decide buffer_load vector length) @@ -79,21 +80,22 @@ struct BlockFmhaPipelineQRKSVSAsync { if constexpr(kK0BlockLength <= 32) { - if constexpr(kPadSeqLenK && kHasBias && FmhaMask::IsMasking) + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS && + FmhaMask::IsMasking) return 1; else return 2; } else if constexpr(kK0BlockLength <= 64) { - if constexpr(kPadSeqLenK && kHasBias) + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 2; else return 3; } else if constexpr(kK0BlockLength <= 128) { - if constexpr(kPadSeqLenK && kHasBias) + if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 1; else return 2; @@ -124,7 +126,8 @@ struct BlockFmhaPipelineQRKSVSAsync typename LSEElementFunction, typename SAccElementFunction, typename PComputeElementFunction, - typename OAccElementFunction> + typename OAccElementFunction, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -140,6 +143,7 @@ struct BlockFmhaPipelineQRKSVSAsync const PComputeElementFunction& p_compute_element_func, const OAccElementFunction& o_acc_element_func, FmhaMask mask, + PositionEncoding position_encoding, float scale_s, void* smem_ptr) const { @@ -247,8 +251,8 @@ struct BlockFmhaPipelineQRKSVSAsync const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); - // check early exit if masked and no work to do. - if constexpr(FmhaMask::IsMasking) + // check early exit + if constexpr(FmhaMask::IsMasking || kPadSeqLenK) { if(num_total_loop <= 0) { @@ -367,7 +371,7 @@ struct BlockFmhaPipelineQRKSVSAsync __builtin_amdgcn_sched_barrier(1); // STAGE 2, scale_s, add bias, mask, softmax - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); @@ -383,6 +387,25 @@ struct BlockFmhaPipelineQRKSVSAsync s_acc, bias_tile); } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } else { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); @@ -463,8 +486,9 @@ struct BlockFmhaPipelineQRKSVSAsync static const auto get_validated_m = [](SMPLComputeDataType raw_m) { /// NOTICE: bias might be materialized mask including -inf values, need - /// consideration - if constexpr(kHasBias || FmhaMask::IsMasking) + /// consideration. alibi does not have this problem + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) { return raw_m == -numeric::infinity() ? type_convert(0.f) @@ -485,7 +509,8 @@ struct BlockFmhaPipelineQRKSVSAsync sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } @@ -509,7 +534,8 @@ struct BlockFmhaPipelineQRKSVSAsync constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 const auto tmp = [&]() { - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); } @@ -617,7 +643,8 @@ struct BlockFmhaPipelineQRKSVSAsync sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]); } @@ -661,7 +688,8 @@ struct BlockFmhaPipelineQRKSVSAsync typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, - typename LSEDramBlockWindowTmp> + typename LSEDramBlockWindowTmp, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile @@ -669,6 +697,7 @@ struct BlockFmhaPipelineQRKSVSAsync const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, + PositionEncoding position_encoding, float scale_s, void* smem_ptr) const { @@ -686,6 +715,7 @@ struct BlockFmhaPipelineQRKSVSAsync identity{}, identity{}, mask, + position_encoding, scale_s, smem_ptr); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp index 0e59ee6fe0..80f40f8154 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp" @@ -46,7 +47,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr bool kHasBias = Problem::kHasBias; + static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; // last dimension vector length used to create tensor view(and decide buffer_load vector length) @@ -82,7 +83,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 } else if constexpr(kK0BlockLength <= 128) { - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 1; else return 2; @@ -105,7 +106,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, - typename LSEDramBlockWindowTmp> + typename LSEDramBlockWindowTmp, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile @@ -113,6 +115,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile LSEDramBlockWindowTmp& /*lse_dram_window_tmp*/, // not supported FmhaMask mask, + PositionEncoding /*position_encoding*/, float scale_s, float descale_qk, float descale_sv, @@ -249,13 +252,13 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 k_block_tile = load_tile(k_dram_window); } - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads } const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads @@ -300,7 +303,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 } // STAGE 2, scale_s, add bias, mask, softmax - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { tile_elementwise_inout( [&](auto& x, const auto& y) { @@ -356,7 +359,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 static const auto get_validated_m = [](SMPLComputeDataType raw_m) { /// NOTICE: bias might be materialized mask including -inf values, need /// consideration - if constexpr(kHasBias || FmhaMask::IsMasking) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) { return raw_m == -numeric::infinity() ? type_convert(0.f) @@ -377,7 +381,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } @@ -401,7 +405,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 const auto tmp = [&]() { - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp index 677c05769c..e12e767069 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" namespace ck_tile { @@ -45,7 +46,7 @@ struct BlockFmhaPipelineQSKSVS static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr bool kHasBias = Problem::kHasBias; + static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr index_t kBlockPerCu = []() { @@ -63,7 +64,7 @@ struct BlockFmhaPipelineQSKSVS } else if constexpr(kK0BlockLength <= 128) { - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return 1; else return 2; @@ -99,7 +100,8 @@ struct BlockFmhaPipelineQSKSVS typename LSEElementFunction, typename SAccElementFunction, typename PComputeElementFunction, - typename OAccElementFunction> + typename OAccElementFunction, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const QElementFunction& q_element_func, @@ -115,6 +117,7 @@ struct BlockFmhaPipelineQSKSVS const PComputeElementFunction& p_compute_element_func, const OAccElementFunction& o_acc_element_func, FmhaMask mask, + PositionEncoding position_encoding, float scale_s, void* smem_ptr) const { @@ -265,13 +268,13 @@ struct BlockFmhaPipelineQSKSVS k_block_tile = load_tile(k_dram_window); } - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads } const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads @@ -313,7 +316,7 @@ struct BlockFmhaPipelineQSKSVS } // STAGE 2, scale_s, add bias, mask, softmax - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); @@ -329,6 +332,25 @@ struct BlockFmhaPipelineQSKSVS s_acc, bias_tile); } + else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + s_acc.get_tile_distribution(), make_tuple(idx0, idx1)); + + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); + + s_acc(i_j_idx) *= scale_s; + position_encoding.update(s_acc(i_j_idx), row, col); + }); + }); + } else { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); @@ -373,7 +395,8 @@ struct BlockFmhaPipelineQSKSVS static const auto get_validated_m = [](SMPLComputeDataType raw_m) { /// NOTICE: bias might be materialized mask including -inf values, need /// consideration - if constexpr(kHasBias || FmhaMask::IsMasking) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) { return raw_m == -numeric::infinity() ? type_convert(0.f) @@ -394,7 +417,8 @@ struct BlockFmhaPipelineQSKSVS sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx])); } @@ -418,7 +442,8 @@ struct BlockFmhaPipelineQSKSVS constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 const auto tmp = [&]() { - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { return exp2(m_old[i_idx] - get_validated_m(m[i_idx])); } @@ -510,7 +535,8 @@ struct BlockFmhaPipelineQSKSVS sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) { constexpr auto i_idx = make_tuple(idx0); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + BiasEnum == BlockAttentionBiasEnum::ALIBI) { lse(i_idx) = m_[i_idx] / C_LOG2E + log(l_[i_idx]); } @@ -554,7 +580,8 @@ struct BlockFmhaPipelineQSKSVS typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, - typename LSEDramBlockWindowTmp> + typename LSEDramBlockWindowTmp, + typename PositionEncoding> CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile @@ -562,6 +589,7 @@ struct BlockFmhaPipelineQSKSVS const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile FmhaMask mask, + PositionEncoding position_encoding, float scale_s, void* smem_ptr) const { @@ -579,6 +607,7 @@ struct BlockFmhaPipelineQSKSVS identity{}, identity{}, mask, + position_encoding, scale_s, smem_ptr); } diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index 137f4ddd81..6cb6449f16 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" namespace ck_tile { @@ -11,7 +12,7 @@ template @@ -21,7 +22,7 @@ struct TileFmhaTraits static constexpr bool kPadSeqLenK = kPadSeqLenK_; static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; static constexpr bool kPadHeadDimV = kPadHeadDimV_; - static constexpr bool kHasBias = kHasBias_; + static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kStoreLSE = kStoreLSE_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr index_t kBlockPerCu = kBlockPerCu_; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 33aa10df72..25c63ac7fe 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -181,3 +181,4 @@ add_subdirectory(wrapper) if(GPU_TARGETS MATCHES "gfx11") add_subdirectory(wmma_op) endif() +add_subdirectory(position_embedding) diff --git a/test/position_embedding/CMakeLists.txt b/test/position_embedding/CMakeLists.txt new file mode 100644 index 0000000000..e7a939bebb --- /dev/null +++ b/test/position_embedding/CMakeLists.txt @@ -0,0 +1 @@ +add_test_executable(test_position_embedding position_embedding.cpp) diff --git a/test/position_embedding/position_embedding.cpp b/test/position_embedding/position_embedding.cpp new file mode 100644 index 0000000000..e295ec454a --- /dev/null +++ b/test/position_embedding/position_embedding.cpp @@ -0,0 +1,215 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha.hpp" + +#ifndef TEST_ALIBI_VERBOSE +#define TEST_ALIBI_VERBOSE 0 +#endif + +template +struct attention_score +{ + ck_tile::index_t rows, cols; + std::vector pixels; + + attention_score(ck_tile::index_t rows_, + ck_tile::index_t cols_, + DataType init_v_ = static_cast(0)) + : rows(rows_), cols(cols_), pixels(rows_ * cols_, init_v_) + { + } + + auto& operator()(ck_tile::index_t i_row, ck_tile::index_t i_col) + { + return pixels[i_row * cols + i_col]; + } + + void print() + { + for(auto i_row = 0; i_row < rows; i_row++) + { + for(auto i_col = 0; i_col < cols; i_col++) + { + std::cout << pixels[i_row * cols + i_col] << " "; + } + std::cout << std::endl; + } + } +}; + +template +void alibi_traverse_with_slope(attention_score& score, + DataType slope, + ck_tile::AlibiMode mode = ck_tile::AlibiMode::VERTICAL) +{ + using Alibi = ck_tile::Alibi; + auto alibi = Alibi{slope, score.rows, score.cols, mode}; + + for(ck_tile::index_t i_row = 0; i_row < score.rows; i_row++) + { + for(ck_tile::index_t i_col = 0; i_col < score.cols; i_col++) + { + alibi.update(score(i_row, i_col), i_row, i_col); + } + } +} + +std::string alibi_mode_to_str(ck_tile::AlibiMode mode) +{ + if(mode == ck_tile::AlibiMode::VERTICAL) + return std::string("alibi_verti"); + else if(mode == ck_tile::AlibiMode::FROM_TOP_LEFT) + return std::string("alibi_top-l"); + else if(mode == ck_tile::AlibiMode::FROM_BOTTOM_RIGHT) + return std::string("alibi_bot-r"); + return ""; +} + +template +bool test_alibi_traverse_with_slope(ck_tile::index_t rows, + ck_tile::index_t cols, + DataType slope, + ck_tile::AlibiMode mode, + const std::vector& expected) +{ + attention_score score{rows, cols}; + alibi_traverse_with_slope(score, slope, mode); + + bool is_match = std::equal(score.pixels.begin(), score.pixels.end(), expected.begin()); +#if TEST_ALIBI_VERBOSE + std::cout << "---------" << alibi_mode_to_str(mode) << ", " << rows << "x" << cols << "(" + << (RowMajor ? "row_major" : "col_major") << ")" + << (is_match ? ", valie:y" : ", valid:n") << std::endl; + score.print(); +#endif + return is_match; +} + +template +bool test_alibi_slope_generation(ck_tile::index_t nheads, const std::vector& expected) +{ + auto slopes = ck_tile::get_alibi_slopes(nheads); + + bool is_match = std::equal(slopes.begin(), + slopes.end(), + expected.begin(), + expected.end(), + [](const DataType& lhs, const DataType& rhs) { + constexpr float rtol = 1e-6; + auto error = std::abs(lhs - rhs); + return error < rtol * std::abs(rhs); + }); +#if TEST_ALIBI_VERBOSE + std::cout << "-------------------- slopes " << nheads << ", " << (is_match ? "y" : "n") + << std::endl; + for(ck_tile::index_t i = 0; i < nheads; i++) + { + std::cout << slopes[i] << " "; + } + std::cout << std::endl; +#endif + return is_match; +} + +int main() +{ + using dtype = int32_t; + dtype slope = static_cast(1); + + bool rtn = true; + + // clang-format off + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::VERTICAL, {0, 1, 2, 3, 4, 5, + 0, 1, 2, 3, 4, 5, + 0, 1, 2, 3, 4, 5, + 0, 1, 2, 3, 4, 5}); + + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, 4, 5, + 1, 0, 1, 2, 3, 4, + 2, 1, 0, 1, 2, 3, + 3, 2, 1, 0, 1, 2}); + + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, + 1, 0, 1, 2, + 2, 1, 0, 1, + 3, 2, 1, 0, + 4, 3, 2, 1, + 5, 4, 3, 2}); + + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, + 1, 0, 1, + 2, 1, 0}); + + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 1, 0, 1, 2, 3, + 3, 2, 1, 0, 1, 2, + 4, 3, 2, 1, 0, 1, + 5, 4, 3, 2, 1, 0}); + + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 3, 4, 5, + 1, 2, 3, 4, + 0, 1, 2, 3, + 1, 0, 1, 2, + 2, 1, 0, 1, + 3, 2, 1, 0}); + + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {0, 1, 2, + 1, 0, 1, + 2, 1, 0}); + + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::VERTICAL, {0, 1, 2, 3, 4, 5, + 0, 1, 2, 3, 4, 5, + 0, 1, 2, 3, 4, 5, + 0, 1, 2, 3, 4, 5}); + + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, 4, 5, + 1, 0, 1, 2, 3, 4, + 2, 1, 0, 1, 2, 3, + 3, 2, 1, 0, 1, 2}); + + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, + 1, 0, 1, 2, + 2, 1, 0, 1, + 3, 2, 1, 0, + 4, 3, 2, 1, + 5, 4, 3, 2}); + + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, + 1, 0, 1, + 2, 1, 0}); + + rtn &= test_alibi_traverse_with_slope(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 1, 0, 1, 2, 3, + 3, 2, 1, 0, 1, 2, + 4, 3, 2, 1, 0, 1, + 5, 4, 3, 2, 1, 0}); + + rtn &= test_alibi_traverse_with_slope(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 3, 4, 5, + 1, 2, 3, 4, + 0, 1, 2, 3, + 1, 0, 1, 2, + 2, 1, 0, 1, + 3, 2, 1, 0}); + + rtn &= test_alibi_traverse_with_slope(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {0, 1, 2, + 1, 0, 1, + 2, 1, 0}); + + rtn &= test_alibi_slope_generation(8, {0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 0.0078125, 0.00390625}); + rtn &= test_alibi_slope_generation(16, {0.7071067811865476, 0.5, 0.35355339059327384, 0.25000000000000006, 0.17677669529663692, + 0.12500000000000006, 0.08838834764831849, 0.06250000000000004, 0.044194173824159244, + 0.03125000000000002, 0.022097086912079626, 0.01562500000000001, 0.011048543456039816, + 0.007812500000000007, 0.005524271728019908, 0.003906250000000004}); + rtn &= test_alibi_slope_generation(1, {0.00390625}); + rtn &= test_alibi_slope_generation(5, {0.25, 0.0625, 0.015625, 0.00390625, 0.5}); + rtn &= test_alibi_slope_generation(6, {0.25, 0.0625, 0.015625, 0.00390625, 0.5, 0.125}); + rtn &= test_alibi_slope_generation(7, {0.25, 0.0625, 0.015625, 0.00390625, 0.5, 0.125, 0.03125}); + rtn &= test_alibi_slope_generation(9, {0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 0.0078125, 0.00390625, 0.7071067811865476}); + // clang-format on + return rtn ? 0 : -1; +}