ref data copying

This commit is contained in:
Tianxing Wu
2025-11-20 11:34:39 +00:00
parent de995fea71
commit f552cd7841

View File

@@ -32,7 +32,7 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
{
ck_tile::ArgParser arg_parser;
arg_parser
.insert("prec", "fp16", "data type. fp16/bf16")
.insert("prec", "bf16", "data type. fp16/bf16")
// .insert("b", "3", "batch size")
.insert("h_k", "8", "num head for k/v. num head for q is " + std::to_string(num_queries_per_kv) + " times this")
// .insert("h_k",
@@ -475,11 +475,12 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
ck_tile::HostTensor<DataType> k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
ck_tile::HostTensor<DataType> v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
ck_tile::HostTensor<DataType> o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
ck_tile::index_t seq_q_off = cu_query_lens[b];
// Copy effective region
q_b.ForEach([&](auto& self, auto idx) {
// idx: [0, s, h, d]
self(idx) = q(b, idx[1], idx[2], idx[3]);
self(idx) = q(seq_q_off + idx[1], idx[2], idx[3]);
});
k_b.ForEach([&](auto& self, auto idx) {
// kv cache is paged
@@ -516,7 +517,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
{
for(int d = 0; d < problem.hdim; ++d)
{
o_ref(b, s, h, d) = o_b(0, s, h, d);
o_ref(seq_q_off + s, h, d) = o_b(0, s, h, d);
}
}
}