Add grouped gemm instances for RDNA4 (#3237)

* wip: grouped_gemm implementation based on wmma kernel + example for fp16

* chore: clean up grouped_gem_wmma_splitk_fp16 example

* chore: add cmake options to fully disable XDL or WMMA kernels

* feat: add tests for grouped gemma wmma instances for f16 and bf16 (all layouts)

* chore: add grouped gemm wmma bf16 example

* refactor: reuse more code between instance factory functions

* chore: turn test failure if not all batch sizes are supported into a warning

* chore: made failing of test on unsupported instances conditional to not break old tests

* chore: add log message to failure case where AK1/BK1/KBatch is too high for K value

* fix: issue with new overloads of GridwiseGemm_wmma_cshuffle_v3::Run()

* fix: stray comma after parameter list

* fix: compilation issues on RDNA3 and tests failing due to unsupported problems still being ran

* chore: update copyright in header comments

* nit: minor feebdack

* refactor: unified XDL / wma tests

* fix: properly disable FP8 instances when ONLY targeting gfx11

* refactor: add v3 suffix to grouped_gemm device struct name

* fix: small typos in example code

* fix: fully exclude xdl/wmma instances when using the corresponding cmake flags

* chore: remove unused destructor and added pipeline support checks to remove unnecessary paths

* fix: make sure to not add instance library to group if library was skipped

* fix: make sure xdl grouped gemm doesnt fail the new test

* fix: explicitly exclude test if no xdl/wmma support, as pattern matching fails in this case

* fix: examples not working since dependent types and functions were moved to ck namespace in develop

* fix: tests failing when compiling for just gfx11 due to trying to run unsupported instances

* chore: replace/add copyright headers with new format
This commit is contained in:
Erwin Terpstra
2025-12-02 00:32:10 +01:00
committed by GitHub
parent 23fb253c4e
commit 46f1d740f0
30 changed files with 2291 additions and 268 deletions

View File

@@ -3,10 +3,15 @@
add_custom_target(test_grouped_gemm)
add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance)
add_dependencies(test_grouped_gemm test_grouped_gemm_splitk)
# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary
# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link
# the instance library if there's no instances present for the current arch.
if (CK_USE_XDL OR CK_USE_WMMA)
add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance)
add_dependencies(test_grouped_gemm test_grouped_gemm_splitk)
endif()
endif()
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp)

View File

@@ -9,6 +9,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "test_grouped_gemm_util.hpp"
#include "test_grouped_gemm_interface_xdl.hpp"
class TestGGemmSplitKInterface_MKNKMN : public ::testing::Test
{

View File

@@ -0,0 +1,205 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <array>
#include <string>
#include <sstream>
#include <tuple>
#include <vector>
#include <gtest/gtest.h>
#include "ck/ck.hpp"
#include "ck/stream_config.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/number.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp"
namespace ck {
namespace test {
template <typename ALayout,
typename BLayout,
typename ELayout,
tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t KPerBlock,
ck::index_t K1,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferSrcScalarPerVector,
index_t CDEBlockTransferScalarPerVector_NPerBlock>
struct DeviceGroupedGemmSplitkInstanceWrapper
{
using F16 = half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = tensor_operation::element_wise::PassThrough;
using EmptyTuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
template <ck::index_t N>
using I = ck::Number<N>;
using ABlockTransferThreadClusterArrageOrder =
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>;
using ABlockTransferSrcAccessOrder =
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>;
using ABlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<ALayout, Row>, I<3>, I<2>>;
using ABlockTransferDstScalarPerVector_K1 =
std::conditional_t<std::is_same_v<ALayout, Row>, I<8>, I<2>>;
using ABlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<1>, I<0>>;
using BBlockTransferThreadClusterArrageOrder =
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>;
using BBlockTransferSrcAccessOrder =
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>;
using BBlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<BLayout, Row>, I<2>, I<3>>;
using BBlockTransferDstScalarPerVector_K1 =
std::conditional_t<std::is_same_v<ALayout, Row>, I<2>, I<8>>;
using BBlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<0>, I<1>>;
using DeviceGroupedGemmSplitKInstance =
tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle<
ALayout,
BLayout,
EmptyTuple,
ELayout,
F16,
F16,
F32,
F16,
EmptyTuple,
F16,
PassThrough,
PassThrough,
PassThrough,
GemmSpec,
1,
128,
128,
128,
KPerBlock,
K1,
K1,
16,
16,
8,
4,
S<1, 4, 16, 1>,
ABlockTransferThreadClusterArrageOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim::value,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1::value,
ABlockLdsAddExtraM::value,
S<1, 4, 16, 1>,
BBlockTransferThreadClusterArrageOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim::value,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1::value,
BBlockLdsAddExtraM::value,
1,
1,
S<1, 16, 1, 8>,
CDEBlockTransferScalarPerVector_NPerBlock>;
bool IsSupported(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1) const
{
std::size_t n_groups = Ms.size();
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
StrideBs.size() == n_groups && StrideCs.size() == n_groups)
<< "The number of groups is not consistent!";
std::vector<tensor_operation::device::GemmDesc> gemm_descs;
for(std::size_t i = 0; i < n_groups; ++i)
{
gemm_descs.push_back(tensor_operation::device::GemmDesc{
Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
}
std::vector<const void*> p_As(n_groups, nullptr);
std::vector<const void*> p_Bs(n_groups, nullptr);
std::vector<void*> p_Cs(n_groups, nullptr);
auto p_Ds = std::vector<std::array<const void*, 0>>{};
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
auto argument = ggemm_instance.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1)
{
ggemm_instance.SetKBatchSize(&argument, kbatch);
}
return ggemm_instance.IsSupportedArgument(argument);
}
float Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1) const
{
std::size_t n_groups = Ms.size();
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
StrideBs.size() == n_groups && StrideCs.size() == n_groups)
<< "The number of groups is not consistent!";
std::vector<tensor_operation::device::GemmDesc> gemm_descs;
for(std::size_t i = 0; i < n_groups; ++i)
{
gemm_descs.push_back(tensor_operation::device::GemmDesc{
Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
}
std::vector<const void*> p_As(n_groups, nullptr);
std::vector<const void*> p_Bs(n_groups, nullptr);
std::vector<void*> p_Cs(n_groups, nullptr);
auto p_Ds = std::vector<std::array<const void*, 0>>{};
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
auto argument = ggemm_instance.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1)
{
ggemm_instance.SetKBatchSize(&argument, kbatch);
}
if(kbatch > 1 && ck::is_gfx11_supported())
{
EXPECT_FALSE(ggemm_instance.IsSupportedArgument(argument));
return 0;
}
else
{
EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument));
auto invoker = ggemm_instance.MakeInvoker();
DeviceMem dev_gemm_kargs(ggemm_instance.GetDeviceKernelArgSize(&argument));
ggemm_instance.SetDeviceKernelArgs(&argument, dev_gemm_kargs.GetDeviceBuffer());
return invoker.Run(argument, StreamConfig{nullptr, false});
}
}
};
} // namespace test
} // namespace ck

View File

@@ -24,21 +24,48 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template <typename Tuple>
class TestGroupedGemm : public ck::test::TestGroupedGemm<Tuple>
{
public:
void SetUp() override
{
ck::test::TestGroupedGemm<Tuple>::SetUp();
#if defined(CK_USE_WMMA)
// The old XDL tests didn't fail if instances were not supported, so we want to keep that
// behaviour When compiling WMMA instances and WMMA is supported, then we'll fail if a
// specific case is not supported
this->fail_if_no_supported_instances_ =
ck::is_gfx11_supported() || ck::is_gfx12_supported();
#endif
}
};
// clang-format off
using KernelTypes = ::testing::Types<
#if defined(CK_USE_WMMA)
// WWMA only. No reason to not have it for XDL, but the instance was not defined and it was not in the original test.
std::tuple< Col, Col, Row, BF16, BF16, BF16>,
#endif
#if defined(CK_USE_XDL) && defined(__gfx9__)
// XDL only at the moment, instances for WMMA not defined
std::tuple< Row, Row, Row, BF16, I8, BF16>,
std::tuple< Row, Col, Row, BF16, I8, BF16>,
#endif
#if (defined(CK_USE_XDL) && (defined(__gfx9__) || defined(__gfx12__))) || (defined(CK_USE_WMMA) && defined(__gfx12__))
std::tuple< Row, Row, Row, F8, F16, F16>,
std::tuple< Row, Row, Row, F16, F8, F16>,
#endif
std::tuple< Row, Row, Row, F16, F16, F16>,
std::tuple< Row, Col, Row, F16, F16, F16>,
std::tuple< Col, Row, Row, F16, F16, F16>,
std::tuple< Col, Col, Row, F16, F16, F16>,
std::tuple< Row, Row, Row, BF16, BF16, BF16>,
std::tuple< Row, Col, Row, BF16, BF16, BF16>,
std::tuple< Col, Row, Row, BF16, BF16, BF16>,
std::tuple< Row, Row, Row, BF16, I8, BF16>,
std::tuple< Row, Col, Row, BF16, I8, BF16>,
std::tuple< Row, Row, Row, F16, F8, F16>,
std::tuple< Row, Row, Row, F8, F16, F16>
std::tuple< Col, Row, Row, BF16, BF16, BF16>
>;
// clang-format on

