start group

This commit is contained in:
ltqin
2025-11-21 02:57:07 +00:00
parent bd5135a83a
commit a7048cf4d4
2 changed files with 29 additions and 8 deletions

View File

@@ -451,9 +451,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
std::size_t flop = 0, num_byte = 0;
auto max_seqlen_q =
std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size
size_t num_block_scale_q = 0;
size_t num_block_scale_k = 0;
auto max_seqlen_k = std::numeric_limits<int32_t>::min();
size_t num_block_scale_q = 0;
size_t num_block_scale_k = 0;
std::vector<int32_t> bseqstart_q_host = {0};
std::vector<int32_t> bseqstart_k_host = {0};
auto max_seqlen_k = std::numeric_limits<int32_t>::min();
{
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
{
@@ -471,6 +473,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
}
num_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_m_);
num_block_scale_k += ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_n_);
bseqstart_q_host.push_back(num_block_scale_q);
bseqstart_k_host.push_back(num_block_scale_k);
flop += nhead * (static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_q +
static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_v);
@@ -481,6 +485,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
sizeof(VDataType) * hdim_v * real_seqlen_k);
}
}
// std::cout << "bseqstart_q_host: " << bseqstart_q_host
// << "bseqstart_k_host: " << bseqstart_k_host << std::endl;
const ck_tile::index_t max_num_page_blocks =
(0 < page_block_size
@@ -720,6 +726,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(q_descale_host);
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(k_descale_host);
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(v_descale_host);
// return fwd_result::no_instance;
}
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);
@@ -1611,6 +1618,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
// reference
if(qscale.type == quant_scale_enum::blockscale)
{
const ck_tile::index_t q_offset =
(mode == mode_enum::batch) ? 0 : bseqstart_q_host[wb];
const ck_tile::index_t k_offset =
(mode == mode_enum::batch) ? 0 : bseqstart_k_host[wb];
ck_tile::reference_batched_quant_gemm<QDataType,
KDataType,
SaccDataType,
@@ -1622,10 +1633,12 @@ fwd_result fmha_fwd_run(mode_enum mode,
ck_tile::idx_identity{},
[&](auto idx, auto value) {
return value * scale_s *
q_descale_host(
b_idx, std::get<0>(idx), std::get<1>(idx) / block_scale_m_) *
k_descale_host(
b_idx, std::get<0>(idx) / nr, std::get<2>(idx) / block_scale_n_);
q_descale_host(b_idx,
std::get<0>(idx),
q_offset + std::get<1>(idx) / block_scale_m_) *
k_descale_host(b_idx,
std::get<0>(idx) / nr,
k_offset + std::get<2>(idx) / block_scale_n_);
});
}
else
@@ -1798,6 +1811,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
if(qscale.type == quant_scale_enum::blockscale)
{
const ck_tile::index_t v_offset =
(mode == mode_enum::batch) ? 0 : bseqstart_k_host[wb];
ck_tile::
reference_batched_quant_gemm<PDataType, VDataType, OaccDataType, ODataType>(
p_host_ref,
@@ -1808,7 +1823,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
return ck_tile::type_convert<float>(value) *
v_descale_host(b_idx,
std::get<0>(idx) / nr,
std::get<2>(idx) / block_scale_n_);
v_offset + std::get<2>(idx) / block_scale_n_);
},
ck_tile::idx_identity{});
}

View File

@@ -1294,6 +1294,12 @@ struct FmhaFwdKernel
{
batch_offset_randval = query_start * kargs.stride_randval;
}
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
{
batch_offset_q_descale = query_start/128;
batch_offset_k_descale = key_start/128;
batch_offset_v_descale = key_start/128;
}
batch_offset_o = query_start * kargs.stride_o;
// real logical lengths (exclude PAD)