mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Input/output permutation for fused attention (#460)
* reopen masking att instance due to CI is upgraded
* re-enable instances previously failed on 9110
* enable ksize-kpadding pair validity test
* add non-masked attention+permute test; expose masking boolean to attention kernel handles
* disable bench
* fix test
* move files
* bulk rename batched_gemm_masking_scale_softmax_gemm_permute to batched_gemm_softmax_gemm_permute
* format
* amend rename
* disable bench in test
* add mask/no-mask test for non-permute attention kernels
* disable broken kernel instance
* example working
add non-permuted problem statement
evaluating whether overhead comes from permutation or the extra kernel arg
* interface for bias addition without implementing it
* test and profiler running
* tidy
* mask type determined by enum class
* unify example code
* move masking specialization to its own header
* align formats
* extract helper functions
* experiment merging dims for attn w/ permute; shows perf parity with attn wo/ permute
* add tensor specialization to template args
since tensor spec packed shows perf parity when permutation isn't needed
remove redundant template args
comment on 'packed' tensor specialization
* grouped attention with input/output permute example
* format
* clean up
* refactor acc0 tile visitor
Co-authored-by: shaojiewang <wsjmessi@163.com>
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: de37550f72]
This commit is contained in:
@@ -41,7 +41,7 @@ add_subdirectory(batched_gemm)
|
||||
add_subdirectory(batched_gemm_reduce)
|
||||
add_subdirectory(batched_gemm_gemm)
|
||||
add_subdirectory(batched_gemm_softmax_gemm)
|
||||
add_subdirectory(batched_gemm_masking_scale_softmax_gemm_permute)
|
||||
add_subdirectory(batched_gemm_softmax_gemm_permute)
|
||||
add_subdirectory(grouped_gemm)
|
||||
add_subdirectory(reduce)
|
||||
add_subdirectory(convnd_fwd)
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
add_custom_target(test_batched_gemm_masking_scale_softmax_gemm_permute)
|
||||
|
||||
add_gtest_executable(test_batched_gemm_masking_scale_softmax_gemm_permute_fp16 test_batched_gemm_masking_scale_softmax_gemm_permute_fp16.cpp)
|
||||
target_link_libraries(test_batched_gemm_masking_scale_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_masking_scale_softmax_gemm_permute_instance)
|
||||
add_dependencies(test_batched_gemm_masking_scale_softmax_gemm_permute test_batched_gemm_masking_scale_softmax_gemm_permute_fp16)
|
||||
@@ -9,9 +9,13 @@ class TestBatchedGemmSoftmaxGemmFP16 : public TestBatchedGemmSoftmaxGemm<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
using Masked = std::true_type;
|
||||
using NoMask = std::false_type;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
std::tuple<F16, F16, F16, F16, Row, Col, Row, Row>
|
||||
std::tuple<F16, F16, F16, F16, Row, Col, Row, Row, NoMask>,
|
||||
std::tuple<F16, F16, F16, F16, Row, Col, Row, Row, Masked>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
@@ -120,7 +124,6 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16_IrregularK)
|
||||
|
||||
using ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
// TODO: enable KPadding tests when it is implemented
|
||||
TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMatch)
|
||||
{
|
||||
int P = 120; // requires padding
|
||||
@@ -152,12 +155,12 @@ TEST(TestBatchedGemmSoftmaxGemmInterface, GemmSpecializationSizeMismatch)
|
||||
// IsSupported(M, N, K, O)
|
||||
// clang-format off
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(128, 128, 120, 128));
|
||||
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
|
||||
// Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0
|
||||
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
|
||||
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 130, 128));
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 130, 128));
|
||||
// Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0
|
||||
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -169,6 +172,5 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, AdhocTest)
|
||||
{1020, 1020, 64, 128, 24},
|
||||
{576, 576, 64, 64, 24},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
@@ -20,14 +20,15 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
template <typename Tuple>
|
||||
struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
|
||||
{
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using B0DataType = std::tuple_element_t<1, Tuple>;
|
||||
using B1DataType = std::tuple_element_t<2, Tuple>;
|
||||
using CDataType = std::tuple_element_t<3, Tuple>;
|
||||
using ALayout = std::tuple_element_t<4, Tuple>;
|
||||
using B0Layout = std::tuple_element_t<5, Tuple>;
|
||||
using B1Layout = std::tuple_element_t<6, Tuple>;
|
||||
using CLayout = std::tuple_element_t<7, Tuple>;
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using B0DataType = std::tuple_element_t<1, Tuple>;
|
||||
using B1DataType = std::tuple_element_t<2, Tuple>;
|
||||
using CDataType = std::tuple_element_t<3, Tuple>;
|
||||
using ALayout = std::tuple_element_t<4, Tuple>;
|
||||
using B0Layout = std::tuple_element_t<5, Tuple>;
|
||||
using B1Layout = std::tuple_element_t<6, Tuple>;
|
||||
using CLayout = std::tuple_element_t<7, Tuple>;
|
||||
using MaskingType = std::tuple_element_t<8, Tuple>;
|
||||
|
||||
std::vector<std::vector<int>> lengths_ = {{256, 256, 64, 64, 4},
|
||||
{256, 256, 128, 128, 4},
|
||||
@@ -54,7 +55,8 @@ struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
|
||||
ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CLayout>(
|
||||
CLayout,
|
||||
MaskingType::value>(
|
||||
verify_, 1, false, bench_, M, N, K, O, BatchCount);
|
||||
|
||||
EXPECT_TRUE(pass);
|
||||
|
||||
5
test/batched_gemm_softmax_gemm_permute/CMakeLists.txt
Normal file
5
test/batched_gemm_softmax_gemm_permute/CMakeLists.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
add_custom_target(test_batched_gemm_softmax_gemm_permute)
|
||||
|
||||
add_gtest_executable(test_batched_gemm_softmax_gemm_permute_fp16 test_batched_gemm_softmax_gemm_permute_fp16.cpp)
|
||||
target_link_libraries(test_batched_gemm_softmax_gemm_permute_fp16 PRIVATE utility device_batched_gemm_softmax_gemm_permute_instance)
|
||||
add_dependencies(test_batched_gemm_softmax_gemm_permute test_batched_gemm_softmax_gemm_permute_fp16)
|
||||
@@ -2,7 +2,7 @@
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp"
|
||||
#include "test_batched_gemm_softmax_gemm_permute_util.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
|
||||
@@ -10,13 +10,18 @@ class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
|
||||
{
|
||||
};
|
||||
|
||||
using I1_t = ck::Number<1>;
|
||||
using I2_t = ck::Number<2>;
|
||||
|
||||
using MaskDisabled_t =
|
||||
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskDisabled>;
|
||||
using MaskOutUpperTriangle_t =
|
||||
ck::integral_constant<MaskingSpecialization, MaskingSpecialization::MaskOutUpperTriangle>;
|
||||
|
||||
// clang-format off
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
using CPermuteNumDims_G_M_O =
|
||||
S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
|
||||
using KernelTypes = ::testing::Types<
|
||||
std::tuple<F16, F16, F16, F16, Row, Col, Row, CPermuteNumDims_G_M_O>
|
||||
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, MaskDisabled_t>,
|
||||
std::tuple<I2_t, I1_t, I1_t, I1_t, I1_t, F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, MaskOutUpperTriangle_t>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
@@ -91,7 +96,7 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Test_FP16_OddO)
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, Bench_FP16_IrregularK)
|
||||
TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, DISABLED_Bench_FP16_IrregularK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{{256, 256, 160, 160, 1, 16},
|
||||
{256, 64, 160, 64, 1, 16},
|
||||
@@ -125,7 +130,6 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, DISABLED_Bench_FP1
|
||||
|
||||
using ck::tensor_operation::device::GemmSpecialization;
|
||||
|
||||
// TODO: enable KPadding tests when it is implemented
|
||||
TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationSizeMatch)
|
||||
{
|
||||
int P = 120; // requires padding
|
||||
@@ -133,22 +137,22 @@ TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationS
|
||||
|
||||
// IsSupported(M, N, K, O)
|
||||
// clang-format off
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(Q, Q, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MPadding>{}.IsSupported(P, Q, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NPadding>{}.IsSupported(Q, P, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KPadding>{}.IsSupported(Q, Q, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNPadding>{}.IsSupported(P, P, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKPadding>{}.IsSupported(P, Q, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKPadding>{}.IsSupported(Q, P, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(P, P, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::OPadding>{}.IsSupported(Q, Q, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MOPadding>{}.IsSupported(P, Q, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NOPadding>{}.IsSupported(Q, P, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KOPadding>{}.IsSupported(Q, Q, P, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNOPadding>{}.IsSupported(P, P, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKOPadding>{}.IsSupported(P, Q, P, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKOPadding>{}.IsSupported(Q, P, P, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(P, P, P, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(Q, Q, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MPadding>{}.IsSupported(P, Q, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NPadding>{}.IsSupported(Q, P, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KPadding>{}.IsSupported(Q, Q, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNPadding>{}.IsSupported(P, P, Q, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKPadding>{}.IsSupported(P, Q, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKPadding>{}.IsSupported(Q, P, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(P, P, P, Q));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::OPadding>{}.IsSupported(Q, Q, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MOPadding>{}.IsSupported(P, Q, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NOPadding>{}.IsSupported(Q, P, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::KOPadding>{}.IsSupported(Q, Q, P, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNOPadding>{}.IsSupported(P, P, Q, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MKOPadding>{}.IsSupported(P, Q, P, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::NKOPadding>{}.IsSupported(Q, P, P, P));
|
||||
EXPECT_TRUE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(P, P, P, P));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -156,13 +160,13 @@ TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteInterface, GemmSpecializationS
|
||||
{
|
||||
// IsSupported(M, N, K, O)
|
||||
// clang-format off
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(128, 128, 120, 128));
|
||||
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::Default>{}.IsSupported(128, 128, 120, 128));
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKPadding>{}.IsSupported(128, 128, 128, 120));
|
||||
// Kernel can't support odd K size because SrcVectorDim == KDim and must satisfy SizeKRaw % ABSrcScalarPerVector == 0
|
||||
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
|
||||
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 130, 128));
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 129, 128));
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 130, 128));
|
||||
// Kernel can't support odd O size because SrcVectorDim == ODim and must satisfy SizeORaw % B1SrcScalarPerVector == 0
|
||||
// EXPECT_FALSE(DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
|
||||
EXPECT_FALSE(DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128<GemmSpecialization::MNKOPadding>{}.IsSupported(128, 128, 128, 129));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -174,6 +178,5 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, AdhocTest)
|
||||
{1020, 1020, 64, 128, 4, 6},
|
||||
{576, 576, 64, 64, 4, 6},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->Run();
|
||||
}
|
||||
@@ -4,10 +4,14 @@
|
||||
#include <iostream>
|
||||
|
||||
#include <vector>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
|
||||
#include "profiler/include/profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp"
|
||||
#include "profiler/include/profile_batched_gemm_softmax_gemm_permute_impl.hpp"
|
||||
|
||||
using ck::tensor_operation::device::GemmSpecialization;
|
||||
using ck::tensor_operation::device::MaskingSpecialization;
|
||||
using ck::tensor_operation::device::TensorSpecialization;
|
||||
|
||||
template <ck::index_t N>
|
||||
using I = ck::Number<N>;
|
||||
@@ -20,14 +24,18 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
template <typename Tuple>
|
||||
struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
|
||||
{
|
||||
using ADataType = std::tuple_element_t<0, Tuple>;
|
||||
using B0DataType = std::tuple_element_t<1, Tuple>;
|
||||
using B1DataType = std::tuple_element_t<2, Tuple>;
|
||||
using CDataType = std::tuple_element_t<3, Tuple>;
|
||||
using ALayout = std::tuple_element_t<4, Tuple>;
|
||||
using B0Layout = std::tuple_element_t<5, Tuple>;
|
||||
using B1Layout = std::tuple_element_t<6, Tuple>;
|
||||
using CPermuteNumDims_G_M_O = std::tuple_element_t<7, Tuple>;
|
||||
using NumDimGType = std::tuple_element_t<0, Tuple>;
|
||||
using NumDimMType = std::tuple_element_t<1, Tuple>;
|
||||
using NumDimNType = std::tuple_element_t<2, Tuple>;
|
||||
using NumDimKType = std::tuple_element_t<3, Tuple>;
|
||||
using NumDimOType = std::tuple_element_t<4, Tuple>;
|
||||
using ADataType = std::tuple_element_t<5, Tuple>;
|
||||
using B0DataType = std::tuple_element_t<6, Tuple>;
|
||||
using B1DataType = std::tuple_element_t<7, Tuple>;
|
||||
using CDataType = std::tuple_element_t<8, Tuple>;
|
||||
using Acc0BiasDataType = std::tuple_element_t<9, Tuple>;
|
||||
using Acc1BiasDataType = std::tuple_element_t<10, Tuple>;
|
||||
using MaskingType = std::tuple_element_t<11, Tuple>;
|
||||
|
||||
std::vector<std::vector<int>> lengths_ = {
|
||||
{256, 256, 64, 64, 6, 4},
|
||||
@@ -42,15 +50,20 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
|
||||
|
||||
void RunSingle(int M, int N, int K, int O, int G0, int G1)
|
||||
{
|
||||
bool pass = ck::profiler::profile_batched_gemm_masking_scale_softmax_gemm_permute_impl<
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CPermuteNumDims_G_M_O>(verify_, 1, false, bench_, M, N, K, O, G0, G1);
|
||||
bool pass =
|
||||
ck::profiler::profile_batched_gemm_softmax_gemm_permute_impl<NumDimGType::value,
|
||||
NumDimMType::value,
|
||||
NumDimNType::value,
|
||||
NumDimKType::value,
|
||||
NumDimOType::value,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<>,
|
||||
MaskingType::value>(
|
||||
verify_, 1, false, bench_, M, N, K, O, G0, G1);
|
||||
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
@@ -72,19 +85,13 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
|
||||
};
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
|
||||
struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
|
||||
{
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
|
||||
using ALayout = Row;
|
||||
using B0Layout = Col;
|
||||
using B1Layout = Row;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
using CPermuteNumDims_G_M_O =
|
||||
S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
|
||||
|
||||
using ADataType = F16;
|
||||
using B0DataType = F16;
|
||||
@@ -103,14 +110,17 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
|
||||
|
||||
using DeviceGemmGemmInstance =
|
||||
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle<
|
||||
ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CPermuteNumDims_G_M_O,
|
||||
2,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
ck::Tuple<>,
|
||||
ck::Tuple<>,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
@@ -119,6 +129,10 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
GemmSpec,
|
||||
TensorSpecialization::Default, // ATensorSpec
|
||||
TensorSpecialization::Default, // B0TensorSpec
|
||||
TensorSpecialization::Default, // B1TensorSpec
|
||||
TensorSpecialization::Default, // CTensorSpec
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
@@ -159,29 +173,48 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
true>; // Masking
|
||||
MaskingSpecialization::MaskOutUpperTriangle>; // MaskOutUpperTriangle
|
||||
|
||||
bool IsSupported(int M, int N, int K, int O)
|
||||
{
|
||||
const int G0 = 1, G1 = 1;
|
||||
|
||||
// A layout [G0, M, G1, K]
|
||||
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
|
||||
std::vector<ck::index_t> a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1};
|
||||
|
||||
// B0 layout [G0, N, G1, K]
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
|
||||
std::vector<ck::index_t> b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1};
|
||||
|
||||
// B1 layout [G0, N, G1, O]
|
||||
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
|
||||
std::vector<ck::index_t> b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O};
|
||||
|
||||
// C layout [G0, M, G1, O]
|
||||
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
|
||||
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
|
||||
|
||||
auto gemm = DeviceGemmGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(nullptr),
|
||||
static_cast<B0DataType*>(nullptr),
|
||||
static_cast<B1DataType*>(nullptr),
|
||||
static_cast<CDataType*>(nullptr),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
0, // BatchCount
|
||||
{0, 0, M, O}, // gs ms ns lengths
|
||||
{0, O, 0, 1}, // gs ms ns strides
|
||||
0, // StrideA
|
||||
0, // StrideB0
|
||||
0, // StrideB1
|
||||
0, // BatchStrideA
|
||||
0, // BatchStrideB0
|
||||
0, // BatchStrideB1
|
||||
{}, // p_acc0_biases
|
||||
{}, // p_acc1_biases
|
||||
a_gs_ms_ks_lengths,
|
||||
a_gs_ms_ks_strides,
|
||||
b0_gs_ns_ks_lengths,
|
||||
b0_gs_ns_ks_strides,
|
||||
b1_gs_os_ns_lengths,
|
||||
b1_gs_os_ns_strides,
|
||||
c_gs_ms_os_lengths,
|
||||
c_gs_ms_os_strides,
|
||||
{}, // acc0_biases_gs_ms_ns_lengths
|
||||
{}, // acc0_biases_gs_ms_ns_strides
|
||||
{}, // acc1_biases_gs_ms_os_lengths
|
||||
{}, // acc1_biases_gs_ms_os_strides
|
||||
PassThrough{}, // a_element_op
|
||||
PassThrough{}, // b0_element_op
|
||||
Scale{1.f}, // acc0_element_op
|
||||
@@ -12,28 +12,91 @@
|
||||
|
||||
using namespace ck;
|
||||
|
||||
void traverse_using_space_filling_curve();
|
||||
void traverse_using_space_filling_curve_linear();
|
||||
void traverse_using_space_filling_curve_snakecurved();
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
(void)argc;
|
||||
(void)argv;
|
||||
|
||||
traverse_using_space_filling_curve();
|
||||
traverse_using_space_filling_curve_linear();
|
||||
traverse_using_space_filling_curve_snakecurved();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void traverse_using_space_filling_curve()
|
||||
void traverse_using_space_filling_curve_linear()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
using TensorLengths = Sequence<16, 10, 9>;
|
||||
using DimAccessOrder = Sequence<2, 0, 1>;
|
||||
using ScalarsPerAccess = Sequence<4, 2, 3>;
|
||||
using SpaceFillingCurve = SpaceFillingCurve<TensorLengths, DimAccessOrder, ScalarsPerAccess>;
|
||||
using TensorLengths = Sequence<3, 2, 2>;
|
||||
using DimAccessOrder = Sequence<2, 0, 1>;
|
||||
using ScalarsPerAccess = Sequence<1, 1, 1>;
|
||||
using SpaceFillingCurve =
|
||||
SpaceFillingCurve<TensorLengths, DimAccessOrder, ScalarsPerAccess, false>;
|
||||
|
||||
constexpr auto expected = make_tuple(make_tuple(0, 0, 0),
|
||||
make_tuple(0, 1, 0),
|
||||
make_tuple(1, 0, 0),
|
||||
make_tuple(1, 1, 0),
|
||||
make_tuple(2, 0, 0),
|
||||
make_tuple(2, 1, 0),
|
||||
make_tuple(0, 0, 1),
|
||||
make_tuple(0, 1, 1),
|
||||
make_tuple(1, 0, 1),
|
||||
make_tuple(1, 1, 1),
|
||||
make_tuple(2, 0, 1),
|
||||
make_tuple(2, 1, 1));
|
||||
|
||||
constexpr index_t num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
static_assert(num_access == reduce_on_sequence(TensorLengths{} / ScalarsPerAccess{},
|
||||
math::multiplies{},
|
||||
Number<1>{}));
|
||||
|
||||
static_for<1, num_access, 1>{}([&](auto i) {
|
||||
constexpr auto idx_curr = SpaceFillingCurve::GetIndex(i);
|
||||
|
||||
static_assert(idx_curr[I0] == expected[i][I0]);
|
||||
static_assert(idx_curr[I1] == expected[i][I1]);
|
||||
static_assert(idx_curr[I2] == expected[i][I2]);
|
||||
|
||||
constexpr auto backward_step = SpaceFillingCurve::GetBackwardStep(i);
|
||||
constexpr auto expected_step = expected[i - I1] - expected[i];
|
||||
static_assert(backward_step[I0] == expected_step[I0]);
|
||||
static_assert(backward_step[I1] == expected_step[I1]);
|
||||
static_assert(backward_step[I2] == expected_step[I2]);
|
||||
});
|
||||
|
||||
static_for<0, num_access - 1, 1>{}([&](auto i) {
|
||||
constexpr auto idx_curr = SpaceFillingCurve::GetIndex(i);
|
||||
|
||||
static_assert(idx_curr[I0] == expected[i][I0]);
|
||||
static_assert(idx_curr[I1] == expected[i][I1]);
|
||||
static_assert(idx_curr[I2] == expected[i][I2]);
|
||||
|
||||
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(i);
|
||||
constexpr auto expected_step = expected[i + I1] - expected[i];
|
||||
static_assert(forward_step[I0] == expected_step[I0]);
|
||||
static_assert(forward_step[I1] == expected_step[I1]);
|
||||
static_assert(forward_step[I2] == expected_step[I2]);
|
||||
});
|
||||
}
|
||||
|
||||
void traverse_using_space_filling_curve_snakecurved()
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
using TensorLengths = Sequence<16, 10, 9>;
|
||||
using DimAccessOrder = Sequence<2, 0, 1>;
|
||||
using ScalarsPerAccess = Sequence<4, 2, 3>;
|
||||
using SpaceFillingCurve =
|
||||
SpaceFillingCurve<TensorLengths, DimAccessOrder, ScalarsPerAccess, true>;
|
||||
|
||||
constexpr auto expected = make_tuple(make_tuple(0, 0, 0),
|
||||
make_tuple(0, 2, 0),
|
||||
|
||||
Reference in New Issue
Block a user