Files
composable_kernel/test/grouped_gemm/test_grouped_gemm_util.hpp
Illia Silin c24e528481 [rocm-libraries] ROCm/rocm-libraries#7760 (commit a61bc76)
[CK] suppress compiler warnings while building pytorch. (#7760)

## Motivation

Recently added compiler flags that are required to suppress false
warnings by latest staging compiler are not recognized by older compiler
versions and are triggering an avalanche of warnings. Previous attempt
to suppress them by using -Wno-unknown-warning-option flag didn't help,
because that flag wasn't recognized either and just added more warnings.
I've verified that current approach by checking the clang version
actually works as intended and makes the warnings go away.

## Technical Details

<!-- Explain the changes along with any relevant GitHub links. -->

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-05-27 06:56:58 -07:00

352 lines
13 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <array>
#include <string>
#include <sstream>
#include <tuple>
#include <type_traits>
#include <vector>
#include <gtest/gtest.h>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp"
#include "profiler/profile_grouped_gemm_fixed_nk_impl.hpp"
#if __clang_major__ >= 23
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wlifetime-safety-invalidation"
#endif
static ck::index_t param_mask = 0xffffff;
static ck::index_t instance_index = -1;
namespace ck {
namespace test {
struct DefaultGroupedGemmProfiler
{
template <typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename ELayout,
typename AElementOp,
typename BElementOp,
typename CDEElementOp>
static bool Run(bool verify,
int init_method,
bool log,
bool bench,
const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
const std::vector<int>& kbatches,
int n_warmup,
int n_iter,
int instance_index,
bool fail_if_no_supported_instances)
{
return ck::profiler::profile_grouped_gemm_impl<ADataType,
BDataType,
EDataType,
AccDataType,
ALayout,
BLayout,
ELayout,
AElementOp,
BElementOp,
CDEElementOp>(
verify,
init_method,
log,
bench,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatches,
n_warmup,
n_iter,
instance_index,
fail_if_no_supported_instances);
}
};
struct FixedNKGroupedGemmProfiler
{
template <typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout>
static bool Run(bool verify,
int init_method,
bool log,
bool bench,
const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
const std::vector<int>& kbatches,
int n_warmup,
int n_iter,
int /*instance_index*/,
bool /*fail_if_no_supported_instances*/)
{
bool pass = true;
for(int kbatch : kbatches)
{
try
{
pass &= ck::profiler::profile_grouped_gemm_fixed_nk_impl<ADataType,
BDataType,
EDataType,
AccDataType,
ALayout,
BLayout,
CLayout>(verify,
init_method,
log,
bench,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
}
catch(const std::exception& e)
{
std::cerr << e.what() << std::endl;
}
}
return pass;
}
};
template <typename Tuple,
bool FailIfNoSupportedInstances = false,
typename Profiler = ck::test::DefaultGroupedGemmProfiler>
class TestGroupedGemm : public testing::Test
{
protected:
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ALayout = tuple_element_t<0, Tuple>;
using BLayout = tuple_element_t<1, Tuple>;
using ELayout = tuple_element_t<2, Tuple>;
using ADataType = tuple_element_t<3, Tuple>;
using BDataType = tuple_element_t<4, Tuple>;
using EDataType = tuple_element_t<5, Tuple>;
using AElementOp = tuple_element_or_t<6, Tuple, PassThrough>;
using BElementOp = tuple_element_or_t<7, Tuple, PassThrough>;
using CDEElementOp = tuple_element_or_t<8, Tuple, PassThrough>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
public:
static constexpr bool verify_ = true;
static constexpr int init_method_ = 1; // integer value initialization
static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance
static constexpr int n_warmup_ = 0;
static constexpr int n_iter_ = 1;
bool fail_if_no_supported_instances_ = FailIfNoSupportedInstances;
std::vector<int> k_batches_;
bool IsSplitKSupported()
{
// gfx11 does not support split-K due to missing atomic add for fp16/bf16
// Technically, we could still use split-K for fp32, but we currently don't have
// instances for it so we disable it entirely
constexpr bool require_16bit_atomic_add =
std::is_same_v<EDataType, ck::half_t> || std::is_same_v<EDataType, ck::bhalf_t>;
bool missing_atomic_add = require_16bit_atomic_add && ck::is_gfx11_supported();
// CDE element operators are not supported in combination with split K
constexpr bool has_cde_element_operator = !std::is_same_v<CDEElementOp, PassThrough>;
return !missing_atomic_add && !has_cde_element_operator;
}
void SetUp() override
{
if(!IsSplitKSupported())
{
k_batches_ = {1};
}
else
{
k_batches_ = {1, 2, 3, 4, 8};
}
}
private:
template <typename Layout>
void SetStrides(std::vector<int>& strides,
const std::vector<int>& rows,
const std::vector<int>& cols) const
{
if(std::is_same_v<Layout, Row>)
{
for(const auto c : cols)
{
strides.emplace_back(c);
}
}
else if(std::is_same_v<Layout, Col>)
{
for(const auto r : rows)
{
strides.emplace_back(r);
}
}
}
public:
void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs = {},
const std::vector<int>& StrideBs = {},
const std::vector<int>& StrideCs = {})
{
std::vector<int> stride_as = StrideAs;
std::vector<int> stride_bs = StrideBs;
std::vector<int> stride_cs = StrideCs;
if(stride_as.empty())
{
SetStrides<ALayout>(stride_as, Ms, Ks);
}
if(stride_bs.empty())
{
SetStrides<BLayout>(stride_bs, Ks, Ns);
}
if(stride_cs.empty())
{
SetStrides<ELayout>(stride_cs, Ms, Ns);
}
std::vector<int> k_batches;
for(size_t i = 0; i < k_batches_.size(); i++)
{
if(param_mask & (1 << i))
{
k_batches.push_back(k_batches_[i]);
}
}
RunSingle(Ms, Ns, Ks, stride_as, stride_bs, stride_cs, k_batches);
}
void RunSingle(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
const std::vector<int>& kbatches)
{
bool pass = false;
using AccDataType = float;
if constexpr(std::is_same_v<Profiler, FixedNKGroupedGemmProfiler>)
{
pass = Profiler::template Run<ADataType,
BDataType,
EDataType,
AccDataType,
ALayout,
BLayout,
ELayout>(verify_,
init_method_,
log_,
bench_,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatches,
n_warmup_,
n_iter_,
instance_index,
fail_if_no_supported_instances_);
}
else
{
pass = Profiler::template Run<ADataType,
BDataType,
EDataType,
AccDataType,
ALayout,
BLayout,
ELayout,
AElementOp,
BElementOp,
CDEElementOp>(verify_,
init_method_,
log_,
bench_,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatches,
n_warmup_,
n_iter_,
instance_index,
fail_if_no_supported_instances_);
}
EXPECT_TRUE(pass);
}
};
} // namespace test
} // namespace ck
int main(int argc, char** argv)
{
testing::InitGoogleTest(&argc, argv);
if(argc == 1) {}
else if(argc == 3)
{
param_mask = strtol(argv[1], nullptr, 0);
instance_index = atoi(argv[2]);
}
else
{
std::cout << "Usage of " << argv[0] << std::endl;
std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl;
}
return RUN_ALL_TESTS();
}
#if __clang_major__ >= 23
#pragma clang diagnostic pop
#endif