From b48fe39bf5763c57ce74ed8807b18e5b2baa17d5 Mon Sep 17 00:00:00 2001 From: Cameron Shinn Date: Mon, 11 Aug 2025 22:44:01 -0700 Subject: [PATCH] Fix num_byte calculations to use nhead_k for K & V size (#2653) Simple fix just to calculate the number of bytes correctly for what's reported in the output. I was getting 6200 GB/s which is past the SoL of MI300. Before: ``` ./bin/tile_example_fmha_fwd -prec=bf16 -b=2 -s=1 -s_k=32768 -h=32 -h_k=8 -d=128 -page_block_size=128 -num_splits=8 -iperm=0 -operm=0 -v=0 -kname=1 [bf16|batch|bshd] b:2, h:32/8, s:1/32768, d:128/128, scale_s:0.0883883, bias:n, p_drop:0, lse:0, squant:0, mask:n, v:r, num_splits:8, page_block_size:128, fmha_fwd_splitkv_d128_bf16_batch_b16x64x64x128x64x128_r1x4x1_r1x4x1_w16x16x16_w16x16x16_qr_nwarp_sshuffle_vr_ps_nlogits_nbias_nmask_lse_nsquant_pagedkv, fmha_fwd_splitkv_combine_d128_bf16_batch_b32_unused_ps_nlse_nsquant, 0.173 ms, 6.20 TFlops, 6202.95 GB/s ``` After: ``` ./bin/tile_example_fmha_fwd -prec=bf16 -b=2 -s=1 -s_k=32768 -h=32 -h_k=8 -d=128 -page_block_size=128 -num_splits=8 -iperm=0 -operm=0 -v=0 -kname=1 [bf16|batch|bshd] b:2, h:32/8, s:1/32768, d:128/128, scale_s:0.0883883, bias:n, p_drop:0, lse:0, squant:0, mask:n, v:r, num_splits:8, page_block_size:128, fmha_fwd_splitkv_d128_bf16_batch_b16x64x64x128x64x128_r1x4x1_r1x4x1_w16x16x16_w16x16x16_qr_nwarp_sshuffle_vr_ps_nlogits_nbias_nmask_lse_nsquant_pagedkv, fmha_fwd_splitkv_combine_d128_bf16_batch_b32_unused_ps_nlse_nsquant, 0.163 ms, 6.58 TFlops, 1644.53 GB/s ``` [ROCm/composable_kernel commit: 352f87e6841f04c83a86eeab6c9718a99f7aad84] --- example/ck_tile/01_fmha/fmha_fwd.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index e9403f4698..48306e35fe 100755 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -525,10 +525,10 @@ bool run(const ck_tile::ArgParser& arg_parser) flop += nhead * (static_cast(2) * mask.get_unmaskarea() * hdim_q + static_cast(2) * mask.get_unmaskarea() * hdim_v); - num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + - sizeof(KDataType) * real_seqlen_k * hdim_q + - sizeof(VDataType) * hdim_v * real_seqlen_k + - sizeof(ODataType) * real_seqlen_q * hdim_v); + num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q + + sizeof(ODataType) * real_seqlen_q * hdim_v); + num_byte += nhead_k * (sizeof(KDataType) * real_seqlen_k * hdim_q + + sizeof(VDataType) * hdim_v * real_seqlen_k); } }