diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 0a35322cdf..aa22d8fd0d 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -480,9 +480,14 @@ bool run(const ck_tile::ArgParser& arg_parser) const auto seqstart_q_host = to_seqstarts(seqlen_qs); const auto seqstart_k_host = to_seqstarts(seqlen_ks); const auto seqstart_k_with_padding_host = to_seqstarts(seqlen_kpads); - std::vector page_idx_host(seqstart_q_host.back(), 0); - std::iota(page_idx_host.begin(), page_idx_host.end(), 0); - // iota_shuffle(page_idx_host.begin(), page_idx_host.end(), 0); + // std::vector page_idx_host(seqstart_k_host.back(), 0); + ck_tile::HostTensor page_idx_host({seqstart_k_host.back()}); + // std::iota(page_idx_host.begin(), page_idx_host.end(), 0); + iota_shuffle(page_idx_host.mData.begin(), page_idx_host.mData.end(), 0); + page_idx_host.savetxt("page_idx_host.txt", "int"); + // for (int i = 0; i < page_idx_host.get_element_space_size(); i++) { + // page_idx_host(i) = (i + 19) % page_idx_host.size(); + // } using TypeConfig = FmhaFwdTypeConfig; @@ -605,7 +610,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ? get_lengths(i_perm, max_num_page_blocks, nhead_k, page_block_size, hdim_q) : get_lengths(i_perm, shape_batch, nhead_k, shape_seqlen_k, hdim_q)); printf("shape %d %d %d %d\n", shape_batch, nhead_k, shape_seqlen_k, seqstart_q_host.back()); - ck_tile::HostTensor k_host_sgl({seqstart_q_host.back(), nhead_k, hdim_q}); + ck_tile::HostTensor k_host_sgl({seqstart_k_host.back(), nhead_k, hdim_q}); /// NOTICE: always use same shape for knew_host & vnew_host in batch/group mode ck_tile::HostTensor knew_host( @@ -748,10 +753,12 @@ bool run(const ck_tile::ArgParser& arg_parser) } } } - k_host_sgl.ForEach([&](auto& self, auto i) { - self(i) = k_host(0, page_idx_host[i[0]], i[1], i[2]); - + k_host.ForEach([&](auto& self, auto i) { + k_host_sgl(page_idx_host(i[1]), i[2], i[3]) = self(i); + // self(i) = k_host(0, page_idx_host[i[0]], i[1], i[2]); }); + // k_host.savetxt("k_host.txt"); + // k_host_sgl.savetxt("k_host_sgl.txt"); iota_shuffle(block_table_host.begin(), block_table_host.end(), 0); iota_shuffle(cache_batch_idx_host.begin(), cache_batch_idx_host.end(), 0); ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); @@ -1185,7 +1192,6 @@ bool run(const ck_tile::ArgParser& arg_parser) auto o_naive_ref = o_naive_buf.ToHost(); o_buf.FromDevice(o_host.data()); // TODO: ugly - auto [rtol_, atol_] = get_elimit(init_method); bool pass_ = ck_tile::check_err( o_host, o_naive_ref, std::string("OUT Error: Incorrect results!"), rtol_, atol_); @@ -1527,6 +1533,8 @@ bool run(const ck_tile::ArgParser& arg_parser) else o_host_result.ForEach([&](auto& self, auto idx) { self(idx) = o_host(b_idx, idx[1] + query_offset, idx[0], idx[2]); }); // clang-format on + // o_host_result.savetxt("o_host_result.txt"); + // o_host_ref.savetxt("o_host_ref.txt"); auto [rtol, atol] = get_elimit(init_method); bool cur_pass = ck_tile::check_err( o_host_result, o_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); diff --git a/include/ck_tile/core/tensor/tile_window_paged.hpp b/include/ck_tile/core/tensor/tile_window_paged.hpp index 5cf77a1ecd..bf80105cc4 100644 --- a/include/ck_tile/core/tensor/tile_window_paged.hpp +++ b/include/ck_tile/core/tensor/tile_window_paged.hpp @@ -188,10 +188,10 @@ struct page_tile_with_static_distribution array{0})); #endif + BottomTensorIndex bottom_tensor_thread_origin_idx_tmp ={0, + window_origin[1] + window_adaptor_thread_coord_tmp.get_bottom_index()[1]}; // BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = - // window_origin + window_adaptor_thread_coord_tmp.get_bottom_index(); - BottomTensorIndex bottom_tensor_thread_origin_idx_tmp = - window_origin + tuple(0, window_adaptor_thread_coord_tmp.get_bottom_index()[1]); + // tuple(0, window_adaptor_thread_coord_tmp.get_bottom_index()[1]); const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate( bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 2ea97f6e31..f2518b434b 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -988,6 +988,8 @@ struct FmhaFwdKernel { batch_offset_v = key_start; } + + kargs.page_idx += key_start; if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = query_start * kargs.stride_bias + key_start; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp index 42a69c019e..aa414d35c1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp @@ -243,11 +243,10 @@ struct BlockFmhaPipelineQRKSVS return o_acc; } } - auto k_dram_block_window = make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), k_dram_block_window_tmp.get_window_lengths(), - {seqlen_k_start, 0}); + {seqlen_k_start, 0}); //todo fixme felix const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); auto bias_dram_window = @@ -283,7 +282,7 @@ struct BlockFmhaPipelineQRKSVS statically_indexed_array offsets; static_for<0, NR, 1>{}([&](auto n0) { - offsets[n0] = page_idx[c_coord[0] + 64 * n0.value] * 128; // Problem::kN_; + offsets[n0] = page_idx[i_total_loops * kN0 + c_coord[0] + 64 * n0.value] * kQKHeaddim; // Problem::kN_; }); auto k_dram_window = make_tile_window_paged( k_dram_block_window.get_bottom_tensor_view(),