mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
Input/output permutation for fused attention (#460)
* reopen masking att instance due to CI is upgraded * re-enable instances previously failed on 9110 * enable ksize-kpadding pair validity test * add non-masked attention+permute test; expose masking boolean to attention kernel handles * disable bench * fix test * move files * bulk rename batched_gemm_masking_scale_softmax_gemm_permute to batched_gemm_softmax_gemm_permute * format * amend rename * disable bench in test * add mask/no-mask test for non-permute attention kernels * disable broken kernel instance * example working add non-permuted problem statement evaluating whether overhead comes from permutation or the extra kernel arg * interface for bias addition without implementing it * test and profiler running * tidy * mask type determined by enum class * unify example code * move masking specialization to its own header * align formats * extract helper functions * experiment merging dims for attn w/ permute; shows perf parity with attn wo/ permute * add tensor specialization to template args since tensor spec packed shows perf parity when permutation isn't needed remove redundant template args comment on 'packed' tensor specialization * grouped attention with input/output permute example * format * clean up * refactor acc0 tile visitor Co-authored-by: shaojiewang <wsjmessi@163.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -151,6 +151,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
|
||||
return make_tuple(c_thread_m, c_thread_n);
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static auto
|
||||
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
|
||||
|
||||
return make_tuple(Number<m0>{},
|
||||
Number<n0>{},
|
||||
waveId_m,
|
||||
waveId_n,
|
||||
blk_idx[I0],
|
||||
blk_idx[I1],
|
||||
blk_idx[I2],
|
||||
blk_idx[I3]);
|
||||
}
|
||||
|
||||
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
|
||||
{
|
||||
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
|
||||
@@ -724,6 +745,21 @@ struct BlockwiseGemmXdlops_v2
|
||||
return make_tuple(c_thread_m, c_thread_n);
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
|
||||
__device__ static auto
|
||||
CalculateCThreadOriginDataIndex8D(Number<m0>, Number<n0>, Number<xdlops_i>, Number<blk_i>)
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
|
||||
|
||||
return make_tuple(
|
||||
m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
|
||||
}
|
||||
|
||||
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
|
||||
|
||||
__host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
|
||||
|
||||
Reference in New Issue
Block a user