mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Undo removing necessary value-overwrite logic
This commit is contained in:
@@ -90,8 +90,16 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
static_cast<LSEDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
|
||||
auto lse_acc_lds_for_write = make_tensor_view<address_space_enum::lds>(
|
||||
lse_acc_lds_ptr, Policy::template MakeLSEaccLdsBlockDescriptor<Problem>());
|
||||
#if 0
|
||||
auto lse_acc_lds_for_read = make_tensor_view<address_space_enum::lds>(
|
||||
lse_acc_lds_ptr, Policy::template MakeLSEaccTLdsBlockDescriptor<Problem>());
|
||||
#endif
|
||||
auto lse_acc_lds_write_window = make_tile_window(
|
||||
lse_acc_lds_for_write, make_tuple(number<kMaxSplits>{}, number<kM0>{}), {0, 0});
|
||||
#if 0
|
||||
auto lse_acc_lds_read_window = make_tile_window(
|
||||
lse_acc_lds_for_read, make_tuple(number<kM0>{}, number<kMaxSplits>{}), {0, 0});
|
||||
#endif
|
||||
auto lse_acc_lds_ms_m0_for_write = Policy::template MakeLSEaccLdsBlockDescriptor<Problem>();
|
||||
|
||||
auto lse_acc_dist = Policy::template MakeLSEaccDramTileDistribution<Problem>();
|
||||
@@ -101,11 +109,15 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
lse_acc_dram_block_window_tmp.get_window_origin(),
|
||||
lse_acc_dist);
|
||||
|
||||
// copy lse_acc to LDS
|
||||
auto lse_acc = load_tile(lse_acc_dram_window);
|
||||
store_tile(lse_acc_lds_write_window, lse_acc);
|
||||
// use LDS to transpose lse_accum from [kMaxSplits, kM0] to [kM0, kMaxSplits]
|
||||
auto lse_acc_tile = load_tile(lse_acc_dram_window);
|
||||
store_tile(lse_acc_lds_write_window, lse_acc_tile);
|
||||
block_sync_lds();
|
||||
|
||||
#if 0
|
||||
auto lse_accum = load_tile(lse_acc_lds_read_window,
|
||||
Policy::template MakeLSEaccTDramTileDistribution<Problem>());
|
||||
#else
|
||||
auto lse_accum_dist = Policy::template MakeLSEaccTDramTileDistribution<Problem>();
|
||||
auto lse_accum = make_static_distributed_tensor<LSEDataType>(lse_accum_dist);
|
||||
|
||||
@@ -127,10 +139,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
|
||||
auto offset =
|
||||
lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(col, row));
|
||||
lse_accum(distributed_indices) = lse_acc_lds_ptr[offset];
|
||||
if(col < num_splits)
|
||||
{
|
||||
lse_accum(distributed_indices) = lse_acc_lds_ptr[offset];
|
||||
}
|
||||
else
|
||||
{
|
||||
lse_accum(distributed_indices) = -numeric<LSEDataType>::infinity();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
#endif
|
||||
|
||||
// calculate row_max of lse_accum
|
||||
const auto f_max = [](auto e0, auto e1) { return ck_tile::max(e0, e1); };
|
||||
@@ -141,8 +161,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
block_tile_reduce_sync(lse_max, f_max, bool_constant<false>{});
|
||||
|
||||
static const auto get_validated_m = [](LSEDataType raw_m) {
|
||||
/// NOTICE: bias might be materialized mask including -inf values, need
|
||||
/// consideration
|
||||
return raw_m == -numeric<LSEDataType>::infinity() ? type_convert<LSEDataType>(0.f)
|
||||
: raw_m;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user