Input/output permutation for fused attention (#460)

* reopen masking att instance due to CI is upgraded

* re-enable instances previously failed on 9110

* enable ksize-kpadding pair validity test

* add non-masked attention+permute test; expose masking boolean to attention kernel handles

* disable bench

* fix test

* move files

* bulk rename batched_gemm_masking_scale_softmax_gemm_permute to batched_gemm_softmax_gemm_permute

* format

* amend rename

* disable bench in test

* add mask/no-mask test for non-permute attention kernels

* disable broken kernel instance

* example working

add non-permuted problem statement

evaluating whether overhead comes from permutation or the extra kernel arg

* interface for bias addition without implementing it

* test and profiler running

* tidy

* mask type determined by enum class

* unify example code

* move masking specialization to its own header

* align formats

* extract helper functions

* experiment merging dims for attn w/ permute; shows perf parity with attn wo/ permute

* add tensor specialization to template args

since tensor spec packed shows perf parity when permutation isn't needed

remove redundant template args

comment on 'packed' tensor specialization

* grouped attention with input/output permute example

* format

* clean up

* refactor acc0 tile visitor

Co-authored-by: shaojiewang <wsjmessi@163.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>

[ROCm/composable_kernel commit: de37550f72]
This commit is contained in:
Anthony Chang
2022-10-28 04:58:20 +08:00
committed by GitHub
parent 0606ddf8e1
commit 276dfdd457
42 changed files with 2654 additions and 2196 deletions

View File

@@ -41,7 +41,7 @@ add_subdirectory(batched_gemm)
add_subdirectory(batched_gemm_reduce)
add_subdirectory(batched_gemm_gemm)
add_subdirectory(batched_gemm_softmax_gemm)
add_subdirectory(batched_gemm_masking_scale_softmax_gemm_permute)
add_subdirectory(batched_gemm_softmax_gemm_permute)
add_subdirectory(grouped_gemm)
add_subdirectory(reduce)
add_subdirectory(convnd_fwd)

View File

@@ -1,5 +0,0 @@
add_custom_target(test_batched_gemm_masking_scale_softmax_gemm_permute)
add_gtest_executable(test_batched_gemm_masking_scale_softmax_gemm_permute_fp16 test_batched_gemm_masking_scale_softmax_gemm_permute_fp16.cpp)
target_link_libraries(test_batched_gemm_masking_scale_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_masking_scale_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_masking_scale_softmax_gemm_permute test_batched_gemm_masking_scale_softmax_gemm_permute_fp16)

View File

@@ -9,9 +9,13 @@ class TestBatchedGemmSoftmaxGemmFP16 : public TestBatchedGemmSoftmaxGemm<Tuple>
{
};
using Masked = std::true_type;
using NoMask = std::false_type;
// clang-format off
using KernelTypes = ::testing::Types<
std::tuple<F16, F16, F16, F16, Row, Col, Row, Row>
std::tuple<F16, F16, F16, F16, Row, Col, Row, Row, NoMask>,
std::tuple<F16, F16, F16, F16, Row, Col, Row, Row, Masked>
>;
// clang-format on
@@ -120,7 +124,6 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16_IrregularK)
using ck::tensor_operation::device::GemmSpecialization;
// TODO: enable KPadding tests when it is implemented
TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMatch)
{
int P = 120; // requires padding
@@ -152,12 +155,12 @@ TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMismatch)
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(128, 128, 120, 128));
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
// Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 130, 128));
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 130, 128));
// Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
// clang-format on
}
@@ -169,6 +172,5 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, AdhocTest)
{1020, 1020, 64, 128, 24},
{576, 576, 64, 64, 24},
};
this->bench_ = true;
this->Run();
}

View File

