diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 50eac35c3f..5bc6544746 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -32,7 +32,7 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); ck_tile::HostTensor v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim}); ck_tile::HostTensor o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim}); + ck_tile::index_t seq_q_off = cu_query_lens[b]; // Copy effective region q_b.ForEach([&](auto& self, auto idx) { // idx: [0, s, h, d] - self(idx) = q(b, idx[1], idx[2], idx[3]); + self(idx) = q(seq_q_off + idx[1], idx[2], idx[3]); }); k_b.ForEach([&](auto& self, auto idx) { // kv cache is paged @@ -516,7 +517,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) { for(int d = 0; d < problem.hdim; ++d) { - o_ref(b, s, h, d) = o_b(0, s, h, d); + o_ref(seq_q_off + s, h, d) = o_b(0, s, h, d); } } }