diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 6e9133806e..2389e763dd 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -603,8 +603,8 @@ bool run(const ck_tile::ArgParser& arg_parser) : std::array{batch, nhead}) : std::array{1, 1}); - auto [rotary_cos_host, rotary_sin_host] = - generate_rotary_cos_sin(shape_seqlen_k, rotary_dim, seed); + auto [rotary_cos_host, rotary_sin_host] = generate_rotary_cos_sin( + std::max(shape_seqlen_q, shape_seqlen_k), rotary_dim, seed); ck_tile::HostTensor lse_acc_host( 1 < num_splits || use_kvcache