From 4626bace60fd6ab159eb77f08e660c100aa47d7e Mon Sep 17 00:00:00 2001 From: ltqin Date: Wed, 5 Nov 2025 14:31:22 +0000 Subject: [PATCH] alloc divice memory for scale --- example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 101 ++++++++++++-------- 1 file changed, 59 insertions(+), 42 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index c06616c5e1..b8c86819d6 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -170,7 +170,9 @@ class BlockQuantizer std::cout << "batch: " << batch << " head: " << head << " seq_len: " << seq_len << " hdim: " << hdim << " dtype_max: " << dtype_max << " num_blocks_: " << num_blocks_ << std::endl; - + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution dis(0.5, 2.0f); for(size_t b = 0; b < batch; ++b){ for(size_t h = 0; h < head; ++h) { @@ -193,6 +195,7 @@ class BlockQuantizer } } // calculate block scale + max_value += dis(gen); float scale = dtype_max / max_value; block_scale(b,h,block) = scale; std::cout << "block: " << block << " scale: " << scale << " max_value: " << max_value << " block_scale: " << block_scale << std::endl; @@ -635,16 +638,21 @@ fwd_result fmha_fwd_run(mode_enum mode, : (seqlen_kpads[0] < 0 ? seqstart_k_host.back() : seqstart_k_with_padding_host.back())); + const ck_tile::index_t num_block_scale_m = + ck_tile::integer_divide_ceil(shape_seqlen_q, block_scale_m_); + const ck_tile::index_t num_block_scale_n = + ck_tile::integer_divide_ceil(shape_seqlen_k, block_scale_n_); + ck_tile::HostTensor q_host( get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q)); - ck_tile::HostTensor q_scale(std::array{ - shape_batch, nhead, ck_tile::integer_divide_ceil(shape_seqlen_q, block_scale_m_)}); + ck_tile::HostTensor q_scale( + std::array{shape_batch, nhead, num_block_scale_m}); ck_tile::HostTensor k_host( 0 < page_block_size ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q) : get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); ck_tile::HostTensor k_scale(std::array{ - shape_batch, nhead_k, ck_tile::integer_divide_ceil(shape_seqlen_k, block_scale_n_)}); + shape_batch, nhead_k, num_block_scale_n}); /// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode ck_tile::HostTensor knew_host( 0 < seqlen_knew @@ -658,7 +666,7 @@ fwd_result fmha_fwd_run(mode_enum mode, : (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v) : get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k))); ck_tile::HostTensor v_scale(std::array{ - shape_batch, nhead_k, ck_tile::integer_divide_ceil(shape_seqlen_k, block_scale_n_)}); + shape_batch, nhead_k, num_block_scale_n}); ck_tile::HostTensor vnew_host( 0 < seqlen_knew ? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v) @@ -779,43 +787,6 @@ fwd_result fmha_fwd_run(mode_enum mode, iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine); iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine); - ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); - ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); - ck_tile::DeviceMem seqstart_q_padded_buf(seqstart_q_with_padding_host.empty() - ? 0 - : seqstart_q_with_padding_host.size() * - sizeof(int32_t)); - ck_tile::DeviceMem seqstart_k_padded_buf( - seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t)); - ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0 - : cuq_cum.size() * sizeof(ck_tile::index_t)); - ck_tile::DeviceMem cu_seqlen_kv_buf( - cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t)); - ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || - 0 <= seqlen_kpads[0] - ? seqlen_ks.size() * sizeof(int32_t) - : 0); - ck_tile::DeviceMem cache_seqlen_k_buf( - need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0); - ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0); - ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0); - ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes()); - float scale_p = 1.f; float scale_o = 1.f; if(quant == 2) @@ -894,6 +865,48 @@ fwd_result fmha_fwd_run(mode_enum mode, } } + ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem seqstart_q_padded_buf(seqstart_q_with_padding_host.empty() + ? 0 + : seqstart_q_with_padding_host.size() * + sizeof(int32_t)); + ck_tile::DeviceMem seqstart_k_padded_buf( + seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t)); + ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0 + : cuq_cum.size() * sizeof(ck_tile::index_t)); + ck_tile::DeviceMem cu_seqlen_kv_buf( + cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t)); + ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) || + 0 <= seqlen_kpads[0] + ? seqlen_ks.size() * sizeof(int32_t) + : 0); + ck_tile::DeviceMem cache_seqlen_k_buf( + need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0); + ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0); + ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0); + ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem q_scale_buf(q_scale.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_scale_buf(k_scale.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_scale_buf(v_scale.get_element_space_size_in_bytes()); + + q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); v_buf.ToDevice(v_host.data()); @@ -921,6 +934,10 @@ fwd_result fmha_fwd_run(mode_enum mode, block_table_buf.ToDevice(block_table_host.data()); cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data()); + q_scale_buf.ToDevice(q_scale.data()); + k_scale_buf.ToDevice(k_scale.data()); + v_scale_buf.ToDevice(v_scale.data()); + if(quant == 2) { //dequant data for host