MNKO padding support on bmm+masking+scale+softmax+bmm+premute (#425)

* add lower triangle bmm

* init code for tile skipping

* functionality right with lower triangle mask

* add decoder lower triangular mask calculation

* use 7*13 group

* fix n2 compute error

* attention with lower triangle mask with tile skipping

* add template to distinguish masking kernel

* rename template and remove default template value

* remove lower triangle gemm reference struct

* add some comments on example

* add 10 instance for masking bmm + scale + softmax + bmm + permute kernels

* add test

* add test file

* add gtest for bmm masking scale softmax bmm permute

* clang-format

* fix compile error

* check lef bottom corner for tile skipping

* fix error: check left bottom corner for tile skipping

* add k padding

* add test and instance for MNK padding

* passing a mask struct

* fix instances

* delete used comments

* format

Co-authored-by: danyao12 <yaodan@dc-smc-13.amd.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
Shaojie WANG
2022-09-21 01:43:53 +08:00
committed by GitHub
parent 9f7c193064
commit ebab84b6f9
21 changed files with 1590 additions and 93 deletions

View File

@@ -76,7 +76,8 @@ template <typename FloatAB,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
bool PadN>
bool PadN,
bool MaskOutUpperTriangle>
struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
static_assert(LoopSched == LoopScheduler::Default,
@@ -97,6 +98,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
// Gemm1
static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
static constexpr auto B1K1 = Number<B1K1Value>{};
@@ -361,7 +366,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
};
template <bool HasMainKBlockLoop, typename Block2CTileMap>
template <bool HasMainKBlockLoop, typename Block2CTileMap, typename C0MatrixMask>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
const FloatAB* __restrict__ p_b1_grid,
@@ -377,22 +382,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap& block_2_ctile_map)
const Block2CTileMap& block_2_ctile_map,
const C0MatrixMask& c0_matrix_mask)
{
const auto a_grid_buf =
conditional_expr<PadN>(make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid,
a_grid_desc_ak0_m_ak1.GetElementSpaceSize(),
NumericLimits<FloatAB>::QuietNaN()),
make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()));
const auto b_grid_buf =
conditional_expr<PadN>(make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid,
b_grid_desc_bk0_n_bk1.GetElementSpaceSize(),
NumericLimits<FloatAB>::QuietNaN()),
make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()));
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize());
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
@@ -749,10 +745,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max = NumericLimits<FloatGemmAcc>::Lowest();
running_max_new = NumericLimits<FloatGemmAcc>::Lowest();
// decoder lower triangular mask
const auto thread_cluster_idx = threadid_to_m_n_thread_cluster_adaptor.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_n_cluster_id = thread_cluster_idx[I1];
const index_t MPerRepeat = MPerBlock / MXdlPerWave;
const index_t NPerRepeat = NPerBlock / NXdlPerWave;
const index_t mstart = m_block_data_idx_on_grid + thread_m_cluster_id;
// gemm1 K loop
index_t gemm1_k_block_outer_index = 0;
do
{
if constexpr(MaskOutUpperTriangle)
{
auto gemm0_n_block_idx =
__builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
if(c0_matrix_mask.IsUpperTriangle(m_block_data_idx_on_grid, gemm0_n_block_idx) &&
c0_matrix_mask.IsUpperTriangle(m_block_data_idx_on_grid + MPerBlock - 1,
gemm0_n_block_idx))
{
continue;
}
}
// gemm0
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
@@ -770,16 +786,63 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
acc_thread_buf,
num_k_block_main_loop);
// Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
static_for<0, acc_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
#else
static_for<0, acc_thread_buf.Size(), 1>{}([&](auto i) {
ElementOpPredicatedResetNaNToMinusInf<PadN>{}.Run(
acc_thread_buf(i), acc_element_op, acc_thread_buf[i]);
});
#endif
// do MNK padding or upper triangular masking
if constexpr(MaskOutUpperTriangle || PadN)
{
const index_t nstart = gemm1_k_block_outer_index * NPerBlock;
static_for<0, m0, 1>{}([&](auto m0_i) {
const index_t m_global = mstart + m0_i * MPerRepeat;
const index_t acc_idx_m0 = m0_i * n0 * n2 * n4;
static_for<0, n0, 1>{}([&](auto n0_i) {
// constexpr auto nrepeat_i = n0_i * NPerRepeat;
// const index_t nstartxdl = nstart + nrepeat_i;
const index_t nstartxdl = nstart + n0_i * NPerRepeat;
const index_t acc_idx_n0 = acc_idx_m0 + n0_i * n2 * n4;
static_for<0, n2, 1>{}([&](auto n2_i) {
const index_t nstartgroup =
nstartxdl + thread_n_cluster_id * n4 + n2_i * AccN3 * n4;
const index_t acc_idx_n2 = acc_idx_n0 + n2_i * n4;
static_for<0, n4, 1>{}([&](auto n4_i) {
const index_t n_global = nstartgroup + n4_i;
const auto acc_offset = Number<acc_idx_n2 + n4_i>{};
if constexpr(MaskOutUpperTriangle)
{
if(c0_matrix_mask.IsMaskedElement(m_global, n_global))
{
acc_thread_buf(acc_offset) =
-ck::NumericLimits<float>::Infinity();
}
else
{
acc_element_op(acc_thread_buf(acc_offset),
acc_thread_buf[acc_offset]);
}
}
else
{
// ignore m_global;
if(c0_matrix_mask.IsNOutOfBound(n_global))
{
acc_thread_buf(acc_offset) =
-ck::NumericLimits<float>::Infinity();
}
else
{
acc_element_op(acc_thread_buf(acc_offset),
acc_thread_buf[acc_offset]);
}
}
});
});
});
});
}
else
{
static_for<0, acc_thread_buf.Size(), 1>{}(
[&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm