// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once // use macro to minimize code change #ifndef EXAMPLE_WITH_COMPUTE_DATATYPE using ComputeDataType = AccDataType; #endif struct ProblemSize final { std::vector Ms; std::vector Ns; std::vector Ks; std::vector stride_As; std::vector stride_Bs; std::vector stride_Cs; ck::index_t group_count; #if defined(EXAMPLE_USE_SPLITK) ck::index_t k_batch; #endif }; struct ExecutionConfig final { bool do_verification = true; int init_method = 1; bool time_kernel = false; bool async_hargs = false; }; bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) { #if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4) static_assert(sizeof(ck::int4_t) == sizeof(int8_t)); static_assert(sizeof(ADataType) == sizeof(KernelADataType)); static_assert(sizeof(BDataType) == sizeof(KernelBDataType)); static_assert(sizeof(EDataType) == sizeof(KernelEDataType)); #endif int group_count = problem_size.group_count; // GEMM shape std::vector gemm_descs; std::vector p_a, p_b; std::vector p_c; gemm_descs.reserve(group_count); for(int i = 0; i < group_count; i++) { int M = problem_size.Ms[i]; int N = problem_size.Ns[i]; int K = problem_size.Ks[i]; int stride_A = problem_size.stride_As[i]; int stride_B = problem_size.stride_Bs[i]; int stride_C = problem_size.stride_Cs[i]; gemm_descs.push_back({M, N, K, stride_A, stride_B, stride_C, {}}); } auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { using namespace ck::literals; if(std::is_same::value) { return HostTensorDescriptor({row, col}, {stride, 1_uz}); } else { return HostTensorDescriptor({row, col}, {1_uz, stride}); } }; std::vector> a_tensors; std::vector> b_tensors; std::vector> c_host_tensors; #ifdef BUILD_INT4_EXAMPLE std::vector> c_device_tensors; #else std::vector> c_device_tensors; #endif a_tensors.reserve(group_count); b_tensors.reserve(group_count); c_host_tensors.reserve(group_count); c_device_tensors.reserve(group_count); using DeviceMemPtr = std::unique_ptr; std::vector a_tensors_device, b_tensors_device, c_tensors_device; a_tensors_device.reserve(group_count); b_tensors_device.reserve(group_count); c_tensors_device.reserve(group_count); std::size_t flop = 0, num_btype = 0; for(std::size_t i = 0; i < gemm_descs.size(); i++) { a_tensors.push_back(Tensor(f_host_tensor_descriptor( gemm_descs[i].M_, gemm_descs[i].K_, gemm_descs[i].stride_A_, ALayout{}))); b_tensors.push_back(Tensor(f_host_tensor_descriptor( gemm_descs[i].K_, gemm_descs[i].N_, gemm_descs[i].stride_B_, BLayout{}))); c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{}))); #ifdef BUILD_INT4_EXAMPLE c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{}))); #else c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, ELayout{}))); #endif std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc << std::endl; flop += std::size_t(2) * gemm_descs[i].M_ * gemm_descs[i].K_ * gemm_descs[i].N_; num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); switch(config.init_method) { case 0: break; case 1: a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); break; case 2: a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); break; default: a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); } } for(std::size_t i = 0; i < gemm_descs.size(); i++) { a_tensors_device.emplace_back(std::make_unique( sizeof(ADataType) * a_tensors[i].mDesc.GetElementSpaceSize())); b_tensors_device.emplace_back(std::make_unique( sizeof(BDataType) * b_tensors[i].mDesc.GetElementSpaceSize())); c_tensors_device.emplace_back(std::make_unique( sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSpaceSize())); #ifdef BUILD_INT4_EXAMPLE const Tensor a_converted(a_tensors[i]); const Tensor b_converted(b_tensors[i]); a_tensors_device[i]->ToDevice(a_converted.mData.data()); b_tensors_device[i]->ToDevice(b_converted.mData.data()); #else a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); c_tensors_device[i]->SetZero(); #endif p_a.push_back(a_tensors_device[i]->GetDeviceBuffer()); p_b.push_back(b_tensors_device[i]->GetDeviceBuffer()); p_c.push_back(c_tensors_device[i]->GetDeviceBuffer()); } auto a_element_op = AElementOp{}; auto b_element_op = BElementOp{}; auto c_element_op = CDEElementOp{}; auto gemm = DeviceGemmInstance{}; auto invoker = gemm.MakeInvoker(); std::vector> p_Ds = {}; // do GEMM auto argument = gemm.MakeArgument( p_a, p_b, p_Ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op); #if defined(EXAMPLE_USE_SPLITK) gemm.SetKBatchSize(&argument, problem_size.k_batch); #endif std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument); std::size_t kargs_size = gemm.GetDeviceKernelArgSize(&argument); std::size_t hargs_size = gemm.GetHostKernelArgSize(&argument); DeviceMem gemm_workspace, gemm_kargs; void* gemm_hargs; // The following is necessary since TwoStage kernel is using additional memory both // for Workspace and kernel arguments. if(kargs_size > 0) { gemm_kargs.Realloc(kargs_size); gemm.SetDeviceKernelArgs(&argument, gemm_kargs.GetDeviceBuffer()); } if(workspace_size > 0 && workspace_size != kargs_size) { gemm_workspace.Realloc(workspace_size); gemm.SetWorkSpacePointer(&argument, gemm_workspace.GetDeviceBuffer()); } if(config.async_hargs && hargs_size > 0) { hip_check_error(hipHostMalloc(&gemm_hargs, hargs_size)); gemm.SetHostKernelArgsPointer(&argument, gemm_hargs); } if(!gemm.IsSupportedArgument(argument)) { throw std::runtime_error( "wrong! device_gemm with the specified compilation parameters does " "not support this GEMM problem"); } if(!config.async_hargs) { invoker.Run(argument, StreamConfig{nullptr, false}); } else { hipStream_t stream0 = nullptr; hip_check_error(hipStreamCreate(&stream0)); hipEvent_t event0 = nullptr; hip_check_error(hipEventCreate(&event0)); invoker.Run(argument, StreamConfig{nullptr, false}, stream0, event0); hip_check_error(hipEventSynchronize(event0)); hip_check_error(hipStreamSynchronize(stream0)); } bool pass = true; if(config.do_verification) { using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; for(std::size_t i = 0; i < gemm_descs.size(); i++) { c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data()); auto ref_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], b_tensors[i], c_host_tensors[i], a_element_op, b_element_op, c_element_op); ref_invoker.Run(ref_argument); #ifdef BUILD_INT4_EXAMPLE const Tensor c_device_result_converted(c_device_tensors[i]); pass &= ck::utils::check_err(c_device_result_converted, c_host_tensors[i]); #else pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); #endif } std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; } if(config.time_kernel) { float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << gemm.GetTypeString() << std::endl; } return pass; } bool run_grouped_gemm_example(int argc, char* argv[]) { ProblemSize problem_size; ExecutionConfig config; problem_size.group_count = 16; #if defined(EXAMPLE_USE_SPLITK) problem_size.k_batch = 1; #endif if(argc == 1) { // use default cases } else if(argc == 4 || argc == 6 || argc == 7) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); config.time_kernel = std::stoi(argv[3]); if(argc == 6) { config.async_hargs = std::stoi(argv[4]); problem_size.group_count = std::stoi(argv[5]); } #if defined(EXAMPLE_USE_SPLITK) if(argc == 7) { problem_size.k_batch = std::stoi(argv[6]); } #endif } else { printf("arg1: verification (0=no, 1=yes)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg4: async hargs (0=no, 1=yes)\n"); printf("arg5: group count (default=16)\n"); #if defined(EXAMPLE_USE_SPLITK) printf("arg6: k-batch count (default=1)\n"); #endif exit(1); } // Lambda to get stride based on layout auto get_stride = [](auto layout, auto row_dim, auto col_dim) { if constexpr(std::is_same_v) { return col_dim; } else { return row_dim; } }; for(int i = 0; i < problem_size.group_count; i++) { problem_size.Ms.push_back(256 + 256 * i); problem_size.Ns.push_back(128 + 128 * i); problem_size.Ks.push_back(128 + 64 * i); problem_size.stride_As.push_back( get_stride(ALayout{}, problem_size.Ms[i], problem_size.Ks[i])); problem_size.stride_Bs.push_back( get_stride(BLayout{}, problem_size.Ks[i], problem_size.Ns[i])); problem_size.stride_Cs.push_back( get_stride(ELayout{}, problem_size.Ms[i], problem_size.Ns[i])); } return run_grouped_gemm(problem_size, config); }