@@ -20,14 +20,15 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template <typename Tuple>
struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
{
using ADataType = std::tuple_element_t<0, Tuple>;
using B0DataType = std::tuple_element_t<1, Tuple>;
using B1DataType = std::tuple_element_t<2, Tuple>;
using CDataType = std::tuple_element_t<3, Tuple>;
using ALayout = std::tuple_element_t<4, Tuple>;
using B0Layout = std::tuple_element_t<5, Tuple>;
using B1Layout = std::tuple_element_t<6, Tuple>;
using CLayout = std::tuple_element_t<7, Tuple>;
using ADataType = std::tuple_element_t<0, Tuple>;
using B0DataType = std::tuple_element_t<1, Tuple>;
using B1DataType = std::tuple_element_t<2, Tuple>;
using CDataType = std::tuple_element_t<3, Tuple>;
using ALayout = std::tuple_element_t<4, Tuple>;
using B0Layout = std::tuple_element_t<5, Tuple>;
using B1Layout = std::tuple_element_t<6, Tuple>;
using CLayout = std::tuple_element_t<7, Tuple>;
using MaskingType = std::tuple_element_t<8, Tuple>;
std::vector<std::vector<int>> lengths_ = {{256, 256, 64, 64, 4},
{256, 256, 128, 128, 4},
@@ -54,7 +55,8 @@ struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
ALayout,
B0Layout,
B1Layout,
CLayout>(
CLayout,
MaskingType::value>(
verify_, 1, false, bench_, M, N, K, O, BatchCount);
EXPECT_TRUE(pass);

View File

@@ -0,0 +1,5 @@
add_custom_target(test_batched_gemm_softmax_gemm_permute)
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp)
target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16)

View File

@@ -2,7 +2,7 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp"
#include "test_batched_gemm_softmax_gemm_permute_util.hpp"
template <typename Tuple>
class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
@@ -10,13 +10,18 @@ class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
{
};
using I1_t = ck::Number<1>;
using I2_t = ck::Number<2>;
using MaskDisabled_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>;
using MaskOutUpperTriangle_t =
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskOutUpperTriangle>;
// clang-format off
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using CPermuteNumDims_G_M_O =
S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
using KernelTypes = ::testing::Types<
std::tuple<F16, F16, F16, F16, Row, Col, Row, CPermuteNumDims_G_M_O>
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, MaskDisabled_t>,
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, MaskOutUpperTriangle_t>
>;
// clang-format on
@@ -91,7 +96,7 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddO)
this->Run();
}
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Bench_FP16_IrregularK)
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, DISABLED_Bench_FP16_IrregularK)
{
this->lengths_ = std::vector<std::vector<int>>{{256, 256, 160, 160, 1, 16},
{256, 64, 160, 64, 1, 16},
@@ -125,7 +130,6 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, DISABLED_Bench_FP1
using ck::tensor_operation::device::GemmSpecialization;
// TODO: enable KPadding tests when it is implemented
TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMatch)
{
int P = 120; // requires padding
@@ -133,22 +137,22 @@ TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationS
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(Q, Q, Q, Q));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MPadding>{}.IsSupported(P, Q, Q, Q));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NPadding>{}.IsSupported(Q, P, Q, Q));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KPadding>{}.IsSupported(Q, Q, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNPadding>{}.IsSupported(P, P, Q, Q));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKPadding>{}.IsSupported(P, Q, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKPadding>{}.IsSupported(Q, P, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(P, P, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::OPadding>{}.IsSupported(Q, Q, Q, P));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MOPadding>{}.IsSupported(P, Q, Q, P));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NOPadding>{}.IsSupported(Q, P, Q, P));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KOPadding>{}.IsSupported(Q, Q, P, P));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNOPadding>{}.IsSupported(P, P, Q, P));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKOPadding>{}.IsSupported(P, Q, P, P));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKOPadding>{}.IsSupported(Q, P, P, P));
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(P, P, P, P));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(Q, Q, Q, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MPadding>{}.IsSupported(P, Q, Q, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NPadding>{}.IsSupported(Q, P, Q, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KPadding>{}.IsSupported(Q, Q, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNPadding>{}.IsSupported(P, P, Q, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKPadding>{}.IsSupported(P, Q, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKPadding>{}.IsSupported(Q, P, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(P, P, P, Q));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::OPadding>{}.IsSupported(Q, Q, Q, P));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MOPadding>{}.IsSupported(P, Q, Q, P));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NOPadding>{}.IsSupported(Q, P, Q, P));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KOPadding>{}.IsSupported(Q, Q, P, P));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNOPadding>{}.IsSupported(P, P, Q, P));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKOPadding>{}.IsSupported(P, Q, P, P));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKOPadding>{}.IsSupported(Q, P, P, P));
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(P, P, P, P));
// clang-format on
}
@@ -156,13 +160,13 @@ TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationS
{
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(128, 128, 120, 128));
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(128, 128, 120, 128));
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
// Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 130, 128));
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 130, 128));
// Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
// clang-format on
}
@@ -174,6 +178,5 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, AdhocTest)
{1020, 1020, 64, 128, 4, 6},
{576, 576, 64, 64, 4, 6},
};
this->bench_ = true;
this->Run();
}

View File

