mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
fix k origin
This commit is contained in:
@@ -480,9 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const auto seqstart_q_host = to_seqstarts(seqlen_qs);
|
||||
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
|
||||
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
|
||||
std::vector<int32_t> page_idx_host(seqstart_q_host.back(), 0);
|
||||
std::iota(page_idx_host.begin(), page_idx_host.end(), 0);
|
||||
// iota_shuffle(page_idx_host.begin(), page_idx_host.end(), 0);
|
||||
// std::vector<int32_t> page_idx_host(seqstart_k_host.back(), 0);
|
||||
ck_tile::HostTensor<int32_t> page_idx_host({seqstart_k_host.back()});
|
||||
// std::iota(page_idx_host.begin(), page_idx_host.end(), 0);
|
||||
iota_shuffle(page_idx_host.mData.begin(), page_idx_host.mData.end(), 0);
|
||||
page_idx_host.savetxt("page_idx_host.txt", "int");
|
||||
// for (int i = 0; i < page_idx_host.get_element_space_size(); i++) {
|
||||
// page_idx_host(i) = (i + 19) % page_idx_host.size();
|
||||
// }
|
||||
|
||||
using TypeConfig = FmhaFwdTypeConfig<DataTypeConfig>;
|
||||
|
||||
@@ -605,7 +610,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
|
||||
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
|
||||
printf("shape %d %d %d %d\n", shape_batch, nhead_k, shape_seqlen_k, seqstart_q_host.back());
|
||||
ck_tile::HostTensor<KDataType> k_host_sgl({seqstart_q_host.back(), nhead_k, hdim_q});
|
||||
ck_tile::HostTensor<KDataType> k_host_sgl({seqstart_k_host.back(), nhead_k, hdim_q});
|
||||
|
||||
/// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode
|
||||
ck_tile::HostTensor<KDataType> knew_host(
|
||||
@@ -748,10 +753,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
}
|
||||
}
|
||||
k_host_sgl.ForEach([&](auto& self, auto i) {
|
||||
self(i) = k_host(0, page_idx_host[i[0]], i[1], i[2]);
|
||||
|
||||
k_host.ForEach([&](auto& self, auto i) {
|
||||
k_host_sgl(page_idx_host(i[1]), i[2], i[3]) = self(i);
|
||||
// self(i) = k_host(0, page_idx_host[i[0]], i[1], i[2]);
|
||||
});
|
||||
// k_host.savetxt("k_host.txt");
|
||||
// k_host_sgl.savetxt("k_host_sgl.txt");
|
||||
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0);
|
||||
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0);
|
||||
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
||||
@@ -1185,7 +1192,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
auto o_naive_ref = o_naive_buf.ToHost<ODataType>();
|
||||
o_buf.FromDevice(o_host.data()); // TODO: ugly
|
||||
|
||||
auto [rtol_, atol_] = get_elimit<DataTypeConfig>(init_method);
|
||||
bool pass_ = ck_tile::check_err(
|
||||
o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_);
|
||||
@@ -1527,6 +1533,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
|
||||
// clang-format on
|
||||
|
||||
// o_host_result.savetxt("o_host_result.txt");
|
||||
// o_host_ref.savetxt("o_host_ref.txt");
|
||||
auto [rtol, atol] = get_elimit<DataTypeConfig>(init_method);
|
||||
bool cur_pass = ck_tile::check_err(
|
||||
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
|
||||
Reference in New Issue
Block a user