mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-05 20:55:59 +00:00
Simplify pipeline source code
This commit is contained in:
@@ -118,21 +118,17 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
auto lse_accum = load_tile(lse_acc_lds_read_window,
|
||||
Policy::template MakeLSEaccTDramTileDistribution<Problem>());
|
||||
#else
|
||||
auto lse_accum_dist = Policy::template MakeLSEaccTDramTileDistribution<Problem>();
|
||||
auto lse_accum = make_static_distributed_tensor<LSEDataType>(lse_accum_dist);
|
||||
auto lse_accum = make_static_distributed_tensor<LSEDataType>(
|
||||
Policy::template MakeLSEaccTDramTileDistribution<Problem>());
|
||||
|
||||
// copy LDS to lse_accum (transpose)
|
||||
{
|
||||
using DataType = LSEDataType;
|
||||
using StaticTileDistribution = decltype(lse_accum_dist);
|
||||
constexpr auto out_spans =
|
||||
static_distributed_tensor<DataType,
|
||||
StaticTileDistribution>::get_distributed_spans();
|
||||
constexpr auto out_spans = decltype(lse_accum)::get_distributed_spans();
|
||||
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto distributed_indices = make_tuple(idx0, idx1);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
StaticTileDistribution{}, distributed_indices);
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
lse_accum.get_tile_distribution(), i_j_idx);
|
||||
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
@@ -140,11 +136,11 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
auto offset = lse_acc_lds_m0_ms_for_read.calculate_offset(make_tuple(row, col));
|
||||
if(col < num_splits)
|
||||
{
|
||||
lse_accum(distributed_indices) = lse_acc_lds_ptr[offset];
|
||||
lse_accum(i_j_idx) = lse_acc_lds_ptr[offset];
|
||||
}
|
||||
else
|
||||
{
|
||||
lse_accum(distributed_indices) = -numeric<LSEDataType>::infinity();
|
||||
lse_accum(i_j_idx) = -numeric<LSEDataType>::infinity();
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -164,17 +160,16 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
: raw_m;
|
||||
};
|
||||
|
||||
auto p_compute = make_static_distributed_tensor<LSEDataType>(
|
||||
lse_accum.get_tile_distribution()); // Pcompute{j}
|
||||
clear_tile(p_compute);
|
||||
decltype(lse_accum) lse_exp;
|
||||
clear_tile(lse_exp);
|
||||
{
|
||||
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
|
||||
constexpr auto p_spans = decltype(lse_exp)::get_distributed_spans();
|
||||
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
p_compute.get_tile_distribution(), i_j_idx);
|
||||
lse_exp.get_tile_distribution(), i_j_idx);
|
||||
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
@@ -183,7 +178,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
if(col < num_splits)
|
||||
{
|
||||
// from shared memory
|
||||
p_compute(i_j_idx) =
|
||||
lse_exp(i_j_idx) =
|
||||
ck_tile::exp(lse_acc_lds_ptr[offset] - get_validated_m(lse_max(i_idx)));
|
||||
}
|
||||
});
|
||||
@@ -192,39 +187,34 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
__syncthreads();
|
||||
|
||||
auto lse_sum = block_tile_reduce<LSEDataType>(
|
||||
p_compute, sequence<1>{}, f_sum, type_convert<LSEDataType>(0));
|
||||
lse_exp, sequence<1>{}, f_sum, type_convert<LSEDataType>(0));
|
||||
block_tile_reduce_sync(lse_sum, f_sum, bool_constant<false>{});
|
||||
|
||||
decltype(lse_max) lse_logsum;
|
||||
{
|
||||
constexpr auto out_spans = static_distributed_tensor<
|
||||
LSEDataType,
|
||||
decltype(lse_logsum.get_tile_distribution())>::get_distributed_spans();
|
||||
constexpr auto out_spans = decltype(lse_logsum)::get_distributed_spans();
|
||||
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto distributed_indices = make_tuple(idx0);
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(lse_sum(distributed_indices) == 0.f ||
|
||||
lse_sum(distributed_indices) != lse_sum(distributed_indices))
|
||||
if(lse_sum(i_idx) == 0.f || lse_sum(i_idx) != lse_sum(i_idx))
|
||||
{
|
||||
lse_logsum(distributed_indices) = numeric<LSEDataType>::infinity();
|
||||
lse_logsum(i_idx) = numeric<LSEDataType>::infinity();
|
||||
}
|
||||
else
|
||||
{
|
||||
lse_logsum(distributed_indices) = ck_tile::log(lse_sum(distributed_indices)) +
|
||||
get_validated_m(lse_max(distributed_indices));
|
||||
lse_logsum(i_idx) =
|
||||
ck_tile::log(lse_sum(i_idx)) + get_validated_m(lse_max(i_idx));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// write lse scales into LDS
|
||||
{
|
||||
constexpr auto out_spans =
|
||||
static_distributed_tensor<LSEDataType, decltype(lse_sum.get_tile_distribution())>::
|
||||
get_distributed_spans();
|
||||
constexpr auto out_spans = decltype(lse_sum)::get_distributed_spans();
|
||||
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto distributed_indices = make_tuple(idx0);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
lse_sum.get_tile_distribution(), distributed_indices);
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
const auto x_indices =
|
||||
get_x_indices_from_distributed_indices(lse_sum.get_tile_distribution(), i_idx);
|
||||
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
|
||||
@@ -232,7 +222,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
{
|
||||
auto offset = lse_acc_lds_m0_ms_for_read.calculate_offset(make_tuple(row, col));
|
||||
lse_acc_lds_ptr[offset] =
|
||||
ck_tile::exp(lse_acc_lds_ptr[offset] - lse_logsum(distributed_indices));
|
||||
ck_tile::exp(lse_acc_lds_ptr[offset] - lse_logsum(i_idx));
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -240,15 +230,13 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
constexpr auto out_spans = static_distributed_tensor<
|
||||
LSEDataType,
|
||||
decltype(lse_logsum.get_tile_distribution())>::get_distributed_spans();
|
||||
constexpr auto out_spans = decltype(lse_logsum)::get_distributed_spans();
|
||||
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto distributed_indices = make_tuple(idx0);
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(lse_logsum(distributed_indices) == numeric<LSEDataType>::infinity())
|
||||
if(lse_logsum(i_idx) == numeric<LSEDataType>::infinity())
|
||||
{
|
||||
lse_logsum(distributed_indices) = -numeric<LSEDataType>::infinity();
|
||||
lse_logsum(i_idx) = -numeric<LSEDataType>::infinity();
|
||||
}
|
||||
});
|
||||
|
||||
@@ -264,21 +252,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist); // Pcompute{j}
|
||||
clear_tile(o_acc);
|
||||
|
||||
const index_t new_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0;
|
||||
const index_t padded_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0;
|
||||
|
||||
for(index_t i_split = 0; i_split < num_splits; ++i_split)
|
||||
{
|
||||
auto o_tile = load_tile(o_acc_dram_window);
|
||||
{
|
||||
using DataType = OaccDataType;
|
||||
constexpr auto out_spans =
|
||||
static_distributed_tensor<DataType,
|
||||
decltype(o_acc_dist)>::get_distributed_spans();
|
||||
constexpr auto out_spans = decltype(o_acc)::get_distributed_spans();
|
||||
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto distributed_indices = make_tuple(idx0, idx1);
|
||||
const auto x_indices =
|
||||
get_x_indices_from_distributed_indices(o_acc_dist, distributed_indices);
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
o_acc.get_tile_distribution(), i_j_idx);
|
||||
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
|
||||
@@ -286,12 +271,12 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
lse_acc_lds_m0_ms_for_read.calculate_offset(make_tuple(row, i_split));
|
||||
|
||||
LSEDataType lse_scale = lse_acc_lds_ptr[offset];
|
||||
o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices);
|
||||
o_acc(i_j_idx) += lse_scale * o_tile(i_j_idx);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
move_tile_window(o_acc_dram_window, {new_max_seqlen_q, 0});
|
||||
move_tile_window(o_acc_dram_window, {padded_max_seqlen_q, 0});
|
||||
}
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
Reference in New Issue
Block a user