diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp index 4e7f5f8ab7..386accfa5a 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp @@ -100,7 +100,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline auto lse_acc_lds_read_window = make_tile_window( lse_acc_lds_for_read, make_tuple(number{}, number{}), {0, 0}); #endif - auto lse_acc_lds_ms_m0_for_write = Policy::template MakeLSEaccLdsBlockDescriptor(); + auto lse_acc_lds_m0_ms_for_read = Policy::template MakeLSEaccTLdsBlockDescriptor(); auto lse_acc_dist = Policy::template MakeLSEaccDramTileDistribution(); auto lse_acc_dram_window = @@ -137,8 +137,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline const auto row = x_indices.at(number<0>{}); const auto col = x_indices.at(number<1>{}); - auto offset = - lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(col, row)); + auto offset = lse_acc_lds_m0_ms_for_read.calculate_offset(make_tuple(row, col)); if(col < num_splits) { lse_accum(distributed_indices) = lse_acc_lds_ptr[offset]; @@ -180,8 +179,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline const auto row = x_indices.at(number<0>{}); const auto col = x_indices.at(number<1>{}); - auto offset = - lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(col, row)); + auto offset = lse_acc_lds_m0_ms_for_read.calculate_offset(make_tuple(row, col)); if(col < num_splits) { // from shared memory @@ -232,8 +230,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline for(index_t col = 0; col < num_splits; ++col) { - auto offset = - lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(col, row)); + auto offset = lse_acc_lds_m0_ms_for_read.calculate_offset(make_tuple(row, col)); lse_acc_lds_ptr[offset] = ck_tile::exp(lse_acc_lds_ptr[offset] - lse_logsum(distributed_indices)); } @@ -286,7 +283,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline const auto row = x_indices.at(number<0>{}); auto offset = - lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(i_split, row)); + lse_acc_lds_m0_ms_for_read.calculate_offset(make_tuple(row, i_split)); LSEDataType lse_scale = lse_acc_lds_ptr[offset]; o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices);