Merge commit '7330ec37ee3b8cf2d54630372dfe9e86a893e4f5' into develop

This commit is contained in:
assistant-librarian[bot]
2025-09-04 21:11:23 +00:00
parent 5677205f88
commit 7f65be1b3e
51 changed files with 3709 additions and 189 deletions

View File

@@ -1,6 +1,14 @@
add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16_xdl.cpp)
add_gtest_executable(test_batched_gemm_gemm_fp16_xdl test_batched_gemm_gemm_fp16_xdl.cpp)
if(result EQUAL 0)
add_custom_target(test_batched_gemm_gemm)
target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance)
add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16)
target_link_libraries(test_batched_gemm_gemm_fp16_xdl PRIVATE utility device_batched_gemm_gemm_instance)
endif()
add_gtest_executable(test_batched_gemm_gemm_bf16_wmma test_batched_gemm_gemm_bf16_wmma_cshuffle_v3.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_gemm_bf16_wmma PRIVATE utility device_batched_gemm_gemm_instance)
endif()
add_gtest_executable(test_batched_gemm_gemm_fp16_wmma test_batched_gemm_gemm_fp16_wmma_cshuffle_v3.cpp)
if(result EQUAL 0)
target_link_libraries(test_batched_gemm_gemm_fp16_wmma PRIVATE utility device_batched_gemm_gemm_instance)
endif()

View File

@@ -0,0 +1,128 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "test_batched_gemm_gemm_util.hpp"
template <typename Tuple>
class TestBatchedGemmGemmBF16 : public TestBatchedGemmGemm<Tuple>
{
};
// clang-format off
using KernelTypes = ::testing::Types<
std::tuple<BF16, BF16, BF16, BF16, Row, Col, Row, Row>,
std::tuple<BF16, BF16, BF16, BF16, Row, Col, Col, Row>
>;
// clang-format on
TYPED_TEST_SUITE(TestBatchedGemmGemmBF16, KernelTypes);
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16)
{
this->bench_ = true;
this->verify_ = true;
this->Run();
}
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_PadM)
{
this->lengths_ = std::vector<std::vector<int>>{
{136, 128, 32, 128, 1},
};
this->bench_ = true;
this->verify_ = true;
this->Run();
}
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_PadN)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 136, 32, 128, 1},
};
this->bench_ = true;
this->verify_ = true;
this->Run();
}
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_PadK)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 40, 128, 1},
{128, 128, 136, 128, 1},
};
this->bench_ = true;
this->verify_ = true;
this->Run();
}
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_PadO)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 32, 136, 1},
};
this->bench_ = true;
this->verify_ = true;
this->Run();
}
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_OddM)
{
this->lengths_ = std::vector<std::vector<int>>{
{129, 128, 32, 128, 1},
};
this->bench_ = true;
this->verify_ = true;
this->Run();
}
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_OddN)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 129, 32, 128, 1},
};
this->bench_ = true;
this->verify_ = true;
this->Run();
}
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_OddK)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 33, 128, 1},
{128, 128, 129, 128, 1},
};
this->bench_ = true;
this->verify_ = true;
this->Run();
}
// If kernel B1Layout is RowMajor, expect not to support odd O size
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_OddO)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 32, 129, 1},
};
this->bench_ = true;
this->verify_ = true;
this->Run();
}
TYPED_TEST(TestBatchedGemmGemmBF16, DISABLED_Bench_BF16)
{
this->lengths_ = std::vector<std::vector<int>>{
{256, 256, 64, 64, 768},
{256, 256, 128, 128, 768},
{512, 512, 64, 64, 768},
{512, 512, 128, 128, 768},
{1024, 1024, 64, 64, 768},
{1024, 1024, 128, 128, 768},
{2048, 2048, 64, 64, 768},
{2048, 2048, 128, 128, 768},
{4096, 4096, 64, 64, 768},
{4096, 4096, 128, 128, 768},
};
this->bench_ = true;
this->verify_ = false;
this->Run();
}

View File

