Added wmma support for gemm quantization: (#2841)

- profiler for gemm quantization for DL/XDL
- tests for gemm quantization for DL/XDL
- implementation for gemm quantization for WMMA
- profiler/tests for gemm qunatization for WMMA

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>

[ROCm/composable_kernel commit: f97b2a3f5d]
This commit is contained in:
Wojciech Laskowski
2025-09-17 01:23:29 +02:00
committed by GitHub
parent 748bdafb9d
commit 1f1d11e933
21 changed files with 1167 additions and 8 deletions

View File

@@ -1,3 +1,4 @@
add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp)
add_example_executable(example_gemm_wmma_quantization_int8 gemm_wmma_quantization_int8.cpp)
add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp)
add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp)

View File

@@ -0,0 +1,211 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <type_traits>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using I8 = int8_t;
using I32 = int32_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ActivationOp = PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<ActivationOp>;
using ADataType = I8;
using BDataType = I8;
using AccDataType = I32;
using CShuffleDataType = I32;
using DsDataType = ck::Tuple<>;
using EDataType = I8;
using ALayout = Col;
using BLayout = Row;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3<
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AccDataType,
CShuffleDataType,
ActivationOp,
ActivationOp,
CDEElementOp,
GemmDefault,
256,
128,
128,
64,
8,
8,
16,
16,
4,
2,
S<4, 64, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
true,
S<4, 64, 1>,
S<0, 2, 1>,
S<0, 2, 1>,
1,
1,
8,
true,
1,
1,
S<1, 32, 1, 8>,
S<1>,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v1,
I8,
I8>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, EDataType, float, PassThrough, PassThrough, CDEElementOp>;
int main(int /* argc */, char* /* argv */[])
{
bool do_verification = true;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 1024;
ck::index_t StrideA = K;
ck::index_t StrideB = N;
ck::index_t StrideE = N;
float requant_scale = 0.03;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
auto a_element_op = PassThrough{};
auto b_element_op = PassThrough{};
auto cde_element_op = CDEElementOp{requant_scale, ActivationOp{}};
// device GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
std::array<const void*, 0>{},
static_cast<EDataType*>(e_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, 0>{},
StrideE,
1,
a_element_op,
b_element_op,
cde_element_op);
if(!gemm.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, e_m_n_host_result, a_element_op, b_element_op, cde_element_op);
ref_invoker.Run(ref_argument);
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
}

View File

@@ -3,6 +3,7 @@
#pragma once
#include <cstdint>
#include <iostream>
#include <sstream>
@@ -171,8 +172,8 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
// other hand, Split K for 16-bit outputs uses packed atomics so ScalarPerVectors cannot
// be odd.
constexpr bool AtomicsImplementationExists =
!(std::is_same_v<EDataType, ck::half_t> ||
std::is_same_v<EDataType, ck::bhalf_t>) ||
!(std::is_same_v<EDataType, ck::half_t> || std::is_same_v<EDataType, ck::bhalf_t> ||
std::is_same_v<EDataType, int8_t>) ||
(CDEShuffleBlockTransferScalarPerVectors{}[0] % 2 == 0);
if(has_main_k_block_loop)

View File

@@ -1065,6 +1065,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
}
}
if constexpr(is_same<remove_cvref_t<EDataType>, int8_t>::value)
{
if(karg.KBatch > 1)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "int8_t does not support KBatch > 1. KBatch: " << karg.KBatch
<< " " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl;
}
return false;
}
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return true;
}

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -77,6 +77,8 @@ void add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
Activation_Mul_Clamp<PassThrough>>>>&
instances);
#endif
#ifdef CK_USE_XDL
// Layout(A, B, C) = [Col, Row, Row]
void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
@@ -136,6 +138,65 @@ void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
PassThrough,
Activation_Mul_Clamp<PassThrough>>>>&
instances);
#endif
#ifdef CK_USE_WMMA
void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Row,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Activation_Mul_Clamp<PassThrough>>>>&
instances);
void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Col,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Activation_Mul_Clamp<PassThrough>>>>&
instances);
void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Activation_Mul_Clamp<PassThrough>>>>&
instances);
void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Activation_Mul_Clamp<PassThrough>>>>&
instances);
#endif
template <typename ALayout,
typename BLayout,
@@ -184,7 +245,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
#endif
#ifdef CK_USE_XDL
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances(op_ptrs);
#endif
}
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
@@ -195,7 +258,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
#endif
#ifdef CK_USE_XDL
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(op_ptrs);
#endif
}
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
@@ -206,7 +271,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs);
#endif
#ifdef CK_USE_XDL
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(op_ptrs);
#endif
}
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
@@ -217,12 +284,117 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
#ifdef DL_KERNELS
add_device_gemm_quantization_dl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs);
#endif
#ifdef CK_USE_XDL
add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances(op_ptrs);
#endif
}
}
return op_ptrs;
}
#ifdef CK_USE_WMMA
using Wrapper =
DeviceGemmMultipleDSplitKWrapper<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Activation_Mul_Clamp<Activation>>;
auto new_op_ptrs =
DeviceOperationInstanceFactory<typename Wrapper::DeviceOp>::GetInstances();
for(auto& op_ptr : new_op_ptrs)
{
op_ptrs.emplace_back(std::make_unique<Wrapper>(std::move(op_ptr)));
}
#endif
return op_ptrs;
}
};
template <typename ALayout,
typename BLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename EDataType,
typename Activation>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleDSplitK<
ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Activation_Mul_Clamp<Activation>>>
{
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
BLayout,
Empty_Tuple,
ELayout,
ADataType,
BDataType,
Empty_Tuple,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Activation_Mul_Clamp<Activation>>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_USE_WMMA
if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<EDataType, int8_t>)
{
if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
if constexpr(is_same_v<Activation, PassThrough>)
{
add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances(
op_ptrs);
}
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
if constexpr(is_same_v<Activation, PassThrough>)
{
add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances(
op_ptrs);
}
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<ELayout, Row>)
{
if constexpr(is_same_v<Activation, PassThrough>)
{
add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
op_ptrs);
}
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<ELayout, Row>)
{
if constexpr(is_same_v<Activation, PassThrough>)
{
add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
op_ptrs);
}
}
}
#endif
return op_ptrs;
}
};
@@ -230,4 +402,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
#endif

