mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Fix the integer overflow in total_flops calculation
This commit is contained in:
@@ -285,14 +285,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(is_jagged)
|
||||
{
|
||||
for(auto len : seq_lengths)
|
||||
total_flops += (static_cast<long>(len) * len * hdim_qk + len * hdim_v * len) * 2;
|
||||
total_flops +=
|
||||
(static_cast<long>(len) * len * hdim_qk + static_cast<long>(len) * hdim_v * len) *
|
||||
2;
|
||||
|
||||
total_flops *= num_head;
|
||||
}
|
||||
else
|
||||
{
|
||||
total_flops = static_cast<long>(num_batch) * num_head *
|
||||
(seqlen * seqlen * hdim_qk + seqlen * hdim_v * seqlen) * 2;
|
||||
(static_cast<long>(seqlen) * seqlen * hdim_qk +
|
||||
static_cast<long>(seqlen) * hdim_v * seqlen) *
|
||||
2;
|
||||
};
|
||||
|
||||
int batches_for_alloc = is_jagged ? 1 : num_batch;
|
||||
@@ -402,7 +406,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
params.philox_offset = 0UL;
|
||||
};
|
||||
|
||||
show_hstu_attention_fwd_param(std::cout, params);
|
||||
// show_hstu_attention_fwd_param(std::cout, params);
|
||||
std::ignore = show_hstu_attention_fwd_param;
|
||||
|
||||
hipStream_t stream;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user