Added wmma support for gemm quantization: (#2841)

- profiler for gemm quantization for DL/XDL
- tests for gemm quantization for DL/XDL
- implementation for gemm quantization for WMMA
- profiler/tests for gemm qunatization for WMMA

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Wojciech Laskowski
2025-09-17 01:23:29 +02:00
committed by GitHub
parent 2723dbd332
commit f97b2a3f5d
21 changed files with 1167 additions and 8 deletions

View File

@@ -20,6 +20,12 @@ list(APPEND GEMM_QUANT_SRC
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp)
list(APPEND GEMM_QUANT_SRC
gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp)
add_instance_library(device_quantization_instance
${CONV2D_PERLAYER_QUANT_SRC}
${CONV2D_PERCHANNEL_QUANT_SRC}

View File

@@ -0,0 +1,79 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_quantization_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <typename OutElementOp,
BlockGemmPipelineScheduler GemmPipelineScheduler,
BlockGemmPipelineVersion GemmPipeline>
using device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances = std::tuple<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| ComputeTypeA| ComputeTypeB|
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| | |
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, 1, 1, S<1, 16, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>
// clang-format on
>;
template <typename OutElementOp,
BlockGemmPipelineScheduler GemmPipelineScheduler,
BlockGemmPipelineVersion GemmPipeline>
using device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances = std::tuple<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| ComputeTypeA| ComputeTypeB|
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| | |
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 64, 1, 2>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>
// clang-format on
>;
template <typename OutElementOp,
BlockGemmPipelineScheduler GemmPipelineScheduler,
BlockGemmPipelineVersion GemmPipeline>
using device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances = std::tuple<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| ComputeTypeA| ComputeTypeB|
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| | |
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, 1, 1, S<1, 32, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>
// clang-format on
>;
template <typename OutElementOp,
BlockGemmPipelineScheduler GemmPipelineScheduler,
BlockGemmPipelineVersion GemmPipeline>
using device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances = std::tuple<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| ComputeTypeA| ComputeTypeB|
//################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| | |
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 16, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,41 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Row,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Mul_Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances<
Mul_Clamp,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v3>{});
add_device_operation_instances(
instances,
device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances<
Mul_Clamp,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,41 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Col,
Col,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Mul_Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances<
Mul_Clamp,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v3>{});
add_device_operation_instances(
instances,
device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances<
Mul_Clamp,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,41 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Row,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Mul_Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances<
Mul_Clamp,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v3>{});
add_device_operation_instances(
instances,
device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances<
Mul_Clamp,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,41 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Empty_Tuple,
Row,
int8_t,
int8_t,
Empty_Tuple,
int8_t,
PassThrough,
PassThrough,
Mul_Clamp>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances<
Mul_Clamp,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v3>{});
add_device_operation_instances(
instances,
device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances<
Mul_Clamp,
BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion::v1>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -33,7 +33,8 @@ using Relu_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp<
using Add_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<PassThrough>;
using Add_Relu_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp<Relu>;
static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
} // namespace instance
} // namespace device