Fix the integer overflow in total_flops calculation

This commit is contained in:
Qianfeng Zhang
2025-04-17 10:34:13 +00:00
parent 6086ead2f9
commit b0ae27046f

View File

@@ -285,14 +285,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(is_jagged)
{
for(auto len : seq_lengths)
total_flops += (static_cast<long>(len) * len * hdim_qk + len * hdim_v * len) * 2;
total_flops +=
(static_cast<long>(len) * len * hdim_qk + static_cast<long>(len) * hdim_v * len) *
2;
total_flops *= num_head;
}
else
{
total_flops = static_cast<long>(num_batch) * num_head *
(seqlen * seqlen * hdim_qk + seqlen * hdim_v * seqlen) * 2;
(static_cast<long>(seqlen) * seqlen * hdim_qk +
static_cast<long>(seqlen) * hdim_v * seqlen) *
2;
};
int batches_for_alloc = is_jagged ? 1 : num_batch;
@@ -402,7 +406,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
params.philox_offset = 0UL;
};
show_hstu_attention_fwd_param(std::cout, params);
// show_hstu_attention_fwd_param(std::cout, params);
std::ignore = show_hstu_attention_fwd_param;
hipStream_t stream;