// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #include #include #include #include #include "gemm_utils.hpp" #include "run_gemm_example.inc" #include "run_gemm_example_common.hpp" #include "universal_gemm_invoker.hpp" // Universal GEMM-specific wrapper that handles test_async flag template int run_gemm_example_with_layouts_universal(ck_tile::ArgParser& arg_parser, const ALayout a_layout = ALayout{}, const BLayout b_layout = BLayout{}, const CLayout c_layout = CLayout{}) { using Invoker = UniversalInvoker; using AccDataType = typename GemmTypeConfig::AccDataType; // Check for async input scheduler test mode bool test_async = arg_parser.get_int("test_async"); if(test_async) { // Extract parameters for async test (same as shared implementation) const ck_tile::index_t M = arg_parser.get_int("m"); const ck_tile::index_t N = arg_parser.get_int("n"); const ck_tile::index_t K = arg_parser.get_int("k"); const ck_tile::index_t kbatch = arg_parser.get_int("split_k"); using Row = ck_tile::tensor_layout::gemm::RowMajor; constexpr bool is_a_row_major = std::is_same_v; constexpr bool is_b_row_major = std::is_same_v; constexpr bool is_c_row_major = std::is_same_v; const ck_tile::index_t stride_A = is_a_row_major ? K : M; const ck_tile::index_t stride_B = is_b_row_major ? N : K; const ck_tile::index_t stride_C = is_c_row_major ? N : M; // Allocate and initialize tensors ck_tile::HostTensor a_m_k(ck_tile::host_tensor_descriptor( M, K, stride_A, ck_tile::bool_constant{})); ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( K, N, stride_B, ck_tile::bool_constant{})); ck_tile::HostTensor c_m_n_dev_result(ck_tile::host_tensor_descriptor( M, N, stride_C, ck_tile::bool_constant{})); ck_tile::FillUniformDistributionIntegerValue{-5, 5}(a_m_k); ck_tile::FillUniformDistributionIntegerValue{-5, 5}(b_k_n); ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); a_m_k_dev_buf.ToDevice(a_m_k.data()); b_k_n_dev_buf.ToDevice(b_k_n.data()); c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(), b_k_n_dev_buf.GetDeviceBuffer(), c_m_n_dev_buf.GetDeviceBuffer(), kbatch, M, N, K, stride_A, stride_B, stride_C}; Invoker::template test_async_input_scheduler, AccDataType, CDataType, ALayout, BLayout, ck_tile::tuple<>, CLayout, ck_tile::element_wise::PassThrough>( args, ck_tile::stream_config{nullptr, false, 1}); // Copy result from device for verification c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); // Compute CPU reference ck_tile::HostTensor c_m_n_ref(ck_tile::host_tensor_descriptor( M, N, stride_C, ck_tile::bool_constant{})); c_m_n_ref.SetZero(); ck_tile::reference_gemm( a_m_k, b_k_n, c_m_n_ref); // Verify results const float max_accumulated_value = *std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end()); const auto rtol_atol = calculate_rtol_atol( K, kbatch, max_accumulated_value); bool pass = do_verify(c_m_n_dev_result, c_m_n_ref, rtol_atol, "CPU"); std::cout << "Async input scheduler test: " << (pass ? "PASS" : "FAIL") << std::endl; return pass; } // Normal path - delegate to shared implementation return run_gemm_example_with_layouts( arg_parser, a_layout, b_layout, c_layout); } // Universal GEMM-specific prec_type dispatcher that uses the wrapper template int run_gemm_example_prec_type_universal(std::string a_layout, std::string b_layout, ck_tile::ArgParser& arg_parser) { using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; bool preshuffle = GemmConfig::Preshuffle; if(preshuffle && std::is_same_v) { throw std::runtime_error("Preshuffle is not supported for this int4 datatype!"); } if(preshuffle && a_layout != "R" && b_layout != "C") { throw std::runtime_error( "Preshuffle is supported only for A(Row major), B(column major) input matrices!"); } using LayoutVariant = std::variant; auto string_to_layout = [](const std::string& layout) -> LayoutVariant { if(layout == "R") return Row{}; if(layout == "C") return Col{}; throw std::runtime_error("Unsupported layout: " + layout); }; auto a_layout_variant = string_to_layout(a_layout); auto b_layout_variant = string_to_layout(b_layout); return std::visit( [&](auto a_layout_type, auto b_layout_type) -> int { if constexpr(std::is_same_v && std::is_same_v) { throw std::runtime_error("Unsupported memory layout for the input matrices when " "BPrecType is ck_tile::pk_int4_t!"); } else { return run_gemm_example_with_layouts_universal( arg_parser, a_layout_type, b_layout_type, Row{}); } }, a_layout_variant, b_layout_variant); } template