Files
composable_kernel/test/grouped_gemm/test_grouped_gemm_util.hpp
Erwin Terpstra 46f1d740f0 Add grouped gemm instances for RDNA4 (#3237)
* wip: grouped_gemm implementation based on wmma kernel + example for fp16

* chore: clean up grouped_gem_wmma_splitk_fp16 example

* chore: add cmake options to fully disable XDL or WMMA kernels

* feat: add tests for grouped gemma wmma instances for f16 and bf16 (all layouts)

* chore: add grouped gemm wmma bf16 example

* refactor: reuse more code between instance factory functions

* chore: turn test failure if not all batch sizes are supported into a warning

* chore: made failing of test on unsupported instances conditional to not break old tests

* chore: add log message to failure case where AK1/BK1/KBatch is too high for K value

* fix: issue with new overloads of GridwiseGemm_wmma_cshuffle_v3::Run()

* fix: stray comma after parameter list

* fix: compilation issues on RDNA3 and tests failing due to unsupported problems still being ran

* chore: update copyright in header comments

* nit: minor feebdack

* refactor: unified XDL / wma tests

* fix: properly disable FP8 instances when ONLY targeting gfx11

* refactor: add v3 suffix to grouped_gemm device struct name

* fix: small typos in example code

* fix: fully exclude xdl/wmma instances when using the corresponding cmake flags

* chore: remove unused destructor and added pipeline support checks to remove unnecessary paths

* fix: make sure to not add instance library to group if library was skipped

* fix: make sure xdl grouped gemm doesnt fail the new test

* fix: explicitly exclude test if no xdl/wmma support, as pattern matching fails in this case

* fix: examples not working since dependent types and functions were moved to ck namespace in develop

* fix: tests failing when compiling for just gfx11 due to trying to run unsupported instances

* chore: replace/add copyright headers with new format
2025-12-01 15:32:10 -08:00

171 lines
5.7 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <array>
#include <string>
#include <sstream>
#include <tuple>
#include <vector>
#include <gtest/gtest.h>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp"
extern ck::index_t param_mask;
extern ck::index_t instance_index;
namespace ck {
namespace test {
template <typename Range>
std::string serialize_range(const Range& range)
{
std::stringstream ss;
for(auto& r : range)
{
ss << r << ", ";
}
std::string str = ss.str();
return std::string(str.begin(), str.end() - 2);
}
template <typename Tuple, bool FailIfNoSupportedInstances = false>
class TestGroupedGemm : public testing::Test
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using ELayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using EDataType = std::tuple_element_t<5, Tuple>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
public:
static constexpr bool verify_ = true;
static constexpr int init_method_ = 1; // integer value initialization
static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance
static constexpr int n_warmup_ = 0;
static constexpr int n_iter_ = 1;
bool fail_if_no_supported_instances_ = FailIfNoSupportedInstances;
std::vector<int> k_batches_;
void SetUp() override
{
constexpr bool require_16bit_atomic_add =
std::is_same_v<EDataType, ck::half_t> || std::is_same_v<EDataType, ck::bhalf_t>;
if(require_16bit_atomic_add && ck::is_gfx11_supported())
{
// gfx11 does not support split-K due to missing atomic add for fp16/bf16
// Technically, we could still use split-K for fp32, but we currently don't have
// instances for it so we disable it entirely
k_batches_ = {1};
}
else
{
k_batches_ = {1, 2, 3, 5, 8};
}
}
private:
template <typename Layout>
void SetStrides(std::vector<int>& strides,
const std::vector<int>& rows,
const std::vector<int>& cols) const
{
if(std::is_same_v<Layout, Row>)
{
for(const auto c : cols)
{
strides.emplace_back(c);
}
}
else if(std::is_same_v<Layout, Col>)
{
for(const auto r : rows)
{
strides.emplace_back(r);
}
}
}
public:
void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs = {},
const std::vector<int>& StrideBs = {},
const std::vector<int>& StrideCs = {})
{
std::vector<int> stride_as = StrideAs;
std::vector<int> stride_bs = StrideBs;
std::vector<int> stride_cs = StrideCs;
if(stride_as.empty())
{
SetStrides<ALayout>(stride_as, Ms, Ks);
}
if(stride_bs.empty())
{
SetStrides<BLayout>(stride_bs, Ks, Ns);
}
if(stride_cs.empty())
{
SetStrides<ELayout>(stride_cs, Ms, Ns);
}
std::vector<int> k_batches;
for(size_t i = 0; i < k_batches_.size(); i++)
{
if(param_mask & (1 << i))
{
k_batches.push_back(k_batches_[i]);
}
}
RunSingle(Ms, Ns, Ks, stride_as, stride_bs, stride_cs, k_batches);
}
void RunSingle(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
const std::vector<int>& kbatches)
{
bool pass =
ck::profiler::profile_grouped_gemm_impl<ADataType,
BDataType,
EDataType,
float,
ALayout,
BLayout,
ELayout>(verify_,
init_method_,
log_,
bench_,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatches,
n_warmup_,
n_iter_,
instance_index,
fail_if_no_supported_instances_);
EXPECT_TRUE(pass);
}
};
} // namespace test
} // namespace ck