mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Use read descriptor to locate lds elements
This commit is contained in:
@@ -100,7 +100,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
auto lse_acc_lds_read_window = make_tile_window(
|
||||
lse_acc_lds_for_read, make_tuple(number<kM0>{}, number<kMaxSplits>{}), {0, 0});
|
||||
#endif
|
||||
auto lse_acc_lds_ms_m0_for_write = Policy::template MakeLSEaccLdsBlockDescriptor<Problem>();
|
||||
auto lse_acc_lds_m0_ms_for_read = Policy::template MakeLSEaccTLdsBlockDescriptor<Problem>();
|
||||
|
||||
auto lse_acc_dist = Policy::template MakeLSEaccDramTileDistribution<Problem>();
|
||||
auto lse_acc_dram_window =
|
||||
@@ -137,8 +137,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
|
||||
auto offset =
|
||||
lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(col, row));
|
||||
auto offset = lse_acc_lds_m0_ms_for_read.calculate_offset(make_tuple(row, col));
|
||||
if(col < num_splits)
|
||||
{
|
||||
lse_accum(distributed_indices) = lse_acc_lds_ptr[offset];
|
||||
@@ -180,8 +179,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
|
||||
auto offset =
|
||||
lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(col, row));
|
||||
auto offset = lse_acc_lds_m0_ms_for_read.calculate_offset(make_tuple(row, col));
|
||||
if(col < num_splits)
|
||||
{
|
||||
// from shared memory
|
||||
@@ -232,8 +230,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
|
||||
for(index_t col = 0; col < num_splits; ++col)
|
||||
{
|
||||
auto offset =
|
||||
lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(col, row));
|
||||
auto offset = lse_acc_lds_m0_ms_for_read.calculate_offset(make_tuple(row, col));
|
||||
lse_acc_lds_ptr[offset] =
|
||||
ck_tile::exp(lse_acc_lds_ptr[offset] - lse_logsum(distributed_indices));
|
||||
}
|
||||
@@ -286,7 +283,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
|
||||
auto offset =
|
||||
lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(i_split, row));
|
||||
lse_acc_lds_m0_ms_for_read.calculate_offset(make_tuple(row, i_split));
|
||||
|
||||
LSEDataType lse_scale = lse_acc_lds_ptr[offset];
|
||||
o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices);
|
||||
|
||||
Reference in New Issue
Block a user