added wmma multiply_multiply instances

This commit is contained in:
Zoltan Lakatos
2025-06-17 15:03:18 +00:00
parent 881bc3fbb0
commit ac60286ed0
13 changed files with 360 additions and 60 deletions

View File

@@ -270,8 +270,8 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
class FloatA,
class FloatB,
class FloatC,
bool neg_a = false,
bool neg_b = false,
bool neg_a = true,
bool neg_b = true,
bool clamp = false>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
@@ -390,8 +390,8 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8_gfx12,
class FloatA,
class FloatB,
class FloatC,
bool neg_a = false,
bool neg_b = false,
bool neg_a = true,
bool neg_b = true,
bool clamp = false>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
@@ -793,6 +793,8 @@ struct WmmaGemm
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
"((f8 or bf8, f8 or bf8), float), (int8, int32) or (int4, int32)!");
static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) {
// Integer wmma operators need extra input flags to indicate if the input is singed or unsigned.
// At the moment CK supports only singed integer inputs, so these flags are hardcoded.
if constexpr(!TransposeC)
{
wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave[k], p_b_wave[k], p_c_thread);

View File

@@ -47,6 +47,7 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using Row_Tuple = ck::Tuple<Row>;
using Row_Row_Tuple = ck::Tuple<Row, Row>;
using Row_Col_Tuple = ck::Tuple<Row, Col>;
// Conv layout
//

View File

@@ -16,6 +16,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_USE_XDL
#ifdef CK_ENABLE_FP8
#ifdef CK_ENABLE_BF16
void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances_part1(
@@ -280,7 +281,6 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_kpadding_in
MultiplyMultiply>>>& instances);
#endif
#endif
#ifdef CK_ENABLE_FP16
void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_comp_default_instances_part1(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
@@ -464,8 +464,7 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_f16_mk_nk_mn_mem_v2_kpadding_in
PassThrough,
MultiplyMultiply>>>& instances);
#endif
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_INT8))
#if (defined(CK_ENABLE_FP16) || defined(CK_ENABLE_INT8))
void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_comp_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
@@ -545,6 +544,35 @@ void add_device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_kpadding_in
MultiplyMultiply>>>& instances);
#endif
#endif // CK_USE_XDL
#ifdef CK_USE_WMMA
void add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Col_Tuple,
Row,
I8,
I8,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_bf16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Col_Tuple,
Row,
I8,
I8,
F32_F32_Tuple,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
#endif // CK_USE_WMMA
template <typename ADataType,
typename BDataType,
@@ -553,36 +581,35 @@ template <typename ADataType,
typename ALayout,
typename BLayout,
typename CLayout>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleDSplitK<
ALayout,
BLayout,
Tuple<Row, Col>,
CLayout,
ADataType,
BDataType,
DsDataType,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyMultiply>>
struct DeviceOperationInstanceFactory<DeviceGemmMultipleDSplitK<ALayout,
BLayout,
Tuple<Row, Col>,
CLayout,
ADataType,
BDataType,
DsDataType,
CDataType,
PassThrough,
PassThrough,
MultiplyMultiply>>
{
using DeviceOp =
DeviceGemmMultipleDSplitK<ALayout,
BLayout,
Tuple<Row, Col>,
CLayout,
ADataType,
BDataType,
DsDataType,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyMultiply>;
using DeviceOp = DeviceGemmMultipleDSplitK<ALayout,
BLayout,
Tuple<Row, Col>,
CLayout,
ADataType,
BDataType,
DsDataType,
CDataType,
PassThrough,
PassThrough,
MultiplyMultiply>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_USE_XDL
#ifdef CK_ENABLE_FP8
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
@@ -667,7 +694,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
}
#endif
#endif
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_INT8))
#if (defined(CK_ENABLE_FP16) || defined(CK_ENABLE_INT8))
if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<CDataType, half_t>)
{
@@ -691,6 +718,31 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
}
}
#endif
#endif // CK_USE_XDL
#ifdef CK_USE_WMMA
if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_f16_km_nk_mn_instances(
op_ptrs);
}
}
if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_bf16_km_nk_mn_instances(
op_ptrs);
}
}
#endif // CK_USE_WMMA
return op_ptrs;
}
};

View File

