mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 21:39:15 +00:00
Wmma support for gemm_bias_add_reduce (#3316)
* Add tests for gemm_bias_add_reduce
* Initial working implementation
* Generalize implementation of reduce epilogue
* Add tests for all layouts
* Add instances
* Fix test archs
* Fix xdl bug
* Remove library/profiler duplications
* Fix num_byted error profiler
* Fix typos
* Fix copyright
[ROCm/composable_kernel commit: aad4cf0985]
This commit is contained in:
@@ -258,6 +258,7 @@ add_subdirectory(conv_util)
|
||||
add_subdirectory(reference_conv_fwd)
|
||||
add_subdirectory(gemm)
|
||||
add_subdirectory(gemm_add)
|
||||
add_subdirectory(gemm_bias_add_reduce)
|
||||
add_subdirectory(gemm_blockscale_wp)
|
||||
add_subdirectory(gemm_layernorm)
|
||||
add_subdirectory(gemm_multi_abd)
|
||||
|
||||
9
test/gemm_bias_add_reduce/CMakeLists.txt
Normal file
9
test/gemm_bias_add_reduce/CMakeLists.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
|
||||
add_gtest_executable(test_gemm_bias_add_reduce_fp16 test_gemm_bias_add_reduce_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_gemm_bias_add_reduce_fp16 PRIVATE utility device_gemm_bias_add_reduce_instance)
|
||||
endif()
|
||||
endif()
|
||||
106
test/gemm_bias_add_reduce/test_gemm_bias_add_reduce_fp16.cpp
Normal file
106
test/gemm_bias_add_reduce/test_gemm_bias_add_reduce_fp16.cpp
Normal file
@@ -0,0 +1,106 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "test_gemm_common.hpp"
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename X, typename Y>
|
||||
struct tuple_concat;
|
||||
|
||||
template <typename... Xs, typename... Ys>
|
||||
struct tuple_concat<std::tuple<Xs...>, std::tuple<Ys...>>
|
||||
{
|
||||
using type = std::tuple<Xs..., Ys...>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmBiasAddReduce_FP16_MK_NK
|
||||
: public ck::test::TestGemmBiasAddReduceCommon<
|
||||
typename tuple_concat<std::tuple<Row, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmBiasAddReduce_FP16_MK_KN
|
||||
: public ck::test::TestGemmBiasAddReduceCommon<
|
||||
typename tuple_concat<std::tuple<Row, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmBiasAddReduce_FP16_KM_KN
|
||||
: public ck::test::TestGemmBiasAddReduceCommon<
|
||||
typename tuple_concat<std::tuple<Col, Row>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmBiasAddReduce_FP16_KM_NK
|
||||
: public ck::test::TestGemmBiasAddReduceCommon<
|
||||
typename tuple_concat<std::tuple<Col, Col>, Tuple>::type>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
std::tuple< F16, F16, F16, F16, F16, F32>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGemmBiasAddReduce_FP16_MK_NK, KernelTypes);
|
||||
TYPED_TEST_SUITE(TestGemmBiasAddReduce_FP16_MK_KN, KernelTypes);
|
||||
TYPED_TEST_SUITE(TestGemmBiasAddReduce_FP16_KM_KN, KernelTypes);
|
||||
TYPED_TEST_SUITE(TestGemmBiasAddReduce_FP16_KM_NK, KernelTypes);
|
||||
|
||||
TYPED_TEST(TestGemmBiasAddReduce_FP16_MK_NK, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 1024;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmBiasAddReduce_FP16_MK_KN, Regular)
|
||||
{
|
||||
std::vector<int> Ms{512};
|
||||
constexpr int N = 1024;
|
||||
constexpr int K = 1024;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmBiasAddReduce_FP16_KM_KN, Regular)
|
||||
{
|
||||
std::vector<int> Ms{256};
|
||||
constexpr int N = 512;
|
||||
constexpr int K = 1024;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGemmBiasAddReduce_FP16_KM_NK, Regular)
|
||||
{
|
||||
std::vector<int> Ms{256};
|
||||
constexpr int N = 1024;
|
||||
constexpr int K = 1024;
|
||||
|
||||
for(int M : Ms)
|
||||
this->Run(M, N, K);
|
||||
}
|
||||
61
test/gemm_bias_add_reduce/test_gemm_common.hpp
Normal file
61
test/gemm_bias_add_reduce/test_gemm_common.hpp
Normal file
@@ -0,0 +1,61 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ck/ck.hpp"
|
||||
#include "profiler/profile_gemm_bias_add_reduce_impl.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace test {
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using F32 = float;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGemmBiasAddReduceCommon : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = Row;
|
||||
using ADataType = std::tuple_element_t<2, Tuple>;
|
||||
using BDataType = std::tuple_element_t<3, Tuple>;
|
||||
using CDataType = std::tuple_element_t<4, Tuple>;
|
||||
using BiasDataType = std::tuple_element_t<5, Tuple>;
|
||||
using D0DataType = std::tuple_element_t<6, Tuple>;
|
||||
using ReduceDataType = std::tuple_element_t<7, Tuple>;
|
||||
|
||||
public:
|
||||
static constexpr bool verify_ = true;
|
||||
static constexpr int init_method_ = 1; // integer value initialization
|
||||
static constexpr bool log_ = false;
|
||||
static constexpr bool bench_ = false; // measure kernel performance
|
||||
|
||||
void Run(const int M, const int N, const int K)
|
||||
{
|
||||
bool all_success = true;
|
||||
|
||||
int StrideA = std::is_same_v<remove_cvref_t<ALayout>, Row> ? K : M;
|
||||
int StrideB = std::is_same_v<remove_cvref_t<BLayout>, Row> ? N : K;
|
||||
int StrideD0 = std::is_same_v<remove_cvref_t<CLayout>, Row> ? N : M;
|
||||
int StrideC = std::is_same_v<CLayout, Row> ? N : M;
|
||||
|
||||
all_success =
|
||||
all_success &
|
||||
ck::profiler::profile_gemm_bias_add_reduce_impl<ADataType,
|
||||
BDataType,
|
||||
CDataType,
|
||||
BiasDataType,
|
||||
D0DataType,
|
||||
ReduceDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
verify_, init_method_, log_, bench_, M, N, K, StrideA, StrideB, StrideC, StrideD0);
|
||||
|
||||
EXPECT_TRUE(all_success);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user