mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
[CK_TILE] Simplify the codes in splitkv_combine pipeline (#1549)
* Simplify the codes in splitkv_combine pipeline * Always set kPadSeqLenK=true for fmha splitkv kernels * Change in Oacc Alignment and TileDistribution to be more adaptable to tile sizes --------- Co-authored-by: Po Yen Chen <PoYen.Chen@amd.com>
This commit is contained in:
@@ -172,22 +172,27 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
lse_accum, sequence<1>{}, f_max, -numeric<LSEDataType>::infinity());
|
||||
block_tile_reduce_sync(lse_max, f_max, bool_constant<false>{});
|
||||
|
||||
static const auto get_validated_m = [](LSEDataType raw_m) {
|
||||
return raw_m == -numeric<LSEDataType>::infinity() ? type_convert<LSEDataType>(0.f)
|
||||
: raw_m;
|
||||
};
|
||||
|
||||
decltype(lse_accum) lse_exp;
|
||||
{
|
||||
constexpr auto spans = decltype(lse_exp)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
if(lse_max[i_idx] == -numeric<LSEDataType>::infinity())
|
||||
{
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
lse_exp(i_j_idx) =
|
||||
ck_tile::exp(lse_accum(i_j_idx) - get_validated_m(lse_max(i_idx)));
|
||||
});
|
||||
lse_exp(i_j_idx) = ck_tile::type_convert<LSEDataType>(0.0f);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
lse_exp(i_j_idx) = ck_tile::exp(lse_accum(i_j_idx) - lse_max(i_idx));
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -201,15 +206,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(lse_sum(i_idx) == 0.f || lse_sum(i_idx) != lse_sum(i_idx))
|
||||
{
|
||||
lse_logsum(i_idx) = numeric<LSEDataType>::infinity();
|
||||
}
|
||||
if(lse_sum[i_idx] == ck_tile::type_convert<LSEDataType>(0.0f))
|
||||
lse_logsum(i_idx) = -numeric<LSEDataType>::infinity();
|
||||
else
|
||||
{
|
||||
lse_logsum(i_idx) =
|
||||
ck_tile::log(lse_sum(i_idx)) + get_validated_m(lse_max(i_idx));
|
||||
}
|
||||
lse_logsum(i_idx) = ck_tile::log(lse_sum(i_idx)) + lse_max(i_idx);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -218,37 +218,47 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
constexpr auto spans = decltype(lse_accum)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
if(lse_logsum(i_idx) == -numeric<LSEDataType>::infinity())
|
||||
{
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
lse_accum.get_tile_distribution(), i_j_idx);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
lse_accum.get_tile_distribution(), i_j_idx);
|
||||
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
if(col < num_splits)
|
||||
{
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
if(col < num_splits)
|
||||
{
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
|
||||
lse_acc_lds(row, col) =
|
||||
ck_tile::exp(lse_accum(i_j_idx) - lse_logsum(i_idx));
|
||||
}
|
||||
});
|
||||
lse_acc_lds(row, col) = ck_tile::type_convert<LSEDataType>(0.0f);
|
||||
}
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
sweep_tile_span(spans[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
lse_accum.get_tile_distribution(), i_j_idx);
|
||||
|
||||
const auto col = x_indices.at(number<1>{});
|
||||
if(col < num_splits)
|
||||
{
|
||||
const auto row = x_indices.at(number<0>{});
|
||||
|
||||
lse_acc_lds(row, col) =
|
||||
ck_tile::exp(lse_accum(i_j_idx) - lse_logsum(i_idx));
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
constexpr auto spans = decltype(lse_logsum)::get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
|
||||
if(lse_logsum(i_idx) == numeric<LSEDataType>::infinity())
|
||||
{
|
||||
lse_logsum(i_idx) = -numeric<LSEDataType>::infinity();
|
||||
}
|
||||
});
|
||||
|
||||
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse_logsum));
|
||||
}
|
||||
|
||||
|
||||
@@ -21,14 +21,23 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
|
||||
{
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
return 16 / sizeof(OaccDataType);
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::kM0;
|
||||
constexpr index_t kNPerBlock = Problem::kN1;
|
||||
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
|
||||
constexpr index_t N0 = get_warp_size() / M2;
|
||||
constexpr index_t N1 = kNPerBlock / N0;
|
||||
|
||||
return min(N1, static_cast<index_t>(16 / sizeof(OaccDataType)));
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
|
||||
{
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
return 16 / sizeof(ODataType);
|
||||
return GetAlignmentOacc<Problem>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -150,16 +159,14 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOaccDramTileDistribution()
|
||||
{
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::kM0;
|
||||
constexpr index_t kNPerBlock = Problem::kN1;
|
||||
|
||||
constexpr index_t N1 = 16 / sizeof(OaccDataType);
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t M2 = get_warp_size() / N0;
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M2 = min(kMPerBlock / M1, get_warp_size());
|
||||
constexpr index_t N0 = get_warp_size() / M2;
|
||||
constexpr index_t N1 = kNPerBlock / N0;
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
|
||||
Reference in New Issue
Block a user