@@ -279,10 +279,10 @@ FOREACH(subdir_path ${dir_list})
message("Found xdl, dl, and wmma instances, but none of those meet the target list. Skipping.")
set(add_inst 0)
endif()
if(("${cmake_instance}" MATCHES "gemm_multiply_multiply" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94") AND (NOT INST_TARGETS MATCHES "gfx95") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH))
message("Found gemm_multiply_multiply_f8 instances, but gfx94/gfx95 not on the target list. Skipping.")
set(add_inst 0)
endif()
# if(("${cmake_instance}" MATCHES "gemm_multiply_multiply" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94") AND (NOT INST_TARGETS MATCHES "gfx95") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH))
# message("Found gemm_multiply_multiply_f8 instances, but gfx94/gfx95 not on the target list. Skipping.")
# set(add_inst 0)
# endif()
if ("${cmake_instance}" MATCHES "gemm_bilinear")
set(add_inst 0)
if((SUPPORTED_GPU_TARGETS MATCHES "gfx9") AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES))

View File

@@ -1,4 +1,4 @@
# ONLY XDL_KERNELS
# ONLY XDL_AND_WMMA_KERNELS
set(GEMM_MULTIPLY_MULTIPLY_INSTANCES)
list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES
@@ -38,6 +38,9 @@ list(APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_xdl_i8_i8_f16/device_gemm_multiply_multiply_xdl_i8_i8_f16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_f16_mk_nk_mn.cpp
device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_bf16_mk_nk_mn.cpp
)
set_source_files_properties(device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance_part1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")

View File

@@ -0,0 +1,73 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <index_t... Is>
using S = Sequence<Is...>;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
template <GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_wmma_i8_i8_bf16_mk_nk_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F32_F32_Tuple, BF16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3, I8, I8>
// clang-format on
>;
void add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_bf16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Col_Tuple,
Row,
I8,
I8,
F32_F32_Tuple,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_wmma_i8_i8_bf16_mk_nk_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_wmma_i8_i8_bf16_mk_nk_mn_instances<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,73 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp"
#include "ck/utility/sequence.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
template <index_t... Is>
using S = Sequence<Is...>;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto V3 = BlockGemmPipelineVersion::v3;
static constexpr auto V1 = BlockGemmPipelineVersion::v1;
template <GemmSpecialization GemmSpec>
using device_gemm_multiply_multiply_wmma_i8_i8_f16_mk_nk_mn_instances =
std::tuple<
// clang-format off
//##################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm|
//##################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer|
//##################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | |
//##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | |
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 128, 64, 256, 64, 8, 8, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, S<8, 8, 8>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, S<8, 8, 8>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Interwave, V1, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, S<8, 8, 8>, Intrawave, V3, I8, I8>,
DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Row_Col_Tuple, Row, I8, I8, F16_F16_Tuple, F16, I32, I32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, Intrawave, V3, I8, I8>
// clang-format on
>;
void add_device_gemm_multiply_multiply_wmma_c_shuffle_i8_i8_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Row_Col_Tuple,
Row,
I8,
I8,
F16_F16_Tuple,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_wmma_i8_i8_f16_mk_nk_mn_instances<GemmDefault>{});
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_wmma_i8_i8_f16_mk_nk_mn_instances<GemmMNKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -69,6 +69,8 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
}
};
std::cout << "cicc: " << StrideD0 << " " << StrideD1 << std::endl;
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{}));
@@ -97,8 +99,8 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-1, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 2});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-1, 1});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{1, 3});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{1, 2});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});

View File

