Simplify pipeline source code

This commit is contained in:
PoYen, Chen
2024-06-12 09:17:04 +00:00
parent ff61463cab
commit e00ff9d246

View File

@@ -118,21 +118,17 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto lse_accum = load_tile(lse_acc_lds_read_window,
Policy::template MakeLSEaccTDramTileDistribution<Problem>());
#else
auto lse_accum_dist = Policy::template MakeLSEaccTDramTileDistribution<Problem>();
auto lse_accum = make_static_distributed_tensor<LSEDataType>(lse_accum_dist);
auto lse_accum = make_static_distributed_tensor<LSEDataType>(
Policy::template MakeLSEaccTDramTileDistribution<Problem>());
// copy LDS to lse_accum (transpose)
{
using DataType = LSEDataType;
using StaticTileDistribution = decltype(lse_accum_dist);
constexpr auto out_spans =
static_distributed_tensor<DataType,
StaticTileDistribution>::get_distributed_spans();
constexpr auto out_spans = decltype(lse_accum)::get_distributed_spans();
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) {
constexpr auto distributed_indices = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
StaticTileDistribution{}, distributed_indices);
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
lse_accum.get_tile_distribution(), i_j_idx);
const auto row = x_indices.at(number<0>{});
const auto col = x_indices.at(number<1>{});
@@ -140,11 +136,11 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto offset = lse_acc_lds_m0_ms_for_read.calculate_offset(make_tuple(row, col));
if(col < num_splits)
{
lse_accum(distributed_indices) = lse_acc_lds_ptr[offset];
lse_accum(i_j_idx) = lse_acc_lds_ptr[offset];
}
else
{
lse_accum(distributed_indices) = -numeric<LSEDataType>::infinity();
lse_accum(i_j_idx) = -numeric<LSEDataType>::infinity();
}
});
});
@@ -164,17 +160,16 @@ struct BlockFmhaFwdSplitKVCombinePipeline
: raw_m;
};
auto p_compute = make_static_distributed_tensor<LSEDataType>(
lse_accum.get_tile_distribution()); // Pcompute{j}
clear_tile(p_compute);
decltype(lse_accum) lse_exp;
clear_tile(lse_exp);
{
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
constexpr auto p_spans = decltype(lse_exp)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
p_compute.get_tile_distribution(), i_j_idx);
lse_exp.get_tile_distribution(), i_j_idx);
const auto row = x_indices.at(number<0>{});
const auto col = x_indices.at(number<1>{});
@@ -183,7 +178,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
if(col < num_splits)
{
// from shared memory
p_compute(i_j_idx) =
lse_exp(i_j_idx) =
ck_tile::exp(lse_acc_lds_ptr[offset] - get_validated_m(lse_max(i_idx)));
}
});
@@ -192,39 +187,34 @@ struct BlockFmhaFwdSplitKVCombinePipeline
__syncthreads();
auto lse_sum = block_tile_reduce<LSEDataType>(
p_compute, sequence<1>{}, f_sum, type_convert<LSEDataType>(0));
lse_exp, sequence<1>{}, f_sum, type_convert<LSEDataType>(0));
block_tile_reduce_sync(lse_sum, f_sum, bool_constant<false>{});
decltype(lse_max) lse_logsum;
{
constexpr auto out_spans = static_distributed_tensor<
LSEDataType,
decltype(lse_logsum.get_tile_distribution())>::get_distributed_spans();
constexpr auto out_spans = decltype(lse_logsum)::get_distributed_spans();
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
constexpr auto distributed_indices = make_tuple(idx0);
constexpr auto i_idx = make_tuple(idx0);
if(lse_sum(distributed_indices) == 0.f ||
lse_sum(distributed_indices) != lse_sum(distributed_indices))
if(lse_sum(i_idx) == 0.f || lse_sum(i_idx) != lse_sum(i_idx))
{
lse_logsum(distributed_indices) = numeric<LSEDataType>::infinity();
lse_logsum(i_idx) = numeric<LSEDataType>::infinity();
}
else
{
lse_logsum(distributed_indices) = ck_tile::log(lse_sum(distributed_indices)) +
get_validated_m(lse_max(distributed_indices));
lse_logsum(i_idx) =
ck_tile::log(lse_sum(i_idx)) + get_validated_m(lse_max(i_idx));
}
});
}
// write lse scales into LDS
{
constexpr auto out_spans =
static_distributed_tensor<LSEDataType, decltype(lse_sum.get_tile_distribution())>::
get_distributed_spans();
constexpr auto out_spans = decltype(lse_sum)::get_distributed_spans();
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
constexpr auto distributed_indices = make_tuple(idx0);
const auto x_indices = get_x_indices_from_distributed_indices(
lse_sum.get_tile_distribution(), distributed_indices);
constexpr auto i_idx = make_tuple(idx0);
const auto x_indices =
get_x_indices_from_distributed_indices(lse_sum.get_tile_distribution(), i_idx);
const auto row = x_indices.at(number<0>{});
@@ -232,7 +222,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
{
auto offset = lse_acc_lds_m0_ms_for_read.calculate_offset(make_tuple(row, col));
lse_acc_lds_ptr[offset] =
ck_tile::exp(lse_acc_lds_ptr[offset] - lse_logsum(distributed_indices));
ck_tile::exp(lse_acc_lds_ptr[offset] - lse_logsum(i_idx));
}
});
}
@@ -240,15 +230,13 @@ struct BlockFmhaFwdSplitKVCombinePipeline
if constexpr(kStoreLSE)
{
constexpr auto out_spans = static_distributed_tensor<
LSEDataType,
decltype(lse_logsum.get_tile_distribution())>::get_distributed_spans();
constexpr auto out_spans = decltype(lse_logsum)::get_distributed_spans();
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
constexpr auto distributed_indices = make_tuple(idx0);
constexpr auto i_idx = make_tuple(idx0);
if(lse_logsum(distributed_indices) == numeric<LSEDataType>::infinity())
if(lse_logsum(i_idx) == numeric<LSEDataType>::infinity())
{
lse_logsum(distributed_indices) = -numeric<LSEDataType>::infinity();
lse_logsum(i_idx) = -numeric<LSEDataType>::infinity();
}
});
@@ -264,21 +252,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist); // Pcompute{j}
clear_tile(o_acc);
const index_t new_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0;
const index_t padded_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0;
for(index_t i_split = 0; i_split < num_splits; ++i_split)
{
auto o_tile = load_tile(o_acc_dram_window);
{
using DataType = OaccDataType;
constexpr auto out_spans =
static_distributed_tensor<DataType,
decltype(o_acc_dist)>::get_distributed_spans();
constexpr auto out_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(out_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(out_spans[number<1>{}], [&](auto idx1) {
constexpr auto distributed_indices = make_tuple(idx0, idx1);
const auto x_indices =
get_x_indices_from_distributed_indices(o_acc_dist, distributed_indices);
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
o_acc.get_tile_distribution(), i_j_idx);
const auto row = x_indices.at(number<0>{});
@@ -286,12 +271,12 @@ struct BlockFmhaFwdSplitKVCombinePipeline
lse_acc_lds_m0_ms_for_read.calculate_offset(make_tuple(row, i_split));
LSEDataType lse_scale = lse_acc_lds_ptr[offset];
o_acc(distributed_indices) += lse_scale * o_tile(distributed_indices);
o_acc(i_j_idx) += lse_scale * o_tile(i_j_idx);
});
});
}
move_tile_window(o_acc_dram_window, {new_max_seqlen_q, 0});
move_tile_window(o_acc_dram_window, {padded_max_seqlen_q, 0});
}
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);