From 9f87368b4a8956375cd180eeeea4457bae8a2784 Mon Sep 17 00:00:00 2001 From: Aviral Goel Date: Tue, 1 Jul 2025 21:34:52 -0400 Subject: [PATCH] [ckProfiler] Add infrastructure and instances to profile gemm_universal with B preshuffle (#2427) * works on mi300 * fix(profiler): add error message for unsupported type/layout * refactor(preshuffle.inc): add type aliases for code readability [ROCm/composable_kernel commit: 36df1cbd0aa106d2b61fa585935dedfb980e4d40] --- .../gpu/gemm_universal_preshuffle.hpp | 151 ++++++ .../gpu/gemm_universal_preshuffle.inc | 122 +++++ .../gpu/CMakeLists.txt | 4 + .../gemm_universal_preshuffle/CMakeLists.txt | 82 ++++ ...ma16x16_nk_mn_comp_default_instance_p1.cpp | 33 ++ ...ma16x16_nk_mn_comp_default_instance_p2.cpp | 33 ++ ...ma16x16_nk_mn_comp_default_instance_p3.cpp | 33 ++ ...ma16x16_nk_mn_comp_default_instance_p4.cpp | 33 ++ ...ma16x16_nk_mn_comp_default_instance_p5.cpp | 33 ++ ...ma16x16_nk_mn_comp_default_instance_p6.cpp | 32 ++ ...f8_bf16_mk_mfma_mn_p1_default_instance.cpp | 32 ++ ...f8_bf16_mk_mfma_mn_p2_default_instance.cpp | 32 ++ ...f8_bf16_mk_mfma_mn_p3_default_instance.cpp | 32 ++ ...f8_bf16_mk_mfma_mn_p4_default_instance.cpp | 32 ++ ...f8_bf16_mk_mfma_mn_p5_default_instance.cpp | 32 ++ ...mk_mfma_nk_mn_comp_default_instance_p1.cpp | 33 ++ ...mk_mfma_nk_mn_comp_default_instance_p2.cpp | 33 ++ ...iversal_preshuffle_f8_f8_bf16_mk_nk_mn.hpp | 279 +++++++++++ ..._f8_bf16_mk_mfma32x32_mn_comp_instance.cpp | 32 ++ ..._bf16_mk_mfma32x32_mn_default_instance.cpp | 31 ++ ...ma16x16_mn_compute_default_instance_p1.cpp | 31 ++ ...ma16x16_mn_compute_default_instance_p2.cpp | 31 ++ ...ma16x16_mn_compute_default_instance_p3.cpp | 31 ++ ...ma16x16_mn_compute_default_instance_p4.cpp | 31 ++ ...ma16x16_mn_compute_default_instance_p5.cpp | 31 ++ ...ma16x16_mn_compute_default_instance_p6.cpp | 31 ++ ...al_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp | 332 +++++++++++++ ...mk_mfma_mn_compute_default_instance_p1.cpp | 33 ++ ...mk_mfma_mn_compute_default_instance_p2.cpp | 31 ++ ..._f8_f16_mk_mfma_mn_p1_default_instance.cpp | 30 ++ ..._f16_mk_mfma_mn_p1_default_instance_v2.cpp | 30 ++ ..._f8_f16_mk_mfma_mn_p2_default_instance.cpp | 30 ++ ..._f16_mk_mfma_mn_p2_default_instance_v2.cpp | 30 ++ ..._f8_f16_mk_mfma_mn_p3_default_instance.cpp | 30 ++ ..._f16_mk_mfma_mn_p3_default_instance_v2.cpp | 30 ++ ..._f8_f16_mk_mfma_mn_p4_default_instance.cpp | 30 ++ ..._f16_mk_mfma_mn_p4_default_instance_v2.cpp | 30 ++ ..._f8_f16_mk_mfma_mn_p5_default_instance.cpp | 30 ++ ..._f16_mk_mfma_mn_p5_default_instance_v2.cpp | 30 ++ ...profile_gemm_universal_preshuffle_impl.hpp | 444 ++++++++++++++++++ profiler/src/CMakeLists.txt | 2 + .../src/profile_gemm_universal_preshuffle.cpp | 180 +++++++ 42 files changed, 2632 insertions(+) create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.inc create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/CMakeLists.txt create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p1.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p2.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p3.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p4.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p5.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p6.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p4_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p5_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_nk_mn_comp_default_instance_p1.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_nk_mn_comp_default_instance_p2.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_f8_bf16_mk_mfma32x32_mn_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_f8_bf16_mk_mfma32x32_mn_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p1.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p2.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p3.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p4.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p5.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p6.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p1.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p2.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance_v2.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance_v2.cpp create mode 100644 profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp create mode 100644 profiler/src/profile_gemm_universal_preshuffle.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.hpp new file mode 100644 index 0000000000..b6acb7bfd8 --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.hpp @@ -0,0 +1,151 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" +#ifdef CK_USE_XDL +#include "gemm_universal_preshuffle.inc" +#endif + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +struct DeviceOperationInstanceFactory< + // ck::tensor_operation::device::DeviceGemmV2BPreshuffle> +{ + using DeviceOp = DeviceGemmV2BPreshuffle; + + static auto GetInstances() + { +#ifdef CK_USE_XDL + std::vector> op_ptrs; +#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8)) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma32x32_mn_instances( + op_ptrs); + add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma32x32_mn_compute_instances( + op_ptrs); + add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p1_instances( + op_ptrs); + add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p2_instances( + op_ptrs); + add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p3_instances( + op_ptrs); + add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p4_instances( + op_ptrs); + add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p5_instances( + op_ptrs); + add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_compute_instances_p1( + op_ptrs); + add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_compute_instances_p2( + op_ptrs); + add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instances_part5( + op_ptrs); + add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instances_part6( + op_ptrs); + add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instances_part4( + op_ptrs); + add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instances_part3( + op_ptrs); + add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instances_part2( + op_ptrs); + add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instances_part1( + op_ptrs); + } + } +#endif +#if(defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8)) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p1( + op_ptrs); + add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p2( + op_ptrs); + add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances( + op_ptrs); + add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances( + op_ptrs); + add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances( + op_ptrs); + add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances( + op_ptrs); + add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances( + op_ptrs); + add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2( + op_ptrs); + add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances_v2( + op_ptrs); + add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances_v2( + op_ptrs); + add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances_v2( + op_ptrs); + add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances_v2( + op_ptrs); + add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p1( + op_ptrs); + add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p2( + op_ptrs); + add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p3( + op_ptrs); + add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p4( + op_ptrs); + add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p5( + op_ptrs); + add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p6( + op_ptrs); + } + } +#endif +#endif // CK_USE_XDL + + return op_ptrs; + } +}; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.inc b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.inc new file mode 100644 index 0000000000..b44d60deaf --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.inc @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +#if(defined(CK_ENABLE_BF16) && defined(CK_ENABLE_FP8)) + +using GemmF8F8BF16InstanceVector = + std::vector>>&; + +using GemmF8F8F16InstanceVector = + std::vector>>&; + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma32x32_mn_instances( + GemmF8F8BF16InstanceVector& instances); + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma32x32_mn_compute_instances( + GemmF8F8BF16InstanceVector& instances); + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p1_instances( + GemmF8F8BF16InstanceVector& instances); + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p2_instances( + GemmF8F8BF16InstanceVector& instances); + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p3_instances( + GemmF8F8BF16InstanceVector& instances); + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p4_instances( + GemmF8F8BF16InstanceVector& instances); +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p5_instances( + GemmF8F8BF16InstanceVector& instances); + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_compute_instances_p1( + GemmF8F8BF16InstanceVector& instances); + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_compute_instances_p2( + GemmF8F8BF16InstanceVector& instances); + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instances_part1( + GemmF8F8BF16InstanceVector& instances); + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instances_part2( + GemmF8F8BF16InstanceVector& instances); + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instances_part3( + GemmF8F8BF16InstanceVector& instances); + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instances_part4( + GemmF8F8BF16InstanceVector& instances); + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instances_part5( + GemmF8F8BF16InstanceVector& instances); + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instances_part6( + GemmF8F8BF16InstanceVector& instances); + +#endif +#if(defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8)) +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p1( + GemmF8F8F16InstanceVector& instances); +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p2( + GemmF8F8F16InstanceVector& instances); + +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances( + GemmF8F8F16InstanceVector& instances); +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances( + GemmF8F8F16InstanceVector& instances); +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances( + GemmF8F8F16InstanceVector& instances); +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances( + GemmF8F8F16InstanceVector& instances); +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances( + GemmF8F8F16InstanceVector& instances); +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2( + GemmF8F8F16InstanceVector& instances); +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances_v2( + GemmF8F8F16InstanceVector& instances); +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances_v2( + GemmF8F8F16InstanceVector& instances); +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances_v2( + GemmF8F8F16InstanceVector& instances); +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances_v2( + GemmF8F8F16InstanceVector& instances); +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p1( + GemmF8F8F16InstanceVector& instances); +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p2( + GemmF8F8F16InstanceVector& instances); +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p3( + GemmF8F8F16InstanceVector& instances); +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p4( + GemmF8F8F16InstanceVector& instances); +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p5( + GemmF8F8F16InstanceVector& instances); +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p6( + GemmF8F8F16InstanceVector& instances); +#endif +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index aea3359aff..d1466206f0 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -283,6 +283,10 @@ FOREACH(subdir_path ${dir_list}) message(DEBUG "Found gemm_multiply_multiply_f8 instances, but gfx94/gfx95 not on the target list. Skipping.") set(add_inst 0) endif() + if(("${cmake_instance}" MATCHES "gemm_universal_preshuffle" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94") AND (NOT INST_TARGETS MATCHES "gfx95") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)) + message(STATUS "Found gemm_universal_preshuffle_f8 instances, but gfx94/gfx95 not on the target list. Skipping.") + set(add_inst 0) + endif() if ("${cmake_instance}" MATCHES "gemm_bilinear") set(add_inst 0) if((SUPPORTED_GPU_TARGETS MATCHES "gfx9") AND (DTYPES MATCHES "fp16" OR NOT DEFINED DTYPES)) diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/CMakeLists.txt new file mode 100644 index 0000000000..5967258789 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/CMakeLists.txt @@ -0,0 +1,82 @@ +# ONLY XDL_KERNELS +set(GEMM_UNIVERSAL_INSTANCES) + +# F8_F8_BF16 +list(APPEND GEMM_UNIVERSAL_INSTANCES +device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p6.cpp +device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p5.cpp +device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p4.cpp +device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p3.cpp +device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p2.cpp +device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p1.cpp +device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_nk_mn_comp_default_instance_p1.cpp +device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_nk_mn_comp_default_instance_p2.cpp +device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp +device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp +device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp +device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p4_default_instance.cpp +device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p5_default_instance.cpp +device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_f8_bf16_mk_mfma32x32_mn_default_instance.cpp +device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_f8_bf16_mk_mfma32x32_mn_comp_instance.cpp +) + +# F8_F8_F16 +list(APPEND GEMM_UNIVERSAL_INSTANCES + device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp + device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp + device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp + device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance.cpp + device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance.cpp + device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp + device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp + device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp + device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance_v2.cpp + device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance_v2.cpp + device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p1.cpp + device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p2.cpp + device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p1.cpp + device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p2.cpp + device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p3.cpp + device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p4.cpp + device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p5.cpp + device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p6.cpp +) + +# F8_F8_F16 +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance_v2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p3.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p4.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p5.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p6.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + +# F8_F8_BF16 +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_f8_bf16_mk_mfma32x32_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_f8_bf16_mk_mfma32x32_mn_comp_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p6.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p5.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p4.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p3.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p4_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p5_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_nk_mn_comp_default_instance_p1.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_nk_mn_comp_default_instance_p2.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + +add_instance_library(device_gemm_universal_preshuffle_instance ${GEMM_UNIVERSAL_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p1.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p1.cpp new file mode 100644 index 0000000000..d069bfaeb5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p1.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instances_part1( + std::vector>>& instances) +{ + + add_device_operation_instances( + instances, + device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_instances_part1< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p2.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p2.cpp new file mode 100644 index 0000000000..a03aa265e9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p2.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instances_part2( + std::vector>>& instances) +{ + + add_device_operation_instances( + instances, + device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_instances_part2< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p3.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p3.cpp new file mode 100644 index 0000000000..135c4ff77a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p3.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instances_part3( + std::vector>>& instances) +{ + + add_device_operation_instances( + instances, + device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_instances_part3< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p4.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p4.cpp new file mode 100644 index 0000000000..e87f5c1e2b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p4.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instances_part4( + std::vector>>& instances) +{ + + add_device_operation_instances( + instances, + device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_instances_part4< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p5.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p5.cpp new file mode 100644 index 0000000000..19ace490ab --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p5.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instances_part5( + std::vector>>& instances) +{ + + add_device_operation_instances( + instances, + device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_instances_part5< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p6.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p6.cpp new file mode 100644 index 0000000000..808b812716 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instance_p6.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +// device_gemm_xdl_universal_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_instances_p6 +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_default_instances_part6( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_instances_part6< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp new file mode 100644 index 0000000000..cd309d528d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p1_instances( + std::vector>>& instances) +{ + + add_device_operation_instances( + instances, + device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p1_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp new file mode 100644 index 0000000000..f95f77b8cf --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p2_instances( + std::vector>>& instances) +{ + + add_device_operation_instances( + instances, + device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p2_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp new file mode 100644 index 0000000000..a0f42dfa8e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p3_instances( + std::vector>>& instances) +{ + + add_device_operation_instances( + instances, + device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p3_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p4_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p4_default_instance.cpp new file mode 100644 index 0000000000..20d2cf4d5d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p4_default_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p4_instances( + std::vector>>& instances) +{ + + add_device_operation_instances( + instances, + device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p4_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p5_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p5_default_instance.cpp new file mode 100644 index 0000000000..abc304542d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p5_default_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p5_instances( + std::vector>>& instances) +{ + + add_device_operation_instances( + instances, + device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p5_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_nk_mn_comp_default_instance_p1.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_nk_mn_comp_default_instance_p1.cpp new file mode 100644 index 0000000000..77b35506b7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_nk_mn_comp_default_instance_p1.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_compute_instances_p1( + std::vector>>& instances) +{ + + add_device_operation_instances( + instances, + device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_compute_instances_p1< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_nk_mn_comp_default_instance_p2.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_nk_mn_comp_default_instance_p2.cpp new file mode 100644 index 0000000000..d5b3cd95c2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_nk_mn_comp_default_instance_p2.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_compute_instances_p2( + std::vector>>& instances) +{ + + add_device_operation_instances( + instances, + device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_compute_instances_p2< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_nk_mn.hpp new file mode 100644 index 0000000000..c761e6ad8c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_nk_mn.hpp @@ -0,0 +1,279 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmNKPadding = GemmSpecialization::NKPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma32x32_mn_instances = std::tuple< + // no valid instances available + >; +template +using device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma32x32_mn_compute_instances = + std::tuple< + // clang-format off + //##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 128, 16, 16, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 16, 32, 32, 7, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 256, 128, 16, 16, 32, 32, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 256, 128, 16, 16, 32, 32, 5, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 128, 16, 16, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 128, 16, 16, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 128, 128, 16, 16, 32, 32, 7, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 128, 128, 16, 16, 32, 32, 6, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 128, 128, 16, 16, 32, 32, 5, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; +template +using device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p1_instances = std::tuple< + // clang-format off + //##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Commented out instances are invalid. MRepeat < 4 + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + ////DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 128, 16, 16, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + ////DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 256, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + ////DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 128, 16, 16, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 512, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + /////DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 512, 128, 16, 16, 16, 16, 2, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; + +template +using device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p2_instances = std::tuple< + // Commented out instances do not work because MRepeat < 4 + // clang-format off + //##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 256, 16, 16, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 128, 256, 16, 16, 16, 16, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 128, 256, 16, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 512, 16, 16, 16, 16, 4, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 128, 512, 16, 16, 16, 16, 2, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; + +template +using device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p3_instances = std::tuple< + // clang-format off + //##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // N 256 +// //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 256, 256, 16, 16, 16, 16, 4, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, +// //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 256, 16, 16, 16, 16, 2, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 256, 512, 16, 16, 16, 16, 4, 4, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, +// //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 512, 16, 16, 16, 16, 2, 4, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 512, 256, 16, 16, 16, 16, 4, 8, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> +// //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 512, 256, 16, 16, 16, 16, 2, 8, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; + +template +using device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p4_instances = std::tuple< + // None of the instacnes have MRepeat >= 4 + // clang-format off + //##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 128, 512, 16, 16, 16, 16, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 512, 16, 16, 16, 16, 1, 4, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 512, 16, 16, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 512, 16, 16, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 64, 512, 16, 16, 16, 16, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 64, 512, 16, 16, 16, 16, 2, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 16, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 16, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; + +template +using device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_p5_instances = std::tuple< + // MRepeat < 1, invalid instances + // clang-format off + //##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 256, 16, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 256, 16, 16, 16, 16, 1, 8, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + + // clang-format on + >; + +template +using device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_compute_instances_p1 = std::tuple< + // clang-format off + //##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8 ,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8 ,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 256, 128, 16, 16, 16, 16, 6, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8 ,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 256, 128, 16, 16, 16, 16, 5, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8 ,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8 ,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; + +template +using device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma_mn_compute_instances_p2 = std::tuple< + // clang-format off + //##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8 ,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 7, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8 ,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 128, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8 ,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 128, 128, 16, 16, 16, 16, 5, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8 ,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 128, 128, 16, 16, 16, 16, 10, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8 ,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8 ,BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; +template +using device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_instances_part1 = + std::tuple< + // clang-format off + //##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 256, 128, 16, 16, 16, 16, 6, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 256, 128, 16, 16, 16, 16, 5, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; + +template +using device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_instances_part2 = + std::tuple< + // clang-format off + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 14, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 224, 128, 16, 16, 16, 16, 7, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 192, 128, 16, 16, 16, 16, 14, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 160, 128, 16, 16, 16, 16, 7, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 14, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 96, 128, 16, 16, 16, 16, 7, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 64, 128, 16, 16, 16, 16, 14, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; +template +using device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_instances_part3 = + std::tuple< + // clang-format off + //############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 256, 128, 16, 16, 16, 16, 12, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 224, 128, 16, 16, 16, 16, 6, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 192, 128, 16, 16, 16, 16, 12, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 160, 128, 16, 16, 16, 16, 6, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 128, 128, 16, 16, 16, 16, 12, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 96, 128, 16, 16, 16, 16, 6, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 64, 128, 16, 16, 16, 16, 12, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + //clang-format on + >; + +template +using device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_instances_part4 = std::tuple< + // clang-format off + //############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 256, 128, 16, 16, 16, 16, 10, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 192, 128, 16, 16, 16, 16, 10, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 128, 128, 16, 16, 16, 16, 10, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 64, 128, 16, 16, 16, 16, 10, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; + +template +using device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma16x16_nk_mn_comp_instances_part5 = + std::tuple< + // clang-format off + //############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 96, 128, 16, 16, 16, 16, 4, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 128, 16, 16, 16, 16, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 256, 16, 16, 16, 16, 8, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 96, 256, 16, 16, 16, 16, 4, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 256, 16, 16, 16, 16, 8, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; + +template +using device_gemm_xdl_universal_preshuffle_xdl_f8_f8_bf16_mk_mfma16x16_mn_compute_instances_part6 = + std::tuple< +// clang-format off + //#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | +#if defined(__gfx94__) || defined(CK_USE_GFX94) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8,32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 224, 128, 16, 16, 16, 16, 4, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8,32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 192, 128, 16, 16, 16, 16, 8, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8,32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 128, 16, 16, 16, 16, 4, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8,32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8,32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> +#endif + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_f8_bf16_mk_mfma32x32_mn_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_f8_bf16_mk_mfma32x32_mn_comp_instance.cpp new file mode 100644 index 0000000000..e367cf2c4b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_f8_bf16_mk_mfma32x32_mn_comp_instance.cpp @@ -0,0 +1,32 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma32x32_mn_compute_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma32x32_mn_compute_instances< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_f8_bf16_mk_mfma32x32_mn_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_f8_bf16_mk_mfma32x32_mn_default_instance.cpp new file mode 100644 index 0000000000..bb34354261 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_bf16/device_gemm_xdl_universal_preshuffle_f8_f8_f8_bf16_mk_mfma32x32_mn_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_nk_mn.hpp" +#include "ck/host_utility/device_prop.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma32x32_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_xdl_universal_preshuffle_f8_f8_bf16_mk_mfma32x32_mn_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p1.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p1.cpp new file mode 100644 index 0000000000..3bdfbf5b0b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p1.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p1( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p1< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p2.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p2.cpp new file mode 100644 index 0000000000..e1bb7bfda3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p2.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p2( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p2< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p3.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p3.cpp new file mode 100644 index 0000000000..68c2bcd3b5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p3.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p3( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p3< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p4.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p4.cpp new file mode 100644 index 0000000000..1581440d9e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p4.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p4( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p4< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p5.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p5.cpp new file mode 100644 index 0000000000..29cfba2e91 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p5.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p5( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p5< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p6.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p6.cpp new file mode 100644 index 0000000000..5da668486a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instance_p6.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_default_instances_p6( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p6< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp new file mode 100644 index 0000000000..df45aa06db --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp @@ -0,0 +1,332 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; +using MultiplyMultiply = element_wise::MultiplyMultiply; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +static constexpr auto v1 = BlockGemmPipelineVersion::v1; +static constexpr auto v2 = BlockGemmPipelineVersion::v2; + +// All commented out instances are invalid because MRepeat < 4. +template +using device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma32x32_mn_instances = std::tuple< + // clang-format off + // None of these will work because MRepeat < 4 + //##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //// p1 + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 128, 16, 16, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //// N 256 + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 256, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //// N 512 + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 512, 128, 16, 16, 32, 32, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 512, 128, 16, 16, 32, 32, 1, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //// p2 + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 256, 16, 16, 32, 32, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 512, 16, 16, 32, 32, 2, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 128, 512, 16, 16, 32, 32, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //// p3 + //// N 256 + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 256, 256, 16, 16, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 256, 16, 16, 32, 32, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 256, 512, 16, 16, 32, 32, 2, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 512, 16, 16, 32, 32, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //// N 512 + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 512, 256, 16, 16, 32, 32, 2, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 512, 256, 16, 16, 32, 32, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //// p4 + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 64, 512, 16, 16, 32, 32, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8> + // clang-format on + >; + +template +using device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma32x32_mn_compute_instances = std::tuple< + // clang-format off + //##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //p1 + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 128, 16, 16, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 16, 32, 32, 7, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 256, 128, 16, 16, 32, 32, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 256, 128, 16, 16, 32, 32, 5, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 128, 16, 16, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + //// p2 + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 128, 16, 16, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 128, 128, 16, 16, 32, 32, 7, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 128, 128, 16, 16, 32, 32, 6, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 128, 128, 16, 16, 32, 32, 5, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + //clang-format on + >; + +template +using device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_instances = + std::tuple< + // clang-format off + //##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 128, 16, 16, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //// N 256 + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 256, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 128, 16, 16, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //// N 512 + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 512, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8> + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 512, 128, 16, 16, 16, 16, 2, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8> + // clang-format on + >; + +template +using device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_instances = std::tuple< + // clang-format off + //##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 256, 16, 16, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 128, 256, 16, 16, 16, 16, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 128, 256, 16, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 512, 16, 16, 16, 16, 4, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8> + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 128, 512, 16, 16, 16, 16, 2, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8> + // clang-format on + >; + +template +using device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_instances = std::tuple< + // clang-format off + //##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // N 256 + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 256, 256, 16, 16, 16, 16, 4, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 256, 16, 16, 16, 16, 2, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 256, 512, 16, 16, 16, 16, 4, 4, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 256, 512, 16, 16, 16, 16, 2, 4, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //// N 512 + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 512, 256, 16, 16, 16, 16, 4, 8, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8> + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 512, 256, 16, 16, 16, 16, 2, 8, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8> + // clang-format on + >; + +template +using device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_instances = std::tuple< + // None of these will work because MRepeat < 4 + // clang-format off + //##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 128, 512, 16, 16, 16, 16, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 512, 16, 16, 16, 16, 1, 4, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + // + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 32, 16, 512, 16, 16, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 128, 16, 16, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 16, 32, 512, 16, 16, 16, 16, 1, 1, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 64, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 32, 64, 512, 16, 16, 16, 16, 1, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 64, 512, 16, 16, 16, 16, 2, 2, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 16, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 16, 512, 16, 16, 16, 16, 1, 1, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 64, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8> + + // clang-format on + >; + +template +using device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_instances = std::tuple< + // None of these will work because MRepeat < 4 + // clang-format off + //##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 256, 256, 16, 16, 16, 16, 1, 4, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8>, + //DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 16, 512, 256, 16, 16, 16, 16, 1, 8, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlkGemmPipeVer, F8> + // clang-format on + >; + +template +using device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_instances_p1 = std::tuple< + // clang-format off + //##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 256, 128, 16, 16, 16, 16, 6, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 256, 128, 16, 16, 16, 16, 5, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; + +template +using device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_instances_p2 = std::tuple< + // clang-format off + //##########################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //##########################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //##########################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //##########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 7, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 128, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 128, 128, 16, 16, 16, 16, 5, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 128, 128, 16, 16, 16, 16, 10, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; + +template +using device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p1 = + std::tuple< + // clang-format off + //############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Compute friendly + // 256x[64, 256, 32]x128 + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 256, 128, 16, 16, 16, 16, 16, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 192, 128, 16, 16, 16, 16, 16, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 160, 128, 16, 16, 16, 16, 8, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 128, 128, 16, 16, 16, 16, 16, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 96, 128, 16, 16, 16, 16, 8, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 64, 128, 16, 16, 16, 16, 16, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; + +template +using device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p2 = + std::tuple< + // clang-format off + //############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // 224x[64, 256, 32]x128 + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 14, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 224, 128, 16, 16, 16, 16, 7, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 192, 128, 16, 16, 16, 16, 14, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 160, 128, 16, 16, 16, 16, 7, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 128, 128, 16, 16, 16, 16, 14, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 96, 128, 16, 16, 16, 16, 7, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 64, 128, 16, 16, 16, 16, 14, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; +template +using device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p3 = + std::tuple< + // clang-format off + //############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // 192x[64, 256, 32]x128, 192x[64]x256 + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 256, 128, 16, 16, 16, 16, 12, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 224, 128, 16, 16, 16, 16, 6, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 192, 128, 16, 16, 16, 16, 12, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 160, 128, 16, 16, 16, 16, 6, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 128, 128, 16, 16, 16, 16, 12, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 96, 128, 16, 16, 16, 16, 6, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 64, 128, 16, 16, 16, 16, 12, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; +template +using device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p4 = + std::tuple< + // clang-format off + //############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // 160x[64, 256, 32]x128, 160x[64, 96, 32]x256 + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 256, 128, 16, 16, 16, 16, 10, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 224, 128, 16, 16, 16, 16, 5, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 192, 128, 16, 16, 16, 16, 10, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 160, 128, 16, 16, 16, 16, 5, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 128, 128, 16, 16, 16, 16, 10, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 96, 128, 16, 16, 16, 16, 5, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 160, 64, 128, 16, 16, 16, 16, 10, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; +template +using device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p5 = + std::tuple< + // clang-format off + //############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 96, 128, 16, 16, 16, 16, 4, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 128, 16, 16, 16, 16, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 256, 16, 16, 16, 16, 8, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 96, 256, 16, 16, 16, 16, 4, 3, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 64, 256, 16, 16, 16, 16, 8, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; + +template +using device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma16x16_mn_compute_instances_p6 = + std::tuple< + // clang-format off + //############################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //############################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| + //############################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NWaveNPerXdl| Scheduler| Verision| + //############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 256, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 224, 128, 16, 16, 16, 16, 4, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 192, 128, 16, 16, 16, 16, 8, 3, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 160, 128, 16, 16, 16, 16, 4, 5, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemm_Xdl_CShuffleV3_BPreshuffle< Row, Col, Row, F8, F8, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p1.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p1.cpp new file mode 100644 index 0000000000..2a6b98bbe9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p1.cpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p1( + std::vector>>& instances) +{ + printf("add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_" + "instances_p1\n"); + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_instances_p1< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p2.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p2.cpp new file mode 100644 index 0000000000..c647aaa4eb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instance_p2.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_default_instances_p2( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_compute_instances_p2< + GemmDefault>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp new file mode 100644 index 0000000000..0a2df2887a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances_v2( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp new file mode 100644 index 0000000000..27a43ced98 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instance_v2.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp new file mode 100644 index 0000000000..a16aed9c22 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp new file mode 100644 index 0000000000..7221768b6d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instance_v2.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances_v2( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp new file mode 100644 index 0000000000..3d254a8bf6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp new file mode 100644 index 0000000000..92ac2fa1de --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instance_v2.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances_v2( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance.cpp new file mode 100644 index 0000000000..76ed9a1ffe --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance_v2.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance_v2.cpp new file mode 100644 index 0000000000..096d28f4bb --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instance_v2.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_default_instances_v2( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p4_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance.cpp new file mode 100644 index 0000000000..c413fa770c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance_v2.cpp b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance_v2.cpp new file mode 100644 index 0000000000..6df56b5f50 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_universal_preshuffle/device_gemm_xdl_universal_preshuffle_f8_f8_f16/device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instance_v2.cpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_default_instances_v2( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_universal_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p5_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp b/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp new file mode 100644 index 0000000000..e218143857 --- /dev/null +++ b/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp @@ -0,0 +1,444 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/gemm_universal_preshuffle.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +void preShuffleBuffer(const T* src, T* dst, int N, int K, int NXdl) +{ + int KPack = 16; + int NLane = NXdl; + int KLane = 64 / NLane; + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + dst[outputIndex] = src[n * K + k]; + } + } +} + +template +bool profile_gemm_universal_preshuffle_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideC, + int KBatch, + int n_warmup, + int n_iter, + uint64_t rotating = 0) +{ + bool pass = true; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_preshuffled( + f_host_tensor_descriptor(K, N, StrideB, BLayout{})); // for preshuffle + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::size_t total_gemm_needed = + a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes(); + int rotating_count = std::max( + 1, + std::min(n_iter, + static_cast(std::ceil(static_cast(rotating) / total_gemm_needed)))); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; + std::cout << "rotating count: " << rotating_count << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-1, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-1, 2}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + } + + using AElementOp = ck::tensor_operation::element_wise::PassThrough; + using BElementOp = ck::tensor_operation::element_wise::PassThrough; + using CElementOp = ck::tensor_operation::element_wise::PassThrough; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto c_element_op = CElementOp{}; + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); + DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + + using DeviceOp = ck::tensor_operation::device::DeviceGemmV2BPreshuffle; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // Run reference GEMM + if(do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); + + ref_invoker.Run(ref_argument); + } + + std::string best_op_name; + std::optional best_op_object_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + float best_kbatch = 0; + + // profile device GEMM instances + for(auto& op_ptr : op_ptrs) + { + const int KPerBlock = op_ptr->GetKPerBlock(); + + if(op_ptr->GetPermuteB()) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + + if constexpr(is_same_v && is_same_v) + { + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b_k_n_permute(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 6, i) = i4x2; + } + } + } + } + } + else + { + b_k_n_permute = b_k_n; + } + int NPerXdl = op_ptr->GetPreShuffleParameters(); + + preShuffleBuffer( + b_k_n_permute.mData.data(), b_preshuffled.mData.data(), N, K, NPerXdl); + + b_device_buf.ToDevice(b_preshuffled.mData.data()); + + std::vector kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38}; + + if(KBatch > 0) + { + kbatch_list = {KBatch}; + } + + for(std::size_t i = 0; i < kbatch_list.size(); i++) + { + auto kbatch_curr = kbatch_list[i]; + + auto argument_ptr = + op_ptr->MakeArgumentPointer(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + kbatch_curr, + a_element_op, + b_element_op, + c_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + + // re-init C to zero before profiling next kernel + c_device_buf.SetZero(); + + invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, false, 0, n_warmup, n_iter}); + + if(do_verification) + { + c_device_buf.FromDevice(c_m_n_device_result.mData.data()); + +#if defined CK_ENABLE_FP8 + // set softer tolerances for fp8 + if constexpr(is_same_v || is_same_v || + is_same_v) + { + std::string msg = "Error: Incorrect results!"; + double rtol = 1e-1; + double atol = 1e-1; + pass = pass & ck::utils::check_err( + c_m_n_device_result, c_m_n_host_result, msg, rtol, atol); + } + else + { +#endif + pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result); +#if defined CK_ENABLE_FP8 + } +#endif + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a_m_k.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b_k_n.mData, ",") << std::endl; + LogRangeAsType( + std::cout << "c_host : ", c_m_n_host_result.mData, ",") + << std::endl; + LogRangeAsType( + std::cout << "c_device: ", c_m_n_device_result.mData, ",") + << std::endl; + } + } + + std::string op_name = op_ptr->GetTypeString(); + std::optional op_obj_name = op_ptr->GetObjectName(); + + float ave_time = invoker_ptr->Run(argument_ptr.get(), + StreamConfig{nullptr, + time_kernel, + 0, + n_warmup, + n_iter, + rotating_count > 1, + rotating_count}); + + std::size_t flop = std::size_t(2) * M * N * K; + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + std::size_t num_btype = sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / BPackedSize + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops + << " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch " + << kbatch_curr << std::endl; + + if(tflops > best_tflops && ave_time > 1e-10) + { + best_op_name = op_name; + best_op_object_name = op_obj_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + best_kbatch = kbatch_curr; + } + } + else + { + std::cout << op_ptr->GetTypeString() << " does not support this problem" + << std::endl; + } + } + } + + if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f32"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = f16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = bf16"; + } + else if constexpr(is_same::value) + { + std::cout << "Best Perf for datatype = int8"; + } + + if constexpr(is_same::value) + { + std::cout << " ALayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " ALayout = ColumnMajor"; + } + + if constexpr(is_same::value) + { + std::cout << " BLayout = RowMajor"; + } + else if constexpr(is_same::value) + { + std::cout << " BLayout = ColumnMajor"; + } + + std::cout << "M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA + << " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch + << " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec + << " GB/s, " << best_op_name << std::endl; + + if(best_op_object_name) + std::cout << best_op_object_name.value() << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 1e65e9e580..1dc942699f 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -63,6 +63,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp) list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp) list(APPEND PROFILER_OPS profile_gemm_blockscale_wp.cpp) + list(APPEND PROFILER_OPS profile_gemm_universal_preshuffle.cpp) endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx95") list(APPEND PROFILER_OPS profile_gemm_mx.cpp) @@ -171,6 +172,7 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance) list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance) list(APPEND DEVICE_INSTANCES device_gemm_blockscale_wp_instance) + list(APPEND DEVICE_INSTANCES device_gemm_universal_preshuffle_instance) endif() if(SUPPORTED_GPU_TARGETS MATCHES "gfx95") list(APPEND DEVICE_INSTANCES device_gemm_mx_instance) diff --git a/profiler/src/profile_gemm_universal_preshuffle.cpp b/profiler/src/profile_gemm_universal_preshuffle.cpp new file mode 100644 index 0000000000..bc09d7d35d --- /dev/null +++ b/profiler/src/profile_gemm_universal_preshuffle.cpp @@ -0,0 +1,180 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include + +#include "profiler/profile_gemm_universal_preshuffle_impl.hpp" +#include "profiler_operation_registry.hpp" + +enum struct GemmMatrixLayout +{ + // MK_KN_MN, // 0 + MK_NK_MN = 1, // 1 + // KM_KN_MN, // 2 + // KM_NK_MN, // 3 +}; + +enum struct GemmDataType +{ + F8_F8_F16 = 0, // 0 + F8_F8_BF16 = 1, // 1 + // F32_F32_F32, // 0 + // F16_F16_F16, // 1 + // BF16_BF16_BF16, // 2 + // INT8_INT8_INT8, // 3 + // F8_F16_F16, // 4 + // F16_F8_F16, // 5 + // F16_F16_F16_F8, // 6 + // F8_F8_BF16, // 7 + // F16_I4_F16, // 8 + // BF16_I4_BF16, // 9 + +}; + +#define OP_NAME "gemm_universal_preshuffle" +#define OP_DESC "Universal GEMM Preshuffle" + +int profile_gemm_universal_preshuffle(int argc, char* argv[]) +{ + if(argc != 15 && argc != 18) + { + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: data type (0: f8->bf16, 1: f8->f16)\n"); + // printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"); + printf(" 1: A[m, k] * B[n, k] = C[m, n];\n"); + // printf(" 2: A[k, m] * B[k, n] = C[m, n];\n"); + // printf(" 3: A[k, m] * B[n, k] = C[m, n])\n"); + printf("arg4: verification (0: no; 1: yes)\n"); + printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); + printf("arg6: print tensor value (0: no; 1: yes)\n"); + printf("arg7: time kernel (0=no, 1=yes)\n"); + printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n"); + printf("arg14: split k into mulitiple batch\n"); + printf("optional:\n"); + printf("arg15: number of warm-up cycles (default 1)\n"); + printf("arg16: number of iterations (default 10)\n"); + printf("arg17: memory for rotating buffer (default 0, size in MB)\n"); + exit(1); + } + + int M; + int N; + int StrideA; + int StrideB; + // Analyze the unsupported matrix shapes, switch the M and N number + if(std::stoi(argv[9]) % 8 != 0 && std::stoi(argv[8]) % 8 == 0) + { + M = std::stoi(argv[9]); + StrideA = std::stoi(argv[12]); + N = std::stoi(argv[8]); + StrideB = std::stoi(argv[11]); + } + else + { + M = std::stoi(argv[8]); + StrideA = std::stoi(argv[11]); + N = std::stoi(argv[9]); + StrideB = std::stoi(argv[12]); + } + const auto data_type = static_cast(std::stoi(argv[2])); + const auto layout = static_cast(std::stoi(argv[3])); + const bool do_verification = std::stoi(argv[4]); + const int init_method = std::stoi(argv[5]); + const bool do_log = std::stoi(argv[6]); + const bool time_kernel = std::stoi(argv[7]); + + const int K = std::stoi(argv[10]); + + const int StrideC = std::stoi(argv[13]); + const int KBatch = std::stoi(argv[14]); + + int n_warmup = 1; + int n_iter = 10; + uint64_t rotating = 0; + if(argc == 18) + { + n_warmup = std::stoi(argv[15]); + n_iter = std::stoi(argv[16]); + rotating = std::stoull(argv[17]) * 1024 * 1024; + } + + using F32 = float; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; +#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) || defined(CK_USE_WMMA_FP8) + using F8 = ck::f8_t; +#endif + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_type, + auto b_type, + auto comp_type, + auto acc_type, + auto c_type, + auto a_layout, + auto b_layout, + auto c_layout) { + using ADataType = decltype(a_type); + using BDataType = decltype(b_type); + using ComputeDataType = decltype(comp_type); + using AccDataType = decltype(acc_type); + using CDataType = decltype(c_type); + + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using CLayout = decltype(c_layout); + + const int DefaultStrideA = ck::is_same_v ? K : M; + const int DefaultStrideB = ck::is_same_v ? N : K; + const int DefaultStrideC = ck::is_same_v ? N : M; + + bool pass = ck::profiler::profile_gemm_universal_preshuffle_impl( + do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + (StrideA < 0) ? DefaultStrideA : StrideA, + (StrideB < 0) ? DefaultStrideB : StrideB, + (StrideC < 0) ? DefaultStrideC : StrideC, + KBatch, + n_warmup, + n_iter, + rotating); + + return pass ? 0 : 1; + }; + + if(data_type == GemmDataType::F8_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F8{}, F8{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); + } +#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) || defined(CK_USE_WMMA_FP8) + if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_NK_MN) + { + return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Col{}, Row{}); + } +#endif + else + { + std::cout << "this data_type & layout is not implemented" << std::endl; + + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_universal_preshuffle);