Undo removing necessary value-overwrite logic

This commit is contained in:
PoYen, Chen
2024-06-12 04:21:31 +00:00
parent e1b4ac293e
commit fcf5cd5e57

View File

@@ -90,8 +90,16 @@ struct BlockFmhaFwdSplitKVCombinePipeline
static_cast<LSEDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto lse_acc_lds_for_write = make_tensor_view<address_space_enum::lds>(
lse_acc_lds_ptr, Policy::template MakeLSEaccLdsBlockDescriptor<Problem>());
#if 0
auto lse_acc_lds_for_read = make_tensor_view<address_space_enum::lds>(
lse_acc_lds_ptr, Policy::template MakeLSEaccTLdsBlockDescriptor<Problem>());
#endif
auto lse_acc_lds_write_window = make_tile_window(
lse_acc_lds_for_write, make_tuple(number<kMaxSplits>{}, number<kM0>{}), {0, 0});
#if 0
auto lse_acc_lds_read_window = make_tile_window(
lse_acc_lds_for_read, make_tuple(number<kM0>{}, number<kMaxSplits>{}), {0, 0});
#endif
auto lse_acc_lds_ms_m0_for_write = Policy::template MakeLSEaccLdsBlockDescriptor<Problem>();
auto lse_acc_dist = Policy::template MakeLSEaccDramTileDistribution<Problem>();
@@ -101,11 +109,15 @@ struct BlockFmhaFwdSplitKVCombinePipeline
lse_acc_dram_block_window_tmp.get_window_origin(),
lse_acc_dist);
// copy lse_acc to LDS
auto lse_acc = load_tile(lse_acc_dram_window);
store_tile(lse_acc_lds_write_window, lse_acc);
// use LDS to transpose lse_accum from [kMaxSplits, kM0] to [kM0, kMaxSplits]
auto lse_acc_tile = load_tile(lse_acc_dram_window);
store_tile(lse_acc_lds_write_window, lse_acc_tile);
block_sync_lds();
#if 0
auto lse_accum = load_tile(lse_acc_lds_read_window,
Policy::template MakeLSEaccTDramTileDistribution<Problem>());
#else
auto lse_accum_dist = Policy::template MakeLSEaccTDramTileDistribution<Problem>();
auto lse_accum = make_static_distributed_tensor<LSEDataType>(lse_accum_dist);
@@ -127,10 +139,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto offset =
lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(col, row));
lse_accum(distributed_indices) = lse_acc_lds_ptr[offset];
if(col < num_splits)
{
lse_accum(distributed_indices) = lse_acc_lds_ptr[offset];
}
else
{
lse_accum(distributed_indices) = -numeric<LSEDataType>::infinity();
}
});
});
}
#endif
// calculate row_max of lse_accum
const auto f_max = [](auto e0, auto e1) { return ck_tile::max(e0, e1); };
@@ -141,8 +161,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
block_tile_reduce_sync(lse_max, f_max, bool_constant<false>{});
static const auto get_validated_m = [](LSEDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
return raw_m == -numeric<LSEDataType>::infinity() ? type_convert<LSEDataType>(0.f)
: raw_m;
};