View File

@@ -20,6 +20,12 @@ list(APPEND GEMM_QUANT_SRC
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp)
list(APPEND GEMM_QUANT_SRC
gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp)
add_instance_library(device_quantization_instance
${CONV2D_PERLAYER_QUANT_SRC}
${CONV2D_PERCHANNEL_QUANT_SRC}

View File

@@ -0,0 +1,79 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_quantization_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <typename OutElementOp,
BlockGemmPipelineScheduler GemmPipelineScheduler,
BlockGemmPipelineVersion GemmPipeline>
using device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances = std::tuple<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| ComputeTypeA| ComputeTypeB|
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| | |
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, 1, 1, S<1, 16, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>
// clang-format on
>;
template <typename OutElementOp,
BlockGemmPipelineScheduler GemmPipelineScheduler,
BlockGemmPipelineVersion GemmPipeline>
using device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances = std::tuple<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| ComputeTypeA| ComputeTypeB|
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| | |
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 64, 1, 2>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>
// clang-format on
>;
template <typename OutElementOp,
BlockGemmPipelineScheduler GemmPipelineScheduler,
BlockGemmPipelineVersion GemmPipeline>
using device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances = std::tuple<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| ComputeTypeA| ComputeTypeB|
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| | |
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, 1, 1, S<1, 32, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>
// clang-format on
>;
template <typename OutElementOp,
BlockGemmPipelineScheduler GemmPipelineScheduler,
BlockGemmPipelineVersion GemmPipeline>
using device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances = std::tuple<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| ComputeTypeA| ComputeTypeB|
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| | |
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 16, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,41 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Row,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Mul_Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances<
Mul_Clamp,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v3>{});
add_device_operation_instances(
instances,
device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances<
Mul_Clamp,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,41 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Col,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Mul_Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances<
Mul_Clamp,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v3>{});
add_device_operation_instances(
instances,
device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances<
Mul_Clamp,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,41 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Mul_Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances<
Mul_Clamp,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v3>{});
add_device_operation_instances(
instances,
device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances<
Mul_Clamp,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,41 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Mul_Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances<
Mul_Clamp,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v3>{});
add_device_operation_instances(
instances,
device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances<
Mul_Clamp,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -33,7 +33,8 @@ using Relu_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<
using Add_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<PassThrough>;
using Add_Relu_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Relu>;
static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
} // namespace instance
} // namespace device

View File

@@ -0,0 +1,231 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck {
namespace profiler {
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename EDataType,
typename ALayout,
typename BLayout,
typename ELayout>
bool profile_gemm_quantization_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideE,
float requant_scale = 0.03f)
{
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using MulClamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<PassThrough>;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using ActivationOp = PassThrough;
using CDEElementOp = MulClamp;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto cde_element_op = CDEElementOp{requant_scale, ActivationOp{}};
using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD<
ALayout,
BLayout,
ck::Tuple<>,
ELayout,
ADataType,
BDataType,
ck::Tuple<>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Activation_Mul_Clamp<PassThrough>>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// run reference
if(do_verification)
{
Tensor<AccDataType> c_m_n({M, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n));
}
}
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
std::string best_op_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
bool pass = true;
// profile device operation instances
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 0>{},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, 0>{},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init E to zero before profiling a kernel
e_device_buf.SetZero();
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
if(do_log)
{
LogRangeAsType<float>(
std::cout << "e_m_n_device_result: ", e_m_n_device_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "e_m_n_host_result: ", e_m_n_host_result.mData, ",")
<< std::endl;
}
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass;
}
} // namespace profiler
} // namespace ck

