mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
Merge commit '7330ec37ee3b8cf2d54630372dfe9e86a893e4f5' into develop
This commit is contained in:
@@ -1,6 +1,14 @@
|
||||
add_gtest_executable(test_batched_gemm_gemm_fp16 test_batched_gemm_gemm_fp16_xdl.cpp)
|
||||
add_gtest_executable(test_batched_gemm_gemm_fp16_xdl test_batched_gemm_gemm_fp16_xdl.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_custom_target(test_batched_gemm_gemm)
|
||||
target_link_libraries(test_batched_gemm_gemm_fp16 PRIVATE utility device_batched_gemm_gemm_instance)
|
||||
add_dependencies(test_batched_gemm_gemm test_batched_gemm_gemm_fp16)
|
||||
target_link_libraries(test_batched_gemm_gemm_fp16_xdl PRIVATE utility device_batched_gemm_gemm_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_batched_gemm_gemm_bf16_wmma test_batched_gemm_gemm_bf16_wmma_cshuffle_v3.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_gemm_bf16_wmma PRIVATE utility device_batched_gemm_gemm_instance)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_batched_gemm_gemm_fp16_wmma test_batched_gemm_gemm_fp16_wmma_cshuffle_v3.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_batched_gemm_gemm_fp16_wmma PRIVATE utility device_batched_gemm_gemm_instance)
|
||||
endif()
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_batched_gemm_gemm_util.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestBatchedGemmGemmBF16 : public TestBatchedGemmGemm<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
std::tuple<BF16, BF16, BF16, BF16, Row, Col, Row, Row>,
|
||||
std::tuple<BF16, BF16, BF16, BF16, Row, Col, Col, Row>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestBatchedGemmGemmBF16, KernelTypes);
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16)
|
||||
{
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_PadM)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{136, 128, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_PadN)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 136, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_PadK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 40, 128, 1},
|
||||
{128, 128, 136, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_PadO)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 136, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_OddM)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{129, 128, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_OddN)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 129, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_OddK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 33, 128, 1},
|
||||
{128, 128, 129, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
// If kernel B1Layout is RowMajor, expect not to support odd O size
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, Test_BF16_OddO)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 129, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmBF16, DISABLED_Bench_BF16)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{256, 256, 64, 64, 768},
|
||||
{256, 256, 128, 128, 768},
|
||||
{512, 512, 64, 64, 768},
|
||||
{512, 512, 128, 128, 768},
|
||||
{1024, 1024, 64, 64, 768},
|
||||
{1024, 1024, 128, 128, 768},
|
||||
{2048, 2048, 64, 64, 768},
|
||||
{2048, 2048, 128, 128, 768},
|
||||
{4096, 4096, 64, 64, 768},
|
||||
{4096, 4096, 128, 128, 768},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = false;
|
||||
this->Run();
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_batched_gemm_gemm_util.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestBatchedGemmGemmFP16 : public TestBatchedGemmGemm<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
std::tuple<F16, F16, F16, F16, Row, Col, Row, Row>,
|
||||
std::tuple<F16, F16, F16, F16, Row, Col, Col, Row>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestBatchedGemmGemmFP16, KernelTypes);
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16)
|
||||
{
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadM)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{136, 128, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadN)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 136, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 40, 128, 1},
|
||||
{128, 128, 136, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_PadO)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 136, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddM)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{129, 128, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddN)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 129, 32, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddK)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 33, 128, 1},
|
||||
{128, 128, 129, 128, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
// If kernel B1Layout is RowMajor, expect not to support odd O size
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, Test_FP16_OddO)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{128, 128, 32, 129, 1},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = true;
|
||||
this->Run();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestBatchedGemmGemmFP16, DISABLED_Bench_FP16)
|
||||
{
|
||||
this->lengths_ = std::vector<std::vector<int>>{
|
||||
{256, 256, 64, 64, 768},
|
||||
{256, 256, 128, 128, 768},
|
||||
{512, 512, 64, 64, 768},
|
||||
{512, 512, 128, 128, 768},
|
||||
{1024, 1024, 64, 64, 768},
|
||||
{1024, 1024, 128, 128, 768},
|
||||
{2048, 2048, 64, 64, 768},
|
||||
{2048, 2048, 128, 128, 768},
|
||||
{4096, 4096, 64, 64, 768},
|
||||
{4096, 4096, 128, 128, 768},
|
||||
};
|
||||
this->bench_ = true;
|
||||
this->verify_ = false;
|
||||
this->Run();
|
||||
}
|
||||
@@ -1,8 +1,126 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_batched_gemm_gemm_util.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
|
||||
{
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ALayout = Row;
|
||||
using B0Layout = Col;
|
||||
using B1Layout = Row;
|
||||
using CLayout = Row;
|
||||
|
||||
using ADataType = F16;
|
||||
using B0DataType = F16;
|
||||
using B1DataType = F16;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using CDataType = F16;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = PassThrough;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
|
||||
|
||||
using DeviceGemmGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
|
||||
ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
GemmSpec,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
128, // Gemm1NPerBlock
|
||||
32, // Gemm1KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
2, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
4, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>, // BBlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<8, 32, 1>, // B1BlockTransfer
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
4,
|
||||
2,
|
||||
false,
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
bool IsSupported(int M, int N, int K, int O)
|
||||
{
|
||||
auto gemm = DeviceGemmGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(nullptr),
|
||||
static_cast<B0DataType*>(nullptr),
|
||||
static_cast<B1DataType*>(nullptr),
|
||||
static_cast<CDataType*>(nullptr),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
0, // BatchCount
|
||||
0, // StrideA
|
||||
0, // StrideB0
|
||||
0, // StrideB1
|
||||
0, // StrideC
|
||||
0, // BatchStrideA
|
||||
0, // BatchStrideB0
|
||||
0, // BatchStrideB1
|
||||
0, // BatchStrideC
|
||||
PassThrough{}, // a_element_op
|
||||
PassThrough{}, // b0_element_op
|
||||
PassThrough{}, // acc0_element_op
|
||||
PassThrough{}, // b1_element_op
|
||||
PassThrough{}); // c_element_op
|
||||
|
||||
return gemm.IsSupportedArgument(argument);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestBatchedGemmGemmFP16 : public TestBatchedGemmGemm<Tuple>
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include <vector>
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp"
|
||||
#include "profiler/profile_batched_gemm_gemm_impl.hpp"
|
||||
|
||||
using ck::tensor_operation::device::GemmSpecialization;
|
||||
@@ -13,7 +12,8 @@ using ck::tensor_operation::device::GemmSpecialization;
|
||||
template <ck::index_t N>
|
||||
using I = ck::Number<N>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
@@ -70,120 +70,3 @@ struct TestBatchedGemmGemm : public ::testing::Test
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <GemmSpecialization GemmSpec>
|
||||
struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
|
||||
{
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ALayout = Row;
|
||||
using B0Layout = Col;
|
||||
using B1Layout = Row;
|
||||
using CLayout = Row;
|
||||
|
||||
using ADataType = F16;
|
||||
using B0DataType = F16;
|
||||
using B1DataType = F16;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using CDataType = F16;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using B0ElementOp = PassThrough;
|
||||
using Acc0ElementOp = PassThrough;
|
||||
using B1ElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
|
||||
|
||||
using DeviceGemmGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmGemm_Xdl_CShuffle<
|
||||
ALayout,
|
||||
B0Layout,
|
||||
B1Layout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
B0DataType,
|
||||
B1DataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
AElementOp,
|
||||
B0ElementOp,
|
||||
Acc0ElementOp,
|
||||
B1ElementOp,
|
||||
CElementOp,
|
||||
GemmSpec,
|
||||
1,
|
||||
256,
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
32, // KPerBlock
|
||||
128, // Gemm1NPerBlock
|
||||
32, // Gemm1KPerBlock
|
||||
8, // AK1
|
||||
8, // BK1
|
||||
2, // B1K1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
4, // Gemm1NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<4, 64, 1>, // BBlockTransfer
|
||||
S<1, 0, 2>,
|
||||
S<1, 0, 2>,
|
||||
2,
|
||||
8,
|
||||
8,
|
||||
true,
|
||||
S<8, 32, 1>, // B1BlockTransfer
|
||||
S<0, 2, 1>,
|
||||
S<0, 2, 1>,
|
||||
1,
|
||||
4,
|
||||
2,
|
||||
false,
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
2, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
|
||||
bool IsSupported(int M, int N, int K, int O)
|
||||
{
|
||||
auto gemm = DeviceGemmGemmInstance{};
|
||||
auto invoker = gemm.MakeInvoker();
|
||||
auto argument = gemm.MakeArgument(static_cast<ADataType*>(nullptr),
|
||||
static_cast<B0DataType*>(nullptr),
|
||||
static_cast<B1DataType*>(nullptr),
|
||||
static_cast<CDataType*>(nullptr),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
O,
|
||||
0, // BatchCount
|
||||
0, // StrideA
|
||||
0, // StrideB0
|
||||
0, // StrideB1
|
||||
0, // StrideC
|
||||
0, // BatchStrideA
|
||||
0, // BatchStrideB0
|
||||
0, // BatchStrideB1
|
||||
0, // BatchStrideC
|
||||
PassThrough{}, // a_element_op
|
||||
PassThrough{}, // b0_element_op
|
||||
PassThrough{}, // acc0_element_op
|
||||
PassThrough{}, // b1_element_op
|
||||
PassThrough{}); // c_element_op
|
||||
|
||||
return gemm.IsSupportedArgument(argument);
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user