fix k origin

This commit is contained in:
coderfeli
2025-04-07 10:04:22 +00:00
parent 57c9d84eb1
commit 4e644a33ab
4 changed files with 23 additions and 14 deletions

View File

@@ -480,9 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
const auto seqstart_q_host = to_seqstarts(seqlen_qs);
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
std::vector<int32_t> page_idx_host(seqstart_q_host.back(), 0);
std::iota(page_idx_host.begin(), page_idx_host.end(), 0);
// iota_shuffle(page_idx_host.begin(), page_idx_host.end(), 0);
// std::vector<int32_t> page_idx_host(seqstart_k_host.back(), 0);
ck_tile::HostTensor<int32_t> page_idx_host({seqstart_k_host.back()});
// std::iota(page_idx_host.begin(), page_idx_host.end(), 0);
iota_shuffle(page_idx_host.mData.begin(), page_idx_host.mData.end(), 0);
page_idx_host.savetxt("page_idx_host.txt", "int");
// for (int i = 0; i < page_idx_host.get_element_space_size(); i++) {
// page_idx_host(i) = (i + 19) % page_idx_host.size();
// }
using TypeConfig = FmhaFwdTypeConfig<DataTypeConfig>;
@@ -605,7 +610,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
printf("shape %d %d %d %d\n", shape_batch, nhead_k, shape_seqlen_k, seqstart_q_host.back());
ck_tile::HostTensor<KDataType> k_host_sgl({seqstart_q_host.back(), nhead_k, hdim_q});
ck_tile::HostTensor<KDataType> k_host_sgl({seqstart_k_host.back(), nhead_k, hdim_q});
/// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode
ck_tile::HostTensor<KDataType> knew_host(
@@ -748,10 +753,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
}
k_host_sgl.ForEach([&](auto& self, auto i) {
self(i) = k_host(0, page_idx_host[i[0]], i[1], i[2]);
k_host.ForEach([&](auto& self, auto i) {
k_host_sgl(page_idx_host(i[1]), i[2], i[3]) = self(i);
// self(i) = k_host(0, page_idx_host[i[0]], i[1], i[2]);
});
// k_host.savetxt("k_host.txt");
// k_host_sgl.savetxt("k_host_sgl.txt");
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0);
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0);
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
@@ -1185,7 +1192,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto o_naive_ref = o_naive_buf.ToHost<ODataType>();
o_buf.FromDevice(o_host.data()); // TODO: ugly
auto [rtol_, atol_] = get_elimit<DataTypeConfig>(init_method);
bool pass_ = ck_tile::check_err(
o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_);
@@ -1527,6 +1533,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
// clang-format on
// o_host_result.savetxt("o_host_result.txt");
// o_host_ref.savetxt("o_host_ref.txt");
auto [rtol, atol] = get_elimit<DataTypeConfig>(init_method);
bool cur_pass = ck_tile::check_err(
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);

View File

@@ -188,10 +188,10 @@ struct page_tile_with_static_distribution
array<index_t, NDimY>{0}));
#endif
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp ={0,
window_origin[1] + window_adaptor_thread_coord_tmp.get_bottom_index()[1]};
// BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
// window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
window_origin + tuple<index_t, index_t>(0, window_adaptor_thread_coord_tmp.get_bottom_index()[1]);
// tuple<index_t, index_t>(0, window_adaptor_thread_coord_tmp.get_bottom_index()[1]);
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);

View File

@@ -988,6 +988,8 @@ struct FmhaFwdKernel
{
batch_offset_v = key_start;
}
kargs.page_idx += key_start;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
batch_offset_bias = query_start * kargs.stride_bias + key_start;

View File

@@ -243,11 +243,10 @@ struct BlockFmhaPipelineQRKSVS
return o_acc;
}
}
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0});
{seqlen_k_start, 0}); //todo fixme felix
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window =
@@ -283,7 +282,7 @@ struct BlockFmhaPipelineQRKSVS
statically_indexed_array<index_t, NR> offsets;
static_for<0, NR, 1>{}([&](auto n0) {
offsets[n0] = page_idx[c_coord[0] + 64 * n0.value] * 128; // Problem::kN_;
offsets[n0] = page_idx[i_total_loops * kN0 + c_coord[0] + 64 * n0.value] * kQKHeaddim; // Problem::kN_;
});
auto k_dram_window = make_tile_window_paged(
k_dram_block_window.get_bottom_tensor_view(),