View File

@@ -32,6 +32,7 @@ set(PROFILER_OPS
profile_conv_tensor_rearrange.cpp
profile_transpose.cpp
profile_permute_scale.cpp
profile_gemm_quantization.cpp
)
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
@@ -112,6 +113,10 @@ if(DL_KERNELS)
list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp)
endif()
if(CK_ENABLE_INT8)
list(APPEND PROFILER_OPS profile_gemm_quantization.cpp)
endif()
set(PROFILER_SOURCES profiler.cpp)
foreach(SOURCE ${PROFILER_OPS})
string(REGEX REPLACE "profile_(.+)\.cpp" "\\1" OP_NAME ${SOURCE})
@@ -248,6 +253,10 @@ if(DL_KERNELS)
list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance)
endif()
if(CK_ENABLE_INT8)
list(APPEND DEVICE_INSTANCES device_quantization_instance)
endif()
set(PROFILER_LIBS utility getopt::getopt)
foreach(LIB ${DEVICE_INSTANCES})
string(REGEX REPLACE "device_(.+)_instance" "\\1" INSTANCE_NAME ${LIB})

View File

@@ -0,0 +1,115 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <cstdio>
#include "profiler/profile_gemm_quantization_impl.hpp"
#include "profiler_operation_registry.hpp"
#define OP_NAME "gemm_quantization"
#define OP_DESC "GEMM Quantization"
using INT8 = int8_t;
using INT32 = int32_t;
int profile_gemm_quantization(int argc, char* argv[])
{
enum struct MatrixLayout
{
MK_KN_MN, // 0:
MK_NK_MN, // 1:
KM_KN_MN, // 2:
KM_NK_MN, // 3:
};
if(argc != 14)
{
// clang-format off
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: matrix layout (0: E[m, n] = A[m, k] * B[k, n];\n");
printf(" 1: E[m, n] = A[m, k] * B[n, k];\n");
printf(" 2: E[m, n] = A[k, m] * B[k, n];\n");
printf(" 3: E[m, n] = A[k, m] * B[n, k])\n");
printf("arg3: verification (0: no; 1: yes)\n");
printf("arg4: initialization (0: no init; default: integer value)\n");
printf("arg5: print tensor value (0: no; 1: yes)\n");
printf("arg6: time kernel (0=no, 1=yes)\n");
printf("arg7 to 12: M, N, K, StrideA, StrideB, StrideE\n");
printf("arg13: requant_scale (float, e.g., 0.03)\n");
// clang-format on
exit(1);
}
const auto layout = static_cast<MatrixLayout>(std::stoi(argv[2]));
const bool do_verification = std::stoi(argv[3]);
const int init_method = std::stoi(argv[4]);
const bool do_log = std::stoi(argv[5]);
const bool time_kernel = std::stoi(argv[6]);
const int M = std::stoi(argv[7]);
const int N = std::stoi(argv[8]);
const int K = std::stoi(argv[9]);
const int StrideA = std::stoi(argv[10]);
const int StrideB = std::stoi(argv[11]);
const int StrideE = std::stoi(argv[12]);
const float requant_scale = std::stof(argv[13]);
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
auto profile = [&](auto a_layout, auto b_layout, auto e_layout) {
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using ELayout = decltype(e_layout);
bool pass = ck::profiler::profile_gemm_quantization_impl<int8_t,
int8_t,
int32_t,
int8_t,
ALayout,
BLayout,
ELayout>(do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
StrideA,
StrideB,
StrideE,
requant_scale);
return pass ? 0 : 1;
};
if(layout == MatrixLayout::MK_KN_MN)
{
return profile(Row{}, Row{}, Row{});
}
else if(layout == MatrixLayout::MK_NK_MN)
{
return profile(Row{}, Col{}, Row{});
}
else if(layout == MatrixLayout::KM_KN_MN)
{
return profile(Col{}, Row{}, Row{});
}
else if(layout == MatrixLayout::KM_NK_MN)
{
return profile(Col{}, Col{}, Row{});
}
else
{
std::cout << "this layout is not implemented" << std::endl;
return 1;
}
}
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_quantization);

View File

@@ -277,6 +277,7 @@ add_subdirectory(conv_tensor_rearrange)
add_subdirectory(transpose)
add_subdirectory(permute_scale)
add_subdirectory(wrapper)
add_subdirectory(quantization)
if(SUPPORTED_GPU_TARGETS MATCHES "gfx11")
add_subdirectory(wmma_op)
endif()

View File

