mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
Input/output permutation for fused attention (#460)
* reopen masking att instance due to CI is upgraded * re-enable instances previously failed on 9110 * enable ksize-kpadding pair validity test * add non-masked attention+permute test; expose masking boolean to attention kernel handles * disable bench * fix test * move files * bulk rename batched_gemm_masking_scale_softmax_gemm_permute to batched_gemm_softmax_gemm_permute * format * amend rename * disable bench in test * add mask/no-mask test for non-permute attention kernels * disable broken kernel instance * example working add non-permuted problem statement evaluating whether overhead comes from permutation or the extra kernel arg * interface for bias addition without implementing it * test and profiler running * tidy * mask type determined by enum class * unify example code * move masking specialization to its own header * align formats * extract helper functions * experiment merging dims for attn w/ permute; shows perf parity with attn wo/ permute * add tensor specialization to template args since tensor spec packed shows perf parity when permutation isn't needed remove redundant template args comment on 'packed' tensor specialization * grouped attention with input/output permute example * format * clean up * refactor acc0 tile visitor Co-authored-by: shaojiewang <wsjmessi@163.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -12,28 +12,91 @@
|
||||
|
||||
using namespace ck;
|
||||
|
||||
void traverse_using_space_filling_curve();
|
||||
void traverse_using_space_filling_curve_linear();
|
||||
void traverse_using_space_filling_curve_snakecurved();
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
(void)argc;
|
||||
(void)argv;
|
||||
|
||||
traverse_using_space_filling_curve();
|
||||
traverse_using_space_filling_curve_linear();
|
||||
traverse_using_space_filling_curve_snakecurved();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void traverse_using_space_filling_curve()
|
||||
void traverse_using_space_filling_curve_linear()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
using TensorLengths = Sequence<16, 10, 9>;
|
||||
using DimAccessOrder = Sequence<2, 0, 1>;
|
||||
using ScalarsPerAccess = Sequence<4, 2, 3>;
|
||||
using SpaceFillingCurve = SpaceFillingCurve<TensorLengths, DimAccessOrder, ScalarsPerAccess>;
|
||||
using TensorLengths = Sequence<3, 2, 2>;
|
||||
using DimAccessOrder = Sequence<2, 0, 1>;
|
||||
using ScalarsPerAccess = Sequence<1, 1, 1>;
|
||||
using SpaceFillingCurve =
|
||||
SpaceFillingCurve<TensorLengths, DimAccessOrder, ScalarsPerAccess, false>;
|
||||
|
||||
constexpr auto expected = make_tuple(make_tuple(0, 0, 0),
|
||||
make_tuple(0, 1, 0),
|
||||
make_tuple(1, 0, 0),
|
||||
make_tuple(1, 1, 0),
|
||||
make_tuple(2, 0, 0),
|
||||
make_tuple(2, 1, 0),
|
||||
make_tuple(0, 0, 1),
|
||||
make_tuple(0, 1, 1),
|
||||
make_tuple(1, 0, 1),
|
||||
make_tuple(1, 1, 1),
|
||||
make_tuple(2, 0, 1),
|
||||
make_tuple(2, 1, 1));
|
||||
|
||||
constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
static_assert(num_access == reduce_on_sequence(TensorLengths{} / ScalarsPerAccess{},
|
||||
math::multiplies{},
|
||||
Number<1>{}));
|
||||
|
||||
static_for<1, num_access, 1>{}([&](auto i) {
|
||||
constexpr auto idx_curr = SpaceFillingCurve::GetIndex(i);
|
||||
|
||||
static_assert(idx_curr[I0] == expected[i][I0]);
|
||||
static_assert(idx_curr[I1] == expected[i][I1]);
|
||||
static_assert(idx_curr[I2] == expected[i][I2]);
|
||||
|
||||
constexpr auto backward_step = SpaceFillingCurve::GetBackwardStep(i);
|
||||
constexpr auto expected_step = expected[i - I1] - expected[i];
|
||||
static_assert(backward_step[I0] == expected_step[I0]);
|
||||
static_assert(backward_step[I1] == expected_step[I1]);
|
||||
static_assert(backward_step[I2] == expected_step[I2]);
|
||||
});
|
||||
|
||||
static_for<0, num_access - 1, 1>{}([&](auto i) {
|
||||
constexpr auto idx_curr = SpaceFillingCurve::GetIndex(i);
|
||||
|
||||
static_assert(idx_curr[I0] == expected[i][I0]);
|
||||
static_assert(idx_curr[I1] == expected[i][I1]);
|
||||
static_assert(idx_curr[I2] == expected[i][I2]);
|
||||
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(i);
|
||||
constexpr auto expected_step = expected[i + I1] - expected[i];
|
||||
static_assert(forward_step[I0] == expected_step[I0]);
|
||||
static_assert(forward_step[I1] == expected_step[I1]);
|
||||
static_assert(forward_step[I2] == expected_step[I2]);
|
||||
});
|
||||
}
|
||||
|
||||
void traverse_using_space_filling_curve_snakecurved()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
using TensorLengths = Sequence<16, 10, 9>;
|
||||
using DimAccessOrder = Sequence<2, 0, 1>;
|
||||
using ScalarsPerAccess = Sequence<4, 2, 3>;
|
||||
using SpaceFillingCurve =
|
||||
SpaceFillingCurve<TensorLengths, DimAccessOrder, ScalarsPerAccess, true>;
|
||||
|
||||
constexpr auto expected = make_tuple(make_tuple(0, 0, 0),
|
||||
make_tuple(0, 2, 0),
|
||||
|
||||
Reference in New Issue
Block a user