fix k origin

This commit is contained in:
coderfeli
2025-04-07 10:04:22 +00:00
parent 57c9d84eb1
commit 4e644a33ab
4 changed files with 23 additions and 14 deletions

View File

@@ -480,9 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
const auto seqstart_q_host = to_seqstarts(seqlen_qs);
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
std::vector<int32_t> page_idx_host(seqstart_q_host.back(), 0);
std::iota(page_idx_host.begin(), page_idx_host.end(), 0);
// iota_shuffle(page_idx_host.begin(), page_idx_host.end(), 0);
// std::vector<int32_t> page_idx_host(seqstart_k_host.back(), 0);
ck_tile::HostTensor<int32_t> page_idx_host({seqstart_k_host.back()});
// std::iota(page_idx_host.begin(), page_idx_host.end(), 0);
iota_shuffle(page_idx_host.mData.begin(), page_idx_host.mData.end(), 0);
page_idx_host.savetxt("page_idx_host.txt", "int");
// for (int i = 0; i < page_idx_host.get_element_space_size(); i++) {
// page_idx_host(i) = (i + 19) % page_idx_host.size();
// }
using TypeConfig = FmhaFwdTypeConfig<DataTypeConfig>;
@@ -605,7 +610,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
printf("shape %d %d %d %d\n", shape_batch, nhead_k, shape_seqlen_k, seqstart_q_host.back());
ck_tile::HostTensor<KDataType> k_host_sgl({seqstart_q_host.back(), nhead_k, hdim_q});
ck_tile::HostTensor<KDataType> k_host_sgl({seqstart_k_host.back(), nhead_k, hdim_q});
/// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode
ck_tile::HostTensor<KDataType> knew_host(
@@ -748,10 +753,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
}
k_host_sgl.ForEach([&](auto& self, auto i) {
self(i) = k_host(0, page_idx_host[i[0]], i[1], i[2]);
k_host.ForEach([&](auto& self, auto i) {
k_host_sgl(page_idx_host(i[1]), i[2], i[3]) = self(i);
// self(i) = k_host(0, page_idx_host[i[0]], i[1], i[2]);
});
// k_host.savetxt("k_host.txt");
// k_host_sgl.savetxt("k_host_sgl.txt");
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0);
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0);
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
@@ -1185,7 +1192,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto o_naive_ref = o_naive_buf.ToHost<ODataType>();
o_buf.FromDevice(o_host.data()); // TODO: ugly
auto [rtol_, atol_] = get_elimit<DataTypeConfig>(init_method);
bool pass_ = ck_tile::check_err(
o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_);
@@ -1527,6 +1533,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
// clang-format on
// o_host_result.savetxt("o_host_result.txt");
// o_host_ref.savetxt("o_host_ref.txt");
auto [rtol, atol] = get_elimit<DataTypeConfig>(init_method);
bool cur_pass = ck_tile::check_err(
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);