// SPDX-License-Identifier: MIT // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once template static constexpr inline auto is_row_major(Layout layout_) { return ck_tile::bool_constant, ck_tile::tensor_layout::gemm::RowMajor>>{}; } template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) { using ComputeType = std::conditional_t; // Calculate thresholds const auto rtol = ck_tile::get_relative_threshold( ck_tile::integer_divide_ceil(K, kbatch)); const auto atol = ck_tile::get_absolute_threshold( max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); // Calculate error due to split_k accumulation const auto rtol_split_k = ck_tile::get_relative_threshold(kbatch); const auto atol_split_k = ck_tile::get_absolute_threshold( max_accumulated_value, kbatch); // Use higher threshold return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } template void permute_tensor_b(Tensor& tensor) { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, ck_tile::sequence, ck_tile:: sequence, GemmConfig::PermuteA, GemmConfig::PermuteB>; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; using GemmPipeline = GEMM_PIPELINE; const ck_tile::index_t K = tensor.get_length(0); const ck_tile::index_t N = tensor.get_length(1); const ck_tile::index_t K1 = GemmPipeline::GetSmemPackB(); const ck_tile::index_t K0 = K / K1; Tensor tensor_copy = tensor; // int K0, N, K1 for(int j = 0; j < K0; j++) { for(int i = 0; i < N; i++) { for(int jj = 0; jj < K1; jj++) { tensor(j * N * K1 + i * K1 + jj) = tensor_copy(i * K + (j * K1 + jj)); } } } } template void permute_vectors_i4x4_b(Tensor& tensor) { const ck_tile::index_t K = tensor.get_length(0); const ck_tile::index_t N = tensor.get_length(1); // vector pk_i4x4 permute for(int i = 0; i < N; i++) { for(int j = 0; j < K; j += 8) { int8_t input[8]; for(int k = 0; k < 4; k++) { int8_t i4x2 = tensor(j + k * 2, i).data; input[k * 2 + 0] = (i4x2 >> 4) & 0xf; input[k * 2 + 1] = (i4x2 >> 0) & 0xf; } // permute 01234567->20643175 { int8_t hi = input[2]; int8_t lo = input[0]; int8_t i4x2 = (hi << 4) | lo; tensor(j + 0, i) = i4x2; } { int8_t hi = input[6]; int8_t lo = input[4]; int8_t i4x2 = (hi << 4) | lo; tensor(j + 2, i) = i4x2; } { int8_t hi = input[3]; int8_t lo = input[1]; int8_t i4x2 = (hi << 4) | lo; tensor(j + 4, i) = i4x2; } { int8_t hi = input[7]; int8_t lo = input[5]; int8_t i4x2 = (hi << 4) | lo; tensor(j + 6, i) = i4x2; } } } } template float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t stride_A, ck_tile::index_t stride_B, ck_tile::index_t stride_C, ck_tile::index_t kbatch, int n_warmup, int n_repeat) { ck_tile::GemmHostArgs args; args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); args.k_batch = kbatch; args.M = M; args.N = N; args.K = K; args.stride_A = stride_A; args.stride_B = stride_B; args.stride_C = stride_C; float ave_time = gemm_calc( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_byte / 1.E6 / ave_time; std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K << " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C << " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name << " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits::name << " B_Type=" << DataTypeTraits::name << " C_Type=" << DataTypeTraits::name << " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off") << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; return ave_time; } template int run_gemm_example_with_layouts(int argc, char* argv[], const ALayout a_layout = ALayout{}, const BLayout b_layout = BLayout{}, [[maybe_unused]] const CLayout c_layout = CLayout{}) { auto [result, arg_parser] = create_args(argc, argv); if(!result) return -1; using AccDataType = typename GemmTypeConfig::AccDataType; ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); ck_tile::index_t K = arg_parser.get_int("k"); ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); ck_tile::index_t kbatch = arg_parser.get_int("split_k"); int n_warmup = arg_parser.get_int("warmup"); int n_repeat = arg_parser.get_int("repeat"); ck_tile::index_t init_method = arg_parser.get_int("init"); stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); ck_tile::HostTensor a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); ck_tile::HostTensor b_k_n( ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); ck_tile::HostTensor c_m_n_dev_result( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); if(init_method == 0) { ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); } else if(init_method == 1) { ck_tile::FillMonotonicSeq{}(a_m_k); ck_tile::FillMonotonicSeq{}(b_k_n); } else if(init_method == 2) { ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k); ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n); } else { a_m_k.SetZero(); b_k_n.SetZero(); } if(GemmConfig::UseStructuredSparsity) { ck_tile::AdjustToStructuredSparsity{}(a_m_k); } ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); static_assert(!GemmConfig::PermuteA, "Not implemented"); if constexpr(std::is_same_v) { // Permute vector pk_i4x4 data for device implementation ck_tile::HostTensor b_k_n_dev = b_k_n; if constexpr(GemmConfig::PermuteB) { permute_tensor_b(b_k_n_dev); } permute_vectors_i4x4_b(b_k_n_dev); b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); } else { if constexpr(GemmConfig::PermuteB) { std::cout << "Permute for this DataType is not implemented." << std::endl; return false; } b_k_n_dev_buf.ToDevice(b_k_n.data()); } a_m_k_dev_buf.ToDevice(a_m_k.data()); c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); invoke_gemm( a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_dev_buf, M, N, K, stride_A, stride_B, stride_C, kbatch, n_warmup, n_repeat); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool pass = true; if(arg_parser.get_int("v") == 1) { ck_tile::HostTensor c_m_n_host_ref( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); c_m_n_host_ref.SetZero(); ck_tile::reference_gemm( a_m_k, b_k_n, c_m_n_host_ref); const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); const auto rtol_atol = calculate_rtol_atol( K, kbatch, max_accumulated_value); pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref, "Error: Incorrect results!", rtol_atol.at(ck_tile::number<0>{}), rtol_atol.at(ck_tile::number<1>{})); std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; } else if(arg_parser.get_int("v") == 2) { if constexpr(std::is_same_v) { // Restore input for B for gpu reference b_k_n_dev_buf.ToDevice(b_k_n.data()); } ck_tile::HostTensor c_m_n_gpu_ref( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); c_m_n_gpu_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero(); ADataType* d_A; BDataType* d_B; CDataType* d_C; ck_tile::hip_check_error(hipMalloc(&d_A, a_m_k.get_element_space_size_in_bytes())); ck_tile::hip_check_error(hipMalloc(&d_B, b_k_n.get_element_space_size_in_bytes())); ck_tile::hip_check_error( hipMalloc(&d_C, c_m_n_dev_result.get_element_space_size_in_bytes())); ck_tile::hip_check_error(hipMemcpy(d_A, a_m_k_dev_buf.GetDeviceBuffer(), a_m_k.get_element_space_size_in_bytes(), hipMemcpyHostToDevice)); ck_tile::hip_check_error(hipMemcpy(d_B, b_k_n_dev_buf.GetDeviceBuffer(), b_k_n.get_element_space_size_in_bytes(), hipMemcpyHostToDevice)); ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(), d_C, c_m_n_dev_result.get_element_space_size_in_bytes(), hipMemcpyDeviceToHost)); ck_tile::hip_check_error(hipFree(d_A)); ck_tile::hip_check_error(hipFree(d_B)); ck_tile::hip_check_error(hipFree(d_C)); c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); const float max_accumulated_value = *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); const auto rtol_atol = calculate_rtol_atol( K, kbatch, max_accumulated_value); pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref, "Error: Incorrect results!", rtol_atol.at(ck_tile::number<0>{}), rtol_atol.at(ck_tile::number<1>{})); std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl; } return pass; }