Added wmma support for gemm quantization: (#2841)

- profiler for gemm quantization for DL/XDL
- tests for gemm quantization for DL/XDL
- implementation for gemm quantization for WMMA
- profiler/tests for gemm qunatization for WMMA

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Wojciech Laskowski
2025-09-17 01:23:29 +02:00
committed by GitHub
parent 2723dbd332
commit f97b2a3f5d
21 changed files with 1167 additions and 8 deletions

View File

@@ -0,0 +1,231 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck {
namespace profiler {
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename EDataType,
typename ALayout,
typename BLayout,
typename ELayout>
bool profile_gemm_quantization_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideE,
float requant_scale = 0.03f)
{
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using MulClamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<PassThrough>;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using ActivationOp = PassThrough;
using CDEElementOp = MulClamp;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto cde_element_op = CDEElementOp{requant_scale, ActivationOp{}};
using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD<
ALayout,
BLayout,
ck::Tuple<>,
ELayout,
ADataType,
BDataType,
ck::Tuple<>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Activation_Mul_Clamp<PassThrough>>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// run reference
if(do_verification)
{
Tensor<AccDataType> c_m_n({M, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n));
}
}
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
std::string best_op_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
bool pass = true;
// profile device operation instances
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 0>{},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, 0>{},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init E to zero before profiling a kernel
e_device_buf.SetZero();
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
if(do_log)
{
LogRangeAsType<float>(
std::cout << "e_m_n_device_result: ", e_m_n_device_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "e_m_n_host_result: ", e_m_n_host_result.mData, ",")
<< std::endl;
}
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass;
}
} // namespace profiler
} // namespace ck