View File

@@ -65,6 +65,13 @@ TYPED_TEST(TestGroupedGemm, MNKPadded)
TYPED_TEST(TestGroupedGemm, TestLargeKBatch)
{
// gfx11 does not support split-K due to missing atomic add for fp16/bf16
// Technically, we could still run the tests for fp32, but we currently don't have instances for
// it so we disable it entirely
if(ck::is_gfx11_supported())
GTEST_SKIP() << "Split-K not supported for FP16/BF16 on GFX11 due to missing atomic add "
"instructions";
const std::vector<int> Ms{188, 210};
constexpr int N = 768;
constexpr int K = 4096;

View File

@@ -11,16 +11,7 @@
#include <gtest/gtest.h>
#include "ck/ck.hpp"
#include "ck/stream_config.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/number.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp"
extern ck::index_t param_mask;
@@ -41,7 +32,7 @@ std::string serialize_range(const Range& range)
return std::string(str.begin(), str.end() - 2);
}
template <typename Tuple>
template <typename Tuple, bool FailIfNoSupportedInstances = false>
class TestGroupedGemm : public testing::Test
{
protected:
@@ -62,9 +53,26 @@ class TestGroupedGemm : public testing::Test
static constexpr bool bench_ = false; // measure kernel performance
static constexpr int n_warmup_ = 0;
static constexpr int n_iter_ = 1;
bool fail_if_no_supported_instances_ = FailIfNoSupportedInstances;
std::vector<int> k_batches_;
void SetUp() override { k_batches_ = {1, 2, 3, 5, 8}; }
void SetUp() override
{
constexpr bool require_16bit_atomic_add =
std::is_same_v<EDataType, ck::half_t> || std::is_same_v<EDataType, ck::bhalf_t>;
if(require_16bit_atomic_add && ck::is_gfx11_supported())
{
// gfx11 does not support split-K due to missing atomic add for fp16/bf16
// Technically, we could still use split-K for fp32, but we currently don't have
// instances for it so we disable it entirely
k_batches_ = {1};
}
else
{
k_batches_ = {1, 2, 3, 5, 8};
}
}
private:
template <typename Layout>
@@ -132,204 +140,31 @@ class TestGroupedGemm : public testing::Test
const std::vector<int>& StrideCs,
const std::vector<int>& kbatches)
{
bool pass = ck::profiler::profile_grouped_gemm_impl<ADataType,
BDataType,
EDataType,
float,
ALayout,
BLayout,
ELayout>(verify_,
init_method_,
log_,
bench_,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatches,
n_warmup_,
n_iter_,
instance_index);
bool pass =
ck::profiler::profile_grouped_gemm_impl<ADataType,
BDataType,
EDataType,
float,
ALayout,
BLayout,
ELayout>(verify_,
init_method_,
log_,
bench_,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatches,
n_warmup_,
n_iter_,
instance_index,
fail_if_no_supported_instances_);
EXPECT_TRUE(pass);
}
};
template <typename ALayout,
typename BLayout,
typename ELayout,
tensor_operation::device::GemmSpecialization GemmSpec,
ck::index_t KPerBlock,
ck::index_t K1,
ck::index_t ABlockTransferSrcScalarPerVector,
ck::index_t BBlockTransferSrcScalarPerVector,
index_t CDEBlockTransferScalarPerVector_NPerBlock>
struct DeviceGroupedGemmSplitkInstanceWrapper
{
using F16 = half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = tensor_operation::element_wise::PassThrough;
using EmptyTuple = ck::Tuple<>;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
template <ck::index_t N>
using I = ck::Number<N>;
using ABlockTransferThreadClusterArrageOrder =
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>;
using ABlockTransferSrcAccessOrder =
std::conditional_t<std::is_same_v<ALayout, Row>, S<0, 2, 1, 3>, S<0, 1, 3, 2>>;
using ABlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<ALayout, Row>, I<3>, I<2>>;
using ABlockTransferDstScalarPerVector_K1 =
std::conditional_t<std::is_same_v<ALayout, Row>, I<8>, I<2>>;
using ABlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<1>, I<0>>;
using BBlockTransferThreadClusterArrageOrder =
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>;
using BBlockTransferSrcAccessOrder =
std::conditional_t<std::is_same_v<BLayout, Row>, S<0, 1, 3, 2>, S<0, 2, 1, 3>>;
using BBlockTransferSrcVectorDim = std::conditional_t<std::is_same_v<BLayout, Row>, I<2>, I<3>>;
using BBlockTransferDstScalarPerVector_K1 =
std::conditional_t<std::is_same_v<ALayout, Row>, I<2>, I<8>>;
using BBlockLdsAddExtraM = std::conditional_t<std::is_same_v<ALayout, Row>, I<0>, I<1>>;
using DeviceGroupedGemmSplitKInstance =
tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle<
ALayout,
BLayout,
EmptyTuple,
ELayout,
F16,
F16,
F32,
F16,
EmptyTuple,
F16,
PassThrough,
PassThrough,
PassThrough,
GemmSpec,
1,
128,
128,
128,
KPerBlock,
K1,
K1,
16,
16,
8,
4,
S<1, 4, 16, 1>,
ABlockTransferThreadClusterArrageOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim::value,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_K1::value,
ABlockLdsAddExtraM::value,
S<1, 4, 16, 1>,
BBlockTransferThreadClusterArrageOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim::value,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_K1::value,
BBlockLdsAddExtraM::value,
1,
1,
S<1, 16, 1, 8>,
CDEBlockTransferScalarPerVector_NPerBlock>;
bool IsSupported(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1) const
{
std::size_t n_groups = Ms.size();
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
StrideBs.size() == n_groups && StrideCs.size() == n_groups)
<< "The number of groups is not consistent!";
std::vector<tensor_operation::device::GemmDesc> gemm_descs;
for(std::size_t i = 0; i < n_groups; ++i)
{
gemm_descs.push_back(tensor_operation::device::GemmDesc{
Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
}
std::vector<const void*> p_As(n_groups, nullptr);
std::vector<const void*> p_Bs(n_groups, nullptr);
std::vector<void*> p_Cs(n_groups, nullptr);
auto p_Ds = std::vector<std::array<const void*, 0>>{};
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
auto argument = ggemm_instance.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1)
{
ggemm_instance.SetKBatchSize(&argument, kbatch);
}
return ggemm_instance.IsSupportedArgument(argument);
}
float Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1) const
{
std::size_t n_groups = Ms.size();
EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups &&
StrideBs.size() == n_groups && StrideCs.size() == n_groups)
<< "The number of groups is not consistent!";
std::vector<tensor_operation::device::GemmDesc> gemm_descs;
for(std::size_t i = 0; i < n_groups; ++i)
{
gemm_descs.push_back(tensor_operation::device::GemmDesc{
Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
}
std::vector<const void*> p_As(n_groups, nullptr);
std::vector<const void*> p_Bs(n_groups, nullptr);
std::vector<void*> p_Cs(n_groups, nullptr);
auto p_Ds = std::vector<std::array<const void*, 0>>{};
auto ggemm_instance = DeviceGroupedGemmSplitKInstance{};
auto argument = ggemm_instance.MakeArgument(
p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{});
if(kbatch > 1)
{
ggemm_instance.SetKBatchSize(&argument, kbatch);
}
if(kbatch > 1 && ck::is_gfx11_supported())
{
EXPECT_FALSE(ggemm_instance.IsSupportedArgument(argument));
return 0;
}
else
{
EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument));
auto invoker = ggemm_instance.MakeInvoker();
DeviceMem dev_gemm_kargs(ggemm_instance.GetDeviceKernelArgSize(&argument));
ggemm_instance.SetDeviceKernelArgs(&argument, dev_gemm_kargs.GetDeviceBuffer());
return invoker.Run(argument, StreamConfig{nullptr, false});
}
}
};
} // namespace test
} // namespace ck