Implement grouped gemm tile loop for RDNA4 (#3304)

* feat: grouped gemm tile loop support for RDNA4

* fix: removed extra parameter from grouped gemm example instance

* fix: FP8 check incorrectly enabling FP8 on RDNA3

[ROCm/composable_kernel commit: eb041079a3]
This commit is contained in:
Erwin Terpstra
2026-01-13 07:14:23 +01:00
committed by GitHub
parent 0d13ef7329
commit d69aeffd0d
44 changed files with 3067 additions and 1223 deletions

View File

@@ -0,0 +1,18 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
add_custom_target(test_grouped_gemm_tile_loop)
if (CK_USE_XDL OR CK_USE_WMMA)
add_gtest_executable(test_grouped_gemm_tile_loop_vanilla test_grouped_gemm_tile_loop.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_gemm_tile_loop_vanilla PRIVATE utility device_grouped_gemm_tile_loop_instance)
add_dependencies(test_grouped_gemm_tile_loop test_grouped_gemm_tile_loop_vanilla)
endif()
add_gtest_executable(test_grouped_gemm_tile_loop_multiply test_grouped_gemm_tile_loop_multiply.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_gemm_tile_loop_multiply PRIVATE utility device_grouped_gemm_tile_loop_instance)
add_dependencies(test_grouped_gemm_tile_loop test_grouped_gemm_tile_loop_multiply)
endif()
endif()

View File

@@ -0,0 +1,52 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/data_type.hpp"
#include "gtest/gtest.h"
#include "test_grouped_gemm_tile_loop_util.hpp"
ck::index_t param_mask = 0xffffff;
ck::index_t instance_index = -1;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F8 = ck::f8_t;
using I8 = int8_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <typename Tuple>
class TestGroupedGemmTileLoop : public ck::test::TestGroupedGemmTileLoop<Tuple>
{
};
// clang-format off
using KernelTypes = ::testing::Types<
ck::Tuple<Row, Row, ck::Tuple<>, Row, F16, F16, ck::Tuple<>, F16>,
ck::Tuple<Row, Col, ck::Tuple<>, Row, F16, F16, ck::Tuple<>, F16>
>;
// clang-format on
TYPED_TEST_SUITE(TestGroupedGemmTileLoop, KernelTypes);
#include "test_grouped_gemm_tile_loop_ut_cases.inc"
int main(int argc, char** argv)
{
testing::InitGoogleTest(&argc, argv);
if(argc == 1) {}
else if(argc == 3)
{
param_mask = strtol(argv[1], nullptr, 0);
instance_index = atoi(argv[2]);
}
else
{
std::cout << "Usage of " << argv[0] << std::endl;
std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl;
}
return RUN_ALL_TESTS();
}

View File

@@ -0,0 +1,63 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/utility/data_type.hpp"
#include "gtest/gtest.h"
#include "example/68_gemm_add/common.hpp"
#include "test_grouped_gemm_tile_loop_util.hpp"
ck::index_t param_mask = 0xffffff;
ck::index_t instance_index = -1;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F8 = ck::f8_t;
using I8 = int8_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Multiply = ck::tensor_operation::element_wise::Multiply;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
template <typename Tuple>
class TestGroupedGemmTileLoop : public ck::test::TestGroupedGemmTileLoop<Tuple>
{
};
// clang-format off
using KernelTypes = ::testing::Types<
ck::Tuple<Row, Row, ck::Tuple<Row>, Row, BF16, I8, ck::Tuple<BF16>, BF16, PassThrough, PassThrough, Multiply>,
ck::Tuple<Row, Row, ck::Tuple<Row, Row>, Row, BF16, I8, ck::Tuple<BF16, BF16>, BF16, PassThrough, PassThrough, MultiplyAdd>,
ck::Tuple<Row, Row, ck::Tuple<Row, Row>, Row, BF16, I8, ck::Tuple<BF16, BF16>, BF16, PassThrough, PassThrough, MultiplyAddFastGelu>,
ck::Tuple<Row, Row, ck::Tuple<Row>, Row, BF16, I8, ck::Tuple<BF16>, BF16, PassThrough, PassThrough, MultiplyFastGelu>
>;
// clang-format on
TYPED_TEST_SUITE(TestGroupedGemmTileLoop, KernelTypes);
#include "test_grouped_gemm_tile_loop_ut_cases.inc"
int main(int argc, char** argv)
{
testing::InitGoogleTest(&argc, argv);
if(argc == 1) {}
else if(argc == 3)
{
param_mask = strtol(argv[1], nullptr, 0);
instance_index = atoi(argv[2]);
}
else
{
std::cout << "Usage of " << argv[0] << std::endl;
std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl;
}
return RUN_ALL_TESTS();
}

View File

@@ -0,0 +1,64 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
TYPED_TEST(TestGroupedGemmTileLoop, TinyCases)
{
const std::vector<int> Ms{2, 1};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
this->Run(Ms, Ns, Ks);
}
TYPED_TEST(TestGroupedGemmTileLoop, SmallCases)
{
const std::vector<int> Ms{2, 1, 3, 4, 5};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
this->Run(Ms, Ns, Ks);
}
TYPED_TEST(TestGroupedGemmTileLoop, MidCases)
{
const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
constexpr int N = 768;
constexpr int K = 544;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
this->Run(Ms, Ns, Ks);
}
TYPED_TEST(TestGroupedGemmTileLoop, Regular)
{
const std::vector<int> Ms{64, 128, 256};
constexpr int N = 768;
constexpr int K = 320;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
this->Run(Ms, Ns, Ks);
}
TYPED_TEST(TestGroupedGemmTileLoop, MNKPadded)
{
const std::vector<int> Ms{127, 150, 188, 210};
constexpr int N = 136;
constexpr int K = 280;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
this->Run(Ms, Ns, Ks);
}

View File

@@ -0,0 +1,173 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <array>
#include <string>
#include <sstream>
#include <tuple>
#include <type_traits>
#include <vector>
#include <gtest/gtest.h>
#include "ck/ck.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp"
#include "profiler/profile_grouped_gemm_tile_loop_generic_impl.hpp"
extern ck::index_t param_mask;
extern ck::index_t instance_index;
namespace ck {
namespace test {
template <typename Tuple, bool FailIfNoSupportedInstances = false>
class TestGroupedGemmTileLoop : public testing::Test
{
protected:
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ALayout = tuple_element_t<0, Tuple>;
using BLayout = tuple_element_t<1, Tuple>;
using DsLayout = tuple_element_t<2, Tuple>;
using ELayout = tuple_element_t<3, Tuple>;
using ADataType = tuple_element_t<4, Tuple>;
using BDataType = tuple_element_t<5, Tuple>;
using DsDataType = tuple_element_t<6, Tuple>;
using EDataType = tuple_element_t<7, Tuple>;
using AElementOp = tuple_element_or_t<8, Tuple, PassThrough>;
using BElementOp = tuple_element_or_t<9, Tuple, PassThrough>;
using CDEElementOp = tuple_element_or_t<10, Tuple, PassThrough>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
static constexpr auto NumDTensor = DsLayout::Size();
public:
static constexpr bool verify_ = true;
static constexpr int init_method_ = 1; // integer value initialization
static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance
static constexpr int n_warmup_ = 0;
static constexpr int n_iter_ = 1;
bool fail_if_no_supported_instances_ = FailIfNoSupportedInstances;
private:
template <typename Layout>
void SetStrides(std::vector<int>& strides,
const std::vector<int>& rows,
const std::vector<int>& cols) const
{
if(std::is_same_v<Layout, Row>)
{
for(const auto c : cols)
{
strides.emplace_back(c);
}
}
else if(std::is_same_v<Layout, Col>)
{
for(const auto r : rows)
{
strides.emplace_back(r);
}
}
}
public:
void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs = {},
const std::vector<int>& StrideBs = {},
const std::vector<std::array<int, NumDTensor>>& StrideDs = {},
const std::vector<int>& StrideEs = {})
{
std::vector<int> stride_as = StrideAs;
std::vector<int> stride_bs = StrideBs;
std::vector<std::array<int, NumDTensor>> stride_ds = StrideDs;
std::vector<int> stride_es = StrideEs;
if(stride_as.empty())
{
SetStrides<ALayout>(stride_as, Ms, Ks);
}
if(stride_bs.empty())
{
SetStrides<BLayout>(stride_bs, Ks, Ns);
}
if(stride_ds.empty())
{
for(size_t group = 0; group < Ms.size(); ++group)
{
std::array<int, NumDTensor> d_strides;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = tuple_element_t<i, DsLayout>;
if(std::is_same_v<DLayout, Row>)
{
d_strides[i] = Ns[group];
}
else if(std::is_same_v<DLayout, Col>)
{
d_strides[i] = Ms[group];
}
});
stride_ds.emplace_back(d_strides);
}
}
if(stride_es.empty())
{
SetStrides<ELayout>(stride_es, Ms, Ns);
}
RunSingle(Ms, Ns, Ks, stride_as, stride_bs, stride_ds, stride_es);
}
void RunSingle(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<std::array<int, NumDTensor>>& StrideDs,
const std::vector<int>& StrideEs)
{
bool pass =
ck::profiler::profile_grouped_gemm_tile_loop_generic_impl<ADataType,
BDataType,
DsDataType,
EDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
AElementOp,
BElementOp,
CDEElementOp>(verify_,
init_method_,
log_,
bench_,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideDs,
StrideEs,
n_warmup_,
n_iter_);
EXPECT_TRUE(pass);
}
};
} // namespace test
} // namespace ck