@@ -0,0 +1,2 @@
add_custom_target(test_quantization)
add_subdirectory(gemm)

View File

@@ -0,0 +1,9 @@
add_custom_target(test_gemm_quantization_targets)
add_gtest_executable(test_gemm_quantization test_gemm_quantization.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_quantization PRIVATE utility device_quantization_instance)
add_dependencies(test_gemm_quantization_targets test_gemm_quantization)
endif()
add_dependencies(test_quantization test_gemm_quantization_targets)

View File

@@ -0,0 +1,40 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "profiler/profile_gemm_quantization_impl.hpp"
#include "test_gemm_quantization_util.hpp"
using I8 = int8_t;
using I32 = int32_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <typename Tuple>
class TestGemmQuantization : public ck::test::TestGemmQuantizationCommon<Tuple>
{
protected:
using ProfileCall = bool (*const)(int, int, bool, bool, int, int, int, int, int, int, float);
ProfileCall GetImpl() override
{
return &ck::profiler::profile_gemm_quantization_impl<
typename ck::test::TestGemmQuantizationCommon<Tuple>::ADataType,
typename ck::test::TestGemmQuantizationCommon<Tuple>::BDataType,
typename ck::test::TestGemmQuantizationCommon<Tuple>::AccDataType,
typename ck::test::TestGemmQuantizationCommon<Tuple>::EDataType,
typename ck::test::TestGemmQuantizationCommon<Tuple>::ALayout,
typename ck::test::TestGemmQuantizationCommon<Tuple>::BLayout,
typename ck::test::TestGemmQuantizationCommon<Tuple>::ELayout>;
}
};
using KernelTypes = ::testing::Types<std::tuple<I8, I8, I32, I8, Row, Row, Row>,
std::tuple<I8, I8, I32, I8, Row, Col, Row>,
std::tuple<I8, I8, I32, I8, Col, Row, Row>,
std::tuple<I8, I8, I32, I8, Col, Col, Row>>;
TYPED_TEST_SUITE(TestGemmQuantization, KernelTypes);
#include "test_gemm_quantization_ut_cases.inc"

View File

@@ -0,0 +1,41 @@
#pragma once
TYPED_TEST(TestGemmQuantization, SmallM)
{
std::vector<int> Ms{1, 2, 3, 4, 5, 6};
constexpr int N = 512;
constexpr int K = 320;
for(int M : Ms)
this->Run({{M, N, K}});
}
TYPED_TEST(TestGemmQuantization, MidLargeM)
{
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 1024;
constexpr int K = 320;
for(int M : Ms)
this->Run({{M, N, K}});
}
TYPED_TEST(TestGemmQuantization, MNKPadded)
{
const std::vector<int> Ms{127, 150, 188, 210};
constexpr int N = 136;
constexpr int K = 280;
for(int M : Ms)
this->Run({{M, N, K}});
}
TYPED_TEST(TestGemmQuantization, Regular)
{
constexpr int M = 512;
constexpr int N = 512;
std::vector<int> Ks{512};
for(int K : Ks)
this->Run({{M, N, K}});
}

View File

@@ -0,0 +1,62 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/utility/data_type.hpp"
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using I8 = int8_t;
using I32 = int32_t;
namespace ck {
namespace test {
using TestMatrixSizes = std::vector<std::vector<ck::index_t>>;
static const TestMatrixSizes DefaultTestMatrixSizes = {
{16, 32, 64}, {512, 2048, 4096}, {2048, 1024, 16}};
template <typename Tuple>
class TestGemmQuantizationCommon : public ::testing::Test
{
protected:
using ADataType = std::tuple_element_t<0, Tuple>;
using BDataType = std::tuple_element_t<1, Tuple>;
using AccDataType = std::tuple_element_t<2, Tuple>;
using EDataType = std::tuple_element_t<3, Tuple>;
using ALayout = std::tuple_element_t<4, Tuple>;
using BLayout = std::tuple_element_t<5, Tuple>;
using ELayout = std::tuple_element_t<6, Tuple>;
using ProfileCall = bool (*const)(int, int, bool, bool, int, int, int, int, int, int, float);
virtual ProfileCall GetImpl() = 0;
void Run(const TestMatrixSizes& lengths = DefaultTestMatrixSizes)
{
bool all_success = true;
for(auto length : lengths)
{
int M = length[0];
int N = length[1];
int K = length[2];
int StrideA = ck::is_same_v<ALayout, Row> ? K : M;
int StrideB = ck::is_same_v<BLayout, Row> ? N : K;
int StrideE = ck::is_same_v<ELayout, Row> ? N : M;
float requant_scale = 0.03f;
all_success =
all_success &
GetImpl()(1, 1, false, true, M, N, K, StrideA, StrideB, StrideE, requant_scale);
}
EXPECT_TRUE(all_success);
}
};
} // namespace test
} // namespace ck