diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index a45a0a7174..77f236422a 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -383,9 +383,9 @@ bool run(const ck_tile::ArgParser& arg_parser) } const ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim"); - if(!(rotary_dim < hdim_q) || !(rotary_dim < hdim_v)) + if(!(rotary_dim <= hdim_q)) { - std::cerr << "rotary_dim should be less than head dim for q/q" << std::endl; + std::cerr << "rotary_dim should be less than or equal to head dim for q" << std::endl; return false; } @@ -1053,7 +1053,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // copy Knew to the end of K if(0 < seqlen_knew) { - ck_tile::HostTensor knew_host_ref({nhead, real_seqlen_k, hdim_q}); + ck_tile::HostTensor knew_host_ref({nhead, seqlen_knew, hdim_q}); if(i_perm) knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[0] / nr, i[1], i[2]); }); else knew_host_ref.ForEach([&](auto& self, auto i) { self(i) = knew_host(b, i[1], i[0] / nr, i[2]); });