@@ -58,7 +58,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
endif()
list(APPEND PROFILER_OPS profile_gemm_multiply_add.cpp)
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]")
list(APPEND PROFILER_OPS profile_gemm_multiply_multiply.cpp)
list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp)
list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp)
endif()
@@ -84,6 +83,9 @@ if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFIN
(SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]" AND (DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)))
list(APPEND PROFILER_OPS profile_gemm_bilinear.cpp)
endif()
#if((SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]") OR (SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]"))
list(APPEND PROFILER_OPS profile_gemm_multiply_multiply.cpp)
#endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]")
list(APPEND PROFILER_OPS profile_gemm_universal.cpp)
@@ -149,7 +151,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
if(DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)
list(APPEND DEVICE_INSTANCES device_gemm_add_instance)
list(APPEND DEVICE_INSTANCES device_gemm_add_add_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
# list(APPEND DEVICE_INSTANCES device_gemm_fastgelu_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_gemm_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_add_relu_gemm_add_instance)
list(APPEND DEVICE_INSTANCES device_grouped_gemm_instance)
@@ -165,7 +167,6 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
list(APPEND DEVICE_INSTANCES device_batched_gemm_reduce_instance)
list(APPEND DEVICE_INSTANCES device_gemm_multiply_add_instance)
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]")
list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_instance)
list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance)
list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance)
endif()
@@ -195,6 +196,9 @@ if((SUPPORTED_GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "fp16" OR NOT DEFIN
(SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]" AND (DTYPES MATCHES "int8" OR NOT DEFINED DTYPES)))
list(APPEND DEVICE_INSTANCES device_gemm_bilinear_instance)
endif()
#if((SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]") OR (SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]"))
list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_instance)
#endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR SUPPORTED_GPU_TARGETS MATCHES "gfx1[12]")
list(APPEND DEVICE_INSTANCES device_gemm_universal_instance)

View File

@@ -13,6 +13,8 @@ static void print_helper_message()
int main(int argc, char* argv[])
{
printf("cicc2\n");
if(argc == 1)
{
print_helper_message();

View File

@@ -1,24 +1,29 @@
add_gtest_executable(test_gemm_add_xdl test_gemm_add_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_add_xdl PRIVATE utility device_gemm_add_instance)
endif()
# add_gtest_executable(test_gemm_add_xdl test_gemm_add_xdl.cpp)
# if(result EQUAL 0)
# target_link_libraries(test_gemm_add_xdl PRIVATE utility device_gemm_add_instance)
# endif()
add_gtest_executable(test_gemm_add_relu_xdl test_gemm_add_relu_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_add_relu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance)
endif()
# add_gtest_executable(test_gemm_add_relu_xdl test_gemm_add_relu_xdl.cpp)
# if(result EQUAL 0)
# target_link_libraries(test_gemm_add_relu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_relu_instance)
# endif()
add_gtest_executable(test_gemm_add_silu_xdl test_gemm_add_silu_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_add_silu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance)
endif()
# add_gtest_executable(test_gemm_add_silu_xdl test_gemm_add_silu_xdl.cpp)
# if(result EQUAL 0)
# target_link_libraries(test_gemm_add_silu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_silu_instance)
# endif()
add_gtest_executable(test_gemm_add_fastgelu_xdl test_gemm_add_fastgelu_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_add_fastgelu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance)
endif()
# add_gtest_executable(test_gemm_add_fastgelu_xdl test_gemm_add_fastgelu_xdl.cpp)
# if(result EQUAL 0)
# target_link_libraries(test_gemm_add_fastgelu_xdl PRIVATE utility device_gemm_add_instance device_gemm_add_fastgelu_instance)
# endif()
add_gtest_executable(test_gemm_add_fastgelu_wmma test_gemm_add_fastgelu_wmma.cpp)
# add_gtest_executable(test_gemm_add_fastgelu_wmma test_gemm_add_fastgelu_wmma.cpp)
# if(result EQUAL 0)
# target_link_libraries(test_gemm_add_fastgelu_wmma PRIVATE utility device_gemm_add_fastgelu_instance)
# endif()
add_gtest_executable(test_gemm_multiply_multiply_wmma test_gemm_multiply_multiply_wmma.cpp)
if(result EQUAL 0)
target_link_libraries(test_gemm_add_fastgelu_wmma PRIVATE utility device_gemm_add_fastgelu_instance)
target_link_libraries(test_gemm_multiply_multiply_wmma PRIVATE utility device_gemm_multiply_multiply_instance)
endif()

View File

@@ -12,6 +12,7 @@ using I8 = int8_t;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using I32 = int32_t;
template <typename Tuple>
class TestGemmD0Common : public ::testing::Test

View File

@@ -0,0 +1,82 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/ck.hpp"
#include "profiler/profile_gemm_multiply_multiply_impl.hpp"
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using I8 = int8_t;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
using F32 = float;
using I32 = int32_t;
template <typename Tuple>
class TestGemmMultiplyMultiply : public ::testing::Test
{
private:
using ADataType = std::tuple_element_t<0, Tuple>;
using BDataType = std::tuple_element_t<1, Tuple>;
using AccDataType = std::tuple_element_t<2, Tuple>;
using D0DataType = std::tuple_element_t<3, Tuple>;
using D1DataType = std::tuple_element_t<4, Tuple>;
using EDataType = std::tuple_element_t<5, Tuple>;
using ALayout = std::tuple_element_t<6, Tuple>;
using BLayout = std::tuple_element_t<7, Tuple>;
using D0Layout = std::tuple_element_t<8, Tuple>;
using D1Layout = std::tuple_element_t<9, Tuple>;
using ELayout = std::tuple_element_t<10, Tuple>;
constexpr static auto ProfileGemmMultiplyMultiplyImpl =
ck::profiler::profile_gemm_multiply_multiply_impl<ADataType,
BDataType,
ADataType, // ComputeDataType
AccDataType,
D0DataType,
D1DataType,
EDataType,
ALayout,
BLayout,
D0Layout,
D1Layout,
ELayout>;
public:
void Run()
{
std::vector<std::vector<ck::index_t>> lengths = {{1024, 1024, 128}};
// std::vector<std::vector<ck::index_t>> lengths = {
// {16, 32, 64}, /*{2048, 4096, 8192},*/ {2048, 4096, 128}};
bool all_success = true;
for(auto length : lengths)
{
int M = length[0];
int N = length[1];
int K = length[2];
int StrideA = ck::is_same_v<ALayout, Row> ? K : M;
int StrideB = ck::is_same_v<BLayout, Row> ? N : K;
int StrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
int StrideD1 = ck::is_same_v<D1Layout, Row> ? N : M;
int StrideE = ck::is_same_v<ELayout, Row> ? N : M;
all_success =
all_success &
ProfileGemmMultiplyMultiplyImpl(1, 1, false, false, M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE, 1, 1, 1, 0);
}
EXPECT_TRUE(all_success);
}
};
using KernelTypes =
::testing::Types<std::tuple<I8, I8, I32, F16, F16, F16, Row, Col, Row, Col, Row>/*,
std::tuple<I8, I8, I32, F32, F32, BF16, Row, Col, Row, Col, Row>*/>;
TYPED_TEST_SUITE(TestGemmMultiplyMultiply, KernelTypes);
TYPED_TEST(TestGemmMultiplyMultiply, Test_BF16FP16) { this->Run(); }