From 9b341c5d6f5102e3670e79e78c03d02ccec7023a Mon Sep 17 00:00:00 2001 From: ltqin Date: Thu, 6 Nov 2025 08:01:41 +0000 Subject: [PATCH] add batch block scale parameters to kernel --- example/ck_tile/01_fmha/fmha_fwd.hpp | 24 +++++ example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 102 +++++++++++------- .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 98 ++++++++++++++++- 3 files changed, 187 insertions(+), 37 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 761def6d6a..2029845cab 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -196,6 +196,10 @@ struct fmha_fwd_args const void* seqstart_padded_q_ptr = nullptr; // [batch+1] const void* seqstart_padded_k_ptr = nullptr; // [batch+1] + const float* q_scale_ptr = nullptr; + const float* k_scale_ptr = nullptr; + const float* v_scale_ptr = nullptr; + ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_k; ck_tile::index_t batch; @@ -224,6 +228,9 @@ struct fmha_fwd_args ck_tile::index_t nhead_stride_randval; ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_o; + ck_tile::index_t nhead_stride_q_scale; + ck_tile::index_t nhead_stride_k_scale; + ck_tile::index_t nhead_stride_v_scale; ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_v; @@ -231,12 +238,18 @@ struct fmha_fwd_args ck_tile::index_t batch_stride_randval; ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_o; + ck_tile::index_t batch_stride_q_scale; + ck_tile::index_t batch_stride_k_scale; + ck_tile::index_t batch_stride_v_scale; ck_tile::index_t window_size_left; ck_tile::index_t window_size_right; ck_tile::index_t mask_type; ck_tile::index_t min_seqlen_q; + ck_tile::index_t block_scale_m; + ck_tile::index_t block_scale_n; + float p_drop; bool s_randval; @@ -596,6 +609,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.rand_val_ptr, args.lse_ptr, args.o_ptr, + args.q_scale_ptr, + args.k_scale_ptr, + args.v_scale_ptr, args.seqlen_q, args.seqlen_k, args.hdim_q, @@ -619,6 +635,9 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.nhead_stride_randval, args.nhead_stride_lse, args.nhead_stride_o, + args.nhead_stride_q_scale, + args.nhead_stride_k_scale, + args.nhead_stride_v_scale, args.batch_stride_q, args.batch_stride_k, args.batch_stride_v, @@ -626,12 +645,17 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) args.batch_stride_randval, args.batch_stride_lse, args.batch_stride_o, + args.batch_stride_q_scale, + args.batch_stride_k_scale, + args.batch_stride_v_scale, args.window_size_left, args.window_size_right, args.mask_type, args.p_drop, args.s_randval, args.drop_seed_offset, + args.block_scale_m, + args.block_scale_n, args.cu_seqlen_q_ptr, args.cu_seqlen_kv_ptr); } diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index b8c86819d6..16f6df528e 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -159,27 +159,29 @@ class BlockQuantizer template void quantize(const SrcTensor& in, DstTensor& out, ScaleTensor& block_scale, size_t block_size_) { - using InDataType = typename std::remove_reference_t::DataType; - using OutDataType = typename std::remove_reference_t::DataType; - float dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - size_t batch = in.get_length(0); - size_t head = in.get_length(i_perm ? 1 : 2); - size_t seq_len = in.get_length(i_perm ? 2 : 1); - size_t hdim = in.get_length(3); + using InDataType = typename std::remove_reference_t::DataType; + using OutDataType = typename std::remove_reference_t::DataType; + float dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + size_t batch = in.get_length(0); + size_t head = in.get_length(i_perm ? 1 : 2); + size_t seq_len = in.get_length(i_perm ? 2 : 1); + size_t hdim = in.get_length(3); size_t num_blocks_ = (seq_len + block_size_ - 1) / block_size_; - std::cout << "batch: " << batch << " head: " << head << " seq_len: " << seq_len - << " hdim: " << hdim << " dtype_max: " << dtype_max - << " num_blocks_: " << num_blocks_ << std::endl; + // std::cout << "batch: " << batch << " head: " << head << " seq_len: " << seq_len + // << " hdim: " << hdim << " dtype_max: " << dtype_max + // << " num_blocks_: " << num_blocks_ << std::endl; std::random_device rd; std::mt19937 gen(rd()); std::uniform_real_distribution dis(0.5, 2.0f); - for(size_t b = 0; b < batch; ++b){ + for(size_t b = 0; b < batch; ++b) + { for(size_t h = 0; h < head; ++h) { for(size_t block = 0; block < num_blocks_; ++block) { // get block max value - float max_value = ck_tile::type_convert(ck_tile::numeric::min()); + float max_value = + ck_tile::type_convert(ck_tile::numeric::min()); for(size_t s = block * block_size_; s < (block + 1) * block_size_ && s < seq_len; ++s) @@ -196,11 +198,12 @@ class BlockQuantizer } // calculate block scale max_value += dis(gen); - float scale = dtype_max / max_value; - block_scale(b,h,block) = scale; - std::cout << "block: " << block << " scale: " << scale << " max_value: " << max_value << " block_scale: " << block_scale << std::endl; - - //quant + float scale = dtype_max / max_value; + block_scale(b, h, block) = scale; + // std::cout << "block: " << block << " scale: " << scale << " max_value: " << + // max_value << " block_scale: " << block_scale << std::endl; + + // quant for(size_t s = block * block_size_; s < (block + 1) * block_size_ && s < seq_len; ++s) @@ -211,26 +214,28 @@ class BlockQuantizer if(!i_perm) idx = {b, s, h, d}; float val = ck_tile::type_convert(in(idx)); - out(idx) = ck_tile::type_convert(val * scale); + out(idx) = ck_tile::type_convert(val * scale); } } - } + } } } } template - void dequantize(const SrcTensor& in, DstTensor& out, ScaleTensor& block_scale, size_t block_size_) + void + dequantize(const SrcTensor& in, DstTensor& out, ScaleTensor& block_scale, size_t block_size_) { - using OutDataType = typename std::remove_reference_t::DataType; - size_t batch = in.get_length(0); - size_t head = in.get_length(i_perm ? 1 : 2); - size_t seq_len = in.get_length(i_perm ? 2 : 1); - size_t hdim = in.get_length(3); + using OutDataType = typename std::remove_reference_t::DataType; + size_t batch = in.get_length(0); + size_t head = in.get_length(i_perm ? 1 : 2); + size_t seq_len = in.get_length(i_perm ? 2 : 1); + size_t hdim = in.get_length(3); size_t num_blocks_ = (seq_len + block_size_ - 1) / block_size_; - //dequant - for(size_t b = 0; b < batch; ++b){ + // dequant + for(size_t b = 0; b < batch; ++b) + { for(size_t h = 0; h < head; ++h) { for(size_t block = 0; block < num_blocks_; ++block) @@ -245,14 +250,13 @@ class BlockQuantizer std::vector idx = {b, h, s, d}; if(!i_perm) idx = {b, s, h, d}; - float val = ck_tile::type_convert(in(idx)); - out(idx) = ck_tile::type_convert(val / scale); + float val = ck_tile::type_convert(in(idx)); + out(idx) = ck_tile::type_convert(val / scale); } } } } } - } }; @@ -651,8 +655,8 @@ fwd_result fmha_fwd_run(mode_enum mode, 0 < page_block_size ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q) : get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); - ck_tile::HostTensor k_scale(std::array{ - shape_batch, nhead_k, num_block_scale_n}); + ck_tile::HostTensor k_scale( + std::array{shape_batch, nhead_k, num_block_scale_n}); /// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode ck_tile::HostTensor knew_host( 0 < seqlen_knew @@ -665,8 +669,8 @@ fwd_result fmha_fwd_run(mode_enum mode, : get_lengths(i_perm, max_num_page_blocks, nhead_k, hdim_v, page_block_size)) : (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k))); - ck_tile::HostTensor v_scale(std::array{ - shape_batch, nhead_k, num_block_scale_n}); + ck_tile::HostTensor v_scale( + std::array{shape_batch, nhead_k, num_block_scale_n}); ck_tile::HostTensor vnew_host( 0 < seqlen_knew ? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v) @@ -906,7 +910,6 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem k_scale_buf(k_scale.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_scale_buf(v_scale.get_element_space_size_in_bytes()); - q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); v_buf.ToDevice(v_host.data()); @@ -940,7 +943,7 @@ fwd_result fmha_fwd_run(mode_enum mode, if(quant == 2) { - //dequant data for host + // dequant data for host BlockQuantizer quantizer(i_perm); // q_host.savetxt("./q_quant.txt"); quantizer.dequantize(q_host, q_host, q_scale, block_scale_m_); @@ -1125,6 +1128,9 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q_lse); const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + const ck_tile::index_t nhead_stride_q_scale = num_block_scale_m; + const ck_tile::index_t nhead_stride_k_scale = num_block_scale_n; + const ck_tile::index_t nhead_stride_v_scale = num_block_scale_n; // setup batch_stride_* arguments const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); const ck_tile::index_t batch_stride_k = @@ -1142,6 +1148,9 @@ fwd_result fmha_fwd_run(mode_enum mode, const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch); + const ck_tile::index_t batch_stride_q_scale = num_block_scale_m * nhead; + const ck_tile::index_t batch_stride_k_scale = num_block_scale_n * nhead_k; + const ck_tile::index_t batch_stride_v_scale = num_block_scale_n * nhead_k; // setup split_stride_* arguments (only used in split-kv kernel) const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q); const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v); @@ -1235,6 +1244,27 @@ fwd_result fmha_fwd_run(mode_enum mode, if constexpr(std::is_same_v>) { + if(quant == 2) + { + args.q_scale_ptr = + reinterpret_cast(q_scale_buf.GetDeviceBuffer()); + args.k_scale_ptr = + reinterpret_cast(k_scale_buf.GetDeviceBuffer()); + args.v_scale_ptr = + reinterpret_cast(v_scale_buf.GetDeviceBuffer()); + + args.nhead_stride_q_scale = nhead_stride_q_scale; + args.nhead_stride_k_scale = nhead_stride_k_scale; + args.nhead_stride_v_scale = nhead_stride_v_scale; + + args.batch_stride_q_scale = batch_stride_q_scale; + args.batch_stride_k_scale = batch_stride_k_scale; + args.batch_stride_v_scale = batch_stride_v_scale; + + args.block_scale_m = block_scale_m_; + args.block_scale_n = block_scale_n_; + } + args.rand_val_ptr = randval_buf.GetDeviceBuffer(); args.stride_randval = stride_randval; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index dafe99febe..4ba64db6de 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -156,6 +156,27 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_o; }; + struct FmhaFwdCommonBlockScaleKargs + { + const float* q_scale_ptr = nullptr; + const float* k_scale_ptr = nullptr; + const float* v_scale_ptr = nullptr; + + ck_tile::index_t nhead_stride_q_scale; + ck_tile::index_t nhead_stride_k_scale; + ck_tile::index_t nhead_stride_v_scale; + + ck_tile::index_t block_scale_m; + ck_tile::index_t block_scale_n; + }; + + struct FmhaFwdBatchBlockScaleKargs : public FmhaFwdCommonBlockScaleKargs + { + ck_tile::index_t batch_stride_q_scale; + ck_tile::index_t batch_stride_k_scale; + ck_tile::index_t batch_stride_v_scale; + }; + struct FmhaFwdLogitsSoftCapKargs { FmhaFwdLogitsSoftCapKargs() = default; @@ -287,7 +308,8 @@ struct FmhaFwdKernel std::conditional_t>, std::conditional_t>, std::conditional_t>, - std::conditional_t> + std::conditional_t>, + std::conditional_t> { ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_k; @@ -342,6 +364,9 @@ struct FmhaFwdKernel void* rand_val_ptr, void* lse_ptr, void* o_ptr, + const float* q_scale_ptr, + const float* k_scale_ptr, + const float* v_scale_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, @@ -365,6 +390,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_scale, + ck_tile::index_t nhead_stride_k_scale, + ck_tile::index_t nhead_stride_v_scale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -372,6 +400,9 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_q_scale, + ck_tile::index_t batch_stride_k_scale, + ck_tile::index_t batch_stride_v_scale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, @@ -379,6 +410,8 @@ struct FmhaFwdKernel bool s_randval, std::variant, std::pair> drop_seed_offset, + ck_tile::index_t block_scale_m, + ck_tile::index_t block_scale_n, const ck_tile::index_t* cu_seqlen_q_ptr = nullptr, const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr) { @@ -411,6 +444,7 @@ struct FmhaFwdKernel {}, // placeholder for fp8_static_quant args {}, // placeholder for dropout {}, // placeholder for logits_soft_cap + {}, // palceholder for quant scale batch_stride_q, batch_stride_k, batch_stride_v, @@ -471,6 +505,24 @@ struct FmhaFwdKernel kargs.init_logits_soft_cap(logits_soft_cap); } + if constexpr(kDoFp8StaticQuant) + { + kargs.q_scale_ptr = q_scale_ptr; + kargs.k_scale_ptr = k_scale_ptr; + kargs.v_scale_ptr = v_scale_ptr; + + kargs.nhead_stride_q_scale = nhead_stride_q_scale; + kargs.nhead_stride_k_scale = nhead_stride_k_scale; + kargs.nhead_stride_v_scale = nhead_stride_v_scale; + + kargs.batch_stride_q_scale = batch_stride_q_scale; + kargs.batch_stride_k_scale = batch_stride_k_scale; + kargs.batch_stride_v_scale = batch_stride_v_scale; + + kargs.block_scale_m = block_scale_m; + kargs.block_scale_n = block_scale_n; + } + kargs.cu_seqlen_q_ptr = cu_seqlen_q_ptr; kargs.cu_seqlen_kv_ptr = cu_seqlen_kv_ptr; return kargs; @@ -486,6 +538,9 @@ struct FmhaFwdKernel void* rand_val_ptr, void* lse_ptr, void* o_ptr, + const float* q_scale_ptr, + const float* k_scale_ptr, + const float* v_scale_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, @@ -509,6 +564,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_scale, + ck_tile::index_t nhead_stride_k_scale, + ck_tile::index_t nhead_stride_v_scale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -516,12 +574,17 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_q_scale, + ck_tile::index_t batch_stride_k_scale, + ck_tile::index_t batch_stride_v_scale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple& drop_seed_offset, + ck_tile::index_t block_scale_m, + ck_tile::index_t block_scale_n, const ck_tile::index_t* cu_seqlen_q_ptr = nullptr, const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr) { @@ -533,6 +596,9 @@ struct FmhaFwdKernel rand_val_ptr, lse_ptr, o_ptr, + q_scale_ptr, + k_scale_ptr, + v_scale_ptr, seqlen_q, seqlen_k, hdim_q, @@ -556,6 +622,9 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_scale, + nhead_stride_k_scale, + nhead_stride_v_scale, batch_stride_q, batch_stride_k, batch_stride_v, @@ -563,12 +632,17 @@ struct FmhaFwdKernel batch_stride_randval, batch_stride_lse, batch_stride_o, + batch_stride_q_scale, + batch_stride_k_scale, + batch_stride_v_scale, window_size_left, window_size_right, mask_type, p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + block_scale_m, + block_scale_n, cu_seqlen_q_ptr, cu_seqlen_kv_ptr); } @@ -583,6 +657,9 @@ struct FmhaFwdKernel void* rand_val_ptr, void* lse_ptr, void* o_ptr, + const float* q_scale_ptr, + const float* k_scale_ptr, + const float* v_scale_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, @@ -606,6 +683,9 @@ struct FmhaFwdKernel ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, + ck_tile::index_t nhead_stride_q_scale, + ck_tile::index_t nhead_stride_k_scale, + ck_tile::index_t nhead_stride_v_scale, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, @@ -613,12 +693,17 @@ struct FmhaFwdKernel ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, + ck_tile::index_t batch_stride_q_scale, + ck_tile::index_t batch_stride_k_scale, + ck_tile::index_t batch_stride_v_scale, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple& drop_seed_offset, + ck_tile::index_t block_scale_m, + ck_tile::index_t block_scale_n, const ck_tile::index_t* cu_seqlen_q_ptr = nullptr, const ck_tile::index_t* cu_seqlen_kv_ptr = nullptr) { @@ -630,6 +715,9 @@ struct FmhaFwdKernel rand_val_ptr, lse_ptr, o_ptr, + q_scale_ptr, + k_scale_ptr, + v_scale_ptr, seqlen_q, seqlen_k, hdim_q, @@ -653,6 +741,9 @@ struct FmhaFwdKernel nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_scale, + nhead_stride_k_scale, + nhead_stride_v_scale, batch_stride_q, batch_stride_k, batch_stride_v, @@ -660,12 +751,17 @@ struct FmhaFwdKernel batch_stride_randval, batch_stride_lse, batch_stride_o, + batch_stride_q_scale, + batch_stride_k_scale, + batch_stride_v_scale, window_size_left, window_size_right, mask_type, p_drop, s_randval, std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)), + block_scale_m, + block_scale_n, cu_seqlen_q_ptr, cu_seqlen_kv_ptr); }