mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 04:37:02 +00:00
start group
This commit is contained in:
@@ -451,9 +451,11 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
std::size_t flop = 0, num_byte = 0;
|
||||
auto max_seqlen_q =
|
||||
std::numeric_limits<int32_t>::min(); // we will use max seqlen to decide grid size
|
||||
size_t num_block_scale_q = 0;
|
||||
size_t num_block_scale_k = 0;
|
||||
auto max_seqlen_k = std::numeric_limits<int32_t>::min();
|
||||
size_t num_block_scale_q = 0;
|
||||
size_t num_block_scale_k = 0;
|
||||
std::vector<int32_t> bseqstart_q_host = {0};
|
||||
std::vector<int32_t> bseqstart_k_host = {0};
|
||||
auto max_seqlen_k = std::numeric_limits<int32_t>::min();
|
||||
{
|
||||
for(ck_tile::index_t wb = 0; wb < batch; ++wb)
|
||||
{
|
||||
@@ -471,6 +473,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
}
|
||||
num_block_scale_q += ck_tile::integer_divide_ceil(real_seqlen_q, block_scale_m_);
|
||||
num_block_scale_k += ck_tile::integer_divide_ceil(real_seqlen_k, block_scale_n_);
|
||||
bseqstart_q_host.push_back(num_block_scale_q);
|
||||
bseqstart_k_host.push_back(num_block_scale_k);
|
||||
|
||||
flop += nhead * (static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_q +
|
||||
static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_v);
|
||||
@@ -481,6 +485,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
sizeof(VDataType) * hdim_v * real_seqlen_k);
|
||||
}
|
||||
}
|
||||
// std::cout << "bseqstart_q_host: " << bseqstart_q_host
|
||||
// << "bseqstart_k_host: " << bseqstart_k_host << std::endl;
|
||||
|
||||
const ck_tile::index_t max_num_page_blocks =
|
||||
(0 < page_block_size
|
||||
@@ -720,6 +726,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(q_descale_host);
|
||||
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(k_descale_host);
|
||||
ck_tile::FillUniformDistribution<float>{0.012f, 0.015f, next_seed()}(v_descale_host);
|
||||
// return fwd_result::no_instance;
|
||||
}
|
||||
|
||||
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0, random_engine);
|
||||
@@ -1611,6 +1618,10 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
// reference
|
||||
if(qscale.type == quant_scale_enum::blockscale)
|
||||
{
|
||||
const ck_tile::index_t q_offset =
|
||||
(mode == mode_enum::batch) ? 0 : bseqstart_q_host[wb];
|
||||
const ck_tile::index_t k_offset =
|
||||
(mode == mode_enum::batch) ? 0 : bseqstart_k_host[wb];
|
||||
ck_tile::reference_batched_quant_gemm<QDataType,
|
||||
KDataType,
|
||||
SaccDataType,
|
||||
@@ -1622,10 +1633,12 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
ck_tile::idx_identity{},
|
||||
[&](auto idx, auto value) {
|
||||
return value * scale_s *
|
||||
q_descale_host(
|
||||
b_idx, std::get<0>(idx), std::get<1>(idx) / block_scale_m_) *
|
||||
k_descale_host(
|
||||
b_idx, std::get<0>(idx) / nr, std::get<2>(idx) / block_scale_n_);
|
||||
q_descale_host(b_idx,
|
||||
std::get<0>(idx),
|
||||
q_offset + std::get<1>(idx) / block_scale_m_) *
|
||||
k_descale_host(b_idx,
|
||||
std::get<0>(idx) / nr,
|
||||
k_offset + std::get<2>(idx) / block_scale_n_);
|
||||
});
|
||||
}
|
||||
else
|
||||
@@ -1798,6 +1811,8 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
|
||||
if(qscale.type == quant_scale_enum::blockscale)
|
||||
{
|
||||
const ck_tile::index_t v_offset =
|
||||
(mode == mode_enum::batch) ? 0 : bseqstart_k_host[wb];
|
||||
ck_tile::
|
||||
reference_batched_quant_gemm<PDataType, VDataType, OaccDataType, ODataType>(
|
||||
p_host_ref,
|
||||
@@ -1808,7 +1823,7 @@ fwd_result fmha_fwd_run(mode_enum mode,
|
||||
return ck_tile::type_convert<float>(value) *
|
||||
v_descale_host(b_idx,
|
||||
std::get<0>(idx) / nr,
|
||||
std::get<2>(idx) / block_scale_n_);
|
||||
v_offset + std::get<2>(idx) / block_scale_n_);
|
||||
},
|
||||
ck_tile::idx_identity{});
|
||||
}
|
||||
|
||||
@@ -1294,6 +1294,12 @@ struct FmhaFwdKernel
|
||||
{
|
||||
batch_offset_randval = query_start * kargs.stride_randval;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
batch_offset_q_descale = query_start/128;
|
||||
batch_offset_k_descale = key_start/128;
|
||||
batch_offset_v_descale = key_start/128;
|
||||
}
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
|
||||
// real logical lengths (exclude PAD)
|
||||
|
||||
Reference in New Issue
Block a user