mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
Implement grouped gemm tile loop for RDNA4 (#3304)
* feat: grouped gemm tile loop support for RDNA4
* fix: removed extra parameter from grouped gemm example instance
* fix: FP8 check incorrectly enabling FP8 on RDNA3
[ROCm/composable_kernel commit: eb041079a3]
This commit is contained in:
18
test/grouped_gemm_tile_loop/CMakeLists.txt
Normal file
18
test/grouped_gemm_tile_loop/CMakeLists.txt
Normal file
@@ -0,0 +1,18 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_custom_target(test_grouped_gemm_tile_loop)
|
||||
|
||||
if (CK_USE_XDL OR CK_USE_WMMA)
|
||||
add_gtest_executable(test_grouped_gemm_tile_loop_vanilla test_grouped_gemm_tile_loop.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_gemm_tile_loop_vanilla PRIVATE utility device_grouped_gemm_tile_loop_instance)
|
||||
add_dependencies(test_grouped_gemm_tile_loop test_grouped_gemm_tile_loop_vanilla)
|
||||
endif()
|
||||
|
||||
add_gtest_executable(test_grouped_gemm_tile_loop_multiply test_grouped_gemm_tile_loop_multiply.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_gemm_tile_loop_multiply PRIVATE utility device_grouped_gemm_tile_loop_instance)
|
||||
add_dependencies(test_grouped_gemm_tile_loop test_grouped_gemm_tile_loop_multiply)
|
||||
endif()
|
||||
endif()
|
||||
52
test/grouped_gemm_tile_loop/test_grouped_gemm_tile_loop.cpp
Normal file
52
test/grouped_gemm_tile_loop/test_grouped_gemm_tile_loop.cpp
Normal file
@@ -0,0 +1,52 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test_grouped_gemm_tile_loop_util.hpp"
|
||||
|
||||
ck::index_t param_mask = 0xffffff;
|
||||
ck::index_t instance_index = -1;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F8 = ck::f8_t;
|
||||
using I8 = int8_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedGemmTileLoop : public ck::test::TestGroupedGemmTileLoop<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
ck::Tuple<Row, Row, ck::Tuple<>, Row, F16, F16, ck::Tuple<>, F16>,
|
||||
ck::Tuple<Row, Col, ck::Tuple<>, Row, F16, F16, ck::Tuple<>, F16>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedGemmTileLoop, KernelTypes);
|
||||
|
||||
#include "test_grouped_gemm_tile_loop_ut_cases.inc"
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
if(argc == 1) {}
|
||||
else if(argc == 3)
|
||||
{
|
||||
param_mask = strtol(argv[1], nullptr, 0);
|
||||
instance_index = atoi(argv[2]);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Usage of " << argv[0] << std::endl;
|
||||
std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl;
|
||||
}
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "example/68_gemm_add/common.hpp"
|
||||
#include "test_grouped_gemm_tile_loop_util.hpp"
|
||||
|
||||
ck::index_t param_mask = 0xffffff;
|
||||
ck::index_t instance_index = -1;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F8 = ck::f8_t;
|
||||
using I8 = int8_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Multiply = ck::tensor_operation::element_wise::Multiply;
|
||||
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
|
||||
using MultiplyAddFastGelu = ck::tensor_operation::element_wise::MultiplyAddFastGelu;
|
||||
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedGemmTileLoop : public ck::test::TestGroupedGemmTileLoop<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
ck::Tuple<Row, Row, ck::Tuple<Row>, Row, BF16, I8, ck::Tuple<BF16>, BF16, PassThrough, PassThrough, Multiply>,
|
||||
ck::Tuple<Row, Row, ck::Tuple<Row, Row>, Row, BF16, I8, ck::Tuple<BF16, BF16>, BF16, PassThrough, PassThrough, MultiplyAdd>,
|
||||
ck::Tuple<Row, Row, ck::Tuple<Row, Row>, Row, BF16, I8, ck::Tuple<BF16, BF16>, BF16, PassThrough, PassThrough, MultiplyAddFastGelu>,
|
||||
ck::Tuple<Row, Row, ck::Tuple<Row>, Row, BF16, I8, ck::Tuple<BF16>, BF16, PassThrough, PassThrough, MultiplyFastGelu>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedGemmTileLoop, KernelTypes);
|
||||
|
||||
#include "test_grouped_gemm_tile_loop_ut_cases.inc"
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
if(argc == 1) {}
|
||||
else if(argc == 3)
|
||||
{
|
||||
param_mask = strtol(argv[1], nullptr, 0);
|
||||
instance_index = atoi(argv[2]);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Usage of " << argv[0] << std::endl;
|
||||
std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl;
|
||||
}
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestGroupedGemmTileLoop, TinyCases)
|
||||
{
|
||||
const std::vector<int> Ms{2, 1};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 544;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
|
||||
this->Run(Ms, Ns, Ks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedGemmTileLoop, SmallCases)
|
||||
{
|
||||
const std::vector<int> Ms{2, 1, 3, 4, 5};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 544;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
|
||||
this->Run(Ms, Ns, Ks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedGemmTileLoop, MidCases)
|
||||
{
|
||||
const std::vector<int> Ms{167, 183, 177, 153, 139, 204};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 544;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
|
||||
this->Run(Ms, Ns, Ks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedGemmTileLoop, Regular)
|
||||
{
|
||||
const std::vector<int> Ms{64, 128, 256};
|
||||
constexpr int N = 768;
|
||||
constexpr int K = 320;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
|
||||
this->Run(Ms, Ns, Ks);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedGemmTileLoop, MNKPadded)
|
||||
{
|
||||
const std::vector<int> Ms{127, 150, 188, 210};
|
||||
constexpr int N = 136;
|
||||
constexpr int K = 280;
|
||||
|
||||
const std::vector<int> Ns(Ms.size(), N);
|
||||
const std::vector<int> Ks(Ms.size(), K);
|
||||
|
||||
this->Run(Ms, Ns, Ks);
|
||||
}
|
||||
173
test/grouped_gemm_tile_loop/test_grouped_gemm_tile_loop_util.hpp
Normal file
173
test/grouped_gemm_tile_loop/test_grouped_gemm_tile_loop_util.hpp
Normal file
@@ -0,0 +1,173 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/tuple.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp"
|
||||
#include "profiler/profile_grouped_gemm_tile_loop_generic_impl.hpp"
|
||||
|
||||
extern ck::index_t param_mask;
|
||||
extern ck::index_t instance_index;
|
||||
|
||||
namespace ck {
|
||||
namespace test {
|
||||
|
||||
template <typename Tuple, bool FailIfNoSupportedInstances = false>
|
||||
class TestGroupedGemmTileLoop : public testing::Test
|
||||
{
|
||||
protected:
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ALayout = tuple_element_t<0, Tuple>;
|
||||
using BLayout = tuple_element_t<1, Tuple>;
|
||||
using DsLayout = tuple_element_t<2, Tuple>;
|
||||
using ELayout = tuple_element_t<3, Tuple>;
|
||||
using ADataType = tuple_element_t<4, Tuple>;
|
||||
using BDataType = tuple_element_t<5, Tuple>;
|
||||
using DsDataType = tuple_element_t<6, Tuple>;
|
||||
using EDataType = tuple_element_t<7, Tuple>;
|
||||
using AElementOp = tuple_element_or_t<8, Tuple, PassThrough>;
|
||||
using BElementOp = tuple_element_or_t<9, Tuple, PassThrough>;
|
||||
using CDEElementOp = tuple_element_or_t<10, Tuple, PassThrough>;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
static constexpr auto NumDTensor = DsLayout::Size();
|
||||
|
||||
public:
|
||||
static constexpr bool verify_ = true;
|
||||
static constexpr int init_method_ = 1; // integer value initialization
|
||||
static constexpr bool log_ = false;
|
||||
static constexpr bool bench_ = false; // measure kernel performance
|
||||
static constexpr int n_warmup_ = 0;
|
||||
static constexpr int n_iter_ = 1;
|
||||
|
||||
bool fail_if_no_supported_instances_ = FailIfNoSupportedInstances;
|
||||
|
||||
private:
|
||||
template <typename Layout>
|
||||
void SetStrides(std::vector<int>& strides,
|
||||
const std::vector<int>& rows,
|
||||
const std::vector<int>& cols) const
|
||||
{
|
||||
if(std::is_same_v<Layout, Row>)
|
||||
{
|
||||
for(const auto c : cols)
|
||||
{
|
||||
strides.emplace_back(c);
|
||||
}
|
||||
}
|
||||
else if(std::is_same_v<Layout, Col>)
|
||||
{
|
||||
for(const auto r : rows)
|
||||
{
|
||||
strides.emplace_back(r);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
void Run(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs = {},
|
||||
const std::vector<int>& StrideBs = {},
|
||||
const std::vector<std::array<int, NumDTensor>>& StrideDs = {},
|
||||
const std::vector<int>& StrideEs = {})
|
||||
{
|
||||
std::vector<int> stride_as = StrideAs;
|
||||
std::vector<int> stride_bs = StrideBs;
|
||||
std::vector<std::array<int, NumDTensor>> stride_ds = StrideDs;
|
||||
std::vector<int> stride_es = StrideEs;
|
||||
|
||||
if(stride_as.empty())
|
||||
{
|
||||
SetStrides<ALayout>(stride_as, Ms, Ks);
|
||||
}
|
||||
if(stride_bs.empty())
|
||||
{
|
||||
SetStrides<BLayout>(stride_bs, Ks, Ns);
|
||||
}
|
||||
|
||||
if(stride_ds.empty())
|
||||
{
|
||||
for(size_t group = 0; group < Ms.size(); ++group)
|
||||
{
|
||||
std::array<int, NumDTensor> d_strides;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = tuple_element_t<i, DsLayout>;
|
||||
|
||||
if(std::is_same_v<DLayout, Row>)
|
||||
{
|
||||
d_strides[i] = Ns[group];
|
||||
}
|
||||
else if(std::is_same_v<DLayout, Col>)
|
||||
{
|
||||
d_strides[i] = Ms[group];
|
||||
}
|
||||
});
|
||||
|
||||
stride_ds.emplace_back(d_strides);
|
||||
}
|
||||
}
|
||||
|
||||
if(stride_es.empty())
|
||||
{
|
||||
SetStrides<ELayout>(stride_es, Ms, Ns);
|
||||
}
|
||||
|
||||
RunSingle(Ms, Ns, Ks, stride_as, stride_bs, stride_ds, stride_es);
|
||||
}
|
||||
|
||||
void RunSingle(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
const std::vector<int>& Ks,
|
||||
const std::vector<int>& StrideAs,
|
||||
const std::vector<int>& StrideBs,
|
||||
const std::vector<std::array<int, NumDTensor>>& StrideDs,
|
||||
const std::vector<int>& StrideEs)
|
||||
{
|
||||
bool pass =
|
||||
ck::profiler::profile_grouped_gemm_tile_loop_generic_impl<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CDEElementOp>(verify_,
|
||||
init_method_,
|
||||
log_,
|
||||
bench_,
|
||||
Ms,
|
||||
Ns,
|
||||
Ks,
|
||||
StrideAs,
|
||||
StrideBs,
|
||||
StrideDs,
|
||||
StrideEs,
|
||||
n_warmup_,
|
||||
n_iter_);
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user