Wmma support for gemm_bias_add_reduce (#3316)

* Add tests for gemm_bias_add_reduce

* Initial working implementation

* Generalize implementation of reduce epilogue

* Add tests for all layouts

* Add instances

* Fix test archs

* Fix xdl bug

* Remove library/profiler duplications

* Fix num_byted error profiler

* Fix typos

* Fix copyright

[ROCm/composable_kernel commit: aad4cf0985]
This commit is contained in:
Enrico Degregori
2026-01-07 19:27:16 +01:00
committed by GitHub
parent d074af36c9
commit 6eab5bea54
15 changed files with 1424 additions and 141 deletions

View File

@@ -258,6 +258,7 @@ add_subdirectory(conv_util)
add_subdirectory(reference_conv_fwd)
add_subdirectory(gemm)
add_subdirectory(gemm_add)
add_subdirectory(gemm_bias_add_reduce)
add_subdirectory(gemm_blockscale_wp)
add_subdirectory(gemm_layernorm)
add_subdirectory(gemm_multi_abd)

View File

@@ -0,0 +1,9 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
add_gtest_executable(test_gemm_bias_add_reduce_fp16 test_gemm_bias_add_reduce_fp16.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_bias_add_reduce_fp16 PRIVATE utility device_gemm_bias_add_reduce_instance)
endif()
endif()

View File

@@ -0,0 +1,106 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <tuple>
#include "gtest/gtest.h"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "test_gemm_common.hpp"
using F16 = ck::half_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
namespace {
template <typename X, typename Y>
struct tuple_concat;
template <typename... Xs, typename... Ys>
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
{
using type = std::tuple<Xs..., Ys...>;
};
} // namespace
template <typename Tuple>
class TestGemmBiasAddReduce_FP16_MK_NK
: public ck::test::TestGemmBiasAddReduceCommon<
typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmBiasAddReduce_FP16_MK_KN
: public ck::test::TestGemmBiasAddReduceCommon<
typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmBiasAddReduce_FP16_KM_KN
: public ck::test::TestGemmBiasAddReduceCommon<
typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
{
};
template <typename Tuple>
class TestGemmBiasAddReduce_FP16_KM_NK
: public ck::test::TestGemmBiasAddReduceCommon<
typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
{
};
// clang-format off
using KernelTypes = ::testing::Types<
std::tuple< F16, F16, F16, F16, F16, F32>
>;
// clang-format on
TYPED_TEST_SUITE(TestGemmBiasAddReduce_FP16_MK_NK, KernelTypes);
TYPED_TEST_SUITE(TestGemmBiasAddReduce_FP16_MK_KN, KernelTypes);
TYPED_TEST_SUITE(TestGemmBiasAddReduce_FP16_KM_KN, KernelTypes);
TYPED_TEST_SUITE(TestGemmBiasAddReduce_FP16_KM_NK, KernelTypes);
TYPED_TEST(TestGemmBiasAddReduce_FP16_MK_NK, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 512;
constexpr int K = 1024;
for(int M : Ms)
this->Run(M, N, K);
}
TYPED_TEST(TestGemmBiasAddReduce_FP16_MK_KN, Regular)
{
std::vector<int> Ms{512};
constexpr int N = 1024;
constexpr int K = 1024;
for(int M : Ms)
this->Run(M, N, K);
}
TYPED_TEST(TestGemmBiasAddReduce_FP16_KM_KN, Regular)
{
std::vector<int> Ms{256};
constexpr int N = 512;
constexpr int K = 1024;
for(int M : Ms)
this->Run(M, N, K);
}
TYPED_TEST(TestGemmBiasAddReduce_FP16_KM_NK, Regular)
{
std::vector<int> Ms{256};
constexpr int N = 1024;
constexpr int K = 1024;
for(int M : Ms)
this->Run(M, N, K);
}

View File

@@ -0,0 +1,61 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "profiler/profile_gemm_bias_add_reduce_impl.hpp"
namespace ck {
namespace test {
using Row = ck::tensor_layout::gemm::RowMajor;
using F32 = float;
template <typename Tuple>
class TestGemmBiasAddReduceCommon : public ::testing::Test
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = Row;
using ADataType = std::tuple_element_t<2, Tuple>;
using BDataType = std::tuple_element_t<3, Tuple>;
using CDataType = std::tuple_element_t<4, Tuple>;
using BiasDataType = std::tuple_element_t<5, Tuple>;
using D0DataType = std::tuple_element_t<6, Tuple>;
using ReduceDataType = std::tuple_element_t<7, Tuple>;
public:
static constexpr bool verify_ = true;
static constexpr int init_method_ = 1; // integer value initialization
static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance
void Run(const int M, const int N, const int K)
{
bool all_success = true;
int StrideA = std::is_same_v<remove_cvref_t<ALayout>, Row> ? K : M;
int StrideB = std::is_same_v<remove_cvref_t<BLayout>, Row> ? N : K;
int StrideD0 = std::is_same_v<remove_cvref_t<CLayout>, Row> ? N : M;
int StrideC = std::is_same_v<CLayout, Row> ? N : M;
all_success =
all_success &
ck::profiler::profile_gemm_bias_add_reduce_impl<ADataType,
BDataType,
CDataType,
BiasDataType,
D0DataType,
ReduceDataType,
ALayout,
BLayout,
CLayout>(
verify_, init_method_, log_, bench_, M, N, K, StrideA, StrideB, StrideC, StrideD0);
EXPECT_TRUE(all_success);
}
};
} // namespace test
} // namespace ck