From dfb07b8240fbe87e6b374b04e29615d944dc7d6d Mon Sep 17 00:00:00 2001 From: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Date: Fri, 11 Jul 2025 15:32:12 -0600 Subject: [PATCH] MX GEMM - Add FP6 GEMM Test (#2488) * Add F6 GEMM MX Test * Add BF6 GEMM MX Test [ROCm/composable_kernel commit: 25b359d63041636087a9f0d5bdf27632ffe8cf0d] --- .../device_operation_instance_factory.hpp | 2 + .../tensor_operation_instance/gpu/gemm_mx.hpp | 40 +++++++++++ .../gpu/gemm_mx/CMakeLists.txt | 4 ++ ...vice_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn.hpp | 66 ++++++++++++++++++ ...bf6_bf6_bf16_mk_nk_mn_default_instance.cpp | 32 +++++++++ .../device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn.hpp | 67 +++++++++++++++++++ ...dl_f6_f6_f16_mk_nk_mn_default_instance.cpp | 32 +++++++++ .../include/profiler/profile_gemm_mx_impl.hpp | 20 ++++-- test/gemm_mx/CMakeLists.txt | 1 + test/gemm_mx/test_gemm_mx.cpp | 8 ++- test/gemm_mx/test_gemm_mx_util.hpp | 2 +- 11 files changed, 265 insertions(+), 9 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf6_bf6_bf16/device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf6_bf6_bf16/device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f6_f6_f16/device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f6_f6_f16/device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn_default_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index 022afe7fa4..f6983810be 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -24,6 +24,8 @@ using F8 = ck::f8_t; using BF8 = ck::bf8_t; using I4 = ck::pk_i4_t; using F4 = ck::f4x2_pk_t; +using F6 = ck::f6x16_pk_t; +using BF6 = ck::bf6x16_pk_t; using E8M0 = ck::e8m0_bexp_t; using E8M0PK = int32_t; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp index ec75a0cfb0..2fe4a5c975 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp @@ -87,6 +87,34 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instances( PassThrough, PassThrough>>>& instances); +void add_device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn_default_instances( + std::vector>>& instances); + +void add_device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn_default_instances( + std::vector>>& instances); + template && is_same_v && is_same_v) { + // Row-Col-Row -- one of the two currently supported layouts, another one is + // Row-MFMA-Row if constexpr(is_same_v && is_same_v && is_same_v) { @@ -147,6 +177,16 @@ struct DeviceOperationInstanceFactory< { add_device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instances(op_ptrs); } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn_default_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn_default_instances(op_ptrs); + } } else if constexpr(is_same_v && is_same_v && is_same_v) diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_mx/CMakeLists.txt index bb67a9edae..67805a86b1 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/CMakeLists.txt @@ -2,6 +2,8 @@ set(GEMM_MX_INSTANCES) list(APPEND GEMM_MX_INSTANCES + device_gemm_mx_xdl_f6_f6_f16/device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn_default_instance.cpp + device_gemm_mx_xdl_bf6_bf6_bf16/device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn_default_instance.cpp device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp @@ -11,6 +13,8 @@ list(APPEND GEMM_MX_INSTANCES ) +set_source_files_properties(device_gemm_mx_xdl_f6_f6_f16/device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_mx_xdl_bf6_bf6_bf16/device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf6_bf6_bf16/device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf6_bf6_bf16/device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn.hpp new file mode 100644 index 0000000000..4a3d54e90b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf6_bf6_bf16/device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn.hpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = bhalf_t; +using F32 = float; +using E8M0 = ck::e8m0_bexp_t; +using E8M0PK = int32_t; +using BF6 = ck::bf6x16_pk_t; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto ScaleBlockSize = 32; +static constexpr auto KPerBlock = 256 / ck::packed_size_v; // 256 bf6 = 16 bf6x16_pk_t + +template +using device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn_instances = std::tuple< + // clang-format off + //###########################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //###########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //###########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, BF6, E8M0PK, BF6, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, KPerBlock, 1, 1, 16, 16, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, BF6, E8M0PK, BF6, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 64, KPerBlock, 1, 1, 16, 16, 4, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, BF6, E8M0PK, BF6, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, KPerBlock, 1, 1, 16, 16, 2, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, BF6, E8M0PK, BF6, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 32, KPerBlock, 1, 1, 16, 16, 4, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, BF6, E8M0PK, BF6, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, KPerBlock, 1, 1, 16, 16, 2, 2, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, BF6, E8M0PK, BF6, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, KPerBlock, 1, 1, 16, 16, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, BF6, E8M0PK, BF6, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 64, KPerBlock, 1, 1, 16, 16, 4, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, BF6, E8M0PK, BF6, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, KPerBlock, 1, 1, 16, 16, 2, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, BF6, E8M0PK, BF6, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 32, KPerBlock, 1, 1, 16, 16, 4, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, BF6, E8M0PK, BF6, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, KPerBlock, 1, 1, 16, 16, 2, 2, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + std::nullptr_t + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf6_bf6_bf16/device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf6_bf6_bf16/device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn_default_instance.cpp new file mode 100644 index 0000000000..bc07b32871 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_bf6_bf6_bf16/device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn_default_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_mx_xdl_bf6_bf6_bf16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f6_f6_f16/device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f6_f6_f16/device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn.hpp new file mode 100644 index 0000000000..08c8f472c9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f6_f6_f16/device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F16 = half_t; +using F32 = float; +using E8M0 = ck::e8m0_bexp_t; +using E8M0PK = int32_t; +using F6 = ck::f6x16_pk_t; +using BF6 = ck::bf6x16_pk_t; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto ScaleBlockSize = 32; +static constexpr auto KPerBlock = 256 / ck::packed_size_v; // 256 f6 = 16 f6x16_pk_t + +template +using device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn_instances = std::tuple< + // clang-format off + //###########################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //###########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //###########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F6, E8M0PK, F6, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, KPerBlock, 1, 1, 16, 16, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F6, E8M0PK, F6, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 64, KPerBlock, 1, 1, 16, 16, 4, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F6, E8M0PK, F6, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, KPerBlock, 1, 1, 16, 16, 2, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F6, E8M0PK, F6, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 32, KPerBlock, 1, 1, 16, 16, 4, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F6, E8M0PK, F6, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, KPerBlock, 1, 1, 16, 16, 2, 2, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F6, E8M0PK, F6, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, KPerBlock, 1, 1, 16, 16, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F6, E8M0PK, F6, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 64, KPerBlock, 1, 1, 16, 16, 4, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F6, E8M0PK, F6, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, KPerBlock, 1, 1, 16, 16, 2, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F6, E8M0PK, F6, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 32, KPerBlock, 1, 1, 16, 16, 4, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 32, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F6, E8M0PK, F6, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, KPerBlock, 1, 1, 16, 16, 2, 2, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + std::nullptr_t + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f6_f6_f16/device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f6_f6_f16/device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn_default_instance.cpp new file mode 100644 index 0000000000..d92d0b97fe --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f6_f6_f16/device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn_default_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_mx_xdl_f6_f6_f16_mk_nk_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_mx_impl.hpp b/profiler/include/profiler/profile_gemm_mx_impl.hpp index 4df2348700..1fbe60c6cf 100644 --- a/profiler/include/profiler/profile_gemm_mx_impl.hpp +++ b/profiler/include/profiler/profile_gemm_mx_impl.hpp @@ -216,12 +216,20 @@ bool profile_gemm_mx_impl(int do_verification, auto a_data_element = [](float x) { if constexpr(ck::is_same_v) return ck::type_convert(ck::float2_t(x)); + else if constexpr(ck::packed_size_v == 32) + return ck::type_convert(ck::float32_t(x)); + else if constexpr(ck::packed_size_v == 16) + return ck::type_convert(ck::float16_t(x)); else return ck::type_convert(x); }; auto b_data_element = [](float x) { if constexpr(ck::is_same_v) return ck::type_convert(ck::float2_t(x)); + else if constexpr(ck::packed_size_v == 32) + return ck::type_convert(ck::float32_t(x)); + else if constexpr(ck::packed_size_v == 16) + return ck::type_convert(ck::float16_t(x)); else return ck::type_convert(x); }; @@ -247,15 +255,17 @@ bool profile_gemm_mx_impl(int do_verification, case 1: - a_m_k.GenerateTensorDistr(int_distr{-4, 5}); // Z[-4,4] - b_k_n->GenerateTensorDistr(int_distr{-4, 5}); // Z[-4,4] + a_m_k.GenerateTensorDistr( + int_distr{-4, 4}, ck::identity{}, std::minstd_rand(time(nullptr))); // Z[-4,4] + b_k_n->GenerateTensorDistr(int_distr{-4, 4}); // Z[-4,4] - a_m_k_scale.GenerateTensorDistr(int_distr{125, 129}); // scales: {0.25, 0.5, 1, 2} - b_k_n_scale.GenerateTensorDistr(int_distr{125, 129}); // scales: {0.25, 0.5, 1, 2} + a_m_k_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2} + b_k_n_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2} break; default: - a_m_k.GenerateTensorDistr(float_distr{-2.0, 2.0}); + a_m_k.GenerateTensorDistr( + float_distr{-2.0, 2.0}, ck::identity{}, std::minstd_rand(time(nullptr))); a_m_k_scale.GenerateTensorDistr(float_distr{powf(2.0f, -125.0f), 1.0f}); b_k_n->GenerateTensorDistr(float_distr{-2.0, 2.0}); diff --git a/test/gemm_mx/CMakeLists.txt b/test/gemm_mx/CMakeLists.txt index 71a0a98f2d..7a04d5378f 100644 --- a/test/gemm_mx/CMakeLists.txt +++ b/test/gemm_mx/CMakeLists.txt @@ -1,4 +1,5 @@ add_gtest_executable(test_gemm_mx test_gemm_mx.cpp) if(result EQUAL 0) + target_compile_options(test_gemm_mx PRIVATE -mavx512f) target_link_libraries(test_gemm_mx PRIVATE utility device_gemm_mx_instance) endif() diff --git a/test/gemm_mx/test_gemm_mx.cpp b/test/gemm_mx/test_gemm_mx.cpp index a3449cb1bb..b63fd880c1 100644 --- a/test/gemm_mx/test_gemm_mx.cpp +++ b/test/gemm_mx/test_gemm_mx.cpp @@ -10,8 +10,8 @@ using E8M0 = ck::e8m0_bexp_t; using F8 = ck::f8_t; using BF8 = ck::bf8_t; -using F6 = ck::f6_t; -using BF6 = ck::bf6_t; +using F6 = ck::f6x16_pk_t; +using BF6 = ck::bf6x16_pk_t; using F4 = ck::f4x2_pk_t; using F16 = ck::half_t; using BF16 = ck::bhalf_t; @@ -58,7 +58,9 @@ using KernelTypes_MK_NK = ::testing::Types< std::tuple< F8, F8, F16, ck::Number<32> >, std::tuple< F8, F8, BF16, ck::Number<32> >, #endif - std::tuple< F4, F4, F16, ck::Number<32> > + std::tuple< F4, F4, F16, ck::Number<32> >, + std::tuple< F6, F6, F16, ck::Number<32> >, + std::tuple< BF6, BF6, BF16, ck::Number<32> > >; using KernelTypes_MK_KN = ::testing::Types< diff --git a/test/gemm_mx/test_gemm_mx_util.hpp b/test/gemm_mx/test_gemm_mx_util.hpp index 675a3de127..c2b56bb01f 100644 --- a/test/gemm_mx/test_gemm_mx_util.hpp +++ b/test/gemm_mx/test_gemm_mx_util.hpp @@ -74,7 +74,7 @@ class TestGemmMX : public testing::Test const int StrideB, const int StrideC, int kbatch = 1, - int n_warmup = 1, + int n_warmup = 10, int n_iter = 10) { bool pass = ck::profiler::profile_gemm_mx_impl