diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index e21cc0bcf2..4e7f5f8ab7 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -90,8 +90,16 @@ struct BlockFmhaFwdSplitKVCombinePipeline static_cast(static_cast(static_cast(smem_ptr))); auto lse_acc_lds_for_write = make_tensor_view( lse_acc_lds_ptr, Policy::template MakeLSEaccLdsBlockDescriptor()); +#if 0 + auto lse_acc_lds_for_read = make_tensor_view( + lse_acc_lds_ptr, Policy::template MakeLSEaccTLdsBlockDescriptor()); +#endif auto lse_acc_lds_write_window = make_tile_window( lse_acc_lds_for_write, make_tuple(number{}, number{}), {0, 0}); +#if 0 + auto lse_acc_lds_read_window = make_tile_window( + lse_acc_lds_for_read, make_tuple(number{}, number{}), {0, 0}); +#endif auto lse_acc_lds_ms_m0_for_write = Policy::template MakeLSEaccLdsBlockDescriptor(); auto lse_acc_dist = Policy::template MakeLSEaccDramTileDistribution(); @@ -101,11 +109,15 @@ struct BlockFmhaFwdSplitKVCombinePipeline lse_acc_dram_block_window_tmp.get_window_origin(), lse_acc_dist); - // copy lse_acc to LDS - auto lse_acc = load_tile(lse_acc_dram_window); - store_tile(lse_acc_lds_write_window, lse_acc); + // use LDS to transpose lse_accum from [kMaxSplits, kM0] to [kM0, kMaxSplits] + auto lse_acc_tile = load_tile(lse_acc_dram_window); + store_tile(lse_acc_lds_write_window, lse_acc_tile); block_sync_lds(); +#if 0 + auto lse_accum = load_tile(lse_acc_lds_read_window, + Policy::template MakeLSEaccTDramTileDistribution()); +#else auto lse_accum_dist = Policy::template MakeLSEaccTDramTileDistribution(); auto lse_accum = make_static_distributed_tensor(lse_accum_dist); @@ -127,10 +139,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline auto offset = lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(col, row)); - lse_accum(distributed_indices) = lse_acc_lds_ptr[offset]; + if(col < num_splits) + { + lse_accum(distributed_indices) = lse_acc_lds_ptr[offset]; + } + else + { + lse_accum(distributed_indices) = -numeric::infinity(); + } }); }); } +#endif // calculate row_max of lse_accum const auto f_max = [](auto e0, auto e1) { return ck_tile::max(e0, e1); }; @@ -141,8 +161,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline block_tile_reduce_sync(lse_max, f_max, bool_constant{}); static const auto get_validated_m = [](LSEDataType raw_m) { - /// NOTICE: bias might be materialized mask including -inf values, need - /// consideration return raw_m == -numeric::infinity() ? type_convert(0.f) : raw_m; };