mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4425 (commit 513cf9f)
[CK] Implement device grouped gemm fixed nk multi abd for rdna4 (#4425) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation Add support for grouped gemm multi ABD fixed NK. MR ## Technical Details Changes from the reverted PR: - Device struct for grouped gemm with multiple ABD and fixed NK (DeviceGroupedGemm_Wmma_Multi_ABD_Fixed_NK). - Wmma versions of existing example codes: 59_grouped_gemm_multi_ABD - Unit tests for both new wmma implementation and the reference xdl code (previously missing) - Note: Some Xdl instances were commented out because of unit test failures. As mentioned apparently for xdl this feature was missing tests so our assumption is either there is an implemenetation bug or these instances were not set up correctly. Has the potential for a follow-up issue. - Generic ck profiler interface with the purpose of calling unit tests. - Gemm instances with specific elementwise operations for gemm bias gelu calculations. - Added class for grouped gemm multi ABD reference calculations. Fix epilogue selection in device implementation that caused unit test failures ## Test Plan Covered by added unit tests ## Test Result CI successfully passing ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
1a2c0d835a
commit
a32d704d89
@@ -24,6 +24,12 @@ if (CK_USE_XDL OR CK_USE_WMMA)
|
||||
target_link_libraries(test_grouped_gemm_fixed_nk PRIVATE utility device_grouped_gemm_fixed_nk_instance)
|
||||
add_dependencies(test_grouped_gemm test_grouped_gemm_fixed_nk)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_grouped_gemm_multi_abd_fixed_nk test_grouped_gemm_multi_abd_fixed_nk.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_gemm_multi_abd_fixed_nk PRIVATE utility device_grouped_gemm_fixed_nk_multi_abd_instance)
|
||||
add_dependencies(test_grouped_gemm test_grouped_gemm_multi_abd_fixed_nk)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp)
|
||||
|
||||
256
test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp
Normal file
256
test/grouped_gemm/test_grouped_gemm_multi_abd_fixed_nk.cpp
Normal file
@@ -0,0 +1,256 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/type.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "profiler/profile_grouped_gemm_multi_abd_fixed_nk_impl.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
static ck::index_t param_mask = 0xffffff;
|
||||
static ck::index_t instance_index = -1;
|
||||
|
||||
using FP32 = float;
|
||||
using FP16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using I8 = int8_t;
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddFastGelu = ck::tensor_operation::element_wise::AddFastGelu;
|
||||
using Add = ck::tensor_operation::element_wise::Add;
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
using FastGelu = ck::tensor_operation::element_wise::FastGelu;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<BF16>, BF16, ck::Tuple<Row>, ck::Tuple<Col, Col>, ck::Tuple<Row>, Row, AddFastGelu>,
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<BF16>, BF16, ck::Tuple<Col>, ck::Tuple<Row, Row>, ck::Tuple<Row>, Row, AddFastGelu>,
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<BF16>, BF16, ck::Tuple<Row>, ck::Tuple<Row, Row>, ck::Tuple<Row>, Row, AddFastGelu>,
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<BF16>, BF16, ck::Tuple<Row>, ck::Tuple<Col, Col>, ck::Tuple<Row>, Row, Add>,
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<BF16>, BF16, ck::Tuple<Col>, ck::Tuple<Row, Row>, ck::Tuple<Row>, Row, Add>,
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<BF16>, BF16, ck::Tuple<Row>, ck::Tuple<Row, Row>, ck::Tuple<Row>, Row, Add>,
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<>, BF16, ck::Tuple<Row>, ck::Tuple<Col, Col>, ck::Tuple<>, Row, PassThrough>,
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<>, BF16, ck::Tuple<Col>, ck::Tuple<Row, Row>, ck::Tuple<>, Row, PassThrough>,
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<>, BF16, ck::Tuple<Row>, ck::Tuple<Row, Row>, ck::Tuple<>, Row, PassThrough>,
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<>, BF16, ck::Tuple<Row>, ck::Tuple<Col, Col>, ck::Tuple<>, Row, FastGelu>,
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<>, BF16, ck::Tuple<Col>, ck::Tuple<Row, Row>, ck::Tuple<>, Row, FastGelu>,
|
||||
std::tuple<ck::Tuple<BF16>, ck::Tuple<I8, BF16>, ck::Tuple<>, BF16, ck::Tuple<Row>, ck::Tuple<Row, Row>, ck::Tuple<>, Row, FastGelu>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedGemmMultiABDFixedNK : public testing::Test
|
||||
{
|
||||
protected:
|
||||
using AsDataType = std::tuple_element_t<0, Tuple>;
|
||||
using BsDataType = std::tuple_element_t<1, Tuple>;
|
||||
using DsDataType = std::tuple_element_t<2, Tuple>;
|
||||
using EDataType = std::tuple_element_t<3, Tuple>;
|
||||
using AccDataType = float;
|
||||
using AsLayout = std::tuple_element_t<4, Tuple>;
|
||||
using BsLayout = std::tuple_element_t<5, Tuple>;
|
||||
using DsLayout = std::tuple_element_t<6, Tuple>;
|
||||
using ELayout = std::tuple_element_t<7, Tuple>;
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = Multiply;
|
||||
using CDEElementOp = std::tuple_element_t<8, Tuple>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
public:
|
||||
static constexpr bool verify_ = true;
|
||||
static constexpr int init_method_ = 1; // integer value initialization
|
||||
static constexpr bool log_ = false;
|
||||
static constexpr bool bench_ = false; // measure kernel performance
|
||||
static constexpr int n_warmup_ = 0;
|
||||
static constexpr int n_iter_ = 1;
|
||||
|
||||
std::vector<int> k_batches_ = {1};
|
||||
|
||||
private:
|
||||
template <typename Layout>
|
||||
void SetStrides(std::vector<int>& strides,
|
||||
const std::vector<int>& rows,
|
||||
const std::vector<int>& cols) const
|
||||
{
|
||||
if(std::is_same_v<Layout, Row>)
|
||||
{
|
||||
for(const auto c : cols)
|
||||
{
|
||||
strides.emplace_back(c);
|
||||
}
|
||||
}
|
||||
else if(std::is_same_v<Layout, Col>)
|
||||
{
|
||||
for(const auto r : rows)
|
||||
{
|
||||
strides.emplace_back(r);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Layouts>
|
||||
void SetTupleStrides(std::vector<int>& strides,
|
||||
const std::vector<int>& rows,
|
||||
const std::vector<int>& cols) const
|
||||
{
|
||||
if constexpr(Layouts::Size() > 0)
|
||||
{
|
||||
// As of now multi ABD implementation supports only tensors with matching layouts.
|
||||
using Layout = ck::remove_cvref_t<ck::tuple_element_t<ck::Number<0>{}, Layouts>>;
|
||||
SetStrides<Layout>(strides, rows, cols);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
void Run(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs = {},
|
||||
const std::vector<int>& StrideBs = {},
|
||||
const std::vector<int>& StrideDs = {},
|
||||
const std::vector<int>& StrideE = {})
|
||||
{
|
||||
std::vector<int> stride_as = StrideAs;
|
||||
std::vector<int> stride_bs = StrideBs;
|
||||
std::vector<int> stride_ds = StrideDs;
|
||||
std::vector<int> stride_e = StrideE;
|
||||
|
||||
if(stride_as.empty())
|
||||
{
|
||||
SetTupleStrides<AsLayout>(stride_as, Ms, Ks);
|
||||
}
|
||||
if(stride_bs.empty())
|
||||
{
|
||||
SetTupleStrides<BsLayout>(stride_bs, Ks, Ns);
|
||||
}
|
||||
if(stride_ds.empty())
|
||||
{
|
||||
SetTupleStrides<DsLayout>(stride_ds, Ms, Ns);
|
||||
}
|
||||
if(stride_e.empty())
|
||||
{
|
||||
SetStrides<ELayout>(stride_e, Ms, Ns);
|
||||
}
|
||||
|
||||
RunSingle(Ms, Ns, Ks, stride_as, stride_bs, stride_ds, stride_e);
|
||||
}
|
||||
|
||||
void RunSingle(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<int>& StrideDs,
|
||||
const std::vector<int>& StrideE)
|
||||
{
|
||||
bool pass =
|
||||
ck::profiler::profile_grouped_gemm_multi_abd_fixed_nk_impl<AsDataType,
|
||||
BsDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AccDataType,
|
||||
AsLayout,
|
||||
BsLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>(verify_,
|
||||
init_method_,
|
||||
log_,
|
||||
bench_,
|
||||
Ms,
|
||||
Ns,
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
k_batches_,
|
||||
n_warmup_,
|
||||
n_iter_);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedGemmMultiABDFixedNK, KernelTypes);
|
||||
|
||||
TYPED_TEST(TestGroupedGemmMultiABDFixedNK, TinyCases)
|
||||
{
|
||||
const std::vector<int> Ms{3, 4};
|
||||
constexpr int N = 8;
|
||||
constexpr int K = 64;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
|
||||
this->Run(Ms, Ns, Ks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedGemmMultiABDFixedNK, SmallCases)
|
||||
{
|
||||
const std::vector<int> Ms{3, 5, 16, 7, 8};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 544;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
|
||||
this->Run(Ms, Ns, Ks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedGemmMultiABDFixedNK, MidCases)
|
||||
{
|
||||
const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 544;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
|
||||
this->Run(Ms, Ns, Ks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedGemmMultiABDFixedNK, Regular)
|
||||
{
|
||||
const std::vector<int> Ms{64, 128, 256};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 320;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
|
||||
this->Run(Ms, Ns, Ks);
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
if(argc == 1)
|
||||
{
|
||||
// Run with default arguments.
|
||||
}
|
||||
else if(argc == 3)
|
||||
{
|
||||
param_mask = strtol(argv[1], nullptr, 0);
|
||||
instance_index = atoi(argv[2]);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Usage of " << argv[0] << std::endl;
|
||||
std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl;
|
||||
}
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
Reference in New Issue
Block a user