// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once struct MultiplyMultiply { template CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void { const float x0_f = ck_tile::type_convert(c) * ck_tile::type_convert(d0) * ck_tile::type_convert(d1); e = ck_tile::type_convert(x0_f); } }; template static constexpr inline auto is_row_major(Layout layout_) { return ck_tile::bool_constant, ck_tile::tensor_layout::gemm::RowMajor>>{}; } auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) { using ComputeTypeAB = std::conditional_t; using ComputeType = std::conditional_t; // Calculate thresholds const auto rtol = ck_tile::get_relative_threshold( ck_tile::integer_divide_ceil(K, kbatch)); const auto atol = ck_tile::get_absolute_threshold( max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); // Calculate error due to split_k accumulation const auto rtol_split_k = ck_tile::get_relative_threshold(kbatch); const auto atol_split_k = ck_tile::get_absolute_threshold( max_accumulated_value, kbatch); // Use higher threshold return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); }