diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index d3913eaf5b..cbeb8c2a88 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -451,9 +451,11 @@ fwd_result fmha_fwd_run(mode_enum mode, std::size_t flop = 0, num_byte = 0; auto max_seqlen_q = std::numeric_limits::min(); // we will use max seqlen to decide grid size - size_t num_block_scale_q = 0; - size_t num_block_scale_k = 0; - auto max_seqlen_k = std::numeric_limits::min(); + size_t num_block_scale_q = 0; + size_t num_block_scale_k = 0; + std::vector bseqstart_q_host = {0}; + std::vector bseqstart_k_host = {0}; + auto max_seqlen_k = std::numeric_limits::min(); { for(ck_tile::index_t wb = 0; wb < batch; ++wb) { @@ -471,6 +473,8 @@ fwd_result fmha_fwd_run(mode_enum mode, } num_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_m_); num_block_scale_k += ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_n_); + bseqstart_q_host.push_back(num_block_scale_q); + bseqstart_k_host.push_back(num_block_scale_k); flop += nhead * (static_cast(2) * mask.get_unmaskarea() * hdim_q + static_cast(2) * mask.get_unmaskarea() * hdim_v); @@ -481,6 +485,8 @@ fwd_result fmha_fwd_run(mode_enum mode, sizeof(VDataType) * hdim_v * real_seqlen_k); } } + // std::cout << "bseqstart_q_host: " << bseqstart_q_host + // << "bseqstart_k_host: " << bseqstart_k_host << std::endl; const ck_tile::index_t max_num_page_blocks = (0 < page_block_size @@ -720,6 +726,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(q_descale_host); ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(k_descale_host); ck_tile::FillUniformDistribution{0.012f, 0.015f, next_seed()}(v_descale_host); + // return fwd_result::no_instance; } iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine); @@ -1611,6 +1618,10 @@ fwd_result fmha_fwd_run(mode_enum mode, // reference if(qscale.type == quant_scale_enum::blockscale) { + const ck_tile::index_t q_offset = + (mode == mode_enum::batch) ? 0 : bseqstart_q_host[wb]; + const ck_tile::index_t k_offset = + (mode == mode_enum::batch) ? 0 : bseqstart_k_host[wb]; ck_tile::reference_batched_quant_gemm(idx), std::get<1>(idx) / block_scale_m_) * - k_descale_host( - b_idx, std::get<0>(idx) / nr, std::get<2>(idx) / block_scale_n_); + q_descale_host(b_idx, + std::get<0>(idx), + q_offset + std::get<1>(idx) / block_scale_m_) * + k_descale_host(b_idx, + std::get<0>(idx) / nr, + k_offset + std::get<2>(idx) / block_scale_n_); }); } else @@ -1798,6 +1811,8 @@ fwd_result fmha_fwd_run(mode_enum mode, if(qscale.type == quant_scale_enum::blockscale) { + const ck_tile::index_t v_offset = + (mode == mode_enum::batch) ? 0 : bseqstart_k_host[wb]; ck_tile:: reference_batched_quant_gemm( p_host_ref, @@ -1808,7 +1823,7 @@ fwd_result fmha_fwd_run(mode_enum mode, return ck_tile::type_convert(value) * v_descale_host(b_idx, std::get<0>(idx) / nr, - std::get<2>(idx) / block_scale_n_); + v_offset + std::get<2>(idx) / block_scale_n_); }, ck_tile::idx_identity{}); } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 22fe34d29e..2c5121e6e6 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1294,6 +1294,12 @@ struct FmhaFwdKernel { batch_offset_randval = query_start * kargs.stride_randval; } + if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE) + { + batch_offset_q_descale = query_start/128; + batch_offset_k_descale = key_start/128; + batch_offset_v_descale = key_start/128; + } batch_offset_o = query_start * kargs.stride_o; // real logical lengths (exclude PAD)