// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { using namespace ck::literals; auto M = problem_size.M; auto N = problem_size.N; auto K = problem_size.K; auto StrideA = problem_size.StrideA; auto StrideB = problem_size.StrideB; auto StrideC = problem_size.StrideC; auto KBatch = problem_size.KBatch; auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { if constexpr(std::is_same_v) { return HostTensorDescriptor({row, col}, {stride, 1_uz}); } else { return HostTensorDescriptor({row, col}, {1_uz, stride}); } }; auto f_get_default_stride = [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { if(stride == -1) { // give a chance if stride is -1, return a default packed stride if constexpr(std::is_same_v) { return static_cast(col); } else { return static_cast(row); } } else return static_cast(stride); }; StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor b_k_n_preshuffled(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); switch(config.init_method) { case 0: break; case 1: a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); b_k_n.GenerateTensorValue(GeneratorTensor_2{0, 2}); break; case 2: a_m_k.GenerateTensorValue(GeneratorTensor_1{}); b_k_n.GenerateTensorValue(GeneratorTensor_1{}); break; default: a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); } Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n_preshuffled: " << b_k_n_preshuffled.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); // do GEMM auto device_op = DeviceOpInstance{}; // weight pre-shuffle int NPerWmma = device_op.GetPreShuffleParameters(); int KLane = ck::get_warp_size() / NPerWmma; int K0 = K / (KLane * KPack); // K -> K0 KLane KPack // N -> N0 NPerWmma // N, K -> N0 K0 KLane NPerWmma KPack int tempk; for(int n = 0; n < N; ++n) { for(int k = 0; k < K; ++k) { int n0 = n / NPerWmma; int n1 = n % NPerWmma; int k0 = k / (KLane * KPack); tempk = k % (KLane * KPack); int k1 = tempk / KPack; int k2 = tempk % KPack; int outputIndex = n0 * KPack * NPerWmma * KLane * K0 + k0 * KPack * NPerWmma * KLane + k1 * KPack * NPerWmma + n1 * KPack + k2; b_k_n_preshuffled(outputIndex) = b_k_n(n * K + k); } } a_m_k_device_buf.ToDevice(a_m_k.mData.data()); b_k_n_device_buf.ToDevice(b_k_n_preshuffled.mData.data()); c_m_n_device_buf.ToDevice(c_m_n_device_result.mData.data()); auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; auto c_element_op = CElementOp{}; auto invoker = device_op.MakeInvoker(); auto argument = device_op.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), static_cast(b_k_n_device_buf.GetDeviceBuffer()), static_cast(c_m_n_device_buf.GetDeviceBuffer()), M, N, K, StrideA, StrideB, StrideC, KBatch, a_element_op, b_element_op, c_element_op); if(!device_op.IsSupportedArgument(argument)) { std::cerr << device_op.GetTypeString() << " does not support this problem" << std::endl; return true; } float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 50, 50, false, 1}); bool pass = true; if(config.do_verification) { using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_argument = ref_gemm.MakeArgument( a_m_k, b_k_n, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); ref_invoker.Run(ref_argument); invoker.Run(argument, StreamConfig{nullptr, false, 0}); c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); pass &= ck::utils::check_err(c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", get_rtol(), get_atol()); } if(config.time_kernel) { ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); std::size_t flop = 2_uz * M * N * K; std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << device_op.GetTypeString() << std::endl; } return pass; } bool run_gemm_splitk_example(int argc, char* argv[]) { ProblemSizeSplitK problem_size{3840, 4096, 4096, 4096, 4096, 4096, 1}; ExecutionConfig config; return parse_cmd_args(argc, argv, problem_size, config) && run_gemm(problem_size, config); }