// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include #include #include #include #include #include #include #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "profiler/profile_grouped_gemm_impl.hpp" #include "profiler/profile_grouped_gemm_fixed_nk_impl.hpp" extern ck::index_t param_mask; extern ck::index_t instance_index; namespace ck { namespace test { struct DefaultGroupedGemmProfiler { template static bool Run(bool verify, int init_method, bool log, bool bench, const std::vector& Ms, const std::vector& Ns, const std::vector& Ks, const std::vector& StrideAs, const std::vector& StrideBs, const std::vector& StrideCs, const std::vector& kbatches, int n_warmup, int n_iter, int instance_index, bool fail_if_no_supported_instances) { return ck::profiler::profile_grouped_gemm_impl( verify, init_method, log, bench, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatches, n_warmup, n_iter, instance_index, fail_if_no_supported_instances); } }; struct FixedNKGroupedGemmProfiler { template static bool Run(bool verify, int init_method, bool log, bool bench, const std::vector& Ms, const std::vector& Ns, const std::vector& Ks, const std::vector& StrideAs, const std::vector& StrideBs, const std::vector& StrideCs, const std::vector& kbatches, int n_warmup, int n_iter, int /*instance_index*/, bool /*fail_if_no_supported_instances*/) { bool pass = true; for(int kbatch : kbatches) { try { pass &= ck::profiler::profile_grouped_gemm_fixed_nk_impl(verify, init_method, log, bench, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatch, n_warmup, n_iter); } catch(const std::exception& e) { std::cerr << e.what() << std::endl; } } return pass; } }; template class TestGroupedGemm : public testing::Test { protected: using PassThrough = ck::tensor_operation::element_wise::PassThrough; using ALayout = tuple_element_t<0, Tuple>; using BLayout = tuple_element_t<1, Tuple>; using ELayout = tuple_element_t<2, Tuple>; using ADataType = tuple_element_t<3, Tuple>; using BDataType = tuple_element_t<4, Tuple>; using EDataType = tuple_element_t<5, Tuple>; using AElementOp = tuple_element_or_t<6, Tuple, PassThrough>; using BElementOp = tuple_element_or_t<7, Tuple, PassThrough>; using CDEElementOp = tuple_element_or_t<8, Tuple, PassThrough>; using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; public: static constexpr bool verify_ = true; static constexpr int init_method_ = 1; // integer value initialization static constexpr bool log_ = false; static constexpr bool bench_ = false; // measure kernel performance static constexpr int n_warmup_ = 0; static constexpr int n_iter_ = 1; bool fail_if_no_supported_instances_ = FailIfNoSupportedInstances; std::vector k_batches_; bool IsSplitKSupported() { // gfx11 does not support split-K due to missing atomic add for fp16/bf16 // Technically, we could still use split-K for fp32, but we currently don't have // instances for it so we disable it entirely constexpr bool require_16bit_atomic_add = std::is_same_v || std::is_same_v; bool missing_atomic_add = require_16bit_atomic_add && ck::is_gfx11_supported(); // CDE element operators are not supported in combination with split K constexpr bool has_cde_element_operator = !std::is_same_v; return !missing_atomic_add && !has_cde_element_operator; } void SetUp() override { if(!IsSplitKSupported()) { k_batches_ = {1}; } else { k_batches_ = {1, 2, 3, 4, 8}; } } private: template void SetStrides(std::vector& strides, const std::vector& rows, const std::vector& cols) const { if(std::is_same_v) { for(const auto c : cols) { strides.emplace_back(c); } } else if(std::is_same_v) { for(const auto r : rows) { strides.emplace_back(r); } } } public: void Run(const std::vector& Ms, const std::vector& Ns, const std::vector& Ks, const std::vector& StrideAs = {}, const std::vector& StrideBs = {}, const std::vector& StrideCs = {}) { std::vector stride_as = StrideAs; std::vector stride_bs = StrideBs; std::vector stride_cs = StrideCs; if(stride_as.empty()) { SetStrides(stride_as, Ms, Ks); } if(stride_bs.empty()) { SetStrides(stride_bs, Ks, Ns); } if(stride_cs.empty()) { SetStrides(stride_cs, Ms, Ns); } std::vector k_batches; for(size_t i = 0; i < k_batches_.size(); i++) { if(param_mask & (1 << i)) { k_batches.push_back(k_batches_[i]); } } RunSingle(Ms, Ns, Ks, stride_as, stride_bs, stride_cs, k_batches); } void RunSingle(const std::vector& Ms, const std::vector& Ns, const std::vector& Ks, const std::vector& StrideAs, const std::vector& StrideBs, const std::vector& StrideCs, const std::vector& kbatches) { bool pass = false; using AccDataType = float; if constexpr(std::is_same_v) { pass = Profiler::template Run(verify_, init_method_, log_, bench_, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatches, n_warmup_, n_iter_, instance_index, fail_if_no_supported_instances_); } else { pass = Profiler::template Run(verify_, init_method_, log_, bench_, Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, kbatches, n_warmup_, n_iter_, instance_index, fail_if_no_supported_instances_); } EXPECT_TRUE(pass); } }; } // namespace test } // namespace ck