@@ -0,0 +1,128 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "test_batched_gemm_gemm_util.hpp"
template <typename Tuple>
class TestBatchedGemmGemmFP16 : public TestBatchedGemmGemm<Tuple>
{
};
// clang-format off
using KernelTypes = ::testing::Types<
std::tuple<F16, F16, F16, F16, Row, Col, Row, Row>,
std::tuple<F16, F16, F16, F16, Row, Col, Col, Row>
>;
// clang-format on
TYPED_TEST_SUITE(TestBatchedGemmGemmFP16, KernelTypes);
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16)
{
this->bench_ = true;
this->verify_ = true;
this->Run();
}
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadM)
{
this->lengths_ = std::vector<std::vector<int>>{
{136, 128, 32, 128, 1},
};
this->bench_ = true;
this->verify_ = true;
this->Run();
}
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadN)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 136, 32, 128, 1},
};
this->bench_ = true;
this->verify_ = true;
this->Run();
}
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadK)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 40, 128, 1},
{128, 128, 136, 128, 1},
};
this->bench_ = true;
this->verify_ = true;
this->Run();
}
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadO)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 32, 136, 1},
};
this->bench_ = true;
this->verify_ = true;
this->Run();
}
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddM)
{
this->lengths_ = std::vector<std::vector<int>>{
{129, 128, 32, 128, 1},
};
this->bench_ = true;
this->verify_ = true;
this->Run();
}
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddN)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 129, 32, 128, 1},
};
this->bench_ = true;
this->verify_ = true;
this->Run();
}
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddK)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 33, 128, 1},
{128, 128, 129, 128, 1},
};
this->bench_ = true;
this->verify_ = true;
this->Run();
}
// If kernel B1Layout is RowMajor, expect not to support odd O size
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddO)
{
this->lengths_ = std::vector<std::vector<int>>{
{128, 128, 32, 129, 1},
};
this->bench_ = true;
this->verify_ = true;
this->Run();
}
TYPED_TEST(TestBatchedGemmGemmFP16, DISABLED_Bench_FP16)
{
this->lengths_ = std::vector<std::vector<int>>{
{256, 256, 64, 64, 768},
{256, 256, 128, 128, 768},
{512, 512, 64, 64, 768},
{512, 512, 128, 128, 768},
{1024, 1024, 64, 64, 768},
{1024, 1024, 128, 128, 768},
{2048, 2048, 64, 64, 768},
{2048, 2048, 128, 128, 768},
{4096, 4096, 64, 64, 768},
{4096, 4096, 128, 128, 768},
};
this->bench_ = true;
this->verify_ = false;
this->Run();
}

View File

@@ -1,8 +1,126 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "test_batched_gemm_gemm_util.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
template <GemmSpecialization GemmSpec>
struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ALayout = Row;
using B0Layout = Col;
using B1Layout = Row;
using CLayout = Row;
using ADataType = F16;
using B0DataType = F16;
using B1DataType = F16;
using AccDataType = float;
using CShuffleDataType = float;
using CDataType = F16;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = PassThrough;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
using DeviceGemmGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
ALayout,
B0Layout,
B1Layout,
CLayout,
ADataType,
B0DataType,
B1DataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
bool IsSupported(int M, int N, int K, int O)
{
auto gemm = DeviceGemmGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(nullptr),
static_cast<B0DataType*>(nullptr),
static_cast<B1DataType*>(nullptr),
static_cast<CDataType*>(nullptr),
M,
N,
K,
O,
0, // BatchCount
0, // StrideA
0, // StrideB0
0, // StrideB1
0, // StrideC
0, // BatchStrideA
0, // BatchStrideB0
0, // BatchStrideB1
0, // BatchStrideC
PassThrough{}, // a_element_op
PassThrough{}, // b0_element_op
PassThrough{}, // acc0_element_op
PassThrough{}, // b1_element_op
PassThrough{}); // c_element_op
return gemm.IsSupportedArgument(argument);
}
};
template <typename Tuple>
class TestBatchedGemmGemmFP16 : public TestBatchedGemmGemm<Tuple>

View File

@@ -1,11 +1,10 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "profiler/profile_batched_gemm_gemm_impl.hpp"
using ck::tensor_operation::device::GemmSpecialization;
@@ -13,7 +12,8 @@ using ck::tensor_operation::device::GemmSpecialization;
template <ck::index_t N>
using I = ck::Number<N>;
using F16 = ck::half_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
@@ -70,120 +70,3 @@ struct TestBatchedGemmGemm : public ::testing::Test
}
}
};
template <GemmSpecialization GemmSpec>
struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ALayout = Row;
using B0Layout = Col;
using B1Layout = Row;
using CLayout = Row;
using ADataType = F16;
using B0DataType = F16;
using B1DataType = F16;
using AccDataType = float;
using CShuffleDataType = float;
using CDataType = F16;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = PassThrough;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
using DeviceGemmGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
ALayout,
B0Layout,
B1Layout,
CLayout,
ADataType,
B0DataType,
B1DataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
GemmSpec,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
bool IsSupported(int M, int N, int K, int O)
{
auto gemm = DeviceGemmGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(nullptr),
static_cast<B0DataType*>(nullptr),
static_cast<B1DataType*>(nullptr),
static_cast<CDataType*>(nullptr),
M,
N,
K,
O,
0, // BatchCount
0, // StrideA
0, // StrideB0
0, // StrideB1
0, // StrideC
0, // BatchStrideA
0, // BatchStrideB0
0, // BatchStrideB1
0, // BatchStrideC
PassThrough{}, // a_element_op
PassThrough{}, // b0_element_op
PassThrough{}, // acc0_element_op
PassThrough{}, // b1_element_op
PassThrough{}); // c_element_op
return gemm.IsSupportedArgument(argument);
}
};