diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index f2e966597c..01ec88b945 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1106,9 +1106,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::reference_batched_rotary_position_embedding( q_host_ref, rotary_cos_host, rotary_sin_host, is_rotary_interleaved, q_host_ref_ro); - #if 0 q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); }); - #endif } #if 0 HOST_DEBUG_STMTS { diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp index 7c3447ae2e..4a19116a0e 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp @@ -494,7 +494,7 @@ struct FmhaFwdAppendKVKernel const auto rotary_cos_dram_native = make_naive_tensor_view( reinterpret_cast(kargs.rotary_cos_ptr), - make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2), + make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2), make_tuple(kargs.rotary_dim / 2, 1), number<8>{}, number<1>{}); @@ -519,7 +519,7 @@ struct FmhaFwdAppendKVKernel const auto rotary_sin_dram_native = make_naive_tensor_view( reinterpret_cast(kargs.rotary_sin_ptr), - make_tuple(kargs.seqlen_k + kargs.seqlen_knew, kargs.rotary_dim / 2), + make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2), make_tuple(kargs.rotary_dim / 2, 1), number<8>{}, number<1>{}); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index a6e012df1e..efd13ed503 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -108,10 +108,6 @@ struct BlockFmhaFwdAppendKVPipeline { #if defined(ENABLE_DEVICE_DEBUG_STMTS) auto* const ksmem = reinterpret_cast(smem_ptr); - if(threadIdx.x == 0) - { - printf("\n"); - } #endif auto print_tile = [&](const auto& tile, index_t num_display_rows = -1) { @@ -140,6 +136,7 @@ struct BlockFmhaFwdAppendKVPipeline DEVICE_DEBUG_STMTS { + printf("\n"); for(int row = 0; row < (0 < num_display_rows ? std::min(num_display_rows, num_rows) : num_rows); ++row) @@ -195,7 +192,7 @@ struct BlockFmhaFwdAppendKVPipeline rotary_dim, thread_end); } - print_tile(knew_tile, 2); + // print_tile(knew_tile, 2); store_tile(k_dram_block_window, knew_tile); auto vnew_window = make_tile_window(