Input/output permutation for fused attention (#460)

* reopen masking att instance due to CI is upgraded

* re-enable instances previously failed on 9110

* enable ksize-kpadding pair validity test

* add non-masked attention+permute test; expose masking boolean to attention kernel handles

* disable bench

* fix test

* move files

* bulk rename batched_gemm_masking_scale_softmax_gemm_permute to batched_gemm_softmax_gemm_permute

* format

* amend rename

* disable bench in test

* add mask/no-mask test for non-permute attention kernels

* disable broken kernel instance

* example working

add non-permuted problem statement

evaluating whether overhead comes from permutation or the extra kernel arg

* interface for bias addition without implementing it

* test and profiler running

* tidy

* mask type determined by enum class

* unify example code

* move masking specialization to its own header

* align formats

* extract helper functions

* experiment merging dims for attn w/ permute; shows perf parity with attn wo/ permute

* add tensor specialization to template args

since tensor spec packed shows perf parity when permutation isn't needed

remove redundant template args

comment on 'packed' tensor specialization

* grouped attention with input/output permute example

* format

* clean up

* refactor acc0 tile visitor

Co-authored-by: shaojiewang <wsjmessi@163.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
Anthony Chang
2022-10-28 04:58:20 +08:00
committed by GitHub
parent cd51732690
commit de37550f72
42 changed files with 2654 additions and 2196 deletions

View File

@@ -151,6 +151,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
return make_tuple(Number<m0>{},
Number<n0>{},
waveId_m,
waveId_n,
blk_idx[I0],
blk_idx[I1],
blk_idx[I2],
blk_idx[I3]);
}
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
{
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
@@ -724,6 +745,21 @@ struct BlockwiseGemmXdlops_v2
return make_tuple(c_thread_m, c_thread_n);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
__device__ static auto
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto waveId_n = wave_idx[I1];
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
return make_tuple(
m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
}
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),