Use read descriptor to locate lds elements

This commit is contained in:
PoYen, Chen
2024-06-12 04:31:33 +00:00
parent fcf5cd5e57
commit ff61463cab

View File

@@ -100,7 +100,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto lse_acc_lds_read_window = make_tile_window(
lse_acc_lds_for_read, make_tuple(number<kM0>{}, number<kMaxSplits>{}), {0, 0});
#endif
auto lse_acc_lds_ms_m0_for_write = Policy::template MakeLSEaccLdsBlockDescriptor<Problem>();
auto lse_acc_lds_m0_ms_for_read = Policy::template MakeLSEaccTLdsBlockDescriptor<Problem>();
auto lse_acc_dist = Policy::template MakeLSEaccDramTileDistribution<Problem>();
auto lse_acc_dram_window =
@@ -137,8 +137,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const auto row = x_indices.at(number<0>{});
const auto col = x_indices.at(number<1>{});
auto offset =
lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(col, row));
auto offset = lse_acc_lds_m0_ms_for_read.calculate_offset(make_tuple(row, col));
if(col < num_splits)
{
lse_accum(distributed_indices) = lse_acc_lds_ptr[offset];
@@ -180,8 +179,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const auto row = x_indices.at(number<0>{});
const auto col = x_indices.at(number<1>{});
auto offset =
lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(col, row));
auto offset = lse_acc_lds_m0_ms_for_read.calculate_offset(make_tuple(row, col));
if(col < num_splits)
{
// from shared memory
@@ -232,8 +230,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
for(index_t col = 0; col < num_splits; ++col)
{
auto offset =
lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(col, row));
auto offset = lse_acc_lds_m0_ms_for_read.calculate_offset(make_tuple(row, col));
lse_acc_lds_ptr[offset] =
ck_tile::exp(lse_acc_lds_ptr[offset] - lse_logsum(distributed_indices));
}
@@ -286,7 +283,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const auto row = x_indices.at(number<0>{});
auto offset =
lse_acc_lds_ms_m0_for_write.calculate_offset(make_tuple(i_split, row));
lse_acc_lds_m0_ms_for_read.calculate_offset(make_tuple(row, i_split));
LSEDataType lse_scale = lse_acc_lds_ptr[offset];
o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices);