mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
173 implement device grouped gemm fixed nk for rdna4 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Proposed changes This PR adds an RDNA4 implementation of the device_grouped_gemm_fixed_nk instance library using for WMMA. The implementation is based on the existing DeviceGroupedGemm_Xdl_Fixed_NK design and reuses the same high-level structure, but replaces the XDL kernel with a WMMA-based one. It uses the GridwiseGemm_wmma_cshuffle_v3 kernel. At this stage, the focus is functional correctness and compatibility, not performance tuning. ## Technical Details - Device struct for grouped gemm fixed NK - Example code for the WMMA version - Unit tests for both new wmma implementation and the reference XDL code (previously missing) - Generic ck profiler interface with the purpose of calling unit tests. ## Checklist Please put an into the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask. - [x] I have added tests relevant to the introduced functionality, and the unit tests are passing locally - [x] I have added the test to REGRESSION_TESTS list defined at the top of CMakeLists.txt in tests/CMakeLists.txt, **IF** the test takes more than 30 seconds to run. - [ ] I have added inline documentation which enables the maintainers with understanding the motivation - [ ] I have removed the stale documentation which is no longer relevant after this pull request - [x] (If this change is user-facing) I have added release notes which provide the end users with a brief summary of the improvement from this pull request - [x] I have run on all changed files - [x] Any dependent changes have been merged ## Discussion If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered
328 lines
13 KiB
C++
328 lines
13 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include <array>
|
|
#include <string>
|
|
#include <sstream>
|
|
#include <tuple>
|
|
#include <type_traits>
|
|
#include <vector>
|
|
#include <gtest/gtest.h>
|
|
|
|
#include "ck/ck.hpp"
|
|
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
|
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
|
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
|
#include "profiler/profile_grouped_gemm_impl.hpp"
|
|
#include "profiler/profile_grouped_gemm_fixed_nk_impl.hpp"
|
|
|
|
extern ck::index_t param_mask;
|
|
extern ck::index_t instance_index;
|
|
|
|
namespace ck {
|
|
namespace test {
|
|
|
|
struct DefaultGroupedGemmProfiler
|
|
{
|
|
template <typename ADataType,
|
|
typename BDataType,
|
|
typename EDataType,
|
|
typename AccDataType,
|
|
typename ALayout,
|
|
typename BLayout,
|
|
typename ELayout,
|
|
typename AElementOp,
|
|
typename BElementOp,
|
|
typename CDEElementOp>
|
|
static bool Run(bool verify,
|
|
int init_method,
|
|
bool log,
|
|
bool bench,
|
|
const std::vector<int>& Ms,
|
|
const std::vector<int>& Ns,
|
|
const std::vector<int>& Ks,
|
|
const std::vector<int>& StrideAs,
|
|
const std::vector<int>& StrideBs,
|
|
const std::vector<int>& StrideCs,
|
|
const std::vector<int>& kbatches,
|
|
int n_warmup,
|
|
int n_iter,
|
|
int instance_index,
|
|
bool fail_if_no_supported_instances)
|
|
{
|
|
return ck::profiler::profile_grouped_gemm_impl<ADataType,
|
|
BDataType,
|
|
EDataType,
|
|
AccDataType,
|
|
ALayout,
|
|
BLayout,
|
|
ELayout,
|
|
AElementOp,
|
|
BElementOp,
|
|
CDEElementOp>(
|
|
verify,
|
|
init_method,
|
|
log,
|
|
bench,
|
|
Ms,
|
|
Ns,
|
|
Ks,
|
|
StrideAs,
|
|
StrideBs,
|
|
StrideCs,
|
|
kbatches,
|
|
n_warmup,
|
|
n_iter,
|
|
instance_index,
|
|
fail_if_no_supported_instances);
|
|
}
|
|
};
|
|
|
|
struct FixedNKGroupedGemmProfiler
|
|
{
|
|
template <typename ADataType,
|
|
typename BDataType,
|
|
typename EDataType,
|
|
typename AccDataType,
|
|
typename ALayout,
|
|
typename BLayout,
|
|
typename CLayout>
|
|
static bool Run(bool verify,
|
|
int init_method,
|
|
bool log,
|
|
bool bench,
|
|
const std::vector<int>& Ms,
|
|
const std::vector<int>& Ns,
|
|
const std::vector<int>& Ks,
|
|
const std::vector<int>& StrideAs,
|
|
const std::vector<int>& StrideBs,
|
|
const std::vector<int>& StrideCs,
|
|
const std::vector<int>& kbatches,
|
|
int n_warmup,
|
|
int n_iter,
|
|
int /*instance_index*/,
|
|
bool /*fail_if_no_supported_instances*/)
|
|
{
|
|
bool pass = true;
|
|
for(int kbatch : kbatches)
|
|
{
|
|
try
|
|
{
|
|
pass &= ck::profiler::profile_grouped_gemm_fixed_nk_impl<ADataType,
|
|
BDataType,
|
|
EDataType,
|
|
AccDataType,
|
|
ALayout,
|
|
BLayout,
|
|
CLayout>(verify,
|
|
init_method,
|
|
log,
|
|
bench,
|
|
Ms,
|
|
Ns,
|
|
Ks,
|
|
StrideAs,
|
|
StrideBs,
|
|
StrideCs,
|
|
kbatch,
|
|
n_warmup,
|
|
n_iter);
|
|
}
|
|
catch(const std::exception& e)
|
|
{
|
|
std::cerr << e.what() << std::endl;
|
|
}
|
|
}
|
|
return pass;
|
|
}
|
|
};
|
|
|
|
template <typename Tuple,
|
|
bool FailIfNoSupportedInstances = false,
|
|
typename Profiler = ck::test::DefaultGroupedGemmProfiler>
|
|
class TestGroupedGemm : public testing::Test
|
|
{
|
|
protected:
|
|
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
|
|
|
using ALayout = tuple_element_t<0, Tuple>;
|
|
using BLayout = tuple_element_t<1, Tuple>;
|
|
using ELayout = tuple_element_t<2, Tuple>;
|
|
using ADataType = tuple_element_t<3, Tuple>;
|
|
using BDataType = tuple_element_t<4, Tuple>;
|
|
using EDataType = tuple_element_t<5, Tuple>;
|
|
using AElementOp = tuple_element_or_t<6, Tuple, PassThrough>;
|
|
using BElementOp = tuple_element_or_t<7, Tuple, PassThrough>;
|
|
using CDEElementOp = tuple_element_or_t<8, Tuple, PassThrough>;
|
|
|
|
using Row = ck::tensor_layout::gemm::RowMajor;
|
|
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
|
|
|
public:
|
|
static constexpr bool verify_ = true;
|
|
static constexpr int init_method_ = 1; // integer value initialization
|
|
static constexpr bool log_ = false;
|
|
static constexpr bool bench_ = false; // measure kernel performance
|
|
static constexpr int n_warmup_ = 0;
|
|
static constexpr int n_iter_ = 1;
|
|
|
|
bool fail_if_no_supported_instances_ = FailIfNoSupportedInstances;
|
|
std::vector<int> k_batches_;
|
|
|
|
bool IsSplitKSupported()
|
|
{
|
|
// gfx11 does not support split-K due to missing atomic add for fp16/bf16
|
|
// Technically, we could still use split-K for fp32, but we currently don't have
|
|
// instances for it so we disable it entirely
|
|
constexpr bool require_16bit_atomic_add =
|
|
std::is_same_v<EDataType, ck::half_t> || std::is_same_v<EDataType, ck::bhalf_t>;
|
|
bool missing_atomic_add = require_16bit_atomic_add && ck::is_gfx11_supported();
|
|
|
|
// CDE element operators are not supported in combination with split K
|
|
constexpr bool has_cde_element_operator = !std::is_same_v<CDEElementOp, PassThrough>;
|
|
|
|
return !missing_atomic_add && !has_cde_element_operator;
|
|
}
|
|
|
|
void SetUp() override
|
|
{
|
|
if(!IsSplitKSupported())
|
|
{
|
|
k_batches_ = {1};
|
|
}
|
|
else
|
|
{
|
|
k_batches_ = {1, 2, 3, 4, 8};
|
|
}
|
|
}
|
|
|
|
private:
|
|
template <typename Layout>
|
|
void SetStrides(std::vector<int>& strides,
|
|
const std::vector<int>& rows,
|
|
const std::vector<int>& cols) const
|
|
{
|
|
if(std::is_same_v<Layout, Row>)
|
|
{
|
|
for(const auto c : cols)
|
|
{
|
|
strides.emplace_back(c);
|
|
}
|
|
}
|
|
else if(std::is_same_v<Layout, Col>)
|
|
{
|
|
for(const auto r : rows)
|
|
{
|
|
strides.emplace_back(r);
|
|
}
|
|
}
|
|
}
|
|
|
|
public:
|
|
void Run(const std::vector<int>& Ms,
|
|
const std::vector<int>& Ns,
|
|
const std::vector<int>& Ks,
|
|
const std::vector<int>& StrideAs = {},
|
|
const std::vector<int>& StrideBs = {},
|
|
const std::vector<int>& StrideCs = {})
|
|
{
|
|
std::vector<int> stride_as = StrideAs;
|
|
std::vector<int> stride_bs = StrideBs;
|
|
std::vector<int> stride_cs = StrideCs;
|
|
|
|
if(stride_as.empty())
|
|
{
|
|
SetStrides<ALayout>(stride_as, Ms, Ks);
|
|
}
|
|
if(stride_bs.empty())
|
|
{
|
|
SetStrides<BLayout>(stride_bs, Ks, Ns);
|
|
}
|
|
if(stride_cs.empty())
|
|
{
|
|
SetStrides<ELayout>(stride_cs, Ms, Ns);
|
|
}
|
|
std::vector<int> k_batches;
|
|
for(size_t i = 0; i < k_batches_.size(); i++)
|
|
{
|
|
if(param_mask & (1 << i))
|
|
{
|
|
k_batches.push_back(k_batches_[i]);
|
|
}
|
|
}
|
|
|
|
RunSingle(Ms, Ns, Ks, stride_as, stride_bs, stride_cs, k_batches);
|
|
}
|
|
|
|
void RunSingle(const std::vector<int>& Ms,
|
|
const std::vector<int>& Ns,
|
|
const std::vector<int>& Ks,
|
|
const std::vector<int>& StrideAs,
|
|
const std::vector<int>& StrideBs,
|
|
const std::vector<int>& StrideCs,
|
|
const std::vector<int>& kbatches)
|
|
{
|
|
bool pass = false;
|
|
using AccDataType = float;
|
|
|
|
if constexpr(std::is_same_v<Profiler, FixedNKGroupedGemmProfiler>)
|
|
{
|
|
pass = Profiler::template Run<ADataType,
|
|
BDataType,
|
|
EDataType,
|
|
AccDataType,
|
|
ALayout,
|
|
BLayout,
|
|
ELayout>(verify_,
|
|
init_method_,
|
|
log_,
|
|
bench_,
|
|
Ms,
|
|
Ns,
|
|
Ks,
|
|
StrideAs,
|
|
StrideBs,
|
|
StrideCs,
|
|
kbatches,
|
|
n_warmup_,
|
|
n_iter_,
|
|
instance_index,
|
|
fail_if_no_supported_instances_);
|
|
}
|
|
else
|
|
{
|
|
pass = Profiler::template Run<ADataType,
|
|
BDataType,
|
|
EDataType,
|
|
AccDataType,
|
|
ALayout,
|
|
BLayout,
|
|
ELayout,
|
|
AElementOp,
|
|
BElementOp,
|
|
CDEElementOp>(verify_,
|
|
init_method_,
|
|
log_,
|
|
bench_,
|
|
Ms,
|
|
Ns,
|
|
Ks,
|
|
StrideAs,
|
|
StrideBs,
|
|
StrideCs,
|
|
kbatches,
|
|
n_warmup_,
|
|
n_iter_,
|
|
instance_index,
|
|
fail_if_no_supported_instances_);
|
|
}
|
|
|
|
EXPECT_TRUE(pass);
|
|
}
|
|
};
|
|
|
|
} // namespace test
|
|
} // namespace ck
|