From 1f1d11e933965c4b0530d00797ce1909c9f10fde Mon Sep 17 00:00:00 2001 From: Wojciech Laskowski <77888887+wj-laskowski@users.noreply.github.com> Date: Wed, 17 Sep 2025 01:23:29 +0200 Subject: [PATCH] Added wmma support for gemm quantization: (#2841) - profiler for gemm quantization for DL/XDL - tests for gemm quantization for DL/XDL - implementation for gemm quantization for WMMA - profiler/tests for gemm qunatization for WMMA Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> [ROCm/composable_kernel commit: f97b2a3f5d331009188c4601bd986a7b53a1ce2b] --- example/14_gemm_quantization/CMakeLists.txt | 1 + .../gemm_wmma_quantization_int8.cpp | 211 ++++++++++++++++ .../device_gemm_wmma_cshuffle_v3_common.hpp | 5 +- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 14 ++ .../gpu/quantization/gemm_quantization.hpp | 180 +++++++++++++- .../gpu/quantization/CMakeLists.txt | 6 + ...ation_wmma_c_shuffle_i8_i8_i8_instance.hpp | 79 ++++++ ...a_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp | 41 ++++ ...a_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp | 41 ++++ ...a_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp | 41 ++++ ...a_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp | 41 ++++ .../gemm/gemm_quantization_common.hpp | 5 +- .../profile_gemm_quantization_impl.hpp | 231 ++++++++++++++++++ profiler/src/CMakeLists.txt | 9 + profiler/src/profile_gemm_quantization.cpp | 115 +++++++++ test/CMakeLists.txt | 1 + test/quantization/CMakeLists.txt | 2 + test/quantization/gemm/CMakeLists.txt | 9 + .../gemm/test_gemm_quantization.cpp | 40 +++ .../gemm/test_gemm_quantization_ut_cases.inc | 41 ++++ .../gemm/test_gemm_quantization_util.hpp | 62 +++++ 21 files changed, 1167 insertions(+), 8 deletions(-) create mode 100644 example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp create mode 100644 library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp create mode 100644 library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp create mode 100644 profiler/include/profiler/profile_gemm_quantization_impl.hpp create mode 100644 profiler/src/profile_gemm_quantization.cpp create mode 100644 test/quantization/CMakeLists.txt create mode 100644 test/quantization/gemm/CMakeLists.txt create mode 100644 test/quantization/gemm/test_gemm_quantization.cpp create mode 100644 test/quantization/gemm/test_gemm_quantization_ut_cases.inc create mode 100644 test/quantization/gemm/test_gemm_quantization_util.hpp diff --git a/example/14_gemm_quantization/CMakeLists.txt b/example/14_gemm_quantization/CMakeLists.txt index 8703fa3ed7..b058e7b0fa 100644 --- a/example/14_gemm_quantization/CMakeLists.txt +++ b/example/14_gemm_quantization/CMakeLists.txt @@ -1,3 +1,4 @@ add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp) +add_example_executable(example_gemm_wmma_quantization_int8 gemm_wmma_quantization_int8.cpp) add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp) add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp) diff --git a/example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp b/example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp new file mode 100644 index 0000000000..a3023997a1 --- /dev/null +++ b/example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp @@ -0,0 +1,211 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using I8 = int8_t; +using I32 = int32_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ActivationOp = PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + +using ADataType = I8; +using BDataType = I8; +using AccDataType = I32; +using CShuffleDataType = I32; +using DsDataType = ck::Tuple<>; +using EDataType = I8; + +using ALayout = Col; +using BLayout = Row; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3< + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + DsDataType, + EDataType, + AccDataType, + CShuffleDataType, + ActivationOp, + ActivationOp, + CDEElementOp, + GemmDefault, + 256, + 128, + 128, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<4, 64, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + true, + S<4, 64, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + true, + 1, + 1, + S<1, 32, 1, 8>, + S<1>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + I8, + I8>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int /* argc */, char* /* argv */[]) +{ + bool do_verification = true; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideE = N; + + float requant_scale = 0.03; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = PassThrough{}; + auto b_element_op = PassThrough{}; + auto cde_element_op = CDEElementOp{requant_scale, ActivationOp{}}; + + // device GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + std::array{}, + static_cast(e_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, e_m_n_host_result, a_element_op, b_element_op, cde_element_op); + + ref_invoker.Run(ref_argument); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp index 55aa7b59ee..72191632d8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp @@ -3,6 +3,7 @@ #pragma once +#include #include #include @@ -171,8 +172,8 @@ struct DeviceGemm_Wmma_CShuffleV3_Common // other hand, Split K for 16-bit outputs uses packed atomics so ScalarPerVectors cannot // be odd. constexpr bool AtomicsImplementationExists = - !(std::is_same_v || - std::is_same_v) || + !(std::is_same_v || std::is_same_v || + std::is_same_v) || (CDEShuffleBlockTransferScalarPerVectors{}[0] % 2 == 0); if(has_main_k_block_loop) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index b226730a09..59d3a6a4c5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -1065,6 +1065,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base } } + if constexpr(is_same, int8_t>::value) + { + if(karg.KBatch > 1) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "int8_t does not support KBatch > 1. KBatch: " << karg.KBatch + << " " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp b/library/include/ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp index 19600a90f8..9f148618ae 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -77,6 +77,8 @@ void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances( Activation_Mul_Clamp>>>& instances); #endif + +#ifdef CK_USE_XDL // Layout(A, B, C) = [Col, Row, Row] void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( std::vector>>>& instances); +#endif + +#ifdef CK_USE_WMMA +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances( + std::vector>>>& + instances); + +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances( + std::vector>>>& + instances); + +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances( + std::vector>>>& + instances); + +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances( + std::vector>>>& + instances); +#endif template && is_same_v && @@ -195,7 +258,9 @@ struct DeviceOperationInstanceFactory && is_same_v && @@ -206,7 +271,9 @@ struct DeviceOperationInstanceFactory && is_same_v && @@ -217,12 +284,117 @@ struct DeviceOperationInstanceFactory>; + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif + + return op_ptrs; + } +}; + +template +struct DeviceOperationInstanceFactory>> +{ + using DeviceOp = DeviceGemmMultipleDSplitK>; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_USE_WMMA + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v) + { + add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances( + op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v) + { + add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances( + op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v) + { + add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances( + op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v) + { + add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances( + op_ptrs); + } + } + } +#endif + return op_ptrs; } }; @@ -230,4 +402,4 @@ struct DeviceOperationInstanceFactory +using device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances = std::tuple< + // clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| ComputeTypeA| ComputeTypeB| + //################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| | | + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, 1, 1, S<1, 16, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t> + // clang-format on + >; + +template +using device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances = std::tuple< + // clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| ComputeTypeA| ComputeTypeB| + //################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| | | + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 64, 1, 2>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t> + // clang-format on + >; + +template +using device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances = std::tuple< + // clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| ComputeTypeA| ComputeTypeB| + //################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| | | + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, 1, 1, S<1, 32, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t> + // clang-format on + >; + +template +using device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances = std::tuple< + // clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| ComputeTypeA| ComputeTypeB| + //################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| | | + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 16, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..a3838bb398 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v3>{}); + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..31ff723166 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v3>{}); + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..07a632a77c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v3>{}); + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..ed9cc908ef --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v3>{}); + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/gemm_quantization_common.hpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/gemm_quantization_common.hpp index e7c2500fef..a4eb29c7a1 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/gemm/gemm_quantization_common.hpp +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/gemm_quantization_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -33,7 +33,8 @@ using Relu_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp< using Add_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; using Add_Relu_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; -static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; } // namespace instance } // namespace device diff --git a/profiler/include/profiler/profile_gemm_quantization_impl.hpp b/profiler/include/profiler/profile_gemm_quantization_impl.hpp new file mode 100644 index 0000000000..a115a41a34 --- /dev/null +++ b/profiler/include/profiler/profile_gemm_quantization_impl.hpp @@ -0,0 +1,231 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_gemm_quantization_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideE, + float requant_scale = 0.03f) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using MulClamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using ActivationOp = PassThrough; + using CDEElementOp = MulClamp; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{requant_scale, ActivationOp{}}; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + ck::Tuple<>, + ELayout, + ADataType, + BDataType, + ck::Tuple<>, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Activation_Mul_Clamp>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // run reference + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n)); + } + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + bool pass = true; + + // profile device operation instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init E to zero before profiling a kernel + e_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + + if(do_log) + { + LogRangeAsType( + std::cout << "e_m_n_device_result: ", e_m_n_device_result.mData, ",") + << std::endl; + + LogRangeAsType( + std::cout << "e_m_n_host_result: ", e_m_n_host_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 7cfdc5bfc9..31f684fe75 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -32,6 +32,7 @@ set(PROFILER_OPS profile_conv_tensor_rearrange.cpp profile_transpose.cpp profile_permute_scale.cpp + profile_gemm_quantization.cpp ) if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") @@ -112,6 +113,10 @@ if(DL_KERNELS) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp) endif() +if(CK_ENABLE_INT8) + list(APPEND PROFILER_OPS profile_gemm_quantization.cpp) +endif() + set(PROFILER_SOURCES profiler.cpp) foreach(SOURCE ${PROFILER_OPS}) string(REGEX REPLACE "profile_(.+)\.cpp" "\\1" OP_NAME ${SOURCE}) @@ -248,6 +253,10 @@ if(DL_KERNELS) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance) endif() +if(CK_ENABLE_INT8) + list(APPEND DEVICE_INSTANCES device_quantization_instance) +endif() + set(PROFILER_LIBS utility getopt::getopt) foreach(LIB ${DEVICE_INSTANCES}) string(REGEX REPLACE "device_(.+)_instance" "\\1" INSTANCE_NAME ${LIB}) diff --git a/profiler/src/profile_gemm_quantization.cpp b/profiler/src/profile_gemm_quantization.cpp new file mode 100644 index 0000000000..d28dd60dce --- /dev/null +++ b/profiler/src/profile_gemm_quantization.cpp @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "profiler/profile_gemm_quantization_impl.hpp" +#include "profiler_operation_registry.hpp" + +#define OP_NAME "gemm_quantization" +#define OP_DESC "GEMM Quantization" + +using INT8 = int8_t; +using INT32 = int32_t; + +int profile_gemm_quantization(int argc, char* argv[]) +{ + enum struct MatrixLayout + { + MK_KN_MN, // 0: + MK_NK_MN, // 1: + KM_KN_MN, // 2: + KM_NK_MN, // 3: + }; + + if(argc != 14) + { + // clang-format off + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: matrix layout (0: E[m, n] = A[m, k] * B[k, n];\n"); + printf(" 1: E[m, n] = A[m, k] * B[n, k];\n"); + printf(" 2: E[m, n] = A[k, m] * B[k, n];\n"); + printf(" 3: E[m, n] = A[k, m] * B[n, k])\n"); + printf("arg3: verification (0: no; 1: yes)\n"); + printf("arg4: initialization (0: no init; default: integer value)\n"); + printf("arg5: print tensor value (0: no; 1: yes)\n"); + printf("arg6: time kernel (0=no, 1=yes)\n"); + printf("arg7 to 12: M, N, K, StrideA, StrideB, StrideE\n"); + printf("arg13: requant_scale (float, e.g., 0.03)\n"); + // clang-format on + exit(1); + } + + const auto layout = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const bool time_kernel = std::stoi(argv[6]); + + const int M = std::stoi(argv[7]); + const int N = std::stoi(argv[8]); + const int K = std::stoi(argv[9]); + + const int StrideA = std::stoi(argv[10]); + const int StrideB = std::stoi(argv[11]); + const int StrideE = std::stoi(argv[12]); + + const float requant_scale = std::stof(argv[13]); + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_layout, auto b_layout, auto e_layout) { + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using ELayout = decltype(e_layout); + + bool pass = ck::profiler::profile_gemm_quantization_impl(do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + StrideA, + StrideB, + StrideE, + requant_scale); + + return pass ? 0 : 1; + }; + + if(layout == MatrixLayout::MK_KN_MN) + { + return profile(Row{}, Row{}, Row{}); + } + else if(layout == MatrixLayout::MK_NK_MN) + { + return profile(Row{}, Col{}, Row{}); + } + else if(layout == MatrixLayout::KM_KN_MN) + { + return profile(Col{}, Row{}, Row{}); + } + else if(layout == MatrixLayout::KM_NK_MN) + { + return profile(Col{}, Col{}, Row{}); + } + else + { + std::cout << "this layout is not implemented" << std::endl; + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_quantization); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index f898f67685..cedac568db 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -277,6 +277,7 @@ add_subdirectory(conv_tensor_rearrange) add_subdirectory(transpose) add_subdirectory(permute_scale) add_subdirectory(wrapper) +add_subdirectory(quantization) if(SUPPORTED_GPU_TARGETS MATCHES "gfx11") add_subdirectory(wmma_op) endif() diff --git a/test/quantization/CMakeLists.txt b/test/quantization/CMakeLists.txt new file mode 100644 index 0000000000..89a99f5e5d --- /dev/null +++ b/test/quantization/CMakeLists.txt @@ -0,0 +1,2 @@ +add_custom_target(test_quantization) +add_subdirectory(gemm) diff --git a/test/quantization/gemm/CMakeLists.txt b/test/quantization/gemm/CMakeLists.txt new file mode 100644 index 0000000000..630e6e09c9 --- /dev/null +++ b/test/quantization/gemm/CMakeLists.txt @@ -0,0 +1,9 @@ +add_custom_target(test_gemm_quantization_targets) + +add_gtest_executable(test_gemm_quantization test_gemm_quantization.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_quantization PRIVATE utility device_quantization_instance) + add_dependencies(test_gemm_quantization_targets test_gemm_quantization) +endif() + +add_dependencies(test_quantization test_gemm_quantization_targets) diff --git a/test/quantization/gemm/test_gemm_quantization.cpp b/test/quantization/gemm/test_gemm_quantization.cpp new file mode 100644 index 0000000000..9981ae8a41 --- /dev/null +++ b/test/quantization/gemm/test_gemm_quantization.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_quantization_impl.hpp" +#include "test_gemm_quantization_util.hpp" + +using I8 = int8_t; +using I32 = int32_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +class TestGemmQuantization : public ck::test::TestGemmQuantizationCommon +{ + protected: + using ProfileCall = bool (*const)(int, int, bool, bool, int, int, int, int, int, int, float); + + ProfileCall GetImpl() override + { + return &ck::profiler::profile_gemm_quantization_impl< + typename ck::test::TestGemmQuantizationCommon::ADataType, + typename ck::test::TestGemmQuantizationCommon::BDataType, + typename ck::test::TestGemmQuantizationCommon::AccDataType, + typename ck::test::TestGemmQuantizationCommon::EDataType, + typename ck::test::TestGemmQuantizationCommon::ALayout, + typename ck::test::TestGemmQuantizationCommon::BLayout, + typename ck::test::TestGemmQuantizationCommon::ELayout>; + } +}; + +using KernelTypes = ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; + +TYPED_TEST_SUITE(TestGemmQuantization, KernelTypes); + +#include "test_gemm_quantization_ut_cases.inc" diff --git a/test/quantization/gemm/test_gemm_quantization_ut_cases.inc b/test/quantization/gemm/test_gemm_quantization_ut_cases.inc new file mode 100644 index 0000000000..83a13e4a85 --- /dev/null +++ b/test/quantization/gemm/test_gemm_quantization_ut_cases.inc @@ -0,0 +1,41 @@ +#pragma once + +TYPED_TEST(TestGemmQuantization, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + for(int M : Ms) + this->Run({{M, N, K}}); +} + +TYPED_TEST(TestGemmQuantization, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 1024; + constexpr int K = 320; + + for(int M : Ms) + this->Run({{M, N, K}}); +} + +TYPED_TEST(TestGemmQuantization, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + for(int M : Ms) + this->Run({{M, N, K}}); +} + +TYPED_TEST(TestGemmQuantization, Regular) +{ + constexpr int M = 512; + constexpr int N = 512; + std::vector Ks{512}; + + for(int K : Ks) + this->Run({{M, N, K}}); +} diff --git a/test/quantization/gemm/test_gemm_quantization_util.hpp b/test/quantization/gemm/test_gemm_quantization_util.hpp new file mode 100644 index 0000000000..e1ca0de2db --- /dev/null +++ b/test/quantization/gemm/test_gemm_quantization_util.hpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using I8 = int8_t; +using I32 = int32_t; + +namespace ck { +namespace test { + +using TestMatrixSizes = std::vector>; + +static const TestMatrixSizes DefaultTestMatrixSizes = { + {16, 32, 64}, {512, 2048, 4096}, {2048, 1024, 16}}; + +template +class TestGemmQuantizationCommon : public ::testing::Test +{ + protected: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using AccDataType = std::tuple_element_t<2, Tuple>; + using EDataType = std::tuple_element_t<3, Tuple>; + using ALayout = std::tuple_element_t<4, Tuple>; + using BLayout = std::tuple_element_t<5, Tuple>; + using ELayout = std::tuple_element_t<6, Tuple>; + + using ProfileCall = bool (*const)(int, int, bool, bool, int, int, int, int, int, int, float); + + virtual ProfileCall GetImpl() = 0; + + void Run(const TestMatrixSizes& lengths = DefaultTestMatrixSizes) + { + bool all_success = true; + + for(auto length : lengths) + { + int M = length[0]; + int N = length[1]; + int K = length[2]; + int StrideA = ck::is_same_v ? K : M; + int StrideB = ck::is_same_v ? N : K; + int StrideE = ck::is_same_v ? N : M; + float requant_scale = 0.03f; + + all_success = + all_success & + GetImpl()(1, 1, false, true, M, N, K, StrideA, StrideB, StrideE, requant_scale); + } + + EXPECT_TRUE(all_success); + } +}; + +} // namespace test +} // namespace ck