diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 312ea26e11..04d9cc88d7 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -66,11 +66,13 @@ def get_mask_check_map(mask: str): QSCALE_MAP = { "no": "ck_tile::BlockAttentionQuantScaleEnum::NO_SCALE", "pertensor": "ck_tile::BlockAttentionQuantScaleEnum::PERTENSOR", + "blockscale": "ck_tile::BlockAttentionQuantScaleEnum::BLOCKSCALE", } QSCALE_CHECK_MAP = { "no": "quant_scale_enum::no_scale", "pertensor": "quant_scale_enum::pertensor", + "blockscale": "quant_scale_enum::blockscale", } BIAS_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 902bae20e1..009dd4b9c8 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -343,7 +343,7 @@ class FmhaFwdPipeline: F_bias: str # true/false F_lse: str # F_dropout: str # - F_qscale: str # no/pertensor + F_qscale: str # no/pertensor/blockscale F_mask: str # value from MASK_MAP F_skip: str # true/false F_trload: str # true/false @@ -739,7 +739,7 @@ class KernelComponentFactoryGfx9: # no need lse/dropout kernels for logits, qscale, mask, bias in itertools.product( ["f"], - ["no", "pertensor"], + ["no", "pertensor", "blockscale"], get_mask_map(mask_impl).keys(), ["no"], ): @@ -830,7 +830,7 @@ class KernelComponentFactoryGfx12: elif dtype in ["fp8", "fp8bf16", "fp8fp32"]: # no need lse/dropout kernels for logits, qscale, mask, bias in itertools.product( - ["f"], ["no", "pertensor"], get_mask_map(mask_impl).keys(), ["no"] + ["f"], ["no", "pertensor", "blockscale"], get_mask_map(mask_impl).keys(), ["no"] ): pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, "f", "f", qscale, mask, "f", "f")) # fmt: skip diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index b628fa1d87..16cb1a0566 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -256,6 +256,9 @@ struct fmha_fwd_args ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_o; + ck_tile::index_t nhead_stride_q_descale; + ck_tile::index_t nhead_stride_k_descale; + ck_tile::index_t nhead_stride_v_descale; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; @@ -263,6 +266,9 @@ struct fmha_fwd_args ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_o; + ck_tile::index_t batch_stride_q_descale; + ck_tile::index_t batch_stride_k_descale; + ck_tile::index_t batch_stride_v_descale; ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; @@ -274,6 +280,9 @@ struct fmha_fwd_args std::variant, std::pair> drop_seed_offset; + + ck_tile::index_t block_scale_m; + ck_tile::index_t block_scale_n; }; struct fmha_fwd_pagedkv_args @@ -604,6 +613,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.stride_bias, args.stride_randval, args.stride_o, + args.nhead_stride_q_descale, + args.nhead_stride_k_descale, + args.nhead_stride_v_descale, args.nhead_stride_q, args.nhead_stride_k, args.nhead_stride_v, @@ -618,6 +630,8 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.p_drop, args.s_randval, args.drop_seed_offset, + args.block_scale_m, + args.block_scale_n, args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr); } @@ -654,6 +668,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, + args.nhead_stride_q_descale, + args.nhead_stride_k_descale, + args.nhead_stride_v_descale, args.batch_stride_q, args.batch_stride_k, args.batch_stride_v, @@ -661,12 +678,17 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.batch_stride_randval, args.batch_stride_lse, args.batch_stride_o, + args.batch_stride_q_descale, + args.batch_stride_k_descale, + args.batch_stride_v_descale, args.window_size_left, args.window_size_right, args.mask_type, args.p_drop, args.s_randval, args.drop_seed_offset, + args.block_scale_m, + args.block_scale_n, args.cu_seqlen_q_ptr, args.cu_seqlen_k_ptr); } diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 6af535a70f..c1e8ac6f4e 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -187,6 +187,9 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::stream_config& stream_config, std::optional json = std::nullopt) { + constexpr ck_tile::index_t block_scale_m_ = 128; + constexpr ck_tile::index_t block_scale_n_ = 128; + const std::string data_type = []() { if constexpr(std::is_same_v) return "fp32"; @@ -448,7 +451,9 @@ fwd_result fmha_fwd_run(mode_enum mode, std::size_t flop = 0, num_byte = 0; auto max_seqlen_q = std::numeric_limits::min(); // we will use max seqlen to decide grid size - auto max_seqlen_k = std::numeric_limits::min(); + size_t num_block_scale_q = 0; + size_t num_block_scale_k = 0; + auto max_seqlen_k = std::numeric_limits::min(); { for(ck_tile::index_t wb = 0; wb < batch; ++wb) { @@ -464,6 +469,8 @@ fwd_result fmha_fwd_run(mode_enum mode, { max_seqlen_k = real_seqlen_k; } + num_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_m_); + num_block_scale_k += ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_n_); flop += nhead * (static_cast(2) * mask.get_unmaskarea() * hdim_q + static_cast(2) * mask.get_unmaskarea() * hdim_v); @@ -525,6 +532,13 @@ fwd_result fmha_fwd_run(mode_enum mode, ? seqstart_k_with_padding_host.back() : seqstart_k_host.back())); + const ck_tile::index_t num_block_scale_m = + (mode == mode_enum::batch) ? ck_tile::integer_divide_ceil(shape_seqlen_q, block_scale_m_) + : num_block_scale_q; + const ck_tile::index_t num_block_scale_n = + (mode == mode_enum::batch) ? ck_tile::integer_divide_ceil(shape_seqlen_k, block_scale_n_) + : num_block_scale_k; + ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); ck_tile::HostTensor k_host( @@ -575,9 +589,18 @@ fwd_result fmha_fwd_run(mode_enum mode, : std::array{1, 1, 1, 1, 1}); // TODO - change the tensor length for different quant scale - ck_tile::HostTensor q_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); - ck_tile::HostTensor k_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); - ck_tile::HostTensor v_descale_host(get_lengths(i_perm, 1, 1, 1, 1)); + ck_tile::HostTensor q_descale_host( + qscale.type == quant_scale_enum::blockscale + ? std::array{shape_batch, nhead, num_block_scale_m} + : std::array{1, 1, 1}); + ck_tile::HostTensor k_descale_host( + qscale.type == quant_scale_enum::blockscale + ? std::array{shape_batch, nhead, num_block_scale_n} + : std::array{1, 1, 1}); + ck_tile::HostTensor v_descale_host( + qscale.type == quant_scale_enum::blockscale + ? std::array{shape_batch, nhead, num_block_scale_n} + : std::array{1, 1, 1}); // batch mode of lse data layout is [batch, nhead, seqlen_q] // group mode of lse data layout is [nhead, total_seqlen_q] @@ -692,6 +715,12 @@ fwd_result fmha_fwd_run(mode_enum mode, k_descale_host(0) = qkv_max / k_dtype_max; v_descale_host(0) = qkv_max / v_dtype_max; } + else if(qscale.type == quant_scale_enum::blockscale) + { + ck_tile::FillUniformDistribution{0.015f, 0.02f, next_seed()}(q_descale_host); + ck_tile::FillUniformDistribution{0.015f, 0.02f, next_seed()}(k_descale_host); + ck_tile::FillUniformDistribution{0.015f, 0.02f, next_seed()}(v_descale_host); + } iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine); iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine); @@ -941,11 +970,14 @@ fwd_result fmha_fwd_run(mode_enum mode, }(); const ck_tile::index_t nhead_stride_bias = (i_perm ? 0 * shape_seqlen_q * max_seqlen_k : 0 * max_seqlen_k); - const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); - const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; - const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); - const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); - const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k); + const ck_tile::index_t nhead_stride_lse = shape_seqlen_q; + const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q); + const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_q_descale = num_block_scale_m; + const ck_tile::index_t nhead_stride_k_descale = num_block_scale_n; + const ck_tile::index_t nhead_stride_v_descale = num_block_scale_n; // setup batch_stride_* arguments const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_k = @@ -963,6 +995,9 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); + const ck_tile::index_t batch_stride_q_descale = num_block_scale_m * nhead; + const ck_tile::index_t batch_stride_k_descale = num_block_scale_n * nhead_k; + const ck_tile::index_t batch_stride_v_descale = num_block_scale_n * nhead_k; // setup split_stride_* arguments (only used in split-kv kernel) const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q); const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v); @@ -1046,9 +1081,32 @@ fwd_result fmha_fwd_run(mode_enum mode, if constexpr(std::is_same_v>) { - args.q_descale_ptr = q_descale_buf.GetDeviceBuffer(); - args.k_descale_ptr = k_descale_buf.GetDeviceBuffer(); - args.v_descale_ptr = v_descale_buf.GetDeviceBuffer(); + if(qscale.type == quant_scale_enum::blockscale) + { + args.q_descale_ptr = + reinterpret_cast(q_descale_buf.GetDeviceBuffer()); + args.k_descale_ptr = + reinterpret_cast(k_descale_buf.GetDeviceBuffer()); + args.v_descale_ptr = + reinterpret_cast(v_descale_buf.GetDeviceBuffer()); + + args.nhead_stride_q_descale = nhead_stride_q_descale; + args.nhead_stride_k_descale = nhead_stride_k_descale; + args.nhead_stride_v_descale = nhead_stride_v_descale; + + args.batch_stride_q_descale = batch_stride_q_descale; + args.batch_stride_k_descale = batch_stride_k_descale; + args.batch_stride_v_descale = batch_stride_v_descale; + + args.block_scale_m = block_scale_m_; + args.block_scale_n = block_scale_n_; + } + else + { + args.q_descale_ptr = q_descale_buf.GetDeviceBuffer(); + args.k_descale_ptr = k_descale_buf.GetDeviceBuffer(); + args.v_descale_ptr = v_descale_buf.GetDeviceBuffer(); + } args.rand_val_ptr = randval_buf.GetDeviceBuffer(); @@ -1788,31 +1846,33 @@ fwd_result fmha_fwd_run(mode_enum mode, if(json) { - dump_fmha_fwd_json_results(*json, - data_type, - mode == mode_enum::batch ? "batch" : "group", - io_layout(i_perm, o_perm), - batch, - nhead, - nhead_k, - seqlen_qs[0], - seqlen_ks[0], - seqlen_kpads[0], - hdim_q, - hdim_v, - scale_s, - p_drop, - lse, - qscale.type == quant_scale_enum::no_scale ? "no_scale" - : "pertensor", - bias.type == bias_enum::elementwise_bias - ? "elementwise_bias" - : (bias.type == bias_enum::alibi ? "alibi" : "no_bias"), - is_v_rowmajor ? "r" : "c", - pass, - ave_time, - tflops, - gb_per_sec); + dump_fmha_fwd_json_results( + *json, + data_type, + mode == mode_enum::batch ? "batch" : "group", + io_layout(i_perm, o_perm), + batch, + nhead, + nhead_k, + seqlen_qs[0], + seqlen_ks[0], + seqlen_kpads[0], + hdim_q, + hdim_v, + scale_s, + p_drop, + lse, + qscale.type == quant_scale_enum::no_scale + ? "no_scale" + : (qscale.type == quant_scale_enum::pertensor ? "pertensor" : "blockscale"), + bias.type == bias_enum::elementwise_bias + ? "elementwise_bias" + : (bias.type == bias_enum::alibi ? "alibi" : "no_bias"), + is_v_rowmajor ? "r" : "c", + pass, + ave_time, + tflops, + gb_per_sec); } return pass ? fwd_result::success : fwd_result::failure; diff --git a/example/ck_tile/01_fmha/quant.hpp b/example/ck_tile/01_fmha/quant.hpp index 35461cc53d..667972e67d 100644 --- a/example/ck_tile/01_fmha/quant.hpp +++ b/example/ck_tile/01_fmha/quant.hpp @@ -13,6 +13,7 @@ enum class quant_scale_enum { no_scale = 0, pertensor = 1, + blockscale, }; struct quant_scale_info @@ -25,6 +26,8 @@ struct quant_scale_info os << "n"; else if(type == quant_scale_enum::pertensor) os << "pt"; + else if(type == quant_scale_enum::blockscale) + os << "bs"; } static quant_scale_info decode(std::string str) @@ -38,6 +41,10 @@ struct quant_scale_info { info.type = quant_scale_enum::pertensor; } + else if(str == "bs" || str == "2") + { + info.type = quant_scale_enum::blockscale; + } else { throw std::invalid_argument("invalid quant scale value: " + str); diff --git a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp index 4d80443f35..733fbc4c92 100644 --- a/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp +++ b/include/ck_tile/ops/fmha/block/block_attention_quant_scale_enum.hpp @@ -12,6 +12,7 @@ enum class BlockAttentionQuantScaleEnum { NO_SCALE = 0, PERTENSOR = 1, + BLOCKSCALE, }; template @@ -27,5 +28,11 @@ struct BlockAttentionQuantScaleEnumToStr +struct BlockAttentionQuantScaleEnumToStr +{ + static constexpr const char* name = "blockscale"; +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 85a91ad19e..fb8dd4302f 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -214,6 +214,23 @@ struct FmhaFwdKernel const void* v_descale_ptr = nullptr; }; + struct FmhaFwdCommonBlockScaleKargs : public FmhaFwdCommonQScaleKargs + { + ck_tile::index_t nhead_stride_q_descale; + ck_tile::index_t nhead_stride_k_descale; + ck_tile::index_t nhead_stride_v_descale; + + ck_tile::index_t block_scale_m; + ck_tile::index_t block_scale_n; + }; + + struct FmhaFwdBatchBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs + { + ck_tile::index_t batch_stride_q_descale; + ck_tile::index_t batch_stride_k_descale; + ck_tile::index_t batch_stride_v_descale; + }; + struct FmhaFwdCommonLSEKargs { void* lse_ptr = nullptr; @@ -289,9 +306,12 @@ struct FmhaFwdKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t>, + std::conditional_t< + QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR, + FmhaFwdCommonQScaleKargs, + std::conditional_t>>, std::conditional_t>, std::conditional_t> { @@ -315,9 +335,12 @@ struct FmhaFwdKernel FmhaFwdEmptyKargs<0>>>, std::conditional_t>, std::conditional_t>, - std::conditional_t>, + std::conditional_t< + QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR, + FmhaFwdCommonQScaleKargs, + std::conditional_t>>, std::conditional_t>, std::conditional_t>, std::conditional_t> @@ -374,6 +397,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -381,6 +407,9 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_q_descale, + ck_tile::index_t batch_stride_k_descale, + ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, @@ -388,6 +417,8 @@ struct FmhaFwdKernel bool s_randval, std::variant, std::pair> drop_seed_offset, + ck_tile::index_t block_scale_m, + ck_tile::index_t block_scale_n, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr) { @@ -455,6 +486,23 @@ struct FmhaFwdKernel kargs.k_descale_ptr = k_descale_ptr; kargs.v_descale_ptr = v_descale_ptr; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + + kargs.nhead_stride_q_descale = nhead_stride_q_descale; + kargs.nhead_stride_k_descale = nhead_stride_k_descale; + kargs.nhead_stride_v_descale = nhead_stride_v_descale; + + kargs.batch_stride_q_descale = batch_stride_q_descale; + kargs.batch_stride_k_descale = batch_stride_k_descale; + kargs.batch_stride_v_descale = batch_stride_v_descale; + + kargs.block_scale_m = block_scale_m; + kargs.block_scale_n = block_scale_n; + } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -520,6 +568,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -527,12 +578,17 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_q_descale, + ck_tile::index_t batch_stride_k_descale, + ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple& drop_seed_offset, + ck_tile::index_t block_scale_m, + ck_tile::index_t block_scale_n, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr) { @@ -568,6 +624,9 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_descale, + nhead_stride_k_descale, + nhead_stride_v_descale, batch_stride_q, batch_stride_k, batch_stride_v, @@ -575,12 +634,17 @@ struct FmhaFwdKernel batch_stride_randval, batch_stride_lse, batch_stride_o, + batch_stride_q_descale, + batch_stride_k_descale, + batch_stride_v_descale, window_size_left, window_size_right, mask_type, p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + block_scale_m, + block_scale_n, cu_seqlen_q_ptr, cu_seqlen_k_ptr); } @@ -619,6 +683,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -626,12 +693,17 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_q_descale, + ck_tile::index_t batch_stride_k_descale, + ck_tile::index_t batch_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple& drop_seed_offset, + ck_tile::index_t block_scale_m, + ck_tile::index_t block_scale_n, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr) { @@ -667,6 +739,9 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_descale, + nhead_stride_k_descale, + nhead_stride_v_descale, batch_stride_q, batch_stride_k, batch_stride_v, @@ -674,12 +749,17 @@ struct FmhaFwdKernel batch_stride_randval, batch_stride_lse, batch_stride_o, + batch_stride_q_descale, + batch_stride_k_descale, + batch_stride_v_descale, window_size_left, window_size_right, mask_type, p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + block_scale_m, + block_scale_n, cu_seqlen_q_ptr, cu_seqlen_k_ptr); } @@ -719,6 +799,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, @@ -727,6 +810,8 @@ struct FmhaFwdKernel bool s_randval, std::variant, std::pair> drop_seed_offset, + ck_tile::index_t block_scale_m, + ck_tile::index_t block_scale_n, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr) { @@ -793,6 +878,19 @@ struct FmhaFwdKernel kargs.k_descale_ptr = k_descale_ptr; kargs.v_descale_ptr = v_descale_ptr; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + kargs.q_descale_ptr = q_descale_ptr; + kargs.k_descale_ptr = k_descale_ptr; + kargs.v_descale_ptr = v_descale_ptr; + + kargs.nhead_stride_q_descale = nhead_stride_q_descale; + kargs.nhead_stride_k_descale = nhead_stride_k_descale; + kargs.nhead_stride_v_descale = nhead_stride_v_descale; + + kargs.block_scale_m = block_scale_m; + kargs.block_scale_n = block_scale_n; + } if constexpr(kHasDropout) { if(drop_seed_offset.index() == 0) // seed & offset come from host @@ -863,6 +961,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, @@ -870,6 +971,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, + ck_tile::index_t block_scale_m, + ck_tile::index_t block_scale_n, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr) { @@ -907,6 +1010,9 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_descale, + nhead_stride_k_descale, + nhead_stride_v_descale, window_size_left, window_size_right, mask_type, @@ -914,6 +1020,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + block_scale_m, + block_scale_n, cu_seqlen_q_ptr, cu_seqlen_k_ptr); } @@ -954,6 +1062,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_descale, + ck_tile::index_t nhead_stride_k_descale, + ck_tile::index_t nhead_stride_v_descale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, @@ -961,6 +1072,8 @@ struct FmhaFwdKernel float p_drop, bool s_randval, const std::tuple& drop_seed_offset, + ck_tile::index_t block_scale_m, + ck_tile::index_t block_scale_n, const void* cu_seqlen_q_ptr = nullptr, const void* cu_seqlen_k_ptr = nullptr) { @@ -998,6 +1111,9 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_descale, + nhead_stride_k_descale, + nhead_stride_v_descale, window_size_left, window_size_right, mask_type, @@ -1005,6 +1121,8 @@ struct FmhaFwdKernel p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + block_scale_m, + block_scale_n, cu_seqlen_q_ptr, cu_seqlen_k_ptr); }