mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
ref data copying
This commit is contained in:
@@ -32,7 +32,7 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParse
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser
|
||||
.insert("prec", "fp16", "data type. fp16/bf16")
|
||||
.insert("prec", "bf16", "data type. fp16/bf16")
|
||||
// .insert("b", "3", "batch size")
|
||||
.insert("h_k", "8", "num head for k/v. num head for q is " + std::to_string(num_queries_per_kv) + " times this")
|
||||
// .insert("h_k",
|
||||
@@ -475,11 +475,12 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
ck_tile::HostTensor<DataType> k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
|
||||
ck_tile::HostTensor<DataType> v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
|
||||
ck_tile::HostTensor<DataType> o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
|
||||
ck_tile::index_t seq_q_off = cu_query_lens[b];
|
||||
|
||||
// Copy effective region
|
||||
q_b.ForEach([&](auto& self, auto idx) {
|
||||
// idx: [0, s, h, d]
|
||||
self(idx) = q(b, idx[1], idx[2], idx[3]);
|
||||
self(idx) = q(seq_q_off + idx[1], idx[2], idx[3]);
|
||||
});
|
||||
k_b.ForEach([&](auto& self, auto idx) {
|
||||
// kv cache is paged
|
||||
@@ -516,7 +517,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
{
|
||||
for(int d = 0; d < problem.hdim; ++d)
|
||||
{
|
||||
o_ref(b, s, h, d) = o_b(0, s, h, d);
|
||||
o_ref(seq_q_off + s, h, d) = o_b(0, s, h, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user