mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Input/output permutation for fused attention (#460)
* reopen masking att instance due to CI is upgraded * re-enable instances previously failed on 9110 * enable ksize-kpadding pair validity test * add non-masked attention+permute test; expose masking boolean to attention kernel handles * disable bench * fix test * move files * bulk rename batched_gemm_masking_scale_softmax_gemm_permute to batched_gemm_softmax_gemm_permute * format * amend rename * disable bench in test * add mask/no-mask test for non-permute attention kernels * disable broken kernel instance * example working add non-permuted problem statement evaluating whether overhead comes from permutation or the extra kernel arg * interface for bias addition without implementing it * test and profiler running * tidy * mask type determined by enum class * unify example code * move masking specialization to its own header * align formats * extract helper functions * experiment merging dims for attn w/ permute; shows perf parity with attn wo/ permute * add tensor specialization to template args since tensor spec packed shows perf parity when permutation isn't needed remove redundant template args comment on 'packed' tensor specialization * grouped attention with input/output permute example * format * clean up * refactor acc0 tile visitor Co-authored-by: shaojiewang <wsjmessi@163.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -336,36 +336,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
|
||||
};
|
||||
|
||||
template <bool Pred>
|
||||
struct ElementOpPredicatedResetNaNToMinusInf;
|
||||
|
||||
template <>
|
||||
struct ElementOpPredicatedResetNaNToMinusInf<true>
|
||||
{
|
||||
template <typename ElementOp, typename OutT, typename InT>
|
||||
__host__ __device__ void Run(OutT& y, const ElementOp& op, const InT& x)
|
||||
{
|
||||
if(ck::math::isnan(x))
|
||||
{
|
||||
y = -ck::NumericLimits<float>::Infinity();
|
||||
}
|
||||
else
|
||||
{
|
||||
op(y, x);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementOpPredicatedResetNaNToMinusInf<false>
|
||||
{
|
||||
template <typename ElementOp, typename OutT, typename InT>
|
||||
__host__ __device__ void Run(OutT& y, const ElementOp& op, const InT& x)
|
||||
{
|
||||
op(y, x);
|
||||
}
|
||||
};
|
||||
|
||||
template <bool HasMainKBlockLoop, typename Block2CTileMap, typename C0MatrixMask>
|
||||
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
@@ -406,11 +376,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
return;
|
||||
}
|
||||
|
||||
// HACK: this force m/n_block_data_idx_on_grid into SGPR
|
||||
// HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR
|
||||
const index_t m_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
|
||||
|
||||
const index_t n_block_data_idx_on_grid =
|
||||
const index_t gemm1_n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
@@ -627,7 +597,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
true, // DstResetCoord
|
||||
NumGemmKPrefetchStage>(
|
||||
b1_grid_desc_bk0_n_bk1,
|
||||
make_multi_index(0, n_block_data_idx_on_grid, 0),
|
||||
make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0),
|
||||
b1_element_op,
|
||||
b1_block_desc_bk0_n_bk1,
|
||||
make_multi_index(0, 0, 0),
|
||||
@@ -745,29 +715,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
running_max = NumericLimits<FloatGemmAcc>::Lowest();
|
||||
running_max_new = NumericLimits<FloatGemmAcc>::Lowest();
|
||||
|
||||
// decoder lower triangular mask
|
||||
const auto thread_cluster_idx = threadid_to_m_n_thread_cluster_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
const auto thread_m_cluster_id = thread_cluster_idx[I0];
|
||||
const auto thread_n_cluster_id = thread_cluster_idx[I1];
|
||||
const index_t MPerRepeat = MPerBlock / MXdlPerWave;
|
||||
const index_t NPerRepeat = NPerBlock / NXdlPerWave;
|
||||
const index_t mstart = m_block_data_idx_on_grid + thread_m_cluster_id;
|
||||
|
||||
// gemm1 K loop
|
||||
index_t gemm1_k_block_outer_index = 0;
|
||||
do
|
||||
{
|
||||
if constexpr(MaskOutUpperTriangle)
|
||||
auto n_block_data_idx_on_grid =
|
||||
__builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
|
||||
if(c0_matrix_mask.IsTileSkippable(
|
||||
m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
|
||||
{
|
||||
auto gemm0_n_block_idx =
|
||||
__builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
|
||||
if(c0_matrix_mask.IsUpperTriangle(m_block_data_idx_on_grid, gemm0_n_block_idx) &&
|
||||
c0_matrix_mask.IsUpperTriangle(m_block_data_idx_on_grid + MPerBlock - 1,
|
||||
gemm0_n_block_idx))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// gemm0
|
||||
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
|
||||
@@ -789,60 +746,58 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
|
||||
// do MNK padding or upper triangular masking
|
||||
if constexpr(MaskOutUpperTriangle || PadN)
|
||||
{
|
||||
const index_t nstart = gemm1_k_block_outer_index * NPerBlock;
|
||||
// 8d thread_desc in thread scope
|
||||
constexpr auto c_thread_lengths =
|
||||
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
|
||||
|
||||
static_for<0, m0, 1>{}([&](auto m0_i) {
|
||||
const index_t m_global = mstart + m0_i * MPerRepeat;
|
||||
const index_t acc_idx_m0 = m0_i * n0 * n2 * n4;
|
||||
static_for<0, n0, 1>{}([&](auto n0_i) {
|
||||
// constexpr auto nrepeat_i = n0_i * NPerRepeat;
|
||||
// const index_t nstartxdl = nstart + nrepeat_i;
|
||||
const index_t nstartxdl = nstart + n0_i * NPerRepeat;
|
||||
const index_t acc_idx_n0 = acc_idx_m0 + n0_i * n2 * n4;
|
||||
static_for<0, n2, 1>{}([&](auto n2_i) {
|
||||
const index_t nstartgroup =
|
||||
nstartxdl + thread_n_cluster_id * n4 + n2_i * AccN3 * n4;
|
||||
const index_t acc_idx_n2 = acc_idx_n0 + n2_i * n4;
|
||||
static_for<0, n4, 1>{}([&](auto n4_i) {
|
||||
const index_t n_global = nstartgroup + n4_i;
|
||||
const auto acc_offset = Number<acc_idx_n2 + n4_i>{};
|
||||
if constexpr(MaskOutUpperTriangle)
|
||||
{
|
||||
if(c0_matrix_mask.IsMaskedElement(m_global, n_global))
|
||||
{
|
||||
acc_thread_buf(acc_offset) =
|
||||
-ck::NumericLimits<float>::Infinity();
|
||||
}
|
||||
else
|
||||
{
|
||||
acc_element_op(acc_thread_buf(acc_offset),
|
||||
acc_thread_buf[acc_offset]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// ignore m_global;
|
||||
if(c0_matrix_mask.IsNOutOfBound(n_global))
|
||||
{
|
||||
acc_thread_buf(acc_offset) =
|
||||
-ck::NumericLimits<float>::Infinity();
|
||||
}
|
||||
else
|
||||
{
|
||||
acc_element_op(acc_thread_buf(acc_offset),
|
||||
acc_thread_buf[acc_offset]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
// 8d block_desc in block scope
|
||||
constexpr auto c_block_lengths =
|
||||
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
|
||||
|
||||
constexpr auto M0 = c_block_lengths[I0];
|
||||
constexpr auto N0 = c_block_lengths[I1];
|
||||
constexpr auto M1 = c_block_lengths[I2];
|
||||
constexpr auto N1 = c_block_lengths[I3];
|
||||
constexpr auto M2 = c_block_lengths[I4];
|
||||
constexpr auto N2 = c_block_lengths[I5];
|
||||
constexpr auto N3 = c_block_lengths[I6];
|
||||
constexpr auto N4 = c_block_lengths[I7];
|
||||
|
||||
// works like multi-dimension static_for (static_ford), but provides both the linear
|
||||
// index as well as n-d index
|
||||
using Acc0TileIterator = SpaceFillingCurve<
|
||||
decltype(c_thread_lengths),
|
||||
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
|
||||
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
|
||||
false>; // SnakeCurved
|
||||
|
||||
auto acc0_thread_origin = blockwise_gemm.CalculateCThreadOriginDataIndex8D(
|
||||
Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{});
|
||||
|
||||
constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2)),
|
||||
make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{}));
|
||||
|
||||
static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) {
|
||||
auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin;
|
||||
auto m_local =
|
||||
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
|
||||
auto n_local =
|
||||
block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
|
||||
auto m_global = m_local + m_block_data_idx_on_grid;
|
||||
auto n_global = n_local + n_block_data_idx_on_grid;
|
||||
if(c0_matrix_mask.IsMaskedElement(m_global, n_global))
|
||||
{
|
||||
acc_thread_buf(i) = -ck::NumericLimits<float>::Infinity();
|
||||
}
|
||||
else
|
||||
{
|
||||
acc_element_op(acc_thread_buf(i), acc_thread_buf[i]);
|
||||
}
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, acc_thread_buf.Size(), 1>{}(
|
||||
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
|
||||
}
|
||||
|
||||
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
|
||||
|
||||
|
||||
Reference in New Issue
Block a user