From df55e264ad75fca1af5c8b786c4df022cfa32ab1 Mon Sep 17 00:00:00 2001 From: ltqin Date: Wed, 12 Nov 2025 05:42:59 +0000 Subject: [PATCH] result right --- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 54 ++++++------------ .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 3 +- .../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 56 +++++++++++++------ .../block_fmha_pipeline_qr_ks_vs_async.hpp | 2 + 4 files changed, 61 insertions(+), 54 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index c545699989..1c9c4d1b76 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -167,9 +167,9 @@ class BlockQuantizer size_t seq_len = in.get_length(i_perm ? 2 : 1); size_t hdim = in.get_length(3); size_t num_blocks_ = (seq_len + block_size_ - 1) / block_size_; - std::cout << "batch: " << batch << " head: " << head << " seq_len: " << seq_len - << " hdim: " << hdim << " dtype_max: " << dtype_max - << " num_blocks_: " << num_blocks_ << std::endl; + // std::cout << "batch: " << batch << " head: " << head << " seq_len: " << seq_len + // << " hdim: " << hdim << " dtype_max: " << dtype_max + // << " num_blocks_: " << num_blocks_ << std::endl; std::random_device rd; std::mt19937 gen(rd()); std::uniform_real_distribution dis(0.5f, 2.0f); @@ -215,9 +215,9 @@ class BlockQuantizer } // save scale to tensor block_scale(b, h, block) = 1.0f / scale; - std::cout << "block: " << block << " scale: " << scale - << " max_value: " << max_value << " block_scale: " << block_scale - << std::endl; + // std::cout << "block: " << block << " scale: " << scale + // << " max_value: " << max_value << " block_scale: " << block_scale + // << std::endl; } } } @@ -806,32 +806,13 @@ fwd_result fmha_fwd_run(mode_enum mode, float scale_o = 1.f; if(quant == 2) { - ck_tile::FillUniformDistributionIntegerValue{1.f, 10.f, next_seed()}(q_scale); - ck_tile::FillUniformDistributionIntegerValue{1.f, 10.f, next_seed()}(k_scale); - ck_tile::FillUniformDistributionIntegerValue{1.f, 10.f, next_seed()}(v_scale); - - { //debug info - std::cout << "q_scale: " << q_scale << " k_scale: " << k_scale - << " v_scale: " << v_scale << std::endl; - - ck_tile::HostTensor q_host_deq( - get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); - ck_tile::HostTensor k_host_deq( - 0 < page_block_size - ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q) - : get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); - ck_tile::HostTensor v_host_deq( - 0 < page_block_size - ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q) - : get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); - BlockQuantizer quantizer(i_perm); - quantizer.dequantize(q_host, q_host_deq, q_scale, block_scale_m_); - quantizer.dequantize(k_host, k_host_deq, k_scale, block_scale_n_); - quantizer.dequantize(v_host, v_host_deq, v_scale, block_scale_n_); - q_host_deq.savetxt("./q_deq.txt"); - k_host_deq.savetxt("./k_deq.txt"); - v_host_deq.savetxt("./v_deq.txt"); - } + BlockQuantizer quantizer(i_perm); + quantizer.quantize(q_host, q_host, q_scale, block_scale_m_); + quantizer.quantize(k_host, k_host, k_scale, block_scale_n_); + quantizer.quantize(v_host, v_host, v_scale, block_scale_n_); + q_host.savetxt("./q_quant.txt"); + k_host.savetxt("./k_quant.txt"); + v_host.savetxt("./v_quant.txt"); } else if(quant == 1) { @@ -1737,10 +1718,10 @@ fwd_result fmha_fwd_run(mode_enum mode, s_host_ref, ck_tile::idx_identity{}, ck_tile::idx_identity{}, - [&q_scale, &k_scale, scale_s, wb](auto idx, auto value) { + [&](auto idx, auto value) { return value * scale_s * q_scale(wb, std::get<0>(idx), std::get<1>(idx) / 128) * - k_scale(wb, std::get<0>(idx), std::get<2>(idx) / 128); + k_scale(wb, std::get<0>(idx) / nr, std::get<2>(idx) / 128); }); } else @@ -1919,10 +1900,9 @@ fwd_result fmha_fwd_run(mode_enum mode, v_host_ref, o_host_ref, ck_tile::idx_identity{}, - [&v_scale, wb](auto idx, auto value) { - // idx: b, m, n, k --> h, sq, d, sk + [&v_scale, wb, nr](auto idx, auto value) { return ck_tile::type_convert(value) * - v_scale(wb, std::get<0>(idx), std::get<2>(idx) / 128); + v_scale(wb, std::get<0>(idx) / nr, std::get<2>(idx) / 128); }, ck_tile::idx_identity{}); } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 910bab8b26..1219f4d2e9 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1728,12 +1728,13 @@ struct FmhaFwdKernel o_acc_element_func, // o_acc_element_func mask, position_encoding, - kargs.scale_s * q_scale, + kargs.scale_s, variant, variant_params, block_indices, smem_ptr, dropout, + q_scale, k_scale_ptr, v_scale_ptr, kargs.block_scale_m, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 9e3b4ae066..25ed192443 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -176,6 +176,7 @@ struct BlockFmhaPipelineQRKSVS const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout, + const float q_scale, const float* k_scale_ptr, const float* v_scale_ptr, index_t, @@ -407,7 +408,7 @@ struct BlockFmhaPipelineQRKSVS { if(k_scale_ptr) { - tile_elementwise_inout([k_scale](auto& x) { x = x * k_scale; }, s_acc); + tile_elementwise_inout([q_scale, k_scale](auto& x) { x = x * q_scale * k_scale; }, s_acc); } } @@ -640,18 +641,39 @@ struct BlockFmhaPipelineQRKSVS const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + auto wrapper_gemm1 = [&](auto& acc, auto a, auto b) { + if constexpr(kDoFp8StaticQuant) + { + auto acc0 = gemm_1(a, b); + tile_elementwise_inout( + [&v_scale](auto& o, auto o0) { + // asm volatile(";wrapper_gemm1\n\tv_mul_f32_e32 %0, %1, %2" + // : "=v"(o) + // : "s"(v_scale), "v"(o0) + // : "memory"); + o += o0 * v_scale; + }, + acc, + acc0); + } + else + { + gemm_1(acc, a, b); + }; + }; // STAGE 3, KV gemm - auto o_acc_tmp = decltype(o_acc){}; - clear_tile(o_acc_tmp); + // auto o_acc_tmp = decltype(o_acc){}; + // clear_tile(o_acc_tmp); if constexpr(k1_loops > 1) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { const auto v = load_tile(v_dram_window); // load next v block_sync_lds(); - gemm_1(o_acc_tmp, - get_slice_tile( - p, sequence<0, i_k1 * kK1>{}, sequence{}), - v_lds_window); + wrapper_gemm1(o_acc, + get_slice_tile(p, + sequence<0, i_k1 * kK1>{}, + sequence{}), + v_lds_window); block_sync_lds(); if constexpr(std::is_same_v) { @@ -675,20 +697,21 @@ struct BlockFmhaPipelineQRKSVS // tail { block_sync_lds(); - gemm_1(o_acc_tmp, - get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), - v_lds_window); + wrapper_gemm1( + o_acc, + get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + v_lds_window); block_sync_lds(); } // o_acc += o_acc_tmp; // o_acc += tile_elementwise_in(scale(1.0f / v_scale), o_acc_tmp); // ck_tile::ignore = v_scale; - sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { - sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - o_acc(i_j_idx) += o_acc_tmp(i_j_idx) * v_scale; - }); - }); + // sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { + // sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) { + // constexpr auto i_j_idx = make_tuple(idx0, idx1); + // o_acc(i_j_idx) += o_acc_tmp(i_j_idx) * v_scale; + // }); + // }); } while(++i_total_loops < num_total_loop); @@ -796,6 +819,7 @@ struct BlockFmhaPipelineQRKSVS block_indices, smem_ptr, dropout, + 1.0f, nullptr, nullptr, 128, diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 9f60c1f8c1..9751628ca2 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -186,6 +186,7 @@ struct BlockFmhaPipelineQRKSVSAsync const BlockIndices& block_indices, void* smem_ptr, DropoutType& dropout, + const float, const float*, const float*, index_t, @@ -850,6 +851,7 @@ struct BlockFmhaPipelineQRKSVSAsync block_indices, smem_ptr, dropout, + 1.0f, nullptr, nullptr, 128,