@@ -4,10 +4,14 @@
#include <iostream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "profiler/include/profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp"
#include "profiler/include/profile_batched_gemm_softmax_gemm_permute_impl.hpp"
using ck::tensor_operation::device::GemmSpecialization;
using ck::tensor_operation::device::MaskingSpecialization;
using ck::tensor_operation::device::TensorSpecialization;
template <ck::index_t N>
using I = ck::Number<N>;
@@ -20,14 +24,18 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template <typename Tuple>
struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
{
using ADataType = std::tuple_element_t<0, Tuple>;
using B0DataType = std::tuple_element_t<1, Tuple>;
using B1DataType = std::tuple_element_t<2, Tuple>;
using CDataType = std::tuple_element_t<3, Tuple>;
using ALayout = std::tuple_element_t<4, Tuple>;
using B0Layout = std::tuple_element_t<5, Tuple>;
using B1Layout = std::tuple_element_t<6, Tuple>;
using CPermuteNumDims_G_M_O = std::tuple_element_t<7, Tuple>;
using NumDimGType = std::tuple_element_t<0, Tuple>;
using NumDimMType = std::tuple_element_t<1, Tuple>;
using NumDimNType = std::tuple_element_t<2, Tuple>;
using NumDimKType = std::tuple_element_t<3, Tuple>;
using NumDimOType = std::tuple_element_t<4, Tuple>;
using ADataType = std::tuple_element_t<5, Tuple>;
using B0DataType = std::tuple_element_t<6, Tuple>;
using B1DataType = std::tuple_element_t<7, Tuple>;
using CDataType = std::tuple_element_t<8, Tuple>;
using Acc0BiasDataType = std::tuple_element_t<9, Tuple>;
using Acc1BiasDataType = std::tuple_element_t<10, Tuple>;
using MaskingType = std::tuple_element_t<11, Tuple>;
std::vector<std::vector<int>> lengths_ = {
{256, 256, 64, 64, 6, 4},
@@ -42,15 +50,20 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
void RunSingle(int M, int N, int K, int O, int G0, int G1)
{
bool pass = ck::profiler::profile_batched_gemm_masking_scale_softmax_gemm_permute_impl<
ADataType,
B0DataType,
B1DataType,
CDataType,
ALayout,
B0Layout,
B1Layout,
CPermuteNumDims_G_M_O>(verify_, 1, false, bench_, M, N, K, O, G0, G1);
bool pass =
ck::profiler::profile_batched_gemm_softmax_gemm_permute_impl<NumDimGType::value,
NumDimMType::value,
NumDimNType::value,
NumDimKType::value,
NumDimOType::value,
ADataType,
B0DataType,
B1DataType,
CDataType,
ck::Tuple<>,
ck::Tuple<>,
MaskingType::value>(
verify_, 1, false, bench_, M, N, K, O, G0, G1);
EXPECT_TRUE(pass);
}
@@ -72,19 +85,13 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
};
template <GemmSpecialization GemmSpec>
struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
using ALayout = Row;
using B0Layout = Col;
using B1Layout = Row;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using CPermuteNumDims_G_M_O =
S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
using ADataType = F16;
using B0DataType = F16;
@@ -103,14 +110,17 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
using DeviceGemmGemmInstance =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
ALayout,
B0Layout,
B1Layout,
CPermuteNumDims_G_M_O,
2,
1,
1,
1,
1,
ADataType,
B0DataType,
B1DataType,
CDataType,
ck::Tuple<>,
ck::Tuple<>,
AccDataType,
CShuffleDataType,
AElementOp,
@@ -119,6 +129,10 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
B1ElementOp,
CElementOp,
GemmSpec,
TensorSpecialization::Default, // ATensorSpec
TensorSpecialization::Default, // B0TensorSpec
TensorSpecialization::Default, // B1TensorSpec
TensorSpecialization::Default, // CTensorSpec
1,
256,
128, // MPerBlock
@@ -159,29 +173,48 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
true>; // Masking
MaskingSpecialization::MaskOutUpperTriangle>; // MaskOutUpperTriangle
bool IsSupported(int M, int N, int K, int O)
{
const int G0 = 1, G1 = 1;
// A layout [G0, M, G1, K]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1};
// B0 layout [G0, N, G1, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1};
// B1 layout [G0, N, G1, O]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O};
// C layout [G0, M, G1, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
auto gemm = DeviceGemmGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(nullptr),
static_cast<B0DataType*>(nullptr),
static_cast<B1DataType*>(nullptr),
static_cast<CDataType*>(nullptr),
M,
N,
K,
O,
0, // BatchCount
{0, 0, M, O}, // gs ms ns lengths
{0, O, 0, 1}, // gs ms ns strides
0, // StrideA
0, // StrideB0
0, // StrideB1
0, // BatchStrideA
0, // BatchStrideB0
0, // BatchStrideB1
{}, // p_acc0_biases
{}, // p_acc1_biases
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides,
b1_gs_os_ns_lengths,
b1_gs_os_ns_strides,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
{}, // acc0_biases_gs_ms_ns_lengths
{}, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths
{}, // acc1_biases_gs_ms_os_strides
PassThrough{}, // a_element_op
PassThrough{}, // b0_element_op
Scale{1.f}, // acc0_element_op

View File

@@ -12,28 +12,91 @@
using namespace ck;
void traverse_using_space_filling_curve();
void traverse_using_space_filling_curve_linear();
void traverse_using_space_filling_curve_snakecurved();
int main(int argc, char** argv)
{
(void)argc;
(void)argv;
traverse_using_space_filling_curve();
traverse_using_space_filling_curve_linear();
traverse_using_space_filling_curve_snakecurved();
return 0;
}
void traverse_using_space_filling_curve()
void traverse_using_space_filling_curve_linear()
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
using TensorLengths = Sequence<16, 10, 9>;
using DimAccessOrder = Sequence<2, 0, 1>;
using ScalarsPerAccess = Sequence<4, 2, 3>;
using SpaceFillingCurve = SpaceFillingCurve<TensorLengths, DimAccessOrder, ScalarsPerAccess>;
using TensorLengths = Sequence<3, 2, 2>;
using DimAccessOrder = Sequence<2, 0, 1>;
using ScalarsPerAccess = Sequence<1, 1, 1>;
using SpaceFillingCurve =
SpaceFillingCurve<TensorLengths, DimAccessOrder, ScalarsPerAccess, false>;
constexpr auto expected = make_tuple(make_tuple(0, 0, 0),
make_tuple(0, 1, 0),
make_tuple(1, 0, 0),
make_tuple(1, 1, 0),
make_tuple(2, 0, 0),
make_tuple(2, 1, 0),
make_tuple(0, 0, 1),
make_tuple(0, 1, 1),
make_tuple(1, 0, 1),
make_tuple(1, 1, 1),
make_tuple(2, 0, 1),
make_tuple(2, 1, 1));
constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess();
static_assert(num_access == reduce_on_sequence(TensorLengths{} / ScalarsPerAccess{},
math::multiplies{},
Number<1>{}));
static_for<1, num_access, 1>{}([&](auto i) {
constexpr auto idx_curr = SpaceFillingCurve::GetIndex(i);
static_assert(idx_curr[I0] == expected[i][I0]);
static_assert(idx_curr[I1] == expected[i][I1]);
static_assert(idx_curr[I2] == expected[i][I2]);
constexpr auto backward_step = SpaceFillingCurve::GetBackwardStep(i);
constexpr auto expected_step = expected[i - I1] - expected[i];
static_assert(backward_step[I0] == expected_step[I0]);
static_assert(backward_step[I1] == expected_step[I1]);
static_assert(backward_step[I2] == expected_step[I2]);
});
static_for<0, num_access - 1, 1>{}([&](auto i) {
constexpr auto idx_curr = SpaceFillingCurve::GetIndex(i);
static_assert(idx_curr[I0] == expected[i][I0]);
static_assert(idx_curr[I1] == expected[i][I1]);
static_assert(idx_curr[I2] == expected[i][I2]);
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(i);
constexpr auto expected_step = expected[i + I1] - expected[i];
static_assert(forward_step[I0] == expected_step[I0]);
static_assert(forward_step[I1] == expected_step[I1]);
static_assert(forward_step[I2] == expected_step[I2]);
});
}
void traverse_using_space_filling_curve_snakecurved()
{
constexpr auto I0 = Number<0>{};
constexpr auto I1 = Number<1>{};
constexpr auto I2 = Number<2>{};
using TensorLengths = Sequence<16, 10, 9>;
using DimAccessOrder = Sequence<2, 0, 1>;
using ScalarsPerAccess = Sequence<4, 2, 3>;
using SpaceFillingCurve =
SpaceFillingCurve<TensorLengths, DimAccessOrder, ScalarsPerAccess, true>;
constexpr auto expected = make_tuple(make_tuple(0, 0, 0),
make_tuple(0, 2, 0),