mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
fix k origin
This commit is contained in:
@@ -480,9 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const auto seqstart_q_host = to_seqstarts(seqlen_qs);
|
||||
const auto seqstart_k_host = to_seqstarts(seqlen_ks);
|
||||
const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads);
|
||||
std::vector<int32_t> page_idx_host(seqstart_q_host.back(), 0);
|
||||
std::iota(page_idx_host.begin(), page_idx_host.end(), 0);
|
||||
// iota_shuffle(page_idx_host.begin(), page_idx_host.end(), 0);
|
||||
// std::vector<int32_t> page_idx_host(seqstart_k_host.back(), 0);
|
||||
ck_tile::HostTensor<int32_t> page_idx_host({seqstart_k_host.back()});
|
||||
// std::iota(page_idx_host.begin(), page_idx_host.end(), 0);
|
||||
iota_shuffle(page_idx_host.mData.begin(), page_idx_host.mData.end(), 0);
|
||||
page_idx_host.savetxt("page_idx_host.txt", "int");
|
||||
// for (int i = 0; i < page_idx_host.get_element_space_size(); i++) {
|
||||
// page_idx_host(i) = (i + 19) % page_idx_host.size();
|
||||
// }
|
||||
|
||||
using TypeConfig = FmhaFwdTypeConfig<DataTypeConfig>;
|
||||
|
||||
@@ -605,7 +610,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q)
|
||||
: get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q));
|
||||
printf("shape %d %d %d %d\n", shape_batch, nhead_k, shape_seqlen_k, seqstart_q_host.back());
|
||||
ck_tile::HostTensor<KDataType> k_host_sgl({seqstart_q_host.back(), nhead_k, hdim_q});
|
||||
ck_tile::HostTensor<KDataType> k_host_sgl({seqstart_k_host.back(), nhead_k, hdim_q});
|
||||
|
||||
/// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode
|
||||
ck_tile::HostTensor<KDataType> knew_host(
|
||||
@@ -748,10 +753,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
}
|
||||
}
|
||||
k_host_sgl.ForEach([&](auto& self, auto i) {
|
||||
self(i) = k_host(0, page_idx_host[i[0]], i[1], i[2]);
|
||||
|
||||
k_host.ForEach([&](auto& self, auto i) {
|
||||
k_host_sgl(page_idx_host(i[1]), i[2], i[3]) = self(i);
|
||||
// self(i) = k_host(0, page_idx_host[i[0]], i[1], i[2]);
|
||||
});
|
||||
// k_host.savetxt("k_host.txt");
|
||||
// k_host_sgl.savetxt("k_host_sgl.txt");
|
||||
iota_shuffle(block_table_host.begin(), block_table_host.end(), 0);
|
||||
iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0);
|
||||
ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes());
|
||||
@@ -1185,7 +1192,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
auto o_naive_ref = o_naive_buf.ToHost<ODataType>();
|
||||
o_buf.FromDevice(o_host.data()); // TODO: ugly
|
||||
|
||||
auto [rtol_, atol_] = get_elimit<DataTypeConfig>(init_method);
|
||||
bool pass_ = ck_tile::check_err(
|
||||
o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_);
|
||||
@@ -1527,6 +1533,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); });
|
||||
// clang-format on
|
||||
|
||||
// o_host_result.savetxt("o_host_result.txt");
|
||||
// o_host_ref.savetxt("o_host_ref.txt");
|
||||
auto [rtol, atol] = get_elimit<DataTypeConfig>(init_method);
|
||||
bool cur_pass = ck_tile::check_err(
|
||||
o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
|
||||
@@ -188,10 +188,10 @@ struct page_tile_with_static_distribution
|
||||
array<index_t, NDimY>{0}));
|
||||
#endif
|
||||
|
||||
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp ={0,
|
||||
window_origin[1] + window_adaptor_thread_coord_tmp.get_bottom_index()[1]};
|
||||
// BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
// window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
|
||||
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
|
||||
window_origin + tuple<index_t, index_t>(0, window_adaptor_thread_coord_tmp.get_bottom_index()[1]);
|
||||
// tuple<index_t, index_t>(0, window_adaptor_thread_coord_tmp.get_bottom_index()[1]);
|
||||
const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
|
||||
bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
|
||||
|
||||
|
||||
@@ -988,6 +988,8 @@ struct FmhaFwdKernel
|
||||
{
|
||||
batch_offset_v = key_start;
|
||||
}
|
||||
|
||||
kargs.page_idx += key_start;
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
batch_offset_bias = query_start * kargs.stride_bias + key_start;
|
||||
|
||||
@@ -243,11 +243,10 @@ struct BlockFmhaPipelineQRKSVS
|
||||
return o_acc;
|
||||
}
|
||||
}
|
||||
|
||||
auto k_dram_block_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_k_start, 0});
|
||||
{seqlen_k_start, 0}); //todo fixme felix
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window =
|
||||
@@ -283,7 +282,7 @@ struct BlockFmhaPipelineQRKSVS
|
||||
statically_indexed_array<index_t, NR> offsets;
|
||||
|
||||
static_for<0, NR, 1>{}([&](auto n0) {
|
||||
offsets[n0] = page_idx[c_coord[0] + 64 * n0.value] * 128; // Problem::kN_;
|
||||
offsets[n0] = page_idx[i_total_loops * kN0 + c_coord[0] + 64 * n0.value] * kQKHeaddim; // Problem::kN_;
|
||||
});
|
||||
auto k_dram_window = make_tile_window_paged(
|
||||
k_dram_block_window.get_bottom_tensor_view(),
|
||||
|
||||
Reference in New Issue
Block a user