alloc divice memory for scale

This commit is contained in:
ltqin
2025-11-05 14:31:22 +00:00
parent 2349af01c7
commit 4626bace60

View File

@@ -170,7 +170,9 @@ class BlockQuantizer
std::cout << "batch: " << batch << " head: " << head << " seq_len: " << seq_len
<< " hdim: " << hdim << " dtype_max: " << dtype_max
<< " num_blocks_: " << num_blocks_ << std::endl;
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<float> dis(0.5, 2.0f);
for(size_t b = 0; b < batch; ++b){
for(size_t h = 0; h < head; ++h)
{
@@ -193,6 +195,7 @@ class BlockQuantizer
}
}
// calculate block scale
max_value += dis(gen);
float scale = dtype_max / max_value;
block_scale(b,h,block) = scale;
std::cout << "block: " << block << " scale: " << scale << " max_value: " << max_value << " block_scale: " << block_scale << std::endl;
@@ -635,16 +638,21 @@ fwd_result fmha_fwd_run(mode_enum mode,
: (seqlen_kpads[0] < 0 ? seqstart_k_host.back()
: seqstart_k_with_padding_host.back()));
const ck_tile::index_t num_block_scale_m =
ck_tile::integer_divide_ceil(shape_seqlen_q, block_scale_m_);
const ck_tile::index_t num_block_scale_n =
ck_tile::integer_divide_ceil(shape_seqlen_k, block_scale_n_);
ck_tile::HostTensor<QDataType> q_host(
get_lengths(i_perm, shape_batch, nhead, shape_seqlen_q, hdim_q));
ck_tile::HostTensor<float> q_scale(std::array<ck_tile::index_t, 3>{
shape_batch, nhead, ck_tile::integer_divide_ceil(shape_seqlen_q, block_scale_m_)});
ck_tile::HostTensor<float> q_scale(
std::array<ck_tile::index_t, 3>{shape_batch, nhead, num_block_scale_m});
ck_tile::HostTensor<KDataType> k_host(
0 < page_block_size
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
ck_tile::HostTensor<float> k_scale(std::array<ck_tile::index_t, 3>{
shape_batch, nhead_k, ck_tile::integer_divide_ceil(shape_seqlen_k, block_scale_n_)});
shape_batch, nhead_k, num_block_scale_n});
/// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode
ck_tile::HostTensor<KDataType> knew_host(
0 < seqlen_knew
@@ -658,7 +666,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
: (is_v_rowmajor ? get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_v)
: get_lengths(i_perm, shape_batch, nhead_k, hdim_v, shape_seqlen_k)));
ck_tile::HostTensor<float> v_scale(std::array<ck_tile::index_t, 3>{
shape_batch, nhead_k, ck_tile::integer_divide_ceil(shape_seqlen_k, block_scale_n_)});
shape_batch, nhead_k, num_block_scale_n});
ck_tile::HostTensor<VDataType> vnew_host(
0 < seqlen_knew
? (is_v_rowmajor ? get_lengths(i_perm, batch, nhead_k, seqlen_knew, hdim_v)
@@ -779,43 +787,6 @@ fwd_result fmha_fwd_run(mode_enum mode,
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0, random_engine);
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_q_padded_buf(seqstart_q_with_padding_host.empty()
? 0
: seqstart_q_with_padding_host.size() *
sizeof(int32_t));
ck_tile::DeviceMem seqstart_k_padded_buf(
seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t));
ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0
: cuq_cum.size() * sizeof(ck_tile::index_t));
ck_tile::DeviceMem cu_seqlen_kv_buf(
cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t));
ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) ||
0 <= seqlen_kpads[0]
? seqlen_ks.size() * sizeof(int32_t)
: 0);
ck_tile::DeviceMem cache_seqlen_k_buf(
need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0);
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes());
float scale_p = 1.f;
float scale_o = 1.f;
if(quant == 2)
@@ -894,6 +865,48 @@ fwd_result fmha_fwd_run(mode_enum mode,
}
}
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_k(seqstart_k_host.size() * sizeof(int32_t));
ck_tile::DeviceMem seqstart_q_padded_buf(seqstart_q_with_padding_host.empty()
? 0
: seqstart_q_with_padding_host.size() *
sizeof(int32_t));
ck_tile::DeviceMem seqstart_k_padded_buf(
seqlen_kpads[0] < 0 ? 0 : seqstart_k_with_padding_host.size() * sizeof(int32_t));
ck_tile::DeviceMem cu_seqlen_q_buf(cuq_cum.empty() ? 0
: cuq_cum.size() * sizeof(ck_tile::index_t));
ck_tile::DeviceMem cu_seqlen_kv_buf(
cukv_cum.empty() ? 0 : cukv_cum.size() * sizeof(ck_tile::index_t));
ck_tile::DeviceMem seqlen_k_buf((mode == mode_enum::batch && use_kvcache) ||
0 <= seqlen_kpads[0]
? seqlen_ks.size() * sizeof(int32_t)
: 0);
ck_tile::DeviceMem cache_seqlen_k_buf(
need_append_kvcache ? cache_seqlen_ks.size() * sizeof(int32_t) : 0);
ck_tile::DeviceMem rotary_cos_buf(rotary_cos_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem rotary_sin_buf(rotary_sin_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem drop_seed_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem drop_offset_buf(drop_prefs ? sizeof(uint64_t) : 0);
ck_tile::DeviceMem randval_buf(randval_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem alibi_slope_buf(alibi_slope_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem q_scale_buf(q_scale.get_element_space_size_in_bytes());
ck_tile::DeviceMem k_scale_buf(k_scale.get_element_space_size_in_bytes());
ck_tile::DeviceMem v_scale_buf(v_scale.get_element_space_size_in_bytes());
q_buf.ToDevice(q_host.data());
k_buf.ToDevice(k_host.data());
v_buf.ToDevice(v_host.data());
@@ -921,6 +934,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
block_table_buf.ToDevice(block_table_host.data());
cache_batch_idx_buf.ToDevice(cache_batch_idx_host.data());
q_scale_buf.ToDevice(q_scale.data());
k_scale_buf.ToDevice(k_scale.data());
v_scale_buf.ToDevice(v_scale.data());
if(quant == 2)
{
//dequant data for host