// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once template static constexpr inline auto is_row_major(Layout layout_) { return ck_tile::bool_constant, ck_tile::tensor_layout::gemm::RowMajor>>{}; } auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) { using ComputeTypeAB = std::conditional_t; using ComputeType = std::conditional_t; // Calculate thresholds const auto rtol = ck_tile::get_relative_threshold( ck_tile::integer_divide_ceil(K, kbatch)); const auto atol = ck_tile::get_absolute_threshold( max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); // Calculate error due to split_k accumulation const auto rtol_split_k = ck_tile::get_relative_threshold(kbatch); const auto atol_split_k = ck_tile::get_absolute_threshold( max_accumulated_value, kbatch); // Use higher threshold return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); }